mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-26 04:05:48 +00:00
Compare commits
38 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d04128b8b1 | ||
|
|
bbebdf8f78 | ||
|
|
161279a2d5 | ||
|
|
e5ebb45a20 | ||
|
|
320ba9cb1b | ||
|
|
f2e8cb3114 | ||
|
|
43054a28ec | ||
|
|
dc74aa7b1f | ||
|
|
bd773191c2 | ||
|
|
66dbff41e6 | ||
|
|
1dcffe38bc | ||
|
|
c35e883564 | ||
|
|
fefcd58481 | ||
|
|
bdc89d9e3f | ||
|
|
f4d777b80d | ||
|
|
da4d57b5e3 | ||
|
|
dcdcd067bd | ||
|
|
8b15a29723 | ||
|
|
763853674f | ||
|
|
429b6f3465 | ||
|
|
37d5be1b40 | ||
|
|
8ab99dbb06 | ||
|
|
52799e9c7a | ||
|
|
aef009cc97 | ||
|
|
18d1ea1770 | ||
|
|
f336ad00f4 | ||
|
|
0558e687d9 | ||
|
|
784a99e24a | ||
|
|
da1f5a11f4 | ||
|
|
5633805890 | ||
|
|
0817b45ae1 | ||
|
|
af0e4bdebc | ||
|
|
4cd2320732 | ||
|
|
90a361f0e1 | ||
|
|
194efde97b | ||
|
|
d922a42262 | ||
|
|
f00c3a486e | ||
|
|
192080c9e4 |
73
.github/actions/build-backend-image/action.yml
vendored
Normal file
73
.github/actions/build-backend-image/action.yml
vendored
Normal file
@@ -0,0 +1,73 @@
|
||||
name: "Build Backend Image"
|
||||
description: "Builds and pushes the backend Docker image with cache reuse"
|
||||
inputs:
|
||||
runs-on-ecr-cache:
|
||||
description: "ECR cache registry from runs-on/action"
|
||||
required: true
|
||||
ref-name:
|
||||
description: "Git ref name used for cache suffix fallback"
|
||||
required: true
|
||||
pr-number:
|
||||
description: "Optional PR number for cache suffix"
|
||||
required: false
|
||||
default: ""
|
||||
github-sha:
|
||||
description: "Commit SHA used for cache keys"
|
||||
required: true
|
||||
run-id:
|
||||
description: "GitHub run ID used in output image tag"
|
||||
required: true
|
||||
docker-username:
|
||||
description: "Docker Hub username"
|
||||
required: true
|
||||
docker-token:
|
||||
description: "Docker Hub token"
|
||||
required: true
|
||||
docker-no-cache:
|
||||
description: "Set to 'true' to disable docker build cache"
|
||||
required: false
|
||||
default: "false"
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
shell: bash
|
||||
env:
|
||||
PR_NUMBER: ${{ inputs.pr-number }}
|
||||
REF_NAME: ${{ inputs.ref-name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ inputs.docker-username }}
|
||||
password: ${{ inputs.docker-token }}
|
||||
|
||||
- name: Build and push Backend Docker image
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
push: true
|
||||
tags: ${{ inputs.runs-on-ecr-cache }}:nightly-llm-it-backend-${{ inputs.run-id }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ inputs.github-sha }}
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache
|
||||
type=registry,ref=onyxdotapp/onyx-backend:latest
|
||||
cache-to: |
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ inputs.github-sha }},mode=max
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache,mode=max
|
||||
no-cache: ${{ inputs.docker-no-cache == 'true' }}
|
||||
75
.github/actions/build-integration-image/action.yml
vendored
Normal file
75
.github/actions/build-integration-image/action.yml
vendored
Normal file
@@ -0,0 +1,75 @@
|
||||
name: "Build Integration Image"
|
||||
description: "Builds and pushes the integration test image with docker bake"
|
||||
inputs:
|
||||
runs-on-ecr-cache:
|
||||
description: "ECR cache registry from runs-on/action"
|
||||
required: true
|
||||
ref-name:
|
||||
description: "Git ref name used for cache suffix fallback"
|
||||
required: true
|
||||
pr-number:
|
||||
description: "Optional PR number for cache suffix"
|
||||
required: false
|
||||
default: ""
|
||||
github-sha:
|
||||
description: "Commit SHA used for cache keys"
|
||||
required: true
|
||||
run-id:
|
||||
description: "GitHub run ID used in output image tag"
|
||||
required: true
|
||||
docker-username:
|
||||
description: "Docker Hub username"
|
||||
required: true
|
||||
docker-token:
|
||||
description: "Docker Hub token"
|
||||
required: true
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ inputs.docker-username }}
|
||||
password: ${{ inputs.docker-token }}
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
shell: bash
|
||||
env:
|
||||
PR_NUMBER: ${{ inputs.pr-number }}
|
||||
REF_NAME: ${{ inputs.ref-name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Build and push integration test image with Docker Bake
|
||||
shell: bash
|
||||
env:
|
||||
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
|
||||
TAG: nightly-llm-it-${{ inputs.run-id }}
|
||||
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
|
||||
HEAD_SHA: ${{ inputs.github-sha }}
|
||||
run: |
|
||||
docker buildx bake --push \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
|
||||
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA},mode=max \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX},mode=max \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA} \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX} \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA},mode=max \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX},mode=max \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
|
||||
integration
|
||||
68
.github/actions/build-model-server-image/action.yml
vendored
Normal file
68
.github/actions/build-model-server-image/action.yml
vendored
Normal file
@@ -0,0 +1,68 @@
|
||||
name: "Build Model Server Image"
|
||||
description: "Builds and pushes the model server Docker image with cache reuse"
|
||||
inputs:
|
||||
runs-on-ecr-cache:
|
||||
description: "ECR cache registry from runs-on/action"
|
||||
required: true
|
||||
ref-name:
|
||||
description: "Git ref name used for cache suffix fallback"
|
||||
required: true
|
||||
pr-number:
|
||||
description: "Optional PR number for cache suffix"
|
||||
required: false
|
||||
default: ""
|
||||
github-sha:
|
||||
description: "Commit SHA used for cache keys"
|
||||
required: true
|
||||
run-id:
|
||||
description: "GitHub run ID used in output image tag"
|
||||
required: true
|
||||
docker-username:
|
||||
description: "Docker Hub username"
|
||||
required: true
|
||||
docker-token:
|
||||
description: "Docker Hub token"
|
||||
required: true
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
shell: bash
|
||||
env:
|
||||
PR_NUMBER: ${{ inputs.pr-number }}
|
||||
REF_NAME: ${{ inputs.ref-name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ inputs.docker-username }}
|
||||
password: ${{ inputs.docker-token }}
|
||||
|
||||
- name: Build and push Model Server Docker image
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
push: true
|
||||
tags: ${{ inputs.runs-on-ecr-cache }}:nightly-llm-it-model-server-${{ inputs.run-id }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ inputs.github-sha }}
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache
|
||||
type=registry,ref=onyxdotapp/onyx-model-server:latest
|
||||
cache-to: |
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ inputs.github-sha }},mode=max
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache,mode=max
|
||||
120
.github/actions/run-nightly-provider-chat-test/action.yml
vendored
Normal file
120
.github/actions/run-nightly-provider-chat-test/action.yml
vendored
Normal file
@@ -0,0 +1,120 @@
|
||||
name: "Run Nightly Provider Chat Test"
|
||||
description: "Starts required compose services and runs nightly provider integration test"
|
||||
inputs:
|
||||
provider:
|
||||
description: "Provider slug for NIGHTLY_LLM_PROVIDER"
|
||||
required: true
|
||||
models:
|
||||
description: "Comma-separated model list for NIGHTLY_LLM_MODELS"
|
||||
required: true
|
||||
provider-api-key:
|
||||
description: "API key for NIGHTLY_LLM_API_KEY"
|
||||
required: true
|
||||
strict:
|
||||
description: "String true/false for NIGHTLY_LLM_STRICT"
|
||||
required: true
|
||||
api-base:
|
||||
description: "Optional NIGHTLY_LLM_API_BASE"
|
||||
required: false
|
||||
default: ""
|
||||
custom-config-json:
|
||||
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON"
|
||||
required: false
|
||||
default: ""
|
||||
runs-on-ecr-cache:
|
||||
description: "ECR cache registry from runs-on/action"
|
||||
required: true
|
||||
run-id:
|
||||
description: "GitHub run ID used in image tags"
|
||||
required: true
|
||||
docker-username:
|
||||
description: "Docker Hub username"
|
||||
required: true
|
||||
docker-token:
|
||||
description: "Docker Hub token"
|
||||
required: true
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ inputs.docker-username }}
|
||||
password: ${{ inputs.docker-token }}
|
||||
|
||||
- name: Create .env file for Docker Compose
|
||||
shell: bash
|
||||
env:
|
||||
ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
|
||||
RUN_ID: ${{ inputs.run-id }}
|
||||
run: |
|
||||
cat <<EOF2 > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:nightly-llm-it-backend-${RUN_ID}
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:nightly-llm-it-model-server-${RUN_ID}
|
||||
EOF2
|
||||
|
||||
- name: Start Docker containers
|
||||
shell: bash
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d --wait \
|
||||
relational_db \
|
||||
index \
|
||||
cache \
|
||||
minio \
|
||||
api_server \
|
||||
inference_model_server
|
||||
|
||||
- name: Run nightly provider integration test
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
env:
|
||||
MODELS: ${{ inputs.models }}
|
||||
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
|
||||
NIGHTLY_LLM_API_KEY: ${{ inputs.provider-api-key }}
|
||||
NIGHTLY_LLM_API_BASE: ${{ inputs.api-base }}
|
||||
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom-config-json }}
|
||||
NIGHTLY_LLM_STRICT: ${{ inputs.strict }}
|
||||
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
|
||||
RUN_ID: ${{ inputs.run-id }}
|
||||
with:
|
||||
timeout_minutes: 20
|
||||
max_attempts: 2
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
if [ -z "${MODELS}" ]; then
|
||||
echo "Input 'models' must be non-empty for provider '${NIGHTLY_LLM_PROVIDER}'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e NIGHTLY_LLM_PROVIDER="${NIGHTLY_LLM_PROVIDER}" \
|
||||
-e NIGHTLY_LLM_MODELS="${MODELS}" \
|
||||
-e NIGHTLY_LLM_API_KEY="${NIGHTLY_LLM_API_KEY}" \
|
||||
-e NIGHTLY_LLM_API_BASE="${NIGHTLY_LLM_API_BASE}" \
|
||||
-e NIGHTLY_LLM_CUSTOM_CONFIG_JSON="${NIGHTLY_LLM_CUSTOM_CONFIG_JSON}" \
|
||||
-e NIGHTLY_LLM_STRICT="${NIGHTLY_LLM_STRICT}" \
|
||||
${RUNS_ON_ECR_CACHE}:nightly-llm-it-${RUN_ID} \
|
||||
/app/tests/integration/tests/llm_workflows/test_nightly_provider_chat_workflow.py
|
||||
44
.github/workflows/nightly-llm-provider-chat-openai.yml
vendored
Normal file
44
.github/workflows/nightly-llm-provider-chat-openai.yml
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
name: Nightly LLM Provider Chat Tests (OpenAI)
|
||||
concurrency:
|
||||
group: Nightly-LLM-Provider-Chat-OpenAI-${{ github.workflow }}-${{ github.ref_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# Runs daily at 10:30 UTC (2:30 AM PST / 3:30 AM PDT)
|
||||
- cron: "30 10 * * *"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
openai-provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
with:
|
||||
provider: openai
|
||||
models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
|
||||
strict: true
|
||||
secrets:
|
||||
provider_api_key: ${{ secrets.OPENAI_API_KEY }}
|
||||
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: [openai-provider-chat-test]
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
|
||||
failed-jobs: openai-provider-chat-test
|
||||
title: "🚨 Scheduled OpenAI Provider Chat Tests failed!"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
@@ -11,6 +11,11 @@ permissions:
|
||||
|
||||
jobs:
|
||||
cherry-pick-to-latest-release:
|
||||
outputs:
|
||||
should_cherrypick: ${{ steps.gate.outputs.should_cherrypick }}
|
||||
pr_number: ${{ steps.gate.outputs.pr_number }}
|
||||
cherry_pick_reason: ${{ steps.run_cherry_pick.outputs.reason }}
|
||||
cherry_pick_details: ${{ steps.run_cherry_pick.outputs.details }}
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
@@ -75,10 +80,82 @@ jobs:
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
- name: Create cherry-pick PR to latest release
|
||||
id: run_cherry_pick
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
continue-on-error: true
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
CHERRY_PICK_ASSIGNEE: ${{ steps.gate.outputs.merged_by }}
|
||||
run: |
|
||||
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify
|
||||
set -o pipefail
|
||||
output_file="$(mktemp)"
|
||||
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify 2>&1 | tee "$output_file"
|
||||
exit_code="${PIPESTATUS[0]}"
|
||||
|
||||
if [ "${exit_code}" -eq 0 ]; then
|
||||
echo "status=success" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "status=failure" >> "$GITHUB_OUTPUT"
|
||||
|
||||
reason="command-failed"
|
||||
if grep -qiE "merge conflict during cherry-pick|CONFLICT|could not apply|cherry-pick in progress with staged changes" "$output_file"; then
|
||||
reason="merge-conflict"
|
||||
fi
|
||||
echo "reason=${reason}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
{
|
||||
echo "details<<EOF"
|
||||
tail -n 40 "$output_file"
|
||||
echo "EOF"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Mark workflow as failed if cherry-pick failed
|
||||
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
|
||||
run: |
|
||||
echo "::error::Automated cherry-pick failed (${{ steps.run_cherry_pick.outputs.reason }})."
|
||||
exit 1
|
||||
|
||||
notify-slack-on-cherry-pick-failure:
|
||||
needs:
|
||||
- cherry-pick-to-latest-release
|
||||
if: always() && needs.cherry-pick-to-latest-release.outputs.should_cherrypick == 'true' && needs.cherry-pick-to-latest-release.result != 'success'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Build cherry-pick failure summary
|
||||
id: failure-summary
|
||||
env:
|
||||
SOURCE_PR_NUMBER: ${{ needs.cherry-pick-to-latest-release.outputs.pr_number }}
|
||||
CHERRY_PICK_REASON: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_reason }}
|
||||
CHERRY_PICK_DETAILS: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_details }}
|
||||
run: |
|
||||
source_pr_url="https://github.com/${GITHUB_REPOSITORY}/pull/${SOURCE_PR_NUMBER}"
|
||||
|
||||
reason_text="cherry-pick command failed"
|
||||
if [ "${CHERRY_PICK_REASON}" = "merge-conflict" ]; then
|
||||
reason_text="merge conflict during cherry-pick"
|
||||
fi
|
||||
|
||||
details_excerpt="$(printf '%s' "${CHERRY_PICK_DETAILS}" | tail -n 8 | tr '\n' ' ' | sed "s/[[:space:]]\\+/ /g" | sed "s/\"/'/g" | cut -c1-350)"
|
||||
failed_jobs="• cherry-pick-to-latest-release\\n• source PR: ${source_pr_url}\\n• reason: ${reason_text}"
|
||||
if [ -n "${details_excerpt}" ]; then
|
||||
failed_jobs="${failed_jobs}\\n• excerpt: ${details_excerpt}"
|
||||
fi
|
||||
|
||||
echo "jobs=${failed_jobs}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Notify #cherry-pick-prs about cherry-pick failure
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
|
||||
failed-jobs: ${{ steps.failure-summary.outputs.jobs }}
|
||||
title: "🚨 Automated Cherry-Pick Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
|
||||
@@ -116,7 +116,6 @@ jobs:
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore,opensearch-enabled
|
||||
CODE_INTERPRETER_BETA_ENABLED=true
|
||||
DISABLE_TELEMETRY=true
|
||||
OPENSEARCH_FOR_ONYX_ENABLED=true
|
||||
EOF
|
||||
|
||||
4
.github/workflows/pr-integration-tests.yml
vendored
4
.github/workflows/pr-integration-tests.yml
vendored
@@ -20,6 +20,7 @@ env:
|
||||
# Test Environment Variables
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
SLACK_BOT_TOKEN_TEST_SPACE: ${{ secrets.SLACK_BOT_TOKEN_TEST_SPACE }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
@@ -423,6 +424,7 @@ jobs:
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e EXA_API_KEY=${EXA_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e SLACK_BOT_TOKEN_TEST_SPACE=${SLACK_BOT_TOKEN_TEST_SPACE} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
@@ -443,6 +445,7 @@ jobs:
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
-e ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${{ matrix.edition == 'ee' && 'true' || 'false' }} \
|
||||
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
|
||||
/app/tests/integration/${{ matrix.test-dir.path }}
|
||||
|
||||
@@ -701,6 +704,7 @@ jobs:
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e EXA_API_KEY=${EXA_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e SLACK_BOT_TOKEN_TEST_SPACE=${SLACK_BOT_TOKEN_TEST_SPACE} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AUTH_TYPE=cloud \
|
||||
-e MULTI_TENANT=true \
|
||||
|
||||
206
.github/workflows/reusable-nightly-llm-provider-chat.yml
vendored
Normal file
206
.github/workflows/reusable-nightly-llm-provider-chat.yml
vendored
Normal file
@@ -0,0 +1,206 @@
|
||||
name: Reusable Nightly LLM Provider Chat Tests
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
provider:
|
||||
description: "Provider slug passed to NIGHTLY_LLM_PROVIDER (e.g. openai, anthropic)"
|
||||
required: true
|
||||
type: string
|
||||
models:
|
||||
description: "Comma-separated model list passed to NIGHTLY_LLM_MODELS"
|
||||
required: true
|
||||
type: string
|
||||
strict:
|
||||
description: "Pass-through value for NIGHTLY_LLM_STRICT"
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
api_base:
|
||||
description: "Optional NIGHTLY_LLM_API_BASE override"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
custom_config_json:
|
||||
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON override"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
secrets:
|
||||
provider_api_key:
|
||||
description: "Provider API key passed to NIGHTLY_LLM_API_KEY"
|
||||
required: true
|
||||
DOCKER_USERNAME:
|
||||
required: true
|
||||
DOCKER_TOKEN:
|
||||
required: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
|
||||
NIGHTLY_LLM_MODELS: ${{ inputs.models }}
|
||||
NIGHTLY_LLM_API_KEY: ${{ secrets.provider_api_key }}
|
||||
NIGHTLY_LLM_API_BASE: ${{ inputs.api_base }}
|
||||
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom_config_json }}
|
||||
NIGHTLY_LLM_STRICT: ${{ inputs.strict && 'true' || 'false' }}
|
||||
|
||||
jobs:
|
||||
validate-inputs:
|
||||
# NOTE: Keep this cheap and fail before image builds if required inputs are missing.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
steps:
|
||||
- name: Validate required nightly provider inputs
|
||||
run: |
|
||||
if [ -z "${NIGHTLY_LLM_MODELS}" ]; then
|
||||
echo "Input 'models' must be non-empty for provider '${NIGHTLY_LLM_PROVIDER}'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
build-backend-image:
|
||||
needs: [validate-inputs]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-backend-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Build backend image
|
||||
uses: ./.github/actions/build-backend-image
|
||||
with:
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
ref-name: ${{ github.ref_name }}
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
docker-no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' && 'true' || 'false' }}
|
||||
|
||||
build-model-server-image:
|
||||
needs: [validate-inputs]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-model-server-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Build model server image
|
||||
uses: ./.github/actions/build-model-server-image
|
||||
with:
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
ref-name: ${{ github.ref_name }}
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
build-integration-image:
|
||||
needs: [validate-inputs]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=2cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-integration-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Build integration image
|
||||
uses: ./.github/actions/build-integration-image
|
||||
with:
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
ref-name: ${{ github.ref_name }}
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
provider-chat-test:
|
||||
needs:
|
||||
[build-backend-image, build-model-server-image, build-integration-image]
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- "run-id=${{ github.run_id }}-nightly-${{ inputs.provider }}-provider-chat-test"
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Run nightly provider chat test
|
||||
uses: ./.github/actions/run-nightly-provider-chat-test
|
||||
with:
|
||||
provider: ${{ env.NIGHTLY_LLM_PROVIDER }}
|
||||
models: ${{ env.NIGHTLY_LLM_MODELS }}
|
||||
provider-api-key: ${{ secrets.provider_api_key }}
|
||||
strict: ${{ env.NIGHTLY_LLM_STRICT }}
|
||||
api-base: ${{ env.NIGHTLY_LLM_API_BASE }}
|
||||
custom-config-json: ${{ env.NIGHTLY_LLM_CUSTOM_CONFIG_JSON }}
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Dump API server logs
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
|
||||
- name: Dump all-container logs
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-nightly-${{ inputs.provider }}-llm-provider
|
||||
path: |
|
||||
${{ github.workspace }}/api_server.log
|
||||
${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Stop Docker containers
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose down -v
|
||||
@@ -548,7 +548,7 @@ class in the utils over directly calling the APIs with a library like `requests`
|
||||
calling the utilities directly (e.g. do NOT create admin users with
|
||||
`admin_user = UserManager.create(name="admin_user")`, instead use the `admin_user` fixture).
|
||||
|
||||
A great example of this type of test is `backend/tests/integration/dev_apis/test_simple_chat_api.py`.
|
||||
A great example of this type of test is `backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py`.
|
||||
|
||||
To run them:
|
||||
|
||||
@@ -616,3 +616,9 @@ This is a minimal list - feel free to include more. Do NOT write code as part of
|
||||
Keep it high level. You can reference certain files or functions though.
|
||||
|
||||
Before writing your plan, make sure to do research. Explore the relevant sections in the codebase.
|
||||
|
||||
## Best Practices
|
||||
|
||||
In addition to the other content in this file, best practices for contributing
|
||||
to the codebase can be found at `contributing_guides/best_practices.md`.
|
||||
Understand its contents and follow them.
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
"""add needs_persona_sync to user_file
|
||||
|
||||
Revision ID: 8ffcc2bcfc11
|
||||
Revises: 7616121f6e97
|
||||
Create Date: 2026-02-23 10:48:48.343826
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8ffcc2bcfc11"
|
||||
down_revision = "7616121f6e97"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column(
|
||||
"needs_persona_sync",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user_file", "needs_persona_sync")
|
||||
@@ -34,6 +34,7 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from ee.onyx.server.scim.filtering import ScimFilter
|
||||
from ee.onyx.server.scim.filtering import ScimFilterOperator
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from onyx.db.dal import DAL
|
||||
from onyx.db.models import ScimGroupMapping
|
||||
from onyx.db.models import ScimToken
|
||||
@@ -128,12 +129,19 @@ class ScimDAL(DAL):
|
||||
external_id: str,
|
||||
user_id: UUID,
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> ScimUserMapping:
|
||||
"""Create a mapping between a SCIM externalId and an Onyx user."""
|
||||
f = fields or ScimMappingFields()
|
||||
mapping = ScimUserMapping(
|
||||
external_id=external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
department=f.department,
|
||||
manager=f.manager,
|
||||
given_name=f.given_name,
|
||||
family_name=f.family_name,
|
||||
scim_emails_json=f.scim_emails_json,
|
||||
)
|
||||
self._session.add(mapping)
|
||||
self._session.flush()
|
||||
@@ -311,8 +319,14 @@ class ScimDAL(DAL):
|
||||
user_id: UUID,
|
||||
new_external_id: str | None,
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> None:
|
||||
"""Create, update, or delete the external ID mapping for a user."""
|
||||
"""Create, update, or delete the external ID mapping for a user.
|
||||
|
||||
When *fields* is provided, all mapping fields are written
|
||||
unconditionally — including ``None`` values — so that a caller can
|
||||
clear a previously-set field (e.g. removing a department).
|
||||
"""
|
||||
mapping = self.get_user_mapping_by_user_id(user_id)
|
||||
if new_external_id:
|
||||
if mapping:
|
||||
@@ -320,11 +334,18 @@ class ScimDAL(DAL):
|
||||
mapping.external_id = new_external_id
|
||||
if scim_username is not None:
|
||||
mapping.scim_username = scim_username
|
||||
if fields is not None:
|
||||
mapping.department = fields.department
|
||||
mapping.manager = fields.manager
|
||||
mapping.given_name = fields.given_name
|
||||
mapping.family_name = fields.family_name
|
||||
mapping.scim_emails_json = fields.scim_emails_json
|
||||
else:
|
||||
self.create_user_mapping(
|
||||
external_id=new_external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
elif mapping:
|
||||
self.delete_user_mapping(mapping.id)
|
||||
|
||||
@@ -4,7 +4,6 @@ from collections import deque
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from urllib.parse import unquote
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests as _requests
|
||||
@@ -598,8 +597,12 @@ def get_external_access_from_sharepoint(
|
||||
)
|
||||
elif site_page:
|
||||
site_url = site_page.get("webUrl")
|
||||
# Prefer server-relative URL to avoid OData filters that break on apostrophes
|
||||
server_relative_url = unquote(urlparse(site_url).path)
|
||||
# Keep percent-encoding intact so the path matches the encoding
|
||||
# used by the Office365 library's SPResPath.create_relative(),
|
||||
# which compares against urlparse(context.base_url).path.
|
||||
# Decoding (e.g. %27 → ') causes a mismatch that duplicates
|
||||
# the site prefix in the constructed URL.
|
||||
server_relative_url = urlparse(site_url).path
|
||||
file_obj = client_context.web.get_file_by_server_relative_url(
|
||||
server_relative_url
|
||||
)
|
||||
|
||||
@@ -26,14 +26,14 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from ee.onyx.server.scim.auth import verify_scim_token
|
||||
from ee.onyx.server.scim.filtering import parse_scim_filter
|
||||
from ee.onyx.server.scim.models import SCIM_LIST_RESPONSE_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimError
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimListResponse
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimPatchRequest
|
||||
from ee.onyx.server.scim.models import ScimResourceType
|
||||
from ee.onyx.server.scim.models import ScimSchemaDefinition
|
||||
from ee.onyx.server.scim.models import ScimServiceProviderConfig
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.patch import apply_group_patch
|
||||
@@ -41,6 +41,8 @@ from ee.onyx.server.scim.patch import apply_user_patch
|
||||
from ee.onyx.server.scim.patch import ScimPatchError
|
||||
from ee.onyx.server.scim.providers.base import get_default_provider
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
from ee.onyx.server.scim.providers.base import serialize_emails
|
||||
from ee.onyx.server.scim.schema_definitions import ENTERPRISE_USER_SCHEMA_DEF
|
||||
from ee.onyx.server.scim.schema_definitions import GROUP_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import GROUP_SCHEMA_DEF
|
||||
from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
|
||||
@@ -48,15 +50,28 @@ from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ScimJSONResponse(JSONResponse):
|
||||
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
|
||||
|
||||
media_type = "application/scim+json"
|
||||
|
||||
|
||||
# NOTE: All URL paths in this router (/ServiceProviderConfig, /ResourceTypes,
|
||||
# /Schemas, /Users, /Groups) are mandated by the SCIM spec (RFC 7643/7644).
|
||||
# IdPs like Okta and Azure AD hardcode these exact paths, so they cannot be
|
||||
# changed to kebab-case.
|
||||
|
||||
|
||||
scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
|
||||
|
||||
_pw_helper = PasswordHelper()
|
||||
@@ -86,15 +101,39 @@ def get_service_provider_config() -> ScimServiceProviderConfig:
|
||||
|
||||
|
||||
@scim_router.get("/ResourceTypes")
|
||||
def get_resource_types() -> list[ScimResourceType]:
|
||||
"""List available SCIM resource types (RFC 7643 §6)."""
|
||||
return [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
|
||||
def get_resource_types() -> ScimJSONResponse:
|
||||
"""List available SCIM resource types (RFC 7643 §6).
|
||||
|
||||
Wrapped in a ListResponse envelope (RFC 7644 §3.4.2) because IdPs
|
||||
like Entra ID expect a JSON object, not a bare array.
|
||||
"""
|
||||
resources = [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
|
||||
return ScimJSONResponse(
|
||||
content={
|
||||
"schemas": [SCIM_LIST_RESPONSE_SCHEMA],
|
||||
"totalResults": len(resources),
|
||||
"Resources": [
|
||||
r.model_dump(exclude_none=True, by_alias=True) for r in resources
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@scim_router.get("/Schemas")
|
||||
def get_schemas() -> list[ScimSchemaDefinition]:
|
||||
"""Return SCIM schema definitions (RFC 7643 §7)."""
|
||||
return [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF]
|
||||
def get_schemas() -> ScimJSONResponse:
|
||||
"""Return SCIM schema definitions (RFC 7643 §7).
|
||||
|
||||
Wrapped in a ListResponse envelope (RFC 7644 §3.4.2) because IdPs
|
||||
like Entra ID expect a JSON object, not a bare array.
|
||||
"""
|
||||
schemas = [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF, ENTERPRISE_USER_SCHEMA_DEF]
|
||||
return ScimJSONResponse(
|
||||
content={
|
||||
"schemas": [SCIM_LIST_RESPONSE_SCHEMA],
|
||||
"totalResults": len(schemas),
|
||||
"Resources": [s.model_dump(exclude_none=True) for s in schemas],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -102,15 +141,45 @@ def get_schemas() -> list[ScimSchemaDefinition]:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _scim_error_response(status: int, detail: str) -> JSONResponse:
|
||||
def _scim_error_response(status: int, detail: str) -> ScimJSONResponse:
|
||||
"""Build a SCIM-compliant error response (RFC 7644 §3.12)."""
|
||||
logger.warning("SCIM error response: status=%s detail=%s", status, detail)
|
||||
body = ScimError(status=str(status), detail=detail)
|
||||
return JSONResponse(
|
||||
return ScimJSONResponse(
|
||||
status_code=status,
|
||||
content=body.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
def _parse_excluded_attributes(raw: str | None) -> set[str]:
|
||||
"""Parse the ``excludedAttributes`` query parameter (RFC 7644 §3.4.2.5).
|
||||
|
||||
Returns a set of lowercased attribute names to omit from responses.
|
||||
"""
|
||||
if not raw:
|
||||
return set()
|
||||
return {attr.strip().lower() for attr in raw.split(",") if attr.strip()}
|
||||
|
||||
|
||||
def _apply_exclusions(
|
||||
resource: ScimUserResource | ScimGroupResource,
|
||||
excluded: set[str],
|
||||
) -> dict:
|
||||
"""Serialize a SCIM resource, omitting attributes the IdP excluded.
|
||||
|
||||
RFC 7644 §3.4.2.5 lets the IdP pass ``?excludedAttributes=groups,emails``
|
||||
to reduce response payload size. We strip those fields after serialization
|
||||
so the rest of the pipeline doesn't need to know about them.
|
||||
"""
|
||||
data = resource.model_dump(exclude_none=True, by_alias=True)
|
||||
for attr in excluded:
|
||||
# Match case-insensitively against the camelCase field names
|
||||
keys_to_remove = [k for k in data if k.lower() == attr]
|
||||
for k in keys_to_remove:
|
||||
del data[k]
|
||||
return data
|
||||
|
||||
|
||||
def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
"""Return an error message if seat limit is reached, else None."""
|
||||
check_fn = fetch_ee_implementation_or_noop(
|
||||
@@ -124,7 +193,7 @@ def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | JSONResponse:
|
||||
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | ScimJSONResponse:
|
||||
"""Parse *user_id* as UUID, look up the user, or return a 404 error."""
|
||||
try:
|
||||
uid = UUID(user_id)
|
||||
@@ -144,10 +213,95 @@ def _scim_name_to_str(name: ScimName | None) -> str | None:
|
||||
"""
|
||||
if not name:
|
||||
return None
|
||||
# Build from givenName/familyName first — IdPs like Okta may send a stale
|
||||
# ``formatted`` value while updating the individual name components.
|
||||
# If the client explicitly provides ``formatted``, prefer it — the client
|
||||
# knows what display string it wants. Otherwise build from components.
|
||||
if name.formatted:
|
||||
return name.formatted
|
||||
parts = " ".join(part for part in [name.givenName, name.familyName] if part)
|
||||
return parts or name.formatted
|
||||
return parts or None
|
||||
|
||||
|
||||
def _scim_resource_response(
|
||||
resource: ScimUserResource | ScimGroupResource | ScimListResponse,
|
||||
status_code: int = 200,
|
||||
) -> ScimJSONResponse:
|
||||
"""Serialize a SCIM resource as ``application/scim+json``."""
|
||||
content = resource.model_dump(exclude_none=True, by_alias=True)
|
||||
return ScimJSONResponse(
|
||||
status_code=status_code,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
def _build_list_response(
|
||||
resources: list[ScimUserResource | ScimGroupResource],
|
||||
total: int,
|
||||
start_index: int,
|
||||
count: int,
|
||||
excluded: set[str] | None = None,
|
||||
) -> ScimListResponse | ScimJSONResponse:
|
||||
"""Build a SCIM list response, optionally applying attribute exclusions.
|
||||
|
||||
RFC 7644 §3.4.2.5 — IdPs may request certain attributes be omitted via
|
||||
the ``excludedAttributes`` query parameter.
|
||||
"""
|
||||
if excluded:
|
||||
envelope = ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=start_index,
|
||||
itemsPerPage=count,
|
||||
)
|
||||
data = envelope.model_dump(exclude_none=True)
|
||||
data["Resources"] = [_apply_exclusions(r, excluded) for r in resources]
|
||||
return ScimJSONResponse(content=data)
|
||||
|
||||
return _scim_resource_response(
|
||||
ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=start_index,
|
||||
itemsPerPage=count,
|
||||
Resources=resources,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _extract_enterprise_fields(
|
||||
resource: ScimUserResource,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Extract department and manager from enterprise extension."""
|
||||
ext = resource.enterprise_extension
|
||||
if not ext:
|
||||
return None, None
|
||||
department = ext.department
|
||||
manager = ext.manager.value if ext.manager else None
|
||||
return department, manager
|
||||
|
||||
|
||||
def _mapping_to_fields(
|
||||
mapping: ScimUserMapping | None,
|
||||
) -> ScimMappingFields | None:
|
||||
"""Extract round-trip fields from a SCIM user mapping."""
|
||||
if not mapping:
|
||||
return None
|
||||
return ScimMappingFields(
|
||||
department=mapping.department,
|
||||
manager=mapping.manager,
|
||||
given_name=mapping.given_name,
|
||||
family_name=mapping.family_name,
|
||||
scim_emails_json=mapping.scim_emails_json,
|
||||
)
|
||||
|
||||
|
||||
def _fields_from_resource(resource: ScimUserResource) -> ScimMappingFields:
|
||||
"""Build mapping fields from an incoming SCIM user resource."""
|
||||
department, manager = _extract_enterprise_fields(resource)
|
||||
return ScimMappingFields(
|
||||
department=department,
|
||||
manager=manager,
|
||||
given_name=resource.name.givenName if resource.name else None,
|
||||
family_name=resource.name.familyName if resource.name else None,
|
||||
scim_emails_json=serialize_emails(resource.emails),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -158,12 +312,13 @@ def _scim_name_to_str(name: ScimName | None) -> str | None:
|
||||
@scim_router.get("/Users", response_model=None)
|
||||
def list_users(
|
||||
filter: str | None = Query(None),
|
||||
excludedAttributes: str | None = None,
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(100, ge=0, le=500),
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimListResponse | JSONResponse:
|
||||
) -> ScimListResponse | ScimJSONResponse:
|
||||
"""List users with optional SCIM filter and pagination."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
@@ -185,42 +340,54 @@ def list_users(
|
||||
mapping.external_id if mapping else None,
|
||||
groups=user_groups_map.get(user.id, []),
|
||||
scim_username=mapping.scim_username if mapping else None,
|
||||
fields=_mapping_to_fields(mapping),
|
||||
)
|
||||
for user, mapping in users_with_mappings
|
||||
]
|
||||
|
||||
return ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=startIndex,
|
||||
itemsPerPage=count,
|
||||
Resources=resources,
|
||||
return _build_list_response(
|
||||
resources,
|
||||
total,
|
||||
startIndex,
|
||||
count,
|
||||
excluded=_parse_excluded_attributes(excludedAttributes),
|
||||
)
|
||||
|
||||
|
||||
@scim_router.get("/Users/{user_id}", response_model=None)
|
||||
def get_user(
|
||||
user_id: str,
|
||||
excludedAttributes: str | None = None,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
) -> ScimUserResource | ScimJSONResponse:
|
||||
"""Get a single user by ID."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
return provider.build_user_resource(
|
||||
|
||||
resource = provider.build_user_resource(
|
||||
user,
|
||||
mapping.external_id if mapping else None,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=mapping.scim_username if mapping else None,
|
||||
fields=_mapping_to_fields(mapping),
|
||||
)
|
||||
|
||||
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
|
||||
excluded = _parse_excluded_attributes(excludedAttributes)
|
||||
if excluded:
|
||||
return ScimJSONResponse(content=_apply_exclusions(resource, excluded))
|
||||
|
||||
return _scim_resource_response(resource)
|
||||
|
||||
|
||||
@scim_router.post("/Users", status_code=201, response_model=None)
|
||||
def create_user(
|
||||
@@ -228,7 +395,7 @@ def create_user(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
) -> ScimUserResource | ScimJSONResponse:
|
||||
"""Create a new user from a SCIM provisioning request."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
@@ -270,13 +437,25 @@ def create_user(
|
||||
# Create SCIM mapping (externalId is validated above, always present)
|
||||
external_id = user_resource.externalId
|
||||
scim_username = user_resource.userName.strip()
|
||||
fields = _fields_from_resource(user_resource)
|
||||
dal.create_user_mapping(
|
||||
external_id=external_id, user_id=user.id, scim_username=scim_username
|
||||
external_id=external_id,
|
||||
user_id=user.id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return provider.build_user_resource(user, external_id, scim_username=scim_username)
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
external_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
),
|
||||
status_code=201,
|
||||
)
|
||||
|
||||
|
||||
@scim_router.put("/Users/{user_id}", response_model=None)
|
||||
@@ -286,13 +465,13 @@ def replace_user(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
) -> ScimUserResource | ScimJSONResponse:
|
||||
"""Replace a user entirely (RFC 7644 §3.5.1)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
@@ -313,15 +492,24 @@ def replace_user(
|
||||
|
||||
new_external_id = user_resource.externalId
|
||||
scim_username = user_resource.userName.strip()
|
||||
dal.sync_user_external_id(user.id, new_external_id, scim_username=scim_username)
|
||||
fields = _fields_from_resource(user_resource)
|
||||
dal.sync_user_external_id(
|
||||
user.id,
|
||||
new_external_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return provider.build_user_resource(
|
||||
user,
|
||||
new_external_id,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=scim_username,
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
new_external_id,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -332,7 +520,7 @@ def patch_user(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
) -> ScimUserResource | ScimJSONResponse:
|
||||
"""Partially update a user (RFC 7644 §3.5.2).
|
||||
|
||||
This is the primary endpoint for user deprovisioning — Okta sends
|
||||
@@ -342,23 +530,25 @@ def patch_user(
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
external_id = mapping.external_id if mapping else None
|
||||
current_scim_username = mapping.scim_username if mapping else None
|
||||
current_fields = _mapping_to_fields(mapping)
|
||||
|
||||
current = provider.build_user_resource(
|
||||
user,
|
||||
external_id,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=current_scim_username,
|
||||
fields=current_fields,
|
||||
)
|
||||
|
||||
try:
|
||||
patched = apply_user_patch(
|
||||
patched, ent_data = apply_user_patch(
|
||||
patch_request.Operations, current, provider.ignored_patch_paths
|
||||
)
|
||||
except ScimPatchError as e:
|
||||
@@ -393,17 +583,37 @@ def patch_user(
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
# Build updated fields by merging PATCH enterprise data with current values
|
||||
cf = current_fields or ScimMappingFields()
|
||||
fields = ScimMappingFields(
|
||||
department=ent_data.get("department", cf.department),
|
||||
manager=ent_data.get("manager", cf.manager),
|
||||
given_name=patched.name.givenName if patched.name else cf.given_name,
|
||||
family_name=patched.name.familyName if patched.name else cf.family_name,
|
||||
scim_emails_json=(
|
||||
serialize_emails(patched.emails)
|
||||
if patched.emails is not None
|
||||
else cf.scim_emails_json
|
||||
),
|
||||
)
|
||||
|
||||
dal.sync_user_external_id(
|
||||
user.id, patched.externalId, scim_username=new_scim_username
|
||||
user.id,
|
||||
patched.externalId,
|
||||
scim_username=new_scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return provider.build_user_resource(
|
||||
user,
|
||||
patched.externalId,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=new_scim_username,
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
patched.externalId,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=new_scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -412,25 +622,29 @@ def delete_user(
|
||||
user_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response | JSONResponse:
|
||||
) -> Response | ScimJSONResponse:
|
||||
"""Delete a user (RFC 7644 §3.6).
|
||||
|
||||
Deactivates the user and removes the SCIM mapping. Note that Okta
|
||||
typically uses PATCH active=false instead of DELETE.
|
||||
A second DELETE returns 404 per RFC 7644 §3.6.
|
||||
"""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
dal.deactivate_user(user)
|
||||
|
||||
# If no SCIM mapping exists, the user was already deleted from
|
||||
# SCIM's perspective — return 404 per RFC 7644 §3.6.
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
if mapping:
|
||||
dal.delete_user_mapping(mapping.id)
|
||||
if not mapping:
|
||||
return _scim_error_response(404, f"User {user_id} not found")
|
||||
|
||||
dal.deactivate_user(user)
|
||||
dal.delete_user_mapping(mapping.id)
|
||||
|
||||
dal.commit()
|
||||
|
||||
@@ -442,7 +656,7 @@ def delete_user(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | JSONResponse:
|
||||
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | ScimJSONResponse:
|
||||
"""Parse *group_id* as int, look up the group, or return a 404 error."""
|
||||
try:
|
||||
gid = int(group_id)
|
||||
@@ -497,12 +711,13 @@ def _validate_and_parse_members(
|
||||
@scim_router.get("/Groups", response_model=None)
|
||||
def list_groups(
|
||||
filter: str | None = Query(None),
|
||||
excludedAttributes: str | None = None,
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(100, ge=0, le=500),
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimListResponse | JSONResponse:
|
||||
) -> ScimListResponse | ScimJSONResponse:
|
||||
"""List groups with optional SCIM filter and pagination."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
@@ -522,37 +737,46 @@ def list_groups(
|
||||
for group, ext_id in groups_with_ext_ids
|
||||
]
|
||||
|
||||
return ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=startIndex,
|
||||
itemsPerPage=count,
|
||||
Resources=resources,
|
||||
return _build_list_response(
|
||||
resources,
|
||||
total,
|
||||
startIndex,
|
||||
count,
|
||||
excluded=_parse_excluded_attributes(excludedAttributes),
|
||||
)
|
||||
|
||||
|
||||
@scim_router.get("/Groups/{group_id}", response_model=None)
|
||||
def get_group(
|
||||
group_id: str,
|
||||
excludedAttributes: str | None = None,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
) -> ScimGroupResource | ScimJSONResponse:
|
||||
"""Get a single group by ID."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
mapping = dal.get_group_mapping_by_group_id(group.id)
|
||||
members = dal.get_group_members(group.id)
|
||||
|
||||
return provider.build_group_resource(
|
||||
resource = provider.build_group_resource(
|
||||
group, members, mapping.external_id if mapping else None
|
||||
)
|
||||
|
||||
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
|
||||
excluded = _parse_excluded_attributes(excludedAttributes)
|
||||
if excluded:
|
||||
return ScimJSONResponse(content=_apply_exclusions(resource, excluded))
|
||||
|
||||
return _scim_resource_response(resource)
|
||||
|
||||
|
||||
@scim_router.post("/Groups", status_code=201, response_model=None)
|
||||
def create_group(
|
||||
@@ -560,7 +784,7 @@ def create_group(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
) -> ScimGroupResource | ScimJSONResponse:
|
||||
"""Create a new group from a SCIM provisioning request."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
@@ -596,7 +820,10 @@ def create_group(
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(db_group.id)
|
||||
return provider.build_group_resource(db_group, members, external_id)
|
||||
return _scim_resource_response(
|
||||
provider.build_group_resource(db_group, members, external_id),
|
||||
status_code=201,
|
||||
)
|
||||
|
||||
|
||||
@scim_router.put("/Groups/{group_id}", response_model=None)
|
||||
@@ -606,13 +833,13 @@ def replace_group(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
) -> ScimGroupResource | ScimJSONResponse:
|
||||
"""Replace a group entirely (RFC 7644 §3.5.1)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
@@ -627,7 +854,9 @@ def replace_group(
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
return provider.build_group_resource(group, members, group_resource.externalId)
|
||||
return _scim_resource_response(
|
||||
provider.build_group_resource(group, members, group_resource.externalId)
|
||||
)
|
||||
|
||||
|
||||
@scim_router.patch("/Groups/{group_id}", response_model=None)
|
||||
@@ -637,7 +866,7 @@ def patch_group(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
) -> ScimGroupResource | ScimJSONResponse:
|
||||
"""Partially update a group (RFC 7644 §3.5.2).
|
||||
|
||||
Handles member add/remove operations from Okta and Azure AD.
|
||||
@@ -646,7 +875,7 @@ def patch_group(
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
@@ -685,7 +914,9 @@ def patch_group(
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
return provider.build_group_resource(group, members, patched.externalId)
|
||||
return _scim_resource_response(
|
||||
provider.build_group_resource(group, members, patched.externalId)
|
||||
)
|
||||
|
||||
|
||||
@scim_router.delete("/Groups/{group_id}", status_code=204, response_model=None)
|
||||
@@ -693,13 +924,13 @@ def delete_group(
|
||||
group_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response | JSONResponse:
|
||||
) -> Response | ScimJSONResponse:
|
||||
"""Delete a group (RFC 7644 §3.6)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
|
||||
@@ -7,12 +7,14 @@ SCIM protocol schemas follow the wire format defined in:
|
||||
Admin API schemas are internal to Onyx and used for SCIM token management.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -31,6 +33,9 @@ SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA = (
|
||||
)
|
||||
SCIM_RESOURCE_TYPE_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:ResourceType"
|
||||
SCIM_SCHEMA_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Schema"
|
||||
SCIM_ENTERPRISE_USER_SCHEMA = (
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -70,6 +75,36 @@ class ScimUserGroupRef(BaseModel):
|
||||
display: str | None = None
|
||||
|
||||
|
||||
class ScimManagerRef(BaseModel):
|
||||
"""Manager sub-attribute for the enterprise extension (RFC 7643 §4.3)."""
|
||||
|
||||
value: str | None = None
|
||||
|
||||
|
||||
class ScimEnterpriseExtension(BaseModel):
|
||||
"""Enterprise User extension attributes (RFC 7643 §4.3)."""
|
||||
|
||||
department: str | None = None
|
||||
manager: ScimManagerRef | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScimMappingFields:
|
||||
"""Stored SCIM mapping fields that need to round-trip through the IdP.
|
||||
|
||||
Entra ID sends structured name components, email metadata, and enterprise
|
||||
extension attributes that must be returned verbatim in subsequent GET
|
||||
responses. These fields are persisted on ScimUserMapping and threaded
|
||||
through the DAL, provider, and endpoint layers.
|
||||
"""
|
||||
|
||||
department: str | None = None
|
||||
manager: str | None = None
|
||||
given_name: str | None = None
|
||||
family_name: str | None = None
|
||||
scim_emails_json: str | None = None
|
||||
|
||||
|
||||
class ScimUserResource(BaseModel):
|
||||
"""SCIM User resource representation (RFC 7643 §4.1).
|
||||
|
||||
@@ -78,6 +113,8 @@ class ScimUserResource(BaseModel):
|
||||
to match the SCIM wire format (not Python convention).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_USER_SCHEMA])
|
||||
id: str | None = None # Onyx's internal user ID, set on responses
|
||||
externalId: str | None = None # IdP's identifier for this user
|
||||
@@ -88,6 +125,10 @@ class ScimUserResource(BaseModel):
|
||||
active: bool = True
|
||||
groups: list[ScimUserGroupRef] = Field(default_factory=list)
|
||||
meta: ScimMeta | None = None
|
||||
enterprise_extension: ScimEnterpriseExtension | None = Field(
|
||||
default=None,
|
||||
alias="urn:ietf:params:scim:schemas:extension:enterprise:2.0:User",
|
||||
)
|
||||
|
||||
|
||||
class ScimGroupMember(BaseModel):
|
||||
@@ -165,6 +206,19 @@ class ScimPatchOperation(BaseModel):
|
||||
path: str | None = None
|
||||
value: ScimPatchValue = None
|
||||
|
||||
@field_validator("op", mode="before")
|
||||
@classmethod
|
||||
def normalize_operation(cls, v: object) -> object:
|
||||
"""Normalize op to lowercase for case-insensitive matching.
|
||||
|
||||
Some IdPs (e.g. Entra ID) send capitalized ops like ``"Replace"``
|
||||
instead of ``"replace"``. This is safe for all providers since the
|
||||
enum values are lowercase. If a future provider requires other
|
||||
pre-processing quirks, move patch deserialization into the provider
|
||||
subclass instead of adding more special cases here.
|
||||
"""
|
||||
return v.lower() if isinstance(v, str) else v
|
||||
|
||||
|
||||
class ScimPatchRequest(BaseModel):
|
||||
"""PATCH request body (RFC 7644 §3.5.2).
|
||||
|
||||
@@ -14,8 +14,13 @@ responsible for persisting changes.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
from typing import Any
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
@@ -24,6 +29,55 @@ from ee.onyx.server.scim.models import ScimPatchResourceValue
|
||||
from ee.onyx.server.scim.models import ScimPatchValue
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Lowercased enterprise extension URN for case-insensitive matching
|
||||
_ENTERPRISE_URN_LOWER = SCIM_ENTERPRISE_USER_SCHEMA.lower()
|
||||
|
||||
# Pattern for email filter paths, e.g.:
|
||||
# emails[primary eq true].value (Okta)
|
||||
# emails[type eq "work"].value (Azure AD / Entra ID)
|
||||
_EMAIL_FILTER_RE = re.compile(
|
||||
r"^emails\[.+\]\.value$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Pattern for member removal path: members[value eq "user-id"]
|
||||
_MEMBER_FILTER_RE = re.compile(
|
||||
r'^members\[value\s+eq\s+"([^"]+)"\]$',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dispatch tables for user PATCH paths
|
||||
#
|
||||
# Maps lowercased SCIM path → (camelCase key, target dict name).
|
||||
# "data" writes to the top-level resource dict, "name" writes to the
|
||||
# name sub-object dict. This replaces the elif chains for simple fields.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_USER_REPLACE_PATHS: dict[str, tuple[str, str]] = {
|
||||
"active": ("active", "data"),
|
||||
"username": ("userName", "data"),
|
||||
"externalid": ("externalId", "data"),
|
||||
"name.givenname": ("givenName", "name"),
|
||||
"name.familyname": ("familyName", "name"),
|
||||
"name.formatted": ("formatted", "name"),
|
||||
}
|
||||
|
||||
_USER_REMOVE_PATHS: dict[str, tuple[str, str]] = {
|
||||
"externalid": ("externalId", "data"),
|
||||
"name.givenname": ("givenName", "name"),
|
||||
"name.familyname": ("familyName", "name"),
|
||||
"name.formatted": ("formatted", "name"),
|
||||
"displayname": ("displayName", "data"),
|
||||
}
|
||||
|
||||
_GROUP_REPLACE_PATHS: dict[str, tuple[str, str]] = {
|
||||
"displayname": ("displayName", "data"),
|
||||
"externalid": ("externalId", "data"),
|
||||
}
|
||||
|
||||
|
||||
class ScimPatchError(Exception):
|
||||
"""Raised when a PATCH operation cannot be applied."""
|
||||
@@ -34,18 +88,25 @@ class ScimPatchError(Exception):
|
||||
super().__init__(detail)
|
||||
|
||||
|
||||
# Pattern for member removal path: members[value eq "user-id"]
|
||||
_MEMBER_FILTER_RE = re.compile(
|
||||
r'^members\[value\s+eq\s+"([^"]+)"\]$',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
@dataclass
|
||||
class _UserPatchCtx:
|
||||
"""Bundles the mutable state for user PATCH operations."""
|
||||
|
||||
data: dict[str, Any]
|
||||
name_data: dict[str, Any]
|
||||
ent_data: dict[str, str | None] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User PATCH
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def apply_user_patch(
|
||||
operations: list[ScimPatchOperation],
|
||||
current: ScimUserResource,
|
||||
ignored_paths: frozenset[str] = frozenset(),
|
||||
) -> ScimUserResource:
|
||||
) -> tuple[ScimUserResource, dict[str, str | None]]:
|
||||
"""Apply SCIM PATCH operations to a user resource.
|
||||
|
||||
Args:
|
||||
@@ -53,79 +114,185 @@ def apply_user_patch(
|
||||
current: The current user resource state.
|
||||
ignored_paths: SCIM attribute paths to silently skip (from provider).
|
||||
|
||||
Returns a new ``ScimUserResource`` with the modifications applied.
|
||||
The original object is not mutated.
|
||||
Returns:
|
||||
A tuple of (modified user resource, enterprise extension data dict).
|
||||
The enterprise dict has keys ``"department"`` and ``"manager"``
|
||||
with values set only when a PATCH operation touched them.
|
||||
|
||||
Raises:
|
||||
ScimPatchError: If an operation targets an unsupported path.
|
||||
"""
|
||||
data = current.model_dump()
|
||||
name_data = data.get("name") or {}
|
||||
ctx = _UserPatchCtx(data=data, name_data=data.get("name") or {})
|
||||
|
||||
for op in operations:
|
||||
if op.op == ScimPatchOperationType.REPLACE:
|
||||
_apply_user_replace(op, data, name_data, ignored_paths)
|
||||
elif op.op == ScimPatchOperationType.ADD:
|
||||
_apply_user_replace(op, data, name_data, ignored_paths)
|
||||
if op.op in (ScimPatchOperationType.REPLACE, ScimPatchOperationType.ADD):
|
||||
_apply_user_replace(op, ctx, ignored_paths)
|
||||
elif op.op == ScimPatchOperationType.REMOVE:
|
||||
_apply_user_remove(op, ctx, ignored_paths)
|
||||
else:
|
||||
raise ScimPatchError(
|
||||
f"Unsupported operation '{op.op.value}' on User resource"
|
||||
)
|
||||
|
||||
data["name"] = name_data
|
||||
return ScimUserResource.model_validate(data)
|
||||
ctx.data["name"] = ctx.name_data
|
||||
return ScimUserResource.model_validate(ctx.data), ctx.ent_data
|
||||
|
||||
|
||||
def _apply_user_replace(
|
||||
op: ScimPatchOperation,
|
||||
data: dict,
|
||||
name_data: dict,
|
||||
ctx: _UserPatchCtx,
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Apply a replace/add operation to user data."""
|
||||
path = (op.path or "").lower()
|
||||
|
||||
if not path:
|
||||
# No path — value is a resource dict of top-level attributes to set
|
||||
# No path — value is a resource dict of top-level attributes to set.
|
||||
if isinstance(op.value, ScimPatchResourceValue):
|
||||
for key, val in op.value.model_dump(exclude_unset=True).items():
|
||||
_set_user_field(key.lower(), val, data, name_data, ignored_paths)
|
||||
_set_user_field(key.lower(), val, ctx, ignored_paths, strict=False)
|
||||
else:
|
||||
raise ScimPatchError("Replace without path requires a dict value")
|
||||
return
|
||||
|
||||
_set_user_field(path, op.value, data, name_data, ignored_paths)
|
||||
_set_user_field(path, op.value, ctx, ignored_paths)
|
||||
|
||||
|
||||
def _apply_user_remove(
|
||||
op: ScimPatchOperation,
|
||||
ctx: _UserPatchCtx,
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Apply a remove operation to user data — clears the target field."""
|
||||
path = (op.path or "").lower()
|
||||
if not path:
|
||||
raise ScimPatchError("Remove operation requires a path")
|
||||
|
||||
if path in ignored_paths:
|
||||
return
|
||||
|
||||
entry = _USER_REMOVE_PATHS.get(path)
|
||||
if entry:
|
||||
key, target = entry
|
||||
target_dict = ctx.data if target == "data" else ctx.name_data
|
||||
target_dict[key] = None
|
||||
return
|
||||
|
||||
raise ScimPatchError(f"Unsupported remove path '{path}' for User PATCH")
|
||||
|
||||
|
||||
def _set_user_field(
|
||||
path: str,
|
||||
value: ScimPatchValue,
|
||||
data: dict,
|
||||
name_data: dict,
|
||||
ctx: _UserPatchCtx,
|
||||
ignored_paths: frozenset[str],
|
||||
*,
|
||||
strict: bool = True,
|
||||
) -> None:
|
||||
"""Set a single field on user data by SCIM path."""
|
||||
"""Set a single field on user data by SCIM path.
|
||||
|
||||
Args:
|
||||
strict: When ``False`` (path-less replace), unknown attributes are
|
||||
silently skipped. When ``True`` (explicit path), they raise.
|
||||
"""
|
||||
if path in ignored_paths:
|
||||
return
|
||||
elif path == "active":
|
||||
data["active"] = value
|
||||
elif path == "username":
|
||||
data["userName"] = value
|
||||
elif path == "externalid":
|
||||
data["externalId"] = value
|
||||
elif path == "name.givenname":
|
||||
name_data["givenName"] = value
|
||||
elif path == "name.familyname":
|
||||
name_data["familyName"] = value
|
||||
elif path == "name.formatted":
|
||||
name_data["formatted"] = value
|
||||
elif path == "displayname":
|
||||
data["displayName"] = value
|
||||
name_data["formatted"] = value
|
||||
|
||||
# Simple field writes handled by the dispatch table
|
||||
entry = _USER_REPLACE_PATHS.get(path)
|
||||
if entry:
|
||||
key, target = entry
|
||||
target_dict = ctx.data if target == "data" else ctx.name_data
|
||||
target_dict[key] = value
|
||||
return
|
||||
|
||||
# displayName sets both the top-level field and the name.formatted sub-field
|
||||
if path == "displayname":
|
||||
ctx.data["displayName"] = value
|
||||
ctx.name_data["formatted"] = value
|
||||
elif path == "name":
|
||||
if isinstance(value, dict):
|
||||
for k, v in value.items():
|
||||
ctx.name_data[k] = v
|
||||
elif path == "emails":
|
||||
if isinstance(value, list):
|
||||
ctx.data["emails"] = value
|
||||
elif _EMAIL_FILTER_RE.match(path):
|
||||
_update_primary_email(ctx.data, value)
|
||||
elif path.startswith(_ENTERPRISE_URN_LOWER):
|
||||
_set_enterprise_field(path, value, ctx.ent_data)
|
||||
elif not strict:
|
||||
return
|
||||
else:
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for User PATCH")
|
||||
|
||||
|
||||
def _update_primary_email(data: dict[str, Any], value: ScimPatchValue) -> None:
|
||||
"""Update the primary email entry via an email filter path."""
|
||||
emails: list[dict] = data.get("emails") or []
|
||||
for email_entry in emails:
|
||||
if email_entry.get("primary"):
|
||||
email_entry["value"] = value
|
||||
break
|
||||
else:
|
||||
emails.append({"value": value, "type": "work", "primary": True})
|
||||
data["emails"] = emails
|
||||
|
||||
|
||||
def _to_dict(value: ScimPatchValue) -> dict | None:
|
||||
"""Coerce a SCIM patch value to a plain dict if possible.
|
||||
|
||||
Pydantic may parse raw dicts as ``ScimPatchResourceValue`` (which uses
|
||||
``extra="allow"``), so we also dump those back to a dict.
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
if isinstance(value, ScimPatchResourceValue):
|
||||
return value.model_dump(exclude_unset=True)
|
||||
return None
|
||||
|
||||
|
||||
def _set_enterprise_field(
|
||||
path: str,
|
||||
value: ScimPatchValue,
|
||||
ent_data: dict[str, str | None],
|
||||
) -> None:
|
||||
"""Handle enterprise extension URN paths or value dicts."""
|
||||
# Full URN as key with dict value (path-less PATCH)
|
||||
# e.g. key="urn:...:user", value={"department": "Eng", "manager": {...}}
|
||||
if path == _ENTERPRISE_URN_LOWER:
|
||||
d = _to_dict(value)
|
||||
if d is not None:
|
||||
if "department" in d:
|
||||
ent_data["department"] = d["department"]
|
||||
if "manager" in d:
|
||||
mgr = d["manager"]
|
||||
if isinstance(mgr, dict):
|
||||
ent_data["manager"] = mgr.get("value")
|
||||
return
|
||||
|
||||
# Dotted URN path, e.g. "urn:...:user:department"
|
||||
suffix = path[len(_ENTERPRISE_URN_LOWER) :].lstrip(":").lower()
|
||||
if suffix == "department":
|
||||
ent_data["department"] = str(value) if value is not None else None
|
||||
elif suffix == "manager":
|
||||
d = _to_dict(value)
|
||||
if d is not None:
|
||||
ent_data["manager"] = d.get("value")
|
||||
elif isinstance(value, str):
|
||||
ent_data["manager"] = value
|
||||
else:
|
||||
# Unknown enterprise attributes are silently ignored rather than
|
||||
# rejected — IdPs may send attributes we don't model yet.
|
||||
logger.warning("Ignoring unknown enterprise extension attribute '%s'", suffix)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Group PATCH
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def apply_group_patch(
|
||||
operations: list[ScimPatchOperation],
|
||||
current: ScimGroupResource,
|
||||
@@ -235,12 +402,14 @@ def _set_group_field(
|
||||
"""Set a single field on group data by SCIM path."""
|
||||
if path in ignored_paths:
|
||||
return
|
||||
elif path == "displayname":
|
||||
data["displayName"] = value
|
||||
elif path == "externalid":
|
||||
data["externalId"] = value
|
||||
else:
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
|
||||
|
||||
entry = _GROUP_REPLACE_PATHS.get(path)
|
||||
if entry:
|
||||
key, _ = entry
|
||||
data[key] = value
|
||||
return
|
||||
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
|
||||
|
||||
|
||||
def _apply_group_add(
|
||||
|
||||
@@ -2,13 +2,22 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimEmail
|
||||
from ee.onyx.server.scim.models import ScimEnterpriseExtension
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimManagerRef
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from ee.onyx.server.scim.models import ScimMeta
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimUserGroupRef
|
||||
@@ -17,6 +26,17 @@ from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
COMMON_IGNORED_PATCH_PATHS: frozenset[str] = frozenset(
|
||||
{
|
||||
"id",
|
||||
"schemas",
|
||||
"meta",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ScimProvider(ABC):
|
||||
"""Base class for provider-specific SCIM behavior.
|
||||
|
||||
@@ -41,12 +61,22 @@ class ScimProvider(ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def user_schemas(self) -> list[str]:
|
||||
"""Schema URIs to include in User resource responses.
|
||||
|
||||
Override in subclasses to advertise additional schemas (e.g. the
|
||||
enterprise extension for Entra ID).
|
||||
"""
|
||||
return [SCIM_USER_SCHEMA]
|
||||
|
||||
def build_user_resource(
|
||||
self,
|
||||
user: User,
|
||||
external_id: str | None = None,
|
||||
groups: list[tuple[int, str]] | None = None,
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> ScimUserResource:
|
||||
"""Build a SCIM User response from an Onyx User.
|
||||
|
||||
@@ -58,27 +88,48 @@ class ScimProvider(ABC):
|
||||
for newly-created users.
|
||||
scim_username: The original-case userName from the IdP. Falls
|
||||
back to ``user.email`` (lowercase) when not available.
|
||||
fields: Stored mapping fields that the IdP expects round-tripped.
|
||||
"""
|
||||
f = fields or ScimMappingFields()
|
||||
group_refs = [
|
||||
ScimUserGroupRef(value=str(gid), display=gname)
|
||||
for gid, gname in (groups or [])
|
||||
]
|
||||
|
||||
# Use original-case userName if stored, otherwise fall back to the
|
||||
# lowercased email from the User model.
|
||||
username = scim_username or user.email
|
||||
|
||||
return ScimUserResource(
|
||||
# Build enterprise extension when at least one value is present.
|
||||
# Dynamically add the enterprise URN to schemas per RFC 7643 §3.0.
|
||||
enterprise_ext: ScimEnterpriseExtension | None = None
|
||||
schemas = list(self.user_schemas)
|
||||
if f.department is not None or f.manager is not None:
|
||||
manager_ref = (
|
||||
ScimManagerRef(value=f.manager) if f.manager is not None else None
|
||||
)
|
||||
enterprise_ext = ScimEnterpriseExtension(
|
||||
department=f.department,
|
||||
manager=manager_ref,
|
||||
)
|
||||
if SCIM_ENTERPRISE_USER_SCHEMA not in schemas:
|
||||
schemas.append(SCIM_ENTERPRISE_USER_SCHEMA)
|
||||
|
||||
name = self.build_scim_name(user, f)
|
||||
emails = _deserialize_emails(f.scim_emails_json, username)
|
||||
|
||||
resource = ScimUserResource(
|
||||
schemas=schemas,
|
||||
id=str(user.id),
|
||||
externalId=external_id,
|
||||
userName=username,
|
||||
name=self._build_scim_name(user),
|
||||
name=name,
|
||||
displayName=user.personal_name,
|
||||
emails=[ScimEmail(value=username, type="work", primary=True)],
|
||||
emails=emails,
|
||||
active=user.is_active,
|
||||
groups=group_refs,
|
||||
meta=ScimMeta(resourceType="User"),
|
||||
)
|
||||
resource.enterprise_extension = enterprise_ext
|
||||
return resource
|
||||
|
||||
def build_group_resource(
|
||||
self,
|
||||
@@ -98,9 +149,24 @@ class ScimProvider(ABC):
|
||||
meta=ScimMeta(resourceType="Group"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_scim_name(user: User) -> ScimName | None:
|
||||
"""Extract SCIM name components from a user's personal name."""
|
||||
def build_scim_name(
|
||||
self,
|
||||
user: User,
|
||||
fields: ScimMappingFields,
|
||||
) -> ScimName | None:
|
||||
"""Build SCIM name components for the response.
|
||||
|
||||
Round-trips stored ``given_name``/``family_name`` when available (so
|
||||
the IdP gets back what it sent). Falls back to splitting
|
||||
``personal_name`` for users provisioned before we stored components.
|
||||
Providers may override for custom behavior.
|
||||
"""
|
||||
if fields.given_name is not None or fields.family_name is not None:
|
||||
return ScimName(
|
||||
givenName=fields.given_name,
|
||||
familyName=fields.family_name,
|
||||
formatted=user.personal_name,
|
||||
)
|
||||
if not user.personal_name:
|
||||
return None
|
||||
parts = user.personal_name.split(" ", 1)
|
||||
@@ -111,6 +177,27 @@ class ScimProvider(ABC):
|
||||
)
|
||||
|
||||
|
||||
def _deserialize_emails(stored_json: str | None, username: str) -> list[ScimEmail]:
|
||||
"""Deserialize stored email entries or build a default work email."""
|
||||
if stored_json:
|
||||
try:
|
||||
entries = json.loads(stored_json)
|
||||
if isinstance(entries, list) and entries:
|
||||
return [ScimEmail(**e) for e in entries]
|
||||
except (json.JSONDecodeError, TypeError, ValidationError):
|
||||
logger.warning(
|
||||
"Corrupt scim_emails_json, falling back to default: %s", stored_json
|
||||
)
|
||||
return [ScimEmail(value=username, type="work", primary=True)]
|
||||
|
||||
|
||||
def serialize_emails(emails: list[ScimEmail]) -> str | None:
|
||||
"""Serialize SCIM email entries to JSON for storage."""
|
||||
if not emails:
|
||||
return None
|
||||
return json.dumps([e.model_dump(exclude_none=True) for e in emails])
|
||||
|
||||
|
||||
def get_default_provider() -> ScimProvider:
|
||||
"""Return the default SCIM provider.
|
||||
|
||||
|
||||
36
backend/ee/onyx/server/scim/providers/entra.py
Normal file
36
backend/ee/onyx/server/scim/providers/entra.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Entra ID (Azure AD) SCIM provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
|
||||
_ENTRA_IGNORED_PATCH_PATHS = COMMON_IGNORED_PATCH_PATHS
|
||||
|
||||
|
||||
class EntraProvider(ScimProvider):
|
||||
"""Entra ID (Azure AD) SCIM provider.
|
||||
|
||||
Entra behavioral notes:
|
||||
- Sends capitalized PATCH ops (``"Add"``, ``"Replace"``, ``"Remove"``)
|
||||
— handled by ``ScimPatchOperation.normalize_op`` validator.
|
||||
- Sends the enterprise extension URN as a key in path-less PATCH value
|
||||
dicts — handled by ``_set_enterprise_field`` in ``patch.py`` to
|
||||
store department/manager values.
|
||||
- Expects the enterprise extension schema in ``schemas`` arrays and
|
||||
``/Schemas`` + ``/ResourceTypes`` discovery endpoints.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "entra"
|
||||
|
||||
@property
|
||||
def ignored_patch_paths(self) -> frozenset[str]:
|
||||
return _ENTRA_IGNORED_PATCH_PATHS
|
||||
|
||||
@property
|
||||
def user_schemas(self) -> list[str]:
|
||||
return [SCIM_USER_SCHEMA, SCIM_ENTERPRISE_USER_SCHEMA]
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
|
||||
|
||||
@@ -22,4 +23,4 @@ class OktaProvider(ScimProvider):
|
||||
|
||||
@property
|
||||
def ignored_patch_paths(self) -> frozenset[str]:
|
||||
return frozenset({"id", "schemas", "meta"})
|
||||
return COMMON_IGNORED_PATCH_PATHS
|
||||
|
||||
@@ -4,6 +4,7 @@ Pre-built at import time — these never change at runtime. Separated from
|
||||
api.py to keep the endpoint module focused on request handling.
|
||||
"""
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_GROUP_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimResourceType
|
||||
@@ -20,6 +21,9 @@ USER_RESOURCE_TYPE = ScimResourceType.model_validate(
|
||||
"endpoint": "/scim/v2/Users",
|
||||
"description": "SCIM User resource",
|
||||
"schema": SCIM_USER_SCHEMA,
|
||||
"schemaExtensions": [
|
||||
{"schema": SCIM_ENTERPRISE_USER_SCHEMA, "required": False}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
@@ -104,6 +108,31 @@ USER_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
],
|
||||
)
|
||||
|
||||
ENTERPRISE_USER_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
id=SCIM_ENTERPRISE_USER_SCHEMA,
|
||||
name="EnterpriseUser",
|
||||
description="Enterprise User extension (RFC 7643 §4.3)",
|
||||
attributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="department",
|
||||
type="string",
|
||||
description="Department.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="manager",
|
||||
type="complex",
|
||||
description="The user's manager.",
|
||||
subAttributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="value",
|
||||
type="string",
|
||||
description="Manager user ID.",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
GROUP_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
id=SCIM_GROUP_SCHEMA,
|
||||
name="Group",
|
||||
|
||||
@@ -22,6 +22,7 @@ from onyx.document_index.vespa_constants import HIDDEN
|
||||
from onyx.document_index.vespa_constants import IMAGE_FILE_NAME
|
||||
from onyx.document_index.vespa_constants import METADATA_LIST
|
||||
from onyx.document_index.vespa_constants import METADATA_SUFFIX
|
||||
from onyx.document_index.vespa_constants import PERSONAS
|
||||
from onyx.document_index.vespa_constants import PRIMARY_OWNERS
|
||||
from onyx.document_index.vespa_constants import SECONDARY_OWNERS
|
||||
from onyx.document_index.vespa_constants import SEMANTIC_IDENTIFIER
|
||||
@@ -58,6 +59,7 @@ FIELDS_NEEDED_FOR_TRANSFORMATION: list[str] = [
|
||||
METADATA_SUFFIX,
|
||||
DOCUMENT_SETS,
|
||||
USER_PROJECT,
|
||||
PERSONAS,
|
||||
PRIMARY_OWNERS,
|
||||
SECONDARY_OWNERS,
|
||||
ACCESS_CONTROL_LIST,
|
||||
@@ -276,6 +278,7 @@ def transform_vespa_chunks_to_opensearch_chunks(
|
||||
)
|
||||
)
|
||||
user_projects: list[int] | None = vespa_chunk.get(USER_PROJECT)
|
||||
personas: list[int] | None = vespa_chunk.get(PERSONAS)
|
||||
primary_owners: list[str] | None = vespa_chunk.get(PRIMARY_OWNERS)
|
||||
secondary_owners: list[str] | None = vespa_chunk.get(SECONDARY_OWNERS)
|
||||
|
||||
@@ -325,6 +328,7 @@ def transform_vespa_chunks_to_opensearch_chunks(
|
||||
metadata_suffix=metadata_suffix,
|
||||
document_sets=document_sets,
|
||||
user_projects=user_projects,
|
||||
personas=personas,
|
||||
primary_owners=primary_owners,
|
||||
secondary_owners=secondary_owners,
|
||||
tenant_id=tenant_state,
|
||||
|
||||
@@ -12,6 +12,7 @@ from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from retry import retry
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
@@ -712,7 +713,10 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
|
||||
db_session.execute(
|
||||
select(UserFile.id).where(
|
||||
sa.and_(
|
||||
UserFile.needs_project_sync.is_(True),
|
||||
sa.or_(
|
||||
UserFile.needs_project_sync.is_(True),
|
||||
UserFile.needs_persona_sync.is_(True),
|
||||
),
|
||||
UserFile.status == UserFileStatus.COMPLETED,
|
||||
)
|
||||
)
|
||||
@@ -772,7 +776,11 @@ def process_single_user_file_project_sync(
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
user_file = db_session.execute(
|
||||
select(UserFile)
|
||||
.where(UserFile.id == _as_uuid(user_file_id))
|
||||
.options(selectinload(UserFile.assistants))
|
||||
).scalar_one_or_none()
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - User file not found id={user_file_id}"
|
||||
@@ -800,13 +808,17 @@ def process_single_user_file_project_sync(
|
||||
]
|
||||
|
||||
project_ids = [project.id for project in user_file.projects]
|
||||
persona_ids = [p.id for p in user_file.assistants if not p.deleted]
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.update_single(
|
||||
doc_id=str(user_file.id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=user_file.chunk_count,
|
||||
fields=None,
|
||||
user_fields=VespaDocumentUserFields(user_projects=project_ids),
|
||||
user_fields=VespaDocumentUserFields(
|
||||
user_projects=project_ids,
|
||||
personas=persona_ids,
|
||||
),
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
@@ -814,6 +826,7 @@ def process_single_user_file_project_sync(
|
||||
)
|
||||
|
||||
user_file.needs_project_sync = False
|
||||
user_file.needs_persona_sync = False
|
||||
user_file.last_project_sync_at = datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
)
|
||||
|
||||
@@ -210,10 +210,10 @@ AUTH_COOKIE_EXPIRE_TIME_SECONDS = int(
|
||||
REQUIRE_EMAIL_VERIFICATION = (
|
||||
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
|
||||
)
|
||||
SMTP_SERVER = os.environ.get("SMTP_SERVER") or "smtp.gmail.com"
|
||||
SMTP_SERVER = os.environ.get("SMTP_SERVER") or ""
|
||||
SMTP_PORT = int(os.environ.get("SMTP_PORT") or "587")
|
||||
SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
|
||||
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
|
||||
SMTP_USER = os.environ.get("SMTP_USER") or ""
|
||||
SMTP_PASS = os.environ.get("SMTP_PASS") or ""
|
||||
EMAIL_FROM = os.environ.get("EMAIL_FROM") or SMTP_USER
|
||||
|
||||
SENDGRID_API_KEY = os.environ.get("SENDGRID_API_KEY") or ""
|
||||
@@ -367,7 +367,7 @@ REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
|
||||
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
|
||||
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
|
||||
|
||||
# this assumes that other redis settings remain the same as the primary
|
||||
# this assumes that other redis settings remain the same as the primary
|
||||
REDIS_REPLICA_HOST = os.environ.get("REDIS_REPLICA_HOST") or REDIS_HOST
|
||||
|
||||
REDIS_AUTH_KEY_PREFIX = "fastapi_users_token:"
|
||||
|
||||
@@ -16,6 +16,22 @@ from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_RATE_LIMIT_REASONS = {"userRateLimitExceeded", "rateLimitExceeded"}
|
||||
|
||||
|
||||
def _is_rate_limit_error(error: HttpError) -> bool:
|
||||
"""Google sometimes returns rate-limit errors as 403 with reason
|
||||
'userRateLimitExceeded' instead of 429. This helper detects both."""
|
||||
if error.resp.status == 429:
|
||||
return True
|
||||
if error.resp.status != 403:
|
||||
return False
|
||||
error_details = getattr(error, "error_details", None) or []
|
||||
for detail in error_details:
|
||||
if isinstance(detail, dict) and detail.get("reason") in _RATE_LIMIT_REASONS:
|
||||
return True
|
||||
return "userRateLimitExceeded" in str(error) or "rateLimitExceeded" in str(error)
|
||||
|
||||
|
||||
# Google Drive APIs are quite flakey and may 500 for an
|
||||
# extended period of time. This is now addressed by checkpointing.
|
||||
@@ -57,7 +73,7 @@ def _execute_with_retry(request: Any) -> Any:
|
||||
except HttpError as error:
|
||||
attempt += 1
|
||||
|
||||
if error.resp.status == 429:
|
||||
if _is_rate_limit_error(error):
|
||||
# Attempt to get 'Retry-After' from headers
|
||||
retry_after = error.resp.get("Retry-After")
|
||||
if retry_after:
|
||||
@@ -140,16 +156,16 @@ def _execute_single_retrieval(
|
||||
)
|
||||
logger.error(f"Error executing request: {e}")
|
||||
raise e
|
||||
elif _is_rate_limit_error(e):
|
||||
results = _execute_with_retry(
|
||||
lambda: retrieval_function(**request_kwargs).execute()
|
||||
)
|
||||
elif e.resp.status == 404 or e.resp.status == 403:
|
||||
if continue_on_404_or_403:
|
||||
logger.debug(f"Error executing request: {e}")
|
||||
results = {}
|
||||
else:
|
||||
raise e
|
||||
elif e.resp.status == 429:
|
||||
results = _execute_with_retry(
|
||||
lambda: retrieval_function(**request_kwargs).execute()
|
||||
)
|
||||
else:
|
||||
logger.exception("Error executing request:")
|
||||
raise e
|
||||
|
||||
@@ -147,7 +147,9 @@ class DriveItemData(BaseModel):
|
||||
self.id,
|
||||
ResourcePath("items", ResourcePath(self.drive_id, ResourcePath("drives"))),
|
||||
)
|
||||
return DriveItem(graph_client, path)
|
||||
item = DriveItem(graph_client, path)
|
||||
item.set_property("id", self.id)
|
||||
return item
|
||||
|
||||
|
||||
# The office365 library's ClientContext caches the access token from its
|
||||
|
||||
@@ -11,6 +11,7 @@ from dateutil import parser
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
@@ -258,3 +259,21 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
slim_doc_batch = []
|
||||
if slim_doc_batch:
|
||||
yield slim_doc_batch
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""
|
||||
Very basic validation, we could do more here
|
||||
"""
|
||||
if not self.base_url.startswith("https://") and not self.base_url.startswith(
|
||||
"http://"
|
||||
):
|
||||
raise ConnectorValidationError(
|
||||
"Base URL must start with https:// or http://"
|
||||
)
|
||||
|
||||
try:
|
||||
get_all_post_ids(self.slab_bot_token)
|
||||
except ConnectorMissingCredentialError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise ConnectorValidationError(f"Failed to fetch posts from Slab: {e}")
|
||||
|
||||
@@ -72,6 +72,7 @@ class BaseFilters(BaseModel):
|
||||
class UserFileFilters(BaseModel):
|
||||
user_file_ids: list[UUID] | None = None
|
||||
project_id: int | None = None
|
||||
persona_id: int | None = None
|
||||
|
||||
|
||||
class AssistantKnowledgeFilters(BaseModel):
|
||||
|
||||
@@ -40,6 +40,7 @@ def _build_index_filters(
|
||||
user_provided_filters: BaseFilters | None,
|
||||
user: User, # Used for ACLs, anonymous users only see public docs
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
user_file_ids: list[UUID] | None,
|
||||
persona_document_sets: list[str] | None,
|
||||
persona_time_cutoff: datetime | None,
|
||||
@@ -118,6 +119,7 @@ def _build_index_filters(
|
||||
final_filters = IndexFilters(
|
||||
user_file_ids=user_file_ids,
|
||||
project_id=project_id,
|
||||
persona_id=persona_id,
|
||||
source_type=source_filter,
|
||||
document_set=document_set_filter,
|
||||
time_cutoff=time_filter,
|
||||
@@ -265,6 +267,8 @@ def search_pipeline(
|
||||
llm: LLM | None = None,
|
||||
# If a project ID is provided, it will be exclusively scoped to that project
|
||||
project_id: int | None = None,
|
||||
# If a persona_id is provided, search scopes to files attached to this persona
|
||||
persona_id: int | None = None,
|
||||
# Pre-fetched data — when provided, avoids DB queries (no session needed)
|
||||
acl_filters: list[str] | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
@@ -299,6 +303,7 @@ def search_pipeline(
|
||||
user_provided_filters=chunk_search_request.user_selected_filters,
|
||||
user=user,
|
||||
project_id=project_id,
|
||||
persona_id=persona_id,
|
||||
user_file_ids=user_uploaded_persona_files,
|
||||
persona_document_sets=persona_document_sets,
|
||||
persona_time_cutoff=persona_time_cutoff,
|
||||
|
||||
21
backend/onyx/db/code_interpreter.py
Normal file
21
backend/onyx/db/code_interpreter.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import CodeInterpreterServer
|
||||
|
||||
|
||||
def fetch_code_interpreter_server(
|
||||
db_session: Session,
|
||||
) -> CodeInterpreterServer:
|
||||
server = db_session.scalars(select(CodeInterpreterServer)).one()
|
||||
return server
|
||||
|
||||
|
||||
def update_code_interpreter_server_enabled(
|
||||
db_session: Session,
|
||||
enabled: bool,
|
||||
) -> CodeInterpreterServer:
|
||||
server = db_session.scalars(select(CodeInterpreterServer)).one()
|
||||
server.server_enabled = enabled
|
||||
db_session.commit()
|
||||
return server
|
||||
@@ -4270,6 +4270,9 @@ class UserFile(Base):
|
||||
needs_project_sync: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
needs_persona_sync: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
last_project_sync_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
@@ -765,6 +765,9 @@ def mark_persona_as_deleted(
|
||||
) -> None:
|
||||
persona = get_persona_by_id(persona_id=persona_id, user=user, db_session=db_session)
|
||||
persona.deleted = True
|
||||
affected_file_ids = [uf.id for uf in persona.user_files]
|
||||
if affected_file_ids:
|
||||
_mark_files_need_persona_sync(db_session, affected_file_ids)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -776,11 +779,13 @@ def mark_persona_as_not_deleted(
|
||||
persona = get_persona_by_id(
|
||||
persona_id=persona_id, user=user, db_session=db_session, include_deleted=True
|
||||
)
|
||||
if persona.deleted:
|
||||
persona.deleted = False
|
||||
db_session.commit()
|
||||
else:
|
||||
if not persona.deleted:
|
||||
raise ValueError(f"Persona with ID {persona_id} is not deleted.")
|
||||
persona.deleted = False
|
||||
affected_file_ids = [uf.id for uf in persona.user_files]
|
||||
if affected_file_ids:
|
||||
_mark_files_need_persona_sync(db_session, affected_file_ids)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_delete_persona_by_name(
|
||||
@@ -846,6 +851,20 @@ def update_personas_display_priority(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _mark_files_need_persona_sync(
|
||||
db_session: Session,
|
||||
user_file_ids: list[UUID],
|
||||
) -> None:
|
||||
"""Flag the given UserFile rows so the background sync task picks them up
|
||||
and updates their persona metadata in the vector DB."""
|
||||
if not user_file_ids:
|
||||
return
|
||||
db_session.query(UserFile).filter(UserFile.id.in_(user_file_ids)).update(
|
||||
{UserFile.needs_persona_sync: True},
|
||||
synchronize_session=False,
|
||||
)
|
||||
|
||||
|
||||
def upsert_persona(
|
||||
user: User | None,
|
||||
name: str,
|
||||
@@ -1034,8 +1053,13 @@ def upsert_persona(
|
||||
existing_persona.tools = tools or []
|
||||
|
||||
if user_file_ids is not None:
|
||||
old_file_ids = {uf.id for uf in existing_persona.user_files}
|
||||
new_file_ids = {uf.id for uf in (user_files or [])}
|
||||
affected_file_ids = old_file_ids | new_file_ids
|
||||
existing_persona.user_files.clear()
|
||||
existing_persona.user_files = user_files or []
|
||||
if affected_file_ids:
|
||||
_mark_files_need_persona_sync(db_session, list(affected_file_ids))
|
||||
|
||||
if hierarchy_node_ids is not None:
|
||||
existing_persona.hierarchy_nodes.clear()
|
||||
@@ -1089,6 +1113,8 @@ def upsert_persona(
|
||||
attached_documents=attached_documents or [],
|
||||
)
|
||||
db_session.add(new_persona)
|
||||
if user_files:
|
||||
_mark_files_need_persona_sync(db_session, [uf.id for uf in user_files])
|
||||
persona = new_persona
|
||||
if commit:
|
||||
db_session.commit()
|
||||
|
||||
@@ -2,6 +2,7 @@ import random
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from logging import getLogger
|
||||
from uuid import UUID
|
||||
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.chat import create_chat_session
|
||||
@@ -13,18 +14,26 @@ from onyx.db.models import ChatSession
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def seed_chat_history(num_sessions: int, num_messages: int, days: int) -> None:
|
||||
def seed_chat_history(
|
||||
num_sessions: int,
|
||||
num_messages: int,
|
||||
days: int,
|
||||
user_id: UUID | None = None,
|
||||
persona_id: int | None = None,
|
||||
) -> None:
|
||||
"""Utility function to seed chat history for testing.
|
||||
|
||||
num_sessions: the number of sessions to seed
|
||||
num_messages: the number of messages to seed per sessions
|
||||
days: the number of days looking backwards from the current time over which to randomize
|
||||
the times.
|
||||
user_id: optional user to associate with sessions
|
||||
persona_id: optional persona/assistant to associate with sessions
|
||||
"""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
logger.info(f"Seeding {num_sessions} sessions.")
|
||||
for y in range(0, num_sessions):
|
||||
create_chat_session(db_session, f"pytest_session_{y}", None, None)
|
||||
create_chat_session(db_session, f"pytest_session_{y}", user_id, persona_id)
|
||||
|
||||
# randomize all session times
|
||||
logger.info(f"Seeding {num_messages} messages per session.")
|
||||
|
||||
@@ -3,6 +3,7 @@ from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import UserFile
|
||||
@@ -64,6 +65,23 @@ def fetch_user_project_ids_for_user_files(
|
||||
}
|
||||
|
||||
|
||||
def fetch_persona_ids_for_user_files(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, list[int]]:
|
||||
"""Fetch persona (assistant) ids for specified user files."""
|
||||
stmt = (
|
||||
select(UserFile)
|
||||
.where(UserFile.id.in_(user_file_ids))
|
||||
.options(selectinload(UserFile.assistants))
|
||||
)
|
||||
results = db_session.execute(stmt).scalars().all()
|
||||
return {
|
||||
str(user_file.id): [persona.id for persona in user_file.assistants]
|
||||
for user_file in results
|
||||
}
|
||||
|
||||
|
||||
def update_last_accessed_at_for_user_files(
|
||||
user_file_ids: list[UUID],
|
||||
db_session: Session,
|
||||
|
||||
@@ -121,6 +121,7 @@ class VespaDocumentUserFields:
|
||||
"""
|
||||
|
||||
user_projects: list[int] | None = None
|
||||
personas: list[int] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -148,6 +148,7 @@ class MetadataUpdateRequest(BaseModel):
|
||||
hidden: bool | None = None
|
||||
secondary_index_updated: bool | None = None
|
||||
project_ids: set[int] | None = None
|
||||
persona_ids: set[int] | None = None
|
||||
|
||||
|
||||
class IndexRetrievalFilters(BaseModel):
|
||||
|
||||
@@ -50,6 +50,7 @@ from onyx.document_index.opensearch.schema import DocumentSchema
|
||||
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
|
||||
from onyx.document_index.opensearch.schema import GLOBAL_BOOST_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import HIDDEN_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import PERSONAS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.search import DocumentQuery
|
||||
from onyx.document_index.opensearch.search import (
|
||||
@@ -215,6 +216,7 @@ def _convert_onyx_chunk_to_opensearch_document(
|
||||
# OpenSearch and it will not store any data at all for this field, which
|
||||
# is different from supplying an empty list.
|
||||
user_projects=chunk.user_project or None,
|
||||
personas=chunk.personas or None,
|
||||
primary_owners=get_experts_stores_representations(
|
||||
chunk.source_document.primary_owners
|
||||
),
|
||||
@@ -362,6 +364,11 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
if user_fields and user_fields.user_projects
|
||||
else None
|
||||
),
|
||||
persona_ids=(
|
||||
set(user_fields.personas)
|
||||
if user_fields and user_fields.personas
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -709,6 +716,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
properties_to_update[USER_PROJECTS_FIELD_NAME] = list(
|
||||
update_request.project_ids
|
||||
)
|
||||
if update_request.persona_ids is not None:
|
||||
properties_to_update[PERSONAS_FIELD_NAME] = list(
|
||||
update_request.persona_ids
|
||||
)
|
||||
|
||||
if not properties_to_update:
|
||||
if len(update_request.document_ids) > 1:
|
||||
|
||||
@@ -41,6 +41,7 @@ IMAGE_FILE_ID_FIELD_NAME = "image_file_id"
|
||||
SOURCE_LINKS_FIELD_NAME = "source_links"
|
||||
DOCUMENT_SETS_FIELD_NAME = "document_sets"
|
||||
USER_PROJECTS_FIELD_NAME = "user_projects"
|
||||
PERSONAS_FIELD_NAME = "personas"
|
||||
DOCUMENT_ID_FIELD_NAME = "document_id"
|
||||
CHUNK_INDEX_FIELD_NAME = "chunk_index"
|
||||
MAX_CHUNK_SIZE_FIELD_NAME = "max_chunk_size"
|
||||
@@ -156,6 +157,7 @@ class DocumentChunk(BaseModel):
|
||||
|
||||
document_sets: list[str] | None = None
|
||||
user_projects: list[int] | None = None
|
||||
personas: list[int] | None = None
|
||||
primary_owners: list[str] | None = None
|
||||
secondary_owners: list[str] | None = None
|
||||
|
||||
@@ -485,6 +487,7 @@ class DocumentSchema:
|
||||
# Product-specific fields.
|
||||
DOCUMENT_SETS_FIELD_NAME: {"type": "keyword"},
|
||||
USER_PROJECTS_FIELD_NAME: {"type": "integer"},
|
||||
PERSONAS_FIELD_NAME: {"type": "integer"},
|
||||
PRIMARY_OWNERS_FIELD_NAME: {"type": "keyword"},
|
||||
SECONDARY_OWNERS_FIELD_NAME: {"type": "keyword"},
|
||||
# OpenSearch metadata fields.
|
||||
|
||||
@@ -28,6 +28,7 @@ from onyx.document_index.opensearch.schema import HIDDEN_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import LAST_UPDATED_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import MAX_CHUNK_SIZE_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import METADATA_LIST_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import PERSONAS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import PUBLIC_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import set_or_convert_timezone_to_utc
|
||||
from onyx.document_index.opensearch.schema import SOURCE_TYPE_FIELD_NAME
|
||||
@@ -144,6 +145,7 @@ class DocumentQuery:
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=min_chunk_index,
|
||||
max_chunk_index=max_chunk_index,
|
||||
@@ -202,6 +204,7 @@ class DocumentQuery:
|
||||
document_sets=[],
|
||||
user_file_ids=[],
|
||||
project_id=None,
|
||||
persona_id=None,
|
||||
time_cutoff=None,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -267,6 +270,7 @@ class DocumentQuery:
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -334,6 +338,7 @@ class DocumentQuery:
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -496,6 +501,7 @@ class DocumentQuery:
|
||||
document_sets: list[str],
|
||||
user_file_ids: list[UUID],
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
time_cutoff: datetime | None,
|
||||
min_chunk_index: int | None,
|
||||
max_chunk_index: int | None,
|
||||
@@ -530,6 +536,8 @@ class DocumentQuery:
|
||||
retrieved.
|
||||
project_id: If not None, only documents with this project ID in user
|
||||
projects will be retrieved.
|
||||
persona_id: If not None, only documents whose personas array
|
||||
contains this persona ID will be retrieved.
|
||||
time_cutoff: Time cutoff for the documents to retrieve. If not None,
|
||||
Documents which were last updated before this date will not be
|
||||
returned. For documents which do not have a value for their last
|
||||
@@ -627,6 +635,9 @@ class DocumentQuery:
|
||||
)
|
||||
return user_project_filter
|
||||
|
||||
def _get_persona_filter(persona_id: int) -> dict[str, Any]:
|
||||
return {"term": {PERSONAS_FIELD_NAME: {"value": persona_id}}}
|
||||
|
||||
def _get_time_cutoff_filter(time_cutoff: datetime) -> dict[str, Any]:
|
||||
# Convert to UTC if not already so the cutoff is comparable to the
|
||||
# document data.
|
||||
@@ -780,6 +791,9 @@ class DocumentQuery:
|
||||
# document's user projects list.
|
||||
filter_clauses.append(_get_user_project_filter(project_id))
|
||||
|
||||
if persona_id is not None:
|
||||
filter_clauses.append(_get_persona_filter(persona_id))
|
||||
|
||||
if time_cutoff is not None:
|
||||
# If a time cutoff is provided, the caller will only retrieve
|
||||
# documents where the document was last updated at or after the time
|
||||
|
||||
@@ -181,6 +181,11 @@ schema {{ schema_name }} {
|
||||
rank: filter
|
||||
attribute: fast-search
|
||||
}
|
||||
field personas type array<int> {
|
||||
indexing: summary | attribute
|
||||
rank: filter
|
||||
attribute: fast-search
|
||||
}
|
||||
}
|
||||
|
||||
# If using different tokenization settings, the fieldset has to be removed, and the field must
|
||||
|
||||
@@ -689,6 +689,9 @@ class VespaIndex(DocumentIndex):
|
||||
project_ids: set[int] | None = None
|
||||
if user_fields is not None and user_fields.user_projects is not None:
|
||||
project_ids = set(user_fields.user_projects)
|
||||
persona_ids: set[int] | None = None
|
||||
if user_fields is not None and user_fields.personas is not None:
|
||||
persona_ids = set(user_fields.personas)
|
||||
update_request = MetadataUpdateRequest(
|
||||
document_ids=[doc_id],
|
||||
doc_id_to_chunk_cnt={
|
||||
@@ -699,6 +702,7 @@ class VespaIndex(DocumentIndex):
|
||||
boost=fields.boost if fields is not None else None,
|
||||
hidden=fields.hidden if fields is not None else None,
|
||||
project_ids=project_ids,
|
||||
persona_ids=persona_ids,
|
||||
)
|
||||
|
||||
vespa_document_index.update([update_request])
|
||||
|
||||
@@ -46,6 +46,7 @@ from onyx.document_index.vespa_constants import METADATA
|
||||
from onyx.document_index.vespa_constants import METADATA_LIST
|
||||
from onyx.document_index.vespa_constants import METADATA_SUFFIX
|
||||
from onyx.document_index.vespa_constants import NUM_THREADS
|
||||
from onyx.document_index.vespa_constants import PERSONAS
|
||||
from onyx.document_index.vespa_constants import PRIMARY_OWNERS
|
||||
from onyx.document_index.vespa_constants import SECONDARY_OWNERS
|
||||
from onyx.document_index.vespa_constants import SECTION_CONTINUATION
|
||||
@@ -218,6 +219,7 @@ def _index_vespa_chunk(
|
||||
# still called `image_file_name` in Vespa for backwards compatibility
|
||||
IMAGE_FILE_NAME: chunk.image_file_id,
|
||||
USER_PROJECT: chunk.user_project if chunk.user_project is not None else [],
|
||||
PERSONAS: chunk.personas if chunk.personas is not None else [],
|
||||
BOOST: chunk.boost,
|
||||
AGGREGATED_CHUNK_BOOST_FACTOR: chunk.aggregated_chunk_boost_factor,
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.document_index.vespa_constants import DOCUMENT_ID
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_SETS
|
||||
from onyx.document_index.vespa_constants import HIDDEN
|
||||
from onyx.document_index.vespa_constants import METADATA_LIST
|
||||
from onyx.document_index.vespa_constants import PERSONAS
|
||||
from onyx.document_index.vespa_constants import SOURCE_TYPE
|
||||
from onyx.document_index.vespa_constants import TENANT_ID
|
||||
from onyx.document_index.vespa_constants import USER_PROJECT
|
||||
@@ -149,6 +150,18 @@ def build_vespa_filters(
|
||||
# Vespa YQL 'contains' expects a string literal; quote the integer
|
||||
return f'({USER_PROJECT} contains "{pid}") and '
|
||||
|
||||
def _build_persona_filter(
|
||||
persona_id: int | None,
|
||||
) -> str:
|
||||
if persona_id is None:
|
||||
return ""
|
||||
try:
|
||||
pid = int(persona_id)
|
||||
except Exception:
|
||||
logger.warning(f"Invalid persona ID: {persona_id}")
|
||||
return ""
|
||||
return f'({PERSONAS} contains "{pid}") and '
|
||||
|
||||
# Start building the filter string
|
||||
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
|
||||
|
||||
@@ -192,6 +205,9 @@ def build_vespa_filters(
|
||||
# User project filter (array<int> attribute membership)
|
||||
filter_str += _build_user_project_filter(filters.project_id)
|
||||
|
||||
# Persona filter (array<int> attribute membership)
|
||||
filter_str += _build_persona_filter(filters.persona_id)
|
||||
|
||||
# Time filter
|
||||
filter_str += _build_time_filter(filters.time_cutoff)
|
||||
|
||||
|
||||
@@ -183,6 +183,10 @@ def _update_single_chunk(
|
||||
model_config = {"frozen": True}
|
||||
assign: list[int]
|
||||
|
||||
class _Personas(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
assign: list[int]
|
||||
|
||||
class _VespaPutFields(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
# The names of these fields are based the Vespa schema. Changes to the
|
||||
@@ -193,6 +197,7 @@ def _update_single_chunk(
|
||||
access_control_list: _AccessControl | None = None
|
||||
hidden: _Hidden | None = None
|
||||
user_project: _UserProjects | None = None
|
||||
personas: _Personas | None = None
|
||||
|
||||
class _VespaPutRequest(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
@@ -227,6 +232,11 @@ def _update_single_chunk(
|
||||
if update_request.project_ids is not None
|
||||
else None
|
||||
)
|
||||
personas_update: _Personas | None = (
|
||||
_Personas(assign=list(update_request.persona_ids))
|
||||
if update_request.persona_ids is not None
|
||||
else None
|
||||
)
|
||||
|
||||
vespa_put_fields = _VespaPutFields(
|
||||
boost=boost_update,
|
||||
@@ -234,6 +244,7 @@ def _update_single_chunk(
|
||||
access_control_list=access_update,
|
||||
hidden=hidden_update,
|
||||
user_project=user_projects_update,
|
||||
personas=personas_update,
|
||||
)
|
||||
|
||||
vespa_put_request = _VespaPutRequest(
|
||||
|
||||
@@ -58,6 +58,7 @@ DOCUMENT_SETS = "document_sets"
|
||||
USER_FILE = "user_file"
|
||||
USER_FOLDER = "user_folder"
|
||||
USER_PROJECT = "user_project"
|
||||
PERSONAS = "personas"
|
||||
LARGE_CHUNK_REFERENCE_IDS = "large_chunk_reference_ids"
|
||||
METADATA = "metadata"
|
||||
METADATA_LIST = "metadata_list"
|
||||
|
||||
@@ -146,6 +146,7 @@ class DocumentIndexingBatchAdapter:
|
||||
doc_id_to_document_set.get(chunk.source_document.id, [])
|
||||
),
|
||||
user_project=[],
|
||||
personas=[],
|
||||
boost=(
|
||||
context.id_to_boost_map[chunk.source_document.id]
|
||||
if chunk.source_document.id in context.id_to_boost_map
|
||||
|
||||
@@ -20,6 +20,7 @@ from onyx.db.models import Persona
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.user_file import fetch_chunk_counts_for_user_files
|
||||
from onyx.db.user_file import fetch_persona_ids_for_user_files
|
||||
from onyx.db.user_file import fetch_user_project_ids_for_user_files
|
||||
from onyx.file_store.utils import store_user_file_plaintext
|
||||
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
|
||||
@@ -119,6 +120,10 @@ class UserFileIndexingAdapter:
|
||||
user_file_ids=updatable_ids,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
user_file_id_to_persona_ids = fetch_persona_ids_for_user_files(
|
||||
user_file_ids=updatable_ids,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
user_file_id_to_access: dict[str, DocumentAccess] = get_access_for_user_files(
|
||||
user_file_ids=updatable_ids,
|
||||
db_session=self.db_session,
|
||||
@@ -182,7 +187,7 @@ class UserFileIndexingAdapter:
|
||||
user_project=user_file_id_to_project_ids.get(
|
||||
chunk.source_document.id, []
|
||||
),
|
||||
# we are going to index userfiles only once, so we just set the boost to the default
|
||||
personas=user_file_id_to_persona_ids.get(chunk.source_document.id, []),
|
||||
boost=DEFAULT_BOOST,
|
||||
tenant_id=tenant_id,
|
||||
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
|
||||
|
||||
@@ -112,6 +112,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
access: "DocumentAccess"
|
||||
document_sets: set[str]
|
||||
user_project: list[int]
|
||||
personas: list[int]
|
||||
boost: int
|
||||
aggregated_chunk_boost_factor: float
|
||||
# Full ancestor path from root hierarchy node to document's parent.
|
||||
@@ -126,6 +127,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
access: "DocumentAccess",
|
||||
document_sets: set[str],
|
||||
user_project: list[int],
|
||||
personas: list[int],
|
||||
boost: int,
|
||||
aggregated_chunk_boost_factor: float,
|
||||
tenant_id: str,
|
||||
@@ -137,6 +139,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
access=access,
|
||||
document_sets=document_sets,
|
||||
user_project=user_project,
|
||||
personas=personas,
|
||||
boost=boost,
|
||||
aggregated_chunk_boost_factor=aggregated_chunk_boost_factor,
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@@ -97,6 +97,9 @@ from onyx.server.features.web_search.api import router as web_search_router
|
||||
from onyx.server.federated.api import router as federated_router
|
||||
from onyx.server.kg.api import admin_router as kg_admin_router
|
||||
from onyx.server.manage.administrative import router as admin_router
|
||||
from onyx.server.manage.code_interpreter.api import (
|
||||
admin_router as code_interpreter_admin_router,
|
||||
)
|
||||
from onyx.server.manage.discord_bot.api import router as discord_bot_router
|
||||
from onyx.server.manage.embedding.api import admin_router as embedding_admin_router
|
||||
from onyx.server.manage.embedding.api import basic_router as embedding_router
|
||||
@@ -421,6 +424,9 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, llm_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, kg_admin_router)
|
||||
include_router_with_global_prefix_prepended(application, llm_router)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, code_interpreter_admin_router
|
||||
)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, image_generation_admin_router
|
||||
)
|
||||
|
||||
@@ -592,11 +592,8 @@ def build_slack_response_blocks(
|
||||
)
|
||||
|
||||
citations_blocks = []
|
||||
document_blocks = []
|
||||
if answer.citation_info:
|
||||
citations_blocks = _build_citations_blocks(answer)
|
||||
else:
|
||||
document_blocks = _priority_ordered_documents_blocks(answer)
|
||||
|
||||
citations_divider = [DividerBlock()] if citations_blocks else []
|
||||
buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else []
|
||||
@@ -608,7 +605,6 @@ def build_slack_response_blocks(
|
||||
+ ai_feedback_block
|
||||
+ citations_divider
|
||||
+ citations_blocks
|
||||
+ document_blocks
|
||||
+ buttons_divider
|
||||
+ web_follow_up_block
|
||||
+ follow_up_block
|
||||
|
||||
@@ -1,20 +1,149 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from mistune import create_markdown
|
||||
from mistune import HTMLRenderer
|
||||
|
||||
# Tags that should be replaced with a newline (line-break and block-level elements)
|
||||
_HTML_NEWLINE_TAG_PATTERN = re.compile(
|
||||
r"<br\s*/?>|</(?:p|div|li|h[1-6]|tr|blockquote|section|article)>",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Strips HTML tags but excludes autolinks like <https://...> and <mailto:...>
|
||||
_HTML_TAG_PATTERN = re.compile(
|
||||
r"<(?!https?://|mailto:)/?[a-zA-Z][^>]*>",
|
||||
)
|
||||
|
||||
# Matches fenced code blocks (``` ... ```) so we can skip sanitization inside them
|
||||
_FENCED_CODE_BLOCK_PATTERN = re.compile(r"```[\s\S]*?```")
|
||||
|
||||
# Matches the start of any markdown link: [text]( or [[n]](
|
||||
# The inner group handles nested brackets for citation links like [[1]](.
|
||||
_MARKDOWN_LINK_PATTERN = re.compile(r"\[(?:[^\[\]]|\[[^\]]*\])*\]\(")
|
||||
|
||||
# Matches Slack-style links <url|text> that LLMs sometimes output directly.
|
||||
# Mistune doesn't recognise this syntax, so text() would escape the angle
|
||||
# brackets and Slack would render them as literal text instead of links.
|
||||
_SLACK_LINK_PATTERN = re.compile(r"<(https?://[^|>]+)\|([^>]+)>")
|
||||
|
||||
|
||||
def _sanitize_html(text: str) -> str:
|
||||
"""Strip HTML tags from a text fragment.
|
||||
|
||||
Block-level closing tags and <br> are converted to newlines.
|
||||
All other HTML tags are removed. Autolinks (<https://...>) are preserved.
|
||||
"""
|
||||
text = _HTML_NEWLINE_TAG_PATTERN.sub("\n", text)
|
||||
text = _HTML_TAG_PATTERN.sub("", text)
|
||||
return text
|
||||
|
||||
|
||||
def _transform_outside_code_blocks(
|
||||
message: str, transform: Callable[[str], str]
|
||||
) -> str:
|
||||
"""Apply *transform* only to text outside fenced code blocks."""
|
||||
parts = _FENCED_CODE_BLOCK_PATTERN.split(message)
|
||||
code_blocks = _FENCED_CODE_BLOCK_PATTERN.findall(message)
|
||||
|
||||
result: list[str] = []
|
||||
for i, part in enumerate(parts):
|
||||
result.append(transform(part))
|
||||
if i < len(code_blocks):
|
||||
result.append(code_blocks[i])
|
||||
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int | None]:
|
||||
"""Extract markdown link destination, allowing nested parentheses in the URL."""
|
||||
depth = 0
|
||||
i = start_idx
|
||||
|
||||
while i < len(message):
|
||||
curr = message[i]
|
||||
if curr == "\\":
|
||||
i += 2
|
||||
continue
|
||||
|
||||
if curr == "(":
|
||||
depth += 1
|
||||
elif curr == ")":
|
||||
if depth == 0:
|
||||
return message[start_idx:i], i
|
||||
depth -= 1
|
||||
i += 1
|
||||
|
||||
return message[start_idx:], None
|
||||
|
||||
|
||||
def _normalize_link_destinations(message: str) -> str:
|
||||
"""Wrap markdown link URLs in angle brackets so the parser handles special chars safely.
|
||||
|
||||
Markdown link syntax [text](url) breaks when the URL contains unescaped
|
||||
parentheses, spaces, or other special characters. Wrapping the URL in angle
|
||||
brackets — [text](<url>) — tells the parser to treat everything inside as
|
||||
a literal URL. This applies to all links, not just citations.
|
||||
"""
|
||||
if "](" not in message:
|
||||
return message
|
||||
|
||||
normalized_parts: list[str] = []
|
||||
cursor = 0
|
||||
|
||||
while match := _MARKDOWN_LINK_PATTERN.search(message, cursor):
|
||||
normalized_parts.append(message[cursor : match.end()])
|
||||
destination_start = match.end()
|
||||
destination, end_idx = _extract_link_destination(message, destination_start)
|
||||
if end_idx is None:
|
||||
normalized_parts.append(message[destination_start:])
|
||||
return "".join(normalized_parts)
|
||||
|
||||
already_wrapped = destination.startswith("<") and destination.endswith(">")
|
||||
if destination and not already_wrapped:
|
||||
destination = f"<{destination}>"
|
||||
|
||||
normalized_parts.append(destination)
|
||||
normalized_parts.append(")")
|
||||
cursor = end_idx + 1
|
||||
|
||||
normalized_parts.append(message[cursor:])
|
||||
return "".join(normalized_parts)
|
||||
|
||||
|
||||
def _convert_slack_links_to_markdown(message: str) -> str:
|
||||
"""Convert Slack-style <url|text> links to standard markdown [text](url).
|
||||
|
||||
LLMs sometimes emit Slack mrkdwn link syntax directly. Mistune doesn't
|
||||
recognise it, so the angle brackets would be escaped by text() and Slack
|
||||
would render the link as literal text instead of a clickable link.
|
||||
"""
|
||||
return _transform_outside_code_blocks(
|
||||
message, lambda text: _SLACK_LINK_PATTERN.sub(r"[\2](\1)", text)
|
||||
)
|
||||
|
||||
|
||||
def format_slack_message(message: str | None) -> str:
|
||||
if message is None:
|
||||
return ""
|
||||
message = _transform_outside_code_blocks(message, _sanitize_html)
|
||||
message = _convert_slack_links_to_markdown(message)
|
||||
normalized_message = _normalize_link_destinations(message)
|
||||
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
|
||||
result = md(message)
|
||||
result = md(normalized_message)
|
||||
# With HTMLRenderer, result is always str (not AST list)
|
||||
assert isinstance(result, str)
|
||||
return result
|
||||
return result.rstrip("\n")
|
||||
|
||||
|
||||
class SlackRenderer(HTMLRenderer):
|
||||
"""Renders markdown as Slack mrkdwn format instead of HTML.
|
||||
|
||||
Overrides all HTMLRenderer methods that produce HTML tags to ensure
|
||||
no raw HTML ever appears in Slack messages.
|
||||
"""
|
||||
|
||||
SPECIALS: dict[str, str] = {"&": "&", "<": "<", ">": ">"}
|
||||
|
||||
def escape_special(self, text: str) -> str:
|
||||
@@ -23,7 +152,7 @@ class SlackRenderer(HTMLRenderer):
|
||||
return text
|
||||
|
||||
def heading(self, text: str, level: int, **attrs: Any) -> str: # noqa: ARG002
|
||||
return f"*{text}*\n"
|
||||
return f"*{text}*\n\n"
|
||||
|
||||
def emphasis(self, text: str) -> str:
|
||||
return f"_{text}_"
|
||||
@@ -42,7 +171,7 @@ class SlackRenderer(HTMLRenderer):
|
||||
count += 1
|
||||
prefix = f"{count}. " if ordered else "• "
|
||||
lines[i] = f"{prefix}{line[4:]}"
|
||||
return "\n".join(lines)
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
def list_item(self, text: str) -> str:
|
||||
return f"li: {text}\n"
|
||||
@@ -64,7 +193,30 @@ class SlackRenderer(HTMLRenderer):
|
||||
return f"`{text}`"
|
||||
|
||||
def block_code(self, code: str, info: str | None = None) -> str: # noqa: ARG002
|
||||
return f"```\n{code}\n```\n"
|
||||
return f"```\n{code.rstrip(chr(10))}\n```\n\n"
|
||||
|
||||
def linebreak(self) -> str:
|
||||
return "\n"
|
||||
|
||||
def thematic_break(self) -> str:
|
||||
return "---\n\n"
|
||||
|
||||
def block_quote(self, text: str) -> str:
|
||||
lines = text.strip().split("\n")
|
||||
quoted = "\n".join(f">{line}" for line in lines)
|
||||
return quoted + "\n\n"
|
||||
|
||||
def block_html(self, html: str) -> str:
|
||||
return _sanitize_html(html) + "\n\n"
|
||||
|
||||
def block_error(self, text: str) -> str:
|
||||
return f"```\n{text}\n```\n\n"
|
||||
|
||||
def text(self, text: str) -> str:
|
||||
# Only escape the three entities Slack recognizes: & < >
|
||||
# HTMLRenderer.text() also escapes " to " which Slack renders
|
||||
# as literal " text since Slack doesn't recognize that entity.
|
||||
return self.escape_special(text)
|
||||
|
||||
def paragraph(self, text: str) -> str:
|
||||
return f"{text}\n"
|
||||
return f"{text}\n\n"
|
||||
|
||||
Binary file not shown.
@@ -1,15 +1,19 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate AGENTS.md by scanning the files directory and populating the template.
|
||||
|
||||
This script runs at container startup, AFTER the init container has synced files
|
||||
from S3. It scans the /workspace/files directory to discover what knowledge sources
|
||||
are available and generates appropriate documentation.
|
||||
This script runs during session setup, AFTER files have been synced from S3
|
||||
and the files symlink has been created. It reads an existing AGENTS.md (which
|
||||
contains the {{KNOWLEDGE_SOURCES_SECTION}} placeholder), replaces the
|
||||
placeholder by scanning the knowledge source directory, and writes it back.
|
||||
|
||||
Environment variables:
|
||||
- AGENT_INSTRUCTIONS: The template content with placeholders to replace
|
||||
Usage:
|
||||
python3 generate_agents_md.py <agents_md_path> <files_path>
|
||||
|
||||
Arguments:
|
||||
agents_md_path: Path to the AGENTS.md file to update in place
|
||||
files_path: Path to the files directory to scan for knowledge sources
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
@@ -189,49 +193,39 @@ def build_knowledge_sources_section(files_path: Path) -> str:
|
||||
def main() -> None:
|
||||
"""Main entry point for container startup script.
|
||||
|
||||
Is called by the container startup script to scan /workspace/files and populate
|
||||
the knowledge sources section.
|
||||
Reads an existing AGENTS.md, replaces the {{KNOWLEDGE_SOURCES_SECTION}}
|
||||
placeholder by scanning the files directory, and writes it back.
|
||||
|
||||
Usage:
|
||||
python3 generate_agents_md.py <agents_md_path> <files_path>
|
||||
"""
|
||||
# Read template from environment variable
|
||||
template = os.environ.get("AGENT_INSTRUCTIONS", "")
|
||||
if not template:
|
||||
print("Warning: No AGENT_INSTRUCTIONS template provided", file=sys.stderr)
|
||||
template = "# Agent Instructions\n\nNo instructions provided."
|
||||
if len(sys.argv) != 3:
|
||||
print(
|
||||
f"Usage: {sys.argv[0]} <agents_md_path> <files_path>",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Scan files directory - check /workspace/files first, then /workspace/demo_data
|
||||
files_path = Path("/workspace/files")
|
||||
demo_data_path = Path("/workspace/demo_data")
|
||||
agents_md_path = Path(sys.argv[1])
|
||||
files_path = Path(sys.argv[2])
|
||||
|
||||
# Use demo_data if files doesn't exist or is empty
|
||||
if not files_path.exists() or not any(files_path.iterdir()):
|
||||
if demo_data_path.exists():
|
||||
files_path = demo_data_path
|
||||
if not agents_md_path.exists():
|
||||
print(f"Error: {agents_md_path} not found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
knowledge_sources_section = build_knowledge_sources_section(files_path)
|
||||
template = agents_md_path.read_text()
|
||||
|
||||
# Replace placeholders
|
||||
content = template
|
||||
content = content.replace(
|
||||
# Resolve symlinks (handles both direct symlinks and dirs containing symlinks)
|
||||
resolved_files_path = files_path.resolve()
|
||||
|
||||
knowledge_sources_section = build_knowledge_sources_section(resolved_files_path)
|
||||
|
||||
# Replace placeholder and write back
|
||||
content = template.replace(
|
||||
"{{KNOWLEDGE_SOURCES_SECTION}}", knowledge_sources_section
|
||||
)
|
||||
|
||||
# Write AGENTS.md
|
||||
output_path = Path("/workspace/AGENTS.md")
|
||||
output_path.write_text(content)
|
||||
|
||||
# Log result
|
||||
source_count = 0
|
||||
if files_path.exists():
|
||||
source_count = len(
|
||||
[
|
||||
d
|
||||
for d in files_path.iterdir()
|
||||
if d.is_dir() and not d.name.startswith(".")
|
||||
]
|
||||
)
|
||||
print(
|
||||
f"Generated AGENTS.md with {source_count} knowledge sources from {files_path}"
|
||||
)
|
||||
agents_md_path.write_text(content)
|
||||
print(f"Populated knowledge sources in {agents_md_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1352,6 +1352,9 @@ fi
|
||||
echo "Writing AGENTS.md"
|
||||
printf '%s' '{agent_instructions_escaped}' > {session_path}/AGENTS.md
|
||||
|
||||
# Populate knowledge sources by scanning the files directory
|
||||
python3 /usr/local/bin/generate_agents_md.py {session_path}/AGENTS.md {session_path}/files || true
|
||||
|
||||
# Write opencode config
|
||||
echo "Writing opencode.json"
|
||||
printf '%s' '{opencode_json_escaped}' > {session_path}/opencode.json
|
||||
@@ -1780,6 +1783,9 @@ ln -sf {symlink_target} {session_path}/files
|
||||
echo "Writing AGENTS.md"
|
||||
printf '%s' '{agent_instructions_escaped}' > {session_path}/AGENTS.md
|
||||
|
||||
# Populate knowledge sources by scanning the files directory
|
||||
python3 /usr/local/bin/generate_agents_md.py {session_path}/AGENTS.md {session_path}/files || true
|
||||
|
||||
# Write opencode config
|
||||
echo "Writing opencode.json"
|
||||
printf '%s' '{opencode_json_escaped}' > {session_path}/opencode.json
|
||||
|
||||
47
backend/onyx/server/manage/code_interpreter/api.py
Normal file
47
backend/onyx/server/manage/code_interpreter/api.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.db.code_interpreter import fetch_code_interpreter_server
|
||||
from onyx.db.code_interpreter import update_code_interpreter_server_enabled
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.server.manage.code_interpreter.models import CodeInterpreterServer
|
||||
from onyx.server.manage.code_interpreter.models import CodeInterpreterServerHealth
|
||||
from onyx.tools.tool_implementations.python.code_interpreter_client import (
|
||||
CodeInterpreterClient,
|
||||
)
|
||||
|
||||
admin_router = APIRouter(prefix="/admin/code-interpreter")
|
||||
|
||||
|
||||
@admin_router.get("/health")
|
||||
def get_code_interpreter_health(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> CodeInterpreterServerHealth:
|
||||
try:
|
||||
client = CodeInterpreterClient()
|
||||
return CodeInterpreterServerHealth(healthy=client.health())
|
||||
except ValueError:
|
||||
return CodeInterpreterServerHealth(healthy=False)
|
||||
|
||||
|
||||
@admin_router.get("")
|
||||
def get_code_interpreter(
|
||||
_: User = Depends(current_admin_user), db_session: Session = Depends(get_session)
|
||||
) -> CodeInterpreterServer:
|
||||
ci_server = fetch_code_interpreter_server(db_session)
|
||||
return CodeInterpreterServer(enabled=ci_server.server_enabled)
|
||||
|
||||
|
||||
@admin_router.put("")
|
||||
def update_code_interpreter(
|
||||
update: CodeInterpreterServer,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_code_interpreter_server_enabled(
|
||||
db_session=db_session,
|
||||
enabled=update.enabled,
|
||||
)
|
||||
9
backend/onyx/server/manage/code_interpreter/models.py
Normal file
9
backend/onyx/server/manage/code_interpreter/models.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class CodeInterpreterServer(BaseModel):
|
||||
enabled: bool
|
||||
|
||||
|
||||
class CodeInterpreterServerHealth(BaseModel):
|
||||
healthy: bool
|
||||
@@ -35,6 +35,18 @@ if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class EmailInviteStatus(str, Enum):
|
||||
SENT = "SENT"
|
||||
NOT_CONFIGURED = "NOT_CONFIGURED"
|
||||
SEND_FAILED = "SEND_FAILED"
|
||||
DISABLED = "DISABLED"
|
||||
|
||||
|
||||
class BulkInviteResponse(BaseModel):
|
||||
invited_count: int
|
||||
email_invite_status: EmailInviteStatus
|
||||
|
||||
|
||||
class VersionResponse(BaseModel):
|
||||
backend_version: str
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ from onyx.configs.app_configs import AUTH_BACKEND
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import AuthBackend
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import ENABLE_EMAIL_INVITES
|
||||
from onyx.configs.app_configs import NUM_FREE_TRIAL_USER_INVITES
|
||||
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
|
||||
@@ -78,8 +79,10 @@ from onyx.server.documents.models import PaginatedReturn
|
||||
from onyx.server.features.projects.models import UserFileSnapshot
|
||||
from onyx.server.manage.models import AllUsersResponse
|
||||
from onyx.server.manage.models import AutoScrollRequest
|
||||
from onyx.server.manage.models import BulkInviteResponse
|
||||
from onyx.server.manage.models import ChatBackgroundRequest
|
||||
from onyx.server.manage.models import DefaultAppModeRequest
|
||||
from onyx.server.manage.models import EmailInviteStatus
|
||||
from onyx.server.manage.models import MemoryItem
|
||||
from onyx.server.manage.models import PersonalizationUpdateRequest
|
||||
from onyx.server.manage.models import TenantInfo
|
||||
@@ -368,7 +371,7 @@ def bulk_invite_users(
|
||||
emails: list[str] = Body(..., embed=True),
|
||||
current_user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> int:
|
||||
) -> BulkInviteResponse:
|
||||
"""emails are string validated. If any email fails validation, no emails are
|
||||
invited and an exception is raised."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
@@ -427,34 +430,41 @@ def bulk_invite_users(
|
||||
number_of_invited_users = write_invited_users(all_emails)
|
||||
|
||||
# send out email invitations only to new users (not already invited or existing)
|
||||
if ENABLE_EMAIL_INVITES:
|
||||
if not ENABLE_EMAIL_INVITES:
|
||||
email_invite_status = EmailInviteStatus.DISABLED
|
||||
elif not EMAIL_CONFIGURED:
|
||||
email_invite_status = EmailInviteStatus.NOT_CONFIGURED
|
||||
else:
|
||||
try:
|
||||
for email in emails_needing_seats:
|
||||
send_user_email_invite(email, current_user, AUTH_TYPE)
|
||||
email_invite_status = EmailInviteStatus.SENT
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending email invite to invited users: {e}")
|
||||
email_invite_status = EmailInviteStatus.SEND_FAILED
|
||||
|
||||
if not MULTI_TENANT or DEV_MODE:
|
||||
return number_of_invited_users
|
||||
if MULTI_TENANT and not DEV_MODE:
|
||||
# for billing purposes, write to the control plane about the number of new users
|
||||
try:
|
||||
logger.info("Registering tenant users")
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.billing", "register_tenant_users", None
|
||||
)(tenant_id, get_live_users_count(db_session))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register tenant users: {str(e)}")
|
||||
logger.info(
|
||||
"Reverting changes: removing users from tenant and resetting invited users"
|
||||
)
|
||||
write_invited_users(initial_invited_users) # Reset to original state
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
|
||||
)(new_invited_emails, tenant_id)
|
||||
raise e
|
||||
|
||||
# for billing purposes, write to the control plane about the number of new users
|
||||
try:
|
||||
logger.info("Registering tenant users")
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.billing", "register_tenant_users", None
|
||||
)(tenant_id, get_live_users_count(db_session))
|
||||
|
||||
return number_of_invited_users
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register tenant users: {str(e)}")
|
||||
logger.info(
|
||||
"Reverting changes: removing users from tenant and resetting invited users"
|
||||
)
|
||||
write_invited_users(initial_invited_users) # Reset to original state
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
|
||||
)(new_invited_emails, tenant_id)
|
||||
raise e
|
||||
return BulkInviteResponse(
|
||||
invited_count=number_of_invited_users,
|
||||
email_invite_status=email_invite_status,
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/manage/admin/remove-invited-user", tags=PUBLIC_API_TAGS)
|
||||
|
||||
@@ -54,6 +54,7 @@ logger = setup_logger()
|
||||
class SearchToolConfig(BaseModel):
|
||||
user_selected_filters: BaseFilters | None = None
|
||||
project_id: int | None = None
|
||||
persona_id: int | None = None
|
||||
bypass_acl: bool = False
|
||||
additional_context: str | None = None
|
||||
slack_context: SlackContext | None = None
|
||||
@@ -180,6 +181,7 @@ def construct_tools(
|
||||
document_index=document_index,
|
||||
user_selected_filters=search_tool_config.user_selected_filters,
|
||||
project_id=search_tool_config.project_id,
|
||||
persona_id=search_tool_config.persona_id,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
slack_context=search_tool_config.slack_context,
|
||||
enable_slack_search=search_tool_config.enable_slack_search,
|
||||
@@ -427,6 +429,7 @@ def construct_tools(
|
||||
document_index=document_index,
|
||||
user_selected_filters=search_tool_config.user_selected_filters,
|
||||
project_id=search_tool_config.project_id,
|
||||
persona_id=search_tool_config.persona_id,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
slack_context=search_tool_config.slack_context,
|
||||
enable_slack_search=search_tool_config.enable_slack_search,
|
||||
|
||||
@@ -98,6 +98,17 @@ class CodeInterpreterClient:
|
||||
payload["files"] = files
|
||||
return payload
|
||||
|
||||
def health(self) -> bool:
|
||||
"""Check if the Code Interpreter service is healthy"""
|
||||
url = f"{self.base_url}/health"
|
||||
try:
|
||||
response = self.session.get(url, timeout=5)
|
||||
response.raise_for_status()
|
||||
return response.json().get("status") == "ok"
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception caught when checking health, e={e}")
|
||||
return False
|
||||
|
||||
def execute(
|
||||
self,
|
||||
code: str,
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.configs.app_configs import CODE_INTERPRETER_BASE_URL
|
||||
from onyx.configs.app_configs import CODE_INTERPRETER_DEFAULT_TIMEOUT_MS
|
||||
from onyx.configs.app_configs import CODE_INTERPRETER_MAX_OUTPUT_LENGTH
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.db.code_interpreter import fetch_code_interpreter_server
|
||||
from onyx.file_store.utils import build_full_frontend_file_url
|
||||
from onyx.file_store.utils import get_default_file_store
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
@@ -103,8 +104,10 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
@override
|
||||
@classmethod
|
||||
def is_available(cls, db_session: Session) -> bool:
|
||||
is_available = bool(CODE_INTERPRETER_BASE_URL)
|
||||
return is_available
|
||||
if not CODE_INTERPRETER_BASE_URL:
|
||||
return False
|
||||
server = fetch_code_interpreter_server(db_session)
|
||||
return server.server_enabled
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
|
||||
@@ -247,6 +247,8 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
user_selected_filters: BaseFilters | None,
|
||||
# If the chat is part of a project
|
||||
project_id: int | None,
|
||||
# If set, search scopes to files attached to this persona
|
||||
persona_id: int | None = None,
|
||||
bypass_acl: bool = False,
|
||||
# Slack context for federated Slack search (tokens fetched internally)
|
||||
slack_context: SlackContext | None = None,
|
||||
@@ -261,6 +263,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
self.document_index = document_index
|
||||
self.user_selected_filters = user_selected_filters
|
||||
self.project_id = project_id
|
||||
self.persona_id = persona_id
|
||||
self.bypass_acl = bypass_acl
|
||||
self.slack_context = slack_context
|
||||
self.enable_slack_search = enable_slack_search
|
||||
@@ -456,6 +459,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
limit=num_hits,
|
||||
),
|
||||
project_id=self.project_id,
|
||||
persona_id=self.persona_id,
|
||||
document_index=self.document_index,
|
||||
user=self.user,
|
||||
persona=self.persona,
|
||||
|
||||
@@ -317,7 +317,7 @@ oauthlib==3.2.2
|
||||
# via
|
||||
# kubernetes
|
||||
# requests-oauthlib
|
||||
onyx-devtools==0.6.0
|
||||
onyx-devtools==0.6.2
|
||||
# via onyx
|
||||
openai==2.14.0
|
||||
# via
|
||||
|
||||
@@ -95,6 +95,7 @@ def generate_dummy_chunk(
|
||||
return DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
user_project=[],
|
||||
personas=[],
|
||||
access=DocumentAccess.build(
|
||||
user_emails=user_emails,
|
||||
user_groups=user_groups,
|
||||
|
||||
@@ -3,8 +3,8 @@ set -e
|
||||
|
||||
cleanup() {
|
||||
echo "Error occurred. Cleaning up..."
|
||||
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
|
||||
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
|
||||
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
}
|
||||
|
||||
# Trap errors and output a message, then cleanup
|
||||
@@ -20,8 +20,8 @@ MINIO_VOLUME=${4:-""} # Default is empty if not provided
|
||||
|
||||
# Stop and remove the existing containers
|
||||
echo "Stopping and removing existing containers..."
|
||||
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
|
||||
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio 2>/dev/null || true
|
||||
docker stop onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
docker rm onyx_postgres onyx_vespa onyx_redis onyx_minio onyx_code_interpreter 2>/dev/null || true
|
||||
|
||||
# Start the PostgreSQL container with optional volume
|
||||
echo "Starting PostgreSQL container..."
|
||||
@@ -55,6 +55,10 @@ else
|
||||
docker run --detach --name onyx_minio --publish 9004:9000 --publish 9005:9001 -e MINIO_ROOT_USER=minioadmin -e MINIO_ROOT_PASSWORD=minioadmin minio/minio server /data --console-address ":9001"
|
||||
fi
|
||||
|
||||
# Start the Code Interpreter container
|
||||
echo "Starting Code Interpreter container..."
|
||||
docker run --detach --name onyx_code_interpreter --publish 8000:8000 --user root -v /var/run/docker.sock:/var/run/docker.sock onyxdotapp/code-interpreter:latest bash ./entrypoint.sh code-interpreter-api
|
||||
|
||||
# Ensure alembic runs in the correct directory (backend/)
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
||||
PARENT_DIR="$(dirname "$SCRIPT_DIR")"
|
||||
|
||||
@@ -144,7 +144,8 @@ def use_mock_search_pipeline(
|
||||
auto_detect_filters: bool = False, # noqa: ARG001
|
||||
llm: LLM | None = None, # noqa: ARG001
|
||||
project_id: int | None = None, # noqa: ARG001
|
||||
# Pre-fetched data (used by SearchTool to avoid DB access in parallel)
|
||||
persona_id: int | None = None, # noqa: ARG001
|
||||
# Pre-fetched data (used by SearchTool to avoid DB access in parallel calls)
|
||||
acl_filters: list[str] | None = None, # noqa: ARG001
|
||||
embedding_model: EmbeddingModel | None = None, # noqa: ARG001
|
||||
prefetched_federated_retrieval_infos: ( # noqa: ARG001
|
||||
|
||||
@@ -38,6 +38,7 @@ def _get_search_filters(
|
||||
tags=[],
|
||||
document_sets=[],
|
||||
project_id=None,
|
||||
persona_id=None,
|
||||
time_cutoff=None,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
"""Tests that PythonTool.is_available() respects the server_enabled DB flag.
|
||||
|
||||
Uses a real DB session with CODE_INTERPRETER_BASE_URL mocked so the
|
||||
environment-variable check passes and the DB flag is the deciding factor.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.code_interpreter import fetch_code_interpreter_server
|
||||
from onyx.db.code_interpreter import update_code_interpreter_server_enabled
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
|
||||
|
||||
def test_python_tool_unavailable_when_server_disabled(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""With a valid base URL, the tool should be unavailable when
|
||||
server_enabled is False in the DB."""
|
||||
server = fetch_code_interpreter_server(db_session)
|
||||
initial_enabled = server.server_enabled
|
||||
|
||||
try:
|
||||
update_code_interpreter_server_enabled(db_session, enabled=False)
|
||||
|
||||
with patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
"http://fake:8888",
|
||||
):
|
||||
assert PythonTool.is_available(db_session) is False
|
||||
finally:
|
||||
update_code_interpreter_server_enabled(db_session, enabled=initial_enabled)
|
||||
|
||||
|
||||
def test_python_tool_available_when_server_enabled(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""With a valid base URL, the tool should be available when
|
||||
server_enabled is True in the DB."""
|
||||
server = fetch_code_interpreter_server(db_session)
|
||||
initial_enabled = server.server_enabled
|
||||
|
||||
try:
|
||||
update_code_interpreter_server_enabled(db_session, enabled=True)
|
||||
|
||||
with patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
"http://fake:8888",
|
||||
):
|
||||
assert PythonTool.is_available(db_session) is True
|
||||
finally:
|
||||
update_code_interpreter_server_enabled(db_session, enabled=initial_enabled)
|
||||
@@ -38,5 +38,5 @@ COPY --from=openapi-client /local/onyx_openapi_client /app/generated/onyx_openap
|
||||
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
ENTRYPOINT ["pytest", "-s"]
|
||||
ENTRYPOINT ["pytest", "-s", "-rs"]
|
||||
CMD ["/app/tests/integration", "--ignore=/app/tests/integration/multitenant_tests"]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from urllib.parse import urlencode
|
||||
from uuid import UUID
|
||||
@@ -8,8 +9,10 @@ from requests.models import CaseInsensitiveDict
|
||||
from ee.onyx.server.query_history.models import ChatSessionMinimal
|
||||
from ee.onyx.server.query_history.models import ChatSessionSnapshot
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import MAX_DELAY
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@@ -69,9 +72,42 @@ class QueryHistoryManager:
|
||||
if end_time:
|
||||
query_params["end"] = end_time.isoformat()
|
||||
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/admin/query-history-csv?{urlencode(query_params, doseq=True)}",
|
||||
start_response = requests.post(
|
||||
url=f"{API_SERVER_URL}/admin/query-history/start-export?{urlencode(query_params, doseq=True)}",
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.headers, response.content.decode()
|
||||
start_response.raise_for_status()
|
||||
request_id = start_response.json()["request_id"]
|
||||
|
||||
deadline = time.time() + MAX_DELAY
|
||||
while time.time() < deadline:
|
||||
status_response = requests.get(
|
||||
url=f"{API_SERVER_URL}/admin/query-history/export-status",
|
||||
params={"request_id": request_id},
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
status_response.raise_for_status()
|
||||
status = status_response.json()["status"]
|
||||
if status == TaskStatus.SUCCESS:
|
||||
break
|
||||
if status == TaskStatus.FAILURE:
|
||||
raise RuntimeError("Query history export task failed")
|
||||
time.sleep(2)
|
||||
else:
|
||||
raise TimeoutError(
|
||||
f"Query history export not completed within {MAX_DELAY} seconds"
|
||||
)
|
||||
|
||||
download_response = requests.get(
|
||||
url=f"{API_SERVER_URL}/admin/query-history/download",
|
||||
params={"request_id": request_id},
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
download_response.raise_for_status()
|
||||
|
||||
if not download_response.content:
|
||||
raise RuntimeError(
|
||||
"Query history CSV download returned zero-length content"
|
||||
)
|
||||
|
||||
return download_response.headers, download_response.content.decode()
|
||||
|
||||
@@ -6,16 +6,26 @@ import pytest
|
||||
from onyx.connectors.slack.models import ChannelType
|
||||
from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager
|
||||
|
||||
# from tests.load_env_vars import load_env_vars
|
||||
|
||||
# load_env_vars()
|
||||
SLACK_ADMIN_EMAIL = os.environ.get("SLACK_ADMIN_EMAIL", "evan@onyx.app")
|
||||
SLACK_TEST_USER_1_EMAIL = os.environ.get("SLACK_TEST_USER_1_EMAIL", "evan+1@onyx.app")
|
||||
SLACK_TEST_USER_2_EMAIL = os.environ.get("SLACK_TEST_USER_2_EMAIL", "justin@onyx.app")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def slack_test_setup() -> Generator[tuple[ChannelType, ChannelType], None, None]:
|
||||
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
|
||||
def _provision_slack_channels(
|
||||
bot_token: str,
|
||||
) -> Generator[tuple[ChannelType, ChannelType], None, None]:
|
||||
slack_client = SlackManager.get_slack_client(bot_token)
|
||||
|
||||
auth_info = slack_client.auth_test()
|
||||
print(f"\nSlack workspace: {auth_info.get('team')} ({auth_info.get('url')})")
|
||||
|
||||
user_map = SlackManager.build_slack_user_email_id_map(slack_client)
|
||||
admin_user_id = user_map["admin@example.com"]
|
||||
if SLACK_ADMIN_EMAIL not in user_map:
|
||||
raise KeyError(
|
||||
f"'{SLACK_ADMIN_EMAIL}' not found in Slack workspace. "
|
||||
f"Available emails: {sorted(user_map.keys())}"
|
||||
)
|
||||
admin_user_id = user_map[SLACK_ADMIN_EMAIL]
|
||||
|
||||
(
|
||||
public_channel,
|
||||
@@ -27,5 +37,16 @@ def slack_test_setup() -> Generator[tuple[ChannelType, ChannelType], None, None]
|
||||
|
||||
yield public_channel, private_channel
|
||||
|
||||
# This part will always run after the test, even if it fails
|
||||
SlackManager.cleanup_after_test(slack_client=slack_client, test_id=run_id)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def slack_test_setup() -> Generator[tuple[ChannelType, ChannelType], None, None]:
|
||||
yield from _provision_slack_channels(os.environ["SLACK_BOT_TOKEN"])
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def slack_perm_sync_test_setup() -> (
|
||||
Generator[tuple[ChannelType, ChannelType], None, None]
|
||||
):
|
||||
yield from _provision_slack_channels(os.environ["SLACK_BOT_TOKEN_TEST_SPACE"])
|
||||
|
||||
@@ -16,7 +16,6 @@ from uuid import uuid4
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.connectors.slack.connector import default_msg_filter
|
||||
from onyx.connectors.slack.connector import get_channel_messages
|
||||
from onyx.connectors.slack.models import ChannelType
|
||||
from onyx.connectors.slack.utils import make_paginated_slack_api_call
|
||||
@@ -113,9 +112,6 @@ def _delete_slack_conversation_messages(
|
||||
channel_id = _get_slack_channel_id(channel)
|
||||
for message_batch in get_channel_messages(slack_client, channel):
|
||||
for message in message_batch:
|
||||
if default_msg_filter(message):
|
||||
continue
|
||||
|
||||
if message_to_delete and message.get("text") != message_to_delete:
|
||||
continue
|
||||
print(" removing message: ", message.get("text"))
|
||||
|
||||
@@ -22,6 +22,9 @@ from tests.integration.common_utils.test_models import DATestConnector
|
||||
from tests.integration.common_utils.test_models import DATestCredential
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
from tests.integration.connector_job_tests.slack.conftest import SLACK_ADMIN_EMAIL
|
||||
from tests.integration.connector_job_tests.slack.conftest import SLACK_TEST_USER_1_EMAIL
|
||||
from tests.integration.connector_job_tests.slack.conftest import SLACK_TEST_USER_2_EMAIL
|
||||
from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager
|
||||
|
||||
|
||||
@@ -34,26 +37,24 @@ from tests.integration.connector_job_tests.slack.slack_api_utils import SlackMan
|
||||
def test_slack_permission_sync(
|
||||
reset: None, # noqa: ARG001
|
||||
vespa_client: vespa_fixture, # noqa: ARG001
|
||||
slack_test_setup: tuple[ChannelType, ChannelType],
|
||||
slack_perm_sync_test_setup: tuple[ChannelType, ChannelType],
|
||||
) -> None:
|
||||
public_channel, private_channel = slack_test_setup
|
||||
public_channel, private_channel = slack_perm_sync_test_setup
|
||||
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(
|
||||
email="admin@example.com",
|
||||
email=SLACK_ADMIN_EMAIL,
|
||||
)
|
||||
|
||||
# Creating a non-admin user
|
||||
test_user_1: DATestUser = UserManager.create(
|
||||
email="test_user_1@example.com",
|
||||
email=SLACK_TEST_USER_1_EMAIL,
|
||||
)
|
||||
|
||||
# Creating a non-admin user
|
||||
test_user_2: DATestUser = UserManager.create(
|
||||
email="test_user_2@example.com",
|
||||
email=SLACK_TEST_USER_2_EMAIL,
|
||||
)
|
||||
|
||||
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
|
||||
bot_token = os.environ["SLACK_BOT_TOKEN_TEST_SPACE"]
|
||||
slack_client = SlackManager.get_slack_client(bot_token)
|
||||
email_id_map = SlackManager.build_slack_user_email_id_map(slack_client)
|
||||
admin_user_id = email_id_map[admin_user.email]
|
||||
|
||||
@@ -63,7 +64,7 @@ def test_slack_permission_sync(
|
||||
credential: DATestCredential = CredentialManager.create(
|
||||
source=DocumentSource.SLACK,
|
||||
credential_json={
|
||||
"slack_bot_token": os.environ["SLACK_BOT_TOKEN"],
|
||||
"slack_bot_token": bot_token,
|
||||
},
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
@@ -73,6 +74,7 @@ def test_slack_permission_sync(
|
||||
source=DocumentSource.SLACK,
|
||||
connector_specific_config={
|
||||
"channels": [public_channel["name"], private_channel["name"]],
|
||||
"include_bot_messages": True,
|
||||
},
|
||||
access_type=AccessType.SYNC,
|
||||
groups=[],
|
||||
@@ -102,14 +104,11 @@ def test_slack_permission_sync(
|
||||
public_message = "Steve's favorite number is 809752"
|
||||
private_message = "Sara's favorite number is 346794"
|
||||
|
||||
# Add messages to channels
|
||||
print(f"\n Adding public message to channel: {public_message}")
|
||||
SlackManager.add_message_to_channel(
|
||||
slack_client=slack_client,
|
||||
channel=public_channel,
|
||||
message=public_message,
|
||||
)
|
||||
print(f"\n Adding private message to channel: {private_message}")
|
||||
SlackManager.add_message_to_channel(
|
||||
slack_client=slack_client,
|
||||
channel=private_channel,
|
||||
@@ -127,7 +126,9 @@ def test_slack_permission_sync(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Run permission sync
|
||||
# Run permission sync. Since initial_index_should_sync=True for Slack,
|
||||
# permissions were already set during indexing above — the explicit sync
|
||||
# should find no changes to apply.
|
||||
CCPairManager.sync(
|
||||
cc_pair=cc_pair,
|
||||
user_performing_action=admin_user,
|
||||
@@ -135,59 +136,38 @@ def test_slack_permission_sync(
|
||||
CCPairManager.wait_for_sync(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
number_of_updated_docs=2,
|
||||
number_of_updated_docs=0,
|
||||
user_performing_action=admin_user,
|
||||
should_wait_for_group_sync=False,
|
||||
should_wait_for_vespa_sync=False,
|
||||
)
|
||||
|
||||
# Search as admin with access to both channels
|
||||
print("\nSearching as admin user")
|
||||
onyx_doc_message_strings = DocumentSearchManager.search_documents(
|
||||
# Verify admin can see messages from both channels
|
||||
admin_docs = DocumentSearchManager.search_documents(
|
||||
query="favorite number",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
print(
|
||||
"\n documents retrieved by admin user: ",
|
||||
onyx_doc_message_strings,
|
||||
)
|
||||
assert public_message in admin_docs
|
||||
assert private_message in admin_docs
|
||||
|
||||
# Ensure admin user can see messages from both channels
|
||||
assert public_message in onyx_doc_message_strings
|
||||
assert private_message in onyx_doc_message_strings
|
||||
|
||||
# Search as test_user_2 with access to only the public channel
|
||||
print("\n Searching as test_user_2")
|
||||
onyx_doc_message_strings = DocumentSearchManager.search_documents(
|
||||
# Verify test_user_2 can only see public channel messages
|
||||
user_2_docs = DocumentSearchManager.search_documents(
|
||||
query="favorite number",
|
||||
user_performing_action=test_user_2,
|
||||
)
|
||||
print(
|
||||
"\n documents retrieved by test_user_2: ",
|
||||
onyx_doc_message_strings,
|
||||
)
|
||||
assert public_message in user_2_docs
|
||||
assert private_message not in user_2_docs
|
||||
|
||||
# Ensure test_user_2 can only see messages from the public channel
|
||||
assert public_message in onyx_doc_message_strings
|
||||
assert private_message not in onyx_doc_message_strings
|
||||
|
||||
# Search as test_user_1 with access to both channels
|
||||
print("\n Searching as test_user_1")
|
||||
onyx_doc_message_strings = DocumentSearchManager.search_documents(
|
||||
# Verify test_user_1 can see both channels (member of private channel)
|
||||
user_1_docs = DocumentSearchManager.search_documents(
|
||||
query="favorite number",
|
||||
user_performing_action=test_user_1,
|
||||
)
|
||||
print(
|
||||
"\n documents retrieved by test_user_1 before being removed from private channel: ",
|
||||
onyx_doc_message_strings,
|
||||
)
|
||||
assert public_message in user_1_docs
|
||||
assert private_message in user_1_docs
|
||||
|
||||
# Ensure test_user_1 can see messages from both channels
|
||||
assert public_message in onyx_doc_message_strings
|
||||
assert private_message in onyx_doc_message_strings
|
||||
|
||||
# ----------------------MAKE THE CHANGES--------------------------
|
||||
print("\n Removing test_user_1 from the private channel")
|
||||
before = datetime.now(timezone.utc)
|
||||
# Remove test_user_1 from the private channel
|
||||
before = datetime.now(timezone.utc)
|
||||
desired_channel_members = [admin_user]
|
||||
SlackManager.set_channel_members(
|
||||
slack_client=slack_client,
|
||||
@@ -206,24 +186,16 @@ def test_slack_permission_sync(
|
||||
after=before,
|
||||
number_of_updated_docs=1,
|
||||
user_performing_action=admin_user,
|
||||
should_wait_for_group_sync=False,
|
||||
)
|
||||
|
||||
# ----------------------------VERIFY THE CHANGES---------------------------
|
||||
# Ensure test_user_1 can no longer see messages from the private channel
|
||||
# Search as test_user_1 with access to only the public channel
|
||||
|
||||
onyx_doc_message_strings = DocumentSearchManager.search_documents(
|
||||
# Verify test_user_1 can no longer see private channel after removal
|
||||
user_1_docs = DocumentSearchManager.search_documents(
|
||||
query="favorite number",
|
||||
user_performing_action=test_user_1,
|
||||
)
|
||||
print(
|
||||
"\n documents retrieved by test_user_1 after being removed from private channel: ",
|
||||
onyx_doc_message_strings,
|
||||
)
|
||||
|
||||
# Ensure test_user_1 can only see messages from the public channel
|
||||
assert public_message in onyx_doc_message_strings
|
||||
assert private_message not in onyx_doc_message_strings
|
||||
assert public_message in user_1_docs
|
||||
assert private_message not in user_1_docs
|
||||
|
||||
|
||||
# NOTE(rkuo): it isn't yet clear if the reason these were previously xfail'd
|
||||
@@ -235,21 +207,19 @@ def test_slack_permission_sync(
|
||||
def test_slack_group_permission_sync(
|
||||
reset: None, # noqa: ARG001
|
||||
vespa_client: vespa_fixture, # noqa: ARG001
|
||||
slack_test_setup: tuple[ChannelType, ChannelType],
|
||||
slack_perm_sync_test_setup: tuple[ChannelType, ChannelType],
|
||||
) -> None:
|
||||
"""
|
||||
This test ensures that permission sync overrides onyx group access.
|
||||
"""
|
||||
public_channel, private_channel = slack_test_setup
|
||||
public_channel, private_channel = slack_perm_sync_test_setup
|
||||
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(
|
||||
email="admin@example.com",
|
||||
email=SLACK_ADMIN_EMAIL,
|
||||
)
|
||||
|
||||
# Creating a non-admin user
|
||||
test_user_1: DATestUser = UserManager.create(
|
||||
email="test_user_1@example.com",
|
||||
email=SLACK_TEST_USER_1_EMAIL,
|
||||
)
|
||||
|
||||
# Create a user group and adding the non-admin user to it
|
||||
@@ -264,7 +234,8 @@ def test_slack_group_permission_sync(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
|
||||
bot_token = os.environ["SLACK_BOT_TOKEN_TEST_SPACE"]
|
||||
slack_client = SlackManager.get_slack_client(bot_token)
|
||||
email_id_map = SlackManager.build_slack_user_email_id_map(slack_client)
|
||||
admin_user_id = email_id_map[admin_user.email]
|
||||
|
||||
@@ -282,7 +253,7 @@ def test_slack_group_permission_sync(
|
||||
credential = CredentialManager.create(
|
||||
source=DocumentSource.SLACK,
|
||||
credential_json={
|
||||
"slack_bot_token": os.environ["SLACK_BOT_TOKEN"],
|
||||
"slack_bot_token": bot_token,
|
||||
},
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
@@ -294,6 +265,7 @@ def test_slack_group_permission_sync(
|
||||
source=DocumentSource.SLACK,
|
||||
connector_specific_config={
|
||||
"channels": [private_channel["name"]],
|
||||
"include_bot_messages": True,
|
||||
},
|
||||
access_type=AccessType.SYNC,
|
||||
groups=[user_group.id],
|
||||
@@ -326,7 +298,8 @@ def test_slack_group_permission_sync(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Run permission sync
|
||||
# Run permission sync. Since initial_index_should_sync=True for Slack,
|
||||
# permissions were already set during indexing — no changes expected.
|
||||
CCPairManager.sync(
|
||||
cc_pair=cc_pair,
|
||||
user_performing_action=admin_user,
|
||||
@@ -334,8 +307,10 @@ def test_slack_group_permission_sync(
|
||||
CCPairManager.wait_for_sync(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
number_of_updated_docs=1,
|
||||
number_of_updated_docs=0,
|
||||
user_performing_action=admin_user,
|
||||
should_wait_for_group_sync=False,
|
||||
should_wait_for_vespa_sync=False,
|
||||
)
|
||||
|
||||
# Verify admin can see the message
|
||||
|
||||
@@ -4,75 +4,84 @@ import time
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from onyx.db.chat import delete_chat_session
|
||||
from onyx.db.chat import get_chat_sessions_older_than
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.settings import SettingsManager
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestSettings
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
RETENTION_SECONDS = 10
|
||||
|
||||
|
||||
def _run_ttl_cleanup(retention_days: int) -> None:
|
||||
"""Directly execute TTL cleanup logic, bypassing Celery task infrastructure."""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
old_chat_sessions = get_chat_sessions_older_than(retention_days, db_session)
|
||||
|
||||
for user_id, session_id in old_chat_sessions:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_chat_session(
|
||||
user_id,
|
||||
session_id,
|
||||
db_session,
|
||||
include_deleted=True,
|
||||
hard_delete=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Chat retention tests are enterprise only",
|
||||
)
|
||||
def test_chat_retention(reset: None, admin_user: DATestUser) -> None: # noqa: ARG001
|
||||
def test_chat_retention(
|
||||
reset: None, admin_user: DATestUser, llm_provider: DATestLLMProvider # noqa: ARG001
|
||||
) -> None: # noqa: ARG001
|
||||
"""Test that chat sessions are deleted after the retention period expires."""
|
||||
|
||||
# Set chat retention period to 10 seconds
|
||||
retention_days = 10 / 86400 # 10 seconds in days (10 / 24 / 60 / 60)
|
||||
retention_days = RETENTION_SECONDS // 86400
|
||||
settings = DATestSettings(maximum_chat_retention_days=retention_days)
|
||||
SettingsManager.update_settings(settings, user_performing_action=admin_user)
|
||||
|
||||
# Create a chat session
|
||||
chat_session = ChatSessionManager.create(
|
||||
persona_id=0,
|
||||
description="Test chat retention",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Send a message
|
||||
ChatSessionManager.send_message(
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message="This message should be deleted soon",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert (
|
||||
response.error is None
|
||||
), f"Chat response should not have an error: {response.error}"
|
||||
|
||||
# Verify the chat session exists
|
||||
chat_history = ChatSessionManager.get_chat_history(
|
||||
chat_session=chat_session,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(chat_history) > 0, "Chat session should have messages"
|
||||
|
||||
# Wait for TTL task to run (give it ~60 seconds)
|
||||
print("Waiting for chat retention TTL task to run...")
|
||||
max_wait_time = 60 # maximum time to wait in seconds
|
||||
start_time = time.time()
|
||||
# Wait for the retention period to elapse, then directly run TTL cleanup
|
||||
time.sleep(RETENTION_SECONDS + 2)
|
||||
_run_ttl_cleanup(retention_days)
|
||||
|
||||
# Verify the chat session was deleted
|
||||
session_deleted = False
|
||||
try:
|
||||
chat_history = ChatSessionManager.get_chat_history(
|
||||
chat_session=chat_session,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
session_deleted = len(chat_history) == 0
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code in (404, 400):
|
||||
session_deleted = True
|
||||
else:
|
||||
raise
|
||||
|
||||
while not session_deleted and (time.time() - start_time < max_wait_time):
|
||||
# Check if chat session is deleted
|
||||
try:
|
||||
# Attempt to get chat history - this should 404
|
||||
chat_history = ChatSessionManager.get_chat_history(
|
||||
chat_session=chat_session,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# If we got no messages or an empty response, session might be deleted
|
||||
if not chat_history:
|
||||
session_deleted = True
|
||||
break
|
||||
|
||||
except requests.exceptions.HTTPError as e:
|
||||
# If we get a 404 or other error, the session is gone
|
||||
if e.response.status_code in (404, 400):
|
||||
session_deleted = True
|
||||
break
|
||||
raise # Re-raise other errors
|
||||
|
||||
# Wait a bit before checking again
|
||||
time.sleep(5)
|
||||
print(f"Waited {time.time() - start_time:.1f} seconds for chat deletion...")
|
||||
|
||||
# Assert that the chat session was deleted
|
||||
assert session_deleted, "Chat session was not deleted within the expected time"
|
||||
assert session_deleted, "Chat session was not deleted after retention period"
|
||||
|
||||
32
backend/tests/integration/tests/code_interpreter/conftest.py
Normal file
32
backend/tests/integration/tests/code_interpreter/conftest.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
CODE_INTERPRETER_URL = f"{API_SERVER_URL}/admin/code-interpreter"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def preserve_code_interpreter_state(
|
||||
admin_user: DATestUser,
|
||||
) -> Generator[None, None, None]:
|
||||
"""Capture the code interpreter enabled state before a test and restore it
|
||||
afterwards, so that tests that toggle the setting cannot leak state."""
|
||||
response = requests.get(
|
||||
CODE_INTERPRETER_URL,
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
initial_enabled = response.json()["enabled"]
|
||||
|
||||
yield
|
||||
|
||||
restore = requests.put(
|
||||
CODE_INTERPRETER_URL,
|
||||
json={"enabled": initial_enabled},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
restore.raise_for_status()
|
||||
@@ -0,0 +1,97 @@
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
CODE_INTERPRETER_URL = f"{API_SERVER_URL}/admin/code-interpreter"
|
||||
CODE_INTERPRETER_HEALTH_URL = f"{CODE_INTERPRETER_URL}/health"
|
||||
|
||||
|
||||
def test_get_code_interpreter_health_as_admin(
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Health endpoint should return a JSON object with a 'healthy' boolean."""
|
||||
response = requests.get(
|
||||
CODE_INTERPRETER_HEALTH_URL,
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "healthy" in data
|
||||
assert isinstance(data["healthy"], bool)
|
||||
|
||||
|
||||
def test_get_code_interpreter_status_as_admin(
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""GET endpoint should return a JSON object with an 'enabled' boolean."""
|
||||
response = requests.get(
|
||||
CODE_INTERPRETER_URL,
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "enabled" in data
|
||||
assert isinstance(data["enabled"], bool)
|
||||
|
||||
|
||||
def test_update_code_interpreter_disable_and_enable(
|
||||
admin_user: DATestUser,
|
||||
preserve_code_interpreter_state: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""PUT endpoint should update the enabled flag and persist across reads."""
|
||||
# Disable
|
||||
response = requests.put(
|
||||
CODE_INTERPRETER_URL,
|
||||
json={"enabled": False},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify disabled
|
||||
response = requests.get(
|
||||
CODE_INTERPRETER_URL,
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["enabled"] is False
|
||||
|
||||
# Re-enable
|
||||
response = requests.put(
|
||||
CODE_INTERPRETER_URL,
|
||||
json={"enabled": True},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify enabled
|
||||
response = requests.get(
|
||||
CODE_INTERPRETER_URL,
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["enabled"] is True
|
||||
|
||||
|
||||
def test_code_interpreter_endpoints_require_admin(
|
||||
basic_user: DATestUser,
|
||||
) -> None:
|
||||
"""All code interpreter endpoints should reject non-admin users."""
|
||||
health_response = requests.get(
|
||||
CODE_INTERPRETER_HEALTH_URL,
|
||||
headers=basic_user.headers,
|
||||
)
|
||||
assert health_response.status_code == 403
|
||||
|
||||
get_response = requests.get(
|
||||
CODE_INTERPRETER_URL,
|
||||
headers=basic_user.headers,
|
||||
)
|
||||
assert get_response.status_code == 403
|
||||
|
||||
put_response = requests.put(
|
||||
CODE_INTERPRETER_URL,
|
||||
json={"enabled": True},
|
||||
headers=basic_user.headers,
|
||||
)
|
||||
assert put_response.status_code == 403
|
||||
@@ -1,195 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from onyx.configs.constants import MessageType
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.document import DocumentManager
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestAPIKey
|
||||
from tests.integration.common_utils.test_models import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="/chat/send-message-simple-with-history is enterprise only",
|
||||
)
|
||||
def test_all_stream_chat_message_objects_outputs(reset: None) -> None: # noqa: ARG001
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# create connector
|
||||
cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
# SEEDING DOCUMENTS
|
||||
cc_pair_1.documents = []
|
||||
cc_pair_1.documents.append(
|
||||
DocumentManager.seed_doc_with_content(
|
||||
cc_pair=cc_pair_1,
|
||||
content="Pablo's favorite color is blue",
|
||||
api_key=api_key,
|
||||
)
|
||||
)
|
||||
cc_pair_1.documents.append(
|
||||
DocumentManager.seed_doc_with_content(
|
||||
cc_pair=cc_pair_1,
|
||||
content="Chris's favorite color is red",
|
||||
api_key=api_key,
|
||||
)
|
||||
)
|
||||
cc_pair_1.documents.append(
|
||||
DocumentManager.seed_doc_with_content(
|
||||
cc_pair=cc_pair_1,
|
||||
content="Pika's favorite color is green",
|
||||
api_key=api_key,
|
||||
)
|
||||
)
|
||||
|
||||
# TESTING RESPONSE FOR QUESTION 1
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
|
||||
json={
|
||||
"messages": [
|
||||
{
|
||||
"message": "What is Pablo's favorite color?",
|
||||
"role": MessageType.USER.value,
|
||||
}
|
||||
],
|
||||
"persona_id": 0,
|
||||
},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
response_json = response.json()
|
||||
|
||||
# check that the answer is correct
|
||||
answer_1 = response_json["answer"]
|
||||
assert "blue" in answer_1.lower()
|
||||
|
||||
# FLAKY - check that the llm selected a document
|
||||
# assert 0 in response_json["llm_selected_doc_indices"]
|
||||
|
||||
# check that the final context documents are correct
|
||||
# (it should contain all documents because there arent enough to exclude any)
|
||||
assert 0 in response_json["final_context_doc_indices"]
|
||||
assert 1 in response_json["final_context_doc_indices"]
|
||||
assert 2 in response_json["final_context_doc_indices"]
|
||||
|
||||
# FLAKY - check that the cited documents are correct
|
||||
# assert cc_pair_1.documents[0].id in response_json["cited_documents"].values()
|
||||
|
||||
# flakiness likely due to non-deterministic rephrasing
|
||||
# FLAKY - check that the top documents are correct
|
||||
# assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[0].id
|
||||
print("response 1/3 passed")
|
||||
|
||||
# TESTING RESPONSE FOR QUESTION 2
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
|
||||
json={
|
||||
"messages": [
|
||||
{
|
||||
"message": "What is Pablo's favorite color?",
|
||||
"role": MessageType.USER.value,
|
||||
},
|
||||
{
|
||||
"message": answer_1,
|
||||
"role": MessageType.ASSISTANT.value,
|
||||
},
|
||||
{
|
||||
"message": "What is Chris's favorite color?",
|
||||
"role": MessageType.USER.value,
|
||||
},
|
||||
],
|
||||
"persona_id": 0,
|
||||
},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
response_json = response.json()
|
||||
|
||||
# check that the answer is correct
|
||||
answer_2 = response_json["answer"]
|
||||
assert "red" in answer_2.lower()
|
||||
|
||||
# FLAKY - check that the llm selected a document
|
||||
# assert 0 in response_json["llm_selected_doc_indices"]
|
||||
|
||||
# check that the final context documents are correct
|
||||
# (it should contain all documents because there arent enough to exclude any)
|
||||
assert 0 in response_json["final_context_doc_indices"]
|
||||
assert 1 in response_json["final_context_doc_indices"]
|
||||
assert 2 in response_json["final_context_doc_indices"]
|
||||
|
||||
# FLAKY - check that the cited documents are correct
|
||||
# assert cc_pair_1.documents[1].id in response_json["cited_documents"].values()
|
||||
|
||||
# flakiness likely due to non-deterministic rephrasing
|
||||
# FLAKY - check that the top documents are correct
|
||||
# assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[1].id
|
||||
print("response 2/3 passed")
|
||||
|
||||
# TESTING RESPONSE FOR QUESTION 3
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
|
||||
json={
|
||||
"messages": [
|
||||
{
|
||||
"message": "What is Pablo's favorite color?",
|
||||
"role": MessageType.USER.value,
|
||||
},
|
||||
{
|
||||
"message": answer_1,
|
||||
"role": MessageType.ASSISTANT.value,
|
||||
},
|
||||
{
|
||||
"message": "What is Chris's favorite color?",
|
||||
"role": MessageType.USER.value,
|
||||
},
|
||||
{
|
||||
"message": answer_2,
|
||||
"role": MessageType.ASSISTANT.value,
|
||||
},
|
||||
{
|
||||
"message": "What is Pika's favorite color?",
|
||||
"role": MessageType.USER.value,
|
||||
},
|
||||
],
|
||||
"persona_id": 0,
|
||||
},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
response_json = response.json()
|
||||
|
||||
# check that the answer is correct
|
||||
answer_3 = response_json["answer"]
|
||||
assert "green" in answer_3.lower()
|
||||
|
||||
# FLAKY - check that the llm selected a document
|
||||
# assert 0 in response_json["llm_selected_doc_indices"]
|
||||
|
||||
# check that the final context documents are correct
|
||||
# (it should contain all documents because there arent enough to exclude any)
|
||||
assert 0 in response_json["final_context_doc_indices"]
|
||||
assert 1 in response_json["final_context_doc_indices"]
|
||||
assert 2 in response_json["final_context_doc_indices"]
|
||||
|
||||
# FLAKY - check that the cited documents are correct
|
||||
# assert cc_pair_1.documents[2].id in response_json["cited_documents"].values()
|
||||
|
||||
# flakiness likely due to non-deterministic rephrasing
|
||||
# FLAKY - check that the top documents are correct
|
||||
# assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[2].id
|
||||
print("response 3/3 passed")
|
||||
@@ -1,250 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from onyx.configs.constants import MessageType
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import NUM_DOCS
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.conftest import DocumentBuilderType
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="/chat/send-message-simple-with-history tests are enterprise only",
|
||||
)
|
||||
def test_send_message_simple_with_history(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
document_builder: DocumentBuilderType,
|
||||
) -> None:
|
||||
# create documents using the document builder
|
||||
# Create NUM_DOCS number of documents with dummy content
|
||||
content_list = [f"Document {i} content" for i in range(NUM_DOCS)]
|
||||
docs = document_builder(content_list)
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
|
||||
json={
|
||||
"messages": [
|
||||
{
|
||||
"message": docs[0].content,
|
||||
"role": MessageType.USER.value,
|
||||
}
|
||||
],
|
||||
"persona_id": 0,
|
||||
},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
# Check that the top document is the correct document
|
||||
assert response_json["top_documents"][0]["document_id"] == docs[0].id
|
||||
|
||||
# assert that the metadata is correct
|
||||
for doc in docs:
|
||||
found_doc = next(
|
||||
(x for x in response_json["top_documents"] if x["document_id"] == doc.id),
|
||||
None,
|
||||
)
|
||||
assert found_doc
|
||||
assert found_doc["metadata"]["document_id"] == doc.id
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="/chat/send-message-simple-with-history tests are enterprise only",
|
||||
)
|
||||
def test_using_reference_docs_with_simple_with_history_api_flow(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
document_builder: DocumentBuilderType,
|
||||
) -> None:
|
||||
# SEEDING DOCUMENTS
|
||||
docs = document_builder(
|
||||
[
|
||||
"Chris's favorite color is blue",
|
||||
"Hagen's favorite color is red",
|
||||
"Pablo's favorite color is green",
|
||||
]
|
||||
)
|
||||
|
||||
# SEINDING MESSAGE 1
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
|
||||
json={
|
||||
"messages": [
|
||||
{
|
||||
"message": "What is Pablo's favorite color?",
|
||||
"role": MessageType.USER.value,
|
||||
}
|
||||
],
|
||||
"persona_id": 0,
|
||||
},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
response_json = response.json()
|
||||
|
||||
# get the db_doc_id of the top document to use as a search doc id for second message
|
||||
first_db_doc_id = response_json["top_documents"][0]["db_doc_id"]
|
||||
|
||||
# SEINDING MESSAGE 2
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
|
||||
json={
|
||||
"messages": [
|
||||
{
|
||||
"message": "What is Pablo's favorite color?",
|
||||
"role": MessageType.USER.value,
|
||||
}
|
||||
],
|
||||
"persona_id": 0,
|
||||
"search_doc_ids": [first_db_doc_id],
|
||||
},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
response_json = response.json()
|
||||
|
||||
# make sure there is an answer
|
||||
assert response_json["answer"]
|
||||
|
||||
# This ensures the the document we think we are referencing when we send the search_doc_ids in the second
|
||||
# message is the document that we expect it to be
|
||||
assert response_json["top_documents"][0]["document_id"] == docs[2].id
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="We don't support this anymore with the DR flow :(")
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="/chat/send-message-simple-with-history tests are enterprise only",
|
||||
)
|
||||
def test_send_message_simple_with_history_strict_json(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
|
||||
json={
|
||||
# intentionally not relevant prompt to ensure that the
|
||||
# structured response format is actually used
|
||||
"messages": [
|
||||
{
|
||||
"message": "What is green?",
|
||||
"role": MessageType.USER.value,
|
||||
}
|
||||
],
|
||||
"persona_id": 0,
|
||||
"structured_response_format": {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "presidents",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"presidents": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "List of the first three US presidents",
|
||||
}
|
||||
},
|
||||
"required": ["presidents"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"strict": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
# Check that the answer is present
|
||||
assert "answer" in response_json
|
||||
assert response_json["answer"] is not None
|
||||
|
||||
# helper
|
||||
def clean_json_string(json_string: str) -> str:
|
||||
return json_string.strip().removeprefix("```json").removesuffix("```").strip()
|
||||
|
||||
# Attempt to parse the answer as JSON
|
||||
try:
|
||||
clean_answer = clean_json_string(response_json["answer"])
|
||||
parsed_answer = json.loads(clean_answer)
|
||||
|
||||
# NOTE: do not check content, just the structure
|
||||
assert isinstance(parsed_answer, dict)
|
||||
assert "presidents" in parsed_answer
|
||||
assert isinstance(parsed_answer["presidents"], list)
|
||||
for president in parsed_answer["presidents"]:
|
||||
assert isinstance(president, str)
|
||||
except json.JSONDecodeError:
|
||||
assert (
|
||||
False
|
||||
), f"The answer is not a valid JSON object - '{response_json['answer']}'"
|
||||
|
||||
# Check that the answer_citationless is also valid JSON
|
||||
assert "answer_citationless" in response_json
|
||||
assert response_json["answer_citationless"] is not None
|
||||
try:
|
||||
clean_answer_citationless = clean_json_string(
|
||||
response_json["answer_citationless"]
|
||||
)
|
||||
parsed_answer_citationless = json.loads(clean_answer_citationless)
|
||||
assert isinstance(parsed_answer_citationless, dict)
|
||||
except json.JSONDecodeError:
|
||||
assert False, "The answer_citationless is not a valid JSON object"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="/query/answer-with-citation tests are enterprise only",
|
||||
)
|
||||
def test_answer_with_citation_api(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
document_builder: DocumentBuilderType,
|
||||
) -> None:
|
||||
|
||||
# create docs
|
||||
docs = document_builder(["Chris' favorite color is green"])
|
||||
|
||||
# send a message
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/query/answer-with-citation",
|
||||
json={
|
||||
"messages": [
|
||||
{
|
||||
"message": "What is Chris' favorite color? Make sure to cite the document.",
|
||||
"role": MessageType.USER.value,
|
||||
}
|
||||
],
|
||||
"persona_id": 0,
|
||||
},
|
||||
headers=admin_user.headers,
|
||||
cookies=admin_user.cookies,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
response_json = response.json()
|
||||
assert response_json["answer"]
|
||||
|
||||
has_correct_citation = False
|
||||
for citation in response_json["citations"]:
|
||||
if citation["document_id"] == docs[0].id:
|
||||
has_correct_citation = True
|
||||
break
|
||||
|
||||
assert has_correct_citation
|
||||
@@ -2,7 +2,6 @@ import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
@@ -12,6 +11,7 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.mock_connector.connector import EXTERNAL_USER_EMAILS
|
||||
from onyx.connectors.mock_connector.connector import EXTERNAL_USER_GROUP_IDS
|
||||
from onyx.connectors.mock_connector.connector import MockConnectorCheckpoint
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.document import get_documents_by_ids
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
@@ -25,128 +25,16 @@ from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.document import DocumentManager
|
||||
from tests.integration.common_utils.managers.index_attempt import IndexAttemptManager
|
||||
from tests.integration.common_utils.test_document_utils import create_test_document
|
||||
from tests.integration.common_utils.test_models import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Permission sync is enterprise only",
|
||||
)
|
||||
def test_mock_connector_initial_permission_sync(
|
||||
def _setup_mock_connector(
|
||||
mock_server_client: httpx.Client,
|
||||
vespa_client: vespa_fixture,
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Test that the MockConnector fetches and sets permissions during initial indexing when AccessType.SYNC is used"""
|
||||
|
||||
# Set up mock server behavior
|
||||
doc_uuid = uuid.uuid4()
|
||||
test_doc = create_test_document(doc_id=f"test-doc-{doc_uuid}")
|
||||
|
||||
response = mock_server_client.post(
|
||||
"/set-behavior",
|
||||
json=[
|
||||
{
|
||||
"documents": [test_doc.model_dump(mode="json")],
|
||||
"checkpoint": MockConnectorCheckpoint(has_more=False).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
"failures": [],
|
||||
}
|
||||
],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Create CC Pair with SYNC access type to enable permissions during indexing
|
||||
cc_pair = CCPairManager.create_from_scratch(
|
||||
name=f"mock-connector-permissions-{uuid.uuid4()}",
|
||||
source=DocumentSource.MOCK_CONNECTOR,
|
||||
input_type=InputType.POLL,
|
||||
connector_specific_config={
|
||||
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
|
||||
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
|
||||
},
|
||||
access_type=AccessType.SYNC, # This enables permissions during indexing
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Wait for index attempt to start
|
||||
index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Wait for index attempt to finish
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
index_attempt_id=index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Validate status
|
||||
finished_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
|
||||
index_attempt_id=index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert finished_index_attempt.status == IndexingStatus.SUCCESS
|
||||
|
||||
# Verify document was indexed
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
vespa_client=vespa_client,
|
||||
)
|
||||
assert len(documents) == 1
|
||||
assert documents[0].id == test_doc.id
|
||||
|
||||
# Verify no errors occurred
|
||||
errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(errors) == 0
|
||||
|
||||
# Verify permissions were set during indexing by checking the document in the database
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_docs = get_documents_by_ids(
|
||||
db_session=db_session,
|
||||
document_ids=[test_doc.id],
|
||||
)
|
||||
assert len(db_docs) == 1
|
||||
db_doc = db_docs[0]
|
||||
|
||||
assert db_doc.external_user_emails is not None
|
||||
assert db_doc.external_user_group_ids is not None
|
||||
|
||||
# Check the specific permissions that MockConnector sets
|
||||
assert set(db_doc.external_user_emails) == EXTERNAL_USER_EMAILS
|
||||
assert set(db_doc.external_user_group_ids) == EXTERNAL_USER_GROUP_IDS
|
||||
|
||||
# Verify the document is not public (as set by MockConnector)
|
||||
assert db_doc.is_public is False
|
||||
|
||||
# Verify that the cc_pair was marked as permissions synced
|
||||
updated_cc_pair_info = CCPairManager.get_single(
|
||||
cc_pair.id, user_performing_action=admin_user
|
||||
)
|
||||
assert updated_cc_pair_info is not None
|
||||
assert updated_cc_pair_info.last_full_permission_sync is not None
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Permission sync attempt tracking is enterprise only",
|
||||
)
|
||||
def test_permission_sync_attempt_tracking_integration(
|
||||
mock_server_client: httpx.Client,
|
||||
vespa_client: vespa_fixture, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Test that permission sync attempts are properly tracked during real sync workflows."""
|
||||
|
||||
) -> tuple[DATestCCPair, Document]:
|
||||
"""Common setup: create a test doc, configure mock server, create cc_pair, wait for indexing."""
|
||||
doc_uuid = uuid.uuid4()
|
||||
test_doc = create_test_document(doc_id=f"test-doc-{doc_uuid}")
|
||||
|
||||
@@ -165,7 +53,7 @@ def test_permission_sync_attempt_tracking_integration(
|
||||
assert response.status_code == 200
|
||||
|
||||
cc_pair = CCPairManager.create_from_scratch(
|
||||
name=f"mock-connector-attempt-tracking-{uuid.uuid4()}",
|
||||
name=f"mock-connector-{uuid.uuid4()}",
|
||||
source=DocumentSource.MOCK_CONNECTOR,
|
||||
input_type=InputType.POLL,
|
||||
connector_specific_config={
|
||||
@@ -187,6 +75,95 @@ def test_permission_sync_attempt_tracking_integration(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
finished = IndexAttemptManager.get_index_attempt_by_id(
|
||||
index_attempt_id=index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert finished.status == IndexingStatus.SUCCESS
|
||||
return cc_pair, test_doc
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Permission sync is enterprise only",
|
||||
)
|
||||
def test_mock_connector_initial_permission_sync(
|
||||
mock_server_client: httpx.Client,
|
||||
vespa_client: vespa_fixture,
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Test that the MockConnector fetches and sets permissions during initial indexing
|
||||
when AccessType.SYNC is used."""
|
||||
|
||||
cc_pair, test_doc = _setup_mock_connector(mock_server_client, admin_user)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
vespa_client=vespa_client,
|
||||
)
|
||||
assert len(documents) == 1
|
||||
assert documents[0].id == test_doc.id
|
||||
|
||||
errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(errors) == 0
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_docs = get_documents_by_ids(
|
||||
db_session=db_session,
|
||||
document_ids=[test_doc.id],
|
||||
)
|
||||
assert len(db_docs) == 1
|
||||
db_doc = db_docs[0]
|
||||
|
||||
assert db_doc.external_user_emails is not None
|
||||
assert db_doc.external_user_group_ids is not None
|
||||
assert set(db_doc.external_user_emails) == EXTERNAL_USER_EMAILS
|
||||
assert set(db_doc.external_user_group_ids) == EXTERNAL_USER_GROUP_IDS
|
||||
assert db_doc.is_public is False
|
||||
|
||||
# After initial indexing, the beat task detects last_time_perm_sync is None
|
||||
# and triggers a doc permission sync. Explicitly trigger it to avoid
|
||||
# waiting for the 30s beat interval.
|
||||
before = datetime.now(timezone.utc)
|
||||
CCPairManager.sync(
|
||||
cc_pair=cc_pair,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
CCPairManager.wait_for_sync(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
number_of_updated_docs=1,
|
||||
user_performing_action=admin_user,
|
||||
should_wait_for_group_sync=False,
|
||||
should_wait_for_vespa_sync=False,
|
||||
)
|
||||
|
||||
updated_cc_pair_info = CCPairManager.get_single(
|
||||
cc_pair.id, user_performing_action=admin_user
|
||||
)
|
||||
assert updated_cc_pair_info is not None
|
||||
assert updated_cc_pair_info.last_full_permission_sync is not None
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Permission sync attempt tracking is enterprise only",
|
||||
)
|
||||
def test_permission_sync_attempt_tracking_integration(
|
||||
mock_server_client: httpx.Client,
|
||||
vespa_client: vespa_fixture, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Test that permission sync attempts are properly tracked during real sync workflows."""
|
||||
|
||||
cc_pair, _test_doc = _setup_mock_connector(mock_server_client, admin_user)
|
||||
|
||||
before = datetime.now(timezone.utc)
|
||||
CCPairManager.sync(
|
||||
cc_pair=cc_pair,
|
||||
@@ -198,6 +175,8 @@ def test_permission_sync_attempt_tracking_integration(
|
||||
after=before,
|
||||
number_of_updated_docs=1,
|
||||
user_performing_action=admin_user,
|
||||
should_wait_for_group_sync=False,
|
||||
should_wait_for_vespa_sync=False,
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -219,88 +198,6 @@ def test_permission_sync_attempt_tracking_integration(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Permission sync attempt tracking is enterprise only",
|
||||
)
|
||||
def test_permission_sync_attempt_tracking_with_mocked_failure(
|
||||
mock_server_client: httpx.Client,
|
||||
vespa_client: vespa_fixture, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Test that permission sync attempts are properly tracked when sync fails."""
|
||||
|
||||
doc_uuid = uuid.uuid4()
|
||||
test_doc = create_test_document(doc_id=f"test-doc-{doc_uuid}")
|
||||
|
||||
response = mock_server_client.post(
|
||||
"/set-behavior",
|
||||
json=[
|
||||
{
|
||||
"documents": [test_doc.model_dump(mode="json")],
|
||||
"checkpoint": MockConnectorCheckpoint(has_more=False).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
"failures": [],
|
||||
}
|
||||
],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
cc_pair = CCPairManager.create_from_scratch(
|
||||
name=f"mock-connector-attempt-failure-{uuid.uuid4()}",
|
||||
source=DocumentSource.MOCK_CONNECTOR,
|
||||
input_type=InputType.POLL,
|
||||
connector_specific_config={
|
||||
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
|
||||
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
|
||||
},
|
||||
access_type=AccessType.SYNC,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
index_attempt_id=index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Mock the permission sync to force a failure and verify attempt tracking
|
||||
with patch(
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing.tasks.validate_ccpair_for_user"
|
||||
) as mock_validate:
|
||||
mock_validate.side_effect = Exception("Validation failed for testing")
|
||||
|
||||
try:
|
||||
before = datetime.now(timezone.utc)
|
||||
CCPairManager.sync(
|
||||
cc_pair=cc_pair,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
CCPairManager.wait_for_sync(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
number_of_updated_docs=0,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
attempt = db_session.execute(
|
||||
select(DocPermissionSyncAttempt).where(
|
||||
DocPermissionSyncAttempt.connector_credential_pair_id == cc_pair.id
|
||||
)
|
||||
).scalar_one()
|
||||
|
||||
assert attempt.status == PermissionSyncStatus.FAILED
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Permission sync attempt tracking is enterprise only",
|
||||
@@ -311,45 +208,8 @@ def test_permission_sync_attempt_status_success(
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Test that permission sync attempts are marked as SUCCESS when sync completes without errors."""
|
||||
doc_uuid = uuid.uuid4()
|
||||
test_doc = create_test_document(doc_id=f"test-doc-{doc_uuid}")
|
||||
|
||||
response = mock_server_client.post(
|
||||
"/set-behavior",
|
||||
json=[
|
||||
{
|
||||
"documents": [test_doc.model_dump(mode="json")],
|
||||
"checkpoint": MockConnectorCheckpoint(has_more=False).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
"failures": [],
|
||||
}
|
||||
],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
cc_pair = CCPairManager.create_from_scratch(
|
||||
name=f"mock-connector-success-{uuid.uuid4()}",
|
||||
source=DocumentSource.MOCK_CONNECTOR,
|
||||
input_type=InputType.POLL,
|
||||
connector_specific_config={
|
||||
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
|
||||
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
|
||||
},
|
||||
access_type=AccessType.SYNC,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
index_attempt_id=index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
cc_pair, _test_doc = _setup_mock_connector(mock_server_client, admin_user)
|
||||
|
||||
before = datetime.now(timezone.utc)
|
||||
CCPairManager.sync(
|
||||
@@ -362,6 +222,8 @@ def test_permission_sync_attempt_status_success(
|
||||
after=before,
|
||||
number_of_updated_docs=1,
|
||||
user_performing_action=admin_user,
|
||||
should_wait_for_group_sync=False,
|
||||
should_wait_for_vespa_sync=False,
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
|
||||
@@ -6,11 +6,14 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import can_user_access_llm_provider
|
||||
from onyx.db.llm import fetch_user_group_ids
|
||||
from onyx.db.models import LLMModelFlow
|
||||
from onyx.db.models import LLMProvider as LLMProviderModel
|
||||
from onyx.db.models import LLMProvider__Persona
|
||||
from onyx.db.models import LLMProvider__UserGroup
|
||||
from onyx.db.models import ModelConfiguration
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
@@ -267,6 +270,24 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
|
||||
provider_name=restricted_provider.name,
|
||||
)
|
||||
|
||||
# Set up ModelConfiguration + LLMModelFlow so get_default_llm() can
|
||||
# resolve the default provider when the fallback path is triggered.
|
||||
default_model_config = ModelConfiguration(
|
||||
llm_provider_id=default_provider.id,
|
||||
name=default_provider.default_model_name,
|
||||
is_visible=True,
|
||||
)
|
||||
db_session.add(default_model_config)
|
||||
db_session.flush()
|
||||
db_session.add(
|
||||
LLMModelFlow(
|
||||
model_configuration_id=default_model_config.id,
|
||||
llm_model_flow_type=LLMModelFlowType.CHAT,
|
||||
is_default=True,
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
access_group = UserGroup(name="persona-group")
|
||||
db_session.add(access_group)
|
||||
db_session.flush()
|
||||
|
||||
@@ -0,0 +1,322 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from onyx.configs import app_configs
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.tools.constants import SEARCH_TOOL_ID
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.tool import ToolManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.test_models import ToolName
|
||||
|
||||
|
||||
_ENV_PROVIDER = "NIGHTLY_LLM_PROVIDER"
|
||||
_ENV_MODELS = "NIGHTLY_LLM_MODELS"
|
||||
_ENV_API_KEY = "NIGHTLY_LLM_API_KEY"
|
||||
_ENV_API_BASE = "NIGHTLY_LLM_API_BASE"
|
||||
_ENV_CUSTOM_CONFIG_JSON = "NIGHTLY_LLM_CUSTOM_CONFIG_JSON"
|
||||
_ENV_STRICT = "NIGHTLY_LLM_STRICT"
|
||||
|
||||
|
||||
class NightlyProviderConfig(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
provider: str
|
||||
model_names: list[str]
|
||||
api_key: str | None
|
||||
api_base: str | None
|
||||
custom_config: dict[str, str] | None
|
||||
strict: bool
|
||||
|
||||
|
||||
def _env_true(env_var: str, default: bool = False) -> bool:
|
||||
value = os.environ.get(env_var)
|
||||
if value is None:
|
||||
return default
|
||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _split_csv_env(env_var: str) -> list[str]:
|
||||
return [
|
||||
part.strip() for part in os.environ.get(env_var, "").split(",") if part.strip()
|
||||
]
|
||||
|
||||
|
||||
def _load_provider_config() -> NightlyProviderConfig:
|
||||
provider = os.environ.get(_ENV_PROVIDER, "").strip().lower()
|
||||
model_names = _split_csv_env(_ENV_MODELS)
|
||||
api_key = os.environ.get(_ENV_API_KEY) or None
|
||||
api_base = os.environ.get(_ENV_API_BASE) or None
|
||||
strict = _env_true(_ENV_STRICT, default=False)
|
||||
|
||||
custom_config: dict[str, str] | None = None
|
||||
custom_config_json = os.environ.get(_ENV_CUSTOM_CONFIG_JSON, "").strip()
|
||||
if custom_config_json:
|
||||
parsed = json.loads(custom_config_json)
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError(f"{_ENV_CUSTOM_CONFIG_JSON} must be a JSON object")
|
||||
custom_config = {str(key): str(value) for key, value in parsed.items()}
|
||||
|
||||
if provider == "ollama_chat" and api_key and not custom_config:
|
||||
custom_config = {"OLLAMA_API_KEY": api_key}
|
||||
|
||||
return NightlyProviderConfig(
|
||||
provider=provider,
|
||||
model_names=model_names,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
custom_config=custom_config,
|
||||
strict=strict,
|
||||
)
|
||||
|
||||
|
||||
def _skip_or_fail(strict: bool, message: str) -> None:
|
||||
if strict:
|
||||
pytest.fail(message)
|
||||
pytest.skip(message)
|
||||
|
||||
|
||||
def _validate_provider_config(config: NightlyProviderConfig) -> None:
|
||||
if not config.provider:
|
||||
_skip_or_fail(strict=config.strict, message=f"{_ENV_PROVIDER} must be set")
|
||||
|
||||
if not config.model_names:
|
||||
_skip_or_fail(
|
||||
strict=config.strict,
|
||||
message=f"{_ENV_MODELS} must include at least one model",
|
||||
)
|
||||
|
||||
if config.provider != "ollama_chat" and not config.api_key:
|
||||
_skip_or_fail(
|
||||
strict=config.strict,
|
||||
message=(f"{_ENV_API_KEY} is required for provider '{config.provider}'"),
|
||||
)
|
||||
|
||||
if config.provider == "ollama_chat" and not (
|
||||
config.api_base or _default_api_base_for_provider(config.provider)
|
||||
):
|
||||
_skip_or_fail(
|
||||
strict=config.strict,
|
||||
message=(f"{_ENV_API_BASE} is required for provider '{config.provider}'"),
|
||||
)
|
||||
|
||||
|
||||
def _assert_integration_mode_enabled() -> None:
|
||||
assert (
|
||||
app_configs.INTEGRATION_TESTS_MODE is True
|
||||
), "Integration tests require INTEGRATION_TESTS_MODE=true."
|
||||
|
||||
|
||||
def _seed_connector_for_search_tool(admin_user: DATestUser) -> None:
|
||||
# SearchTool is only exposed when at least one non-default connector exists.
|
||||
CCPairManager.create_from_scratch(
|
||||
source=DocumentSource.INGESTION_API,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
|
||||
def _get_internal_search_tool_id(admin_user: DATestUser) -> int:
|
||||
tools = ToolManager.list_tools(user_performing_action=admin_user)
|
||||
for tool in tools:
|
||||
if tool.in_code_tool_id == SEARCH_TOOL_ID:
|
||||
return tool.id
|
||||
raise AssertionError("SearchTool must exist for this test")
|
||||
|
||||
|
||||
def _default_api_base_for_provider(provider: str) -> str | None:
|
||||
if provider == "openrouter":
|
||||
return "https://openrouter.ai/api/v1"
|
||||
if provider == "ollama_chat":
|
||||
# host.docker.internal works when tests are running inside the integration test container.
|
||||
return "http://host.docker.internal:11434"
|
||||
return None
|
||||
|
||||
|
||||
def _create_provider_payload(
|
||||
provider: str,
|
||||
provider_name: str,
|
||||
model_name: str,
|
||||
api_key: str | None,
|
||||
api_base: str | None,
|
||||
custom_config: dict[str, str] | None,
|
||||
) -> dict:
|
||||
return {
|
||||
"name": provider_name,
|
||||
"provider": provider,
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
"custom_config": custom_config,
|
||||
"default_model_name": model_name,
|
||||
"is_public": True,
|
||||
"groups": [],
|
||||
"personas": [],
|
||||
"model_configurations": [{"name": model_name, "is_visible": True}],
|
||||
"api_key_changed": bool(api_key),
|
||||
"custom_config_changed": bool(custom_config),
|
||||
}
|
||||
|
||||
|
||||
def _ensure_provider_is_default(provider_id: int, admin_user: DATestUser) -> None:
|
||||
list_response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
list_response.raise_for_status()
|
||||
providers = list_response.json()
|
||||
|
||||
current_default = next(
|
||||
(provider for provider in providers if provider.get("is_default_provider")),
|
||||
None,
|
||||
)
|
||||
assert (
|
||||
current_default is not None
|
||||
), "Expected a default provider after setting provider as default"
|
||||
assert (
|
||||
current_default["id"] == provider_id
|
||||
), f"Expected provider {provider_id} to be default, found {current_default['id']}"
|
||||
|
||||
|
||||
def _run_chat_assertions(
|
||||
admin_user: DATestUser,
|
||||
search_tool_id: int,
|
||||
provider: str,
|
||||
model_name: str,
|
||||
) -> None:
|
||||
last_error: str | None = None
|
||||
# Retry once to reduce transient nightly flakes due provider-side blips.
|
||||
for attempt in range(1, 3):
|
||||
chat_session = ChatSessionManager.create(user_performing_action=admin_user)
|
||||
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message=(
|
||||
"Use internal_search to search for 'nightly-provider-regression-sentinel', "
|
||||
"then summarize the result in one short sentence."
|
||||
),
|
||||
user_performing_action=admin_user,
|
||||
forced_tool_ids=[search_tool_id],
|
||||
)
|
||||
|
||||
if response.error is None:
|
||||
used_internal_search = any(
|
||||
used_tool.tool_name == ToolName.INTERNAL_SEARCH
|
||||
for used_tool in response.used_tools
|
||||
)
|
||||
debug_has_internal_search = any(
|
||||
debug_tool_call.tool_name == "internal_search"
|
||||
for debug_tool_call in response.tool_call_debug
|
||||
)
|
||||
has_answer = bool(response.full_message.strip())
|
||||
|
||||
if used_internal_search and debug_has_internal_search and has_answer:
|
||||
return
|
||||
|
||||
last_error = (
|
||||
f"attempt={attempt} provider={provider} model={model_name} "
|
||||
f"used_internal_search={used_internal_search} "
|
||||
f"debug_internal_search={debug_has_internal_search} "
|
||||
f"has_answer={has_answer} "
|
||||
f"tool_call_debug={response.tool_call_debug}"
|
||||
)
|
||||
else:
|
||||
last_error = (
|
||||
f"attempt={attempt} provider={provider} model={model_name} "
|
||||
f"stream_error={response.error.error}"
|
||||
)
|
||||
|
||||
time.sleep(attempt)
|
||||
|
||||
pytest.fail(f"Chat/tool-call assertions failed: {last_error}")
|
||||
|
||||
|
||||
def _create_and_test_provider_for_model(
|
||||
admin_user: DATestUser,
|
||||
config: NightlyProviderConfig,
|
||||
model_name: str,
|
||||
search_tool_id: int,
|
||||
) -> None:
|
||||
provider_name = f"nightly-{config.provider}-{uuid4().hex[:12]}"
|
||||
resolved_api_base = config.api_base or _default_api_base_for_provider(
|
||||
config.provider
|
||||
)
|
||||
|
||||
provider_payload = _create_provider_payload(
|
||||
provider=config.provider,
|
||||
provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
api_key=config.api_key,
|
||||
api_base=resolved_api_base,
|
||||
custom_config=config.custom_config,
|
||||
)
|
||||
|
||||
test_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/llm/test",
|
||||
headers=admin_user.headers,
|
||||
json=provider_payload,
|
||||
)
|
||||
assert test_response.status_code == 200, (
|
||||
f"Provider test endpoint failed for provider={config.provider} "
|
||||
f"model={model_name}: {test_response.status_code} {test_response.text}"
|
||||
)
|
||||
|
||||
create_response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
|
||||
headers=admin_user.headers,
|
||||
json=provider_payload,
|
||||
)
|
||||
assert create_response.status_code == 200, (
|
||||
f"Provider creation failed for provider={config.provider} "
|
||||
f"model={model_name}: {create_response.status_code} {create_response.text}"
|
||||
)
|
||||
provider_id = create_response.json()["id"]
|
||||
|
||||
try:
|
||||
set_default_response = requests.post(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{provider_id}/default",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert set_default_response.status_code == 200, (
|
||||
f"Setting default provider failed for provider={config.provider} "
|
||||
f"model={model_name}: {set_default_response.status_code} "
|
||||
f"{set_default_response.text}"
|
||||
)
|
||||
|
||||
_ensure_provider_is_default(provider_id=provider_id, admin_user=admin_user)
|
||||
_run_chat_assertions(
|
||||
admin_user=admin_user,
|
||||
search_tool_id=search_tool_id,
|
||||
provider=config.provider,
|
||||
model_name=model_name,
|
||||
)
|
||||
finally:
|
||||
requests.delete(
|
||||
f"{API_SERVER_URL}/admin/llm/provider/{provider_id}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
|
||||
|
||||
def test_nightly_provider_chat_workflow(admin_user: DATestUser) -> None:
|
||||
"""Nightly regression test for provider setup + default selection + chat tool calls."""
|
||||
_assert_integration_mode_enabled()
|
||||
config = _load_provider_config()
|
||||
_validate_provider_config(config)
|
||||
|
||||
_seed_connector_for_search_tool(admin_user)
|
||||
search_tool_id = _get_internal_search_tool_id(admin_user)
|
||||
|
||||
for model_name in config.model_names:
|
||||
_create_and_test_provider_for_model(
|
||||
admin_user=admin_user,
|
||||
config=config,
|
||||
model_name=model_name,
|
||||
search_tool_id=search_tool_id,
|
||||
)
|
||||
@@ -6,7 +6,7 @@ the permissions of the curator manipulating connector-credential pairs.
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
from onyx_openapi_client.exceptions import ApiException # type: ignore[import-untyped,unused-ignore,import-not-found]
|
||||
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.server.documents.models import DocumentSource
|
||||
@@ -93,20 +93,9 @@ def test_cc_pair_permissions(reset: None) -> None: # noqa: ARG001
|
||||
|
||||
"""Tests for things Curators should not be able to do"""
|
||||
|
||||
# Curators should not be able to create a public cc pair
|
||||
with pytest.raises(HTTPError):
|
||||
CCPairManager.create(
|
||||
connector_id=connector_1.id,
|
||||
credential_id=credential_1.id,
|
||||
name="invalid_cc_pair_1",
|
||||
access_type=AccessType.PUBLIC,
|
||||
groups=[user_group_1.id],
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Curators should not be able to create a cc
|
||||
# pair for a user group they are not a curator of
|
||||
with pytest.raises(HTTPError):
|
||||
with pytest.raises(ApiException):
|
||||
CCPairManager.create(
|
||||
connector_id=connector_1.id,
|
||||
credential_id=credential_1.id,
|
||||
@@ -118,7 +107,7 @@ def test_cc_pair_permissions(reset: None) -> None: # noqa: ARG001
|
||||
|
||||
# Curators should not be able to create a cc
|
||||
# pair without an attached user group
|
||||
with pytest.raises(HTTPError):
|
||||
with pytest.raises(ApiException):
|
||||
CCPairManager.create(
|
||||
connector_id=connector_1.id,
|
||||
credential_id=credential_1.id,
|
||||
@@ -144,7 +133,7 @@ def test_cc_pair_permissions(reset: None) -> None: # noqa: ARG001
|
||||
|
||||
# Curators should not be able to create a cc
|
||||
# pair for a user group that the credential does not belong to
|
||||
with pytest.raises(HTTPError):
|
||||
with pytest.raises(ApiException):
|
||||
CCPairManager.create(
|
||||
connector_id=connector_1.id,
|
||||
credential_id=credential_2.id,
|
||||
@@ -156,6 +145,16 @@ def test_cc_pair_permissions(reset: None) -> None: # noqa: ARG001
|
||||
|
||||
"""Tests for things Curators should be able to do"""
|
||||
|
||||
# Re-create connector since the credential_2 validation error above
|
||||
# triggers connector deletion in the exception handler
|
||||
connector_1 = ConnectorManager.create(
|
||||
name="admin_owned_connector_2",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_1.id],
|
||||
access_type=AccessType.PRIVATE,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Curators should be able to create a private
|
||||
# cc pair for a user group they are a curator of
|
||||
valid_cc_pair = CCPairManager.create(
|
||||
|
||||
@@ -59,17 +59,7 @@ def test_connector_permissions(reset: None) -> None: # noqa: ARG001
|
||||
|
||||
"""Tests for things Curators should not be able to do"""
|
||||
|
||||
# Curators should not be able to create a public connector
|
||||
with pytest.raises(HTTPError):
|
||||
ConnectorManager.create(
|
||||
name="invalid_connector_1",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_1.id],
|
||||
access_type=AccessType.PUBLIC,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Curators should not be able to create a cc pair for a
|
||||
# Curators should not be able to create a connector for a
|
||||
# user group they are not a curator of
|
||||
with pytest.raises(HTTPError):
|
||||
ConnectorManager.create(
|
||||
@@ -133,12 +123,12 @@ def test_connector_permissions(reset: None) -> None: # noqa: ARG001
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Test that curator cannot create a public connector
|
||||
with pytest.raises(HTTPError):
|
||||
ConnectorManager.create(
|
||||
name="invalid_connector_4",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_1.id],
|
||||
access_type=AccessType.PUBLIC,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
# Curators should be able to create a public connector
|
||||
public_connector = ConnectorManager.create(
|
||||
name="curator_public_connector",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_1.id],
|
||||
access_type=AccessType.PUBLIC,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
assert public_connector.id is not None
|
||||
|
||||
@@ -58,16 +58,6 @@ def test_credential_permissions(reset: None) -> None: # noqa: ARG001
|
||||
|
||||
"""Tests for things Curators should not be able to do"""
|
||||
|
||||
# Curators should not be able to create a public credential
|
||||
with pytest.raises(HTTPError):
|
||||
CredentialManager.create(
|
||||
name="invalid_credential_1",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_1.id],
|
||||
curator_public=True,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Curators should not be able to create a credential for a user group they are not a curator of
|
||||
with pytest.raises(HTTPError):
|
||||
CredentialManager.create(
|
||||
@@ -113,3 +103,16 @@ def test_credential_permissions(reset: None) -> None: # noqa: ARG001
|
||||
verify_deleted=True,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Curators should be able to create a public credential
|
||||
public_credential = CredentialManager.create(
|
||||
name="curator_public_credential",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_1.id],
|
||||
curator_public=True,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
CredentialManager.verify(
|
||||
credential=public_credential,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
@@ -70,10 +70,11 @@ def test_doc_set_permissions_setup(reset: None) -> None: # noqa: ARG001
|
||||
|
||||
"""Tests for things Curators/Admins should not be able to do"""
|
||||
|
||||
# Test that curator cannot create a document set for the group they don't curate
|
||||
# Test that curator cannot create a non-public document set for the group they don't curate
|
||||
with pytest.raises(HTTPError):
|
||||
DocumentSetManager.create(
|
||||
name="Invalid Document Set 1",
|
||||
is_public=False,
|
||||
groups=[user_group_2.id],
|
||||
cc_pair_ids=[public_cc_pair.id],
|
||||
user_performing_action=curator,
|
||||
|
||||
@@ -6,12 +6,14 @@ from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from io import BytesIO
|
||||
from io import StringIO
|
||||
from uuid import UUID
|
||||
from zipfile import ZipFile
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from ee.onyx.db.usage_export import UsageReportMetadata
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.db.seeding.chat_history_seeding import seed_chat_history
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
@@ -26,7 +28,13 @@ class TestUsageExportAPI:
|
||||
self, reset: None, admin_user: DATestUser # noqa: ARG002
|
||||
) -> None:
|
||||
# Seed some chat history data for the report
|
||||
seed_chat_history(num_sessions=10, num_messages=4, days=30)
|
||||
seed_chat_history(
|
||||
num_sessions=10,
|
||||
num_messages=4,
|
||||
days=30,
|
||||
user_id=UUID(admin_user.id),
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
)
|
||||
|
||||
# Get initial list of reports
|
||||
initial_response = requests.get(
|
||||
@@ -76,7 +84,13 @@ class TestUsageExportAPI:
|
||||
self, reset: None, admin_user: DATestUser # noqa: ARG002
|
||||
) -> None:
|
||||
# Seed some chat history data
|
||||
seed_chat_history(num_sessions=20, num_messages=4, days=60)
|
||||
seed_chat_history(
|
||||
num_sessions=20,
|
||||
num_messages=4,
|
||||
days=60,
|
||||
user_id=UUID(admin_user.id),
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
)
|
||||
|
||||
# Get initial list of reports
|
||||
initial_response = requests.get(
|
||||
@@ -148,7 +162,13 @@ class TestUsageExportAPI:
|
||||
self, reset: None, admin_user: DATestUser # noqa: ARG002
|
||||
) -> None:
|
||||
# First generate a report to ensure we have at least one
|
||||
seed_chat_history(num_sessions=5, num_messages=4, days=30)
|
||||
seed_chat_history(
|
||||
num_sessions=5,
|
||||
num_messages=4,
|
||||
days=30,
|
||||
user_id=UUID(admin_user.id),
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
)
|
||||
|
||||
# Get initial count
|
||||
initial_response = requests.get(
|
||||
@@ -204,7 +224,13 @@ class TestUsageExportAPI:
|
||||
self, reset: None, admin_user: DATestUser # noqa: ARG002
|
||||
) -> None:
|
||||
# First generate a report
|
||||
seed_chat_history(num_sessions=5, num_messages=4, days=30)
|
||||
seed_chat_history(
|
||||
num_sessions=5,
|
||||
num_messages=4,
|
||||
days=30,
|
||||
user_id=UUID(admin_user.id),
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
)
|
||||
|
||||
# Get initial reports count
|
||||
initial_response = requests.get(
|
||||
@@ -352,7 +378,13 @@ class TestUsageExportAPI:
|
||||
self, reset: None, admin_user: DATestUser # noqa: ARG002
|
||||
) -> None:
|
||||
# Seed some data
|
||||
seed_chat_history(num_sessions=10, num_messages=4, days=30)
|
||||
seed_chat_history(
|
||||
num_sessions=10,
|
||||
num_messages=4,
|
||||
days=30,
|
||||
user_id=UUID(admin_user.id),
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
)
|
||||
|
||||
# Get initial count of reports
|
||||
initial_response = requests.get(
|
||||
|
||||
@@ -25,6 +25,11 @@ def test_add_users_to_group(reset: None) -> None: # noqa: ARG001
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_performing_action=admin_user,
|
||||
user_groups_to_check=[user_group],
|
||||
)
|
||||
|
||||
updated_user_group = UserGroupManager.add_users(
|
||||
user_group=user_group,
|
||||
user_ids=[user_to_add.id],
|
||||
|
||||
@@ -3,6 +3,8 @@ from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
_enumerate_ad_groups_paginated,
|
||||
)
|
||||
@@ -15,6 +17,9 @@ from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
AD_GROUP_ENUMERATION_THRESHOLD,
|
||||
)
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
get_external_access_from_sharepoint,
|
||||
)
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
get_sharepoint_external_groups,
|
||||
)
|
||||
@@ -266,3 +271,65 @@ def test_enumerate_all_without_token_skips(
|
||||
|
||||
assert results == []
|
||||
mock_enum.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_external_access_from_sharepoint – site page URL handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"site_base_url, web_url, expected_relative_url",
|
||||
[
|
||||
(
|
||||
"https://tenant.sharepoint.com/sites/Evan%27sSite",
|
||||
"https://tenant.sharepoint.com/sites/Evan%27sSite/SitePages/Home.aspx",
|
||||
"/sites/Evan%27sSite/SitePages/Home.aspx",
|
||||
),
|
||||
(
|
||||
"https://tenant.sharepoint.com/sites/NormalSite",
|
||||
"https://tenant.sharepoint.com/sites/NormalSite/SitePages/Page.aspx",
|
||||
"/sites/NormalSite/SitePages/Page.aspx",
|
||||
),
|
||||
(
|
||||
"https://tenant.sharepoint.com/sites/Site%20With%20Spaces",
|
||||
"https://tenant.sharepoint.com/sites/Site%20With%20Spaces/SitePages/Doc.aspx",
|
||||
"/sites/Site%20With%20Spaces/SitePages/Doc.aspx",
|
||||
),
|
||||
],
|
||||
ids=["apostrophe-encoded", "no-special-chars", "space-encoded"],
|
||||
)
|
||||
@patch(f"{MODULE}._get_groups_and_members_recursively")
|
||||
@patch(f"{MODULE}.sleep_and_retry")
|
||||
def test_site_page_url_not_duplicated(
|
||||
mock_sleep: MagicMock, # noqa: ARG001
|
||||
mock_recursive: MagicMock,
|
||||
site_base_url: str,
|
||||
web_url: str,
|
||||
expected_relative_url: str,
|
||||
) -> None:
|
||||
"""Regression: the server-relative URL passed to
|
||||
get_file_by_server_relative_url must preserve percent-encoding so the
|
||||
Office365 library's SPResPath.create_relative() recognises the site prefix
|
||||
and doesn't duplicate it."""
|
||||
mock_recursive.return_value = GroupsResult(
|
||||
groups_to_emails={},
|
||||
found_public_group=False,
|
||||
)
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.base_url = site_base_url
|
||||
|
||||
site_page = {"webUrl": web_url}
|
||||
|
||||
get_external_access_from_sharepoint(
|
||||
client_context=ctx,
|
||||
graph_client=MagicMock(),
|
||||
drive_name=None,
|
||||
drive_item=None,
|
||||
site_page=site_page,
|
||||
)
|
||||
|
||||
ctx.web.get_file_by_server_relative_url.assert_called_once_with(
|
||||
expected_relative_url
|
||||
)
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
import json
|
||||
|
||||
import httplib2 # type: ignore[import-untyped]
|
||||
from googleapiclient.errors import HttpError # type: ignore[import-untyped]
|
||||
|
||||
from onyx.connectors.google_utils.google_utils import _is_rate_limit_error
|
||||
|
||||
|
||||
def _make_http_error(
|
||||
status: int,
|
||||
reason: str = "unknown",
|
||||
error_reason: str = "",
|
||||
) -> HttpError:
|
||||
resp = httplib2.Response({"status": status})
|
||||
if error_reason:
|
||||
body = json.dumps(
|
||||
{
|
||||
"error": {
|
||||
"message": reason,
|
||||
"errors": [{"reason": error_reason, "message": reason}],
|
||||
}
|
||||
}
|
||||
).encode()
|
||||
else:
|
||||
body = json.dumps({"error": {"message": reason}}).encode()
|
||||
return HttpError(resp, body)
|
||||
|
||||
|
||||
def test_429_is_rate_limit() -> None:
|
||||
assert _is_rate_limit_error(_make_http_error(429))
|
||||
|
||||
|
||||
def test_403_user_rate_limit_exceeded() -> None:
|
||||
err = _make_http_error(
|
||||
403,
|
||||
reason="User rate limit exceeded.",
|
||||
error_reason="userRateLimitExceeded",
|
||||
)
|
||||
assert _is_rate_limit_error(err)
|
||||
|
||||
|
||||
def test_403_rate_limit_exceeded() -> None:
|
||||
err = _make_http_error(
|
||||
403,
|
||||
reason="Rate limit exceeded.",
|
||||
error_reason="rateLimitExceeded",
|
||||
)
|
||||
assert _is_rate_limit_error(err)
|
||||
|
||||
|
||||
def test_403_permission_denied_is_not_rate_limit() -> None:
|
||||
err = _make_http_error(
|
||||
403,
|
||||
reason="The caller does not have permission",
|
||||
error_reason="forbidden",
|
||||
)
|
||||
assert not _is_rate_limit_error(err)
|
||||
|
||||
|
||||
def test_404_is_not_rate_limit() -> None:
|
||||
assert not _is_rate_limit_error(_make_http_error(404))
|
||||
|
||||
|
||||
def test_500_is_not_rate_limit() -> None:
|
||||
assert not _is_rate_limit_error(_make_http_error(500))
|
||||
@@ -0,0 +1,34 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.slab.connector import SlabConnector
|
||||
|
||||
|
||||
def _build_connector(base_url: str = "https://myteam.slab.com") -> SlabConnector:
|
||||
connector = SlabConnector(base_url=base_url)
|
||||
connector.load_credentials({"slab_bot_token": "fake-token"})
|
||||
return connector
|
||||
|
||||
|
||||
def test_validate_rejects_missing_scheme() -> None:
|
||||
connector = _build_connector(base_url="myteam.slab.com")
|
||||
with pytest.raises(ConnectorValidationError, match="https://"):
|
||||
connector.validate_connector_settings()
|
||||
|
||||
|
||||
@patch("onyx.connectors.slab.connector.get_all_post_ids", return_value=["id1"])
|
||||
def test_validate_success(mock_get_posts: object) -> None: # noqa: ARG001
|
||||
connector = _build_connector()
|
||||
connector.validate_connector_settings()
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.connectors.slab.connector.get_all_post_ids",
|
||||
side_effect=Exception("401 Unauthorized"),
|
||||
)
|
||||
def test_validate_bad_token_raises(mock_get_posts: object) -> None: # noqa: ARG001
|
||||
connector = _build_connector()
|
||||
with pytest.raises(ConnectorValidationError, match="Failed to fetch posts"):
|
||||
connector.validate_connector_settings()
|
||||
@@ -98,6 +98,11 @@ class TestScimDALUserMappings:
|
||||
"external_id": "ext-1",
|
||||
"user_id": user_id,
|
||||
"scim_username": None,
|
||||
"department": None,
|
||||
"manager": None,
|
||||
"given_name": None,
|
||||
"family_name": None,
|
||||
"scim_emails_json": None,
|
||||
}
|
||||
|
||||
def test_delete_user_mapping(
|
||||
|
||||
199
backend/tests/unit/onyx/indexing/test_personas_in_chunks.py
Normal file
199
backend/tests/unit/onyx/indexing/test_personas_in_chunks.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""Tests that persona IDs are correctly propagated through the indexing pipeline.
|
||||
|
||||
Covers Phase 1 (schema plumbing) and Phase 2 (write at index time) of the
|
||||
unify-assistant-project-files plan.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentSource
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
|
||||
|
||||
def _make_index_chunk(
|
||||
doc_id: str = "test-file-id",
|
||||
content: str = "test content",
|
||||
) -> IndexChunk:
|
||||
embedding = [0.1] * 10
|
||||
doc = Document(
|
||||
id=doc_id,
|
||||
semantic_identifier="test_file.txt",
|
||||
sections=[TextSection(text=content, link=None)],
|
||||
source=DocumentSource.USER_FILE,
|
||||
metadata={},
|
||||
)
|
||||
return IndexChunk(
|
||||
chunk_id=0,
|
||||
blurb=content[:50],
|
||||
content=content,
|
||||
source_links=None,
|
||||
image_file_id=None,
|
||||
section_continuation=False,
|
||||
source_document=doc,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=embedding,
|
||||
mini_chunk_embeddings=[],
|
||||
),
|
||||
title_embedding=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_access() -> DocumentAccess:
|
||||
return DocumentAccess.build(
|
||||
user_emails=["user@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
|
||||
def test_from_index_chunk_propagates_personas() -> None:
|
||||
"""Personas list passed to from_index_chunk appears on the result."""
|
||||
chunk = _make_index_chunk()
|
||||
persona_ids = [10, 20, 30]
|
||||
|
||||
aware_chunk = DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
access=_make_access(),
|
||||
document_sets=set(),
|
||||
user_project=[1],
|
||||
personas=persona_ids,
|
||||
boost=0,
|
||||
aggregated_chunk_boost_factor=1.0,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
assert aware_chunk.personas == persona_ids
|
||||
assert aware_chunk.user_project == [1]
|
||||
|
||||
|
||||
def test_from_index_chunk_empty_personas() -> None:
|
||||
"""An empty personas list is preserved (not turned into None or omitted)."""
|
||||
chunk = _make_index_chunk()
|
||||
|
||||
aware_chunk = DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
access=_make_access(),
|
||||
document_sets=set(),
|
||||
user_project=[],
|
||||
personas=[],
|
||||
boost=0,
|
||||
aggregated_chunk_boost_factor=1.0,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
assert aware_chunk.personas == []
|
||||
|
||||
|
||||
def _make_document(doc_id: str) -> Document:
|
||||
return Document(
|
||||
id=doc_id,
|
||||
semantic_identifier="test_file.txt",
|
||||
sections=[TextSection(text="test content", link=None)],
|
||||
source=DocumentSource.USER_FILE,
|
||||
metadata={},
|
||||
)
|
||||
|
||||
|
||||
def _run_adapter_build(
|
||||
file_id: str,
|
||||
project_ids_map: dict[str, list[int]],
|
||||
persona_ids_map: dict[str, list[int]],
|
||||
) -> list[DocMetadataAwareIndexChunk]:
|
||||
"""Helper that runs UserFileIndexingAdapter.build_metadata_aware_chunks
|
||||
with all external dependencies mocked."""
|
||||
from onyx.indexing.adapters.user_file_indexing_adapter import (
|
||||
UserFileIndexingAdapter,
|
||||
)
|
||||
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
|
||||
|
||||
chunk = _make_index_chunk(doc_id=file_id)
|
||||
doc = _make_document(doc_id=file_id)
|
||||
|
||||
context = DocumentBatchPrepareContext(
|
||||
updatable_docs=[doc],
|
||||
id_to_boost_map={},
|
||||
)
|
||||
|
||||
adapter = UserFileIndexingAdapter(tenant_id="test_tenant", db_session=MagicMock())
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.fetch_user_project_ids_for_user_files",
|
||||
return_value=project_ids_map,
|
||||
),
|
||||
patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.fetch_persona_ids_for_user_files",
|
||||
return_value=persona_ids_map,
|
||||
),
|
||||
patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_access_for_user_files",
|
||||
return_value={file_id: _make_access()},
|
||||
),
|
||||
patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.fetch_chunk_counts_for_user_files",
|
||||
return_value=[(file_id, 0)],
|
||||
),
|
||||
patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
side_effect=Exception("no LLM in tests"),
|
||||
),
|
||||
):
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id="test_tenant",
|
||||
context=context,
|
||||
)
|
||||
|
||||
return result.chunks
|
||||
|
||||
|
||||
def test_build_metadata_aware_chunks_includes_persona_ids() -> None:
|
||||
"""UserFileIndexingAdapter.build_metadata_aware_chunks writes persona IDs
|
||||
fetched from the DB into each chunk's metadata."""
|
||||
file_id = str(uuid4())
|
||||
persona_ids = [5, 12]
|
||||
project_ids = [3]
|
||||
|
||||
chunks = _run_adapter_build(
|
||||
file_id=file_id,
|
||||
project_ids_map={file_id: project_ids},
|
||||
persona_ids_map={file_id: persona_ids},
|
||||
)
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].personas == persona_ids
|
||||
assert chunks[0].user_project == project_ids
|
||||
|
||||
|
||||
def test_build_metadata_aware_chunks_missing_file_defaults_to_empty() -> None:
|
||||
"""When a file has no persona or project associations in the DB, the
|
||||
adapter should default to empty lists (not KeyError or None)."""
|
||||
file_id = str(uuid4())
|
||||
|
||||
chunks = _run_adapter_build(
|
||||
file_id=file_id,
|
||||
project_ids_map={},
|
||||
persona_ids_map={},
|
||||
)
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].personas == []
|
||||
assert chunks[0].user_project == []
|
||||
106
backend/tests/unit/onyx/onyxbot/test_slack_formatting.py
Normal file
106
backend/tests/unit/onyx/onyxbot/test_slack_formatting.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from onyx.onyxbot.slack.formatting import _convert_slack_links_to_markdown
|
||||
from onyx.onyxbot.slack.formatting import _normalize_link_destinations
|
||||
from onyx.onyxbot.slack.formatting import _sanitize_html
|
||||
from onyx.onyxbot.slack.formatting import _transform_outside_code_blocks
|
||||
from onyx.onyxbot.slack.formatting import format_slack_message
|
||||
from onyx.onyxbot.slack.utils import remove_slack_text_interactions
|
||||
from onyx.utils.text_processing import decode_escapes
|
||||
|
||||
|
||||
def test_normalize_citation_link_wraps_url_with_parentheses() -> None:
|
||||
message = (
|
||||
"See [[1]](https://example.com/Access%20ID%20Card(s)%20Guide.pdf) for details."
|
||||
)
|
||||
|
||||
normalized = _normalize_link_destinations(message)
|
||||
|
||||
assert (
|
||||
"See [[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>) for details."
|
||||
== normalized
|
||||
)
|
||||
|
||||
|
||||
def test_normalize_citation_link_keeps_existing_angle_brackets() -> None:
|
||||
message = "[[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>)"
|
||||
|
||||
normalized = _normalize_link_destinations(message)
|
||||
|
||||
assert message == normalized
|
||||
|
||||
|
||||
def test_normalize_citation_link_handles_multiple_links() -> None:
|
||||
message = (
|
||||
"[[1]](https://example.com/(USA)%20Guide.pdf) "
|
||||
"[[2]](https://example.com/Plan(s)%20Overview.pdf)"
|
||||
)
|
||||
|
||||
normalized = _normalize_link_destinations(message)
|
||||
|
||||
assert "[[1]](<https://example.com/(USA)%20Guide.pdf>)" in normalized
|
||||
assert "[[2]](<https://example.com/Plan(s)%20Overview.pdf>)" in normalized
|
||||
|
||||
|
||||
def test_format_slack_message_keeps_parenthesized_citation_links_intact() -> None:
|
||||
message = (
|
||||
"Download [[1]](https://example.com/(USA)%20Access%20ID%20Card(s)%20Guide.pdf)"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
rendered = decode_escapes(remove_slack_text_interactions(formatted))
|
||||
|
||||
assert (
|
||||
"<https://example.com/(USA)%20Access%20ID%20Card(s)%20Guide.pdf|[1]>"
|
||||
in rendered
|
||||
)
|
||||
assert "|[1]>%20Access%20ID%20Card" not in rendered
|
||||
|
||||
|
||||
def test_slack_style_links_converted_to_clickable_links() -> None:
|
||||
message = "Visit <https://example.com/page|Example Page> for details."
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "<https://example.com/page|Example Page>" in formatted
|
||||
assert "<" not in formatted
|
||||
|
||||
|
||||
def test_slack_style_links_preserved_inside_code_blocks() -> None:
|
||||
message = "```\n<https://example.com|click>\n```"
|
||||
|
||||
converted = _convert_slack_links_to_markdown(message)
|
||||
|
||||
assert "<https://example.com|click>" in converted
|
||||
|
||||
|
||||
def test_html_tags_stripped_outside_code_blocks() -> None:
|
||||
message = "Hello<br/>world ```<div>code</div>``` after"
|
||||
|
||||
sanitized = _transform_outside_code_blocks(message, _sanitize_html)
|
||||
|
||||
assert "<br" not in sanitized
|
||||
assert "<div>code</div>" in sanitized
|
||||
|
||||
|
||||
def test_format_slack_message_block_spacing() -> None:
|
||||
message = "Paragraph one.\n\nParagraph two."
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "Paragraph one.\n\nParagraph two." == formatted
|
||||
|
||||
|
||||
def test_format_slack_message_code_block_no_trailing_blank_line() -> None:
|
||||
message = "```python\nprint('hi')\n```"
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert formatted.endswith("print('hi')\n```")
|
||||
|
||||
|
||||
def test_format_slack_message_ampersand_not_double_escaped() -> None:
|
||||
message = 'She said "hello" & goodbye.'
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "&" in formatted
|
||||
assert """ not in formatted
|
||||
@@ -1,10 +1,12 @@
|
||||
"""Test bulk invite limit for free trial tenants."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from onyx.server.manage.models import EmailInviteStatus
|
||||
from onyx.server.manage.users import bulk_invite_users
|
||||
|
||||
|
||||
@@ -33,6 +35,7 @@ def test_trial_tenant_cannot_exceed_invite_limit(*_mocks: None) -> None:
|
||||
@patch("onyx.server.manage.users.get_invited_users", return_value=[])
|
||||
@patch("onyx.server.manage.users.get_all_users", return_value=[])
|
||||
@patch("onyx.server.manage.users.write_invited_users", return_value=3)
|
||||
@patch("onyx.server.manage.users.enforce_seat_limit")
|
||||
@patch("onyx.server.manage.users.NUM_FREE_TRIAL_USER_INVITES", 5)
|
||||
@patch(
|
||||
"onyx.server.manage.users.fetch_ee_implementation_or_noop",
|
||||
@@ -44,4 +47,69 @@ def test_trial_tenant_can_invite_within_limit(*_mocks: None) -> None:
|
||||
|
||||
result = bulk_invite_users(emails=emails)
|
||||
|
||||
assert result == 3
|
||||
assert result.invited_count == 3
|
||||
assert result.email_invite_status == EmailInviteStatus.DISABLED
|
||||
|
||||
|
||||
# --- email_invite_status tests ---
|
||||
|
||||
_COMMON_PATCHES = [
|
||||
patch("onyx.server.manage.users.MULTI_TENANT", False),
|
||||
patch("onyx.server.manage.users.get_current_tenant_id", return_value="test_tenant"),
|
||||
patch("onyx.server.manage.users.get_invited_users", return_value=[]),
|
||||
patch("onyx.server.manage.users.get_all_users", return_value=[]),
|
||||
patch("onyx.server.manage.users.write_invited_users", return_value=1),
|
||||
patch("onyx.server.manage.users.enforce_seat_limit"),
|
||||
]
|
||||
|
||||
|
||||
def _with_common_patches(fn: object) -> object:
|
||||
for p in reversed(_COMMON_PATCHES):
|
||||
fn = p(fn) # type: ignore
|
||||
return fn
|
||||
|
||||
|
||||
@_with_common_patches
|
||||
@patch("onyx.server.manage.users.ENABLE_EMAIL_INVITES", False)
|
||||
def test_email_invite_status_disabled(*_mocks: None) -> None:
|
||||
"""When email invites are disabled, status is disabled."""
|
||||
result = bulk_invite_users(emails=["user@example.com"])
|
||||
|
||||
assert result.email_invite_status == EmailInviteStatus.DISABLED
|
||||
|
||||
|
||||
@_with_common_patches
|
||||
@patch("onyx.server.manage.users.ENABLE_EMAIL_INVITES", True)
|
||||
@patch("onyx.server.manage.users.EMAIL_CONFIGURED", False)
|
||||
def test_email_invite_status_not_configured(*_mocks: None) -> None:
|
||||
"""When email invites are enabled but no server is configured, status is not_configured."""
|
||||
result = bulk_invite_users(emails=["user@example.com"])
|
||||
|
||||
assert result.email_invite_status == EmailInviteStatus.NOT_CONFIGURED
|
||||
|
||||
|
||||
@_with_common_patches
|
||||
@patch("onyx.server.manage.users.ENABLE_EMAIL_INVITES", True)
|
||||
@patch("onyx.server.manage.users.EMAIL_CONFIGURED", True)
|
||||
@patch("onyx.server.manage.users.send_user_email_invite")
|
||||
def test_email_invite_status_sent(mock_send: MagicMock, *_mocks: None) -> None:
|
||||
"""When email invites are enabled and configured, status is sent."""
|
||||
result = bulk_invite_users(emails=["user@example.com"])
|
||||
|
||||
mock_send.assert_called_once()
|
||||
assert result.email_invite_status == EmailInviteStatus.SENT
|
||||
|
||||
|
||||
@_with_common_patches
|
||||
@patch("onyx.server.manage.users.ENABLE_EMAIL_INVITES", True)
|
||||
@patch("onyx.server.manage.users.EMAIL_CONFIGURED", True)
|
||||
@patch(
|
||||
"onyx.server.manage.users.send_user_email_invite",
|
||||
side_effect=Exception("SMTP auth failed"),
|
||||
)
|
||||
def test_email_invite_status_send_failed(*_mocks: None) -> None:
|
||||
"""When email sending throws, status is send_failed and invite is still saved."""
|
||||
result = bulk_invite_users(emails=["user@example.com"])
|
||||
|
||||
assert result.email_invite_status == EmailInviteStatus.SEND_FAILED
|
||||
assert result.invited_count == 1
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
@@ -12,7 +13,9 @@ import pytest
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.scim.api import ScimJSONResponse
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimListResponse
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
@@ -115,6 +118,11 @@ def make_user_mapping(**kwargs: Any) -> MagicMock:
|
||||
mapping.external_id = kwargs.get("external_id", "ext-default")
|
||||
mapping.user_id = kwargs.get("user_id", uuid4())
|
||||
mapping.scim_username = kwargs.get("scim_username", None)
|
||||
mapping.department = kwargs.get("department", None)
|
||||
mapping.manager = kwargs.get("manager", None)
|
||||
mapping.given_name = kwargs.get("given_name", None)
|
||||
mapping.family_name = kwargs.get("family_name", None)
|
||||
mapping.scim_emails_json = kwargs.get("scim_emails_json", None)
|
||||
return mapping
|
||||
|
||||
|
||||
@@ -122,3 +130,35 @@ def assert_scim_error(result: object, expected_status: int) -> None:
|
||||
"""Assert *result* is a JSONResponse with the given status code."""
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == expected_status
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response parsing helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def parse_scim_user(result: object, *, status: int = 200) -> ScimUserResource:
|
||||
"""Assert *result* is a ScimJSONResponse and parse as ScimUserResource."""
|
||||
assert isinstance(
|
||||
result, ScimJSONResponse
|
||||
), f"Expected ScimJSONResponse, got {type(result).__name__}"
|
||||
assert result.status_code == status
|
||||
return ScimUserResource.model_validate(json.loads(result.body))
|
||||
|
||||
|
||||
def parse_scim_group(result: object, *, status: int = 200) -> ScimGroupResource:
|
||||
"""Assert *result* is a ScimJSONResponse and parse as ScimGroupResource."""
|
||||
assert isinstance(
|
||||
result, ScimJSONResponse
|
||||
), f"Expected ScimJSONResponse, got {type(result).__name__}"
|
||||
assert result.status_code == status
|
||||
return ScimGroupResource.model_validate(json.loads(result.body))
|
||||
|
||||
|
||||
def parse_scim_list(result: object) -> ScimListResponse:
|
||||
"""Assert *result* is a ScimJSONResponse and parse as ScimListResponse."""
|
||||
assert isinstance(
|
||||
result, ScimJSONResponse
|
||||
), f"Expected ScimJSONResponse, got {type(result).__name__}"
|
||||
assert result.status_code == 200
|
||||
return ScimListResponse.model_validate(json.loads(result.body))
|
||||
|
||||
983
backend/tests/unit/onyx/server/scim/test_entra.py
Normal file
983
backend/tests/unit/onyx/server/scim/test_entra.py
Normal file
@@ -0,0 +1,983 @@
|
||||
"""Comprehensive Entra ID (Azure AD) SCIM compatibility tests.
|
||||
|
||||
Covers the full Entra provisioning lifecycle: service discovery, user CRUD
|
||||
with enterprise extension schema, group CRUD with excludedAttributes, and
|
||||
all Entra-specific behavioral quirks (PascalCase ops, enterprise URN in
|
||||
PATCH value dicts).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import Response
|
||||
|
||||
from ee.onyx.server.scim.api import create_user
|
||||
from ee.onyx.server.scim.api import delete_user
|
||||
from ee.onyx.server.scim.api import get_group
|
||||
from ee.onyx.server.scim.api import get_resource_types
|
||||
from ee.onyx.server.scim.api import get_schemas
|
||||
from ee.onyx.server.scim.api import get_service_provider_config
|
||||
from ee.onyx.server.scim.api import get_user
|
||||
from ee.onyx.server.scim.api import list_groups
|
||||
from ee.onyx.server.scim.api import list_users
|
||||
from ee.onyx.server.scim.api import patch_group
|
||||
from ee.onyx.server.scim.api import patch_user
|
||||
from ee.onyx.server.scim.api import replace_user
|
||||
from ee.onyx.server.scim.api import ScimJSONResponse
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimEnterpriseExtension
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimManagerRef
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
from ee.onyx.server.scim.models import ScimPatchRequest
|
||||
from ee.onyx.server.scim.models import ScimPatchResourceValue
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
from ee.onyx.server.scim.providers.entra import EntraProvider
|
||||
from tests.unit.onyx.server.scim.conftest import make_db_group
|
||||
from tests.unit.onyx.server.scim.conftest import make_db_user
|
||||
from tests.unit.onyx.server.scim.conftest import make_scim_user
|
||||
from tests.unit.onyx.server.scim.conftest import make_user_mapping
|
||||
from tests.unit.onyx.server.scim.conftest import parse_scim_group
|
||||
from tests.unit.onyx.server.scim.conftest import parse_scim_list
|
||||
from tests.unit.onyx.server.scim.conftest import parse_scim_user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def entra_provider() -> ScimProvider:
|
||||
"""An EntraProvider instance for Entra-specific endpoint tests."""
|
||||
return EntraProvider()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service Discovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEntraServiceDiscovery:
|
||||
"""Entra expects enterprise extension in discovery endpoints."""
|
||||
|
||||
def test_service_provider_config_advertises_patch(self) -> None:
|
||||
config = get_service_provider_config()
|
||||
assert config.patch.supported is True
|
||||
|
||||
def test_resource_types_include_enterprise_extension(self) -> None:
|
||||
result = get_resource_types()
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
assert "Resources" in parsed
|
||||
user_type = next(rt for rt in parsed["Resources"] if rt["id"] == "User")
|
||||
extension_schemas = [ext["schema"] for ext in user_type["schemaExtensions"]]
|
||||
assert SCIM_ENTERPRISE_USER_SCHEMA in extension_schemas
|
||||
|
||||
def test_schemas_include_enterprise_user(self) -> None:
|
||||
result = get_schemas()
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
schema_ids = [s["id"] for s in parsed["Resources"]]
|
||||
assert SCIM_ENTERPRISE_USER_SCHEMA in schema_ids
|
||||
|
||||
def test_enterprise_schema_has_expected_attributes(self) -> None:
|
||||
result = get_schemas()
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
enterprise = next(
|
||||
s for s in parsed["Resources"] if s["id"] == SCIM_ENTERPRISE_USER_SCHEMA
|
||||
)
|
||||
attr_names = {a["name"] for a in enterprise["attributes"]}
|
||||
assert "department" in attr_names
|
||||
assert "manager" in attr_names
|
||||
|
||||
def test_service_discovery_content_type(self) -> None:
|
||||
"""SCIM responses must use application/scim+json content type."""
|
||||
result = get_resource_types()
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
assert result.media_type == "application/scim+json"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User Lifecycle (Entra-specific)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEntraUserLifecycle:
|
||||
"""Test user CRUD through Entra's lens: enterprise schemas, PascalCase ops."""
|
||||
|
||||
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
|
||||
def test_create_user_includes_enterprise_schema(
|
||||
self,
|
||||
mock_seats: MagicMock, # noqa: ARG002
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_user_by_email.return_value = None
|
||||
resource = make_scim_user(userName="alice@contoso.com")
|
||||
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result, status=201)
|
||||
assert SCIM_ENTERPRISE_USER_SCHEMA in resource.schemas
|
||||
assert SCIM_USER_SCHEMA in resource.schemas
|
||||
|
||||
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
|
||||
def test_create_user_with_enterprise_extension(
|
||||
self,
|
||||
mock_seats: MagicMock, # noqa: ARG002
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Enterprise extension department/manager should round-trip on create."""
|
||||
mock_dal.get_user_by_email.return_value = None
|
||||
resource = make_scim_user(
|
||||
userName="alice@contoso.com",
|
||||
enterprise_extension=ScimEnterpriseExtension(
|
||||
department="Engineering",
|
||||
manager=ScimManagerRef(value="mgr-uuid-123"),
|
||||
),
|
||||
)
|
||||
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result, status=201)
|
||||
assert resource.enterprise_extension is not None
|
||||
assert resource.enterprise_extension.department == "Engineering"
|
||||
assert resource.enterprise_extension.manager is not None
|
||||
assert resource.enterprise_extension.manager.value == "mgr-uuid-123"
|
||||
|
||||
# Verify DAL received the enterprise fields
|
||||
mock_dal.create_user_mapping.assert_called_once()
|
||||
call_kwargs = mock_dal.create_user_mapping.call_args[1]
|
||||
assert call_kwargs["fields"] == ScimMappingFields(
|
||||
department="Engineering",
|
||||
manager="mgr-uuid-123",
|
||||
given_name="Test",
|
||||
family_name="User",
|
||||
)
|
||||
|
||||
def test_get_user_includes_enterprise_schema(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user(email="alice@contoso.com")
|
||||
mock_dal.get_user.return_value = user
|
||||
|
||||
result = get_user(
|
||||
user_id=str(user.id),
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result)
|
||||
assert SCIM_ENTERPRISE_USER_SCHEMA in resource.schemas
|
||||
|
||||
def test_get_user_returns_enterprise_extension_data(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""GET should return stored enterprise extension data."""
|
||||
user = make_db_user(email="alice@contoso.com")
|
||||
mock_dal.get_user.return_value = user
|
||||
mapping = make_user_mapping(user_id=user.id)
|
||||
mapping.department = "Sales"
|
||||
mapping.manager = "mgr-456"
|
||||
mock_dal.get_user_mapping_by_user_id.return_value = mapping
|
||||
|
||||
result = get_user(
|
||||
user_id=str(user.id),
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result)
|
||||
assert resource.enterprise_extension is not None
|
||||
assert resource.enterprise_extension.department == "Sales"
|
||||
assert resource.enterprise_extension.manager is not None
|
||||
assert resource.enterprise_extension.manager.value == "mgr-456"
|
||||
|
||||
def test_list_users_includes_enterprise_schema(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user(email="alice@contoso.com")
|
||||
mapping = make_user_mapping(external_id="entra-ext-1", user_id=user.id)
|
||||
mock_dal.list_users.return_value = ([(user, mapping)], 1)
|
||||
|
||||
result = list_users(
|
||||
filter=None,
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parsed = parse_scim_list(result)
|
||||
resource = parsed.Resources[0]
|
||||
assert isinstance(resource, ScimUserResource)
|
||||
assert SCIM_ENTERPRISE_USER_SCHEMA in resource.schemas
|
||||
|
||||
def test_patch_user_deactivate_with_pascal_case_replace(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra sends ``"Replace"`` (PascalCase) instead of ``"replace"``."""
|
||||
user = make_db_user(is_active=True)
|
||||
mock_dal.get_user.return_value = user
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op="Replace", # type: ignore[arg-type]
|
||||
path="active",
|
||||
value=False,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
# Mock doesn't propagate the change, so verify via the DAL call
|
||||
mock_dal.update_user.assert_called_once()
|
||||
call_kwargs = mock_dal.update_user.call_args
|
||||
assert call_kwargs[1]["is_active"] is False
|
||||
|
||||
def test_patch_user_add_external_id_with_pascal_case(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra sends ``"Add"`` (PascalCase) instead of ``"add"``."""
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op="Add", # type: ignore[arg-type]
|
||||
path="externalId",
|
||||
value="entra-ext-999",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
# Verify the patched externalId was synced to the DAL
|
||||
mock_dal.sync_user_external_id.assert_called_once()
|
||||
call_args = mock_dal.sync_user_external_id.call_args
|
||||
assert call_args[0][1] == "entra-ext-999"
|
||||
|
||||
def test_patch_user_enterprise_extension_in_value_dict(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra sends enterprise extension URN as key in path-less PATCH value
|
||||
dicts — enterprise data should be stored, not ignored."""
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
|
||||
value = ScimPatchResourceValue(active=False)
|
||||
assert value.__pydantic_extra__ is not None
|
||||
value.__pydantic_extra__[
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
||||
] = {"department": "Engineering"}
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.REPLACE,
|
||||
path=None,
|
||||
value=value,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
# Verify active=False was applied
|
||||
mock_dal.update_user.assert_called_once()
|
||||
call_kwargs = mock_dal.update_user.call_args
|
||||
assert call_kwargs[1]["is_active"] is False
|
||||
# Verify enterprise data was passed to DAL
|
||||
mock_dal.sync_user_external_id.assert_called_once()
|
||||
sync_kwargs = mock_dal.sync_user_external_id.call_args[1]
|
||||
assert sync_kwargs["fields"] == ScimMappingFields(
|
||||
department="Engineering",
|
||||
given_name="Test",
|
||||
family_name="User",
|
||||
scim_emails_json='[{"value": "test@example.com", "type": "work", "primary": true}]',
|
||||
)
|
||||
|
||||
def test_patch_user_remove_external_id(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""PATCH remove op should clear the target field."""
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
mapping = make_user_mapping(user_id=user.id)
|
||||
mapping.external_id = "ext-to-remove"
|
||||
mock_dal.get_user_mapping_by_user_id.return_value = mapping
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.REMOVE,
|
||||
path="externalId",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
# externalId should be cleared (None)
|
||||
mock_dal.sync_user_external_id.assert_called_once()
|
||||
call_args = mock_dal.sync_user_external_id.call_args
|
||||
assert call_args[0][1] is None
|
||||
|
||||
def test_patch_user_emails_primary_eq_true_value(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""PATCH with path emails[primary eq true].value should update
|
||||
the primary email entry, not userName."""
|
||||
user = make_db_user(email="old@contoso.com")
|
||||
mock_dal.get_user.return_value = user
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.REPLACE,
|
||||
path="emails[primary eq true].value",
|
||||
value="new@contoso.com",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result)
|
||||
# userName should remain unchanged — emails and userName are separate
|
||||
assert resource.userName == "old@contoso.com"
|
||||
# Primary email should be updated
|
||||
primary_emails = [e for e in resource.emails if e.primary]
|
||||
assert len(primary_emails) == 1
|
||||
assert primary_emails[0].value == "new@contoso.com"
|
||||
|
||||
def test_patch_user_enterprise_urn_department_path(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""PATCH with dotted enterprise URN path should store department."""
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.REPLACE,
|
||||
path="urn:ietf:params:scim:schemas:extension:enterprise:2.0:User:department",
|
||||
value="Marketing",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
mock_dal.sync_user_external_id.assert_called_once()
|
||||
sync_kwargs = mock_dal.sync_user_external_id.call_args[1]
|
||||
assert sync_kwargs["fields"] == ScimMappingFields(
|
||||
department="Marketing",
|
||||
given_name="Test",
|
||||
family_name="User",
|
||||
scim_emails_json='[{"value": "test@example.com", "type": "work", "primary": true}]',
|
||||
)
|
||||
|
||||
def test_replace_user_includes_enterprise_schema(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user(email="old@contoso.com")
|
||||
mock_dal.get_user.return_value = user
|
||||
resource = make_scim_user(
|
||||
userName="new@contoso.com",
|
||||
name=ScimName(givenName="New", familyName="Name"),
|
||||
)
|
||||
|
||||
result = replace_user(
|
||||
user_id=str(user.id),
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result)
|
||||
assert SCIM_ENTERPRISE_USER_SCHEMA in resource.schemas
|
||||
|
||||
def test_replace_user_with_enterprise_extension(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""PUT with enterprise extension should store the fields."""
|
||||
user = make_db_user(email="alice@contoso.com")
|
||||
mock_dal.get_user.return_value = user
|
||||
resource = make_scim_user(
|
||||
userName="alice@contoso.com",
|
||||
enterprise_extension=ScimEnterpriseExtension(
|
||||
department="HR",
|
||||
manager=ScimManagerRef(value="boss-id"),
|
||||
),
|
||||
)
|
||||
|
||||
result = replace_user(
|
||||
user_id=str(user.id),
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
mock_dal.sync_user_external_id.assert_called_once()
|
||||
sync_kwargs = mock_dal.sync_user_external_id.call_args[1]
|
||||
assert sync_kwargs["fields"] == ScimMappingFields(
|
||||
department="HR",
|
||||
manager="boss-id",
|
||||
given_name="Test",
|
||||
family_name="User",
|
||||
)
|
||||
|
||||
def test_delete_user_returns_204(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
mock_dal.get_user_mapping_by_user_id.return_value = MagicMock(id=1)
|
||||
|
||||
result = delete_user(
|
||||
user_id=str(user.id),
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.status_code == 204
|
||||
|
||||
def test_double_delete_returns_404(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
"""Second DELETE should return 404 — the SCIM mapping is gone."""
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
# No mapping — user was already deleted from SCIM's perspective
|
||||
mock_dal.get_user_mapping_by_user_id.return_value = None
|
||||
|
||||
result = delete_user(
|
||||
user_id=str(user.id),
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
assert result.status_code == 404
|
||||
|
||||
def test_name_formatted_preserved_on_create(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""When name.formatted is provided, it should be used as personal_name."""
|
||||
mock_dal.get_user_by_email.return_value = None
|
||||
resource = make_scim_user(
|
||||
userName="alice@contoso.com",
|
||||
name=ScimName(
|
||||
givenName="Alice",
|
||||
familyName="Smith",
|
||||
formatted="Dr. Alice Smith",
|
||||
),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"ee.onyx.server.scim.api._check_seat_availability", return_value=None
|
||||
):
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result, status=201)
|
||||
# The User constructor should have received the formatted name
|
||||
mock_dal.add_user.assert_called_once()
|
||||
created_user = mock_dal.add_user.call_args[0][0]
|
||||
assert created_user.personal_name == "Dr. Alice Smith"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Group Lifecycle (Entra-specific)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEntraGroupLifecycle:
|
||||
"""Test group CRUD with Entra-specific behaviors."""
|
||||
|
||||
def test_get_group_standard_response(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=10, name="Contoso Engineering")
|
||||
mock_dal.get_group.return_value = group
|
||||
uid = uuid4()
|
||||
mock_dal.get_group_members.return_value = [(uid, "alice@contoso.com")]
|
||||
|
||||
result = get_group(
|
||||
group_id="10",
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_group(result)
|
||||
assert resource.displayName == "Contoso Engineering"
|
||||
assert len(resource.members) == 1
|
||||
|
||||
def test_list_groups_with_excluded_attributes_members(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra sends ?excludedAttributes=members on group list queries."""
|
||||
group = make_db_group(id=10, name="Engineering")
|
||||
uid = uuid4()
|
||||
mock_dal.list_groups.return_value = ([(group, "ext-g-1")], 1)
|
||||
mock_dal.get_group_members.return_value = [(uid, "alice@contoso.com")]
|
||||
|
||||
result = list_groups(
|
||||
filter=None,
|
||||
excludedAttributes="members",
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
assert parsed["totalResults"] == 1
|
||||
resource = parsed["Resources"][0]
|
||||
assert "members" not in resource
|
||||
assert resource["displayName"] == "Engineering"
|
||||
|
||||
def test_get_group_with_excluded_attributes_members(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra sends ?excludedAttributes=members on single group GET."""
|
||||
group = make_db_group(id=10, name="Engineering")
|
||||
uid = uuid4()
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = [(uid, "alice@contoso.com")]
|
||||
|
||||
result = get_group(
|
||||
group_id="10",
|
||||
excludedAttributes="members",
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
assert "members" not in parsed
|
||||
assert parsed["displayName"] == "Engineering"
|
||||
|
||||
@patch("ee.onyx.server.scim.api.apply_group_patch")
|
||||
def test_patch_group_add_members_with_pascal_case(
|
||||
self,
|
||||
mock_apply: MagicMock,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra sends ``"Add"`` (PascalCase) for group member additions."""
|
||||
group = make_db_group(id=10)
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = []
|
||||
mock_dal.validate_member_ids.return_value = []
|
||||
|
||||
uid = str(uuid4())
|
||||
patched = ScimGroupResource(
|
||||
id="10",
|
||||
displayName="Engineering",
|
||||
members=[ScimGroupMember(value=uid)],
|
||||
)
|
||||
mock_apply.return_value = (patched, [uid], [])
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op="Add", # type: ignore[arg-type]
|
||||
path="members",
|
||||
value=[ScimGroupMember(value=uid)],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_group(
|
||||
group_id="10",
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_group(result)
|
||||
mock_dal.upsert_group_members.assert_called_once()
|
||||
|
||||
@patch("ee.onyx.server.scim.api.apply_group_patch")
|
||||
def test_patch_group_remove_member_with_pascal_case(
|
||||
self,
|
||||
mock_apply: MagicMock,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra sends ``"Remove"`` (PascalCase) for group member removals."""
|
||||
group = make_db_group(id=10)
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = []
|
||||
|
||||
uid = str(uuid4())
|
||||
patched = ScimGroupResource(id="10", displayName="Engineering", members=[])
|
||||
mock_apply.return_value = (patched, [], [uid])
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op="Remove", # type: ignore[arg-type]
|
||||
path=f'members[value eq "{uid}"]',
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_group(
|
||||
group_id="10",
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_group(result)
|
||||
mock_dal.remove_group_members.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# excludedAttributes (RFC 7644 §3.4.2.5)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExcludedAttributes:
|
||||
"""Test excludedAttributes query parameter on GET endpoints."""
|
||||
|
||||
def test_list_groups_excludes_members(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=1, name="Team")
|
||||
uid = uuid4()
|
||||
mock_dal.list_groups.return_value = ([(group, None)], 1)
|
||||
mock_dal.get_group_members.return_value = [(uid, "user@example.com")]
|
||||
|
||||
result = list_groups(
|
||||
filter=None,
|
||||
excludedAttributes="members",
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
resource = parsed["Resources"][0]
|
||||
assert "members" not in resource
|
||||
assert "displayName" in resource
|
||||
|
||||
def test_get_group_excludes_members(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=1, name="Team")
|
||||
uid = uuid4()
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = [(uid, "user@example.com")]
|
||||
|
||||
result = get_group(
|
||||
group_id="1",
|
||||
excludedAttributes="members",
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
assert "members" not in parsed
|
||||
assert "displayName" in parsed
|
||||
|
||||
def test_list_users_excludes_groups(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user()
|
||||
mapping = make_user_mapping(user_id=user.id)
|
||||
mock_dal.list_users.return_value = ([(user, mapping)], 1)
|
||||
mock_dal.get_users_groups_batch.return_value = {user.id: [(1, "Engineering")]}
|
||||
|
||||
result = list_users(
|
||||
filter=None,
|
||||
excludedAttributes="groups",
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
resource = parsed["Resources"][0]
|
||||
assert "groups" not in resource
|
||||
assert "userName" in resource
|
||||
|
||||
def test_get_user_excludes_groups(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
mock_dal.get_user_groups.return_value = [(1, "Engineering")]
|
||||
|
||||
result = get_user(
|
||||
user_id=str(user.id),
|
||||
excludedAttributes="groups",
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
assert "groups" not in parsed
|
||||
assert "userName" in parsed
|
||||
|
||||
def test_multiple_excluded_attributes(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=1, name="Team")
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = []
|
||||
|
||||
result = get_group(
|
||||
group_id="1",
|
||||
excludedAttributes="members,externalId",
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
assert "members" not in parsed
|
||||
assert "externalId" not in parsed
|
||||
assert "displayName" in parsed
|
||||
|
||||
def test_no_excluded_attributes_returns_full_response(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=1, name="Team")
|
||||
uid = uuid4()
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = [(uid, "user@example.com")]
|
||||
|
||||
result = get_group(
|
||||
group_id="1",
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_group(result)
|
||||
assert len(resource.members) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entra Connection Probe
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEntraConnectionProbe:
|
||||
"""Entra sends a probe request during initial SCIM setup."""
|
||||
|
||||
def test_filter_for_nonexistent_user_returns_empty_list(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra probes with: GET /Users?filter=userName eq "non-existent"&count=1"""
|
||||
mock_dal.list_users.return_value = ([], 0)
|
||||
|
||||
result = list_users(
|
||||
filter='userName eq "non-existent@contoso.com"',
|
||||
startIndex=1,
|
||||
count=1,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parsed = parse_scim_list(result)
|
||||
assert parsed.totalResults == 0
|
||||
assert parsed.Resources == []
|
||||
@@ -16,7 +16,6 @@ from ee.onyx.server.scim.api import patch_group
|
||||
from ee.onyx.server.scim.api import replace_group
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimListResponse
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
from ee.onyx.server.scim.models import ScimPatchRequest
|
||||
@@ -25,6 +24,8 @@ from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
from tests.unit.onyx.server.scim.conftest import assert_scim_error
|
||||
from tests.unit.onyx.server.scim.conftest import make_db_group
|
||||
from tests.unit.onyx.server.scim.conftest import make_scim_group
|
||||
from tests.unit.onyx.server.scim.conftest import parse_scim_group
|
||||
from tests.unit.onyx.server.scim.conftest import parse_scim_list
|
||||
|
||||
|
||||
class TestListGroups:
|
||||
@@ -48,9 +49,9 @@ class TestListGroups:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimListResponse)
|
||||
assert result.totalResults == 0
|
||||
assert result.Resources == []
|
||||
parsed = parse_scim_list(result)
|
||||
assert parsed.totalResults == 0
|
||||
assert parsed.Resources == []
|
||||
|
||||
def test_unsupported_filter_returns_400(
|
||||
self,
|
||||
@@ -95,9 +96,9 @@ class TestListGroups:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimListResponse)
|
||||
assert result.totalResults == 1
|
||||
resource = result.Resources[0]
|
||||
parsed = parse_scim_list(result)
|
||||
assert parsed.totalResults == 1
|
||||
resource = parsed.Resources[0]
|
||||
assert isinstance(resource, ScimGroupResource)
|
||||
assert resource.displayName == "Engineering"
|
||||
assert resource.externalId == "ext-g-1"
|
||||
@@ -126,9 +127,9 @@ class TestGetGroup:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimGroupResource)
|
||||
assert result.displayName == "Engineering"
|
||||
assert result.id == "5"
|
||||
resource = parse_scim_group(result)
|
||||
assert resource.displayName == "Engineering"
|
||||
assert resource.id == "5"
|
||||
|
||||
def test_non_integer_id_returns_404(
|
||||
self,
|
||||
@@ -190,8 +191,8 @@ class TestCreateGroup:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimGroupResource)
|
||||
assert result.displayName == "New Group"
|
||||
resource = parse_scim_group(result, status=201)
|
||||
assert resource.displayName == "New Group"
|
||||
mock_dal.add_group.assert_called_once()
|
||||
mock_dal.commit.assert_called_once()
|
||||
|
||||
@@ -283,7 +284,7 @@ class TestCreateGroup:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimGroupResource)
|
||||
parse_scim_group(result, status=201)
|
||||
mock_dal.create_group_mapping.assert_called_once()
|
||||
|
||||
|
||||
@@ -314,7 +315,7 @@ class TestReplaceGroup:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimGroupResource)
|
||||
parse_scim_group(result)
|
||||
mock_dal.update_group.assert_called_once_with(group, name="New Name")
|
||||
mock_dal.replace_group_members.assert_called_once()
|
||||
mock_dal.commit.assert_called_once()
|
||||
@@ -427,7 +428,7 @@ class TestPatchGroup:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimGroupResource)
|
||||
parse_scim_group(result)
|
||||
mock_dal.update_group.assert_called_once_with(group, name="New Name")
|
||||
|
||||
def test_not_found_returns_404(
|
||||
@@ -534,7 +535,7 @@ class TestPatchGroup:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimGroupResource)
|
||||
parse_scim_group(result)
|
||||
mock_dal.validate_member_ids.assert_called_once()
|
||||
mock_dal.upsert_group_members.assert_called_once()
|
||||
|
||||
@@ -614,7 +615,7 @@ class TestPatchGroup:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimGroupResource)
|
||||
parse_scim_group(result)
|
||||
mock_dal.remove_group_members.assert_called_once()
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from ee.onyx.server.scim.models import ScimEmail
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimMeta
|
||||
@@ -12,9 +13,11 @@ from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.patch import apply_group_patch
|
||||
from ee.onyx.server.scim.patch import apply_user_patch
|
||||
from ee.onyx.server.scim.patch import ScimPatchError
|
||||
from ee.onyx.server.scim.providers.entra import EntraProvider
|
||||
from ee.onyx.server.scim.providers.okta import OktaProvider
|
||||
|
||||
_OKTA_IGNORED = OktaProvider().ignored_patch_paths
|
||||
_ENTRA_IGNORED = EntraProvider().ignored_patch_paths
|
||||
|
||||
|
||||
def _make_user(**kwargs: object) -> ScimUserResource:
|
||||
@@ -56,36 +59,36 @@ class TestApplyUserPatch:
|
||||
|
||||
def test_deactivate_user(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch([_replace_op("active", False)], user)
|
||||
result, _ = apply_user_patch([_replace_op("active", False)], user)
|
||||
assert result.active is False
|
||||
assert result.userName == "test@example.com"
|
||||
|
||||
def test_activate_user(self) -> None:
|
||||
user = _make_user(active=False)
|
||||
result = apply_user_patch([_replace_op("active", True)], user)
|
||||
result, _ = apply_user_patch([_replace_op("active", True)], user)
|
||||
assert result.active is True
|
||||
|
||||
def test_replace_given_name(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch([_replace_op("name.givenName", "NewFirst")], user)
|
||||
result, _ = apply_user_patch([_replace_op("name.givenName", "NewFirst")], user)
|
||||
assert result.name is not None
|
||||
assert result.name.givenName == "NewFirst"
|
||||
assert result.name.familyName == "User"
|
||||
|
||||
def test_replace_family_name(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch([_replace_op("name.familyName", "NewLast")], user)
|
||||
result, _ = apply_user_patch([_replace_op("name.familyName", "NewLast")], user)
|
||||
assert result.name is not None
|
||||
assert result.name.familyName == "NewLast"
|
||||
|
||||
def test_replace_username(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch([_replace_op("userName", "new@example.com")], user)
|
||||
result, _ = apply_user_patch([_replace_op("userName", "new@example.com")], user)
|
||||
assert result.userName == "new@example.com"
|
||||
|
||||
def test_replace_without_path_uses_dict(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch(
|
||||
result, _ = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
None,
|
||||
@@ -99,7 +102,7 @@ class TestApplyUserPatch:
|
||||
|
||||
def test_multiple_operations(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch(
|
||||
result, _ = apply_user_patch(
|
||||
[
|
||||
_replace_op("active", False),
|
||||
_replace_op("name.givenName", "Updated"),
|
||||
@@ -112,7 +115,7 @@ class TestApplyUserPatch:
|
||||
|
||||
def test_case_insensitive_path(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch([_replace_op("Active", False)], user)
|
||||
result, _ = apply_user_patch([_replace_op("Active", False)], user)
|
||||
assert result.active is False
|
||||
|
||||
def test_original_not_mutated(self) -> None:
|
||||
@@ -125,15 +128,22 @@ class TestApplyUserPatch:
|
||||
with pytest.raises(ScimPatchError, match="Unsupported path"):
|
||||
apply_user_patch([_replace_op("unknownField", "value")], user)
|
||||
|
||||
def test_remove_op_on_user_raises(self) -> None:
|
||||
def test_remove_op_clears_field(self) -> None:
|
||||
"""Remove op should clear the target field (not raise)."""
|
||||
user = _make_user(externalId="ext-123")
|
||||
result, _ = apply_user_patch([_remove_op("externalId")], user)
|
||||
assert result.externalId is None
|
||||
|
||||
def test_remove_unsupported_path_raises(self) -> None:
|
||||
"""Remove op on unsupported path (e.g. 'active') should raise."""
|
||||
user = _make_user()
|
||||
with pytest.raises(ScimPatchError, match="Unsupported operation"):
|
||||
with pytest.raises(ScimPatchError, match="Unsupported remove path"):
|
||||
apply_user_patch([_remove_op("active")], user)
|
||||
|
||||
def test_replace_without_path_ignores_id(self) -> None:
|
||||
"""Okta sends 'id' alongside actual changes — it should be silently ignored."""
|
||||
user = _make_user()
|
||||
result = apply_user_patch(
|
||||
result, _ = apply_user_patch(
|
||||
[_replace_op(None, ScimPatchResourceValue(active=False, id="some-uuid"))],
|
||||
user,
|
||||
ignored_paths=_OKTA_IGNORED,
|
||||
@@ -143,7 +153,7 @@ class TestApplyUserPatch:
|
||||
def test_replace_without_path_ignores_schemas(self) -> None:
|
||||
"""The 'schemas' key in a value dict should be silently ignored."""
|
||||
user = _make_user()
|
||||
result = apply_user_patch(
|
||||
result, _ = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
None,
|
||||
@@ -161,7 +171,7 @@ class TestApplyUserPatch:
|
||||
def test_okta_deactivation_payload(self) -> None:
|
||||
"""Exact Okta deactivation payload: path-less replace with id + active."""
|
||||
user = _make_user()
|
||||
result = apply_user_patch(
|
||||
result, _ = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
None,
|
||||
@@ -176,7 +186,7 @@ class TestApplyUserPatch:
|
||||
|
||||
def test_replace_displayname(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch(
|
||||
result, _ = apply_user_patch(
|
||||
[_replace_op("displayName", "New Display Name")], user
|
||||
)
|
||||
assert result.displayName == "New Display Name"
|
||||
@@ -187,7 +197,7 @@ class TestApplyUserPatch:
|
||||
"""Okta sends id/schemas/meta alongside actual changes — complex types
|
||||
(lists, nested dicts) must not cause Pydantic validation errors."""
|
||||
user = _make_user()
|
||||
result = apply_user_patch(
|
||||
result, _ = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
None,
|
||||
@@ -207,9 +217,101 @@ class TestApplyUserPatch:
|
||||
|
||||
def test_add_operation_works_like_replace(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch([_add_op("externalId", "ext-456")], user)
|
||||
result, _ = apply_user_patch([_add_op("externalId", "ext-456")], user)
|
||||
assert result.externalId == "ext-456"
|
||||
|
||||
def test_entra_capitalized_replace_op(self) -> None:
|
||||
"""Entra ID sends ``"Replace"`` instead of ``"replace"``."""
|
||||
user = _make_user()
|
||||
op = ScimPatchOperation(op="Replace", path="active", value=False) # type: ignore[arg-type]
|
||||
result, _ = apply_user_patch([op], user)
|
||||
assert result.active is False
|
||||
|
||||
def test_entra_capitalized_add_op(self) -> None:
|
||||
"""Entra ID sends ``"Add"`` instead of ``"add"``."""
|
||||
user = _make_user()
|
||||
op = ScimPatchOperation(op="Add", path="externalId", value="ext-999") # type: ignore[arg-type]
|
||||
result, _ = apply_user_patch([op], user)
|
||||
assert result.externalId == "ext-999"
|
||||
|
||||
def test_entra_enterprise_extension_handled(self) -> None:
|
||||
"""Entra sends the enterprise extension URN as a key in path-less
|
||||
PATCH value dicts — enterprise data should be captured in ent_data."""
|
||||
user = _make_user()
|
||||
value = ScimPatchResourceValue(active=False)
|
||||
# Simulate Entra including the enterprise extension URN as extra data
|
||||
assert value.__pydantic_extra__ is not None
|
||||
value.__pydantic_extra__[
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
||||
] = {"department": "Engineering"}
|
||||
result, ent_data = apply_user_patch(
|
||||
[_replace_op(None, value)],
|
||||
user,
|
||||
ignored_paths=_ENTRA_IGNORED,
|
||||
)
|
||||
assert result.active is False
|
||||
assert result.userName == "test@example.com"
|
||||
assert ent_data["department"] == "Engineering"
|
||||
|
||||
def test_okta_handles_enterprise_extension_urn(self) -> None:
|
||||
"""Enterprise extension URN paths are handled universally, even
|
||||
for Okta — the data is captured in the enterprise data dict."""
|
||||
user = _make_user()
|
||||
value = ScimPatchResourceValue(active=False)
|
||||
assert value.__pydantic_extra__ is not None
|
||||
value.__pydantic_extra__[
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
||||
] = {"department": "Engineering"}
|
||||
result, ent_data = apply_user_patch(
|
||||
[_replace_op(None, value)],
|
||||
user,
|
||||
ignored_paths=_OKTA_IGNORED,
|
||||
)
|
||||
assert result.active is False
|
||||
assert ent_data["department"] == "Engineering"
|
||||
|
||||
def test_emails_primary_eq_true_value(self) -> None:
|
||||
"""emails[primary eq true].value should update the primary email entry."""
|
||||
user = _make_user(
|
||||
emails=[ScimEmail(value="old@example.com", type="work", primary=True)]
|
||||
)
|
||||
result, _ = apply_user_patch(
|
||||
[_replace_op("emails[primary eq true].value", "new@example.com")], user
|
||||
)
|
||||
# userName should remain unchanged — emails and userName are separate
|
||||
assert result.userName == "test@example.com"
|
||||
assert len(result.emails) == 1
|
||||
assert result.emails[0].value == "new@example.com"
|
||||
assert result.emails[0].primary is True
|
||||
|
||||
def test_enterprise_urn_department_path(self) -> None:
|
||||
"""Dotted enterprise URN path should set department in ent_data."""
|
||||
user = _make_user()
|
||||
_, ent_data = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User:department",
|
||||
"Marketing",
|
||||
)
|
||||
],
|
||||
user,
|
||||
)
|
||||
assert ent_data["department"] == "Marketing"
|
||||
|
||||
def test_enterprise_urn_manager_path(self) -> None:
|
||||
"""Dotted enterprise URN path for manager should set manager."""
|
||||
user = _make_user()
|
||||
_, ent_data = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User:manager",
|
||||
ScimPatchResourceValue.model_validate({"value": "boss-id"}),
|
||||
)
|
||||
],
|
||||
user,
|
||||
)
|
||||
assert ent_data["manager"] == "boss-id"
|
||||
|
||||
|
||||
class TestApplyGroupPatch:
|
||||
"""Tests for SCIM group PATCH operations."""
|
||||
|
||||
@@ -2,6 +2,8 @@ from unittest.mock import MagicMock
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimEmail
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
@@ -9,7 +11,10 @@ from ee.onyx.server.scim.models import ScimMeta
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimUserGroupRef
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
|
||||
from ee.onyx.server.scim.providers.base import get_default_provider
|
||||
from ee.onyx.server.scim.providers.entra import _ENTRA_IGNORED_PATCH_PATHS
|
||||
from ee.onyx.server.scim.providers.entra import EntraProvider
|
||||
from ee.onyx.server.scim.providers.okta import OktaProvider
|
||||
|
||||
|
||||
@@ -39,9 +44,7 @@ class TestOktaProvider:
|
||||
assert OktaProvider().name == "okta"
|
||||
|
||||
def test_ignored_patch_paths(self) -> None:
|
||||
assert OktaProvider().ignored_patch_paths == frozenset(
|
||||
{"id", "schemas", "meta"}
|
||||
)
|
||||
assert OktaProvider().ignored_patch_paths == COMMON_IGNORED_PATCH_PATHS
|
||||
|
||||
def test_build_user_resource_basic(self) -> None:
|
||||
provider = OktaProvider()
|
||||
@@ -60,6 +63,12 @@ class TestOktaProvider:
|
||||
meta=ScimMeta(resourceType="User"),
|
||||
)
|
||||
|
||||
def test_build_user_resource_has_core_schema_only(self) -> None:
|
||||
provider = OktaProvider()
|
||||
user = _make_mock_user()
|
||||
result = provider.build_user_resource(user, "ext-123")
|
||||
assert result.schemas == [SCIM_USER_SCHEMA]
|
||||
|
||||
def test_build_user_resource_with_groups(self) -> None:
|
||||
provider = OktaProvider()
|
||||
user = _make_mock_user()
|
||||
@@ -161,6 +170,42 @@ class TestOktaProvider:
|
||||
assert result.members == []
|
||||
|
||||
|
||||
class TestEntraProvider:
|
||||
def test_name(self) -> None:
|
||||
assert EntraProvider().name == "entra"
|
||||
|
||||
def test_ignored_patch_paths(self) -> None:
|
||||
paths = EntraProvider().ignored_patch_paths
|
||||
assert paths == _ENTRA_IGNORED_PATCH_PATHS
|
||||
# Enterprise extension URN is now handled (not ignored)
|
||||
assert paths >= COMMON_IGNORED_PATCH_PATHS
|
||||
|
||||
def test_build_user_resource_includes_enterprise_schema(self) -> None:
|
||||
provider = EntraProvider()
|
||||
user = _make_mock_user()
|
||||
result = provider.build_user_resource(user, "ext-entra-1")
|
||||
|
||||
assert result.schemas == [SCIM_USER_SCHEMA, SCIM_ENTERPRISE_USER_SCHEMA]
|
||||
|
||||
def test_build_user_resource_basic(self) -> None:
|
||||
provider = EntraProvider()
|
||||
user = _make_mock_user()
|
||||
result = provider.build_user_resource(user, "ext-entra-1")
|
||||
|
||||
assert result == ScimUserResource(
|
||||
schemas=[SCIM_USER_SCHEMA, SCIM_ENTERPRISE_USER_SCHEMA],
|
||||
id=str(user.id),
|
||||
externalId="ext-entra-1",
|
||||
userName="test@example.com",
|
||||
name=ScimName(givenName="Test", familyName="User", formatted="Test User"),
|
||||
displayName="Test User",
|
||||
emails=[ScimEmail(value="test@example.com", type="work", primary=True)],
|
||||
active=True,
|
||||
groups=[],
|
||||
meta=ScimMeta(resourceType="User"),
|
||||
)
|
||||
|
||||
|
||||
class TestGetDefaultProvider:
|
||||
def test_returns_okta(self) -> None:
|
||||
provider = get_default_provider()
|
||||
|
||||
@@ -16,7 +16,7 @@ from ee.onyx.server.scim.api import get_user
|
||||
from ee.onyx.server.scim.api import list_users
|
||||
from ee.onyx.server.scim.api import patch_user
|
||||
from ee.onyx.server.scim.api import replace_user
|
||||
from ee.onyx.server.scim.models import ScimListResponse
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
@@ -28,6 +28,8 @@ from tests.unit.onyx.server.scim.conftest import assert_scim_error
|
||||
from tests.unit.onyx.server.scim.conftest import make_db_user
|
||||
from tests.unit.onyx.server.scim.conftest import make_scim_user
|
||||
from tests.unit.onyx.server.scim.conftest import make_user_mapping
|
||||
from tests.unit.onyx.server.scim.conftest import parse_scim_list
|
||||
from tests.unit.onyx.server.scim.conftest import parse_scim_user
|
||||
|
||||
|
||||
class TestListUsers:
|
||||
@@ -51,9 +53,9 @@ class TestListUsers:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimListResponse)
|
||||
assert result.totalResults == 0
|
||||
assert result.Resources == []
|
||||
parsed = parse_scim_list(result)
|
||||
assert parsed.totalResults == 0
|
||||
assert parsed.Resources == []
|
||||
|
||||
def test_returns_users_with_scim_shape(
|
||||
self,
|
||||
@@ -77,10 +79,10 @@ class TestListUsers:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimListResponse)
|
||||
assert result.totalResults == 1
|
||||
assert len(result.Resources) == 1
|
||||
resource = result.Resources[0]
|
||||
parsed = parse_scim_list(result)
|
||||
assert parsed.totalResults == 1
|
||||
assert len(parsed.Resources) == 1
|
||||
resource = parsed.Resources[0]
|
||||
assert isinstance(resource, ScimUserResource)
|
||||
assert resource.userName == "Alice@example.com"
|
||||
assert resource.externalId == "ext-abc"
|
||||
@@ -146,9 +148,9 @@ class TestGetUser:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimUserResource)
|
||||
assert result.userName == "alice@example.com"
|
||||
assert result.id == str(user.id)
|
||||
resource = parse_scim_user(result)
|
||||
assert resource.userName == "alice@example.com"
|
||||
assert resource.id == str(user.id)
|
||||
|
||||
def test_invalid_uuid_returns_404(
|
||||
self,
|
||||
@@ -207,8 +209,8 @@ class TestCreateUser:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimUserResource)
|
||||
assert result.userName == "new@example.com"
|
||||
resource = parse_scim_user(result, status=201)
|
||||
assert resource.userName == "new@example.com"
|
||||
mock_dal.add_user.assert_called_once()
|
||||
mock_dal.commit.assert_called_once()
|
||||
|
||||
@@ -314,8 +316,8 @@ class TestCreateUser:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimUserResource)
|
||||
assert result.externalId == "ext-123"
|
||||
resource = parse_scim_user(result, status=201)
|
||||
assert resource.externalId == "ext-123"
|
||||
mock_dal.create_user_mapping.assert_called_once()
|
||||
|
||||
|
||||
@@ -344,7 +346,7 @@ class TestReplaceUser:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimUserResource)
|
||||
parse_scim_user(result)
|
||||
mock_dal.update_user.assert_called_once()
|
||||
mock_dal.commit.assert_called_once()
|
||||
|
||||
@@ -412,9 +414,15 @@ class TestReplaceUser:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimUserResource)
|
||||
parse_scim_user(result)
|
||||
mock_dal.sync_user_external_id.assert_called_once_with(
|
||||
user.id, None, scim_username="test@example.com"
|
||||
user.id,
|
||||
None,
|
||||
scim_username="test@example.com",
|
||||
fields=ScimMappingFields(
|
||||
given_name="Test",
|
||||
family_name="User",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -448,7 +456,7 @@ class TestPatchUser:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimUserResource)
|
||||
parse_scim_user(result)
|
||||
mock_dal.update_user.assert_called_once()
|
||||
|
||||
def test_not_found_returns_404(
|
||||
@@ -507,7 +515,7 @@ class TestPatchUser:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimUserResource)
|
||||
parse_scim_user(result)
|
||||
# Verify the update_user call received the new display name
|
||||
call_kwargs = mock_dal.update_user.call_args
|
||||
assert call_kwargs[1]["personal_name"] == "New Display Name"
|
||||
@@ -605,10 +613,12 @@ class TestDeleteUser:
|
||||
class TestScimNameToStr:
|
||||
"""Tests for _scim_name_to_str helper."""
|
||||
|
||||
def test_prefers_given_family_over_formatted(self) -> None:
|
||||
"""Okta may send stale formatted while updating givenName/familyName."""
|
||||
name = ScimName(givenName="Jane", familyName="Smith", formatted="Old Name")
|
||||
assert _scim_name_to_str(name) == "Jane Smith"
|
||||
def test_prefers_formatted_over_components(self) -> None:
|
||||
"""When client provides formatted, use it — the client knows what it wants."""
|
||||
name = ScimName(
|
||||
givenName="Jane", familyName="Smith", formatted="Dr. Jane Smith"
|
||||
)
|
||||
assert _scim_name_to_str(name) == "Dr. Jane Smith"
|
||||
|
||||
def test_given_name_only(self) -> None:
|
||||
name = ScimName(givenName="Jane")
|
||||
@@ -653,9 +663,9 @@ class TestEmailCasePreservation:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimUserResource)
|
||||
assert result.userName == "Alice@Example.COM"
|
||||
assert result.emails[0].value == "Alice@Example.COM"
|
||||
resource = parse_scim_user(result, status=201)
|
||||
assert resource.userName == "Alice@Example.COM"
|
||||
assert resource.emails[0].value == "Alice@Example.COM"
|
||||
|
||||
def test_get_preserves_username_case(
|
||||
self,
|
||||
@@ -681,6 +691,6 @@ class TestEmailCasePreservation:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimUserResource)
|
||||
assert result.userName == "Alice@Example.COM"
|
||||
assert result.emails[0].value == "Alice@Example.COM"
|
||||
resource = parse_scim_user(result)
|
||||
assert resource.userName == "Alice@Example.COM"
|
||||
assert resource.emails[0].value == "Alice@Example.COM"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user