Compare commits

..

38 Commits

Author SHA1 Message Date
Evan Lohn
d04128b8b1 fix: sharepoint unquote (#8786) 2026-02-26 03:38:46 +00:00
Nikolas Garza
bbebdf8f78 feat(scim): Entra ID enterprise extension support [3/3] (#8747) 2026-02-26 02:32:04 +00:00
Nikolas Garza
161279a2d5 feat(scim): field round-tripping for IdP attribute preservation [2/3] (#8746) 2026-02-26 02:01:13 +00:00
Jamison Lahman
e5ebb45a20 chore(devtools): upgrade ods: v0.6.1->v0.6.2 (#8773) 2026-02-26 01:57:25 +00:00
Evan Lohn
320ba9cb1b refactor: filter by persona id during search (#8683) 2026-02-26 01:51:00 +00:00
Nikolas Garza
f2e8cb3114 fix(slack): sanitize HTML tags and broken citation links in bot responses (#8767) 2026-02-26 01:47:44 +00:00
Nikolas Garza
43054a28ec feat(scim): SCIM 2.0 protocol compliance fixes [1/3] (#8745) 2026-02-26 01:33:08 +00:00
Justin Tahara
dc74aa7b1f chore(llm): Add OpenAI Integration Tests (#8711) 2026-02-26 00:58:28 +00:00
Raunak Bhagat
bd773191c2 feat(opal): add more icons (#8778) 2026-02-26 00:38:54 +00:00
Evan Lohn
66dbff41e6 refactor: extend sync mechanism to persona files (#8682) 2026-02-26 00:32:30 +00:00
roshan
1dcffe38bc fix: Invoke generate_agents_md.py in K8s to populate knowledge sources (#8768)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-26 00:04:10 +00:00
Evan Lohn
c35e883564 refactor: persona id in vector db by indexing (#8681) 2026-02-25 22:51:57 +00:00
Jamison Lahman
fefcd58481 chore(devtools): ods web to run web/package.json scripts (#8766)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-02-25 14:05:29 -08:00
Jamison Lahman
bdc89d9e3f chore(fe): opal button implements responsiveHideText (#8764) 2026-02-25 21:05:08 +00:00
Evan Lohn
f4d777b80d refactor: persona id in vector db (#8680) 2026-02-25 20:42:38 +00:00
acaprau
da4d57b5e3 chore(devtools): Make AGENTS.md reference contributing_guides/best_practices.md (#8760)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-02-25 20:27:12 +00:00
Evan Lohn
dcdcd067bd fix: drive 403 rate limits (#8762) 2026-02-25 20:12:36 +00:00
Evan Lohn
8b15a29723 feat: slab connector validation (#8758) 2026-02-25 20:00:42 +00:00
Danelegend
763853674f feat(ci): Add preview modal for data types (#8752) 2026-02-25 19:52:19 +00:00
Jamison Lahman
429b6f3465 fix(fe): modal aligning with detached element after navigation (#8676) 2026-02-25 19:33:07 +00:00
Danelegend
37d5be1b40 feat: python tool not added when no code interpretter server (#8749) 2026-02-25 19:17:42 +00:00
Jamison Lahman
8ab99dbb06 chore(fe): add hover style to AgentCard (#8689) 2026-02-25 19:08:00 +00:00
Jamison Lahman
52799e9c7a fix(fe): middle align human chat message text (#8756) 2026-02-25 19:00:01 +00:00
Jamison Lahman
aef009cc97 chore(fe): foldable buttons display text via tooltip when disabled (#8735) 2026-02-25 18:39:53 +00:00
Evan Lohn
18d1ea1770 fix: sharepoint driveItem perm sync (#8698) 2026-02-25 18:29:26 +00:00
Bo-Onyx
f336ad00f4 fix(user invitation): failed but no warning. (#8731)
Co-authored-by: Bo Yang <boyang@Bos-MacBook-Pro.local>
2026-02-25 17:23:39 +00:00
SubashMohan
0558e687d9 fix: persist onboarding dismissal in localStorage with user-specific keys (#8674) 2026-02-25 06:22:17 +00:00
roshan
784a99e24a updated demo data (#8748) 2026-02-24 19:59:46 -08:00
Justin Tahara
da1f5a11f4 chore(cherry-pick): Alerting on Failed Cherry-Picks (#8744) 2026-02-25 02:09:19 +00:00
Justin Tahara
5633805890 chore(devtools): Upgrade ods from 0.6.0 -> 0.6.1 (#8743) 2026-02-25 02:01:20 +00:00
Danelegend
0817b45ae1 feat: Get code interpreter config route (#8739) 2026-02-25 01:49:30 +00:00
Justin Tahara
af0e4bdebc fix(slack): Cleaning up URL Links (#8569) 2026-02-25 01:42:12 +00:00
Justin Tahara
4cd2320732 chore(cherry-pick): Add Github Label for PRs (#8736) 2026-02-25 00:46:12 +00:00
Danelegend
90a361f0e1 feat: code interpreter routes (#8670) 2026-02-24 16:27:10 -08:00
Justin Tahara
194efde97b chore(llm): Scaffolding for Nightly LLM Tests (#8704) 2026-02-25 00:06:24 +00:00
Danelegend
d922a42262 feat: code interpreter docker default deploy (#8672) 2026-02-24 23:51:19 +00:00
Danelegend
f00c3a486e feat: default deploy code interpreter - helm & bump version 0.3.0 (#8685) 2026-02-24 23:40:46 +00:00
Danelegend
192080c9e4 feat: default deploy code interpreter - restart_script (#8686) 2026-02-24 23:40:36 +00:00
139 changed files with 5597 additions and 1317 deletions

View File

@@ -0,0 +1,73 @@
name: "Build Backend Image"
description: "Builds and pushes the backend Docker image with cache reuse"
inputs:
runs-on-ecr-cache:
description: "ECR cache registry from runs-on/action"
required: true
ref-name:
description: "Git ref name used for cache suffix fallback"
required: true
pr-number:
description: "Optional PR number for cache suffix"
required: false
default: ""
github-sha:
description: "Commit SHA used for cache keys"
required: true
run-id:
description: "GitHub run ID used in output image tag"
required: true
docker-username:
description: "Docker Hub username"
required: true
docker-token:
description: "Docker Hub token"
required: true
docker-no-cache:
description: "Set to 'true' to disable docker build cache"
required: false
default: "false"
runs:
using: "composite"
steps:
- name: Format branch name for cache
id: format-branch
shell: bash
env:
PR_NUMBER: ${{ inputs.pr-number }}
REF_NAME: ${{ inputs.ref-name }}
run: |
if [ -n "${PR_NUMBER}" ]; then
CACHE_SUFFIX="${PR_NUMBER}"
else
# shellcheck disable=SC2001
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
fi
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ inputs.docker-username }}
password: ${{ inputs.docker-token }}
- name: Build and push Backend Docker image
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile
push: true
tags: ${{ inputs.runs-on-ecr-cache }}:nightly-llm-it-backend-${{ inputs.run-id }}
cache-from: |
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ inputs.github-sha }}
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }}
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache
type=registry,ref=onyxdotapp/onyx-backend:latest
cache-to: |
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ inputs.github-sha }},mode=max
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache,mode=max
no-cache: ${{ inputs.docker-no-cache == 'true' }}

View File

@@ -0,0 +1,75 @@
name: "Build Integration Image"
description: "Builds and pushes the integration test image with docker bake"
inputs:
runs-on-ecr-cache:
description: "ECR cache registry from runs-on/action"
required: true
ref-name:
description: "Git ref name used for cache suffix fallback"
required: true
pr-number:
description: "Optional PR number for cache suffix"
required: false
default: ""
github-sha:
description: "Commit SHA used for cache keys"
required: true
run-id:
description: "GitHub run ID used in output image tag"
required: true
docker-username:
description: "Docker Hub username"
required: true
docker-token:
description: "Docker Hub token"
required: true
runs:
using: "composite"
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ inputs.docker-username }}
password: ${{ inputs.docker-token }}
- name: Format branch name for cache
id: format-branch
shell: bash
env:
PR_NUMBER: ${{ inputs.pr-number }}
REF_NAME: ${{ inputs.ref-name }}
run: |
if [ -n "${PR_NUMBER}" ]; then
CACHE_SUFFIX="${PR_NUMBER}"
else
# shellcheck disable=SC2001
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
fi
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
- name: Build and push integration test image with Docker Bake
shell: bash
env:
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
TAG: nightly-llm-it-${{ inputs.run-id }}
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
HEAD_SHA: ${{ inputs.github-sha }}
run: |
docker buildx bake --push \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA},mode=max \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX},mode=max \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA} \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX} \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA},mode=max \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX},mode=max \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
integration

View File

@@ -0,0 +1,68 @@
name: "Build Model Server Image"
description: "Builds and pushes the model server Docker image with cache reuse"
inputs:
runs-on-ecr-cache:
description: "ECR cache registry from runs-on/action"
required: true
ref-name:
description: "Git ref name used for cache suffix fallback"
required: true
pr-number:
description: "Optional PR number for cache suffix"
required: false
default: ""
github-sha:
description: "Commit SHA used for cache keys"
required: true
run-id:
description: "GitHub run ID used in output image tag"
required: true
docker-username:
description: "Docker Hub username"
required: true
docker-token:
description: "Docker Hub token"
required: true
runs:
using: "composite"
steps:
- name: Format branch name for cache
id: format-branch
shell: bash
env:
PR_NUMBER: ${{ inputs.pr-number }}
REF_NAME: ${{ inputs.ref-name }}
run: |
if [ -n "${PR_NUMBER}" ]; then
CACHE_SUFFIX="${PR_NUMBER}"
else
# shellcheck disable=SC2001
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
fi
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ inputs.docker-username }}
password: ${{ inputs.docker-token }}
- name: Build and push Model Server Docker image
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile.model_server
push: true
tags: ${{ inputs.runs-on-ecr-cache }}:nightly-llm-it-model-server-${{ inputs.run-id }}
cache-from: |
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ inputs.github-sha }}
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }}
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache
type=registry,ref=onyxdotapp/onyx-model-server:latest
cache-to: |
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ inputs.github-sha }},mode=max
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache,mode=max

View File

@@ -0,0 +1,120 @@
name: "Run Nightly Provider Chat Test"
description: "Starts required compose services and runs nightly provider integration test"
inputs:
provider:
description: "Provider slug for NIGHTLY_LLM_PROVIDER"
required: true
models:
description: "Comma-separated model list for NIGHTLY_LLM_MODELS"
required: true
provider-api-key:
description: "API key for NIGHTLY_LLM_API_KEY"
required: true
strict:
description: "String true/false for NIGHTLY_LLM_STRICT"
required: true
api-base:
description: "Optional NIGHTLY_LLM_API_BASE"
required: false
default: ""
custom-config-json:
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON"
required: false
default: ""
runs-on-ecr-cache:
description: "ECR cache registry from runs-on/action"
required: true
run-id:
description: "GitHub run ID used in image tags"
required: true
docker-username:
description: "Docker Hub username"
required: true
docker-token:
description: "Docker Hub token"
required: true
runs:
using: "composite"
steps:
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ inputs.docker-username }}
password: ${{ inputs.docker-token }}
- name: Create .env file for Docker Compose
shell: bash
env:
ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
RUN_ID: ${{ inputs.run-id }}
run: |
cat <<EOF2 > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
LICENSE_ENFORCEMENT_ENABLED=false
AUTH_TYPE=basic
POSTGRES_POOL_PRE_PING=true
POSTGRES_USE_NULL_POOL=true
REQUIRE_EMAIL_VERIFICATION=false
DISABLE_TELEMETRY=true
INTEGRATION_TESTS_MODE=true
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
ONYX_BACKEND_IMAGE=${ECR_CACHE}:nightly-llm-it-backend-${RUN_ID}
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:nightly-llm-it-model-server-${RUN_ID}
EOF2
- name: Start Docker containers
shell: bash
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d --wait \
relational_db \
index \
cache \
minio \
api_server \
inference_model_server
- name: Run nightly provider integration test
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
env:
MODELS: ${{ inputs.models }}
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
NIGHTLY_LLM_API_KEY: ${{ inputs.provider-api-key }}
NIGHTLY_LLM_API_BASE: ${{ inputs.api-base }}
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom-config-json }}
NIGHTLY_LLM_STRICT: ${{ inputs.strict }}
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
RUN_ID: ${{ inputs.run-id }}
with:
timeout_minutes: 20
max_attempts: 2
retry_wait_seconds: 10
command: |
if [ -z "${MODELS}" ]; then
echo "Input 'models' must be non-empty for provider '${NIGHTLY_LLM_PROVIDER}'."
exit 1
fi
docker run --rm --network onyx_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e DB_READONLY_USER=db_readonly_user \
-e DB_READONLY_PASSWORD=password \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e TEST_WEB_HOSTNAME=test-runner \
-e NIGHTLY_LLM_PROVIDER="${NIGHTLY_LLM_PROVIDER}" \
-e NIGHTLY_LLM_MODELS="${MODELS}" \
-e NIGHTLY_LLM_API_KEY="${NIGHTLY_LLM_API_KEY}" \
-e NIGHTLY_LLM_API_BASE="${NIGHTLY_LLM_API_BASE}" \
-e NIGHTLY_LLM_CUSTOM_CONFIG_JSON="${NIGHTLY_LLM_CUSTOM_CONFIG_JSON}" \
-e NIGHTLY_LLM_STRICT="${NIGHTLY_LLM_STRICT}" \
${RUNS_ON_ECR_CACHE}:nightly-llm-it-${RUN_ID} \
/app/tests/integration/tests/llm_workflows/test_nightly_provider_chat_workflow.py

View File

@@ -0,0 +1,44 @@
name: Nightly LLM Provider Chat Tests (OpenAI)
concurrency:
group: Nightly-LLM-Provider-Chat-OpenAI-${{ github.workflow }}-${{ github.ref_name }}
cancel-in-progress: true
on:
schedule:
# Runs daily at 10:30 UTC (2:30 AM PST / 3:30 AM PDT)
- cron: "30 10 * * *"
workflow_dispatch:
permissions:
contents: read
jobs:
openai-provider-chat-test:
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
with:
provider: openai
models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
strict: true
secrets:
provider_api_key: ${{ secrets.OPENAI_API_KEY }}
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
notify-slack-on-failure:
needs: [openai-provider-chat-test]
if: failure() && github.event_name == 'schedule'
runs-on: ubuntu-slim
timeout-minutes: 5
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Send Slack notification
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
failed-jobs: openai-provider-chat-test
title: "🚨 Scheduled OpenAI Provider Chat Tests failed!"
ref-name: ${{ github.ref_name }}

View File

@@ -11,6 +11,11 @@ permissions:
jobs:
cherry-pick-to-latest-release:
outputs:
should_cherrypick: ${{ steps.gate.outputs.should_cherrypick }}
pr_number: ${{ steps.gate.outputs.pr_number }}
cherry_pick_reason: ${{ steps.run_cherry_pick.outputs.reason }}
cherry_pick_details: ${{ steps.run_cherry_pick.outputs.details }}
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
@@ -75,10 +80,82 @@ jobs:
git config user.email "github-actions[bot]@users.noreply.github.com"
- name: Create cherry-pick PR to latest release
id: run_cherry_pick
if: steps.gate.outputs.should_cherrypick == 'true'
continue-on-error: true
env:
GH_TOKEN: ${{ github.token }}
GITHUB_TOKEN: ${{ github.token }}
CHERRY_PICK_ASSIGNEE: ${{ steps.gate.outputs.merged_by }}
run: |
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify
set -o pipefail
output_file="$(mktemp)"
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify 2>&1 | tee "$output_file"
exit_code="${PIPESTATUS[0]}"
if [ "${exit_code}" -eq 0 ]; then
echo "status=success" >> "$GITHUB_OUTPUT"
exit 0
fi
echo "status=failure" >> "$GITHUB_OUTPUT"
reason="command-failed"
if grep -qiE "merge conflict during cherry-pick|CONFLICT|could not apply|cherry-pick in progress with staged changes" "$output_file"; then
reason="merge-conflict"
fi
echo "reason=${reason}" >> "$GITHUB_OUTPUT"
{
echo "details<<EOF"
tail -n 40 "$output_file"
echo "EOF"
} >> "$GITHUB_OUTPUT"
- name: Mark workflow as failed if cherry-pick failed
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
run: |
echo "::error::Automated cherry-pick failed (${{ steps.run_cherry_pick.outputs.reason }})."
exit 1
notify-slack-on-cherry-pick-failure:
needs:
- cherry-pick-to-latest-release
if: always() && needs.cherry-pick-to-latest-release.outputs.should_cherrypick == 'true' && needs.cherry-pick-to-latest-release.result != 'success'
runs-on: ubuntu-slim
timeout-minutes: 10
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Build cherry-pick failure summary
id: failure-summary
env:
SOURCE_PR_NUMBER: ${{ needs.cherry-pick-to-latest-release.outputs.pr_number }}
CHERRY_PICK_REASON: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_reason }}
CHERRY_PICK_DETAILS: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_details }}
run: |
source_pr_url="https://github.com/${GITHUB_REPOSITORY}/pull/${SOURCE_PR_NUMBER}"
reason_text="cherry-pick command failed"
if [ "${CHERRY_PICK_REASON}" = "merge-conflict" ]; then
reason_text="merge conflict during cherry-pick"
fi
details_excerpt="$(printf '%s' "${CHERRY_PICK_DETAILS}" | tail -n 8 | tr '\n' ' ' | sed "s/[[:space:]]\\+/ /g" | sed "s/\"/'/g" | cut -c1-350)"
failed_jobs="• cherry-pick-to-latest-release\\n• source PR: ${source_pr_url}\\n• reason: ${reason_text}"
if [ -n "${details_excerpt}" ]; then
failed_jobs="${failed_jobs}\\n• excerpt: ${details_excerpt}"
fi
echo "jobs=${failed_jobs}" >> "$GITHUB_OUTPUT"
- name: Notify #cherry-pick-prs about cherry-pick failure
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
failed-jobs: ${{ steps.failure-summary.outputs.jobs }}
title: "🚨 Automated Cherry-Pick Failed"
ref-name: ${{ github.ref_name }}

View File

@@ -116,7 +116,6 @@ jobs:
run: |
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore,opensearch-enabled
CODE_INTERPRETER_BETA_ENABLED=true
DISABLE_TELEMETRY=true
OPENSEARCH_FOR_ONYX_ENABLED=true
EOF

View File

@@ -20,6 +20,7 @@ env:
# Test Environment Variables
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
SLACK_BOT_TOKEN_TEST_SPACE: ${{ secrets.SLACK_BOT_TOKEN_TEST_SPACE }}
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
@@ -423,6 +424,7 @@ jobs:
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e EXA_API_KEY=${EXA_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e SLACK_BOT_TOKEN_TEST_SPACE=${SLACK_BOT_TOKEN_TEST_SPACE} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
@@ -443,6 +445,7 @@ jobs:
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
-e ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${{ matrix.edition == 'ee' && 'true' || 'false' }} \
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
/app/tests/integration/${{ matrix.test-dir.path }}
@@ -701,6 +704,7 @@ jobs:
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e EXA_API_KEY=${EXA_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e SLACK_BOT_TOKEN_TEST_SPACE=${SLACK_BOT_TOKEN_TEST_SPACE} \
-e TEST_WEB_HOSTNAME=test-runner \
-e AUTH_TYPE=cloud \
-e MULTI_TENANT=true \

View File

@@ -0,0 +1,206 @@
name: Reusable Nightly LLM Provider Chat Tests
on:
workflow_call:
inputs:
provider:
description: "Provider slug passed to NIGHTLY_LLM_PROVIDER (e.g. openai, anthropic)"
required: true
type: string
models:
description: "Comma-separated model list passed to NIGHTLY_LLM_MODELS"
required: true
type: string
strict:
description: "Pass-through value for NIGHTLY_LLM_STRICT"
required: false
default: true
type: boolean
api_base:
description: "Optional NIGHTLY_LLM_API_BASE override"
required: false
default: ""
type: string
custom_config_json:
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON override"
required: false
default: ""
type: string
secrets:
provider_api_key:
description: "Provider API key passed to NIGHTLY_LLM_API_KEY"
required: true
DOCKER_USERNAME:
required: true
DOCKER_TOKEN:
required: true
permissions:
contents: read
env:
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
NIGHTLY_LLM_MODELS: ${{ inputs.models }}
NIGHTLY_LLM_API_KEY: ${{ secrets.provider_api_key }}
NIGHTLY_LLM_API_BASE: ${{ inputs.api_base }}
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom_config_json }}
NIGHTLY_LLM_STRICT: ${{ inputs.strict && 'true' || 'false' }}
jobs:
validate-inputs:
# NOTE: Keep this cheap and fail before image builds if required inputs are missing.
runs-on: ubuntu-slim
timeout-minutes: 5
steps:
- name: Validate required nightly provider inputs
run: |
if [ -z "${NIGHTLY_LLM_MODELS}" ]; then
echo "Input 'models' must be non-empty for provider '${NIGHTLY_LLM_PROVIDER}'."
exit 1
fi
build-backend-image:
needs: [validate-inputs]
runs-on:
[
runs-on,
runner=1cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-backend-image",
"extras=ecr-cache",
]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Build backend image
uses: ./.github/actions/build-backend-image
with:
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
ref-name: ${{ github.ref_name }}
pr-number: ${{ github.event.pull_request.number }}
github-sha: ${{ github.sha }}
run-id: ${{ github.run_id }}
docker-username: ${{ secrets.DOCKER_USERNAME }}
docker-token: ${{ secrets.DOCKER_TOKEN }}
docker-no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' && 'true' || 'false' }}
build-model-server-image:
needs: [validate-inputs]
runs-on:
[
runs-on,
runner=1cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-model-server-image",
"extras=ecr-cache",
]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Build model server image
uses: ./.github/actions/build-model-server-image
with:
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
ref-name: ${{ github.ref_name }}
pr-number: ${{ github.event.pull_request.number }}
github-sha: ${{ github.sha }}
run-id: ${{ github.run_id }}
docker-username: ${{ secrets.DOCKER_USERNAME }}
docker-token: ${{ secrets.DOCKER_TOKEN }}
build-integration-image:
needs: [validate-inputs]
runs-on:
[
runs-on,
runner=2cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-integration-image",
"extras=ecr-cache",
]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Build integration image
uses: ./.github/actions/build-integration-image
with:
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
ref-name: ${{ github.ref_name }}
pr-number: ${{ github.event.pull_request.number }}
github-sha: ${{ github.sha }}
run-id: ${{ github.run_id }}
docker-username: ${{ secrets.DOCKER_USERNAME }}
docker-token: ${{ secrets.DOCKER_TOKEN }}
provider-chat-test:
needs:
[build-backend-image, build-model-server-image, build-integration-image]
runs-on:
- runs-on
- runner=4cpu-linux-arm64
- "run-id=${{ github.run_id }}-nightly-${{ inputs.provider }}-provider-chat-test"
- extras=ecr-cache
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Run nightly provider chat test
uses: ./.github/actions/run-nightly-provider-chat-test
with:
provider: ${{ env.NIGHTLY_LLM_PROVIDER }}
models: ${{ env.NIGHTLY_LLM_MODELS }}
provider-api-key: ${{ secrets.provider_api_key }}
strict: ${{ env.NIGHTLY_LLM_STRICT }}
api-base: ${{ env.NIGHTLY_LLM_API_BASE }}
custom-config-json: ${{ env.NIGHTLY_LLM_CUSTOM_CONFIG_JSON }}
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
run-id: ${{ github.run_id }}
docker-username: ${{ secrets.DOCKER_USERNAME }}
docker-token: ${{ secrets.DOCKER_TOKEN }}
- name: Dump API server logs
if: always()
run: |
cd deployment/docker_compose
docker compose logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
- name: Dump all-container logs
if: always()
run: |
cd deployment/docker_compose
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-all-logs-nightly-${{ inputs.provider }}-llm-provider
path: |
${{ github.workspace }}/api_server.log
${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
if: always()
run: |
cd deployment/docker_compose
docker compose down -v

View File

@@ -548,7 +548,7 @@ class in the utils over directly calling the APIs with a library like `requests`
calling the utilities directly (e.g. do NOT create admin users with
`admin_user = UserManager.create(name="admin_user")`, instead use the `admin_user` fixture).
A great example of this type of test is `backend/tests/integration/dev_apis/test_simple_chat_api.py`.
A great example of this type of test is `backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py`.
To run them:
@@ -616,3 +616,9 @@ This is a minimal list - feel free to include more. Do NOT write code as part of
Keep it high level. You can reference certain files or functions though.
Before writing your plan, make sure to do research. Explore the relevant sections in the codebase.
## Best Practices
In addition to the other content in this file, best practices for contributing
to the codebase can be found at `contributing_guides/best_practices.md`.
Understand its contents and follow them.

View File

@@ -0,0 +1,33 @@
"""add needs_persona_sync to user_file
Revision ID: 8ffcc2bcfc11
Revises: 7616121f6e97
Create Date: 2026-02-23 10:48:48.343826
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "8ffcc2bcfc11"
down_revision = "7616121f6e97"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user_file",
sa.Column(
"needs_persona_sync",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
)
def downgrade() -> None:
op.drop_column("user_file", "needs_persona_sync")

View File

@@ -34,6 +34,7 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert
from ee.onyx.server.scim.filtering import ScimFilter
from ee.onyx.server.scim.filtering import ScimFilterOperator
from ee.onyx.server.scim.models import ScimMappingFields
from onyx.db.dal import DAL
from onyx.db.models import ScimGroupMapping
from onyx.db.models import ScimToken
@@ -128,12 +129,19 @@ class ScimDAL(DAL):
external_id: str,
user_id: UUID,
scim_username: str | None = None,
fields: ScimMappingFields | None = None,
) -> ScimUserMapping:
"""Create a mapping between a SCIM externalId and an Onyx user."""
f = fields or ScimMappingFields()
mapping = ScimUserMapping(
external_id=external_id,
user_id=user_id,
scim_username=scim_username,
department=f.department,
manager=f.manager,
given_name=f.given_name,
family_name=f.family_name,
scim_emails_json=f.scim_emails_json,
)
self._session.add(mapping)
self._session.flush()
@@ -311,8 +319,14 @@ class ScimDAL(DAL):
user_id: UUID,
new_external_id: str | None,
scim_username: str | None = None,
fields: ScimMappingFields | None = None,
) -> None:
"""Create, update, or delete the external ID mapping for a user."""
"""Create, update, or delete the external ID mapping for a user.
When *fields* is provided, all mapping fields are written
unconditionally — including ``None`` values — so that a caller can
clear a previously-set field (e.g. removing a department).
"""
mapping = self.get_user_mapping_by_user_id(user_id)
if new_external_id:
if mapping:
@@ -320,11 +334,18 @@ class ScimDAL(DAL):
mapping.external_id = new_external_id
if scim_username is not None:
mapping.scim_username = scim_username
if fields is not None:
mapping.department = fields.department
mapping.manager = fields.manager
mapping.given_name = fields.given_name
mapping.family_name = fields.family_name
mapping.scim_emails_json = fields.scim_emails_json
else:
self.create_user_mapping(
external_id=new_external_id,
user_id=user_id,
scim_username=scim_username,
fields=fields,
)
elif mapping:
self.delete_user_mapping(mapping.id)

View File

@@ -4,7 +4,6 @@ from collections import deque
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from urllib.parse import unquote
from urllib.parse import urlparse
import requests as _requests
@@ -598,8 +597,12 @@ def get_external_access_from_sharepoint(
)
elif site_page:
site_url = site_page.get("webUrl")
# Prefer server-relative URL to avoid OData filters that break on apostrophes
server_relative_url = unquote(urlparse(site_url).path)
# Keep percent-encoding intact so the path matches the encoding
# used by the Office365 library's SPResPath.create_relative(),
# which compares against urlparse(context.base_url).path.
# Decoding (e.g. %27 → ') causes a mismatch that duplicates
# the site prefix in the constructed URL.
server_relative_url = urlparse(site_url).path
file_obj = client_context.web.get_file_by_server_relative_url(
server_relative_url
)

View File

@@ -26,14 +26,14 @@ from sqlalchemy.orm import Session
from ee.onyx.db.scim import ScimDAL
from ee.onyx.server.scim.auth import verify_scim_token
from ee.onyx.server.scim.filtering import parse_scim_filter
from ee.onyx.server.scim.models import SCIM_LIST_RESPONSE_SCHEMA
from ee.onyx.server.scim.models import ScimError
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimListResponse
from ee.onyx.server.scim.models import ScimMappingFields
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimPatchRequest
from ee.onyx.server.scim.models import ScimResourceType
from ee.onyx.server.scim.models import ScimSchemaDefinition
from ee.onyx.server.scim.models import ScimServiceProviderConfig
from ee.onyx.server.scim.models import ScimUserResource
from ee.onyx.server.scim.patch import apply_group_patch
@@ -41,6 +41,8 @@ from ee.onyx.server.scim.patch import apply_user_patch
from ee.onyx.server.scim.patch import ScimPatchError
from ee.onyx.server.scim.providers.base import get_default_provider
from ee.onyx.server.scim.providers.base import ScimProvider
from ee.onyx.server.scim.providers.base import serialize_emails
from ee.onyx.server.scim.schema_definitions import ENTERPRISE_USER_SCHEMA_DEF
from ee.onyx.server.scim.schema_definitions import GROUP_RESOURCE_TYPE
from ee.onyx.server.scim.schema_definitions import GROUP_SCHEMA_DEF
from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
@@ -48,15 +50,28 @@ from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import ScimToken
from onyx.db.models import ScimUserMapping
from onyx.db.models import User
from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
class ScimJSONResponse(JSONResponse):
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
media_type = "application/scim+json"
# NOTE: All URL paths in this router (/ServiceProviderConfig, /ResourceTypes,
# /Schemas, /Users, /Groups) are mandated by the SCIM spec (RFC 7643/7644).
# IdPs like Okta and Azure AD hardcode these exact paths, so they cannot be
# changed to kebab-case.
scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
_pw_helper = PasswordHelper()
@@ -86,15 +101,39 @@ def get_service_provider_config() -> ScimServiceProviderConfig:
@scim_router.get("/ResourceTypes")
def get_resource_types() -> list[ScimResourceType]:
"""List available SCIM resource types (RFC 7643 §6)."""
return [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
def get_resource_types() -> ScimJSONResponse:
"""List available SCIM resource types (RFC 7643 §6).
Wrapped in a ListResponse envelope (RFC 7644 §3.4.2) because IdPs
like Entra ID expect a JSON object, not a bare array.
"""
resources = [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
return ScimJSONResponse(
content={
"schemas": [SCIM_LIST_RESPONSE_SCHEMA],
"totalResults": len(resources),
"Resources": [
r.model_dump(exclude_none=True, by_alias=True) for r in resources
],
}
)
@scim_router.get("/Schemas")
def get_schemas() -> list[ScimSchemaDefinition]:
"""Return SCIM schema definitions (RFC 7643 §7)."""
return [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF]
def get_schemas() -> ScimJSONResponse:
"""Return SCIM schema definitions (RFC 7643 §7).
Wrapped in a ListResponse envelope (RFC 7644 §3.4.2) because IdPs
like Entra ID expect a JSON object, not a bare array.
"""
schemas = [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF, ENTERPRISE_USER_SCHEMA_DEF]
return ScimJSONResponse(
content={
"schemas": [SCIM_LIST_RESPONSE_SCHEMA],
"totalResults": len(schemas),
"Resources": [s.model_dump(exclude_none=True) for s in schemas],
}
)
# ---------------------------------------------------------------------------
@@ -102,15 +141,45 @@ def get_schemas() -> list[ScimSchemaDefinition]:
# ---------------------------------------------------------------------------
def _scim_error_response(status: int, detail: str) -> JSONResponse:
def _scim_error_response(status: int, detail: str) -> ScimJSONResponse:
"""Build a SCIM-compliant error response (RFC 7644 §3.12)."""
logger.warning("SCIM error response: status=%s detail=%s", status, detail)
body = ScimError(status=str(status), detail=detail)
return JSONResponse(
return ScimJSONResponse(
status_code=status,
content=body.model_dump(exclude_none=True),
)
def _parse_excluded_attributes(raw: str | None) -> set[str]:
"""Parse the ``excludedAttributes`` query parameter (RFC 7644 §3.4.2.5).
Returns a set of lowercased attribute names to omit from responses.
"""
if not raw:
return set()
return {attr.strip().lower() for attr in raw.split(",") if attr.strip()}
def _apply_exclusions(
resource: ScimUserResource | ScimGroupResource,
excluded: set[str],
) -> dict:
"""Serialize a SCIM resource, omitting attributes the IdP excluded.
RFC 7644 §3.4.2.5 lets the IdP pass ``?excludedAttributes=groups,emails``
to reduce response payload size. We strip those fields after serialization
so the rest of the pipeline doesn't need to know about them.
"""
data = resource.model_dump(exclude_none=True, by_alias=True)
for attr in excluded:
# Match case-insensitively against the camelCase field names
keys_to_remove = [k for k in data if k.lower() == attr]
for k in keys_to_remove:
del data[k]
return data
def _check_seat_availability(dal: ScimDAL) -> str | None:
"""Return an error message if seat limit is reached, else None."""
check_fn = fetch_ee_implementation_or_noop(
@@ -124,7 +193,7 @@ def _check_seat_availability(dal: ScimDAL) -> str | None:
return None
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | JSONResponse:
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | ScimJSONResponse:
"""Parse *user_id* as UUID, look up the user, or return a 404 error."""
try:
uid = UUID(user_id)
@@ -144,10 +213,95 @@ def _scim_name_to_str(name: ScimName | None) -> str | None:
"""
if not name:
return None
# Build from givenName/familyName first — IdPs like Okta may send a stale
# ``formatted`` value while updating the individual name components.
# If the client explicitly provides ``formatted``, prefer it — the client
# knows what display string it wants. Otherwise build from components.
if name.formatted:
return name.formatted
parts = " ".join(part for part in [name.givenName, name.familyName] if part)
return parts or name.formatted
return parts or None
def _scim_resource_response(
resource: ScimUserResource | ScimGroupResource | ScimListResponse,
status_code: int = 200,
) -> ScimJSONResponse:
"""Serialize a SCIM resource as ``application/scim+json``."""
content = resource.model_dump(exclude_none=True, by_alias=True)
return ScimJSONResponse(
status_code=status_code,
content=content,
)
def _build_list_response(
resources: list[ScimUserResource | ScimGroupResource],
total: int,
start_index: int,
count: int,
excluded: set[str] | None = None,
) -> ScimListResponse | ScimJSONResponse:
"""Build a SCIM list response, optionally applying attribute exclusions.
RFC 7644 §3.4.2.5 — IdPs may request certain attributes be omitted via
the ``excludedAttributes`` query parameter.
"""
if excluded:
envelope = ScimListResponse(
totalResults=total,
startIndex=start_index,
itemsPerPage=count,
)
data = envelope.model_dump(exclude_none=True)
data["Resources"] = [_apply_exclusions(r, excluded) for r in resources]
return ScimJSONResponse(content=data)
return _scim_resource_response(
ScimListResponse(
totalResults=total,
startIndex=start_index,
itemsPerPage=count,
Resources=resources,
)
)
def _extract_enterprise_fields(
resource: ScimUserResource,
) -> tuple[str | None, str | None]:
"""Extract department and manager from enterprise extension."""
ext = resource.enterprise_extension
if not ext:
return None, None
department = ext.department
manager = ext.manager.value if ext.manager else None
return department, manager
def _mapping_to_fields(
mapping: ScimUserMapping | None,
) -> ScimMappingFields | None:
"""Extract round-trip fields from a SCIM user mapping."""
if not mapping:
return None
return ScimMappingFields(
department=mapping.department,
manager=mapping.manager,
given_name=mapping.given_name,
family_name=mapping.family_name,
scim_emails_json=mapping.scim_emails_json,
)
def _fields_from_resource(resource: ScimUserResource) -> ScimMappingFields:
"""Build mapping fields from an incoming SCIM user resource."""
department, manager = _extract_enterprise_fields(resource)
return ScimMappingFields(
department=department,
manager=manager,
given_name=resource.name.givenName if resource.name else None,
family_name=resource.name.familyName if resource.name else None,
scim_emails_json=serialize_emails(resource.emails),
)
# ---------------------------------------------------------------------------
@@ -158,12 +312,13 @@ def _scim_name_to_str(name: ScimName | None) -> str | None:
@scim_router.get("/Users", response_model=None)
def list_users(
filter: str | None = Query(None),
excludedAttributes: str | None = None,
startIndex: int = Query(1, ge=1),
count: int = Query(100, ge=0, le=500),
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimListResponse | JSONResponse:
) -> ScimListResponse | ScimJSONResponse:
"""List users with optional SCIM filter and pagination."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
@@ -185,42 +340,54 @@ def list_users(
mapping.external_id if mapping else None,
groups=user_groups_map.get(user.id, []),
scim_username=mapping.scim_username if mapping else None,
fields=_mapping_to_fields(mapping),
)
for user, mapping in users_with_mappings
]
return ScimListResponse(
totalResults=total,
startIndex=startIndex,
itemsPerPage=count,
Resources=resources,
return _build_list_response(
resources,
total,
startIndex,
count,
excluded=_parse_excluded_attributes(excludedAttributes),
)
@scim_router.get("/Users/{user_id}", response_model=None)
def get_user(
user_id: str,
excludedAttributes: str | None = None,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
) -> ScimUserResource | ScimJSONResponse:
"""Get a single user by ID."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, JSONResponse):
if isinstance(result, ScimJSONResponse):
return result
user = result
mapping = dal.get_user_mapping_by_user_id(user.id)
return provider.build_user_resource(
resource = provider.build_user_resource(
user,
mapping.external_id if mapping else None,
groups=dal.get_user_groups(user.id),
scim_username=mapping.scim_username if mapping else None,
fields=_mapping_to_fields(mapping),
)
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
excluded = _parse_excluded_attributes(excludedAttributes)
if excluded:
return ScimJSONResponse(content=_apply_exclusions(resource, excluded))
return _scim_resource_response(resource)
@scim_router.post("/Users", status_code=201, response_model=None)
def create_user(
@@ -228,7 +395,7 @@ def create_user(
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
) -> ScimUserResource | ScimJSONResponse:
"""Create a new user from a SCIM provisioning request."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
@@ -270,13 +437,25 @@ def create_user(
# Create SCIM mapping (externalId is validated above, always present)
external_id = user_resource.externalId
scim_username = user_resource.userName.strip()
fields = _fields_from_resource(user_resource)
dal.create_user_mapping(
external_id=external_id, user_id=user.id, scim_username=scim_username
external_id=external_id,
user_id=user.id,
scim_username=scim_username,
fields=fields,
)
dal.commit()
return provider.build_user_resource(user, external_id, scim_username=scim_username)
return _scim_resource_response(
provider.build_user_resource(
user,
external_id,
scim_username=scim_username,
fields=fields,
),
status_code=201,
)
@scim_router.put("/Users/{user_id}", response_model=None)
@@ -286,13 +465,13 @@ def replace_user(
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
) -> ScimUserResource | ScimJSONResponse:
"""Replace a user entirely (RFC 7644 §3.5.1)."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, JSONResponse):
if isinstance(result, ScimJSONResponse):
return result
user = result
@@ -313,15 +492,24 @@ def replace_user(
new_external_id = user_resource.externalId
scim_username = user_resource.userName.strip()
dal.sync_user_external_id(user.id, new_external_id, scim_username=scim_username)
fields = _fields_from_resource(user_resource)
dal.sync_user_external_id(
user.id,
new_external_id,
scim_username=scim_username,
fields=fields,
)
dal.commit()
return provider.build_user_resource(
user,
new_external_id,
groups=dal.get_user_groups(user.id),
scim_username=scim_username,
return _scim_resource_response(
provider.build_user_resource(
user,
new_external_id,
groups=dal.get_user_groups(user.id),
scim_username=scim_username,
fields=fields,
)
)
@@ -332,7 +520,7 @@ def patch_user(
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
) -> ScimUserResource | ScimJSONResponse:
"""Partially update a user (RFC 7644 §3.5.2).
This is the primary endpoint for user deprovisioning — Okta sends
@@ -342,23 +530,25 @@ def patch_user(
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, JSONResponse):
if isinstance(result, ScimJSONResponse):
return result
user = result
mapping = dal.get_user_mapping_by_user_id(user.id)
external_id = mapping.external_id if mapping else None
current_scim_username = mapping.scim_username if mapping else None
current_fields = _mapping_to_fields(mapping)
current = provider.build_user_resource(
user,
external_id,
groups=dal.get_user_groups(user.id),
scim_username=current_scim_username,
fields=current_fields,
)
try:
patched = apply_user_patch(
patched, ent_data = apply_user_patch(
patch_request.Operations, current, provider.ignored_patch_paths
)
except ScimPatchError as e:
@@ -393,17 +583,37 @@ def patch_user(
personal_name=personal_name,
)
# Build updated fields by merging PATCH enterprise data with current values
cf = current_fields or ScimMappingFields()
fields = ScimMappingFields(
department=ent_data.get("department", cf.department),
manager=ent_data.get("manager", cf.manager),
given_name=patched.name.givenName if patched.name else cf.given_name,
family_name=patched.name.familyName if patched.name else cf.family_name,
scim_emails_json=(
serialize_emails(patched.emails)
if patched.emails is not None
else cf.scim_emails_json
),
)
dal.sync_user_external_id(
user.id, patched.externalId, scim_username=new_scim_username
user.id,
patched.externalId,
scim_username=new_scim_username,
fields=fields,
)
dal.commit()
return provider.build_user_resource(
user,
patched.externalId,
groups=dal.get_user_groups(user.id),
scim_username=new_scim_username,
return _scim_resource_response(
provider.build_user_resource(
user,
patched.externalId,
groups=dal.get_user_groups(user.id),
scim_username=new_scim_username,
fields=fields,
)
)
@@ -412,25 +622,29 @@ def delete_user(
user_id: str,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> Response | JSONResponse:
) -> Response | ScimJSONResponse:
"""Delete a user (RFC 7644 §3.6).
Deactivates the user and removes the SCIM mapping. Note that Okta
typically uses PATCH active=false instead of DELETE.
A second DELETE returns 404 per RFC 7644 §3.6.
"""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, JSONResponse):
if isinstance(result, ScimJSONResponse):
return result
user = result
dal.deactivate_user(user)
# If no SCIM mapping exists, the user was already deleted from
# SCIM's perspective — return 404 per RFC 7644 §3.6.
mapping = dal.get_user_mapping_by_user_id(user.id)
if mapping:
dal.delete_user_mapping(mapping.id)
if not mapping:
return _scim_error_response(404, f"User {user_id} not found")
dal.deactivate_user(user)
dal.delete_user_mapping(mapping.id)
dal.commit()
@@ -442,7 +656,7 @@ def delete_user(
# ---------------------------------------------------------------------------
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | JSONResponse:
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | ScimJSONResponse:
"""Parse *group_id* as int, look up the group, or return a 404 error."""
try:
gid = int(group_id)
@@ -497,12 +711,13 @@ def _validate_and_parse_members(
@scim_router.get("/Groups", response_model=None)
def list_groups(
filter: str | None = Query(None),
excludedAttributes: str | None = None,
startIndex: int = Query(1, ge=1),
count: int = Query(100, ge=0, le=500),
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimListResponse | JSONResponse:
) -> ScimListResponse | ScimJSONResponse:
"""List groups with optional SCIM filter and pagination."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
@@ -522,37 +737,46 @@ def list_groups(
for group, ext_id in groups_with_ext_ids
]
return ScimListResponse(
totalResults=total,
startIndex=startIndex,
itemsPerPage=count,
Resources=resources,
return _build_list_response(
resources,
total,
startIndex,
count,
excluded=_parse_excluded_attributes(excludedAttributes),
)
@scim_router.get("/Groups/{group_id}", response_model=None)
def get_group(
group_id: str,
excludedAttributes: str | None = None,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
) -> ScimGroupResource | ScimJSONResponse:
"""Get a single group by ID."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, JSONResponse):
if isinstance(result, ScimJSONResponse):
return result
group = result
mapping = dal.get_group_mapping_by_group_id(group.id)
members = dal.get_group_members(group.id)
return provider.build_group_resource(
resource = provider.build_group_resource(
group, members, mapping.external_id if mapping else None
)
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
excluded = _parse_excluded_attributes(excludedAttributes)
if excluded:
return ScimJSONResponse(content=_apply_exclusions(resource, excluded))
return _scim_resource_response(resource)
@scim_router.post("/Groups", status_code=201, response_model=None)
def create_group(
@@ -560,7 +784,7 @@ def create_group(
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
) -> ScimGroupResource | ScimJSONResponse:
"""Create a new group from a SCIM provisioning request."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
@@ -596,7 +820,10 @@ def create_group(
dal.commit()
members = dal.get_group_members(db_group.id)
return provider.build_group_resource(db_group, members, external_id)
return _scim_resource_response(
provider.build_group_resource(db_group, members, external_id),
status_code=201,
)
@scim_router.put("/Groups/{group_id}", response_model=None)
@@ -606,13 +833,13 @@ def replace_group(
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
) -> ScimGroupResource | ScimJSONResponse:
"""Replace a group entirely (RFC 7644 §3.5.1)."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, JSONResponse):
if isinstance(result, ScimJSONResponse):
return result
group = result
@@ -627,7 +854,9 @@ def replace_group(
dal.commit()
members = dal.get_group_members(group.id)
return provider.build_group_resource(group, members, group_resource.externalId)
return _scim_resource_response(
provider.build_group_resource(group, members, group_resource.externalId)
)
@scim_router.patch("/Groups/{group_id}", response_model=None)
@@ -637,7 +866,7 @@ def patch_group(
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
) -> ScimGroupResource | ScimJSONResponse:
"""Partially update a group (RFC 7644 §3.5.2).
Handles member add/remove operations from Okta and Azure AD.
@@ -646,7 +875,7 @@ def patch_group(
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, JSONResponse):
if isinstance(result, ScimJSONResponse):
return result
group = result
@@ -685,7 +914,9 @@ def patch_group(
dal.commit()
members = dal.get_group_members(group.id)
return provider.build_group_resource(group, members, patched.externalId)
return _scim_resource_response(
provider.build_group_resource(group, members, patched.externalId)
)
@scim_router.delete("/Groups/{group_id}", status_code=204, response_model=None)
@@ -693,13 +924,13 @@ def delete_group(
group_id: str,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> Response | JSONResponse:
) -> Response | ScimJSONResponse:
"""Delete a group (RFC 7644 §3.6)."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, JSONResponse):
if isinstance(result, ScimJSONResponse):
return result
group = result

View File

@@ -7,12 +7,14 @@ SCIM protocol schemas follow the wire format defined in:
Admin API schemas are internal to Onyx and used for SCIM token management.
"""
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import field_validator
# ---------------------------------------------------------------------------
@@ -31,6 +33,9 @@ SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA = (
)
SCIM_RESOURCE_TYPE_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:ResourceType"
SCIM_SCHEMA_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Schema"
SCIM_ENTERPRISE_USER_SCHEMA = (
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
)
# ---------------------------------------------------------------------------
@@ -70,6 +75,36 @@ class ScimUserGroupRef(BaseModel):
display: str | None = None
class ScimManagerRef(BaseModel):
"""Manager sub-attribute for the enterprise extension (RFC 7643 §4.3)."""
value: str | None = None
class ScimEnterpriseExtension(BaseModel):
"""Enterprise User extension attributes (RFC 7643 §4.3)."""
department: str | None = None
manager: ScimManagerRef | None = None
@dataclass
class ScimMappingFields:
"""Stored SCIM mapping fields that need to round-trip through the IdP.
Entra ID sends structured name components, email metadata, and enterprise
extension attributes that must be returned verbatim in subsequent GET
responses. These fields are persisted on ScimUserMapping and threaded
through the DAL, provider, and endpoint layers.
"""
department: str | None = None
manager: str | None = None
given_name: str | None = None
family_name: str | None = None
scim_emails_json: str | None = None
class ScimUserResource(BaseModel):
"""SCIM User resource representation (RFC 7643 §4.1).
@@ -78,6 +113,8 @@ class ScimUserResource(BaseModel):
to match the SCIM wire format (not Python convention).
"""
model_config = ConfigDict(populate_by_name=True)
schemas: list[str] = Field(default_factory=lambda: [SCIM_USER_SCHEMA])
id: str | None = None # Onyx's internal user ID, set on responses
externalId: str | None = None # IdP's identifier for this user
@@ -88,6 +125,10 @@ class ScimUserResource(BaseModel):
active: bool = True
groups: list[ScimUserGroupRef] = Field(default_factory=list)
meta: ScimMeta | None = None
enterprise_extension: ScimEnterpriseExtension | None = Field(
default=None,
alias="urn:ietf:params:scim:schemas:extension:enterprise:2.0:User",
)
class ScimGroupMember(BaseModel):
@@ -165,6 +206,19 @@ class ScimPatchOperation(BaseModel):
path: str | None = None
value: ScimPatchValue = None
@field_validator("op", mode="before")
@classmethod
def normalize_operation(cls, v: object) -> object:
"""Normalize op to lowercase for case-insensitive matching.
Some IdPs (e.g. Entra ID) send capitalized ops like ``"Replace"``
instead of ``"replace"``. This is safe for all providers since the
enum values are lowercase. If a future provider requires other
pre-processing quirks, move patch deserialization into the provider
subclass instead of adding more special cases here.
"""
return v.lower() if isinstance(v, str) else v
class ScimPatchRequest(BaseModel):
"""PATCH request body (RFC 7644 §3.5.2).

View File

@@ -14,8 +14,13 @@ responsible for persisting changes.
from __future__ import annotations
import logging
import re
from dataclasses import dataclass
from dataclasses import field
from typing import Any
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimPatchOperation
@@ -24,6 +29,55 @@ from ee.onyx.server.scim.models import ScimPatchResourceValue
from ee.onyx.server.scim.models import ScimPatchValue
from ee.onyx.server.scim.models import ScimUserResource
logger = logging.getLogger(__name__)
# Lowercased enterprise extension URN for case-insensitive matching
_ENTERPRISE_URN_LOWER = SCIM_ENTERPRISE_USER_SCHEMA.lower()
# Pattern for email filter paths, e.g.:
# emails[primary eq true].value (Okta)
# emails[type eq "work"].value (Azure AD / Entra ID)
_EMAIL_FILTER_RE = re.compile(
r"^emails\[.+\]\.value$",
re.IGNORECASE,
)
# Pattern for member removal path: members[value eq "user-id"]
_MEMBER_FILTER_RE = re.compile(
r'^members\[value\s+eq\s+"([^"]+)"\]$',
re.IGNORECASE,
)
# ---------------------------------------------------------------------------
# Dispatch tables for user PATCH paths
#
# Maps lowercased SCIM path → (camelCase key, target dict name).
# "data" writes to the top-level resource dict, "name" writes to the
# name sub-object dict. This replaces the elif chains for simple fields.
# ---------------------------------------------------------------------------
_USER_REPLACE_PATHS: dict[str, tuple[str, str]] = {
"active": ("active", "data"),
"username": ("userName", "data"),
"externalid": ("externalId", "data"),
"name.givenname": ("givenName", "name"),
"name.familyname": ("familyName", "name"),
"name.formatted": ("formatted", "name"),
}
_USER_REMOVE_PATHS: dict[str, tuple[str, str]] = {
"externalid": ("externalId", "data"),
"name.givenname": ("givenName", "name"),
"name.familyname": ("familyName", "name"),
"name.formatted": ("formatted", "name"),
"displayname": ("displayName", "data"),
}
_GROUP_REPLACE_PATHS: dict[str, tuple[str, str]] = {
"displayname": ("displayName", "data"),
"externalid": ("externalId", "data"),
}
class ScimPatchError(Exception):
"""Raised when a PATCH operation cannot be applied."""
@@ -34,18 +88,25 @@ class ScimPatchError(Exception):
super().__init__(detail)
# Pattern for member removal path: members[value eq "user-id"]
_MEMBER_FILTER_RE = re.compile(
r'^members\[value\s+eq\s+"([^"]+)"\]$',
re.IGNORECASE,
)
@dataclass
class _UserPatchCtx:
"""Bundles the mutable state for user PATCH operations."""
data: dict[str, Any]
name_data: dict[str, Any]
ent_data: dict[str, str | None] = field(default_factory=dict)
# ---------------------------------------------------------------------------
# User PATCH
# ---------------------------------------------------------------------------
def apply_user_patch(
operations: list[ScimPatchOperation],
current: ScimUserResource,
ignored_paths: frozenset[str] = frozenset(),
) -> ScimUserResource:
) -> tuple[ScimUserResource, dict[str, str | None]]:
"""Apply SCIM PATCH operations to a user resource.
Args:
@@ -53,79 +114,185 @@ def apply_user_patch(
current: The current user resource state.
ignored_paths: SCIM attribute paths to silently skip (from provider).
Returns a new ``ScimUserResource`` with the modifications applied.
The original object is not mutated.
Returns:
A tuple of (modified user resource, enterprise extension data dict).
The enterprise dict has keys ``"department"`` and ``"manager"``
with values set only when a PATCH operation touched them.
Raises:
ScimPatchError: If an operation targets an unsupported path.
"""
data = current.model_dump()
name_data = data.get("name") or {}
ctx = _UserPatchCtx(data=data, name_data=data.get("name") or {})
for op in operations:
if op.op == ScimPatchOperationType.REPLACE:
_apply_user_replace(op, data, name_data, ignored_paths)
elif op.op == ScimPatchOperationType.ADD:
_apply_user_replace(op, data, name_data, ignored_paths)
if op.op in (ScimPatchOperationType.REPLACE, ScimPatchOperationType.ADD):
_apply_user_replace(op, ctx, ignored_paths)
elif op.op == ScimPatchOperationType.REMOVE:
_apply_user_remove(op, ctx, ignored_paths)
else:
raise ScimPatchError(
f"Unsupported operation '{op.op.value}' on User resource"
)
data["name"] = name_data
return ScimUserResource.model_validate(data)
ctx.data["name"] = ctx.name_data
return ScimUserResource.model_validate(ctx.data), ctx.ent_data
def _apply_user_replace(
op: ScimPatchOperation,
data: dict,
name_data: dict,
ctx: _UserPatchCtx,
ignored_paths: frozenset[str],
) -> None:
"""Apply a replace/add operation to user data."""
path = (op.path or "").lower()
if not path:
# No path — value is a resource dict of top-level attributes to set
# No path — value is a resource dict of top-level attributes to set.
if isinstance(op.value, ScimPatchResourceValue):
for key, val in op.value.model_dump(exclude_unset=True).items():
_set_user_field(key.lower(), val, data, name_data, ignored_paths)
_set_user_field(key.lower(), val, ctx, ignored_paths, strict=False)
else:
raise ScimPatchError("Replace without path requires a dict value")
return
_set_user_field(path, op.value, data, name_data, ignored_paths)
_set_user_field(path, op.value, ctx, ignored_paths)
def _apply_user_remove(
op: ScimPatchOperation,
ctx: _UserPatchCtx,
ignored_paths: frozenset[str],
) -> None:
"""Apply a remove operation to user data — clears the target field."""
path = (op.path or "").lower()
if not path:
raise ScimPatchError("Remove operation requires a path")
if path in ignored_paths:
return
entry = _USER_REMOVE_PATHS.get(path)
if entry:
key, target = entry
target_dict = ctx.data if target == "data" else ctx.name_data
target_dict[key] = None
return
raise ScimPatchError(f"Unsupported remove path '{path}' for User PATCH")
def _set_user_field(
path: str,
value: ScimPatchValue,
data: dict,
name_data: dict,
ctx: _UserPatchCtx,
ignored_paths: frozenset[str],
*,
strict: bool = True,
) -> None:
"""Set a single field on user data by SCIM path."""
"""Set a single field on user data by SCIM path.
Args:
strict: When ``False`` (path-less replace), unknown attributes are
silently skipped. When ``True`` (explicit path), they raise.
"""
if path in ignored_paths:
return
elif path == "active":
data["active"] = value
elif path == "username":
data["userName"] = value
elif path == "externalid":
data["externalId"] = value
elif path == "name.givenname":
name_data["givenName"] = value
elif path == "name.familyname":
name_data["familyName"] = value
elif path == "name.formatted":
name_data["formatted"] = value
elif path == "displayname":
data["displayName"] = value
name_data["formatted"] = value
# Simple field writes handled by the dispatch table
entry = _USER_REPLACE_PATHS.get(path)
if entry:
key, target = entry
target_dict = ctx.data if target == "data" else ctx.name_data
target_dict[key] = value
return
# displayName sets both the top-level field and the name.formatted sub-field
if path == "displayname":
ctx.data["displayName"] = value
ctx.name_data["formatted"] = value
elif path == "name":
if isinstance(value, dict):
for k, v in value.items():
ctx.name_data[k] = v
elif path == "emails":
if isinstance(value, list):
ctx.data["emails"] = value
elif _EMAIL_FILTER_RE.match(path):
_update_primary_email(ctx.data, value)
elif path.startswith(_ENTERPRISE_URN_LOWER):
_set_enterprise_field(path, value, ctx.ent_data)
elif not strict:
return
else:
raise ScimPatchError(f"Unsupported path '{path}' for User PATCH")
def _update_primary_email(data: dict[str, Any], value: ScimPatchValue) -> None:
"""Update the primary email entry via an email filter path."""
emails: list[dict] = data.get("emails") or []
for email_entry in emails:
if email_entry.get("primary"):
email_entry["value"] = value
break
else:
emails.append({"value": value, "type": "work", "primary": True})
data["emails"] = emails
def _to_dict(value: ScimPatchValue) -> dict | None:
"""Coerce a SCIM patch value to a plain dict if possible.
Pydantic may parse raw dicts as ``ScimPatchResourceValue`` (which uses
``extra="allow"``), so we also dump those back to a dict.
"""
if isinstance(value, dict):
return value
if isinstance(value, ScimPatchResourceValue):
return value.model_dump(exclude_unset=True)
return None
def _set_enterprise_field(
path: str,
value: ScimPatchValue,
ent_data: dict[str, str | None],
) -> None:
"""Handle enterprise extension URN paths or value dicts."""
# Full URN as key with dict value (path-less PATCH)
# e.g. key="urn:...:user", value={"department": "Eng", "manager": {...}}
if path == _ENTERPRISE_URN_LOWER:
d = _to_dict(value)
if d is not None:
if "department" in d:
ent_data["department"] = d["department"]
if "manager" in d:
mgr = d["manager"]
if isinstance(mgr, dict):
ent_data["manager"] = mgr.get("value")
return
# Dotted URN path, e.g. "urn:...:user:department"
suffix = path[len(_ENTERPRISE_URN_LOWER) :].lstrip(":").lower()
if suffix == "department":
ent_data["department"] = str(value) if value is not None else None
elif suffix == "manager":
d = _to_dict(value)
if d is not None:
ent_data["manager"] = d.get("value")
elif isinstance(value, str):
ent_data["manager"] = value
else:
# Unknown enterprise attributes are silently ignored rather than
# rejected — IdPs may send attributes we don't model yet.
logger.warning("Ignoring unknown enterprise extension attribute '%s'", suffix)
# ---------------------------------------------------------------------------
# Group PATCH
# ---------------------------------------------------------------------------
def apply_group_patch(
operations: list[ScimPatchOperation],
current: ScimGroupResource,
@@ -235,12 +402,14 @@ def _set_group_field(
"""Set a single field on group data by SCIM path."""
if path in ignored_paths:
return
elif path == "displayname":
data["displayName"] = value
elif path == "externalid":
data["externalId"] = value
else:
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
entry = _GROUP_REPLACE_PATHS.get(path)
if entry:
key, _ = entry
data[key] = value
return
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
def _apply_group_add(

View File

@@ -2,13 +2,22 @@
from __future__ import annotations
import json
import logging
from abc import ABC
from abc import abstractmethod
from uuid import UUID
from pydantic import ValidationError
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
from ee.onyx.server.scim.models import ScimEmail
from ee.onyx.server.scim.models import ScimEnterpriseExtension
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimManagerRef
from ee.onyx.server.scim.models import ScimMappingFields
from ee.onyx.server.scim.models import ScimMeta
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimUserGroupRef
@@ -17,6 +26,17 @@ from onyx.db.models import User
from onyx.db.models import UserGroup
logger = logging.getLogger(__name__)
COMMON_IGNORED_PATCH_PATHS: frozenset[str] = frozenset(
{
"id",
"schemas",
"meta",
}
)
class ScimProvider(ABC):
"""Base class for provider-specific SCIM behavior.
@@ -41,12 +61,22 @@ class ScimProvider(ABC):
"""
...
@property
def user_schemas(self) -> list[str]:
"""Schema URIs to include in User resource responses.
Override in subclasses to advertise additional schemas (e.g. the
enterprise extension for Entra ID).
"""
return [SCIM_USER_SCHEMA]
def build_user_resource(
self,
user: User,
external_id: str | None = None,
groups: list[tuple[int, str]] | None = None,
scim_username: str | None = None,
fields: ScimMappingFields | None = None,
) -> ScimUserResource:
"""Build a SCIM User response from an Onyx User.
@@ -58,27 +88,48 @@ class ScimProvider(ABC):
for newly-created users.
scim_username: The original-case userName from the IdP. Falls
back to ``user.email`` (lowercase) when not available.
fields: Stored mapping fields that the IdP expects round-tripped.
"""
f = fields or ScimMappingFields()
group_refs = [
ScimUserGroupRef(value=str(gid), display=gname)
for gid, gname in (groups or [])
]
# Use original-case userName if stored, otherwise fall back to the
# lowercased email from the User model.
username = scim_username or user.email
return ScimUserResource(
# Build enterprise extension when at least one value is present.
# Dynamically add the enterprise URN to schemas per RFC 7643 §3.0.
enterprise_ext: ScimEnterpriseExtension | None = None
schemas = list(self.user_schemas)
if f.department is not None or f.manager is not None:
manager_ref = (
ScimManagerRef(value=f.manager) if f.manager is not None else None
)
enterprise_ext = ScimEnterpriseExtension(
department=f.department,
manager=manager_ref,
)
if SCIM_ENTERPRISE_USER_SCHEMA not in schemas:
schemas.append(SCIM_ENTERPRISE_USER_SCHEMA)
name = self.build_scim_name(user, f)
emails = _deserialize_emails(f.scim_emails_json, username)
resource = ScimUserResource(
schemas=schemas,
id=str(user.id),
externalId=external_id,
userName=username,
name=self._build_scim_name(user),
name=name,
displayName=user.personal_name,
emails=[ScimEmail(value=username, type="work", primary=True)],
emails=emails,
active=user.is_active,
groups=group_refs,
meta=ScimMeta(resourceType="User"),
)
resource.enterprise_extension = enterprise_ext
return resource
def build_group_resource(
self,
@@ -98,9 +149,24 @@ class ScimProvider(ABC):
meta=ScimMeta(resourceType="Group"),
)
@staticmethod
def _build_scim_name(user: User) -> ScimName | None:
"""Extract SCIM name components from a user's personal name."""
def build_scim_name(
self,
user: User,
fields: ScimMappingFields,
) -> ScimName | None:
"""Build SCIM name components for the response.
Round-trips stored ``given_name``/``family_name`` when available (so
the IdP gets back what it sent). Falls back to splitting
``personal_name`` for users provisioned before we stored components.
Providers may override for custom behavior.
"""
if fields.given_name is not None or fields.family_name is not None:
return ScimName(
givenName=fields.given_name,
familyName=fields.family_name,
formatted=user.personal_name,
)
if not user.personal_name:
return None
parts = user.personal_name.split(" ", 1)
@@ -111,6 +177,27 @@ class ScimProvider(ABC):
)
def _deserialize_emails(stored_json: str | None, username: str) -> list[ScimEmail]:
"""Deserialize stored email entries or build a default work email."""
if stored_json:
try:
entries = json.loads(stored_json)
if isinstance(entries, list) and entries:
return [ScimEmail(**e) for e in entries]
except (json.JSONDecodeError, TypeError, ValidationError):
logger.warning(
"Corrupt scim_emails_json, falling back to default: %s", stored_json
)
return [ScimEmail(value=username, type="work", primary=True)]
def serialize_emails(emails: list[ScimEmail]) -> str | None:
"""Serialize SCIM email entries to JSON for storage."""
if not emails:
return None
return json.dumps([e.model_dump(exclude_none=True) for e in emails])
def get_default_provider() -> ScimProvider:
"""Return the default SCIM provider.

View File

@@ -0,0 +1,36 @@
"""Entra ID (Azure AD) SCIM provider."""
from __future__ import annotations
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
from ee.onyx.server.scim.providers.base import ScimProvider
_ENTRA_IGNORED_PATCH_PATHS = COMMON_IGNORED_PATCH_PATHS
class EntraProvider(ScimProvider):
"""Entra ID (Azure AD) SCIM provider.
Entra behavioral notes:
- Sends capitalized PATCH ops (``"Add"``, ``"Replace"``, ``"Remove"``)
— handled by ``ScimPatchOperation.normalize_op`` validator.
- Sends the enterprise extension URN as a key in path-less PATCH value
dicts — handled by ``_set_enterprise_field`` in ``patch.py`` to
store department/manager values.
- Expects the enterprise extension schema in ``schemas`` arrays and
``/Schemas`` + ``/ResourceTypes`` discovery endpoints.
"""
@property
def name(self) -> str:
return "entra"
@property
def ignored_patch_paths(self) -> frozenset[str]:
return _ENTRA_IGNORED_PATCH_PATHS
@property
def user_schemas(self) -> list[str]:
return [SCIM_USER_SCHEMA, SCIM_ENTERPRISE_USER_SCHEMA]

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
from ee.onyx.server.scim.providers.base import ScimProvider
@@ -22,4 +23,4 @@ class OktaProvider(ScimProvider):
@property
def ignored_patch_paths(self) -> frozenset[str]:
return frozenset({"id", "schemas", "meta"})
return COMMON_IGNORED_PATCH_PATHS

View File

@@ -4,6 +4,7 @@ Pre-built at import time — these never change at runtime. Separated from
api.py to keep the endpoint module focused on request handling.
"""
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
from ee.onyx.server.scim.models import SCIM_GROUP_SCHEMA
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
from ee.onyx.server.scim.models import ScimResourceType
@@ -20,6 +21,9 @@ USER_RESOURCE_TYPE = ScimResourceType.model_validate(
"endpoint": "/scim/v2/Users",
"description": "SCIM User resource",
"schema": SCIM_USER_SCHEMA,
"schemaExtensions": [
{"schema": SCIM_ENTERPRISE_USER_SCHEMA, "required": False}
],
}
)
@@ -104,6 +108,31 @@ USER_SCHEMA_DEF = ScimSchemaDefinition(
],
)
ENTERPRISE_USER_SCHEMA_DEF = ScimSchemaDefinition(
id=SCIM_ENTERPRISE_USER_SCHEMA,
name="EnterpriseUser",
description="Enterprise User extension (RFC 7643 §4.3)",
attributes=[
ScimSchemaAttribute(
name="department",
type="string",
description="Department.",
),
ScimSchemaAttribute(
name="manager",
type="complex",
description="The user's manager.",
subAttributes=[
ScimSchemaAttribute(
name="value",
type="string",
description="Manager user ID.",
),
],
),
],
)
GROUP_SCHEMA_DEF = ScimSchemaDefinition(
id=SCIM_GROUP_SCHEMA,
name="Group",

View File

@@ -22,6 +22,7 @@ from onyx.document_index.vespa_constants import HIDDEN
from onyx.document_index.vespa_constants import IMAGE_FILE_NAME
from onyx.document_index.vespa_constants import METADATA_LIST
from onyx.document_index.vespa_constants import METADATA_SUFFIX
from onyx.document_index.vespa_constants import PERSONAS
from onyx.document_index.vespa_constants import PRIMARY_OWNERS
from onyx.document_index.vespa_constants import SECONDARY_OWNERS
from onyx.document_index.vespa_constants import SEMANTIC_IDENTIFIER
@@ -58,6 +59,7 @@ FIELDS_NEEDED_FOR_TRANSFORMATION: list[str] = [
METADATA_SUFFIX,
DOCUMENT_SETS,
USER_PROJECT,
PERSONAS,
PRIMARY_OWNERS,
SECONDARY_OWNERS,
ACCESS_CONTROL_LIST,
@@ -276,6 +278,7 @@ def transform_vespa_chunks_to_opensearch_chunks(
)
)
user_projects: list[int] | None = vespa_chunk.get(USER_PROJECT)
personas: list[int] | None = vespa_chunk.get(PERSONAS)
primary_owners: list[str] | None = vespa_chunk.get(PRIMARY_OWNERS)
secondary_owners: list[str] | None = vespa_chunk.get(SECONDARY_OWNERS)
@@ -325,6 +328,7 @@ def transform_vespa_chunks_to_opensearch_chunks(
metadata_suffix=metadata_suffix,
document_sets=document_sets,
user_projects=user_projects,
personas=personas,
primary_owners=primary_owners,
secondary_owners=secondary_owners,
tenant_id=tenant_state,

View File

@@ -12,6 +12,7 @@ from redis import Redis
from redis.lock import Lock as RedisLock
from retry import retry
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
@@ -712,7 +713,10 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
db_session.execute(
select(UserFile.id).where(
sa.and_(
UserFile.needs_project_sync.is_(True),
sa.or_(
UserFile.needs_project_sync.is_(True),
UserFile.needs_persona_sync.is_(True),
),
UserFile.status == UserFileStatus.COMPLETED,
)
)
@@ -772,7 +776,11 @@ def process_single_user_file_project_sync(
try:
with get_session_with_current_tenant() as db_session:
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
user_file = db_session.execute(
select(UserFile)
.where(UserFile.id == _as_uuid(user_file_id))
.options(selectinload(UserFile.assistants))
).scalar_one_or_none()
if not user_file:
task_logger.info(
f"process_single_user_file_project_sync - User file not found id={user_file_id}"
@@ -800,13 +808,17 @@ def process_single_user_file_project_sync(
]
project_ids = [project.id for project in user_file.projects]
persona_ids = [p.id for p in user_file.assistants if not p.deleted]
for retry_document_index in retry_document_indices:
retry_document_index.update_single(
doc_id=str(user_file.id),
tenant_id=tenant_id,
chunk_count=user_file.chunk_count,
fields=None,
user_fields=VespaDocumentUserFields(user_projects=project_ids),
user_fields=VespaDocumentUserFields(
user_projects=project_ids,
personas=persona_ids,
),
)
task_logger.info(
@@ -814,6 +826,7 @@ def process_single_user_file_project_sync(
)
user_file.needs_project_sync = False
user_file.needs_persona_sync = False
user_file.last_project_sync_at = datetime.datetime.now(
datetime.timezone.utc
)

View File

@@ -210,10 +210,10 @@ AUTH_COOKIE_EXPIRE_TIME_SECONDS = int(
REQUIRE_EMAIL_VERIFICATION = (
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
)
SMTP_SERVER = os.environ.get("SMTP_SERVER") or "smtp.gmail.com"
SMTP_SERVER = os.environ.get("SMTP_SERVER") or ""
SMTP_PORT = int(os.environ.get("SMTP_PORT") or "587")
SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
SMTP_USER = os.environ.get("SMTP_USER") or ""
SMTP_PASS = os.environ.get("SMTP_PASS") or ""
EMAIL_FROM = os.environ.get("EMAIL_FROM") or SMTP_USER
SENDGRID_API_KEY = os.environ.get("SENDGRID_API_KEY") or ""
@@ -367,7 +367,7 @@ REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
# this assumes that other redis settings remain the same as the primary
# this assumes that other redis settings remain the same as the primary
REDIS_REPLICA_HOST = os.environ.get("REDIS_REPLICA_HOST") or REDIS_HOST
REDIS_AUTH_KEY_PREFIX = "fastapi_users_token:"

View File

@@ -16,6 +16,22 @@ from onyx.utils.retry_wrapper import retry_builder
logger = setup_logger()
_RATE_LIMIT_REASONS = {"userRateLimitExceeded", "rateLimitExceeded"}
def _is_rate_limit_error(error: HttpError) -> bool:
"""Google sometimes returns rate-limit errors as 403 with reason
'userRateLimitExceeded' instead of 429. This helper detects both."""
if error.resp.status == 429:
return True
if error.resp.status != 403:
return False
error_details = getattr(error, "error_details", None) or []
for detail in error_details:
if isinstance(detail, dict) and detail.get("reason") in _RATE_LIMIT_REASONS:
return True
return "userRateLimitExceeded" in str(error) or "rateLimitExceeded" in str(error)
# Google Drive APIs are quite flakey and may 500 for an
# extended period of time. This is now addressed by checkpointing.
@@ -57,7 +73,7 @@ def _execute_with_retry(request: Any) -> Any:
except HttpError as error:
attempt += 1
if error.resp.status == 429:
if _is_rate_limit_error(error):
# Attempt to get 'Retry-After' from headers
retry_after = error.resp.get("Retry-After")
if retry_after:
@@ -140,16 +156,16 @@ def _execute_single_retrieval(
)
logger.error(f"Error executing request: {e}")
raise e
elif _is_rate_limit_error(e):
results = _execute_with_retry(
lambda: retrieval_function(**request_kwargs).execute()
)
elif e.resp.status == 404 or e.resp.status == 403:
if continue_on_404_or_403:
logger.debug(f"Error executing request: {e}")
results = {}
else:
raise e
elif e.resp.status == 429:
results = _execute_with_retry(
lambda: retrieval_function(**request_kwargs).execute()
)
else:
logger.exception("Error executing request:")
raise e

View File

@@ -147,7 +147,9 @@ class DriveItemData(BaseModel):
self.id,
ResourcePath("items", ResourcePath(self.drive_id, ResourcePath("drives"))),
)
return DriveItem(graph_client, path)
item = DriveItem(graph_client, path)
item.set_property("id", self.id)
return item
# The office365 library's ClientContext caches the access token from its

View File

@@ -11,6 +11,7 @@ from dateutil import parser
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
@@ -258,3 +259,21 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
slim_doc_batch = []
if slim_doc_batch:
yield slim_doc_batch
def validate_connector_settings(self) -> None:
"""
Very basic validation, we could do more here
"""
if not self.base_url.startswith("https://") and not self.base_url.startswith(
"http://"
):
raise ConnectorValidationError(
"Base URL must start with https:// or http://"
)
try:
get_all_post_ids(self.slab_bot_token)
except ConnectorMissingCredentialError:
raise
except Exception as e:
raise ConnectorValidationError(f"Failed to fetch posts from Slab: {e}")

View File

@@ -72,6 +72,7 @@ class BaseFilters(BaseModel):
class UserFileFilters(BaseModel):
user_file_ids: list[UUID] | None = None
project_id: int | None = None
persona_id: int | None = None
class AssistantKnowledgeFilters(BaseModel):

View File

@@ -40,6 +40,7 @@ def _build_index_filters(
user_provided_filters: BaseFilters | None,
user: User, # Used for ACLs, anonymous users only see public docs
project_id: int | None,
persona_id: int | None,
user_file_ids: list[UUID] | None,
persona_document_sets: list[str] | None,
persona_time_cutoff: datetime | None,
@@ -118,6 +119,7 @@ def _build_index_filters(
final_filters = IndexFilters(
user_file_ids=user_file_ids,
project_id=project_id,
persona_id=persona_id,
source_type=source_filter,
document_set=document_set_filter,
time_cutoff=time_filter,
@@ -265,6 +267,8 @@ def search_pipeline(
llm: LLM | None = None,
# If a project ID is provided, it will be exclusively scoped to that project
project_id: int | None = None,
# If a persona_id is provided, search scopes to files attached to this persona
persona_id: int | None = None,
# Pre-fetched data — when provided, avoids DB queries (no session needed)
acl_filters: list[str] | None = None,
embedding_model: EmbeddingModel | None = None,
@@ -299,6 +303,7 @@ def search_pipeline(
user_provided_filters=chunk_search_request.user_selected_filters,
user=user,
project_id=project_id,
persona_id=persona_id,
user_file_ids=user_uploaded_persona_files,
persona_document_sets=persona_document_sets,
persona_time_cutoff=persona_time_cutoff,

View File

@@ -0,0 +1,21 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.models import CodeInterpreterServer
def fetch_code_interpreter_server(
db_session: Session,
) -> CodeInterpreterServer:
server = db_session.scalars(select(CodeInterpreterServer)).one()
return server
def update_code_interpreter_server_enabled(
db_session: Session,
enabled: bool,
) -> CodeInterpreterServer:
server = db_session.scalars(select(CodeInterpreterServer)).one()
server.server_enabled = enabled
db_session.commit()
return server

View File

@@ -4270,6 +4270,9 @@ class UserFile(Base):
needs_project_sync: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)
needs_persona_sync: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)
last_project_sync_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)

View File

@@ -765,6 +765,9 @@ def mark_persona_as_deleted(
) -> None:
persona = get_persona_by_id(persona_id=persona_id, user=user, db_session=db_session)
persona.deleted = True
affected_file_ids = [uf.id for uf in persona.user_files]
if affected_file_ids:
_mark_files_need_persona_sync(db_session, affected_file_ids)
db_session.commit()
@@ -776,11 +779,13 @@ def mark_persona_as_not_deleted(
persona = get_persona_by_id(
persona_id=persona_id, user=user, db_session=db_session, include_deleted=True
)
if persona.deleted:
persona.deleted = False
db_session.commit()
else:
if not persona.deleted:
raise ValueError(f"Persona with ID {persona_id} is not deleted.")
persona.deleted = False
affected_file_ids = [uf.id for uf in persona.user_files]
if affected_file_ids:
_mark_files_need_persona_sync(db_session, affected_file_ids)
db_session.commit()
def mark_delete_persona_by_name(
@@ -846,6 +851,20 @@ def update_personas_display_priority(
db_session.commit()
def _mark_files_need_persona_sync(
db_session: Session,
user_file_ids: list[UUID],
) -> None:
"""Flag the given UserFile rows so the background sync task picks them up
and updates their persona metadata in the vector DB."""
if not user_file_ids:
return
db_session.query(UserFile).filter(UserFile.id.in_(user_file_ids)).update(
{UserFile.needs_persona_sync: True},
synchronize_session=False,
)
def upsert_persona(
user: User | None,
name: str,
@@ -1034,8 +1053,13 @@ def upsert_persona(
existing_persona.tools = tools or []
if user_file_ids is not None:
old_file_ids = {uf.id for uf in existing_persona.user_files}
new_file_ids = {uf.id for uf in (user_files or [])}
affected_file_ids = old_file_ids | new_file_ids
existing_persona.user_files.clear()
existing_persona.user_files = user_files or []
if affected_file_ids:
_mark_files_need_persona_sync(db_session, list(affected_file_ids))
if hierarchy_node_ids is not None:
existing_persona.hierarchy_nodes.clear()
@@ -1089,6 +1113,8 @@ def upsert_persona(
attached_documents=attached_documents or [],
)
db_session.add(new_persona)
if user_files:
_mark_files_need_persona_sync(db_session, [uf.id for uf in user_files])
persona = new_persona
if commit:
db_session.commit()

View File

@@ -2,6 +2,7 @@ import random
from datetime import datetime
from datetime import timedelta
from logging import getLogger
from uuid import UUID
from onyx.configs.constants import MessageType
from onyx.db.chat import create_chat_session
@@ -13,18 +14,26 @@ from onyx.db.models import ChatSession
logger = getLogger(__name__)
def seed_chat_history(num_sessions: int, num_messages: int, days: int) -> None:
def seed_chat_history(
num_sessions: int,
num_messages: int,
days: int,
user_id: UUID | None = None,
persona_id: int | None = None,
) -> None:
"""Utility function to seed chat history for testing.
num_sessions: the number of sessions to seed
num_messages: the number of messages to seed per sessions
days: the number of days looking backwards from the current time over which to randomize
the times.
user_id: optional user to associate with sessions
persona_id: optional persona/assistant to associate with sessions
"""
with get_session_with_current_tenant() as db_session:
logger.info(f"Seeding {num_sessions} sessions.")
for y in range(0, num_sessions):
create_chat_session(db_session, f"pytest_session_{y}", None, None)
create_chat_session(db_session, f"pytest_session_{y}", user_id, persona_id)
# randomize all session times
logger.info(f"Seeding {num_messages} messages per session.")

View File

@@ -3,6 +3,7 @@ from uuid import UUID
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.db.models import UserFile
@@ -64,6 +65,23 @@ def fetch_user_project_ids_for_user_files(
}
def fetch_persona_ids_for_user_files(
user_file_ids: list[str],
db_session: Session,
) -> dict[str, list[int]]:
"""Fetch persona (assistant) ids for specified user files."""
stmt = (
select(UserFile)
.where(UserFile.id.in_(user_file_ids))
.options(selectinload(UserFile.assistants))
)
results = db_session.execute(stmt).scalars().all()
return {
str(user_file.id): [persona.id for persona in user_file.assistants]
for user_file in results
}
def update_last_accessed_at_for_user_files(
user_file_ids: list[UUID],
db_session: Session,

View File

@@ -121,6 +121,7 @@ class VespaDocumentUserFields:
"""
user_projects: list[int] | None = None
personas: list[int] | None = None
@dataclass

View File

@@ -148,6 +148,7 @@ class MetadataUpdateRequest(BaseModel):
hidden: bool | None = None
secondary_index_updated: bool | None = None
project_ids: set[int] | None = None
persona_ids: set[int] | None = None
class IndexRetrievalFilters(BaseModel):

View File

@@ -50,6 +50,7 @@ from onyx.document_index.opensearch.schema import DocumentSchema
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
from onyx.document_index.opensearch.schema import GLOBAL_BOOST_FIELD_NAME
from onyx.document_index.opensearch.schema import HIDDEN_FIELD_NAME
from onyx.document_index.opensearch.schema import PERSONAS_FIELD_NAME
from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME
from onyx.document_index.opensearch.search import DocumentQuery
from onyx.document_index.opensearch.search import (
@@ -215,6 +216,7 @@ def _convert_onyx_chunk_to_opensearch_document(
# OpenSearch and it will not store any data at all for this field, which
# is different from supplying an empty list.
user_projects=chunk.user_project or None,
personas=chunk.personas or None,
primary_owners=get_experts_stores_representations(
chunk.source_document.primary_owners
),
@@ -362,6 +364,11 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
if user_fields and user_fields.user_projects
else None
),
persona_ids=(
set(user_fields.personas)
if user_fields and user_fields.personas
else None
),
)
try:
@@ -709,6 +716,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
properties_to_update[USER_PROJECTS_FIELD_NAME] = list(
update_request.project_ids
)
if update_request.persona_ids is not None:
properties_to_update[PERSONAS_FIELD_NAME] = list(
update_request.persona_ids
)
if not properties_to_update:
if len(update_request.document_ids) > 1:

View File

@@ -41,6 +41,7 @@ IMAGE_FILE_ID_FIELD_NAME = "image_file_id"
SOURCE_LINKS_FIELD_NAME = "source_links"
DOCUMENT_SETS_FIELD_NAME = "document_sets"
USER_PROJECTS_FIELD_NAME = "user_projects"
PERSONAS_FIELD_NAME = "personas"
DOCUMENT_ID_FIELD_NAME = "document_id"
CHUNK_INDEX_FIELD_NAME = "chunk_index"
MAX_CHUNK_SIZE_FIELD_NAME = "max_chunk_size"
@@ -156,6 +157,7 @@ class DocumentChunk(BaseModel):
document_sets: list[str] | None = None
user_projects: list[int] | None = None
personas: list[int] | None = None
primary_owners: list[str] | None = None
secondary_owners: list[str] | None = None
@@ -485,6 +487,7 @@ class DocumentSchema:
# Product-specific fields.
DOCUMENT_SETS_FIELD_NAME: {"type": "keyword"},
USER_PROJECTS_FIELD_NAME: {"type": "integer"},
PERSONAS_FIELD_NAME: {"type": "integer"},
PRIMARY_OWNERS_FIELD_NAME: {"type": "keyword"},
SECONDARY_OWNERS_FIELD_NAME: {"type": "keyword"},
# OpenSearch metadata fields.

View File

@@ -28,6 +28,7 @@ from onyx.document_index.opensearch.schema import HIDDEN_FIELD_NAME
from onyx.document_index.opensearch.schema import LAST_UPDATED_FIELD_NAME
from onyx.document_index.opensearch.schema import MAX_CHUNK_SIZE_FIELD_NAME
from onyx.document_index.opensearch.schema import METADATA_LIST_FIELD_NAME
from onyx.document_index.opensearch.schema import PERSONAS_FIELD_NAME
from onyx.document_index.opensearch.schema import PUBLIC_FIELD_NAME
from onyx.document_index.opensearch.schema import set_or_convert_timezone_to_utc
from onyx.document_index.opensearch.schema import SOURCE_TYPE_FIELD_NAME
@@ -144,6 +145,7 @@ class DocumentQuery:
document_sets=index_filters.document_set or [],
user_file_ids=index_filters.user_file_ids or [],
project_id=index_filters.project_id,
persona_id=index_filters.persona_id,
time_cutoff=index_filters.time_cutoff,
min_chunk_index=min_chunk_index,
max_chunk_index=max_chunk_index,
@@ -202,6 +204,7 @@ class DocumentQuery:
document_sets=[],
user_file_ids=[],
project_id=None,
persona_id=None,
time_cutoff=None,
min_chunk_index=None,
max_chunk_index=None,
@@ -267,6 +270,7 @@ class DocumentQuery:
document_sets=index_filters.document_set or [],
user_file_ids=index_filters.user_file_ids or [],
project_id=index_filters.project_id,
persona_id=index_filters.persona_id,
time_cutoff=index_filters.time_cutoff,
min_chunk_index=None,
max_chunk_index=None,
@@ -334,6 +338,7 @@ class DocumentQuery:
document_sets=index_filters.document_set or [],
user_file_ids=index_filters.user_file_ids or [],
project_id=index_filters.project_id,
persona_id=index_filters.persona_id,
time_cutoff=index_filters.time_cutoff,
min_chunk_index=None,
max_chunk_index=None,
@@ -496,6 +501,7 @@ class DocumentQuery:
document_sets: list[str],
user_file_ids: list[UUID],
project_id: int | None,
persona_id: int | None,
time_cutoff: datetime | None,
min_chunk_index: int | None,
max_chunk_index: int | None,
@@ -530,6 +536,8 @@ class DocumentQuery:
retrieved.
project_id: If not None, only documents with this project ID in user
projects will be retrieved.
persona_id: If not None, only documents whose personas array
contains this persona ID will be retrieved.
time_cutoff: Time cutoff for the documents to retrieve. If not None,
Documents which were last updated before this date will not be
returned. For documents which do not have a value for their last
@@ -627,6 +635,9 @@ class DocumentQuery:
)
return user_project_filter
def _get_persona_filter(persona_id: int) -> dict[str, Any]:
return {"term": {PERSONAS_FIELD_NAME: {"value": persona_id}}}
def _get_time_cutoff_filter(time_cutoff: datetime) -> dict[str, Any]:
# Convert to UTC if not already so the cutoff is comparable to the
# document data.
@@ -780,6 +791,9 @@ class DocumentQuery:
# document's user projects list.
filter_clauses.append(_get_user_project_filter(project_id))
if persona_id is not None:
filter_clauses.append(_get_persona_filter(persona_id))
if time_cutoff is not None:
# If a time cutoff is provided, the caller will only retrieve
# documents where the document was last updated at or after the time

View File

@@ -181,6 +181,11 @@ schema {{ schema_name }} {
rank: filter
attribute: fast-search
}
field personas type array<int> {
indexing: summary | attribute
rank: filter
attribute: fast-search
}
}
# If using different tokenization settings, the fieldset has to be removed, and the field must

View File

@@ -689,6 +689,9 @@ class VespaIndex(DocumentIndex):
project_ids: set[int] | None = None
if user_fields is not None and user_fields.user_projects is not None:
project_ids = set(user_fields.user_projects)
persona_ids: set[int] | None = None
if user_fields is not None and user_fields.personas is not None:
persona_ids = set(user_fields.personas)
update_request = MetadataUpdateRequest(
document_ids=[doc_id],
doc_id_to_chunk_cnt={
@@ -699,6 +702,7 @@ class VespaIndex(DocumentIndex):
boost=fields.boost if fields is not None else None,
hidden=fields.hidden if fields is not None else None,
project_ids=project_ids,
persona_ids=persona_ids,
)
vespa_document_index.update([update_request])

View File

@@ -46,6 +46,7 @@ from onyx.document_index.vespa_constants import METADATA
from onyx.document_index.vespa_constants import METADATA_LIST
from onyx.document_index.vespa_constants import METADATA_SUFFIX
from onyx.document_index.vespa_constants import NUM_THREADS
from onyx.document_index.vespa_constants import PERSONAS
from onyx.document_index.vespa_constants import PRIMARY_OWNERS
from onyx.document_index.vespa_constants import SECONDARY_OWNERS
from onyx.document_index.vespa_constants import SECTION_CONTINUATION
@@ -218,6 +219,7 @@ def _index_vespa_chunk(
# still called `image_file_name` in Vespa for backwards compatibility
IMAGE_FILE_NAME: chunk.image_file_id,
USER_PROJECT: chunk.user_project if chunk.user_project is not None else [],
PERSONAS: chunk.personas if chunk.personas is not None else [],
BOOST: chunk.boost,
AGGREGATED_CHUNK_BOOST_FACTOR: chunk.aggregated_chunk_boost_factor,
}

View File

@@ -12,6 +12,7 @@ from onyx.document_index.vespa_constants import DOCUMENT_ID
from onyx.document_index.vespa_constants import DOCUMENT_SETS
from onyx.document_index.vespa_constants import HIDDEN
from onyx.document_index.vespa_constants import METADATA_LIST
from onyx.document_index.vespa_constants import PERSONAS
from onyx.document_index.vespa_constants import SOURCE_TYPE
from onyx.document_index.vespa_constants import TENANT_ID
from onyx.document_index.vespa_constants import USER_PROJECT
@@ -149,6 +150,18 @@ def build_vespa_filters(
# Vespa YQL 'contains' expects a string literal; quote the integer
return f'({USER_PROJECT} contains "{pid}") and '
def _build_persona_filter(
persona_id: int | None,
) -> str:
if persona_id is None:
return ""
try:
pid = int(persona_id)
except Exception:
logger.warning(f"Invalid persona ID: {persona_id}")
return ""
return f'({PERSONAS} contains "{pid}") and '
# Start building the filter string
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
@@ -192,6 +205,9 @@ def build_vespa_filters(
# User project filter (array<int> attribute membership)
filter_str += _build_user_project_filter(filters.project_id)
# Persona filter (array<int> attribute membership)
filter_str += _build_persona_filter(filters.persona_id)
# Time filter
filter_str += _build_time_filter(filters.time_cutoff)

View File

@@ -183,6 +183,10 @@ def _update_single_chunk(
model_config = {"frozen": True}
assign: list[int]
class _Personas(BaseModel):
model_config = {"frozen": True}
assign: list[int]
class _VespaPutFields(BaseModel):
model_config = {"frozen": True}
# The names of these fields are based the Vespa schema. Changes to the
@@ -193,6 +197,7 @@ def _update_single_chunk(
access_control_list: _AccessControl | None = None
hidden: _Hidden | None = None
user_project: _UserProjects | None = None
personas: _Personas | None = None
class _VespaPutRequest(BaseModel):
model_config = {"frozen": True}
@@ -227,6 +232,11 @@ def _update_single_chunk(
if update_request.project_ids is not None
else None
)
personas_update: _Personas | None = (
_Personas(assign=list(update_request.persona_ids))
if update_request.persona_ids is not None
else None
)
vespa_put_fields = _VespaPutFields(
boost=boost_update,
@@ -234,6 +244,7 @@ def _update_single_chunk(
access_control_list=access_update,
hidden=hidden_update,
user_project=user_projects_update,
personas=personas_update,
)
vespa_put_request = _VespaPutRequest(

View File

@@ -58,6 +58,7 @@ DOCUMENT_SETS = "document_sets"
USER_FILE = "user_file"
USER_FOLDER = "user_folder"
USER_PROJECT = "user_project"
PERSONAS = "personas"
LARGE_CHUNK_REFERENCE_IDS = "large_chunk_reference_ids"
METADATA = "metadata"
METADATA_LIST = "metadata_list"

View File

@@ -146,6 +146,7 @@ class DocumentIndexingBatchAdapter:
doc_id_to_document_set.get(chunk.source_document.id, [])
),
user_project=[],
personas=[],
boost=(
context.id_to_boost_map[chunk.source_document.id]
if chunk.source_document.id in context.id_to_boost_map

View File

@@ -20,6 +20,7 @@ from onyx.db.models import Persona
from onyx.db.models import UserFile
from onyx.db.notification import create_notification
from onyx.db.user_file import fetch_chunk_counts_for_user_files
from onyx.db.user_file import fetch_persona_ids_for_user_files
from onyx.db.user_file import fetch_user_project_ids_for_user_files
from onyx.file_store.utils import store_user_file_plaintext
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
@@ -119,6 +120,10 @@ class UserFileIndexingAdapter:
user_file_ids=updatable_ids,
db_session=self.db_session,
)
user_file_id_to_persona_ids = fetch_persona_ids_for_user_files(
user_file_ids=updatable_ids,
db_session=self.db_session,
)
user_file_id_to_access: dict[str, DocumentAccess] = get_access_for_user_files(
user_file_ids=updatable_ids,
db_session=self.db_session,
@@ -182,7 +187,7 @@ class UserFileIndexingAdapter:
user_project=user_file_id_to_project_ids.get(
chunk.source_document.id, []
),
# we are going to index userfiles only once, so we just set the boost to the default
personas=user_file_id_to_persona_ids.get(chunk.source_document.id, []),
boost=DEFAULT_BOOST,
tenant_id=tenant_id,
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],

View File

@@ -112,6 +112,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
access: "DocumentAccess"
document_sets: set[str]
user_project: list[int]
personas: list[int]
boost: int
aggregated_chunk_boost_factor: float
# Full ancestor path from root hierarchy node to document's parent.
@@ -126,6 +127,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
access: "DocumentAccess",
document_sets: set[str],
user_project: list[int],
personas: list[int],
boost: int,
aggregated_chunk_boost_factor: float,
tenant_id: str,
@@ -137,6 +139,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
access=access,
document_sets=document_sets,
user_project=user_project,
personas=personas,
boost=boost,
aggregated_chunk_boost_factor=aggregated_chunk_boost_factor,
tenant_id=tenant_id,

View File

@@ -97,6 +97,9 @@ from onyx.server.features.web_search.api import router as web_search_router
from onyx.server.federated.api import router as federated_router
from onyx.server.kg.api import admin_router as kg_admin_router
from onyx.server.manage.administrative import router as admin_router
from onyx.server.manage.code_interpreter.api import (
admin_router as code_interpreter_admin_router,
)
from onyx.server.manage.discord_bot.api import router as discord_bot_router
from onyx.server.manage.embedding.api import admin_router as embedding_admin_router
from onyx.server.manage.embedding.api import basic_router as embedding_router
@@ -421,6 +424,9 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
include_router_with_global_prefix_prepended(application, llm_admin_router)
include_router_with_global_prefix_prepended(application, kg_admin_router)
include_router_with_global_prefix_prepended(application, llm_router)
include_router_with_global_prefix_prepended(
application, code_interpreter_admin_router
)
include_router_with_global_prefix_prepended(
application, image_generation_admin_router
)

View File

@@ -592,11 +592,8 @@ def build_slack_response_blocks(
)
citations_blocks = []
document_blocks = []
if answer.citation_info:
citations_blocks = _build_citations_blocks(answer)
else:
document_blocks = _priority_ordered_documents_blocks(answer)
citations_divider = [DividerBlock()] if citations_blocks else []
buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else []
@@ -608,7 +605,6 @@ def build_slack_response_blocks(
+ ai_feedback_block
+ citations_divider
+ citations_blocks
+ document_blocks
+ buttons_divider
+ web_follow_up_block
+ follow_up_block

View File

@@ -1,20 +1,149 @@
import re
from collections.abc import Callable
from typing import Any
from mistune import create_markdown
from mistune import HTMLRenderer
# Tags that should be replaced with a newline (line-break and block-level elements)
_HTML_NEWLINE_TAG_PATTERN = re.compile(
r"<br\s*/?>|</(?:p|div|li|h[1-6]|tr|blockquote|section|article)>",
re.IGNORECASE,
)
# Strips HTML tags but excludes autolinks like <https://...> and <mailto:...>
_HTML_TAG_PATTERN = re.compile(
r"<(?!https?://|mailto:)/?[a-zA-Z][^>]*>",
)
# Matches fenced code blocks (``` ... ```) so we can skip sanitization inside them
_FENCED_CODE_BLOCK_PATTERN = re.compile(r"```[\s\S]*?```")
# Matches the start of any markdown link: [text]( or [[n]](
# The inner group handles nested brackets for citation links like [[1]](.
_MARKDOWN_LINK_PATTERN = re.compile(r"\[(?:[^\[\]]|\[[^\]]*\])*\]\(")
# Matches Slack-style links <url|text> that LLMs sometimes output directly.
# Mistune doesn't recognise this syntax, so text() would escape the angle
# brackets and Slack would render them as literal text instead of links.
_SLACK_LINK_PATTERN = re.compile(r"<(https?://[^|>]+)\|([^>]+)>")
def _sanitize_html(text: str) -> str:
"""Strip HTML tags from a text fragment.
Block-level closing tags and <br> are converted to newlines.
All other HTML tags are removed. Autolinks (<https://...>) are preserved.
"""
text = _HTML_NEWLINE_TAG_PATTERN.sub("\n", text)
text = _HTML_TAG_PATTERN.sub("", text)
return text
def _transform_outside_code_blocks(
message: str, transform: Callable[[str], str]
) -> str:
"""Apply *transform* only to text outside fenced code blocks."""
parts = _FENCED_CODE_BLOCK_PATTERN.split(message)
code_blocks = _FENCED_CODE_BLOCK_PATTERN.findall(message)
result: list[str] = []
for i, part in enumerate(parts):
result.append(transform(part))
if i < len(code_blocks):
result.append(code_blocks[i])
return "".join(result)
def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int | None]:
"""Extract markdown link destination, allowing nested parentheses in the URL."""
depth = 0
i = start_idx
while i < len(message):
curr = message[i]
if curr == "\\":
i += 2
continue
if curr == "(":
depth += 1
elif curr == ")":
if depth == 0:
return message[start_idx:i], i
depth -= 1
i += 1
return message[start_idx:], None
def _normalize_link_destinations(message: str) -> str:
"""Wrap markdown link URLs in angle brackets so the parser handles special chars safely.
Markdown link syntax [text](url) breaks when the URL contains unescaped
parentheses, spaces, or other special characters. Wrapping the URL in angle
brackets — [text](<url>) — tells the parser to treat everything inside as
a literal URL. This applies to all links, not just citations.
"""
if "](" not in message:
return message
normalized_parts: list[str] = []
cursor = 0
while match := _MARKDOWN_LINK_PATTERN.search(message, cursor):
normalized_parts.append(message[cursor : match.end()])
destination_start = match.end()
destination, end_idx = _extract_link_destination(message, destination_start)
if end_idx is None:
normalized_parts.append(message[destination_start:])
return "".join(normalized_parts)
already_wrapped = destination.startswith("<") and destination.endswith(">")
if destination and not already_wrapped:
destination = f"<{destination}>"
normalized_parts.append(destination)
normalized_parts.append(")")
cursor = end_idx + 1
normalized_parts.append(message[cursor:])
return "".join(normalized_parts)
def _convert_slack_links_to_markdown(message: str) -> str:
"""Convert Slack-style <url|text> links to standard markdown [text](url).
LLMs sometimes emit Slack mrkdwn link syntax directly. Mistune doesn't
recognise it, so the angle brackets would be escaped by text() and Slack
would render the link as literal text instead of a clickable link.
"""
return _transform_outside_code_blocks(
message, lambda text: _SLACK_LINK_PATTERN.sub(r"[\2](\1)", text)
)
def format_slack_message(message: str | None) -> str:
if message is None:
return ""
message = _transform_outside_code_blocks(message, _sanitize_html)
message = _convert_slack_links_to_markdown(message)
normalized_message = _normalize_link_destinations(message)
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
result = md(message)
result = md(normalized_message)
# With HTMLRenderer, result is always str (not AST list)
assert isinstance(result, str)
return result
return result.rstrip("\n")
class SlackRenderer(HTMLRenderer):
"""Renders markdown as Slack mrkdwn format instead of HTML.
Overrides all HTMLRenderer methods that produce HTML tags to ensure
no raw HTML ever appears in Slack messages.
"""
SPECIALS: dict[str, str] = {"&": "&amp;", "<": "&lt;", ">": "&gt;"}
def escape_special(self, text: str) -> str:
@@ -23,7 +152,7 @@ class SlackRenderer(HTMLRenderer):
return text
def heading(self, text: str, level: int, **attrs: Any) -> str: # noqa: ARG002
return f"*{text}*\n"
return f"*{text}*\n\n"
def emphasis(self, text: str) -> str:
return f"_{text}_"
@@ -42,7 +171,7 @@ class SlackRenderer(HTMLRenderer):
count += 1
prefix = f"{count}. " if ordered else ""
lines[i] = f"{prefix}{line[4:]}"
return "\n".join(lines)
return "\n".join(lines) + "\n"
def list_item(self, text: str) -> str:
return f"li: {text}\n"
@@ -64,7 +193,30 @@ class SlackRenderer(HTMLRenderer):
return f"`{text}`"
def block_code(self, code: str, info: str | None = None) -> str: # noqa: ARG002
return f"```\n{code}\n```\n"
return f"```\n{code.rstrip(chr(10))}\n```\n\n"
def linebreak(self) -> str:
return "\n"
def thematic_break(self) -> str:
return "---\n\n"
def block_quote(self, text: str) -> str:
lines = text.strip().split("\n")
quoted = "\n".join(f">{line}" for line in lines)
return quoted + "\n\n"
def block_html(self, html: str) -> str:
return _sanitize_html(html) + "\n\n"
def block_error(self, text: str) -> str:
return f"```\n{text}\n```\n\n"
def text(self, text: str) -> str:
# Only escape the three entities Slack recognizes: & < >
# HTMLRenderer.text() also escapes " to &quot; which Slack renders
# as literal &quot; text since Slack doesn't recognize that entity.
return self.escape_special(text)
def paragraph(self, text: str) -> str:
return f"{text}\n"
return f"{text}\n\n"

View File

@@ -1,15 +1,19 @@
#!/usr/bin/env python3
"""Generate AGENTS.md by scanning the files directory and populating the template.
This script runs at container startup, AFTER the init container has synced files
from S3. It scans the /workspace/files directory to discover what knowledge sources
are available and generates appropriate documentation.
This script runs during session setup, AFTER files have been synced from S3
and the files symlink has been created. It reads an existing AGENTS.md (which
contains the {{KNOWLEDGE_SOURCES_SECTION}} placeholder), replaces the
placeholder by scanning the knowledge source directory, and writes it back.
Environment variables:
- AGENT_INSTRUCTIONS: The template content with placeholders to replace
Usage:
python3 generate_agents_md.py <agents_md_path> <files_path>
Arguments:
agents_md_path: Path to the AGENTS.md file to update in place
files_path: Path to the files directory to scan for knowledge sources
"""
import os
import sys
from pathlib import Path
@@ -189,49 +193,39 @@ def build_knowledge_sources_section(files_path: Path) -> str:
def main() -> None:
"""Main entry point for container startup script.
Is called by the container startup script to scan /workspace/files and populate
the knowledge sources section.
Reads an existing AGENTS.md, replaces the {{KNOWLEDGE_SOURCES_SECTION}}
placeholder by scanning the files directory, and writes it back.
Usage:
python3 generate_agents_md.py <agents_md_path> <files_path>
"""
# Read template from environment variable
template = os.environ.get("AGENT_INSTRUCTIONS", "")
if not template:
print("Warning: No AGENT_INSTRUCTIONS template provided", file=sys.stderr)
template = "# Agent Instructions\n\nNo instructions provided."
if len(sys.argv) != 3:
print(
f"Usage: {sys.argv[0]} <agents_md_path> <files_path>",
file=sys.stderr,
)
sys.exit(1)
# Scan files directory - check /workspace/files first, then /workspace/demo_data
files_path = Path("/workspace/files")
demo_data_path = Path("/workspace/demo_data")
agents_md_path = Path(sys.argv[1])
files_path = Path(sys.argv[2])
# Use demo_data if files doesn't exist or is empty
if not files_path.exists() or not any(files_path.iterdir()):
if demo_data_path.exists():
files_path = demo_data_path
if not agents_md_path.exists():
print(f"Error: {agents_md_path} not found", file=sys.stderr)
sys.exit(1)
knowledge_sources_section = build_knowledge_sources_section(files_path)
template = agents_md_path.read_text()
# Replace placeholders
content = template
content = content.replace(
# Resolve symlinks (handles both direct symlinks and dirs containing symlinks)
resolved_files_path = files_path.resolve()
knowledge_sources_section = build_knowledge_sources_section(resolved_files_path)
# Replace placeholder and write back
content = template.replace(
"{{KNOWLEDGE_SOURCES_SECTION}}", knowledge_sources_section
)
# Write AGENTS.md
output_path = Path("/workspace/AGENTS.md")
output_path.write_text(content)
# Log result
source_count = 0
if files_path.exists():
source_count = len(
[
d
for d in files_path.iterdir()
if d.is_dir() and not d.name.startswith(".")
]
)
print(
f"Generated AGENTS.md with {source_count} knowledge sources from {files_path}"
)
agents_md_path.write_text(content)
print(f"Populated knowledge sources in {agents_md_path}")
if __name__ == "__main__":

View File

@@ -1352,6 +1352,9 @@ fi
echo "Writing AGENTS.md"
printf '%s' '{agent_instructions_escaped}' > {session_path}/AGENTS.md
# Populate knowledge sources by scanning the files directory
python3 /usr/local/bin/generate_agents_md.py {session_path}/AGENTS.md {session_path}/files || true
# Write opencode config
echo "Writing opencode.json"
printf '%s' '{opencode_json_escaped}' > {session_path}/opencode.json
@@ -1780,6 +1783,9 @@ ln -sf {symlink_target} {session_path}/files
echo "Writing AGENTS.md"
printf '%s' '{agent_instructions_escaped}' > {session_path}/AGENTS.md
# Populate knowledge sources by scanning the files directory
python3 /usr/local/bin/generate_agents_md.py {session_path}/AGENTS.md {session_path}/files || true
# Write opencode config
echo "Writing opencode.json"
printf '%s' '{opencode_json_escaped}' > {session_path}/opencode.json

View File

@@ -0,0 +1,47 @@
from fastapi import APIRouter
from fastapi import Depends
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.db.code_interpreter import fetch_code_interpreter_server
from onyx.db.code_interpreter import update_code_interpreter_server_enabled
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.server.manage.code_interpreter.models import CodeInterpreterServer
from onyx.server.manage.code_interpreter.models import CodeInterpreterServerHealth
from onyx.tools.tool_implementations.python.code_interpreter_client import (
CodeInterpreterClient,
)
admin_router = APIRouter(prefix="/admin/code-interpreter")
@admin_router.get("/health")
def get_code_interpreter_health(
_: User = Depends(current_admin_user),
) -> CodeInterpreterServerHealth:
try:
client = CodeInterpreterClient()
return CodeInterpreterServerHealth(healthy=client.health())
except ValueError:
return CodeInterpreterServerHealth(healthy=False)
@admin_router.get("")
def get_code_interpreter(
_: User = Depends(current_admin_user), db_session: Session = Depends(get_session)
) -> CodeInterpreterServer:
ci_server = fetch_code_interpreter_server(db_session)
return CodeInterpreterServer(enabled=ci_server.server_enabled)
@admin_router.put("")
def update_code_interpreter(
update: CodeInterpreterServer,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_code_interpreter_server_enabled(
db_session=db_session,
enabled=update.enabled,
)

View File

@@ -0,0 +1,9 @@
from pydantic import BaseModel
class CodeInterpreterServer(BaseModel):
enabled: bool
class CodeInterpreterServerHealth(BaseModel):
healthy: bool

View File

@@ -35,6 +35,18 @@ if TYPE_CHECKING:
pass
class EmailInviteStatus(str, Enum):
SENT = "SENT"
NOT_CONFIGURED = "NOT_CONFIGURED"
SEND_FAILED = "SEND_FAILED"
DISABLED = "DISABLED"
class BulkInviteResponse(BaseModel):
invited_count: int
email_invite_status: EmailInviteStatus
class VersionResponse(BaseModel):
backend_version: str

View File

@@ -36,6 +36,7 @@ from onyx.configs.app_configs import AUTH_BACKEND
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import AuthBackend
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import EMAIL_CONFIGURED
from onyx.configs.app_configs import ENABLE_EMAIL_INVITES
from onyx.configs.app_configs import NUM_FREE_TRIAL_USER_INVITES
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
@@ -78,8 +79,10 @@ from onyx.server.documents.models import PaginatedReturn
from onyx.server.features.projects.models import UserFileSnapshot
from onyx.server.manage.models import AllUsersResponse
from onyx.server.manage.models import AutoScrollRequest
from onyx.server.manage.models import BulkInviteResponse
from onyx.server.manage.models import ChatBackgroundRequest
from onyx.server.manage.models import DefaultAppModeRequest
from onyx.server.manage.models import EmailInviteStatus
from onyx.server.manage.models import MemoryItem
from onyx.server.manage.models import PersonalizationUpdateRequest
from onyx.server.manage.models import TenantInfo
@@ -368,7 +371,7 @@ def bulk_invite_users(
emails: list[str] = Body(..., embed=True),
current_user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> int:
) -> BulkInviteResponse:
"""emails are string validated. If any email fails validation, no emails are
invited and an exception is raised."""
tenant_id = get_current_tenant_id()
@@ -427,34 +430,41 @@ def bulk_invite_users(
number_of_invited_users = write_invited_users(all_emails)
# send out email invitations only to new users (not already invited or existing)
if ENABLE_EMAIL_INVITES:
if not ENABLE_EMAIL_INVITES:
email_invite_status = EmailInviteStatus.DISABLED
elif not EMAIL_CONFIGURED:
email_invite_status = EmailInviteStatus.NOT_CONFIGURED
else:
try:
for email in emails_needing_seats:
send_user_email_invite(email, current_user, AUTH_TYPE)
email_invite_status = EmailInviteStatus.SENT
except Exception as e:
logger.error(f"Error sending email invite to invited users: {e}")
email_invite_status = EmailInviteStatus.SEND_FAILED
if not MULTI_TENANT or DEV_MODE:
return number_of_invited_users
if MULTI_TENANT and not DEV_MODE:
# for billing purposes, write to the control plane about the number of new users
try:
logger.info("Registering tenant users")
fetch_ee_implementation_or_noop(
"onyx.server.tenants.billing", "register_tenant_users", None
)(tenant_id, get_live_users_count(db_session))
except Exception as e:
logger.error(f"Failed to register tenant users: {str(e)}")
logger.info(
"Reverting changes: removing users from tenant and resetting invited users"
)
write_invited_users(initial_invited_users) # Reset to original state
fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
)(new_invited_emails, tenant_id)
raise e
# for billing purposes, write to the control plane about the number of new users
try:
logger.info("Registering tenant users")
fetch_ee_implementation_or_noop(
"onyx.server.tenants.billing", "register_tenant_users", None
)(tenant_id, get_live_users_count(db_session))
return number_of_invited_users
except Exception as e:
logger.error(f"Failed to register tenant users: {str(e)}")
logger.info(
"Reverting changes: removing users from tenant and resetting invited users"
)
write_invited_users(initial_invited_users) # Reset to original state
fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
)(new_invited_emails, tenant_id)
raise e
return BulkInviteResponse(
invited_count=number_of_invited_users,
email_invite_status=email_invite_status,
)
@router.patch("/manage/admin/remove-invited-user", tags=PUBLIC_API_TAGS)

View File

@@ -54,6 +54,7 @@ logger = setup_logger()
class SearchToolConfig(BaseModel):
user_selected_filters: BaseFilters | None = None
project_id: int | None = None
persona_id: int | None = None
bypass_acl: bool = False
additional_context: str | None = None
slack_context: SlackContext | None = None
@@ -180,6 +181,7 @@ def construct_tools(
document_index=document_index,
user_selected_filters=search_tool_config.user_selected_filters,
project_id=search_tool_config.project_id,
persona_id=search_tool_config.persona_id,
bypass_acl=search_tool_config.bypass_acl,
slack_context=search_tool_config.slack_context,
enable_slack_search=search_tool_config.enable_slack_search,
@@ -427,6 +429,7 @@ def construct_tools(
document_index=document_index,
user_selected_filters=search_tool_config.user_selected_filters,
project_id=search_tool_config.project_id,
persona_id=search_tool_config.persona_id,
bypass_acl=search_tool_config.bypass_acl,
slack_context=search_tool_config.slack_context,
enable_slack_search=search_tool_config.enable_slack_search,

View File

@@ -98,6 +98,17 @@ class CodeInterpreterClient:
payload["files"] = files
return payload
def health(self) -> bool:
"""Check if the Code Interpreter service is healthy"""
url = f"{self.base_url}/health"
try:
response = self.session.get(url, timeout=5)
response.raise_for_status()
return response.json().get("status") == "ok"
except Exception as e:
logger.warning(f"Exception caught when checking health, e={e}")
return False
def execute(
self,
code: str,

View File

@@ -12,6 +12,7 @@ from onyx.configs.app_configs import CODE_INTERPRETER_BASE_URL
from onyx.configs.app_configs import CODE_INTERPRETER_DEFAULT_TIMEOUT_MS
from onyx.configs.app_configs import CODE_INTERPRETER_MAX_OUTPUT_LENGTH
from onyx.configs.constants import FileOrigin
from onyx.db.code_interpreter import fetch_code_interpreter_server
from onyx.file_store.utils import build_full_frontend_file_url
from onyx.file_store.utils import get_default_file_store
from onyx.server.query_and_chat.placement import Placement
@@ -103,8 +104,10 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
@override
@classmethod
def is_available(cls, db_session: Session) -> bool:
is_available = bool(CODE_INTERPRETER_BASE_URL)
return is_available
if not CODE_INTERPRETER_BASE_URL:
return False
server = fetch_code_interpreter_server(db_session)
return server.server_enabled
def tool_definition(self) -> dict:
return {

View File

@@ -247,6 +247,8 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
user_selected_filters: BaseFilters | None,
# If the chat is part of a project
project_id: int | None,
# If set, search scopes to files attached to this persona
persona_id: int | None = None,
bypass_acl: bool = False,
# Slack context for federated Slack search (tokens fetched internally)
slack_context: SlackContext | None = None,
@@ -261,6 +263,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
self.document_index = document_index
self.user_selected_filters = user_selected_filters
self.project_id = project_id
self.persona_id = persona_id
self.bypass_acl = bypass_acl
self.slack_context = slack_context
self.enable_slack_search = enable_slack_search
@@ -456,6 +459,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
limit=num_hits,
),
project_id=self.project_id,
persona_id=self.persona_id,
document_index=self.document_index,
user=self.user,
persona=self.persona,

View File

@@ -317,7 +317,7 @@ oauthlib==3.2.2
# via
# kubernetes
# requests-oauthlib
onyx-devtools==0.6.0
onyx-devtools==0.6.2
# via onyx
openai==2.14.0
# via

View File

@@ -95,6 +95,7 @@ def generate_dummy_chunk(
return DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
user_project=[],
personas=[],
access=DocumentAccess.build(
user_emails=user_emails,
user_groups=user_groups,

View File

@@ -3,8 +3,8 @@ set -e
cleanup() {
echo "Error occurred. Cleaning up..."
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
}
# Trap errors and output a message, then cleanup
@@ -20,8 +20,8 @@ MINIO_VOLUME=${4:-""} # Default is empty if not provided
# Stop and remove the existing containers
echo "Stopping and removing existing containers..."
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
# Start the PostgreSQL container with optional volume
echo "Starting PostgreSQL container..."
@@ -55,6 +55,10 @@ else
docker run --detach --name onyx_minio --publish 9004:9000 --publish 9005:9001 -e MINIO_ROOT_USER=minioadmin -e MINIO_ROOT_PASSWORD=minioadmin minio/minio server /data --console-address ":9001"
fi
# Start the Code Interpreter container
echo "Starting Code Interpreter container..."
docker run --detach --name onyx_code_interpreter --publish 8000:8000 --user root -v /var/run/docker.sock:/var/run/docker.sock onyxdotapp/code-interpreter:latest bash ./entrypoint.sh code-interpreter-api
# Ensure alembic runs in the correct directory (backend/)
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
PARENT_DIR="$(dirname "$SCRIPT_DIR")"

View File

@@ -144,7 +144,8 @@ def use_mock_search_pipeline(
auto_detect_filters: bool = False, # noqa: ARG001
llm: LLM | None = None, # noqa: ARG001
project_id: int | None = None, # noqa: ARG001
# Pre-fetched data (used by SearchTool to avoid DB access in parallel)
persona_id: int | None = None, # noqa: ARG001
# Pre-fetched data (used by SearchTool to avoid DB access in parallel calls)
acl_filters: list[str] | None = None, # noqa: ARG001
embedding_model: EmbeddingModel | None = None, # noqa: ARG001
prefetched_federated_retrieval_infos: ( # noqa: ARG001

View File

@@ -38,6 +38,7 @@ def _get_search_filters(
tags=[],
document_sets=[],
project_id=None,
persona_id=None,
time_cutoff=None,
min_chunk_index=None,
max_chunk_index=None,

View File

@@ -0,0 +1,53 @@
"""Tests that PythonTool.is_available() respects the server_enabled DB flag.
Uses a real DB session with CODE_INTERPRETER_BASE_URL mocked so the
environment-variable check passes and the DB flag is the deciding factor.
"""
from unittest.mock import patch
from sqlalchemy.orm import Session
from onyx.db.code_interpreter import fetch_code_interpreter_server
from onyx.db.code_interpreter import update_code_interpreter_server_enabled
from onyx.tools.tool_implementations.python.python_tool import PythonTool
def test_python_tool_unavailable_when_server_disabled(
db_session: Session,
) -> None:
"""With a valid base URL, the tool should be unavailable when
server_enabled is False in the DB."""
server = fetch_code_interpreter_server(db_session)
initial_enabled = server.server_enabled
try:
update_code_interpreter_server_enabled(db_session, enabled=False)
with patch(
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
"http://fake:8888",
):
assert PythonTool.is_available(db_session) is False
finally:
update_code_interpreter_server_enabled(db_session, enabled=initial_enabled)
def test_python_tool_available_when_server_enabled(
db_session: Session,
) -> None:
"""With a valid base URL, the tool should be available when
server_enabled is True in the DB."""
server = fetch_code_interpreter_server(db_session)
initial_enabled = server.server_enabled
try:
update_code_interpreter_server_enabled(db_session, enabled=True)
with patch(
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
"http://fake:8888",
):
assert PythonTool.is_available(db_session) is True
finally:
update_code_interpreter_server_enabled(db_session, enabled=initial_enabled)

View File

@@ -38,5 +38,5 @@ COPY --from=openapi-client /local/onyx_openapi_client /app/generated/onyx_openap
ENV PYTHONPATH=/app
ENTRYPOINT ["pytest", "-s"]
ENTRYPOINT ["pytest", "-s", "-rs"]
CMD ["/app/tests/integration", "--ignore=/app/tests/integration/multitenant_tests"]

View File

@@ -1,3 +1,4 @@
import time
from datetime import datetime
from urllib.parse import urlencode
from uuid import UUID
@@ -8,8 +9,10 @@ from requests.models import CaseInsensitiveDict
from ee.onyx.server.query_history.models import ChatSessionMinimal
from ee.onyx.server.query_history.models import ChatSessionSnapshot
from onyx.configs.constants import QAFeedbackType
from onyx.db.enums import TaskStatus
from onyx.server.documents.models import PaginatedReturn
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.test_models import DATestUser
@@ -69,9 +72,42 @@ class QueryHistoryManager:
if end_time:
query_params["end"] = end_time.isoformat()
response = requests.get(
url=f"{API_SERVER_URL}/admin/query-history-csv?{urlencode(query_params, doseq=True)}",
start_response = requests.post(
url=f"{API_SERVER_URL}/admin/query-history/start-export?{urlencode(query_params, doseq=True)}",
headers=user_performing_action.headers,
)
response.raise_for_status()
return response.headers, response.content.decode()
start_response.raise_for_status()
request_id = start_response.json()["request_id"]
deadline = time.time() + MAX_DELAY
while time.time() < deadline:
status_response = requests.get(
url=f"{API_SERVER_URL}/admin/query-history/export-status",
params={"request_id": request_id},
headers=user_performing_action.headers,
)
status_response.raise_for_status()
status = status_response.json()["status"]
if status == TaskStatus.SUCCESS:
break
if status == TaskStatus.FAILURE:
raise RuntimeError("Query history export task failed")
time.sleep(2)
else:
raise TimeoutError(
f"Query history export not completed within {MAX_DELAY} seconds"
)
download_response = requests.get(
url=f"{API_SERVER_URL}/admin/query-history/download",
params={"request_id": request_id},
headers=user_performing_action.headers,
)
download_response.raise_for_status()
if not download_response.content:
raise RuntimeError(
"Query history CSV download returned zero-length content"
)
return download_response.headers, download_response.content.decode()

View File

@@ -6,16 +6,26 @@ import pytest
from onyx.connectors.slack.models import ChannelType
from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager
# from tests.load_env_vars import load_env_vars
# load_env_vars()
SLACK_ADMIN_EMAIL = os.environ.get("SLACK_ADMIN_EMAIL", "evan@onyx.app")
SLACK_TEST_USER_1_EMAIL = os.environ.get("SLACK_TEST_USER_1_EMAIL", "evan+1@onyx.app")
SLACK_TEST_USER_2_EMAIL = os.environ.get("SLACK_TEST_USER_2_EMAIL", "justin@onyx.app")
@pytest.fixture()
def slack_test_setup() -> Generator[tuple[ChannelType, ChannelType], None, None]:
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
def _provision_slack_channels(
bot_token: str,
) -> Generator[tuple[ChannelType, ChannelType], None, None]:
slack_client = SlackManager.get_slack_client(bot_token)
auth_info = slack_client.auth_test()
print(f"\nSlack workspace: {auth_info.get('team')} ({auth_info.get('url')})")
user_map = SlackManager.build_slack_user_email_id_map(slack_client)
admin_user_id = user_map["admin@example.com"]
if SLACK_ADMIN_EMAIL not in user_map:
raise KeyError(
f"'{SLACK_ADMIN_EMAIL}' not found in Slack workspace. "
f"Available emails: {sorted(user_map.keys())}"
)
admin_user_id = user_map[SLACK_ADMIN_EMAIL]
(
public_channel,
@@ -27,5 +37,16 @@ def slack_test_setup() -> Generator[tuple[ChannelType, ChannelType], None, None]
yield public_channel, private_channel
# This part will always run after the test, even if it fails
SlackManager.cleanup_after_test(slack_client=slack_client, test_id=run_id)
@pytest.fixture()
def slack_test_setup() -> Generator[tuple[ChannelType, ChannelType], None, None]:
yield from _provision_slack_channels(os.environ["SLACK_BOT_TOKEN"])
@pytest.fixture()
def slack_perm_sync_test_setup() -> (
Generator[tuple[ChannelType, ChannelType], None, None]
):
yield from _provision_slack_channels(os.environ["SLACK_BOT_TOKEN_TEST_SPACE"])

View File

@@ -16,7 +16,6 @@ from uuid import uuid4
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from onyx.connectors.slack.connector import default_msg_filter
from onyx.connectors.slack.connector import get_channel_messages
from onyx.connectors.slack.models import ChannelType
from onyx.connectors.slack.utils import make_paginated_slack_api_call
@@ -113,9 +112,6 @@ def _delete_slack_conversation_messages(
channel_id = _get_slack_channel_id(channel)
for message_batch in get_channel_messages(slack_client, channel):
for message in message_batch:
if default_msg_filter(message):
continue
if message_to_delete and message.get("text") != message_to_delete:
continue
print(" removing message: ", message.get("text"))

View File

@@ -22,6 +22,9 @@ from tests.integration.common_utils.test_models import DATestConnector
from tests.integration.common_utils.test_models import DATestCredential
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.vespa import vespa_fixture
from tests.integration.connector_job_tests.slack.conftest import SLACK_ADMIN_EMAIL
from tests.integration.connector_job_tests.slack.conftest import SLACK_TEST_USER_1_EMAIL
from tests.integration.connector_job_tests.slack.conftest import SLACK_TEST_USER_2_EMAIL
from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager
@@ -34,26 +37,24 @@ from tests.integration.connector_job_tests.slack.slack_api_utils import SlackMan
def test_slack_permission_sync(
reset: None, # noqa: ARG001
vespa_client: vespa_fixture, # noqa: ARG001
slack_test_setup: tuple[ChannelType, ChannelType],
slack_perm_sync_test_setup: tuple[ChannelType, ChannelType],
) -> None:
public_channel, private_channel = slack_test_setup
public_channel, private_channel = slack_perm_sync_test_setup
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(
email="admin@example.com",
email=SLACK_ADMIN_EMAIL,
)
# Creating a non-admin user
test_user_1: DATestUser = UserManager.create(
email="test_user_1@example.com",
email=SLACK_TEST_USER_1_EMAIL,
)
# Creating a non-admin user
test_user_2: DATestUser = UserManager.create(
email="test_user_2@example.com",
email=SLACK_TEST_USER_2_EMAIL,
)
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
bot_token = os.environ["SLACK_BOT_TOKEN_TEST_SPACE"]
slack_client = SlackManager.get_slack_client(bot_token)
email_id_map = SlackManager.build_slack_user_email_id_map(slack_client)
admin_user_id = email_id_map[admin_user.email]
@@ -63,7 +64,7 @@ def test_slack_permission_sync(
credential: DATestCredential = CredentialManager.create(
source=DocumentSource.SLACK,
credential_json={
"slack_bot_token": os.environ["SLACK_BOT_TOKEN"],
"slack_bot_token": bot_token,
},
user_performing_action=admin_user,
)
@@ -73,6 +74,7 @@ def test_slack_permission_sync(
source=DocumentSource.SLACK,
connector_specific_config={
"channels": [public_channel["name"], private_channel["name"]],
"include_bot_messages": True,
},
access_type=AccessType.SYNC,
groups=[],
@@ -102,14 +104,11 @@ def test_slack_permission_sync(
public_message = "Steve's favorite number is 809752"
private_message = "Sara's favorite number is 346794"
# Add messages to channels
print(f"\n Adding public message to channel: {public_message}")
SlackManager.add_message_to_channel(
slack_client=slack_client,
channel=public_channel,
message=public_message,
)
print(f"\n Adding private message to channel: {private_message}")
SlackManager.add_message_to_channel(
slack_client=slack_client,
channel=private_channel,
@@ -127,7 +126,9 @@ def test_slack_permission_sync(
user_performing_action=admin_user,
)
# Run permission sync
# Run permission sync. Since initial_index_should_sync=True for Slack,
# permissions were already set during indexing above — the explicit sync
# should find no changes to apply.
CCPairManager.sync(
cc_pair=cc_pair,
user_performing_action=admin_user,
@@ -135,59 +136,38 @@ def test_slack_permission_sync(
CCPairManager.wait_for_sync(
cc_pair=cc_pair,
after=before,
number_of_updated_docs=2,
number_of_updated_docs=0,
user_performing_action=admin_user,
should_wait_for_group_sync=False,
should_wait_for_vespa_sync=False,
)
# Search as admin with access to both channels
print("\nSearching as admin user")
onyx_doc_message_strings = DocumentSearchManager.search_documents(
# Verify admin can see messages from both channels
admin_docs = DocumentSearchManager.search_documents(
query="favorite number",
user_performing_action=admin_user,
)
print(
"\n documents retrieved by admin user: ",
onyx_doc_message_strings,
)
assert public_message in admin_docs
assert private_message in admin_docs
# Ensure admin user can see messages from both channels
assert public_message in onyx_doc_message_strings
assert private_message in onyx_doc_message_strings
# Search as test_user_2 with access to only the public channel
print("\n Searching as test_user_2")
onyx_doc_message_strings = DocumentSearchManager.search_documents(
# Verify test_user_2 can only see public channel messages
user_2_docs = DocumentSearchManager.search_documents(
query="favorite number",
user_performing_action=test_user_2,
)
print(
"\n documents retrieved by test_user_2: ",
onyx_doc_message_strings,
)
assert public_message in user_2_docs
assert private_message not in user_2_docs
# Ensure test_user_2 can only see messages from the public channel
assert public_message in onyx_doc_message_strings
assert private_message not in onyx_doc_message_strings
# Search as test_user_1 with access to both channels
print("\n Searching as test_user_1")
onyx_doc_message_strings = DocumentSearchManager.search_documents(
# Verify test_user_1 can see both channels (member of private channel)
user_1_docs = DocumentSearchManager.search_documents(
query="favorite number",
user_performing_action=test_user_1,
)
print(
"\n documents retrieved by test_user_1 before being removed from private channel: ",
onyx_doc_message_strings,
)
assert public_message in user_1_docs
assert private_message in user_1_docs
# Ensure test_user_1 can see messages from both channels
assert public_message in onyx_doc_message_strings
assert private_message in onyx_doc_message_strings
# ----------------------MAKE THE CHANGES--------------------------
print("\n Removing test_user_1 from the private channel")
before = datetime.now(timezone.utc)
# Remove test_user_1 from the private channel
before = datetime.now(timezone.utc)
desired_channel_members = [admin_user]
SlackManager.set_channel_members(
slack_client=slack_client,
@@ -206,24 +186,16 @@ def test_slack_permission_sync(
after=before,
number_of_updated_docs=1,
user_performing_action=admin_user,
should_wait_for_group_sync=False,
)
# ----------------------------VERIFY THE CHANGES---------------------------
# Ensure test_user_1 can no longer see messages from the private channel
# Search as test_user_1 with access to only the public channel
onyx_doc_message_strings = DocumentSearchManager.search_documents(
# Verify test_user_1 can no longer see private channel after removal
user_1_docs = DocumentSearchManager.search_documents(
query="favorite number",
user_performing_action=test_user_1,
)
print(
"\n documents retrieved by test_user_1 after being removed from private channel: ",
onyx_doc_message_strings,
)
# Ensure test_user_1 can only see messages from the public channel
assert public_message in onyx_doc_message_strings
assert private_message not in onyx_doc_message_strings
assert public_message in user_1_docs
assert private_message not in user_1_docs
# NOTE(rkuo): it isn't yet clear if the reason these were previously xfail'd
@@ -235,21 +207,19 @@ def test_slack_permission_sync(
def test_slack_group_permission_sync(
reset: None, # noqa: ARG001
vespa_client: vespa_fixture, # noqa: ARG001
slack_test_setup: tuple[ChannelType, ChannelType],
slack_perm_sync_test_setup: tuple[ChannelType, ChannelType],
) -> None:
"""
This test ensures that permission sync overrides onyx group access.
"""
public_channel, private_channel = slack_test_setup
public_channel, private_channel = slack_perm_sync_test_setup
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(
email="admin@example.com",
email=SLACK_ADMIN_EMAIL,
)
# Creating a non-admin user
test_user_1: DATestUser = UserManager.create(
email="test_user_1@example.com",
email=SLACK_TEST_USER_1_EMAIL,
)
# Create a user group and adding the non-admin user to it
@@ -264,7 +234,8 @@ def test_slack_group_permission_sync(
user_performing_action=admin_user,
)
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
bot_token = os.environ["SLACK_BOT_TOKEN_TEST_SPACE"]
slack_client = SlackManager.get_slack_client(bot_token)
email_id_map = SlackManager.build_slack_user_email_id_map(slack_client)
admin_user_id = email_id_map[admin_user.email]
@@ -282,7 +253,7 @@ def test_slack_group_permission_sync(
credential = CredentialManager.create(
source=DocumentSource.SLACK,
credential_json={
"slack_bot_token": os.environ["SLACK_BOT_TOKEN"],
"slack_bot_token": bot_token,
},
user_performing_action=admin_user,
)
@@ -294,6 +265,7 @@ def test_slack_group_permission_sync(
source=DocumentSource.SLACK,
connector_specific_config={
"channels": [private_channel["name"]],
"include_bot_messages": True,
},
access_type=AccessType.SYNC,
groups=[user_group.id],
@@ -326,7 +298,8 @@ def test_slack_group_permission_sync(
user_performing_action=admin_user,
)
# Run permission sync
# Run permission sync. Since initial_index_should_sync=True for Slack,
# permissions were already set during indexing — no changes expected.
CCPairManager.sync(
cc_pair=cc_pair,
user_performing_action=admin_user,
@@ -334,8 +307,10 @@ def test_slack_group_permission_sync(
CCPairManager.wait_for_sync(
cc_pair=cc_pair,
after=before,
number_of_updated_docs=1,
number_of_updated_docs=0,
user_performing_action=admin_user,
should_wait_for_group_sync=False,
should_wait_for_vespa_sync=False,
)
# Verify admin can see the message

View File

@@ -4,75 +4,84 @@ import time
import pytest
import requests
from onyx.db.chat import delete_chat_session
from onyx.db.chat import get_chat_sessions_older_than
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.settings import SettingsManager
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestSettings
from tests.integration.common_utils.test_models import DATestUser
RETENTION_SECONDS = 10
def _run_ttl_cleanup(retention_days: int) -> None:
"""Directly execute TTL cleanup logic, bypassing Celery task infrastructure."""
with get_session_with_current_tenant() as db_session:
old_chat_sessions = get_chat_sessions_older_than(retention_days, db_session)
for user_id, session_id in old_chat_sessions:
with get_session_with_current_tenant() as db_session:
delete_chat_session(
user_id,
session_id,
db_session,
include_deleted=True,
hard_delete=True,
)
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Chat retention tests are enterprise only",
)
def test_chat_retention(reset: None, admin_user: DATestUser) -> None: # noqa: ARG001
def test_chat_retention(
reset: None, admin_user: DATestUser, llm_provider: DATestLLMProvider # noqa: ARG001
) -> None: # noqa: ARG001
"""Test that chat sessions are deleted after the retention period expires."""
# Set chat retention period to 10 seconds
retention_days = 10 / 86400 # 10 seconds in days (10 / 24 / 60 / 60)
retention_days = RETENTION_SECONDS // 86400
settings = DATestSettings(maximum_chat_retention_days=retention_days)
SettingsManager.update_settings(settings, user_performing_action=admin_user)
# Create a chat session
chat_session = ChatSessionManager.create(
persona_id=0,
description="Test chat retention",
user_performing_action=admin_user,
)
# Send a message
ChatSessionManager.send_message(
response = ChatSessionManager.send_message(
chat_session_id=chat_session.id,
message="This message should be deleted soon",
user_performing_action=admin_user,
)
assert (
response.error is None
), f"Chat response should not have an error: {response.error}"
# Verify the chat session exists
chat_history = ChatSessionManager.get_chat_history(
chat_session=chat_session,
user_performing_action=admin_user,
)
assert len(chat_history) > 0, "Chat session should have messages"
# Wait for TTL task to run (give it ~60 seconds)
print("Waiting for chat retention TTL task to run...")
max_wait_time = 60 # maximum time to wait in seconds
start_time = time.time()
# Wait for the retention period to elapse, then directly run TTL cleanup
time.sleep(RETENTION_SECONDS + 2)
_run_ttl_cleanup(retention_days)
# Verify the chat session was deleted
session_deleted = False
try:
chat_history = ChatSessionManager.get_chat_history(
chat_session=chat_session,
user_performing_action=admin_user,
)
session_deleted = len(chat_history) == 0
except requests.exceptions.HTTPError as e:
if e.response.status_code in (404, 400):
session_deleted = True
else:
raise
while not session_deleted and (time.time() - start_time < max_wait_time):
# Check if chat session is deleted
try:
# Attempt to get chat history - this should 404
chat_history = ChatSessionManager.get_chat_history(
chat_session=chat_session,
user_performing_action=admin_user,
)
# If we got no messages or an empty response, session might be deleted
if not chat_history:
session_deleted = True
break
except requests.exceptions.HTTPError as e:
# If we get a 404 or other error, the session is gone
if e.response.status_code in (404, 400):
session_deleted = True
break
raise # Re-raise other errors
# Wait a bit before checking again
time.sleep(5)
print(f"Waited {time.time() - start_time:.1f} seconds for chat deletion...")
# Assert that the chat session was deleted
assert session_deleted, "Chat session was not deleted within the expected time"
assert session_deleted, "Chat session was not deleted after retention period"

View File

@@ -0,0 +1,32 @@
from collections.abc import Generator
import pytest
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.test_models import DATestUser
CODE_INTERPRETER_URL = f"{API_SERVER_URL}/admin/code-interpreter"
@pytest.fixture
def preserve_code_interpreter_state(
admin_user: DATestUser,
) -> Generator[None, None, None]:
"""Capture the code interpreter enabled state before a test and restore it
afterwards, so that tests that toggle the setting cannot leak state."""
response = requests.get(
CODE_INTERPRETER_URL,
headers=admin_user.headers,
)
response.raise_for_status()
initial_enabled = response.json()["enabled"]
yield
restore = requests.put(
CODE_INTERPRETER_URL,
json={"enabled": initial_enabled},
headers=admin_user.headers,
)
restore.raise_for_status()

View File

@@ -0,0 +1,97 @@
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.test_models import DATestUser
CODE_INTERPRETER_URL = f"{API_SERVER_URL}/admin/code-interpreter"
CODE_INTERPRETER_HEALTH_URL = f"{CODE_INTERPRETER_URL}/health"
def test_get_code_interpreter_health_as_admin(
admin_user: DATestUser,
) -> None:
"""Health endpoint should return a JSON object with a 'healthy' boolean."""
response = requests.get(
CODE_INTERPRETER_HEALTH_URL,
headers=admin_user.headers,
)
assert response.status_code == 200
data = response.json()
assert "healthy" in data
assert isinstance(data["healthy"], bool)
def test_get_code_interpreter_status_as_admin(
admin_user: DATestUser,
) -> None:
"""GET endpoint should return a JSON object with an 'enabled' boolean."""
response = requests.get(
CODE_INTERPRETER_URL,
headers=admin_user.headers,
)
assert response.status_code == 200
data = response.json()
assert "enabled" in data
assert isinstance(data["enabled"], bool)
def test_update_code_interpreter_disable_and_enable(
admin_user: DATestUser,
preserve_code_interpreter_state: None, # noqa: ARG001
) -> None:
"""PUT endpoint should update the enabled flag and persist across reads."""
# Disable
response = requests.put(
CODE_INTERPRETER_URL,
json={"enabled": False},
headers=admin_user.headers,
)
assert response.status_code == 200
# Verify disabled
response = requests.get(
CODE_INTERPRETER_URL,
headers=admin_user.headers,
)
assert response.status_code == 200
assert response.json()["enabled"] is False
# Re-enable
response = requests.put(
CODE_INTERPRETER_URL,
json={"enabled": True},
headers=admin_user.headers,
)
assert response.status_code == 200
# Verify enabled
response = requests.get(
CODE_INTERPRETER_URL,
headers=admin_user.headers,
)
assert response.status_code == 200
assert response.json()["enabled"] is True
def test_code_interpreter_endpoints_require_admin(
basic_user: DATestUser,
) -> None:
"""All code interpreter endpoints should reject non-admin users."""
health_response = requests.get(
CODE_INTERPRETER_HEALTH_URL,
headers=basic_user.headers,
)
assert health_response.status_code == 403
get_response = requests.get(
CODE_INTERPRETER_URL,
headers=basic_user.headers,
)
assert get_response.status_code == 403
put_response = requests.put(
CODE_INTERPRETER_URL,
json={"enabled": True},
headers=basic_user.headers,
)
assert put_response.status_code == 403

View File

@@ -1,195 +0,0 @@
import os
import pytest
import requests
from onyx.configs.constants import MessageType
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.api_key import APIKeyManager
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.document import DocumentManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestAPIKey
from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestUser
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="/chat/send-message-simple-with-history is enterprise only",
)
def test_all_stream_chat_message_objects_outputs(reset: None) -> None: # noqa: ARG001
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
# create connector
cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch(
user_performing_action=admin_user,
)
api_key: DATestAPIKey = APIKeyManager.create(
user_performing_action=admin_user,
)
LLMProviderManager.create(user_performing_action=admin_user)
# SEEDING DOCUMENTS
cc_pair_1.documents = []
cc_pair_1.documents.append(
DocumentManager.seed_doc_with_content(
cc_pair=cc_pair_1,
content="Pablo's favorite color is blue",
api_key=api_key,
)
)
cc_pair_1.documents.append(
DocumentManager.seed_doc_with_content(
cc_pair=cc_pair_1,
content="Chris's favorite color is red",
api_key=api_key,
)
)
cc_pair_1.documents.append(
DocumentManager.seed_doc_with_content(
cc_pair=cc_pair_1,
content="Pika's favorite color is green",
api_key=api_key,
)
)
# TESTING RESPONSE FOR QUESTION 1
response = requests.post(
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
json={
"messages": [
{
"message": "What is Pablo's favorite color?",
"role": MessageType.USER.value,
}
],
"persona_id": 0,
},
headers=admin_user.headers,
)
assert response.status_code == 200
response_json = response.json()
# check that the answer is correct
answer_1 = response_json["answer"]
assert "blue" in answer_1.lower()
# FLAKY - check that the llm selected a document
# assert 0 in response_json["llm_selected_doc_indices"]
# check that the final context documents are correct
# (it should contain all documents because there arent enough to exclude any)
assert 0 in response_json["final_context_doc_indices"]
assert 1 in response_json["final_context_doc_indices"]
assert 2 in response_json["final_context_doc_indices"]
# FLAKY - check that the cited documents are correct
# assert cc_pair_1.documents[0].id in response_json["cited_documents"].values()
# flakiness likely due to non-deterministic rephrasing
# FLAKY - check that the top documents are correct
# assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[0].id
print("response 1/3 passed")
# TESTING RESPONSE FOR QUESTION 2
response = requests.post(
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
json={
"messages": [
{
"message": "What is Pablo's favorite color?",
"role": MessageType.USER.value,
},
{
"message": answer_1,
"role": MessageType.ASSISTANT.value,
},
{
"message": "What is Chris's favorite color?",
"role": MessageType.USER.value,
},
],
"persona_id": 0,
},
headers=admin_user.headers,
)
assert response.status_code == 200
response_json = response.json()
# check that the answer is correct
answer_2 = response_json["answer"]
assert "red" in answer_2.lower()
# FLAKY - check that the llm selected a document
# assert 0 in response_json["llm_selected_doc_indices"]
# check that the final context documents are correct
# (it should contain all documents because there arent enough to exclude any)
assert 0 in response_json["final_context_doc_indices"]
assert 1 in response_json["final_context_doc_indices"]
assert 2 in response_json["final_context_doc_indices"]
# FLAKY - check that the cited documents are correct
# assert cc_pair_1.documents[1].id in response_json["cited_documents"].values()
# flakiness likely due to non-deterministic rephrasing
# FLAKY - check that the top documents are correct
# assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[1].id
print("response 2/3 passed")
# TESTING RESPONSE FOR QUESTION 3
response = requests.post(
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
json={
"messages": [
{
"message": "What is Pablo's favorite color?",
"role": MessageType.USER.value,
},
{
"message": answer_1,
"role": MessageType.ASSISTANT.value,
},
{
"message": "What is Chris's favorite color?",
"role": MessageType.USER.value,
},
{
"message": answer_2,
"role": MessageType.ASSISTANT.value,
},
{
"message": "What is Pika's favorite color?",
"role": MessageType.USER.value,
},
],
"persona_id": 0,
},
headers=admin_user.headers,
)
assert response.status_code == 200
response_json = response.json()
# check that the answer is correct
answer_3 = response_json["answer"]
assert "green" in answer_3.lower()
# FLAKY - check that the llm selected a document
# assert 0 in response_json["llm_selected_doc_indices"]
# check that the final context documents are correct
# (it should contain all documents because there arent enough to exclude any)
assert 0 in response_json["final_context_doc_indices"]
assert 1 in response_json["final_context_doc_indices"]
assert 2 in response_json["final_context_doc_indices"]
# FLAKY - check that the cited documents are correct
# assert cc_pair_1.documents[2].id in response_json["cited_documents"].values()
# flakiness likely due to non-deterministic rephrasing
# FLAKY - check that the top documents are correct
# assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[2].id
print("response 3/3 passed")

View File

@@ -1,250 +0,0 @@
import json
import os
import pytest
import requests
from onyx.configs.constants import MessageType
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import NUM_DOCS
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.conftest import DocumentBuilderType
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="/chat/send-message-simple-with-history tests are enterprise only",
)
def test_send_message_simple_with_history(
reset: None, # noqa: ARG001
admin_user: DATestUser,
llm_provider: DATestLLMProvider, # noqa: ARG001
document_builder: DocumentBuilderType,
) -> None:
# create documents using the document builder
# Create NUM_DOCS number of documents with dummy content
content_list = [f"Document {i} content" for i in range(NUM_DOCS)]
docs = document_builder(content_list)
response = requests.post(
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
json={
"messages": [
{
"message": docs[0].content,
"role": MessageType.USER.value,
}
],
"persona_id": 0,
},
headers=admin_user.headers,
)
assert response.status_code == 200
response_json = response.json()
# Check that the top document is the correct document
assert response_json["top_documents"][0]["document_id"] == docs[0].id
# assert that the metadata is correct
for doc in docs:
found_doc = next(
(x for x in response_json["top_documents"] if x["document_id"] == doc.id),
None,
)
assert found_doc
assert found_doc["metadata"]["document_id"] == doc.id
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="/chat/send-message-simple-with-history tests are enterprise only",
)
def test_using_reference_docs_with_simple_with_history_api_flow(
reset: None, # noqa: ARG001
admin_user: DATestUser,
llm_provider: DATestLLMProvider, # noqa: ARG001
document_builder: DocumentBuilderType,
) -> None:
# SEEDING DOCUMENTS
docs = document_builder(
[
"Chris's favorite color is blue",
"Hagen's favorite color is red",
"Pablo's favorite color is green",
]
)
# SEINDING MESSAGE 1
response = requests.post(
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
json={
"messages": [
{
"message": "What is Pablo's favorite color?",
"role": MessageType.USER.value,
}
],
"persona_id": 0,
},
headers=admin_user.headers,
)
assert response.status_code == 200
response_json = response.json()
# get the db_doc_id of the top document to use as a search doc id for second message
first_db_doc_id = response_json["top_documents"][0]["db_doc_id"]
# SEINDING MESSAGE 2
response = requests.post(
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
json={
"messages": [
{
"message": "What is Pablo's favorite color?",
"role": MessageType.USER.value,
}
],
"persona_id": 0,
"search_doc_ids": [first_db_doc_id],
},
headers=admin_user.headers,
)
assert response.status_code == 200
response_json = response.json()
# make sure there is an answer
assert response_json["answer"]
# This ensures the the document we think we are referencing when we send the search_doc_ids in the second
# message is the document that we expect it to be
assert response_json["top_documents"][0]["document_id"] == docs[2].id
@pytest.mark.skip(reason="We don't support this anymore with the DR flow :(")
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="/chat/send-message-simple-with-history tests are enterprise only",
)
def test_send_message_simple_with_history_strict_json(
reset: None, # noqa: ARG001
admin_user: DATestUser,
llm_provider: DATestLLMProvider, # noqa: ARG001
) -> None:
response = requests.post(
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
json={
# intentionally not relevant prompt to ensure that the
# structured response format is actually used
"messages": [
{
"message": "What is green?",
"role": MessageType.USER.value,
}
],
"persona_id": 0,
"structured_response_format": {
"type": "json_schema",
"json_schema": {
"name": "presidents",
"schema": {
"type": "object",
"properties": {
"presidents": {
"type": "array",
"items": {"type": "string"},
"description": "List of the first three US presidents",
}
},
"required": ["presidents"],
"additionalProperties": False,
},
"strict": True,
},
},
},
headers=admin_user.headers,
)
assert response.status_code == 200
response_json = response.json()
# Check that the answer is present
assert "answer" in response_json
assert response_json["answer"] is not None
# helper
def clean_json_string(json_string: str) -> str:
return json_string.strip().removeprefix("```json").removesuffix("```").strip()
# Attempt to parse the answer as JSON
try:
clean_answer = clean_json_string(response_json["answer"])
parsed_answer = json.loads(clean_answer)
# NOTE: do not check content, just the structure
assert isinstance(parsed_answer, dict)
assert "presidents" in parsed_answer
assert isinstance(parsed_answer["presidents"], list)
for president in parsed_answer["presidents"]:
assert isinstance(president, str)
except json.JSONDecodeError:
assert (
False
), f"The answer is not a valid JSON object - '{response_json['answer']}'"
# Check that the answer_citationless is also valid JSON
assert "answer_citationless" in response_json
assert response_json["answer_citationless"] is not None
try:
clean_answer_citationless = clean_json_string(
response_json["answer_citationless"]
)
parsed_answer_citationless = json.loads(clean_answer_citationless)
assert isinstance(parsed_answer_citationless, dict)
except json.JSONDecodeError:
assert False, "The answer_citationless is not a valid JSON object"
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="/query/answer-with-citation tests are enterprise only",
)
def test_answer_with_citation_api(
reset: None, # noqa: ARG001
admin_user: DATestUser,
llm_provider: DATestLLMProvider, # noqa: ARG001
document_builder: DocumentBuilderType,
) -> None:
# create docs
docs = document_builder(["Chris' favorite color is green"])
# send a message
response = requests.post(
f"{API_SERVER_URL}/query/answer-with-citation",
json={
"messages": [
{
"message": "What is Chris' favorite color? Make sure to cite the document.",
"role": MessageType.USER.value,
}
],
"persona_id": 0,
},
headers=admin_user.headers,
cookies=admin_user.cookies,
)
assert response.status_code == 200
response_json = response.json()
assert response_json["answer"]
has_correct_citation = False
for citation in response_json["citations"]:
if citation["document_id"] == docs[0].id:
has_correct_citation = True
break
assert has_correct_citation

View File

@@ -2,7 +2,6 @@ import os
import uuid
from datetime import datetime
from datetime import timezone
from unittest.mock import patch
import httpx
import pytest
@@ -12,6 +11,7 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.mock_connector.connector import EXTERNAL_USER_EMAILS
from onyx.connectors.mock_connector.connector import EXTERNAL_USER_GROUP_IDS
from onyx.connectors.mock_connector.connector import MockConnectorCheckpoint
from onyx.connectors.models import Document
from onyx.connectors.models import InputType
from onyx.db.document import get_documents_by_ids
from onyx.db.engine.sql_engine import get_session_with_current_tenant
@@ -25,128 +25,16 @@ from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.document import DocumentManager
from tests.integration.common_utils.managers.index_attempt import IndexAttemptManager
from tests.integration.common_utils.test_document_utils import create_test_document
from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.vespa import vespa_fixture
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Permission sync is enterprise only",
)
def test_mock_connector_initial_permission_sync(
def _setup_mock_connector(
mock_server_client: httpx.Client,
vespa_client: vespa_fixture,
admin_user: DATestUser,
) -> None:
"""Test that the MockConnector fetches and sets permissions during initial indexing when AccessType.SYNC is used"""
# Set up mock server behavior
doc_uuid = uuid.uuid4()
test_doc = create_test_document(doc_id=f"test-doc-{doc_uuid}")
response = mock_server_client.post(
"/set-behavior",
json=[
{
"documents": [test_doc.model_dump(mode="json")],
"checkpoint": MockConnectorCheckpoint(has_more=False).model_dump(
mode="json"
),
"failures": [],
}
],
)
assert response.status_code == 200
# Create CC Pair with SYNC access type to enable permissions during indexing
cc_pair = CCPairManager.create_from_scratch(
name=f"mock-connector-permissions-{uuid.uuid4()}",
source=DocumentSource.MOCK_CONNECTOR,
input_type=InputType.POLL,
connector_specific_config={
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
},
access_type=AccessType.SYNC, # This enables permissions during indexing
user_performing_action=admin_user,
)
# Wait for index attempt to start
index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
# Wait for index attempt to finish
IndexAttemptManager.wait_for_index_attempt_completion(
index_attempt_id=index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
# Validate status
finished_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
index_attempt_id=index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert finished_index_attempt.status == IndexingStatus.SUCCESS
# Verify document was indexed
with get_session_with_current_tenant() as db_session:
documents = DocumentManager.fetch_documents_for_cc_pair(
cc_pair_id=cc_pair.id,
db_session=db_session,
vespa_client=vespa_client,
)
assert len(documents) == 1
assert documents[0].id == test_doc.id
# Verify no errors occurred
errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair(
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert len(errors) == 0
# Verify permissions were set during indexing by checking the document in the database
with get_session_with_current_tenant() as db_session:
db_docs = get_documents_by_ids(
db_session=db_session,
document_ids=[test_doc.id],
)
assert len(db_docs) == 1
db_doc = db_docs[0]
assert db_doc.external_user_emails is not None
assert db_doc.external_user_group_ids is not None
# Check the specific permissions that MockConnector sets
assert set(db_doc.external_user_emails) == EXTERNAL_USER_EMAILS
assert set(db_doc.external_user_group_ids) == EXTERNAL_USER_GROUP_IDS
# Verify the document is not public (as set by MockConnector)
assert db_doc.is_public is False
# Verify that the cc_pair was marked as permissions synced
updated_cc_pair_info = CCPairManager.get_single(
cc_pair.id, user_performing_action=admin_user
)
assert updated_cc_pair_info is not None
assert updated_cc_pair_info.last_full_permission_sync is not None
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Permission sync attempt tracking is enterprise only",
)
def test_permission_sync_attempt_tracking_integration(
mock_server_client: httpx.Client,
vespa_client: vespa_fixture, # noqa: ARG001
admin_user: DATestUser,
) -> None:
"""Test that permission sync attempts are properly tracked during real sync workflows."""
) -> tuple[DATestCCPair, Document]:
"""Common setup: create a test doc, configure mock server, create cc_pair, wait for indexing."""
doc_uuid = uuid.uuid4()
test_doc = create_test_document(doc_id=f"test-doc-{doc_uuid}")
@@ -165,7 +53,7 @@ def test_permission_sync_attempt_tracking_integration(
assert response.status_code == 200
cc_pair = CCPairManager.create_from_scratch(
name=f"mock-connector-attempt-tracking-{uuid.uuid4()}",
name=f"mock-connector-{uuid.uuid4()}",
source=DocumentSource.MOCK_CONNECTOR,
input_type=InputType.POLL,
connector_specific_config={
@@ -187,6 +75,95 @@ def test_permission_sync_attempt_tracking_integration(
user_performing_action=admin_user,
)
finished = IndexAttemptManager.get_index_attempt_by_id(
index_attempt_id=index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert finished.status == IndexingStatus.SUCCESS
return cc_pair, test_doc
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Permission sync is enterprise only",
)
def test_mock_connector_initial_permission_sync(
mock_server_client: httpx.Client,
vespa_client: vespa_fixture,
admin_user: DATestUser,
) -> None:
"""Test that the MockConnector fetches and sets permissions during initial indexing
when AccessType.SYNC is used."""
cc_pair, test_doc = _setup_mock_connector(mock_server_client, admin_user)
with get_session_with_current_tenant() as db_session:
documents = DocumentManager.fetch_documents_for_cc_pair(
cc_pair_id=cc_pair.id,
db_session=db_session,
vespa_client=vespa_client,
)
assert len(documents) == 1
assert documents[0].id == test_doc.id
errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair(
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
assert len(errors) == 0
with get_session_with_current_tenant() as db_session:
db_docs = get_documents_by_ids(
db_session=db_session,
document_ids=[test_doc.id],
)
assert len(db_docs) == 1
db_doc = db_docs[0]
assert db_doc.external_user_emails is not None
assert db_doc.external_user_group_ids is not None
assert set(db_doc.external_user_emails) == EXTERNAL_USER_EMAILS
assert set(db_doc.external_user_group_ids) == EXTERNAL_USER_GROUP_IDS
assert db_doc.is_public is False
# After initial indexing, the beat task detects last_time_perm_sync is None
# and triggers a doc permission sync. Explicitly trigger it to avoid
# waiting for the 30s beat interval.
before = datetime.now(timezone.utc)
CCPairManager.sync(
cc_pair=cc_pair,
user_performing_action=admin_user,
)
CCPairManager.wait_for_sync(
cc_pair=cc_pair,
after=before,
number_of_updated_docs=1,
user_performing_action=admin_user,
should_wait_for_group_sync=False,
should_wait_for_vespa_sync=False,
)
updated_cc_pair_info = CCPairManager.get_single(
cc_pair.id, user_performing_action=admin_user
)
assert updated_cc_pair_info is not None
assert updated_cc_pair_info.last_full_permission_sync is not None
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Permission sync attempt tracking is enterprise only",
)
def test_permission_sync_attempt_tracking_integration(
mock_server_client: httpx.Client,
vespa_client: vespa_fixture, # noqa: ARG001
admin_user: DATestUser,
) -> None:
"""Test that permission sync attempts are properly tracked during real sync workflows."""
cc_pair, _test_doc = _setup_mock_connector(mock_server_client, admin_user)
before = datetime.now(timezone.utc)
CCPairManager.sync(
cc_pair=cc_pair,
@@ -198,6 +175,8 @@ def test_permission_sync_attempt_tracking_integration(
after=before,
number_of_updated_docs=1,
user_performing_action=admin_user,
should_wait_for_group_sync=False,
should_wait_for_vespa_sync=False,
)
with get_session_with_current_tenant() as db_session:
@@ -219,88 +198,6 @@ def test_permission_sync_attempt_tracking_integration(
)
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Permission sync attempt tracking is enterprise only",
)
def test_permission_sync_attempt_tracking_with_mocked_failure(
mock_server_client: httpx.Client,
vespa_client: vespa_fixture, # noqa: ARG001
admin_user: DATestUser,
) -> None:
"""Test that permission sync attempts are properly tracked when sync fails."""
doc_uuid = uuid.uuid4()
test_doc = create_test_document(doc_id=f"test-doc-{doc_uuid}")
response = mock_server_client.post(
"/set-behavior",
json=[
{
"documents": [test_doc.model_dump(mode="json")],
"checkpoint": MockConnectorCheckpoint(has_more=False).model_dump(
mode="json"
),
"failures": [],
}
],
)
assert response.status_code == 200
cc_pair = CCPairManager.create_from_scratch(
name=f"mock-connector-attempt-failure-{uuid.uuid4()}",
source=DocumentSource.MOCK_CONNECTOR,
input_type=InputType.POLL,
connector_specific_config={
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
},
access_type=AccessType.SYNC,
user_performing_action=admin_user,
)
index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
IndexAttemptManager.wait_for_index_attempt_completion(
index_attempt_id=index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
# Mock the permission sync to force a failure and verify attempt tracking
with patch(
"ee.onyx.background.celery.tasks.doc_permission_syncing.tasks.validate_ccpair_for_user"
) as mock_validate:
mock_validate.side_effect = Exception("Validation failed for testing")
try:
before = datetime.now(timezone.utc)
CCPairManager.sync(
cc_pair=cc_pair,
user_performing_action=admin_user,
)
CCPairManager.wait_for_sync(
cc_pair=cc_pair,
after=before,
number_of_updated_docs=0,
user_performing_action=admin_user,
)
except Exception:
pass
with get_session_with_current_tenant() as db_session:
attempt = db_session.execute(
select(DocPermissionSyncAttempt).where(
DocPermissionSyncAttempt.connector_credential_pair_id == cc_pair.id
)
).scalar_one()
assert attempt.status == PermissionSyncStatus.FAILED
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Permission sync attempt tracking is enterprise only",
@@ -311,45 +208,8 @@ def test_permission_sync_attempt_status_success(
admin_user: DATestUser,
) -> None:
"""Test that permission sync attempts are marked as SUCCESS when sync completes without errors."""
doc_uuid = uuid.uuid4()
test_doc = create_test_document(doc_id=f"test-doc-{doc_uuid}")
response = mock_server_client.post(
"/set-behavior",
json=[
{
"documents": [test_doc.model_dump(mode="json")],
"checkpoint": MockConnectorCheckpoint(has_more=False).model_dump(
mode="json"
),
"failures": [],
}
],
)
assert response.status_code == 200
cc_pair = CCPairManager.create_from_scratch(
name=f"mock-connector-success-{uuid.uuid4()}",
source=DocumentSource.MOCK_CONNECTOR,
input_type=InputType.POLL,
connector_specific_config={
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
},
access_type=AccessType.SYNC,
user_performing_action=admin_user,
)
index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
IndexAttemptManager.wait_for_index_attempt_completion(
index_attempt_id=index_attempt.id,
cc_pair_id=cc_pair.id,
user_performing_action=admin_user,
)
cc_pair, _test_doc = _setup_mock_connector(mock_server_client, admin_user)
before = datetime.now(timezone.utc)
CCPairManager.sync(
@@ -362,6 +222,8 @@ def test_permission_sync_attempt_status_success(
after=before,
number_of_updated_docs=1,
user_performing_action=admin_user,
should_wait_for_group_sync=False,
should_wait_for_vespa_sync=False,
)
with get_session_with_current_tenant() as db_session:

View File

@@ -6,11 +6,14 @@ from sqlalchemy.orm import Session
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import LLMModelFlowType
from onyx.db.llm import can_user_access_llm_provider
from onyx.db.llm import fetch_user_group_ids
from onyx.db.models import LLMModelFlow
from onyx.db.models import LLMProvider as LLMProviderModel
from onyx.db.models import LLMProvider__Persona
from onyx.db.models import LLMProvider__UserGroup
from onyx.db.models import ModelConfiguration
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.models import User__UserGroup
@@ -267,6 +270,24 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
provider_name=restricted_provider.name,
)
# Set up ModelConfiguration + LLMModelFlow so get_default_llm() can
# resolve the default provider when the fallback path is triggered.
default_model_config = ModelConfiguration(
llm_provider_id=default_provider.id,
name=default_provider.default_model_name,
is_visible=True,
)
db_session.add(default_model_config)
db_session.flush()
db_session.add(
LLMModelFlow(
model_configuration_id=default_model_config.id,
llm_model_flow_type=LLMModelFlowType.CHAT,
is_default=True,
)
)
db_session.flush()
access_group = UserGroup(name="persona-group")
db_session.add(access_group)
db_session.flush()

View File

@@ -0,0 +1,322 @@
import json
import os
import time
from uuid import uuid4
import pytest
import requests
from pydantic import BaseModel
from pydantic import ConfigDict
from onyx.configs import app_configs
from onyx.configs.constants import DocumentSource
from onyx.tools.constants import SEARCH_TOOL_ID
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.tool import ToolManager
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.test_models import ToolName
_ENV_PROVIDER = "NIGHTLY_LLM_PROVIDER"
_ENV_MODELS = "NIGHTLY_LLM_MODELS"
_ENV_API_KEY = "NIGHTLY_LLM_API_KEY"
_ENV_API_BASE = "NIGHTLY_LLM_API_BASE"
_ENV_CUSTOM_CONFIG_JSON = "NIGHTLY_LLM_CUSTOM_CONFIG_JSON"
_ENV_STRICT = "NIGHTLY_LLM_STRICT"
class NightlyProviderConfig(BaseModel):
model_config = ConfigDict(frozen=True)
provider: str
model_names: list[str]
api_key: str | None
api_base: str | None
custom_config: dict[str, str] | None
strict: bool
def _env_true(env_var: str, default: bool = False) -> bool:
value = os.environ.get(env_var)
if value is None:
return default
return value.strip().lower() in {"1", "true", "yes", "on"}
def _split_csv_env(env_var: str) -> list[str]:
return [
part.strip() for part in os.environ.get(env_var, "").split(",") if part.strip()
]
def _load_provider_config() -> NightlyProviderConfig:
provider = os.environ.get(_ENV_PROVIDER, "").strip().lower()
model_names = _split_csv_env(_ENV_MODELS)
api_key = os.environ.get(_ENV_API_KEY) or None
api_base = os.environ.get(_ENV_API_BASE) or None
strict = _env_true(_ENV_STRICT, default=False)
custom_config: dict[str, str] | None = None
custom_config_json = os.environ.get(_ENV_CUSTOM_CONFIG_JSON, "").strip()
if custom_config_json:
parsed = json.loads(custom_config_json)
if not isinstance(parsed, dict):
raise ValueError(f"{_ENV_CUSTOM_CONFIG_JSON} must be a JSON object")
custom_config = {str(key): str(value) for key, value in parsed.items()}
if provider == "ollama_chat" and api_key and not custom_config:
custom_config = {"OLLAMA_API_KEY": api_key}
return NightlyProviderConfig(
provider=provider,
model_names=model_names,
api_key=api_key,
api_base=api_base,
custom_config=custom_config,
strict=strict,
)
def _skip_or_fail(strict: bool, message: str) -> None:
if strict:
pytest.fail(message)
pytest.skip(message)
def _validate_provider_config(config: NightlyProviderConfig) -> None:
if not config.provider:
_skip_or_fail(strict=config.strict, message=f"{_ENV_PROVIDER} must be set")
if not config.model_names:
_skip_or_fail(
strict=config.strict,
message=f"{_ENV_MODELS} must include at least one model",
)
if config.provider != "ollama_chat" and not config.api_key:
_skip_or_fail(
strict=config.strict,
message=(f"{_ENV_API_KEY} is required for provider '{config.provider}'"),
)
if config.provider == "ollama_chat" and not (
config.api_base or _default_api_base_for_provider(config.provider)
):
_skip_or_fail(
strict=config.strict,
message=(f"{_ENV_API_BASE} is required for provider '{config.provider}'"),
)
def _assert_integration_mode_enabled() -> None:
assert (
app_configs.INTEGRATION_TESTS_MODE is True
), "Integration tests require INTEGRATION_TESTS_MODE=true."
def _seed_connector_for_search_tool(admin_user: DATestUser) -> None:
# SearchTool is only exposed when at least one non-default connector exists.
CCPairManager.create_from_scratch(
source=DocumentSource.INGESTION_API,
user_performing_action=admin_user,
)
def _get_internal_search_tool_id(admin_user: DATestUser) -> int:
tools = ToolManager.list_tools(user_performing_action=admin_user)
for tool in tools:
if tool.in_code_tool_id == SEARCH_TOOL_ID:
return tool.id
raise AssertionError("SearchTool must exist for this test")
def _default_api_base_for_provider(provider: str) -> str | None:
if provider == "openrouter":
return "https://openrouter.ai/api/v1"
if provider == "ollama_chat":
# host.docker.internal works when tests are running inside the integration test container.
return "http://host.docker.internal:11434"
return None
def _create_provider_payload(
provider: str,
provider_name: str,
model_name: str,
api_key: str | None,
api_base: str | None,
custom_config: dict[str, str] | None,
) -> dict:
return {
"name": provider_name,
"provider": provider,
"api_key": api_key,
"api_base": api_base,
"custom_config": custom_config,
"default_model_name": model_name,
"is_public": True,
"groups": [],
"personas": [],
"model_configurations": [{"name": model_name, "is_visible": True}],
"api_key_changed": bool(api_key),
"custom_config_changed": bool(custom_config),
}
def _ensure_provider_is_default(provider_id: int, admin_user: DATestUser) -> None:
list_response = requests.get(
f"{API_SERVER_URL}/admin/llm/provider",
headers=admin_user.headers,
)
list_response.raise_for_status()
providers = list_response.json()
current_default = next(
(provider for provider in providers if provider.get("is_default_provider")),
None,
)
assert (
current_default is not None
), "Expected a default provider after setting provider as default"
assert (
current_default["id"] == provider_id
), f"Expected provider {provider_id} to be default, found {current_default['id']}"
def _run_chat_assertions(
admin_user: DATestUser,
search_tool_id: int,
provider: str,
model_name: str,
) -> None:
last_error: str | None = None
# Retry once to reduce transient nightly flakes due provider-side blips.
for attempt in range(1, 3):
chat_session = ChatSessionManager.create(user_performing_action=admin_user)
response = ChatSessionManager.send_message(
chat_session_id=chat_session.id,
message=(
"Use internal_search to search for 'nightly-provider-regression-sentinel', "
"then summarize the result in one short sentence."
),
user_performing_action=admin_user,
forced_tool_ids=[search_tool_id],
)
if response.error is None:
used_internal_search = any(
used_tool.tool_name == ToolName.INTERNAL_SEARCH
for used_tool in response.used_tools
)
debug_has_internal_search = any(
debug_tool_call.tool_name == "internal_search"
for debug_tool_call in response.tool_call_debug
)
has_answer = bool(response.full_message.strip())
if used_internal_search and debug_has_internal_search and has_answer:
return
last_error = (
f"attempt={attempt} provider={provider} model={model_name} "
f"used_internal_search={used_internal_search} "
f"debug_internal_search={debug_has_internal_search} "
f"has_answer={has_answer} "
f"tool_call_debug={response.tool_call_debug}"
)
else:
last_error = (
f"attempt={attempt} provider={provider} model={model_name} "
f"stream_error={response.error.error}"
)
time.sleep(attempt)
pytest.fail(f"Chat/tool-call assertions failed: {last_error}")
def _create_and_test_provider_for_model(
admin_user: DATestUser,
config: NightlyProviderConfig,
model_name: str,
search_tool_id: int,
) -> None:
provider_name = f"nightly-{config.provider}-{uuid4().hex[:12]}"
resolved_api_base = config.api_base or _default_api_base_for_provider(
config.provider
)
provider_payload = _create_provider_payload(
provider=config.provider,
provider_name=provider_name,
model_name=model_name,
api_key=config.api_key,
api_base=resolved_api_base,
custom_config=config.custom_config,
)
test_response = requests.post(
f"{API_SERVER_URL}/admin/llm/test",
headers=admin_user.headers,
json=provider_payload,
)
assert test_response.status_code == 200, (
f"Provider test endpoint failed for provider={config.provider} "
f"model={model_name}: {test_response.status_code} {test_response.text}"
)
create_response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
headers=admin_user.headers,
json=provider_payload,
)
assert create_response.status_code == 200, (
f"Provider creation failed for provider={config.provider} "
f"model={model_name}: {create_response.status_code} {create_response.text}"
)
provider_id = create_response.json()["id"]
try:
set_default_response = requests.post(
f"{API_SERVER_URL}/admin/llm/provider/{provider_id}/default",
headers=admin_user.headers,
)
assert set_default_response.status_code == 200, (
f"Setting default provider failed for provider={config.provider} "
f"model={model_name}: {set_default_response.status_code} "
f"{set_default_response.text}"
)
_ensure_provider_is_default(provider_id=provider_id, admin_user=admin_user)
_run_chat_assertions(
admin_user=admin_user,
search_tool_id=search_tool_id,
provider=config.provider,
model_name=model_name,
)
finally:
requests.delete(
f"{API_SERVER_URL}/admin/llm/provider/{provider_id}",
headers=admin_user.headers,
)
def test_nightly_provider_chat_workflow(admin_user: DATestUser) -> None:
"""Nightly regression test for provider setup + default selection + chat tool calls."""
_assert_integration_mode_enabled()
config = _load_provider_config()
_validate_provider_config(config)
_seed_connector_for_search_tool(admin_user)
search_tool_id = _get_internal_search_tool_id(admin_user)
for model_name in config.model_names:
_create_and_test_provider_for_model(
admin_user=admin_user,
config=config,
model_name=model_name,
search_tool_id=search_tool_id,
)

View File

@@ -6,7 +6,7 @@ the permissions of the curator manipulating connector-credential pairs.
import os
import pytest
from requests.exceptions import HTTPError
from onyx_openapi_client.exceptions import ApiException # type: ignore[import-untyped,unused-ignore,import-not-found]
from onyx.db.enums import AccessType
from onyx.server.documents.models import DocumentSource
@@ -93,20 +93,9 @@ def test_cc_pair_permissions(reset: None) -> None: # noqa: ARG001
"""Tests for things Curators should not be able to do"""
# Curators should not be able to create a public cc pair
with pytest.raises(HTTPError):
CCPairManager.create(
connector_id=connector_1.id,
credential_id=credential_1.id,
name="invalid_cc_pair_1",
access_type=AccessType.PUBLIC,
groups=[user_group_1.id],
user_performing_action=curator,
)
# Curators should not be able to create a cc
# pair for a user group they are not a curator of
with pytest.raises(HTTPError):
with pytest.raises(ApiException):
CCPairManager.create(
connector_id=connector_1.id,
credential_id=credential_1.id,
@@ -118,7 +107,7 @@ def test_cc_pair_permissions(reset: None) -> None: # noqa: ARG001
# Curators should not be able to create a cc
# pair without an attached user group
with pytest.raises(HTTPError):
with pytest.raises(ApiException):
CCPairManager.create(
connector_id=connector_1.id,
credential_id=credential_1.id,
@@ -144,7 +133,7 @@ def test_cc_pair_permissions(reset: None) -> None: # noqa: ARG001
# Curators should not be able to create a cc
# pair for a user group that the credential does not belong to
with pytest.raises(HTTPError):
with pytest.raises(ApiException):
CCPairManager.create(
connector_id=connector_1.id,
credential_id=credential_2.id,
@@ -156,6 +145,16 @@ def test_cc_pair_permissions(reset: None) -> None: # noqa: ARG001
"""Tests for things Curators should be able to do"""
# Re-create connector since the credential_2 validation error above
# triggers connector deletion in the exception handler
connector_1 = ConnectorManager.create(
name="admin_owned_connector_2",
source=DocumentSource.CONFLUENCE,
groups=[user_group_1.id],
access_type=AccessType.PRIVATE,
user_performing_action=admin_user,
)
# Curators should be able to create a private
# cc pair for a user group they are a curator of
valid_cc_pair = CCPairManager.create(

View File

@@ -59,17 +59,7 @@ def test_connector_permissions(reset: None) -> None: # noqa: ARG001
"""Tests for things Curators should not be able to do"""
# Curators should not be able to create a public connector
with pytest.raises(HTTPError):
ConnectorManager.create(
name="invalid_connector_1",
source=DocumentSource.CONFLUENCE,
groups=[user_group_1.id],
access_type=AccessType.PUBLIC,
user_performing_action=curator,
)
# Curators should not be able to create a cc pair for a
# Curators should not be able to create a connector for a
# user group they are not a curator of
with pytest.raises(HTTPError):
ConnectorManager.create(
@@ -133,12 +123,12 @@ def test_connector_permissions(reset: None) -> None: # noqa: ARG001
user_performing_action=curator,
)
# Test that curator cannot create a public connector
with pytest.raises(HTTPError):
ConnectorManager.create(
name="invalid_connector_4",
source=DocumentSource.CONFLUENCE,
groups=[user_group_1.id],
access_type=AccessType.PUBLIC,
user_performing_action=curator,
)
# Curators should be able to create a public connector
public_connector = ConnectorManager.create(
name="curator_public_connector",
source=DocumentSource.CONFLUENCE,
groups=[user_group_1.id],
access_type=AccessType.PUBLIC,
user_performing_action=curator,
)
assert public_connector.id is not None

View File

@@ -58,16 +58,6 @@ def test_credential_permissions(reset: None) -> None: # noqa: ARG001
"""Tests for things Curators should not be able to do"""
# Curators should not be able to create a public credential
with pytest.raises(HTTPError):
CredentialManager.create(
name="invalid_credential_1",
source=DocumentSource.CONFLUENCE,
groups=[user_group_1.id],
curator_public=True,
user_performing_action=curator,
)
# Curators should not be able to create a credential for a user group they are not a curator of
with pytest.raises(HTTPError):
CredentialManager.create(
@@ -113,3 +103,16 @@ def test_credential_permissions(reset: None) -> None: # noqa: ARG001
verify_deleted=True,
user_performing_action=curator,
)
# Curators should be able to create a public credential
public_credential = CredentialManager.create(
name="curator_public_credential",
source=DocumentSource.CONFLUENCE,
groups=[user_group_1.id],
curator_public=True,
user_performing_action=curator,
)
CredentialManager.verify(
credential=public_credential,
user_performing_action=curator,
)

View File

@@ -70,10 +70,11 @@ def test_doc_set_permissions_setup(reset: None) -> None: # noqa: ARG001
"""Tests for things Curators/Admins should not be able to do"""
# Test that curator cannot create a document set for the group they don't curate
# Test that curator cannot create a non-public document set for the group they don't curate
with pytest.raises(HTTPError):
DocumentSetManager.create(
name="Invalid Document Set 1",
is_public=False,
groups=[user_group_2.id],
cc_pair_ids=[public_cc_pair.id],
user_performing_action=curator,

View File

@@ -6,12 +6,14 @@ from datetime import timedelta
from datetime import timezone
from io import BytesIO
from io import StringIO
from uuid import UUID
from zipfile import ZipFile
import pytest
import requests
from ee.onyx.db.usage_export import UsageReportMetadata
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.db.seeding.chat_history_seeding import seed_chat_history
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.test_models import DATestUser
@@ -26,7 +28,13 @@ class TestUsageExportAPI:
self, reset: None, admin_user: DATestUser # noqa: ARG002
) -> None:
# Seed some chat history data for the report
seed_chat_history(num_sessions=10, num_messages=4, days=30)
seed_chat_history(
num_sessions=10,
num_messages=4,
days=30,
user_id=UUID(admin_user.id),
persona_id=DEFAULT_PERSONA_ID,
)
# Get initial list of reports
initial_response = requests.get(
@@ -76,7 +84,13 @@ class TestUsageExportAPI:
self, reset: None, admin_user: DATestUser # noqa: ARG002
) -> None:
# Seed some chat history data
seed_chat_history(num_sessions=20, num_messages=4, days=60)
seed_chat_history(
num_sessions=20,
num_messages=4,
days=60,
user_id=UUID(admin_user.id),
persona_id=DEFAULT_PERSONA_ID,
)
# Get initial list of reports
initial_response = requests.get(
@@ -148,7 +162,13 @@ class TestUsageExportAPI:
self, reset: None, admin_user: DATestUser # noqa: ARG002
) -> None:
# First generate a report to ensure we have at least one
seed_chat_history(num_sessions=5, num_messages=4, days=30)
seed_chat_history(
num_sessions=5,
num_messages=4,
days=30,
user_id=UUID(admin_user.id),
persona_id=DEFAULT_PERSONA_ID,
)
# Get initial count
initial_response = requests.get(
@@ -204,7 +224,13 @@ class TestUsageExportAPI:
self, reset: None, admin_user: DATestUser # noqa: ARG002
) -> None:
# First generate a report
seed_chat_history(num_sessions=5, num_messages=4, days=30)
seed_chat_history(
num_sessions=5,
num_messages=4,
days=30,
user_id=UUID(admin_user.id),
persona_id=DEFAULT_PERSONA_ID,
)
# Get initial reports count
initial_response = requests.get(
@@ -352,7 +378,13 @@ class TestUsageExportAPI:
self, reset: None, admin_user: DATestUser # noqa: ARG002
) -> None:
# Seed some data
seed_chat_history(num_sessions=10, num_messages=4, days=30)
seed_chat_history(
num_sessions=10,
num_messages=4,
days=30,
user_id=UUID(admin_user.id),
persona_id=DEFAULT_PERSONA_ID,
)
# Get initial count of reports
initial_response = requests.get(

View File

@@ -25,6 +25,11 @@ def test_add_users_to_group(reset: None) -> None: # noqa: ARG001
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_performing_action=admin_user,
user_groups_to_check=[user_group],
)
updated_user_group = UserGroupManager.add_users(
user_group=user_group,
user_ids=[user_to_add.id],

View File

@@ -3,6 +3,8 @@ from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from ee.onyx.external_permissions.sharepoint.permission_utils import (
_enumerate_ad_groups_paginated,
)
@@ -15,6 +17,9 @@ from ee.onyx.external_permissions.sharepoint.permission_utils import (
from ee.onyx.external_permissions.sharepoint.permission_utils import (
AD_GROUP_ENUMERATION_THRESHOLD,
)
from ee.onyx.external_permissions.sharepoint.permission_utils import (
get_external_access_from_sharepoint,
)
from ee.onyx.external_permissions.sharepoint.permission_utils import (
get_sharepoint_external_groups,
)
@@ -266,3 +271,65 @@ def test_enumerate_all_without_token_skips(
assert results == []
mock_enum.assert_not_called()
# ---------------------------------------------------------------------------
# get_external_access_from_sharepoint site page URL handling
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"site_base_url, web_url, expected_relative_url",
[
(
"https://tenant.sharepoint.com/sites/Evan%27sSite",
"https://tenant.sharepoint.com/sites/Evan%27sSite/SitePages/Home.aspx",
"/sites/Evan%27sSite/SitePages/Home.aspx",
),
(
"https://tenant.sharepoint.com/sites/NormalSite",
"https://tenant.sharepoint.com/sites/NormalSite/SitePages/Page.aspx",
"/sites/NormalSite/SitePages/Page.aspx",
),
(
"https://tenant.sharepoint.com/sites/Site%20With%20Spaces",
"https://tenant.sharepoint.com/sites/Site%20With%20Spaces/SitePages/Doc.aspx",
"/sites/Site%20With%20Spaces/SitePages/Doc.aspx",
),
],
ids=["apostrophe-encoded", "no-special-chars", "space-encoded"],
)
@patch(f"{MODULE}._get_groups_and_members_recursively")
@patch(f"{MODULE}.sleep_and_retry")
def test_site_page_url_not_duplicated(
mock_sleep: MagicMock, # noqa: ARG001
mock_recursive: MagicMock,
site_base_url: str,
web_url: str,
expected_relative_url: str,
) -> None:
"""Regression: the server-relative URL passed to
get_file_by_server_relative_url must preserve percent-encoding so the
Office365 library's SPResPath.create_relative() recognises the site prefix
and doesn't duplicate it."""
mock_recursive.return_value = GroupsResult(
groups_to_emails={},
found_public_group=False,
)
ctx = MagicMock()
ctx.base_url = site_base_url
site_page = {"webUrl": web_url}
get_external_access_from_sharepoint(
client_context=ctx,
graph_client=MagicMock(),
drive_name=None,
drive_item=None,
site_page=site_page,
)
ctx.web.get_file_by_server_relative_url.assert_called_once_with(
expected_relative_url
)

View File

@@ -0,0 +1,65 @@
import json
import httplib2 # type: ignore[import-untyped]
from googleapiclient.errors import HttpError # type: ignore[import-untyped]
from onyx.connectors.google_utils.google_utils import _is_rate_limit_error
def _make_http_error(
status: int,
reason: str = "unknown",
error_reason: str = "",
) -> HttpError:
resp = httplib2.Response({"status": status})
if error_reason:
body = json.dumps(
{
"error": {
"message": reason,
"errors": [{"reason": error_reason, "message": reason}],
}
}
).encode()
else:
body = json.dumps({"error": {"message": reason}}).encode()
return HttpError(resp, body)
def test_429_is_rate_limit() -> None:
assert _is_rate_limit_error(_make_http_error(429))
def test_403_user_rate_limit_exceeded() -> None:
err = _make_http_error(
403,
reason="User rate limit exceeded.",
error_reason="userRateLimitExceeded",
)
assert _is_rate_limit_error(err)
def test_403_rate_limit_exceeded() -> None:
err = _make_http_error(
403,
reason="Rate limit exceeded.",
error_reason="rateLimitExceeded",
)
assert _is_rate_limit_error(err)
def test_403_permission_denied_is_not_rate_limit() -> None:
err = _make_http_error(
403,
reason="The caller does not have permission",
error_reason="forbidden",
)
assert not _is_rate_limit_error(err)
def test_404_is_not_rate_limit() -> None:
assert not _is_rate_limit_error(_make_http_error(404))
def test_500_is_not_rate_limit() -> None:
assert not _is_rate_limit_error(_make_http_error(500))

View File

@@ -0,0 +1,34 @@
from unittest.mock import patch
import pytest
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.slab.connector import SlabConnector
def _build_connector(base_url: str = "https://myteam.slab.com") -> SlabConnector:
connector = SlabConnector(base_url=base_url)
connector.load_credentials({"slab_bot_token": "fake-token"})
return connector
def test_validate_rejects_missing_scheme() -> None:
connector = _build_connector(base_url="myteam.slab.com")
with pytest.raises(ConnectorValidationError, match="https://"):
connector.validate_connector_settings()
@patch("onyx.connectors.slab.connector.get_all_post_ids", return_value=["id1"])
def test_validate_success(mock_get_posts: object) -> None: # noqa: ARG001
connector = _build_connector()
connector.validate_connector_settings()
@patch(
"onyx.connectors.slab.connector.get_all_post_ids",
side_effect=Exception("401 Unauthorized"),
)
def test_validate_bad_token_raises(mock_get_posts: object) -> None: # noqa: ARG001
connector = _build_connector()
with pytest.raises(ConnectorValidationError, match="Failed to fetch posts"):
connector.validate_connector_settings()

View File

@@ -98,6 +98,11 @@ class TestScimDALUserMappings:
"external_id": "ext-1",
"user_id": user_id,
"scim_username": None,
"department": None,
"manager": None,
"given_name": None,
"family_name": None,
"scim_emails_json": None,
}
def test_delete_user_mapping(

View File

@@ -0,0 +1,199 @@
"""Tests that persona IDs are correctly propagated through the indexing pipeline.
Covers Phase 1 (schema plumbing) and Phase 2 (write at index time) of the
unify-assistant-project-files plan.
"""
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
from onyx.access.models import DocumentAccess
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentSource
from onyx.connectors.models import TextSection
from onyx.indexing.models import ChunkEmbedding
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
def _make_index_chunk(
doc_id: str = "test-file-id",
content: str = "test content",
) -> IndexChunk:
embedding = [0.1] * 10
doc = Document(
id=doc_id,
semantic_identifier="test_file.txt",
sections=[TextSection(text=content, link=None)],
source=DocumentSource.USER_FILE,
metadata={},
)
return IndexChunk(
chunk_id=0,
blurb=content[:50],
content=content,
source_links=None,
image_file_id=None,
section_continuation=False,
source_document=doc,
title_prefix="",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
contextual_rag_reserved_tokens=0,
doc_summary="",
chunk_context="",
mini_chunk_texts=None,
large_chunk_id=None,
embeddings=ChunkEmbedding(
full_embedding=embedding,
mini_chunk_embeddings=[],
),
title_embedding=None,
)
def _make_access() -> DocumentAccess:
return DocumentAccess.build(
user_emails=["user@example.com"],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=False,
)
def test_from_index_chunk_propagates_personas() -> None:
"""Personas list passed to from_index_chunk appears on the result."""
chunk = _make_index_chunk()
persona_ids = [10, 20, 30]
aware_chunk = DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=_make_access(),
document_sets=set(),
user_project=[1],
personas=persona_ids,
boost=0,
aggregated_chunk_boost_factor=1.0,
tenant_id="test_tenant",
)
assert aware_chunk.personas == persona_ids
assert aware_chunk.user_project == [1]
def test_from_index_chunk_empty_personas() -> None:
"""An empty personas list is preserved (not turned into None or omitted)."""
chunk = _make_index_chunk()
aware_chunk = DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=_make_access(),
document_sets=set(),
user_project=[],
personas=[],
boost=0,
aggregated_chunk_boost_factor=1.0,
tenant_id="test_tenant",
)
assert aware_chunk.personas == []
def _make_document(doc_id: str) -> Document:
return Document(
id=doc_id,
semantic_identifier="test_file.txt",
sections=[TextSection(text="test content", link=None)],
source=DocumentSource.USER_FILE,
metadata={},
)
def _run_adapter_build(
file_id: str,
project_ids_map: dict[str, list[int]],
persona_ids_map: dict[str, list[int]],
) -> list[DocMetadataAwareIndexChunk]:
"""Helper that runs UserFileIndexingAdapter.build_metadata_aware_chunks
with all external dependencies mocked."""
from onyx.indexing.adapters.user_file_indexing_adapter import (
UserFileIndexingAdapter,
)
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
chunk = _make_index_chunk(doc_id=file_id)
doc = _make_document(doc_id=file_id)
context = DocumentBatchPrepareContext(
updatable_docs=[doc],
id_to_boost_map={},
)
adapter = UserFileIndexingAdapter(tenant_id="test_tenant", db_session=MagicMock())
with (
patch(
"onyx.indexing.adapters.user_file_indexing_adapter.fetch_user_project_ids_for_user_files",
return_value=project_ids_map,
),
patch(
"onyx.indexing.adapters.user_file_indexing_adapter.fetch_persona_ids_for_user_files",
return_value=persona_ids_map,
),
patch(
"onyx.indexing.adapters.user_file_indexing_adapter.get_access_for_user_files",
return_value={file_id: _make_access()},
),
patch(
"onyx.indexing.adapters.user_file_indexing_adapter.fetch_chunk_counts_for_user_files",
return_value=[(file_id, 0)],
),
patch(
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
side_effect=Exception("no LLM in tests"),
),
):
result = adapter.build_metadata_aware_chunks(
chunks_with_embeddings=[chunk],
chunk_content_scores=[1.0],
tenant_id="test_tenant",
context=context,
)
return result.chunks
def test_build_metadata_aware_chunks_includes_persona_ids() -> None:
"""UserFileIndexingAdapter.build_metadata_aware_chunks writes persona IDs
fetched from the DB into each chunk's metadata."""
file_id = str(uuid4())
persona_ids = [5, 12]
project_ids = [3]
chunks = _run_adapter_build(
file_id=file_id,
project_ids_map={file_id: project_ids},
persona_ids_map={file_id: persona_ids},
)
assert len(chunks) == 1
assert chunks[0].personas == persona_ids
assert chunks[0].user_project == project_ids
def test_build_metadata_aware_chunks_missing_file_defaults_to_empty() -> None:
"""When a file has no persona or project associations in the DB, the
adapter should default to empty lists (not KeyError or None)."""
file_id = str(uuid4())
chunks = _run_adapter_build(
file_id=file_id,
project_ids_map={},
persona_ids_map={},
)
assert len(chunks) == 1
assert chunks[0].personas == []
assert chunks[0].user_project == []

View File

@@ -0,0 +1,106 @@
from onyx.onyxbot.slack.formatting import _convert_slack_links_to_markdown
from onyx.onyxbot.slack.formatting import _normalize_link_destinations
from onyx.onyxbot.slack.formatting import _sanitize_html
from onyx.onyxbot.slack.formatting import _transform_outside_code_blocks
from onyx.onyxbot.slack.formatting import format_slack_message
from onyx.onyxbot.slack.utils import remove_slack_text_interactions
from onyx.utils.text_processing import decode_escapes
def test_normalize_citation_link_wraps_url_with_parentheses() -> None:
message = (
"See [[1]](https://example.com/Access%20ID%20Card(s)%20Guide.pdf) for details."
)
normalized = _normalize_link_destinations(message)
assert (
"See [[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>) for details."
== normalized
)
def test_normalize_citation_link_keeps_existing_angle_brackets() -> None:
message = "[[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>)"
normalized = _normalize_link_destinations(message)
assert message == normalized
def test_normalize_citation_link_handles_multiple_links() -> None:
message = (
"[[1]](https://example.com/(USA)%20Guide.pdf) "
"[[2]](https://example.com/Plan(s)%20Overview.pdf)"
)
normalized = _normalize_link_destinations(message)
assert "[[1]](<https://example.com/(USA)%20Guide.pdf>)" in normalized
assert "[[2]](<https://example.com/Plan(s)%20Overview.pdf>)" in normalized
def test_format_slack_message_keeps_parenthesized_citation_links_intact() -> None:
message = (
"Download [[1]](https://example.com/(USA)%20Access%20ID%20Card(s)%20Guide.pdf)"
)
formatted = format_slack_message(message)
rendered = decode_escapes(remove_slack_text_interactions(formatted))
assert (
"<https://example.com/(USA)%20Access%20ID%20Card(s)%20Guide.pdf|[1]>"
in rendered
)
assert "|[1]>%20Access%20ID%20Card" not in rendered
def test_slack_style_links_converted_to_clickable_links() -> None:
message = "Visit <https://example.com/page|Example Page> for details."
formatted = format_slack_message(message)
assert "<https://example.com/page|Example Page>" in formatted
assert "&lt;" not in formatted
def test_slack_style_links_preserved_inside_code_blocks() -> None:
message = "```\n<https://example.com|click>\n```"
converted = _convert_slack_links_to_markdown(message)
assert "<https://example.com|click>" in converted
def test_html_tags_stripped_outside_code_blocks() -> None:
message = "Hello<br/>world ```<div>code</div>``` after"
sanitized = _transform_outside_code_blocks(message, _sanitize_html)
assert "<br" not in sanitized
assert "<div>code</div>" in sanitized
def test_format_slack_message_block_spacing() -> None:
message = "Paragraph one.\n\nParagraph two."
formatted = format_slack_message(message)
assert "Paragraph one.\n\nParagraph two." == formatted
def test_format_slack_message_code_block_no_trailing_blank_line() -> None:
message = "```python\nprint('hi')\n```"
formatted = format_slack_message(message)
assert formatted.endswith("print('hi')\n```")
def test_format_slack_message_ampersand_not_double_escaped() -> None:
message = 'She said "hello" & goodbye.'
formatted = format_slack_message(message)
assert "&amp;" in formatted
assert "&quot;" not in formatted

View File

@@ -1,10 +1,12 @@
"""Test bulk invite limit for free trial tenants."""
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from fastapi import HTTPException
from onyx.server.manage.models import EmailInviteStatus
from onyx.server.manage.users import bulk_invite_users
@@ -33,6 +35,7 @@ def test_trial_tenant_cannot_exceed_invite_limit(*_mocks: None) -> None:
@patch("onyx.server.manage.users.get_invited_users", return_value=[])
@patch("onyx.server.manage.users.get_all_users", return_value=[])
@patch("onyx.server.manage.users.write_invited_users", return_value=3)
@patch("onyx.server.manage.users.enforce_seat_limit")
@patch("onyx.server.manage.users.NUM_FREE_TRIAL_USER_INVITES", 5)
@patch(
"onyx.server.manage.users.fetch_ee_implementation_or_noop",
@@ -44,4 +47,69 @@ def test_trial_tenant_can_invite_within_limit(*_mocks: None) -> None:
result = bulk_invite_users(emails=emails)
assert result == 3
assert result.invited_count == 3
assert result.email_invite_status == EmailInviteStatus.DISABLED
# --- email_invite_status tests ---
_COMMON_PATCHES = [
patch("onyx.server.manage.users.MULTI_TENANT", False),
patch("onyx.server.manage.users.get_current_tenant_id", return_value="test_tenant"),
patch("onyx.server.manage.users.get_invited_users", return_value=[]),
patch("onyx.server.manage.users.get_all_users", return_value=[]),
patch("onyx.server.manage.users.write_invited_users", return_value=1),
patch("onyx.server.manage.users.enforce_seat_limit"),
]
def _with_common_patches(fn: object) -> object:
for p in reversed(_COMMON_PATCHES):
fn = p(fn) # type: ignore
return fn
@_with_common_patches
@patch("onyx.server.manage.users.ENABLE_EMAIL_INVITES", False)
def test_email_invite_status_disabled(*_mocks: None) -> None:
"""When email invites are disabled, status is disabled."""
result = bulk_invite_users(emails=["user@example.com"])
assert result.email_invite_status == EmailInviteStatus.DISABLED
@_with_common_patches
@patch("onyx.server.manage.users.ENABLE_EMAIL_INVITES", True)
@patch("onyx.server.manage.users.EMAIL_CONFIGURED", False)
def test_email_invite_status_not_configured(*_mocks: None) -> None:
"""When email invites are enabled but no server is configured, status is not_configured."""
result = bulk_invite_users(emails=["user@example.com"])
assert result.email_invite_status == EmailInviteStatus.NOT_CONFIGURED
@_with_common_patches
@patch("onyx.server.manage.users.ENABLE_EMAIL_INVITES", True)
@patch("onyx.server.manage.users.EMAIL_CONFIGURED", True)
@patch("onyx.server.manage.users.send_user_email_invite")
def test_email_invite_status_sent(mock_send: MagicMock, *_mocks: None) -> None:
"""When email invites are enabled and configured, status is sent."""
result = bulk_invite_users(emails=["user@example.com"])
mock_send.assert_called_once()
assert result.email_invite_status == EmailInviteStatus.SENT
@_with_common_patches
@patch("onyx.server.manage.users.ENABLE_EMAIL_INVITES", True)
@patch("onyx.server.manage.users.EMAIL_CONFIGURED", True)
@patch(
"onyx.server.manage.users.send_user_email_invite",
side_effect=Exception("SMTP auth failed"),
)
def test_email_invite_status_send_failed(*_mocks: None) -> None:
"""When email sending throws, status is send_failed and invite is still saved."""
result = bulk_invite_users(emails=["user@example.com"])
assert result.email_invite_status == EmailInviteStatus.SEND_FAILED
assert result.invited_count == 1

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import json
from collections.abc import Generator
from typing import Any
from unittest.mock import MagicMock
@@ -12,7 +13,9 @@ import pytest
from fastapi.responses import JSONResponse
from sqlalchemy.orm import Session
from ee.onyx.server.scim.api import ScimJSONResponse
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimListResponse
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimUserResource
from ee.onyx.server.scim.providers.base import ScimProvider
@@ -115,6 +118,11 @@ def make_user_mapping(**kwargs: Any) -> MagicMock:
mapping.external_id = kwargs.get("external_id", "ext-default")
mapping.user_id = kwargs.get("user_id", uuid4())
mapping.scim_username = kwargs.get("scim_username", None)
mapping.department = kwargs.get("department", None)
mapping.manager = kwargs.get("manager", None)
mapping.given_name = kwargs.get("given_name", None)
mapping.family_name = kwargs.get("family_name", None)
mapping.scim_emails_json = kwargs.get("scim_emails_json", None)
return mapping
@@ -122,3 +130,35 @@ def assert_scim_error(result: object, expected_status: int) -> None:
"""Assert *result* is a JSONResponse with the given status code."""
assert isinstance(result, JSONResponse)
assert result.status_code == expected_status
# ---------------------------------------------------------------------------
# Response parsing helpers
# ---------------------------------------------------------------------------
def parse_scim_user(result: object, *, status: int = 200) -> ScimUserResource:
"""Assert *result* is a ScimJSONResponse and parse as ScimUserResource."""
assert isinstance(
result, ScimJSONResponse
), f"Expected ScimJSONResponse, got {type(result).__name__}"
assert result.status_code == status
return ScimUserResource.model_validate(json.loads(result.body))
def parse_scim_group(result: object, *, status: int = 200) -> ScimGroupResource:
"""Assert *result* is a ScimJSONResponse and parse as ScimGroupResource."""
assert isinstance(
result, ScimJSONResponse
), f"Expected ScimJSONResponse, got {type(result).__name__}"
assert result.status_code == status
return ScimGroupResource.model_validate(json.loads(result.body))
def parse_scim_list(result: object) -> ScimListResponse:
"""Assert *result* is a ScimJSONResponse and parse as ScimListResponse."""
assert isinstance(
result, ScimJSONResponse
), f"Expected ScimJSONResponse, got {type(result).__name__}"
assert result.status_code == 200
return ScimListResponse.model_validate(json.loads(result.body))

View File

@@ -0,0 +1,983 @@
"""Comprehensive Entra ID (Azure AD) SCIM compatibility tests.
Covers the full Entra provisioning lifecycle: service discovery, user CRUD
with enterprise extension schema, group CRUD with excludedAttributes, and
all Entra-specific behavioral quirks (PascalCase ops, enterprise URN in
PATCH value dicts).
"""
from __future__ import annotations
import json
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
import pytest
from fastapi import Response
from ee.onyx.server.scim.api import create_user
from ee.onyx.server.scim.api import delete_user
from ee.onyx.server.scim.api import get_group
from ee.onyx.server.scim.api import get_resource_types
from ee.onyx.server.scim.api import get_schemas
from ee.onyx.server.scim.api import get_service_provider_config
from ee.onyx.server.scim.api import get_user
from ee.onyx.server.scim.api import list_groups
from ee.onyx.server.scim.api import list_users
from ee.onyx.server.scim.api import patch_group
from ee.onyx.server.scim.api import patch_user
from ee.onyx.server.scim.api import replace_user
from ee.onyx.server.scim.api import ScimJSONResponse
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
from ee.onyx.server.scim.models import ScimEnterpriseExtension
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimManagerRef
from ee.onyx.server.scim.models import ScimMappingFields
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimPatchOperation
from ee.onyx.server.scim.models import ScimPatchOperationType
from ee.onyx.server.scim.models import ScimPatchRequest
from ee.onyx.server.scim.models import ScimPatchResourceValue
from ee.onyx.server.scim.models import ScimUserResource
from ee.onyx.server.scim.providers.base import ScimProvider
from ee.onyx.server.scim.providers.entra import EntraProvider
from tests.unit.onyx.server.scim.conftest import make_db_group
from tests.unit.onyx.server.scim.conftest import make_db_user
from tests.unit.onyx.server.scim.conftest import make_scim_user
from tests.unit.onyx.server.scim.conftest import make_user_mapping
from tests.unit.onyx.server.scim.conftest import parse_scim_group
from tests.unit.onyx.server.scim.conftest import parse_scim_list
from tests.unit.onyx.server.scim.conftest import parse_scim_user
@pytest.fixture
def entra_provider() -> ScimProvider:
"""An EntraProvider instance for Entra-specific endpoint tests."""
return EntraProvider()
# ---------------------------------------------------------------------------
# Service Discovery
# ---------------------------------------------------------------------------
class TestEntraServiceDiscovery:
"""Entra expects enterprise extension in discovery endpoints."""
def test_service_provider_config_advertises_patch(self) -> None:
config = get_service_provider_config()
assert config.patch.supported is True
def test_resource_types_include_enterprise_extension(self) -> None:
result = get_resource_types()
assert isinstance(result, ScimJSONResponse)
parsed = json.loads(result.body)
assert "Resources" in parsed
user_type = next(rt for rt in parsed["Resources"] if rt["id"] == "User")
extension_schemas = [ext["schema"] for ext in user_type["schemaExtensions"]]
assert SCIM_ENTERPRISE_USER_SCHEMA in extension_schemas
def test_schemas_include_enterprise_user(self) -> None:
result = get_schemas()
assert isinstance(result, ScimJSONResponse)
parsed = json.loads(result.body)
schema_ids = [s["id"] for s in parsed["Resources"]]
assert SCIM_ENTERPRISE_USER_SCHEMA in schema_ids
def test_enterprise_schema_has_expected_attributes(self) -> None:
result = get_schemas()
assert isinstance(result, ScimJSONResponse)
parsed = json.loads(result.body)
enterprise = next(
s for s in parsed["Resources"] if s["id"] == SCIM_ENTERPRISE_USER_SCHEMA
)
attr_names = {a["name"] for a in enterprise["attributes"]}
assert "department" in attr_names
assert "manager" in attr_names
def test_service_discovery_content_type(self) -> None:
"""SCIM responses must use application/scim+json content type."""
result = get_resource_types()
assert isinstance(result, ScimJSONResponse)
assert result.media_type == "application/scim+json"
# ---------------------------------------------------------------------------
# User Lifecycle (Entra-specific)
# ---------------------------------------------------------------------------
class TestEntraUserLifecycle:
"""Test user CRUD through Entra's lens: enterprise schemas, PascalCase ops."""
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
def test_create_user_includes_enterprise_schema(
self,
mock_seats: MagicMock, # noqa: ARG002
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
mock_dal.get_user_by_email.return_value = None
resource = make_scim_user(userName="alice@contoso.com")
result = create_user(
user_resource=resource,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
resource = parse_scim_user(result, status=201)
assert SCIM_ENTERPRISE_USER_SCHEMA in resource.schemas
assert SCIM_USER_SCHEMA in resource.schemas
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
def test_create_user_with_enterprise_extension(
self,
mock_seats: MagicMock, # noqa: ARG002
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""Enterprise extension department/manager should round-trip on create."""
mock_dal.get_user_by_email.return_value = None
resource = make_scim_user(
userName="alice@contoso.com",
enterprise_extension=ScimEnterpriseExtension(
department="Engineering",
manager=ScimManagerRef(value="mgr-uuid-123"),
),
)
result = create_user(
user_resource=resource,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
resource = parse_scim_user(result, status=201)
assert resource.enterprise_extension is not None
assert resource.enterprise_extension.department == "Engineering"
assert resource.enterprise_extension.manager is not None
assert resource.enterprise_extension.manager.value == "mgr-uuid-123"
# Verify DAL received the enterprise fields
mock_dal.create_user_mapping.assert_called_once()
call_kwargs = mock_dal.create_user_mapping.call_args[1]
assert call_kwargs["fields"] == ScimMappingFields(
department="Engineering",
manager="mgr-uuid-123",
given_name="Test",
family_name="User",
)
def test_get_user_includes_enterprise_schema(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
user = make_db_user(email="alice@contoso.com")
mock_dal.get_user.return_value = user
result = get_user(
user_id=str(user.id),
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
resource = parse_scim_user(result)
assert SCIM_ENTERPRISE_USER_SCHEMA in resource.schemas
def test_get_user_returns_enterprise_extension_data(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""GET should return stored enterprise extension data."""
user = make_db_user(email="alice@contoso.com")
mock_dal.get_user.return_value = user
mapping = make_user_mapping(user_id=user.id)
mapping.department = "Sales"
mapping.manager = "mgr-456"
mock_dal.get_user_mapping_by_user_id.return_value = mapping
result = get_user(
user_id=str(user.id),
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
resource = parse_scim_user(result)
assert resource.enterprise_extension is not None
assert resource.enterprise_extension.department == "Sales"
assert resource.enterprise_extension.manager is not None
assert resource.enterprise_extension.manager.value == "mgr-456"
def test_list_users_includes_enterprise_schema(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
user = make_db_user(email="alice@contoso.com")
mapping = make_user_mapping(external_id="entra-ext-1", user_id=user.id)
mock_dal.list_users.return_value = ([(user, mapping)], 1)
result = list_users(
filter=None,
startIndex=1,
count=100,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
parsed = parse_scim_list(result)
resource = parsed.Resources[0]
assert isinstance(resource, ScimUserResource)
assert SCIM_ENTERPRISE_USER_SCHEMA in resource.schemas
def test_patch_user_deactivate_with_pascal_case_replace(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""Entra sends ``"Replace"`` (PascalCase) instead of ``"replace"``."""
user = make_db_user(is_active=True)
mock_dal.get_user.return_value = user
patch_req = ScimPatchRequest(
Operations=[
ScimPatchOperation(
op="Replace", # type: ignore[arg-type]
path="active",
value=False,
)
]
)
result = patch_user(
user_id=str(user.id),
patch_request=patch_req,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
parse_scim_user(result)
# Mock doesn't propagate the change, so verify via the DAL call
mock_dal.update_user.assert_called_once()
call_kwargs = mock_dal.update_user.call_args
assert call_kwargs[1]["is_active"] is False
def test_patch_user_add_external_id_with_pascal_case(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""Entra sends ``"Add"`` (PascalCase) instead of ``"add"``."""
user = make_db_user()
mock_dal.get_user.return_value = user
patch_req = ScimPatchRequest(
Operations=[
ScimPatchOperation(
op="Add", # type: ignore[arg-type]
path="externalId",
value="entra-ext-999",
)
]
)
result = patch_user(
user_id=str(user.id),
patch_request=patch_req,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
parse_scim_user(result)
# Verify the patched externalId was synced to the DAL
mock_dal.sync_user_external_id.assert_called_once()
call_args = mock_dal.sync_user_external_id.call_args
assert call_args[0][1] == "entra-ext-999"
def test_patch_user_enterprise_extension_in_value_dict(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""Entra sends enterprise extension URN as key in path-less PATCH value
dicts — enterprise data should be stored, not ignored."""
user = make_db_user()
mock_dal.get_user.return_value = user
value = ScimPatchResourceValue(active=False)
assert value.__pydantic_extra__ is not None
value.__pydantic_extra__[
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
] = {"department": "Engineering"}
patch_req = ScimPatchRequest(
Operations=[
ScimPatchOperation(
op=ScimPatchOperationType.REPLACE,
path=None,
value=value,
)
]
)
result = patch_user(
user_id=str(user.id),
patch_request=patch_req,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
parse_scim_user(result)
# Verify active=False was applied
mock_dal.update_user.assert_called_once()
call_kwargs = mock_dal.update_user.call_args
assert call_kwargs[1]["is_active"] is False
# Verify enterprise data was passed to DAL
mock_dal.sync_user_external_id.assert_called_once()
sync_kwargs = mock_dal.sync_user_external_id.call_args[1]
assert sync_kwargs["fields"] == ScimMappingFields(
department="Engineering",
given_name="Test",
family_name="User",
scim_emails_json='[{"value": "test@example.com", "type": "work", "primary": true}]',
)
def test_patch_user_remove_external_id(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""PATCH remove op should clear the target field."""
user = make_db_user()
mock_dal.get_user.return_value = user
mapping = make_user_mapping(user_id=user.id)
mapping.external_id = "ext-to-remove"
mock_dal.get_user_mapping_by_user_id.return_value = mapping
patch_req = ScimPatchRequest(
Operations=[
ScimPatchOperation(
op=ScimPatchOperationType.REMOVE,
path="externalId",
)
]
)
result = patch_user(
user_id=str(user.id),
patch_request=patch_req,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
parse_scim_user(result)
# externalId should be cleared (None)
mock_dal.sync_user_external_id.assert_called_once()
call_args = mock_dal.sync_user_external_id.call_args
assert call_args[0][1] is None
def test_patch_user_emails_primary_eq_true_value(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""PATCH with path emails[primary eq true].value should update
the primary email entry, not userName."""
user = make_db_user(email="old@contoso.com")
mock_dal.get_user.return_value = user
patch_req = ScimPatchRequest(
Operations=[
ScimPatchOperation(
op=ScimPatchOperationType.REPLACE,
path="emails[primary eq true].value",
value="new@contoso.com",
)
]
)
result = patch_user(
user_id=str(user.id),
patch_request=patch_req,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
resource = parse_scim_user(result)
# userName should remain unchanged — emails and userName are separate
assert resource.userName == "old@contoso.com"
# Primary email should be updated
primary_emails = [e for e in resource.emails if e.primary]
assert len(primary_emails) == 1
assert primary_emails[0].value == "new@contoso.com"
def test_patch_user_enterprise_urn_department_path(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""PATCH with dotted enterprise URN path should store department."""
user = make_db_user()
mock_dal.get_user.return_value = user
patch_req = ScimPatchRequest(
Operations=[
ScimPatchOperation(
op=ScimPatchOperationType.REPLACE,
path="urn:ietf:params:scim:schemas:extension:enterprise:2.0:User:department",
value="Marketing",
)
]
)
result = patch_user(
user_id=str(user.id),
patch_request=patch_req,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
parse_scim_user(result)
mock_dal.sync_user_external_id.assert_called_once()
sync_kwargs = mock_dal.sync_user_external_id.call_args[1]
assert sync_kwargs["fields"] == ScimMappingFields(
department="Marketing",
given_name="Test",
family_name="User",
scim_emails_json='[{"value": "test@example.com", "type": "work", "primary": true}]',
)
def test_replace_user_includes_enterprise_schema(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
user = make_db_user(email="old@contoso.com")
mock_dal.get_user.return_value = user
resource = make_scim_user(
userName="new@contoso.com",
name=ScimName(givenName="New", familyName="Name"),
)
result = replace_user(
user_id=str(user.id),
user_resource=resource,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
resource = parse_scim_user(result)
assert SCIM_ENTERPRISE_USER_SCHEMA in resource.schemas
def test_replace_user_with_enterprise_extension(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""PUT with enterprise extension should store the fields."""
user = make_db_user(email="alice@contoso.com")
mock_dal.get_user.return_value = user
resource = make_scim_user(
userName="alice@contoso.com",
enterprise_extension=ScimEnterpriseExtension(
department="HR",
manager=ScimManagerRef(value="boss-id"),
),
)
result = replace_user(
user_id=str(user.id),
user_resource=resource,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
parse_scim_user(result)
mock_dal.sync_user_external_id.assert_called_once()
sync_kwargs = mock_dal.sync_user_external_id.call_args[1]
assert sync_kwargs["fields"] == ScimMappingFields(
department="HR",
manager="boss-id",
given_name="Test",
family_name="User",
)
def test_delete_user_returns_204(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
user = make_db_user()
mock_dal.get_user.return_value = user
mock_dal.get_user_mapping_by_user_id.return_value = MagicMock(id=1)
result = delete_user(
user_id=str(user.id),
_token=mock_token,
db_session=mock_db_session,
)
assert isinstance(result, Response)
assert result.status_code == 204
def test_double_delete_returns_404(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
) -> None:
"""Second DELETE should return 404 — the SCIM mapping is gone."""
user = make_db_user()
mock_dal.get_user.return_value = user
# No mapping — user was already deleted from SCIM's perspective
mock_dal.get_user_mapping_by_user_id.return_value = None
result = delete_user(
user_id=str(user.id),
_token=mock_token,
db_session=mock_db_session,
)
assert isinstance(result, ScimJSONResponse)
assert result.status_code == 404
def test_name_formatted_preserved_on_create(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""When name.formatted is provided, it should be used as personal_name."""
mock_dal.get_user_by_email.return_value = None
resource = make_scim_user(
userName="alice@contoso.com",
name=ScimName(
givenName="Alice",
familyName="Smith",
formatted="Dr. Alice Smith",
),
)
with patch(
"ee.onyx.server.scim.api._check_seat_availability", return_value=None
):
result = create_user(
user_resource=resource,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
parse_scim_user(result, status=201)
# The User constructor should have received the formatted name
mock_dal.add_user.assert_called_once()
created_user = mock_dal.add_user.call_args[0][0]
assert created_user.personal_name == "Dr. Alice Smith"
# ---------------------------------------------------------------------------
# Group Lifecycle (Entra-specific)
# ---------------------------------------------------------------------------
class TestEntraGroupLifecycle:
"""Test group CRUD with Entra-specific behaviors."""
def test_get_group_standard_response(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
group = make_db_group(id=10, name="Contoso Engineering")
mock_dal.get_group.return_value = group
uid = uuid4()
mock_dal.get_group_members.return_value = [(uid, "alice@contoso.com")]
result = get_group(
group_id="10",
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
resource = parse_scim_group(result)
assert resource.displayName == "Contoso Engineering"
assert len(resource.members) == 1
def test_list_groups_with_excluded_attributes_members(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""Entra sends ?excludedAttributes=members on group list queries."""
group = make_db_group(id=10, name="Engineering")
uid = uuid4()
mock_dal.list_groups.return_value = ([(group, "ext-g-1")], 1)
mock_dal.get_group_members.return_value = [(uid, "alice@contoso.com")]
result = list_groups(
filter=None,
excludedAttributes="members",
startIndex=1,
count=100,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
assert isinstance(result, ScimJSONResponse)
parsed = json.loads(result.body)
assert parsed["totalResults"] == 1
resource = parsed["Resources"][0]
assert "members" not in resource
assert resource["displayName"] == "Engineering"
def test_get_group_with_excluded_attributes_members(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""Entra sends ?excludedAttributes=members on single group GET."""
group = make_db_group(id=10, name="Engineering")
uid = uuid4()
mock_dal.get_group.return_value = group
mock_dal.get_group_members.return_value = [(uid, "alice@contoso.com")]
result = get_group(
group_id="10",
excludedAttributes="members",
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
assert isinstance(result, ScimJSONResponse)
parsed = json.loads(result.body)
assert "members" not in parsed
assert parsed["displayName"] == "Engineering"
@patch("ee.onyx.server.scim.api.apply_group_patch")
def test_patch_group_add_members_with_pascal_case(
self,
mock_apply: MagicMock,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""Entra sends ``"Add"`` (PascalCase) for group member additions."""
group = make_db_group(id=10)
mock_dal.get_group.return_value = group
mock_dal.get_group_members.return_value = []
mock_dal.validate_member_ids.return_value = []
uid = str(uuid4())
patched = ScimGroupResource(
id="10",
displayName="Engineering",
members=[ScimGroupMember(value=uid)],
)
mock_apply.return_value = (patched, [uid], [])
patch_req = ScimPatchRequest(
Operations=[
ScimPatchOperation(
op="Add", # type: ignore[arg-type]
path="members",
value=[ScimGroupMember(value=uid)],
)
]
)
result = patch_group(
group_id="10",
patch_request=patch_req,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
parse_scim_group(result)
mock_dal.upsert_group_members.assert_called_once()
@patch("ee.onyx.server.scim.api.apply_group_patch")
def test_patch_group_remove_member_with_pascal_case(
self,
mock_apply: MagicMock,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""Entra sends ``"Remove"`` (PascalCase) for group member removals."""
group = make_db_group(id=10)
mock_dal.get_group.return_value = group
mock_dal.get_group_members.return_value = []
uid = str(uuid4())
patched = ScimGroupResource(id="10", displayName="Engineering", members=[])
mock_apply.return_value = (patched, [], [uid])
patch_req = ScimPatchRequest(
Operations=[
ScimPatchOperation(
op="Remove", # type: ignore[arg-type]
path=f'members[value eq "{uid}"]',
)
]
)
result = patch_group(
group_id="10",
patch_request=patch_req,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
parse_scim_group(result)
mock_dal.remove_group_members.assert_called_once()
# ---------------------------------------------------------------------------
# excludedAttributes (RFC 7644 §3.4.2.5)
# ---------------------------------------------------------------------------
class TestExcludedAttributes:
"""Test excludedAttributes query parameter on GET endpoints."""
def test_list_groups_excludes_members(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
group = make_db_group(id=1, name="Team")
uid = uuid4()
mock_dal.list_groups.return_value = ([(group, None)], 1)
mock_dal.get_group_members.return_value = [(uid, "user@example.com")]
result = list_groups(
filter=None,
excludedAttributes="members",
startIndex=1,
count=100,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
assert isinstance(result, ScimJSONResponse)
parsed = json.loads(result.body)
resource = parsed["Resources"][0]
assert "members" not in resource
assert "displayName" in resource
def test_get_group_excludes_members(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
group = make_db_group(id=1, name="Team")
uid = uuid4()
mock_dal.get_group.return_value = group
mock_dal.get_group_members.return_value = [(uid, "user@example.com")]
result = get_group(
group_id="1",
excludedAttributes="members",
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
assert isinstance(result, ScimJSONResponse)
parsed = json.loads(result.body)
assert "members" not in parsed
assert "displayName" in parsed
def test_list_users_excludes_groups(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
user = make_db_user()
mapping = make_user_mapping(user_id=user.id)
mock_dal.list_users.return_value = ([(user, mapping)], 1)
mock_dal.get_users_groups_batch.return_value = {user.id: [(1, "Engineering")]}
result = list_users(
filter=None,
excludedAttributes="groups",
startIndex=1,
count=100,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
assert isinstance(result, ScimJSONResponse)
parsed = json.loads(result.body)
resource = parsed["Resources"][0]
assert "groups" not in resource
assert "userName" in resource
def test_get_user_excludes_groups(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
user = make_db_user()
mock_dal.get_user.return_value = user
mock_dal.get_user_groups.return_value = [(1, "Engineering")]
result = get_user(
user_id=str(user.id),
excludedAttributes="groups",
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
assert isinstance(result, ScimJSONResponse)
parsed = json.loads(result.body)
assert "groups" not in parsed
assert "userName" in parsed
def test_multiple_excluded_attributes(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
group = make_db_group(id=1, name="Team")
mock_dal.get_group.return_value = group
mock_dal.get_group_members.return_value = []
result = get_group(
group_id="1",
excludedAttributes="members,externalId",
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
assert isinstance(result, ScimJSONResponse)
parsed = json.loads(result.body)
assert "members" not in parsed
assert "externalId" not in parsed
assert "displayName" in parsed
def test_no_excluded_attributes_returns_full_response(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
group = make_db_group(id=1, name="Team")
uid = uuid4()
mock_dal.get_group.return_value = group
mock_dal.get_group_members.return_value = [(uid, "user@example.com")]
result = get_group(
group_id="1",
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
resource = parse_scim_group(result)
assert len(resource.members) == 1
# ---------------------------------------------------------------------------
# Entra Connection Probe
# ---------------------------------------------------------------------------
class TestEntraConnectionProbe:
"""Entra sends a probe request during initial SCIM setup."""
def test_filter_for_nonexistent_user_returns_empty_list(
self,
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
entra_provider: ScimProvider,
) -> None:
"""Entra probes with: GET /Users?filter=userName eq "non-existent"&count=1"""
mock_dal.list_users.return_value = ([], 0)
result = list_users(
filter='userName eq "non-existent@contoso.com"',
startIndex=1,
count=1,
_token=mock_token,
provider=entra_provider,
db_session=mock_db_session,
)
parsed = parse_scim_list(result)
assert parsed.totalResults == 0
assert parsed.Resources == []

View File

@@ -16,7 +16,6 @@ from ee.onyx.server.scim.api import patch_group
from ee.onyx.server.scim.api import replace_group
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimListResponse
from ee.onyx.server.scim.models import ScimPatchOperation
from ee.onyx.server.scim.models import ScimPatchOperationType
from ee.onyx.server.scim.models import ScimPatchRequest
@@ -25,6 +24,8 @@ from ee.onyx.server.scim.providers.base import ScimProvider
from tests.unit.onyx.server.scim.conftest import assert_scim_error
from tests.unit.onyx.server.scim.conftest import make_db_group
from tests.unit.onyx.server.scim.conftest import make_scim_group
from tests.unit.onyx.server.scim.conftest import parse_scim_group
from tests.unit.onyx.server.scim.conftest import parse_scim_list
class TestListGroups:
@@ -48,9 +49,9 @@ class TestListGroups:
db_session=mock_db_session,
)
assert isinstance(result, ScimListResponse)
assert result.totalResults == 0
assert result.Resources == []
parsed = parse_scim_list(result)
assert parsed.totalResults == 0
assert parsed.Resources == []
def test_unsupported_filter_returns_400(
self,
@@ -95,9 +96,9 @@ class TestListGroups:
db_session=mock_db_session,
)
assert isinstance(result, ScimListResponse)
assert result.totalResults == 1
resource = result.Resources[0]
parsed = parse_scim_list(result)
assert parsed.totalResults == 1
resource = parsed.Resources[0]
assert isinstance(resource, ScimGroupResource)
assert resource.displayName == "Engineering"
assert resource.externalId == "ext-g-1"
@@ -126,9 +127,9 @@ class TestGetGroup:
db_session=mock_db_session,
)
assert isinstance(result, ScimGroupResource)
assert result.displayName == "Engineering"
assert result.id == "5"
resource = parse_scim_group(result)
assert resource.displayName == "Engineering"
assert resource.id == "5"
def test_non_integer_id_returns_404(
self,
@@ -190,8 +191,8 @@ class TestCreateGroup:
db_session=mock_db_session,
)
assert isinstance(result, ScimGroupResource)
assert result.displayName == "New Group"
resource = parse_scim_group(result, status=201)
assert resource.displayName == "New Group"
mock_dal.add_group.assert_called_once()
mock_dal.commit.assert_called_once()
@@ -283,7 +284,7 @@ class TestCreateGroup:
db_session=mock_db_session,
)
assert isinstance(result, ScimGroupResource)
parse_scim_group(result, status=201)
mock_dal.create_group_mapping.assert_called_once()
@@ -314,7 +315,7 @@ class TestReplaceGroup:
db_session=mock_db_session,
)
assert isinstance(result, ScimGroupResource)
parse_scim_group(result)
mock_dal.update_group.assert_called_once_with(group, name="New Name")
mock_dal.replace_group_members.assert_called_once()
mock_dal.commit.assert_called_once()
@@ -427,7 +428,7 @@ class TestPatchGroup:
db_session=mock_db_session,
)
assert isinstance(result, ScimGroupResource)
parse_scim_group(result)
mock_dal.update_group.assert_called_once_with(group, name="New Name")
def test_not_found_returns_404(
@@ -534,7 +535,7 @@ class TestPatchGroup:
db_session=mock_db_session,
)
assert isinstance(result, ScimGroupResource)
parse_scim_group(result)
mock_dal.validate_member_ids.assert_called_once()
mock_dal.upsert_group_members.assert_called_once()
@@ -614,7 +615,7 @@ class TestPatchGroup:
db_session=mock_db_session,
)
assert isinstance(result, ScimGroupResource)
parse_scim_group(result)
mock_dal.remove_group_members.assert_called_once()

View File

@@ -1,5 +1,6 @@
import pytest
from ee.onyx.server.scim.models import ScimEmail
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimMeta
@@ -12,9 +13,11 @@ from ee.onyx.server.scim.models import ScimUserResource
from ee.onyx.server.scim.patch import apply_group_patch
from ee.onyx.server.scim.patch import apply_user_patch
from ee.onyx.server.scim.patch import ScimPatchError
from ee.onyx.server.scim.providers.entra import EntraProvider
from ee.onyx.server.scim.providers.okta import OktaProvider
_OKTA_IGNORED = OktaProvider().ignored_patch_paths
_ENTRA_IGNORED = EntraProvider().ignored_patch_paths
def _make_user(**kwargs: object) -> ScimUserResource:
@@ -56,36 +59,36 @@ class TestApplyUserPatch:
def test_deactivate_user(self) -> None:
user = _make_user()
result = apply_user_patch([_replace_op("active", False)], user)
result, _ = apply_user_patch([_replace_op("active", False)], user)
assert result.active is False
assert result.userName == "test@example.com"
def test_activate_user(self) -> None:
user = _make_user(active=False)
result = apply_user_patch([_replace_op("active", True)], user)
result, _ = apply_user_patch([_replace_op("active", True)], user)
assert result.active is True
def test_replace_given_name(self) -> None:
user = _make_user()
result = apply_user_patch([_replace_op("name.givenName", "NewFirst")], user)
result, _ = apply_user_patch([_replace_op("name.givenName", "NewFirst")], user)
assert result.name is not None
assert result.name.givenName == "NewFirst"
assert result.name.familyName == "User"
def test_replace_family_name(self) -> None:
user = _make_user()
result = apply_user_patch([_replace_op("name.familyName", "NewLast")], user)
result, _ = apply_user_patch([_replace_op("name.familyName", "NewLast")], user)
assert result.name is not None
assert result.name.familyName == "NewLast"
def test_replace_username(self) -> None:
user = _make_user()
result = apply_user_patch([_replace_op("userName", "new@example.com")], user)
result, _ = apply_user_patch([_replace_op("userName", "new@example.com")], user)
assert result.userName == "new@example.com"
def test_replace_without_path_uses_dict(self) -> None:
user = _make_user()
result = apply_user_patch(
result, _ = apply_user_patch(
[
_replace_op(
None,
@@ -99,7 +102,7 @@ class TestApplyUserPatch:
def test_multiple_operations(self) -> None:
user = _make_user()
result = apply_user_patch(
result, _ = apply_user_patch(
[
_replace_op("active", False),
_replace_op("name.givenName", "Updated"),
@@ -112,7 +115,7 @@ class TestApplyUserPatch:
def test_case_insensitive_path(self) -> None:
user = _make_user()
result = apply_user_patch([_replace_op("Active", False)], user)
result, _ = apply_user_patch([_replace_op("Active", False)], user)
assert result.active is False
def test_original_not_mutated(self) -> None:
@@ -125,15 +128,22 @@ class TestApplyUserPatch:
with pytest.raises(ScimPatchError, match="Unsupported path"):
apply_user_patch([_replace_op("unknownField", "value")], user)
def test_remove_op_on_user_raises(self) -> None:
def test_remove_op_clears_field(self) -> None:
"""Remove op should clear the target field (not raise)."""
user = _make_user(externalId="ext-123")
result, _ = apply_user_patch([_remove_op("externalId")], user)
assert result.externalId is None
def test_remove_unsupported_path_raises(self) -> None:
"""Remove op on unsupported path (e.g. 'active') should raise."""
user = _make_user()
with pytest.raises(ScimPatchError, match="Unsupported operation"):
with pytest.raises(ScimPatchError, match="Unsupported remove path"):
apply_user_patch([_remove_op("active")], user)
def test_replace_without_path_ignores_id(self) -> None:
"""Okta sends 'id' alongside actual changes — it should be silently ignored."""
user = _make_user()
result = apply_user_patch(
result, _ = apply_user_patch(
[_replace_op(None, ScimPatchResourceValue(active=False, id="some-uuid"))],
user,
ignored_paths=_OKTA_IGNORED,
@@ -143,7 +153,7 @@ class TestApplyUserPatch:
def test_replace_without_path_ignores_schemas(self) -> None:
"""The 'schemas' key in a value dict should be silently ignored."""
user = _make_user()
result = apply_user_patch(
result, _ = apply_user_patch(
[
_replace_op(
None,
@@ -161,7 +171,7 @@ class TestApplyUserPatch:
def test_okta_deactivation_payload(self) -> None:
"""Exact Okta deactivation payload: path-less replace with id + active."""
user = _make_user()
result = apply_user_patch(
result, _ = apply_user_patch(
[
_replace_op(
None,
@@ -176,7 +186,7 @@ class TestApplyUserPatch:
def test_replace_displayname(self) -> None:
user = _make_user()
result = apply_user_patch(
result, _ = apply_user_patch(
[_replace_op("displayName", "New Display Name")], user
)
assert result.displayName == "New Display Name"
@@ -187,7 +197,7 @@ class TestApplyUserPatch:
"""Okta sends id/schemas/meta alongside actual changes — complex types
(lists, nested dicts) must not cause Pydantic validation errors."""
user = _make_user()
result = apply_user_patch(
result, _ = apply_user_patch(
[
_replace_op(
None,
@@ -207,9 +217,101 @@ class TestApplyUserPatch:
def test_add_operation_works_like_replace(self) -> None:
user = _make_user()
result = apply_user_patch([_add_op("externalId", "ext-456")], user)
result, _ = apply_user_patch([_add_op("externalId", "ext-456")], user)
assert result.externalId == "ext-456"
def test_entra_capitalized_replace_op(self) -> None:
"""Entra ID sends ``"Replace"`` instead of ``"replace"``."""
user = _make_user()
op = ScimPatchOperation(op="Replace", path="active", value=False) # type: ignore[arg-type]
result, _ = apply_user_patch([op], user)
assert result.active is False
def test_entra_capitalized_add_op(self) -> None:
"""Entra ID sends ``"Add"`` instead of ``"add"``."""
user = _make_user()
op = ScimPatchOperation(op="Add", path="externalId", value="ext-999") # type: ignore[arg-type]
result, _ = apply_user_patch([op], user)
assert result.externalId == "ext-999"
def test_entra_enterprise_extension_handled(self) -> None:
"""Entra sends the enterprise extension URN as a key in path-less
PATCH value dicts — enterprise data should be captured in ent_data."""
user = _make_user()
value = ScimPatchResourceValue(active=False)
# Simulate Entra including the enterprise extension URN as extra data
assert value.__pydantic_extra__ is not None
value.__pydantic_extra__[
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
] = {"department": "Engineering"}
result, ent_data = apply_user_patch(
[_replace_op(None, value)],
user,
ignored_paths=_ENTRA_IGNORED,
)
assert result.active is False
assert result.userName == "test@example.com"
assert ent_data["department"] == "Engineering"
def test_okta_handles_enterprise_extension_urn(self) -> None:
"""Enterprise extension URN paths are handled universally, even
for Okta — the data is captured in the enterprise data dict."""
user = _make_user()
value = ScimPatchResourceValue(active=False)
assert value.__pydantic_extra__ is not None
value.__pydantic_extra__[
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
] = {"department": "Engineering"}
result, ent_data = apply_user_patch(
[_replace_op(None, value)],
user,
ignored_paths=_OKTA_IGNORED,
)
assert result.active is False
assert ent_data["department"] == "Engineering"
def test_emails_primary_eq_true_value(self) -> None:
"""emails[primary eq true].value should update the primary email entry."""
user = _make_user(
emails=[ScimEmail(value="old@example.com", type="work", primary=True)]
)
result, _ = apply_user_patch(
[_replace_op("emails[primary eq true].value", "new@example.com")], user
)
# userName should remain unchanged — emails and userName are separate
assert result.userName == "test@example.com"
assert len(result.emails) == 1
assert result.emails[0].value == "new@example.com"
assert result.emails[0].primary is True
def test_enterprise_urn_department_path(self) -> None:
"""Dotted enterprise URN path should set department in ent_data."""
user = _make_user()
_, ent_data = apply_user_patch(
[
_replace_op(
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User:department",
"Marketing",
)
],
user,
)
assert ent_data["department"] == "Marketing"
def test_enterprise_urn_manager_path(self) -> None:
"""Dotted enterprise URN path for manager should set manager."""
user = _make_user()
_, ent_data = apply_user_patch(
[
_replace_op(
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User:manager",
ScimPatchResourceValue.model_validate({"value": "boss-id"}),
)
],
user,
)
assert ent_data["manager"] == "boss-id"
class TestApplyGroupPatch:
"""Tests for SCIM group PATCH operations."""

View File

@@ -2,6 +2,8 @@ from unittest.mock import MagicMock
from uuid import UUID
from uuid import uuid4
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
from ee.onyx.server.scim.models import ScimEmail
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
@@ -9,7 +11,10 @@ from ee.onyx.server.scim.models import ScimMeta
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimUserGroupRef
from ee.onyx.server.scim.models import ScimUserResource
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
from ee.onyx.server.scim.providers.base import get_default_provider
from ee.onyx.server.scim.providers.entra import _ENTRA_IGNORED_PATCH_PATHS
from ee.onyx.server.scim.providers.entra import EntraProvider
from ee.onyx.server.scim.providers.okta import OktaProvider
@@ -39,9 +44,7 @@ class TestOktaProvider:
assert OktaProvider().name == "okta"
def test_ignored_patch_paths(self) -> None:
assert OktaProvider().ignored_patch_paths == frozenset(
{"id", "schemas", "meta"}
)
assert OktaProvider().ignored_patch_paths == COMMON_IGNORED_PATCH_PATHS
def test_build_user_resource_basic(self) -> None:
provider = OktaProvider()
@@ -60,6 +63,12 @@ class TestOktaProvider:
meta=ScimMeta(resourceType="User"),
)
def test_build_user_resource_has_core_schema_only(self) -> None:
provider = OktaProvider()
user = _make_mock_user()
result = provider.build_user_resource(user, "ext-123")
assert result.schemas == [SCIM_USER_SCHEMA]
def test_build_user_resource_with_groups(self) -> None:
provider = OktaProvider()
user = _make_mock_user()
@@ -161,6 +170,42 @@ class TestOktaProvider:
assert result.members == []
class TestEntraProvider:
def test_name(self) -> None:
assert EntraProvider().name == "entra"
def test_ignored_patch_paths(self) -> None:
paths = EntraProvider().ignored_patch_paths
assert paths == _ENTRA_IGNORED_PATCH_PATHS
# Enterprise extension URN is now handled (not ignored)
assert paths >= COMMON_IGNORED_PATCH_PATHS
def test_build_user_resource_includes_enterprise_schema(self) -> None:
provider = EntraProvider()
user = _make_mock_user()
result = provider.build_user_resource(user, "ext-entra-1")
assert result.schemas == [SCIM_USER_SCHEMA, SCIM_ENTERPRISE_USER_SCHEMA]
def test_build_user_resource_basic(self) -> None:
provider = EntraProvider()
user = _make_mock_user()
result = provider.build_user_resource(user, "ext-entra-1")
assert result == ScimUserResource(
schemas=[SCIM_USER_SCHEMA, SCIM_ENTERPRISE_USER_SCHEMA],
id=str(user.id),
externalId="ext-entra-1",
userName="test@example.com",
name=ScimName(givenName="Test", familyName="User", formatted="Test User"),
displayName="Test User",
emails=[ScimEmail(value="test@example.com", type="work", primary=True)],
active=True,
groups=[],
meta=ScimMeta(resourceType="User"),
)
class TestGetDefaultProvider:
def test_returns_okta(self) -> None:
provider = get_default_provider()

View File

@@ -16,7 +16,7 @@ from ee.onyx.server.scim.api import get_user
from ee.onyx.server.scim.api import list_users
from ee.onyx.server.scim.api import patch_user
from ee.onyx.server.scim.api import replace_user
from ee.onyx.server.scim.models import ScimListResponse
from ee.onyx.server.scim.models import ScimMappingFields
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimPatchOperation
from ee.onyx.server.scim.models import ScimPatchOperationType
@@ -28,6 +28,8 @@ from tests.unit.onyx.server.scim.conftest import assert_scim_error
from tests.unit.onyx.server.scim.conftest import make_db_user
from tests.unit.onyx.server.scim.conftest import make_scim_user
from tests.unit.onyx.server.scim.conftest import make_user_mapping
from tests.unit.onyx.server.scim.conftest import parse_scim_list
from tests.unit.onyx.server.scim.conftest import parse_scim_user
class TestListUsers:
@@ -51,9 +53,9 @@ class TestListUsers:
db_session=mock_db_session,
)
assert isinstance(result, ScimListResponse)
assert result.totalResults == 0
assert result.Resources == []
parsed = parse_scim_list(result)
assert parsed.totalResults == 0
assert parsed.Resources == []
def test_returns_users_with_scim_shape(
self,
@@ -77,10 +79,10 @@ class TestListUsers:
db_session=mock_db_session,
)
assert isinstance(result, ScimListResponse)
assert result.totalResults == 1
assert len(result.Resources) == 1
resource = result.Resources[0]
parsed = parse_scim_list(result)
assert parsed.totalResults == 1
assert len(parsed.Resources) == 1
resource = parsed.Resources[0]
assert isinstance(resource, ScimUserResource)
assert resource.userName == "Alice@example.com"
assert resource.externalId == "ext-abc"
@@ -146,9 +148,9 @@ class TestGetUser:
db_session=mock_db_session,
)
assert isinstance(result, ScimUserResource)
assert result.userName == "alice@example.com"
assert result.id == str(user.id)
resource = parse_scim_user(result)
assert resource.userName == "alice@example.com"
assert resource.id == str(user.id)
def test_invalid_uuid_returns_404(
self,
@@ -207,8 +209,8 @@ class TestCreateUser:
db_session=mock_db_session,
)
assert isinstance(result, ScimUserResource)
assert result.userName == "new@example.com"
resource = parse_scim_user(result, status=201)
assert resource.userName == "new@example.com"
mock_dal.add_user.assert_called_once()
mock_dal.commit.assert_called_once()
@@ -314,8 +316,8 @@ class TestCreateUser:
db_session=mock_db_session,
)
assert isinstance(result, ScimUserResource)
assert result.externalId == "ext-123"
resource = parse_scim_user(result, status=201)
assert resource.externalId == "ext-123"
mock_dal.create_user_mapping.assert_called_once()
@@ -344,7 +346,7 @@ class TestReplaceUser:
db_session=mock_db_session,
)
assert isinstance(result, ScimUserResource)
parse_scim_user(result)
mock_dal.update_user.assert_called_once()
mock_dal.commit.assert_called_once()
@@ -412,9 +414,15 @@ class TestReplaceUser:
db_session=mock_db_session,
)
assert isinstance(result, ScimUserResource)
parse_scim_user(result)
mock_dal.sync_user_external_id.assert_called_once_with(
user.id, None, scim_username="test@example.com"
user.id,
None,
scim_username="test@example.com",
fields=ScimMappingFields(
given_name="Test",
family_name="User",
),
)
@@ -448,7 +456,7 @@ class TestPatchUser:
db_session=mock_db_session,
)
assert isinstance(result, ScimUserResource)
parse_scim_user(result)
mock_dal.update_user.assert_called_once()
def test_not_found_returns_404(
@@ -507,7 +515,7 @@ class TestPatchUser:
db_session=mock_db_session,
)
assert isinstance(result, ScimUserResource)
parse_scim_user(result)
# Verify the update_user call received the new display name
call_kwargs = mock_dal.update_user.call_args
assert call_kwargs[1]["personal_name"] == "New Display Name"
@@ -605,10 +613,12 @@ class TestDeleteUser:
class TestScimNameToStr:
"""Tests for _scim_name_to_str helper."""
def test_prefers_given_family_over_formatted(self) -> None:
"""Okta may send stale formatted while updating givenName/familyName."""
name = ScimName(givenName="Jane", familyName="Smith", formatted="Old Name")
assert _scim_name_to_str(name) == "Jane Smith"
def test_prefers_formatted_over_components(self) -> None:
"""When client provides formatted, use it — the client knows what it wants."""
name = ScimName(
givenName="Jane", familyName="Smith", formatted="Dr. Jane Smith"
)
assert _scim_name_to_str(name) == "Dr. Jane Smith"
def test_given_name_only(self) -> None:
name = ScimName(givenName="Jane")
@@ -653,9 +663,9 @@ class TestEmailCasePreservation:
db_session=mock_db_session,
)
assert isinstance(result, ScimUserResource)
assert result.userName == "Alice@Example.COM"
assert result.emails[0].value == "Alice@Example.COM"
resource = parse_scim_user(result, status=201)
assert resource.userName == "Alice@Example.COM"
assert resource.emails[0].value == "Alice@Example.COM"
def test_get_preserves_username_case(
self,
@@ -681,6 +691,6 @@ class TestEmailCasePreservation:
db_session=mock_db_session,
)
assert isinstance(result, ScimUserResource)
assert result.userName == "Alice@Example.COM"
assert result.emails[0].value == "Alice@Example.COM"
resource = parse_scim_user(result)
assert resource.userName == "Alice@Example.COM"
assert resource.emails[0].value == "Alice@Example.COM"

Some files were not shown because too many files have changed in this diff Show More