mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-26 04:05:48 +00:00
Compare commits
2 Commits
main
...
fix-memory
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7394577001 | ||
|
|
974dab79d7 |
73
.github/actions/build-backend-image/action.yml
vendored
73
.github/actions/build-backend-image/action.yml
vendored
@@ -1,73 +0,0 @@
|
||||
name: "Build Backend Image"
|
||||
description: "Builds and pushes the backend Docker image with cache reuse"
|
||||
inputs:
|
||||
runs-on-ecr-cache:
|
||||
description: "ECR cache registry from runs-on/action"
|
||||
required: true
|
||||
ref-name:
|
||||
description: "Git ref name used for cache suffix fallback"
|
||||
required: true
|
||||
pr-number:
|
||||
description: "Optional PR number for cache suffix"
|
||||
required: false
|
||||
default: ""
|
||||
github-sha:
|
||||
description: "Commit SHA used for cache keys"
|
||||
required: true
|
||||
run-id:
|
||||
description: "GitHub run ID used in output image tag"
|
||||
required: true
|
||||
docker-username:
|
||||
description: "Docker Hub username"
|
||||
required: true
|
||||
docker-token:
|
||||
description: "Docker Hub token"
|
||||
required: true
|
||||
docker-no-cache:
|
||||
description: "Set to 'true' to disable docker build cache"
|
||||
required: false
|
||||
default: "false"
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
shell: bash
|
||||
env:
|
||||
PR_NUMBER: ${{ inputs.pr-number }}
|
||||
REF_NAME: ${{ inputs.ref-name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ inputs.docker-username }}
|
||||
password: ${{ inputs.docker-token }}
|
||||
|
||||
- name: Build and push Backend Docker image
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
push: true
|
||||
tags: ${{ inputs.runs-on-ecr-cache }}:nightly-llm-it-backend-${{ inputs.run-id }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ inputs.github-sha }}
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache
|
||||
type=registry,ref=onyxdotapp/onyx-backend:latest
|
||||
cache-to: |
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ inputs.github-sha }},mode=max
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache,mode=max
|
||||
no-cache: ${{ inputs.docker-no-cache == 'true' }}
|
||||
@@ -1,75 +0,0 @@
|
||||
name: "Build Integration Image"
|
||||
description: "Builds and pushes the integration test image with docker bake"
|
||||
inputs:
|
||||
runs-on-ecr-cache:
|
||||
description: "ECR cache registry from runs-on/action"
|
||||
required: true
|
||||
ref-name:
|
||||
description: "Git ref name used for cache suffix fallback"
|
||||
required: true
|
||||
pr-number:
|
||||
description: "Optional PR number for cache suffix"
|
||||
required: false
|
||||
default: ""
|
||||
github-sha:
|
||||
description: "Commit SHA used for cache keys"
|
||||
required: true
|
||||
run-id:
|
||||
description: "GitHub run ID used in output image tag"
|
||||
required: true
|
||||
docker-username:
|
||||
description: "Docker Hub username"
|
||||
required: true
|
||||
docker-token:
|
||||
description: "Docker Hub token"
|
||||
required: true
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ inputs.docker-username }}
|
||||
password: ${{ inputs.docker-token }}
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
shell: bash
|
||||
env:
|
||||
PR_NUMBER: ${{ inputs.pr-number }}
|
||||
REF_NAME: ${{ inputs.ref-name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Build and push integration test image with Docker Bake
|
||||
shell: bash
|
||||
env:
|
||||
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
|
||||
TAG: nightly-llm-it-${{ inputs.run-id }}
|
||||
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
|
||||
HEAD_SHA: ${{ inputs.github-sha }}
|
||||
run: |
|
||||
docker buildx bake --push \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
|
||||
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA},mode=max \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX},mode=max \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA} \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX} \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA},mode=max \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX},mode=max \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
|
||||
integration
|
||||
@@ -1,68 +0,0 @@
|
||||
name: "Build Model Server Image"
|
||||
description: "Builds and pushes the model server Docker image with cache reuse"
|
||||
inputs:
|
||||
runs-on-ecr-cache:
|
||||
description: "ECR cache registry from runs-on/action"
|
||||
required: true
|
||||
ref-name:
|
||||
description: "Git ref name used for cache suffix fallback"
|
||||
required: true
|
||||
pr-number:
|
||||
description: "Optional PR number for cache suffix"
|
||||
required: false
|
||||
default: ""
|
||||
github-sha:
|
||||
description: "Commit SHA used for cache keys"
|
||||
required: true
|
||||
run-id:
|
||||
description: "GitHub run ID used in output image tag"
|
||||
required: true
|
||||
docker-username:
|
||||
description: "Docker Hub username"
|
||||
required: true
|
||||
docker-token:
|
||||
description: "Docker Hub token"
|
||||
required: true
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
shell: bash
|
||||
env:
|
||||
PR_NUMBER: ${{ inputs.pr-number }}
|
||||
REF_NAME: ${{ inputs.ref-name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ inputs.docker-username }}
|
||||
password: ${{ inputs.docker-token }}
|
||||
|
||||
- name: Build and push Model Server Docker image
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
push: true
|
||||
tags: ${{ inputs.runs-on-ecr-cache }}:nightly-llm-it-model-server-${{ inputs.run-id }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ inputs.github-sha }}
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache
|
||||
type=registry,ref=onyxdotapp/onyx-model-server:latest
|
||||
cache-to: |
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ inputs.github-sha }},mode=max
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache,mode=max
|
||||
@@ -1,120 +0,0 @@
|
||||
name: "Run Nightly Provider Chat Test"
|
||||
description: "Starts required compose services and runs nightly provider integration test"
|
||||
inputs:
|
||||
provider:
|
||||
description: "Provider slug for NIGHTLY_LLM_PROVIDER"
|
||||
required: true
|
||||
models:
|
||||
description: "Comma-separated model list for NIGHTLY_LLM_MODELS"
|
||||
required: true
|
||||
provider-api-key:
|
||||
description: "API key for NIGHTLY_LLM_API_KEY"
|
||||
required: true
|
||||
strict:
|
||||
description: "String true/false for NIGHTLY_LLM_STRICT"
|
||||
required: true
|
||||
api-base:
|
||||
description: "Optional NIGHTLY_LLM_API_BASE"
|
||||
required: false
|
||||
default: ""
|
||||
custom-config-json:
|
||||
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON"
|
||||
required: false
|
||||
default: ""
|
||||
runs-on-ecr-cache:
|
||||
description: "ECR cache registry from runs-on/action"
|
||||
required: true
|
||||
run-id:
|
||||
description: "GitHub run ID used in image tags"
|
||||
required: true
|
||||
docker-username:
|
||||
description: "Docker Hub username"
|
||||
required: true
|
||||
docker-token:
|
||||
description: "Docker Hub token"
|
||||
required: true
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ inputs.docker-username }}
|
||||
password: ${{ inputs.docker-token }}
|
||||
|
||||
- name: Create .env file for Docker Compose
|
||||
shell: bash
|
||||
env:
|
||||
ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
|
||||
RUN_ID: ${{ inputs.run-id }}
|
||||
run: |
|
||||
cat <<EOF2 > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:nightly-llm-it-backend-${RUN_ID}
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:nightly-llm-it-model-server-${RUN_ID}
|
||||
EOF2
|
||||
|
||||
- name: Start Docker containers
|
||||
shell: bash
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d --wait \
|
||||
relational_db \
|
||||
index \
|
||||
cache \
|
||||
minio \
|
||||
api_server \
|
||||
inference_model_server
|
||||
|
||||
- name: Run nightly provider integration test
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
env:
|
||||
MODELS: ${{ inputs.models }}
|
||||
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
|
||||
NIGHTLY_LLM_API_KEY: ${{ inputs.provider-api-key }}
|
||||
NIGHTLY_LLM_API_BASE: ${{ inputs.api-base }}
|
||||
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom-config-json }}
|
||||
NIGHTLY_LLM_STRICT: ${{ inputs.strict }}
|
||||
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
|
||||
RUN_ID: ${{ inputs.run-id }}
|
||||
with:
|
||||
timeout_minutes: 20
|
||||
max_attempts: 2
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
if [ -z "${MODELS}" ]; then
|
||||
echo "Input 'models' must be non-empty for provider '${NIGHTLY_LLM_PROVIDER}'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e NIGHTLY_LLM_PROVIDER="${NIGHTLY_LLM_PROVIDER}" \
|
||||
-e NIGHTLY_LLM_MODELS="${MODELS}" \
|
||||
-e NIGHTLY_LLM_API_KEY="${NIGHTLY_LLM_API_KEY}" \
|
||||
-e NIGHTLY_LLM_API_BASE="${NIGHTLY_LLM_API_BASE}" \
|
||||
-e NIGHTLY_LLM_CUSTOM_CONFIG_JSON="${NIGHTLY_LLM_CUSTOM_CONFIG_JSON}" \
|
||||
-e NIGHTLY_LLM_STRICT="${NIGHTLY_LLM_STRICT}" \
|
||||
${RUNS_ON_ECR_CACHE}:nightly-llm-it-${RUN_ID} \
|
||||
/app/tests/integration/tests/llm_workflows/test_nightly_provider_chat_workflow.py
|
||||
@@ -1,44 +0,0 @@
|
||||
name: Nightly LLM Provider Chat Tests (OpenAI)
|
||||
concurrency:
|
||||
group: Nightly-LLM-Provider-Chat-OpenAI-${{ github.workflow }}-${{ github.ref_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# Runs daily at 10:30 UTC (2:30 AM PST / 3:30 AM PDT)
|
||||
- cron: "30 10 * * *"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
openai-provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
with:
|
||||
provider: openai
|
||||
models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
|
||||
strict: true
|
||||
secrets:
|
||||
provider_api_key: ${{ secrets.OPENAI_API_KEY }}
|
||||
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: [openai-provider-chat-test]
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
|
||||
failed-jobs: openai-provider-chat-test
|
||||
title: "🚨 Scheduled OpenAI Provider Chat Tests failed!"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
4
.github/workflows/pr-integration-tests.yml
vendored
4
.github/workflows/pr-integration-tests.yml
vendored
@@ -20,7 +20,6 @@ env:
|
||||
# Test Environment Variables
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
SLACK_BOT_TOKEN_TEST_SPACE: ${{ secrets.SLACK_BOT_TOKEN_TEST_SPACE }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
@@ -424,7 +423,6 @@ jobs:
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e EXA_API_KEY=${EXA_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e SLACK_BOT_TOKEN_TEST_SPACE=${SLACK_BOT_TOKEN_TEST_SPACE} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
@@ -445,7 +443,6 @@ jobs:
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
-e ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${{ matrix.edition == 'ee' && 'true' || 'false' }} \
|
||||
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
|
||||
/app/tests/integration/${{ matrix.test-dir.path }}
|
||||
|
||||
@@ -704,7 +701,6 @@ jobs:
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e EXA_API_KEY=${EXA_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e SLACK_BOT_TOKEN_TEST_SPACE=${SLACK_BOT_TOKEN_TEST_SPACE} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AUTH_TYPE=cloud \
|
||||
-e MULTI_TENANT=true \
|
||||
|
||||
@@ -1,206 +0,0 @@
|
||||
name: Reusable Nightly LLM Provider Chat Tests
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
provider:
|
||||
description: "Provider slug passed to NIGHTLY_LLM_PROVIDER (e.g. openai, anthropic)"
|
||||
required: true
|
||||
type: string
|
||||
models:
|
||||
description: "Comma-separated model list passed to NIGHTLY_LLM_MODELS"
|
||||
required: true
|
||||
type: string
|
||||
strict:
|
||||
description: "Pass-through value for NIGHTLY_LLM_STRICT"
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
api_base:
|
||||
description: "Optional NIGHTLY_LLM_API_BASE override"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
custom_config_json:
|
||||
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON override"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
secrets:
|
||||
provider_api_key:
|
||||
description: "Provider API key passed to NIGHTLY_LLM_API_KEY"
|
||||
required: true
|
||||
DOCKER_USERNAME:
|
||||
required: true
|
||||
DOCKER_TOKEN:
|
||||
required: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
|
||||
NIGHTLY_LLM_MODELS: ${{ inputs.models }}
|
||||
NIGHTLY_LLM_API_KEY: ${{ secrets.provider_api_key }}
|
||||
NIGHTLY_LLM_API_BASE: ${{ inputs.api_base }}
|
||||
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom_config_json }}
|
||||
NIGHTLY_LLM_STRICT: ${{ inputs.strict && 'true' || 'false' }}
|
||||
|
||||
jobs:
|
||||
validate-inputs:
|
||||
# NOTE: Keep this cheap and fail before image builds if required inputs are missing.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
steps:
|
||||
- name: Validate required nightly provider inputs
|
||||
run: |
|
||||
if [ -z "${NIGHTLY_LLM_MODELS}" ]; then
|
||||
echo "Input 'models' must be non-empty for provider '${NIGHTLY_LLM_PROVIDER}'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
build-backend-image:
|
||||
needs: [validate-inputs]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-backend-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Build backend image
|
||||
uses: ./.github/actions/build-backend-image
|
||||
with:
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
ref-name: ${{ github.ref_name }}
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
docker-no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' && 'true' || 'false' }}
|
||||
|
||||
build-model-server-image:
|
||||
needs: [validate-inputs]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-model-server-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Build model server image
|
||||
uses: ./.github/actions/build-model-server-image
|
||||
with:
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
ref-name: ${{ github.ref_name }}
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
build-integration-image:
|
||||
needs: [validate-inputs]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=2cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-integration-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Build integration image
|
||||
uses: ./.github/actions/build-integration-image
|
||||
with:
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
ref-name: ${{ github.ref_name }}
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
provider-chat-test:
|
||||
needs:
|
||||
[build-backend-image, build-model-server-image, build-integration-image]
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- "run-id=${{ github.run_id }}-nightly-${{ inputs.provider }}-provider-chat-test"
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Run nightly provider chat test
|
||||
uses: ./.github/actions/run-nightly-provider-chat-test
|
||||
with:
|
||||
provider: ${{ env.NIGHTLY_LLM_PROVIDER }}
|
||||
models: ${{ env.NIGHTLY_LLM_MODELS }}
|
||||
provider-api-key: ${{ secrets.provider_api_key }}
|
||||
strict: ${{ env.NIGHTLY_LLM_STRICT }}
|
||||
api-base: ${{ env.NIGHTLY_LLM_API_BASE }}
|
||||
custom-config-json: ${{ env.NIGHTLY_LLM_CUSTOM_CONFIG_JSON }}
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Dump API server logs
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
|
||||
- name: Dump all-container logs
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-nightly-${{ inputs.provider }}-llm-provider
|
||||
path: |
|
||||
${{ github.workspace }}/api_server.log
|
||||
${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Stop Docker containers
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose down -v
|
||||
@@ -548,7 +548,7 @@ class in the utils over directly calling the APIs with a library like `requests`
|
||||
calling the utilities directly (e.g. do NOT create admin users with
|
||||
`admin_user = UserManager.create(name="admin_user")`, instead use the `admin_user` fixture).
|
||||
|
||||
A great example of this type of test is `backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py`.
|
||||
A great example of this type of test is `backend/tests/integration/dev_apis/test_simple_chat_api.py`.
|
||||
|
||||
To run them:
|
||||
|
||||
@@ -616,9 +616,3 @@ This is a minimal list - feel free to include more. Do NOT write code as part of
|
||||
Keep it high level. You can reference certain files or functions though.
|
||||
|
||||
Before writing your plan, make sure to do research. Explore the relevant sections in the codebase.
|
||||
|
||||
## Best Practices
|
||||
|
||||
In addition to the other content in this file, best practices for contributing
|
||||
to the codebase can be found at `contributing_guides/best_practices.md`.
|
||||
Understand its contents and follow them.
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
"""add needs_persona_sync to user_file
|
||||
|
||||
Revision ID: 8ffcc2bcfc11
|
||||
Revises: 7616121f6e97
|
||||
Create Date: 2026-02-23 10:48:48.343826
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8ffcc2bcfc11"
|
||||
down_revision = "7616121f6e97"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column(
|
||||
"needs_persona_sync",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user_file", "needs_persona_sync")
|
||||
@@ -34,7 +34,6 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from ee.onyx.server.scim.filtering import ScimFilter
|
||||
from ee.onyx.server.scim.filtering import ScimFilterOperator
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from onyx.db.dal import DAL
|
||||
from onyx.db.models import ScimGroupMapping
|
||||
from onyx.db.models import ScimToken
|
||||
@@ -129,19 +128,12 @@ class ScimDAL(DAL):
|
||||
external_id: str,
|
||||
user_id: UUID,
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> ScimUserMapping:
|
||||
"""Create a mapping between a SCIM externalId and an Onyx user."""
|
||||
f = fields or ScimMappingFields()
|
||||
mapping = ScimUserMapping(
|
||||
external_id=external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
department=f.department,
|
||||
manager=f.manager,
|
||||
given_name=f.given_name,
|
||||
family_name=f.family_name,
|
||||
scim_emails_json=f.scim_emails_json,
|
||||
)
|
||||
self._session.add(mapping)
|
||||
self._session.flush()
|
||||
@@ -319,14 +311,8 @@ class ScimDAL(DAL):
|
||||
user_id: UUID,
|
||||
new_external_id: str | None,
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> None:
|
||||
"""Create, update, or delete the external ID mapping for a user.
|
||||
|
||||
When *fields* is provided, all mapping fields are written
|
||||
unconditionally — including ``None`` values — so that a caller can
|
||||
clear a previously-set field (e.g. removing a department).
|
||||
"""
|
||||
"""Create, update, or delete the external ID mapping for a user."""
|
||||
mapping = self.get_user_mapping_by_user_id(user_id)
|
||||
if new_external_id:
|
||||
if mapping:
|
||||
@@ -334,18 +320,11 @@ class ScimDAL(DAL):
|
||||
mapping.external_id = new_external_id
|
||||
if scim_username is not None:
|
||||
mapping.scim_username = scim_username
|
||||
if fields is not None:
|
||||
mapping.department = fields.department
|
||||
mapping.manager = fields.manager
|
||||
mapping.given_name = fields.given_name
|
||||
mapping.family_name = fields.family_name
|
||||
mapping.scim_emails_json = fields.scim_emails_json
|
||||
else:
|
||||
self.create_user_mapping(
|
||||
external_id=new_external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
elif mapping:
|
||||
self.delete_user_mapping(mapping.id)
|
||||
|
||||
@@ -4,6 +4,7 @@ from collections import deque
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from urllib.parse import unquote
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests as _requests
|
||||
@@ -597,12 +598,8 @@ def get_external_access_from_sharepoint(
|
||||
)
|
||||
elif site_page:
|
||||
site_url = site_page.get("webUrl")
|
||||
# Keep percent-encoding intact so the path matches the encoding
|
||||
# used by the Office365 library's SPResPath.create_relative(),
|
||||
# which compares against urlparse(context.base_url).path.
|
||||
# Decoding (e.g. %27 → ') causes a mismatch that duplicates
|
||||
# the site prefix in the constructed URL.
|
||||
server_relative_url = urlparse(site_url).path
|
||||
# Prefer server-relative URL to avoid OData filters that break on apostrophes
|
||||
server_relative_url = unquote(urlparse(site_url).path)
|
||||
file_obj = client_context.web.get_file_by_server_relative_url(
|
||||
server_relative_url
|
||||
)
|
||||
|
||||
@@ -26,14 +26,14 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from ee.onyx.server.scim.auth import verify_scim_token
|
||||
from ee.onyx.server.scim.filtering import parse_scim_filter
|
||||
from ee.onyx.server.scim.models import SCIM_LIST_RESPONSE_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimError
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimListResponse
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimPatchRequest
|
||||
from ee.onyx.server.scim.models import ScimResourceType
|
||||
from ee.onyx.server.scim.models import ScimSchemaDefinition
|
||||
from ee.onyx.server.scim.models import ScimServiceProviderConfig
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.patch import apply_group_patch
|
||||
@@ -41,8 +41,6 @@ from ee.onyx.server.scim.patch import apply_user_patch
|
||||
from ee.onyx.server.scim.patch import ScimPatchError
|
||||
from ee.onyx.server.scim.providers.base import get_default_provider
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
from ee.onyx.server.scim.providers.base import serialize_emails
|
||||
from ee.onyx.server.scim.schema_definitions import ENTERPRISE_USER_SCHEMA_DEF
|
||||
from ee.onyx.server.scim.schema_definitions import GROUP_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import GROUP_SCHEMA_DEF
|
||||
from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
|
||||
@@ -50,28 +48,15 @@ from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ScimJSONResponse(JSONResponse):
|
||||
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
|
||||
|
||||
media_type = "application/scim+json"
|
||||
|
||||
|
||||
# NOTE: All URL paths in this router (/ServiceProviderConfig, /ResourceTypes,
|
||||
# /Schemas, /Users, /Groups) are mandated by the SCIM spec (RFC 7643/7644).
|
||||
# IdPs like Okta and Azure AD hardcode these exact paths, so they cannot be
|
||||
# changed to kebab-case.
|
||||
|
||||
|
||||
scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
|
||||
|
||||
_pw_helper = PasswordHelper()
|
||||
@@ -101,39 +86,15 @@ def get_service_provider_config() -> ScimServiceProviderConfig:
|
||||
|
||||
|
||||
@scim_router.get("/ResourceTypes")
|
||||
def get_resource_types() -> ScimJSONResponse:
|
||||
"""List available SCIM resource types (RFC 7643 §6).
|
||||
|
||||
Wrapped in a ListResponse envelope (RFC 7644 §3.4.2) because IdPs
|
||||
like Entra ID expect a JSON object, not a bare array.
|
||||
"""
|
||||
resources = [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
|
||||
return ScimJSONResponse(
|
||||
content={
|
||||
"schemas": [SCIM_LIST_RESPONSE_SCHEMA],
|
||||
"totalResults": len(resources),
|
||||
"Resources": [
|
||||
r.model_dump(exclude_none=True, by_alias=True) for r in resources
|
||||
],
|
||||
}
|
||||
)
|
||||
def get_resource_types() -> list[ScimResourceType]:
|
||||
"""List available SCIM resource types (RFC 7643 §6)."""
|
||||
return [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
|
||||
|
||||
|
||||
@scim_router.get("/Schemas")
|
||||
def get_schemas() -> ScimJSONResponse:
|
||||
"""Return SCIM schema definitions (RFC 7643 §7).
|
||||
|
||||
Wrapped in a ListResponse envelope (RFC 7644 §3.4.2) because IdPs
|
||||
like Entra ID expect a JSON object, not a bare array.
|
||||
"""
|
||||
schemas = [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF, ENTERPRISE_USER_SCHEMA_DEF]
|
||||
return ScimJSONResponse(
|
||||
content={
|
||||
"schemas": [SCIM_LIST_RESPONSE_SCHEMA],
|
||||
"totalResults": len(schemas),
|
||||
"Resources": [s.model_dump(exclude_none=True) for s in schemas],
|
||||
}
|
||||
)
|
||||
def get_schemas() -> list[ScimSchemaDefinition]:
|
||||
"""Return SCIM schema definitions (RFC 7643 §7)."""
|
||||
return [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -141,45 +102,15 @@ def get_schemas() -> ScimJSONResponse:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _scim_error_response(status: int, detail: str) -> ScimJSONResponse:
|
||||
def _scim_error_response(status: int, detail: str) -> JSONResponse:
|
||||
"""Build a SCIM-compliant error response (RFC 7644 §3.12)."""
|
||||
logger.warning("SCIM error response: status=%s detail=%s", status, detail)
|
||||
body = ScimError(status=str(status), detail=detail)
|
||||
return ScimJSONResponse(
|
||||
return JSONResponse(
|
||||
status_code=status,
|
||||
content=body.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
def _parse_excluded_attributes(raw: str | None) -> set[str]:
|
||||
"""Parse the ``excludedAttributes`` query parameter (RFC 7644 §3.4.2.5).
|
||||
|
||||
Returns a set of lowercased attribute names to omit from responses.
|
||||
"""
|
||||
if not raw:
|
||||
return set()
|
||||
return {attr.strip().lower() for attr in raw.split(",") if attr.strip()}
|
||||
|
||||
|
||||
def _apply_exclusions(
|
||||
resource: ScimUserResource | ScimGroupResource,
|
||||
excluded: set[str],
|
||||
) -> dict:
|
||||
"""Serialize a SCIM resource, omitting attributes the IdP excluded.
|
||||
|
||||
RFC 7644 §3.4.2.5 lets the IdP pass ``?excludedAttributes=groups,emails``
|
||||
to reduce response payload size. We strip those fields after serialization
|
||||
so the rest of the pipeline doesn't need to know about them.
|
||||
"""
|
||||
data = resource.model_dump(exclude_none=True, by_alias=True)
|
||||
for attr in excluded:
|
||||
# Match case-insensitively against the camelCase field names
|
||||
keys_to_remove = [k for k in data if k.lower() == attr]
|
||||
for k in keys_to_remove:
|
||||
del data[k]
|
||||
return data
|
||||
|
||||
|
||||
def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
"""Return an error message if seat limit is reached, else None."""
|
||||
check_fn = fetch_ee_implementation_or_noop(
|
||||
@@ -193,7 +124,7 @@ def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | ScimJSONResponse:
|
||||
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | JSONResponse:
|
||||
"""Parse *user_id* as UUID, look up the user, or return a 404 error."""
|
||||
try:
|
||||
uid = UUID(user_id)
|
||||
@@ -213,95 +144,10 @@ def _scim_name_to_str(name: ScimName | None) -> str | None:
|
||||
"""
|
||||
if not name:
|
||||
return None
|
||||
# If the client explicitly provides ``formatted``, prefer it — the client
|
||||
# knows what display string it wants. Otherwise build from components.
|
||||
if name.formatted:
|
||||
return name.formatted
|
||||
# Build from givenName/familyName first — IdPs like Okta may send a stale
|
||||
# ``formatted`` value while updating the individual name components.
|
||||
parts = " ".join(part for part in [name.givenName, name.familyName] if part)
|
||||
return parts or None
|
||||
|
||||
|
||||
def _scim_resource_response(
|
||||
resource: ScimUserResource | ScimGroupResource | ScimListResponse,
|
||||
status_code: int = 200,
|
||||
) -> ScimJSONResponse:
|
||||
"""Serialize a SCIM resource as ``application/scim+json``."""
|
||||
content = resource.model_dump(exclude_none=True, by_alias=True)
|
||||
return ScimJSONResponse(
|
||||
status_code=status_code,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
def _build_list_response(
|
||||
resources: list[ScimUserResource | ScimGroupResource],
|
||||
total: int,
|
||||
start_index: int,
|
||||
count: int,
|
||||
excluded: set[str] | None = None,
|
||||
) -> ScimListResponse | ScimJSONResponse:
|
||||
"""Build a SCIM list response, optionally applying attribute exclusions.
|
||||
|
||||
RFC 7644 §3.4.2.5 — IdPs may request certain attributes be omitted via
|
||||
the ``excludedAttributes`` query parameter.
|
||||
"""
|
||||
if excluded:
|
||||
envelope = ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=start_index,
|
||||
itemsPerPage=count,
|
||||
)
|
||||
data = envelope.model_dump(exclude_none=True)
|
||||
data["Resources"] = [_apply_exclusions(r, excluded) for r in resources]
|
||||
return ScimJSONResponse(content=data)
|
||||
|
||||
return _scim_resource_response(
|
||||
ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=start_index,
|
||||
itemsPerPage=count,
|
||||
Resources=resources,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _extract_enterprise_fields(
|
||||
resource: ScimUserResource,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Extract department and manager from enterprise extension."""
|
||||
ext = resource.enterprise_extension
|
||||
if not ext:
|
||||
return None, None
|
||||
department = ext.department
|
||||
manager = ext.manager.value if ext.manager else None
|
||||
return department, manager
|
||||
|
||||
|
||||
def _mapping_to_fields(
|
||||
mapping: ScimUserMapping | None,
|
||||
) -> ScimMappingFields | None:
|
||||
"""Extract round-trip fields from a SCIM user mapping."""
|
||||
if not mapping:
|
||||
return None
|
||||
return ScimMappingFields(
|
||||
department=mapping.department,
|
||||
manager=mapping.manager,
|
||||
given_name=mapping.given_name,
|
||||
family_name=mapping.family_name,
|
||||
scim_emails_json=mapping.scim_emails_json,
|
||||
)
|
||||
|
||||
|
||||
def _fields_from_resource(resource: ScimUserResource) -> ScimMappingFields:
|
||||
"""Build mapping fields from an incoming SCIM user resource."""
|
||||
department, manager = _extract_enterprise_fields(resource)
|
||||
return ScimMappingFields(
|
||||
department=department,
|
||||
manager=manager,
|
||||
given_name=resource.name.givenName if resource.name else None,
|
||||
family_name=resource.name.familyName if resource.name else None,
|
||||
scim_emails_json=serialize_emails(resource.emails),
|
||||
)
|
||||
return parts or name.formatted
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -312,13 +158,12 @@ def _fields_from_resource(resource: ScimUserResource) -> ScimMappingFields:
|
||||
@scim_router.get("/Users", response_model=None)
|
||||
def list_users(
|
||||
filter: str | None = Query(None),
|
||||
excludedAttributes: str | None = None,
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(100, ge=0, le=500),
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimListResponse | ScimJSONResponse:
|
||||
) -> ScimListResponse | JSONResponse:
|
||||
"""List users with optional SCIM filter and pagination."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
@@ -340,54 +185,42 @@ def list_users(
|
||||
mapping.external_id if mapping else None,
|
||||
groups=user_groups_map.get(user.id, []),
|
||||
scim_username=mapping.scim_username if mapping else None,
|
||||
fields=_mapping_to_fields(mapping),
|
||||
)
|
||||
for user, mapping in users_with_mappings
|
||||
]
|
||||
|
||||
return _build_list_response(
|
||||
resources,
|
||||
total,
|
||||
startIndex,
|
||||
count,
|
||||
excluded=_parse_excluded_attributes(excludedAttributes),
|
||||
return ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=startIndex,
|
||||
itemsPerPage=count,
|
||||
Resources=resources,
|
||||
)
|
||||
|
||||
|
||||
@scim_router.get("/Users/{user_id}", response_model=None)
|
||||
def get_user(
|
||||
user_id: str,
|
||||
excludedAttributes: str | None = None,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | ScimJSONResponse:
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Get a single user by ID."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
|
||||
resource = provider.build_user_resource(
|
||||
return provider.build_user_resource(
|
||||
user,
|
||||
mapping.external_id if mapping else None,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=mapping.scim_username if mapping else None,
|
||||
fields=_mapping_to_fields(mapping),
|
||||
)
|
||||
|
||||
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
|
||||
excluded = _parse_excluded_attributes(excludedAttributes)
|
||||
if excluded:
|
||||
return ScimJSONResponse(content=_apply_exclusions(resource, excluded))
|
||||
|
||||
return _scim_resource_response(resource)
|
||||
|
||||
|
||||
@scim_router.post("/Users", status_code=201, response_model=None)
|
||||
def create_user(
|
||||
@@ -395,7 +228,7 @@ def create_user(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | ScimJSONResponse:
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Create a new user from a SCIM provisioning request."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
@@ -437,25 +270,13 @@ def create_user(
|
||||
# Create SCIM mapping (externalId is validated above, always present)
|
||||
external_id = user_resource.externalId
|
||||
scim_username = user_resource.userName.strip()
|
||||
fields = _fields_from_resource(user_resource)
|
||||
dal.create_user_mapping(
|
||||
external_id=external_id,
|
||||
user_id=user.id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
external_id=external_id, user_id=user.id, scim_username=scim_username
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
external_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
),
|
||||
status_code=201,
|
||||
)
|
||||
return provider.build_user_resource(user, external_id, scim_username=scim_username)
|
||||
|
||||
|
||||
@scim_router.put("/Users/{user_id}", response_model=None)
|
||||
@@ -465,13 +286,13 @@ def replace_user(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | ScimJSONResponse:
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Replace a user entirely (RFC 7644 §3.5.1)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
@@ -492,24 +313,15 @@ def replace_user(
|
||||
|
||||
new_external_id = user_resource.externalId
|
||||
scim_username = user_resource.userName.strip()
|
||||
fields = _fields_from_resource(user_resource)
|
||||
dal.sync_user_external_id(
|
||||
user.id,
|
||||
new_external_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
dal.sync_user_external_id(user.id, new_external_id, scim_username=scim_username)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
new_external_id,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
return provider.build_user_resource(
|
||||
user,
|
||||
new_external_id,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=scim_username,
|
||||
)
|
||||
|
||||
|
||||
@@ -520,7 +332,7 @@ def patch_user(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | ScimJSONResponse:
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Partially update a user (RFC 7644 §3.5.2).
|
||||
|
||||
This is the primary endpoint for user deprovisioning — Okta sends
|
||||
@@ -530,25 +342,23 @@ def patch_user(
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
external_id = mapping.external_id if mapping else None
|
||||
current_scim_username = mapping.scim_username if mapping else None
|
||||
current_fields = _mapping_to_fields(mapping)
|
||||
|
||||
current = provider.build_user_resource(
|
||||
user,
|
||||
external_id,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=current_scim_username,
|
||||
fields=current_fields,
|
||||
)
|
||||
|
||||
try:
|
||||
patched, ent_data = apply_user_patch(
|
||||
patched = apply_user_patch(
|
||||
patch_request.Operations, current, provider.ignored_patch_paths
|
||||
)
|
||||
except ScimPatchError as e:
|
||||
@@ -583,37 +393,17 @@ def patch_user(
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
# Build updated fields by merging PATCH enterprise data with current values
|
||||
cf = current_fields or ScimMappingFields()
|
||||
fields = ScimMappingFields(
|
||||
department=ent_data.get("department", cf.department),
|
||||
manager=ent_data.get("manager", cf.manager),
|
||||
given_name=patched.name.givenName if patched.name else cf.given_name,
|
||||
family_name=patched.name.familyName if patched.name else cf.family_name,
|
||||
scim_emails_json=(
|
||||
serialize_emails(patched.emails)
|
||||
if patched.emails is not None
|
||||
else cf.scim_emails_json
|
||||
),
|
||||
)
|
||||
|
||||
dal.sync_user_external_id(
|
||||
user.id,
|
||||
patched.externalId,
|
||||
scim_username=new_scim_username,
|
||||
fields=fields,
|
||||
user.id, patched.externalId, scim_username=new_scim_username
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
patched.externalId,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=new_scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
return provider.build_user_resource(
|
||||
user,
|
||||
patched.externalId,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=new_scim_username,
|
||||
)
|
||||
|
||||
|
||||
@@ -622,29 +412,25 @@ def delete_user(
|
||||
user_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response | ScimJSONResponse:
|
||||
) -> Response | JSONResponse:
|
||||
"""Delete a user (RFC 7644 §3.6).
|
||||
|
||||
Deactivates the user and removes the SCIM mapping. Note that Okta
|
||||
typically uses PATCH active=false instead of DELETE.
|
||||
A second DELETE returns 404 per RFC 7644 §3.6.
|
||||
"""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
# If no SCIM mapping exists, the user was already deleted from
|
||||
# SCIM's perspective — return 404 per RFC 7644 §3.6.
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
if not mapping:
|
||||
return _scim_error_response(404, f"User {user_id} not found")
|
||||
|
||||
dal.deactivate_user(user)
|
||||
dal.delete_user_mapping(mapping.id)
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
if mapping:
|
||||
dal.delete_user_mapping(mapping.id)
|
||||
|
||||
dal.commit()
|
||||
|
||||
@@ -656,7 +442,7 @@ def delete_user(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | ScimJSONResponse:
|
||||
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | JSONResponse:
|
||||
"""Parse *group_id* as int, look up the group, or return a 404 error."""
|
||||
try:
|
||||
gid = int(group_id)
|
||||
@@ -711,13 +497,12 @@ def _validate_and_parse_members(
|
||||
@scim_router.get("/Groups", response_model=None)
|
||||
def list_groups(
|
||||
filter: str | None = Query(None),
|
||||
excludedAttributes: str | None = None,
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(100, ge=0, le=500),
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimListResponse | ScimJSONResponse:
|
||||
) -> ScimListResponse | JSONResponse:
|
||||
"""List groups with optional SCIM filter and pagination."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
@@ -737,46 +522,37 @@ def list_groups(
|
||||
for group, ext_id in groups_with_ext_ids
|
||||
]
|
||||
|
||||
return _build_list_response(
|
||||
resources,
|
||||
total,
|
||||
startIndex,
|
||||
count,
|
||||
excluded=_parse_excluded_attributes(excludedAttributes),
|
||||
return ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=startIndex,
|
||||
itemsPerPage=count,
|
||||
Resources=resources,
|
||||
)
|
||||
|
||||
|
||||
@scim_router.get("/Groups/{group_id}", response_model=None)
|
||||
def get_group(
|
||||
group_id: str,
|
||||
excludedAttributes: str | None = None,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | ScimJSONResponse:
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Get a single group by ID."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
mapping = dal.get_group_mapping_by_group_id(group.id)
|
||||
members = dal.get_group_members(group.id)
|
||||
|
||||
resource = provider.build_group_resource(
|
||||
return provider.build_group_resource(
|
||||
group, members, mapping.external_id if mapping else None
|
||||
)
|
||||
|
||||
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
|
||||
excluded = _parse_excluded_attributes(excludedAttributes)
|
||||
if excluded:
|
||||
return ScimJSONResponse(content=_apply_exclusions(resource, excluded))
|
||||
|
||||
return _scim_resource_response(resource)
|
||||
|
||||
|
||||
@scim_router.post("/Groups", status_code=201, response_model=None)
|
||||
def create_group(
|
||||
@@ -784,7 +560,7 @@ def create_group(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | ScimJSONResponse:
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Create a new group from a SCIM provisioning request."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
@@ -820,10 +596,7 @@ def create_group(
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(db_group.id)
|
||||
return _scim_resource_response(
|
||||
provider.build_group_resource(db_group, members, external_id),
|
||||
status_code=201,
|
||||
)
|
||||
return provider.build_group_resource(db_group, members, external_id)
|
||||
|
||||
|
||||
@scim_router.put("/Groups/{group_id}", response_model=None)
|
||||
@@ -833,13 +606,13 @@ def replace_group(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | ScimJSONResponse:
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Replace a group entirely (RFC 7644 §3.5.1)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
@@ -854,9 +627,7 @@ def replace_group(
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
return _scim_resource_response(
|
||||
provider.build_group_resource(group, members, group_resource.externalId)
|
||||
)
|
||||
return provider.build_group_resource(group, members, group_resource.externalId)
|
||||
|
||||
|
||||
@scim_router.patch("/Groups/{group_id}", response_model=None)
|
||||
@@ -866,7 +637,7 @@ def patch_group(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | ScimJSONResponse:
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Partially update a group (RFC 7644 §3.5.2).
|
||||
|
||||
Handles member add/remove operations from Okta and Azure AD.
|
||||
@@ -875,7 +646,7 @@ def patch_group(
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
@@ -914,9 +685,7 @@ def patch_group(
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
return _scim_resource_response(
|
||||
provider.build_group_resource(group, members, patched.externalId)
|
||||
)
|
||||
return provider.build_group_resource(group, members, patched.externalId)
|
||||
|
||||
|
||||
@scim_router.delete("/Groups/{group_id}", status_code=204, response_model=None)
|
||||
@@ -924,13 +693,13 @@ def delete_group(
|
||||
group_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response | ScimJSONResponse:
|
||||
) -> Response | JSONResponse:
|
||||
"""Delete a group (RFC 7644 §3.6)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
if isinstance(result, JSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
|
||||
@@ -7,14 +7,12 @@ SCIM protocol schemas follow the wire format defined in:
|
||||
Admin API schemas are internal to Onyx and used for SCIM token management.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -33,9 +31,6 @@ SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA = (
|
||||
)
|
||||
SCIM_RESOURCE_TYPE_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:ResourceType"
|
||||
SCIM_SCHEMA_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Schema"
|
||||
SCIM_ENTERPRISE_USER_SCHEMA = (
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -75,36 +70,6 @@ class ScimUserGroupRef(BaseModel):
|
||||
display: str | None = None
|
||||
|
||||
|
||||
class ScimManagerRef(BaseModel):
|
||||
"""Manager sub-attribute for the enterprise extension (RFC 7643 §4.3)."""
|
||||
|
||||
value: str | None = None
|
||||
|
||||
|
||||
class ScimEnterpriseExtension(BaseModel):
|
||||
"""Enterprise User extension attributes (RFC 7643 §4.3)."""
|
||||
|
||||
department: str | None = None
|
||||
manager: ScimManagerRef | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScimMappingFields:
|
||||
"""Stored SCIM mapping fields that need to round-trip through the IdP.
|
||||
|
||||
Entra ID sends structured name components, email metadata, and enterprise
|
||||
extension attributes that must be returned verbatim in subsequent GET
|
||||
responses. These fields are persisted on ScimUserMapping and threaded
|
||||
through the DAL, provider, and endpoint layers.
|
||||
"""
|
||||
|
||||
department: str | None = None
|
||||
manager: str | None = None
|
||||
given_name: str | None = None
|
||||
family_name: str | None = None
|
||||
scim_emails_json: str | None = None
|
||||
|
||||
|
||||
class ScimUserResource(BaseModel):
|
||||
"""SCIM User resource representation (RFC 7643 §4.1).
|
||||
|
||||
@@ -113,8 +78,6 @@ class ScimUserResource(BaseModel):
|
||||
to match the SCIM wire format (not Python convention).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_USER_SCHEMA])
|
||||
id: str | None = None # Onyx's internal user ID, set on responses
|
||||
externalId: str | None = None # IdP's identifier for this user
|
||||
@@ -125,10 +88,6 @@ class ScimUserResource(BaseModel):
|
||||
active: bool = True
|
||||
groups: list[ScimUserGroupRef] = Field(default_factory=list)
|
||||
meta: ScimMeta | None = None
|
||||
enterprise_extension: ScimEnterpriseExtension | None = Field(
|
||||
default=None,
|
||||
alias="urn:ietf:params:scim:schemas:extension:enterprise:2.0:User",
|
||||
)
|
||||
|
||||
|
||||
class ScimGroupMember(BaseModel):
|
||||
@@ -206,19 +165,6 @@ class ScimPatchOperation(BaseModel):
|
||||
path: str | None = None
|
||||
value: ScimPatchValue = None
|
||||
|
||||
@field_validator("op", mode="before")
|
||||
@classmethod
|
||||
def normalize_operation(cls, v: object) -> object:
|
||||
"""Normalize op to lowercase for case-insensitive matching.
|
||||
|
||||
Some IdPs (e.g. Entra ID) send capitalized ops like ``"Replace"``
|
||||
instead of ``"replace"``. This is safe for all providers since the
|
||||
enum values are lowercase. If a future provider requires other
|
||||
pre-processing quirks, move patch deserialization into the provider
|
||||
subclass instead of adding more special cases here.
|
||||
"""
|
||||
return v.lower() if isinstance(v, str) else v
|
||||
|
||||
|
||||
class ScimPatchRequest(BaseModel):
|
||||
"""PATCH request body (RFC 7644 §3.5.2).
|
||||
|
||||
@@ -14,13 +14,8 @@ responsible for persisting changes.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
from typing import Any
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
@@ -29,55 +24,6 @@ from ee.onyx.server.scim.models import ScimPatchResourceValue
|
||||
from ee.onyx.server.scim.models import ScimPatchValue
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Lowercased enterprise extension URN for case-insensitive matching
|
||||
_ENTERPRISE_URN_LOWER = SCIM_ENTERPRISE_USER_SCHEMA.lower()
|
||||
|
||||
# Pattern for email filter paths, e.g.:
|
||||
# emails[primary eq true].value (Okta)
|
||||
# emails[type eq "work"].value (Azure AD / Entra ID)
|
||||
_EMAIL_FILTER_RE = re.compile(
|
||||
r"^emails\[.+\]\.value$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Pattern for member removal path: members[value eq "user-id"]
|
||||
_MEMBER_FILTER_RE = re.compile(
|
||||
r'^members\[value\s+eq\s+"([^"]+)"\]$',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dispatch tables for user PATCH paths
|
||||
#
|
||||
# Maps lowercased SCIM path → (camelCase key, target dict name).
|
||||
# "data" writes to the top-level resource dict, "name" writes to the
|
||||
# name sub-object dict. This replaces the elif chains for simple fields.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_USER_REPLACE_PATHS: dict[str, tuple[str, str]] = {
|
||||
"active": ("active", "data"),
|
||||
"username": ("userName", "data"),
|
||||
"externalid": ("externalId", "data"),
|
||||
"name.givenname": ("givenName", "name"),
|
||||
"name.familyname": ("familyName", "name"),
|
||||
"name.formatted": ("formatted", "name"),
|
||||
}
|
||||
|
||||
_USER_REMOVE_PATHS: dict[str, tuple[str, str]] = {
|
||||
"externalid": ("externalId", "data"),
|
||||
"name.givenname": ("givenName", "name"),
|
||||
"name.familyname": ("familyName", "name"),
|
||||
"name.formatted": ("formatted", "name"),
|
||||
"displayname": ("displayName", "data"),
|
||||
}
|
||||
|
||||
_GROUP_REPLACE_PATHS: dict[str, tuple[str, str]] = {
|
||||
"displayname": ("displayName", "data"),
|
||||
"externalid": ("externalId", "data"),
|
||||
}
|
||||
|
||||
|
||||
class ScimPatchError(Exception):
|
||||
"""Raised when a PATCH operation cannot be applied."""
|
||||
@@ -88,25 +34,18 @@ class ScimPatchError(Exception):
|
||||
super().__init__(detail)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _UserPatchCtx:
|
||||
"""Bundles the mutable state for user PATCH operations."""
|
||||
|
||||
data: dict[str, Any]
|
||||
name_data: dict[str, Any]
|
||||
ent_data: dict[str, str | None] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User PATCH
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pattern for member removal path: members[value eq "user-id"]
|
||||
_MEMBER_FILTER_RE = re.compile(
|
||||
r'^members\[value\s+eq\s+"([^"]+)"\]$',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def apply_user_patch(
|
||||
operations: list[ScimPatchOperation],
|
||||
current: ScimUserResource,
|
||||
ignored_paths: frozenset[str] = frozenset(),
|
||||
) -> tuple[ScimUserResource, dict[str, str | None]]:
|
||||
) -> ScimUserResource:
|
||||
"""Apply SCIM PATCH operations to a user resource.
|
||||
|
||||
Args:
|
||||
@@ -114,185 +53,79 @@ def apply_user_patch(
|
||||
current: The current user resource state.
|
||||
ignored_paths: SCIM attribute paths to silently skip (from provider).
|
||||
|
||||
Returns:
|
||||
A tuple of (modified user resource, enterprise extension data dict).
|
||||
The enterprise dict has keys ``"department"`` and ``"manager"``
|
||||
with values set only when a PATCH operation touched them.
|
||||
Returns a new ``ScimUserResource`` with the modifications applied.
|
||||
The original object is not mutated.
|
||||
|
||||
Raises:
|
||||
ScimPatchError: If an operation targets an unsupported path.
|
||||
"""
|
||||
data = current.model_dump()
|
||||
ctx = _UserPatchCtx(data=data, name_data=data.get("name") or {})
|
||||
name_data = data.get("name") or {}
|
||||
|
||||
for op in operations:
|
||||
if op.op in (ScimPatchOperationType.REPLACE, ScimPatchOperationType.ADD):
|
||||
_apply_user_replace(op, ctx, ignored_paths)
|
||||
elif op.op == ScimPatchOperationType.REMOVE:
|
||||
_apply_user_remove(op, ctx, ignored_paths)
|
||||
if op.op == ScimPatchOperationType.REPLACE:
|
||||
_apply_user_replace(op, data, name_data, ignored_paths)
|
||||
elif op.op == ScimPatchOperationType.ADD:
|
||||
_apply_user_replace(op, data, name_data, ignored_paths)
|
||||
else:
|
||||
raise ScimPatchError(
|
||||
f"Unsupported operation '{op.op.value}' on User resource"
|
||||
)
|
||||
|
||||
ctx.data["name"] = ctx.name_data
|
||||
return ScimUserResource.model_validate(ctx.data), ctx.ent_data
|
||||
data["name"] = name_data
|
||||
return ScimUserResource.model_validate(data)
|
||||
|
||||
|
||||
def _apply_user_replace(
|
||||
op: ScimPatchOperation,
|
||||
ctx: _UserPatchCtx,
|
||||
data: dict,
|
||||
name_data: dict,
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Apply a replace/add operation to user data."""
|
||||
path = (op.path or "").lower()
|
||||
|
||||
if not path:
|
||||
# No path — value is a resource dict of top-level attributes to set.
|
||||
# No path — value is a resource dict of top-level attributes to set
|
||||
if isinstance(op.value, ScimPatchResourceValue):
|
||||
for key, val in op.value.model_dump(exclude_unset=True).items():
|
||||
_set_user_field(key.lower(), val, ctx, ignored_paths, strict=False)
|
||||
_set_user_field(key.lower(), val, data, name_data, ignored_paths)
|
||||
else:
|
||||
raise ScimPatchError("Replace without path requires a dict value")
|
||||
return
|
||||
|
||||
_set_user_field(path, op.value, ctx, ignored_paths)
|
||||
|
||||
|
||||
def _apply_user_remove(
|
||||
op: ScimPatchOperation,
|
||||
ctx: _UserPatchCtx,
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Apply a remove operation to user data — clears the target field."""
|
||||
path = (op.path or "").lower()
|
||||
if not path:
|
||||
raise ScimPatchError("Remove operation requires a path")
|
||||
|
||||
if path in ignored_paths:
|
||||
return
|
||||
|
||||
entry = _USER_REMOVE_PATHS.get(path)
|
||||
if entry:
|
||||
key, target = entry
|
||||
target_dict = ctx.data if target == "data" else ctx.name_data
|
||||
target_dict[key] = None
|
||||
return
|
||||
|
||||
raise ScimPatchError(f"Unsupported remove path '{path}' for User PATCH")
|
||||
_set_user_field(path, op.value, data, name_data, ignored_paths)
|
||||
|
||||
|
||||
def _set_user_field(
|
||||
path: str,
|
||||
value: ScimPatchValue,
|
||||
ctx: _UserPatchCtx,
|
||||
data: dict,
|
||||
name_data: dict,
|
||||
ignored_paths: frozenset[str],
|
||||
*,
|
||||
strict: bool = True,
|
||||
) -> None:
|
||||
"""Set a single field on user data by SCIM path.
|
||||
|
||||
Args:
|
||||
strict: When ``False`` (path-less replace), unknown attributes are
|
||||
silently skipped. When ``True`` (explicit path), they raise.
|
||||
"""
|
||||
"""Set a single field on user data by SCIM path."""
|
||||
if path in ignored_paths:
|
||||
return
|
||||
|
||||
# Simple field writes handled by the dispatch table
|
||||
entry = _USER_REPLACE_PATHS.get(path)
|
||||
if entry:
|
||||
key, target = entry
|
||||
target_dict = ctx.data if target == "data" else ctx.name_data
|
||||
target_dict[key] = value
|
||||
return
|
||||
|
||||
# displayName sets both the top-level field and the name.formatted sub-field
|
||||
if path == "displayname":
|
||||
ctx.data["displayName"] = value
|
||||
ctx.name_data["formatted"] = value
|
||||
elif path == "name":
|
||||
if isinstance(value, dict):
|
||||
for k, v in value.items():
|
||||
ctx.name_data[k] = v
|
||||
elif path == "emails":
|
||||
if isinstance(value, list):
|
||||
ctx.data["emails"] = value
|
||||
elif _EMAIL_FILTER_RE.match(path):
|
||||
_update_primary_email(ctx.data, value)
|
||||
elif path.startswith(_ENTERPRISE_URN_LOWER):
|
||||
_set_enterprise_field(path, value, ctx.ent_data)
|
||||
elif not strict:
|
||||
return
|
||||
elif path == "active":
|
||||
data["active"] = value
|
||||
elif path == "username":
|
||||
data["userName"] = value
|
||||
elif path == "externalid":
|
||||
data["externalId"] = value
|
||||
elif path == "name.givenname":
|
||||
name_data["givenName"] = value
|
||||
elif path == "name.familyname":
|
||||
name_data["familyName"] = value
|
||||
elif path == "name.formatted":
|
||||
name_data["formatted"] = value
|
||||
elif path == "displayname":
|
||||
data["displayName"] = value
|
||||
name_data["formatted"] = value
|
||||
else:
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for User PATCH")
|
||||
|
||||
|
||||
def _update_primary_email(data: dict[str, Any], value: ScimPatchValue) -> None:
|
||||
"""Update the primary email entry via an email filter path."""
|
||||
emails: list[dict] = data.get("emails") or []
|
||||
for email_entry in emails:
|
||||
if email_entry.get("primary"):
|
||||
email_entry["value"] = value
|
||||
break
|
||||
else:
|
||||
emails.append({"value": value, "type": "work", "primary": True})
|
||||
data["emails"] = emails
|
||||
|
||||
|
||||
def _to_dict(value: ScimPatchValue) -> dict | None:
|
||||
"""Coerce a SCIM patch value to a plain dict if possible.
|
||||
|
||||
Pydantic may parse raw dicts as ``ScimPatchResourceValue`` (which uses
|
||||
``extra="allow"``), so we also dump those back to a dict.
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
if isinstance(value, ScimPatchResourceValue):
|
||||
return value.model_dump(exclude_unset=True)
|
||||
return None
|
||||
|
||||
|
||||
def _set_enterprise_field(
|
||||
path: str,
|
||||
value: ScimPatchValue,
|
||||
ent_data: dict[str, str | None],
|
||||
) -> None:
|
||||
"""Handle enterprise extension URN paths or value dicts."""
|
||||
# Full URN as key with dict value (path-less PATCH)
|
||||
# e.g. key="urn:...:user", value={"department": "Eng", "manager": {...}}
|
||||
if path == _ENTERPRISE_URN_LOWER:
|
||||
d = _to_dict(value)
|
||||
if d is not None:
|
||||
if "department" in d:
|
||||
ent_data["department"] = d["department"]
|
||||
if "manager" in d:
|
||||
mgr = d["manager"]
|
||||
if isinstance(mgr, dict):
|
||||
ent_data["manager"] = mgr.get("value")
|
||||
return
|
||||
|
||||
# Dotted URN path, e.g. "urn:...:user:department"
|
||||
suffix = path[len(_ENTERPRISE_URN_LOWER) :].lstrip(":").lower()
|
||||
if suffix == "department":
|
||||
ent_data["department"] = str(value) if value is not None else None
|
||||
elif suffix == "manager":
|
||||
d = _to_dict(value)
|
||||
if d is not None:
|
||||
ent_data["manager"] = d.get("value")
|
||||
elif isinstance(value, str):
|
||||
ent_data["manager"] = value
|
||||
else:
|
||||
# Unknown enterprise attributes are silently ignored rather than
|
||||
# rejected — IdPs may send attributes we don't model yet.
|
||||
logger.warning("Ignoring unknown enterprise extension attribute '%s'", suffix)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Group PATCH
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def apply_group_patch(
|
||||
operations: list[ScimPatchOperation],
|
||||
current: ScimGroupResource,
|
||||
@@ -402,14 +235,12 @@ def _set_group_field(
|
||||
"""Set a single field on group data by SCIM path."""
|
||||
if path in ignored_paths:
|
||||
return
|
||||
|
||||
entry = _GROUP_REPLACE_PATHS.get(path)
|
||||
if entry:
|
||||
key, _ = entry
|
||||
data[key] = value
|
||||
return
|
||||
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
|
||||
elif path == "displayname":
|
||||
data["displayName"] = value
|
||||
elif path == "externalid":
|
||||
data["externalId"] = value
|
||||
else:
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
|
||||
|
||||
|
||||
def _apply_group_add(
|
||||
|
||||
@@ -2,22 +2,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimEmail
|
||||
from ee.onyx.server.scim.models import ScimEnterpriseExtension
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimManagerRef
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from ee.onyx.server.scim.models import ScimMeta
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimUserGroupRef
|
||||
@@ -26,17 +17,6 @@ from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
COMMON_IGNORED_PATCH_PATHS: frozenset[str] = frozenset(
|
||||
{
|
||||
"id",
|
||||
"schemas",
|
||||
"meta",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ScimProvider(ABC):
|
||||
"""Base class for provider-specific SCIM behavior.
|
||||
|
||||
@@ -61,22 +41,12 @@ class ScimProvider(ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def user_schemas(self) -> list[str]:
|
||||
"""Schema URIs to include in User resource responses.
|
||||
|
||||
Override in subclasses to advertise additional schemas (e.g. the
|
||||
enterprise extension for Entra ID).
|
||||
"""
|
||||
return [SCIM_USER_SCHEMA]
|
||||
|
||||
def build_user_resource(
|
||||
self,
|
||||
user: User,
|
||||
external_id: str | None = None,
|
||||
groups: list[tuple[int, str]] | None = None,
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> ScimUserResource:
|
||||
"""Build a SCIM User response from an Onyx User.
|
||||
|
||||
@@ -88,48 +58,27 @@ class ScimProvider(ABC):
|
||||
for newly-created users.
|
||||
scim_username: The original-case userName from the IdP. Falls
|
||||
back to ``user.email`` (lowercase) when not available.
|
||||
fields: Stored mapping fields that the IdP expects round-tripped.
|
||||
"""
|
||||
f = fields or ScimMappingFields()
|
||||
group_refs = [
|
||||
ScimUserGroupRef(value=str(gid), display=gname)
|
||||
for gid, gname in (groups or [])
|
||||
]
|
||||
|
||||
# Use original-case userName if stored, otherwise fall back to the
|
||||
# lowercased email from the User model.
|
||||
username = scim_username or user.email
|
||||
|
||||
# Build enterprise extension when at least one value is present.
|
||||
# Dynamically add the enterprise URN to schemas per RFC 7643 §3.0.
|
||||
enterprise_ext: ScimEnterpriseExtension | None = None
|
||||
schemas = list(self.user_schemas)
|
||||
if f.department is not None or f.manager is not None:
|
||||
manager_ref = (
|
||||
ScimManagerRef(value=f.manager) if f.manager is not None else None
|
||||
)
|
||||
enterprise_ext = ScimEnterpriseExtension(
|
||||
department=f.department,
|
||||
manager=manager_ref,
|
||||
)
|
||||
if SCIM_ENTERPRISE_USER_SCHEMA not in schemas:
|
||||
schemas.append(SCIM_ENTERPRISE_USER_SCHEMA)
|
||||
|
||||
name = self.build_scim_name(user, f)
|
||||
emails = _deserialize_emails(f.scim_emails_json, username)
|
||||
|
||||
resource = ScimUserResource(
|
||||
schemas=schemas,
|
||||
return ScimUserResource(
|
||||
id=str(user.id),
|
||||
externalId=external_id,
|
||||
userName=username,
|
||||
name=name,
|
||||
name=self._build_scim_name(user),
|
||||
displayName=user.personal_name,
|
||||
emails=emails,
|
||||
emails=[ScimEmail(value=username, type="work", primary=True)],
|
||||
active=user.is_active,
|
||||
groups=group_refs,
|
||||
meta=ScimMeta(resourceType="User"),
|
||||
)
|
||||
resource.enterprise_extension = enterprise_ext
|
||||
return resource
|
||||
|
||||
def build_group_resource(
|
||||
self,
|
||||
@@ -149,24 +98,9 @@ class ScimProvider(ABC):
|
||||
meta=ScimMeta(resourceType="Group"),
|
||||
)
|
||||
|
||||
def build_scim_name(
|
||||
self,
|
||||
user: User,
|
||||
fields: ScimMappingFields,
|
||||
) -> ScimName | None:
|
||||
"""Build SCIM name components for the response.
|
||||
|
||||
Round-trips stored ``given_name``/``family_name`` when available (so
|
||||
the IdP gets back what it sent). Falls back to splitting
|
||||
``personal_name`` for users provisioned before we stored components.
|
||||
Providers may override for custom behavior.
|
||||
"""
|
||||
if fields.given_name is not None or fields.family_name is not None:
|
||||
return ScimName(
|
||||
givenName=fields.given_name,
|
||||
familyName=fields.family_name,
|
||||
formatted=user.personal_name,
|
||||
)
|
||||
@staticmethod
|
||||
def _build_scim_name(user: User) -> ScimName | None:
|
||||
"""Extract SCIM name components from a user's personal name."""
|
||||
if not user.personal_name:
|
||||
return None
|
||||
parts = user.personal_name.split(" ", 1)
|
||||
@@ -177,27 +111,6 @@ class ScimProvider(ABC):
|
||||
)
|
||||
|
||||
|
||||
def _deserialize_emails(stored_json: str | None, username: str) -> list[ScimEmail]:
|
||||
"""Deserialize stored email entries or build a default work email."""
|
||||
if stored_json:
|
||||
try:
|
||||
entries = json.loads(stored_json)
|
||||
if isinstance(entries, list) and entries:
|
||||
return [ScimEmail(**e) for e in entries]
|
||||
except (json.JSONDecodeError, TypeError, ValidationError):
|
||||
logger.warning(
|
||||
"Corrupt scim_emails_json, falling back to default: %s", stored_json
|
||||
)
|
||||
return [ScimEmail(value=username, type="work", primary=True)]
|
||||
|
||||
|
||||
def serialize_emails(emails: list[ScimEmail]) -> str | None:
|
||||
"""Serialize SCIM email entries to JSON for storage."""
|
||||
if not emails:
|
||||
return None
|
||||
return json.dumps([e.model_dump(exclude_none=True) for e in emails])
|
||||
|
||||
|
||||
def get_default_provider() -> ScimProvider:
|
||||
"""Return the default SCIM provider.
|
||||
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
"""Entra ID (Azure AD) SCIM provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
|
||||
_ENTRA_IGNORED_PATCH_PATHS = COMMON_IGNORED_PATCH_PATHS
|
||||
|
||||
|
||||
class EntraProvider(ScimProvider):
|
||||
"""Entra ID (Azure AD) SCIM provider.
|
||||
|
||||
Entra behavioral notes:
|
||||
- Sends capitalized PATCH ops (``"Add"``, ``"Replace"``, ``"Remove"``)
|
||||
— handled by ``ScimPatchOperation.normalize_op`` validator.
|
||||
- Sends the enterprise extension URN as a key in path-less PATCH value
|
||||
dicts — handled by ``_set_enterprise_field`` in ``patch.py`` to
|
||||
store department/manager values.
|
||||
- Expects the enterprise extension schema in ``schemas`` arrays and
|
||||
``/Schemas`` + ``/ResourceTypes`` discovery endpoints.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "entra"
|
||||
|
||||
@property
|
||||
def ignored_patch_paths(self) -> frozenset[str]:
|
||||
return _ENTRA_IGNORED_PATCH_PATHS
|
||||
|
||||
@property
|
||||
def user_schemas(self) -> list[str]:
|
||||
return [SCIM_USER_SCHEMA, SCIM_ENTERPRISE_USER_SCHEMA]
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
|
||||
|
||||
@@ -23,4 +22,4 @@ class OktaProvider(ScimProvider):
|
||||
|
||||
@property
|
||||
def ignored_patch_paths(self) -> frozenset[str]:
|
||||
return COMMON_IGNORED_PATCH_PATHS
|
||||
return frozenset({"id", "schemas", "meta"})
|
||||
|
||||
@@ -4,7 +4,6 @@ Pre-built at import time — these never change at runtime. Separated from
|
||||
api.py to keep the endpoint module focused on request handling.
|
||||
"""
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_GROUP_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimResourceType
|
||||
@@ -21,9 +20,6 @@ USER_RESOURCE_TYPE = ScimResourceType.model_validate(
|
||||
"endpoint": "/scim/v2/Users",
|
||||
"description": "SCIM User resource",
|
||||
"schema": SCIM_USER_SCHEMA,
|
||||
"schemaExtensions": [
|
||||
{"schema": SCIM_ENTERPRISE_USER_SCHEMA, "required": False}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
@@ -108,31 +104,6 @@ USER_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
],
|
||||
)
|
||||
|
||||
ENTERPRISE_USER_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
id=SCIM_ENTERPRISE_USER_SCHEMA,
|
||||
name="EnterpriseUser",
|
||||
description="Enterprise User extension (RFC 7643 §4.3)",
|
||||
attributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="department",
|
||||
type="string",
|
||||
description="Department.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="manager",
|
||||
type="complex",
|
||||
description="The user's manager.",
|
||||
subAttributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="value",
|
||||
type="string",
|
||||
description="Manager user ID.",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
GROUP_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
id=SCIM_GROUP_SCHEMA,
|
||||
name="Group",
|
||||
|
||||
@@ -22,7 +22,6 @@ from onyx.document_index.vespa_constants import HIDDEN
|
||||
from onyx.document_index.vespa_constants import IMAGE_FILE_NAME
|
||||
from onyx.document_index.vespa_constants import METADATA_LIST
|
||||
from onyx.document_index.vespa_constants import METADATA_SUFFIX
|
||||
from onyx.document_index.vespa_constants import PERSONAS
|
||||
from onyx.document_index.vespa_constants import PRIMARY_OWNERS
|
||||
from onyx.document_index.vespa_constants import SECONDARY_OWNERS
|
||||
from onyx.document_index.vespa_constants import SEMANTIC_IDENTIFIER
|
||||
@@ -59,7 +58,6 @@ FIELDS_NEEDED_FOR_TRANSFORMATION: list[str] = [
|
||||
METADATA_SUFFIX,
|
||||
DOCUMENT_SETS,
|
||||
USER_PROJECT,
|
||||
PERSONAS,
|
||||
PRIMARY_OWNERS,
|
||||
SECONDARY_OWNERS,
|
||||
ACCESS_CONTROL_LIST,
|
||||
@@ -278,7 +276,6 @@ def transform_vespa_chunks_to_opensearch_chunks(
|
||||
)
|
||||
)
|
||||
user_projects: list[int] | None = vespa_chunk.get(USER_PROJECT)
|
||||
personas: list[int] | None = vespa_chunk.get(PERSONAS)
|
||||
primary_owners: list[str] | None = vespa_chunk.get(PRIMARY_OWNERS)
|
||||
secondary_owners: list[str] | None = vespa_chunk.get(SECONDARY_OWNERS)
|
||||
|
||||
@@ -328,7 +325,6 @@ def transform_vespa_chunks_to_opensearch_chunks(
|
||||
metadata_suffix=metadata_suffix,
|
||||
document_sets=document_sets,
|
||||
user_projects=user_projects,
|
||||
personas=personas,
|
||||
primary_owners=primary_owners,
|
||||
secondary_owners=secondary_owners,
|
||||
tenant_id=tenant_state,
|
||||
|
||||
@@ -12,7 +12,6 @@ from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from retry import retry
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
@@ -713,10 +712,7 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
|
||||
db_session.execute(
|
||||
select(UserFile.id).where(
|
||||
sa.and_(
|
||||
sa.or_(
|
||||
UserFile.needs_project_sync.is_(True),
|
||||
UserFile.needs_persona_sync.is_(True),
|
||||
),
|
||||
UserFile.needs_project_sync.is_(True),
|
||||
UserFile.status == UserFileStatus.COMPLETED,
|
||||
)
|
||||
)
|
||||
@@ -776,11 +772,7 @@ def process_single_user_file_project_sync(
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file = db_session.execute(
|
||||
select(UserFile)
|
||||
.where(UserFile.id == _as_uuid(user_file_id))
|
||||
.options(selectinload(UserFile.assistants))
|
||||
).scalar_one_or_none()
|
||||
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - User file not found id={user_file_id}"
|
||||
@@ -808,17 +800,13 @@ def process_single_user_file_project_sync(
|
||||
]
|
||||
|
||||
project_ids = [project.id for project in user_file.projects]
|
||||
persona_ids = [p.id for p in user_file.assistants if not p.deleted]
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.update_single(
|
||||
doc_id=str(user_file.id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=user_file.chunk_count,
|
||||
fields=None,
|
||||
user_fields=VespaDocumentUserFields(
|
||||
user_projects=project_ids,
|
||||
personas=persona_ids,
|
||||
),
|
||||
user_fields=VespaDocumentUserFields(user_projects=project_ids),
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
@@ -826,7 +814,6 @@ def process_single_user_file_project_sync(
|
||||
)
|
||||
|
||||
user_file.needs_project_sync = False
|
||||
user_file.needs_persona_sync = False
|
||||
user_file.last_project_sync_at = datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
)
|
||||
|
||||
@@ -210,10 +210,10 @@ AUTH_COOKIE_EXPIRE_TIME_SECONDS = int(
|
||||
REQUIRE_EMAIL_VERIFICATION = (
|
||||
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
|
||||
)
|
||||
SMTP_SERVER = os.environ.get("SMTP_SERVER") or ""
|
||||
SMTP_SERVER = os.environ.get("SMTP_SERVER") or "smtp.gmail.com"
|
||||
SMTP_PORT = int(os.environ.get("SMTP_PORT") or "587")
|
||||
SMTP_USER = os.environ.get("SMTP_USER") or ""
|
||||
SMTP_PASS = os.environ.get("SMTP_PASS") or ""
|
||||
SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
|
||||
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
|
||||
EMAIL_FROM = os.environ.get("EMAIL_FROM") or SMTP_USER
|
||||
|
||||
SENDGRID_API_KEY = os.environ.get("SENDGRID_API_KEY") or ""
|
||||
|
||||
@@ -16,22 +16,6 @@ from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_RATE_LIMIT_REASONS = {"userRateLimitExceeded", "rateLimitExceeded"}
|
||||
|
||||
|
||||
def _is_rate_limit_error(error: HttpError) -> bool:
|
||||
"""Google sometimes returns rate-limit errors as 403 with reason
|
||||
'userRateLimitExceeded' instead of 429. This helper detects both."""
|
||||
if error.resp.status == 429:
|
||||
return True
|
||||
if error.resp.status != 403:
|
||||
return False
|
||||
error_details = getattr(error, "error_details", None) or []
|
||||
for detail in error_details:
|
||||
if isinstance(detail, dict) and detail.get("reason") in _RATE_LIMIT_REASONS:
|
||||
return True
|
||||
return "userRateLimitExceeded" in str(error) or "rateLimitExceeded" in str(error)
|
||||
|
||||
|
||||
# Google Drive APIs are quite flakey and may 500 for an
|
||||
# extended period of time. This is now addressed by checkpointing.
|
||||
@@ -73,7 +57,7 @@ def _execute_with_retry(request: Any) -> Any:
|
||||
except HttpError as error:
|
||||
attempt += 1
|
||||
|
||||
if _is_rate_limit_error(error):
|
||||
if error.resp.status == 429:
|
||||
# Attempt to get 'Retry-After' from headers
|
||||
retry_after = error.resp.get("Retry-After")
|
||||
if retry_after:
|
||||
@@ -156,16 +140,16 @@ def _execute_single_retrieval(
|
||||
)
|
||||
logger.error(f"Error executing request: {e}")
|
||||
raise e
|
||||
elif _is_rate_limit_error(e):
|
||||
results = _execute_with_retry(
|
||||
lambda: retrieval_function(**request_kwargs).execute()
|
||||
)
|
||||
elif e.resp.status == 404 or e.resp.status == 403:
|
||||
if continue_on_404_or_403:
|
||||
logger.debug(f"Error executing request: {e}")
|
||||
results = {}
|
||||
else:
|
||||
raise e
|
||||
elif e.resp.status == 429:
|
||||
results = _execute_with_retry(
|
||||
lambda: retrieval_function(**request_kwargs).execute()
|
||||
)
|
||||
else:
|
||||
logger.exception("Error executing request:")
|
||||
raise e
|
||||
|
||||
@@ -147,9 +147,7 @@ class DriveItemData(BaseModel):
|
||||
self.id,
|
||||
ResourcePath("items", ResourcePath(self.drive_id, ResourcePath("drives"))),
|
||||
)
|
||||
item = DriveItem(graph_client, path)
|
||||
item.set_property("id", self.id)
|
||||
return item
|
||||
return DriveItem(graph_client, path)
|
||||
|
||||
|
||||
# The office365 library's ClientContext caches the access token from its
|
||||
|
||||
@@ -11,7 +11,6 @@ from dateutil import parser
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
@@ -259,21 +258,3 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
|
||||
slim_doc_batch = []
|
||||
if slim_doc_batch:
|
||||
yield slim_doc_batch
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""
|
||||
Very basic validation, we could do more here
|
||||
"""
|
||||
if not self.base_url.startswith("https://") and not self.base_url.startswith(
|
||||
"http://"
|
||||
):
|
||||
raise ConnectorValidationError(
|
||||
"Base URL must start with https:// or http://"
|
||||
)
|
||||
|
||||
try:
|
||||
get_all_post_ids(self.slab_bot_token)
|
||||
except ConnectorMissingCredentialError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise ConnectorValidationError(f"Failed to fetch posts from Slab: {e}")
|
||||
|
||||
@@ -72,7 +72,6 @@ class BaseFilters(BaseModel):
|
||||
class UserFileFilters(BaseModel):
|
||||
user_file_ids: list[UUID] | None = None
|
||||
project_id: int | None = None
|
||||
persona_id: int | None = None
|
||||
|
||||
|
||||
class AssistantKnowledgeFilters(BaseModel):
|
||||
|
||||
@@ -40,7 +40,6 @@ def _build_index_filters(
|
||||
user_provided_filters: BaseFilters | None,
|
||||
user: User, # Used for ACLs, anonymous users only see public docs
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
user_file_ids: list[UUID] | None,
|
||||
persona_document_sets: list[str] | None,
|
||||
persona_time_cutoff: datetime | None,
|
||||
@@ -119,7 +118,6 @@ def _build_index_filters(
|
||||
final_filters = IndexFilters(
|
||||
user_file_ids=user_file_ids,
|
||||
project_id=project_id,
|
||||
persona_id=persona_id,
|
||||
source_type=source_filter,
|
||||
document_set=document_set_filter,
|
||||
time_cutoff=time_filter,
|
||||
@@ -267,8 +265,6 @@ def search_pipeline(
|
||||
llm: LLM | None = None,
|
||||
# If a project ID is provided, it will be exclusively scoped to that project
|
||||
project_id: int | None = None,
|
||||
# If a persona_id is provided, search scopes to files attached to this persona
|
||||
persona_id: int | None = None,
|
||||
# Pre-fetched data — when provided, avoids DB queries (no session needed)
|
||||
acl_filters: list[str] | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
@@ -303,7 +299,6 @@ def search_pipeline(
|
||||
user_provided_filters=chunk_search_request.user_selected_filters,
|
||||
user=user,
|
||||
project_id=project_id,
|
||||
persona_id=persona_id,
|
||||
user_file_ids=user_uploaded_persona_files,
|
||||
persona_document_sets=persona_document_sets,
|
||||
persona_time_cutoff=persona_time_cutoff,
|
||||
|
||||
@@ -4270,9 +4270,6 @@ class UserFile(Base):
|
||||
needs_project_sync: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
needs_persona_sync: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
last_project_sync_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
@@ -765,9 +765,6 @@ def mark_persona_as_deleted(
|
||||
) -> None:
|
||||
persona = get_persona_by_id(persona_id=persona_id, user=user, db_session=db_session)
|
||||
persona.deleted = True
|
||||
affected_file_ids = [uf.id for uf in persona.user_files]
|
||||
if affected_file_ids:
|
||||
_mark_files_need_persona_sync(db_session, affected_file_ids)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -779,13 +776,11 @@ def mark_persona_as_not_deleted(
|
||||
persona = get_persona_by_id(
|
||||
persona_id=persona_id, user=user, db_session=db_session, include_deleted=True
|
||||
)
|
||||
if not persona.deleted:
|
||||
if persona.deleted:
|
||||
persona.deleted = False
|
||||
db_session.commit()
|
||||
else:
|
||||
raise ValueError(f"Persona with ID {persona_id} is not deleted.")
|
||||
persona.deleted = False
|
||||
affected_file_ids = [uf.id for uf in persona.user_files]
|
||||
if affected_file_ids:
|
||||
_mark_files_need_persona_sync(db_session, affected_file_ids)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_delete_persona_by_name(
|
||||
@@ -851,20 +846,6 @@ def update_personas_display_priority(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _mark_files_need_persona_sync(
|
||||
db_session: Session,
|
||||
user_file_ids: list[UUID],
|
||||
) -> None:
|
||||
"""Flag the given UserFile rows so the background sync task picks them up
|
||||
and updates their persona metadata in the vector DB."""
|
||||
if not user_file_ids:
|
||||
return
|
||||
db_session.query(UserFile).filter(UserFile.id.in_(user_file_ids)).update(
|
||||
{UserFile.needs_persona_sync: True},
|
||||
synchronize_session=False,
|
||||
)
|
||||
|
||||
|
||||
def upsert_persona(
|
||||
user: User | None,
|
||||
name: str,
|
||||
@@ -1053,13 +1034,8 @@ def upsert_persona(
|
||||
existing_persona.tools = tools or []
|
||||
|
||||
if user_file_ids is not None:
|
||||
old_file_ids = {uf.id for uf in existing_persona.user_files}
|
||||
new_file_ids = {uf.id for uf in (user_files or [])}
|
||||
affected_file_ids = old_file_ids | new_file_ids
|
||||
existing_persona.user_files.clear()
|
||||
existing_persona.user_files = user_files or []
|
||||
if affected_file_ids:
|
||||
_mark_files_need_persona_sync(db_session, list(affected_file_ids))
|
||||
|
||||
if hierarchy_node_ids is not None:
|
||||
existing_persona.hierarchy_nodes.clear()
|
||||
@@ -1113,8 +1089,6 @@ def upsert_persona(
|
||||
attached_documents=attached_documents or [],
|
||||
)
|
||||
db_session.add(new_persona)
|
||||
if user_files:
|
||||
_mark_files_need_persona_sync(db_session, [uf.id for uf in user_files])
|
||||
persona = new_persona
|
||||
if commit:
|
||||
db_session.commit()
|
||||
|
||||
@@ -2,7 +2,6 @@ import random
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from logging import getLogger
|
||||
from uuid import UUID
|
||||
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.chat import create_chat_session
|
||||
@@ -14,26 +13,18 @@ from onyx.db.models import ChatSession
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def seed_chat_history(
|
||||
num_sessions: int,
|
||||
num_messages: int,
|
||||
days: int,
|
||||
user_id: UUID | None = None,
|
||||
persona_id: int | None = None,
|
||||
) -> None:
|
||||
def seed_chat_history(num_sessions: int, num_messages: int, days: int) -> None:
|
||||
"""Utility function to seed chat history for testing.
|
||||
|
||||
num_sessions: the number of sessions to seed
|
||||
num_messages: the number of messages to seed per sessions
|
||||
days: the number of days looking backwards from the current time over which to randomize
|
||||
the times.
|
||||
user_id: optional user to associate with sessions
|
||||
persona_id: optional persona/assistant to associate with sessions
|
||||
"""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
logger.info(f"Seeding {num_sessions} sessions.")
|
||||
for y in range(0, num_sessions):
|
||||
create_chat_session(db_session, f"pytest_session_{y}", user_id, persona_id)
|
||||
create_chat_session(db_session, f"pytest_session_{y}", None, None)
|
||||
|
||||
# randomize all session times
|
||||
logger.info(f"Seeding {num_messages} messages per session.")
|
||||
|
||||
@@ -3,7 +3,6 @@ from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import UserFile
|
||||
@@ -65,23 +64,6 @@ def fetch_user_project_ids_for_user_files(
|
||||
}
|
||||
|
||||
|
||||
def fetch_persona_ids_for_user_files(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, list[int]]:
|
||||
"""Fetch persona (assistant) ids for specified user files."""
|
||||
stmt = (
|
||||
select(UserFile)
|
||||
.where(UserFile.id.in_(user_file_ids))
|
||||
.options(selectinload(UserFile.assistants))
|
||||
)
|
||||
results = db_session.execute(stmt).scalars().all()
|
||||
return {
|
||||
str(user_file.id): [persona.id for persona in user_file.assistants]
|
||||
for user_file in results
|
||||
}
|
||||
|
||||
|
||||
def update_last_accessed_at_for_user_files(
|
||||
user_file_ids: list[UUID],
|
||||
db_session: Session,
|
||||
|
||||
@@ -121,7 +121,6 @@ class VespaDocumentUserFields:
|
||||
"""
|
||||
|
||||
user_projects: list[int] | None = None
|
||||
personas: list[int] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -148,7 +148,6 @@ class MetadataUpdateRequest(BaseModel):
|
||||
hidden: bool | None = None
|
||||
secondary_index_updated: bool | None = None
|
||||
project_ids: set[int] | None = None
|
||||
persona_ids: set[int] | None = None
|
||||
|
||||
|
||||
class IndexRetrievalFilters(BaseModel):
|
||||
|
||||
@@ -50,7 +50,6 @@ from onyx.document_index.opensearch.schema import DocumentSchema
|
||||
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
|
||||
from onyx.document_index.opensearch.schema import GLOBAL_BOOST_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import HIDDEN_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import PERSONAS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.search import DocumentQuery
|
||||
from onyx.document_index.opensearch.search import (
|
||||
@@ -216,7 +215,6 @@ def _convert_onyx_chunk_to_opensearch_document(
|
||||
# OpenSearch and it will not store any data at all for this field, which
|
||||
# is different from supplying an empty list.
|
||||
user_projects=chunk.user_project or None,
|
||||
personas=chunk.personas or None,
|
||||
primary_owners=get_experts_stores_representations(
|
||||
chunk.source_document.primary_owners
|
||||
),
|
||||
@@ -364,11 +362,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
if user_fields and user_fields.user_projects
|
||||
else None
|
||||
),
|
||||
persona_ids=(
|
||||
set(user_fields.personas)
|
||||
if user_fields and user_fields.personas
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -716,10 +709,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
properties_to_update[USER_PROJECTS_FIELD_NAME] = list(
|
||||
update_request.project_ids
|
||||
)
|
||||
if update_request.persona_ids is not None:
|
||||
properties_to_update[PERSONAS_FIELD_NAME] = list(
|
||||
update_request.persona_ids
|
||||
)
|
||||
|
||||
if not properties_to_update:
|
||||
if len(update_request.document_ids) > 1:
|
||||
|
||||
@@ -41,7 +41,6 @@ IMAGE_FILE_ID_FIELD_NAME = "image_file_id"
|
||||
SOURCE_LINKS_FIELD_NAME = "source_links"
|
||||
DOCUMENT_SETS_FIELD_NAME = "document_sets"
|
||||
USER_PROJECTS_FIELD_NAME = "user_projects"
|
||||
PERSONAS_FIELD_NAME = "personas"
|
||||
DOCUMENT_ID_FIELD_NAME = "document_id"
|
||||
CHUNK_INDEX_FIELD_NAME = "chunk_index"
|
||||
MAX_CHUNK_SIZE_FIELD_NAME = "max_chunk_size"
|
||||
@@ -157,7 +156,6 @@ class DocumentChunk(BaseModel):
|
||||
|
||||
document_sets: list[str] | None = None
|
||||
user_projects: list[int] | None = None
|
||||
personas: list[int] | None = None
|
||||
primary_owners: list[str] | None = None
|
||||
secondary_owners: list[str] | None = None
|
||||
|
||||
@@ -487,7 +485,6 @@ class DocumentSchema:
|
||||
# Product-specific fields.
|
||||
DOCUMENT_SETS_FIELD_NAME: {"type": "keyword"},
|
||||
USER_PROJECTS_FIELD_NAME: {"type": "integer"},
|
||||
PERSONAS_FIELD_NAME: {"type": "integer"},
|
||||
PRIMARY_OWNERS_FIELD_NAME: {"type": "keyword"},
|
||||
SECONDARY_OWNERS_FIELD_NAME: {"type": "keyword"},
|
||||
# OpenSearch metadata fields.
|
||||
|
||||
@@ -28,7 +28,6 @@ from onyx.document_index.opensearch.schema import HIDDEN_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import LAST_UPDATED_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import MAX_CHUNK_SIZE_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import METADATA_LIST_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import PERSONAS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import PUBLIC_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import set_or_convert_timezone_to_utc
|
||||
from onyx.document_index.opensearch.schema import SOURCE_TYPE_FIELD_NAME
|
||||
@@ -145,7 +144,6 @@ class DocumentQuery:
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=min_chunk_index,
|
||||
max_chunk_index=max_chunk_index,
|
||||
@@ -204,7 +202,6 @@ class DocumentQuery:
|
||||
document_sets=[],
|
||||
user_file_ids=[],
|
||||
project_id=None,
|
||||
persona_id=None,
|
||||
time_cutoff=None,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -270,7 +267,6 @@ class DocumentQuery:
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -338,7 +334,6 @@ class DocumentQuery:
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -501,7 +496,6 @@ class DocumentQuery:
|
||||
document_sets: list[str],
|
||||
user_file_ids: list[UUID],
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
time_cutoff: datetime | None,
|
||||
min_chunk_index: int | None,
|
||||
max_chunk_index: int | None,
|
||||
@@ -536,8 +530,6 @@ class DocumentQuery:
|
||||
retrieved.
|
||||
project_id: If not None, only documents with this project ID in user
|
||||
projects will be retrieved.
|
||||
persona_id: If not None, only documents whose personas array
|
||||
contains this persona ID will be retrieved.
|
||||
time_cutoff: Time cutoff for the documents to retrieve. If not None,
|
||||
Documents which were last updated before this date will not be
|
||||
returned. For documents which do not have a value for their last
|
||||
@@ -635,9 +627,6 @@ class DocumentQuery:
|
||||
)
|
||||
return user_project_filter
|
||||
|
||||
def _get_persona_filter(persona_id: int) -> dict[str, Any]:
|
||||
return {"term": {PERSONAS_FIELD_NAME: {"value": persona_id}}}
|
||||
|
||||
def _get_time_cutoff_filter(time_cutoff: datetime) -> dict[str, Any]:
|
||||
# Convert to UTC if not already so the cutoff is comparable to the
|
||||
# document data.
|
||||
@@ -791,9 +780,6 @@ class DocumentQuery:
|
||||
# document's user projects list.
|
||||
filter_clauses.append(_get_user_project_filter(project_id))
|
||||
|
||||
if persona_id is not None:
|
||||
filter_clauses.append(_get_persona_filter(persona_id))
|
||||
|
||||
if time_cutoff is not None:
|
||||
# If a time cutoff is provided, the caller will only retrieve
|
||||
# documents where the document was last updated at or after the time
|
||||
|
||||
@@ -181,11 +181,6 @@ schema {{ schema_name }} {
|
||||
rank: filter
|
||||
attribute: fast-search
|
||||
}
|
||||
field personas type array<int> {
|
||||
indexing: summary | attribute
|
||||
rank: filter
|
||||
attribute: fast-search
|
||||
}
|
||||
}
|
||||
|
||||
# If using different tokenization settings, the fieldset has to be removed, and the field must
|
||||
|
||||
@@ -689,9 +689,6 @@ class VespaIndex(DocumentIndex):
|
||||
project_ids: set[int] | None = None
|
||||
if user_fields is not None and user_fields.user_projects is not None:
|
||||
project_ids = set(user_fields.user_projects)
|
||||
persona_ids: set[int] | None = None
|
||||
if user_fields is not None and user_fields.personas is not None:
|
||||
persona_ids = set(user_fields.personas)
|
||||
update_request = MetadataUpdateRequest(
|
||||
document_ids=[doc_id],
|
||||
doc_id_to_chunk_cnt={
|
||||
@@ -702,7 +699,6 @@ class VespaIndex(DocumentIndex):
|
||||
boost=fields.boost if fields is not None else None,
|
||||
hidden=fields.hidden if fields is not None else None,
|
||||
project_ids=project_ids,
|
||||
persona_ids=persona_ids,
|
||||
)
|
||||
|
||||
vespa_document_index.update([update_request])
|
||||
|
||||
@@ -46,7 +46,6 @@ from onyx.document_index.vespa_constants import METADATA
|
||||
from onyx.document_index.vespa_constants import METADATA_LIST
|
||||
from onyx.document_index.vespa_constants import METADATA_SUFFIX
|
||||
from onyx.document_index.vespa_constants import NUM_THREADS
|
||||
from onyx.document_index.vespa_constants import PERSONAS
|
||||
from onyx.document_index.vespa_constants import PRIMARY_OWNERS
|
||||
from onyx.document_index.vespa_constants import SECONDARY_OWNERS
|
||||
from onyx.document_index.vespa_constants import SECTION_CONTINUATION
|
||||
@@ -219,7 +218,6 @@ def _index_vespa_chunk(
|
||||
# still called `image_file_name` in Vespa for backwards compatibility
|
||||
IMAGE_FILE_NAME: chunk.image_file_id,
|
||||
USER_PROJECT: chunk.user_project if chunk.user_project is not None else [],
|
||||
PERSONAS: chunk.personas if chunk.personas is not None else [],
|
||||
BOOST: chunk.boost,
|
||||
AGGREGATED_CHUNK_BOOST_FACTOR: chunk.aggregated_chunk_boost_factor,
|
||||
}
|
||||
|
||||
@@ -12,7 +12,6 @@ from onyx.document_index.vespa_constants import DOCUMENT_ID
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_SETS
|
||||
from onyx.document_index.vespa_constants import HIDDEN
|
||||
from onyx.document_index.vespa_constants import METADATA_LIST
|
||||
from onyx.document_index.vespa_constants import PERSONAS
|
||||
from onyx.document_index.vespa_constants import SOURCE_TYPE
|
||||
from onyx.document_index.vespa_constants import TENANT_ID
|
||||
from onyx.document_index.vespa_constants import USER_PROJECT
|
||||
@@ -150,18 +149,6 @@ def build_vespa_filters(
|
||||
# Vespa YQL 'contains' expects a string literal; quote the integer
|
||||
return f'({USER_PROJECT} contains "{pid}") and '
|
||||
|
||||
def _build_persona_filter(
|
||||
persona_id: int | None,
|
||||
) -> str:
|
||||
if persona_id is None:
|
||||
return ""
|
||||
try:
|
||||
pid = int(persona_id)
|
||||
except Exception:
|
||||
logger.warning(f"Invalid persona ID: {persona_id}")
|
||||
return ""
|
||||
return f'({PERSONAS} contains "{pid}") and '
|
||||
|
||||
# Start building the filter string
|
||||
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
|
||||
|
||||
@@ -205,9 +192,6 @@ def build_vespa_filters(
|
||||
# User project filter (array<int> attribute membership)
|
||||
filter_str += _build_user_project_filter(filters.project_id)
|
||||
|
||||
# Persona filter (array<int> attribute membership)
|
||||
filter_str += _build_persona_filter(filters.persona_id)
|
||||
|
||||
# Time filter
|
||||
filter_str += _build_time_filter(filters.time_cutoff)
|
||||
|
||||
|
||||
@@ -183,10 +183,6 @@ def _update_single_chunk(
|
||||
model_config = {"frozen": True}
|
||||
assign: list[int]
|
||||
|
||||
class _Personas(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
assign: list[int]
|
||||
|
||||
class _VespaPutFields(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
# The names of these fields are based the Vespa schema. Changes to the
|
||||
@@ -197,7 +193,6 @@ def _update_single_chunk(
|
||||
access_control_list: _AccessControl | None = None
|
||||
hidden: _Hidden | None = None
|
||||
user_project: _UserProjects | None = None
|
||||
personas: _Personas | None = None
|
||||
|
||||
class _VespaPutRequest(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
@@ -232,11 +227,6 @@ def _update_single_chunk(
|
||||
if update_request.project_ids is not None
|
||||
else None
|
||||
)
|
||||
personas_update: _Personas | None = (
|
||||
_Personas(assign=list(update_request.persona_ids))
|
||||
if update_request.persona_ids is not None
|
||||
else None
|
||||
)
|
||||
|
||||
vespa_put_fields = _VespaPutFields(
|
||||
boost=boost_update,
|
||||
@@ -244,7 +234,6 @@ def _update_single_chunk(
|
||||
access_control_list=access_update,
|
||||
hidden=hidden_update,
|
||||
user_project=user_projects_update,
|
||||
personas=personas_update,
|
||||
)
|
||||
|
||||
vespa_put_request = _VespaPutRequest(
|
||||
|
||||
@@ -58,7 +58,6 @@ DOCUMENT_SETS = "document_sets"
|
||||
USER_FILE = "user_file"
|
||||
USER_FOLDER = "user_folder"
|
||||
USER_PROJECT = "user_project"
|
||||
PERSONAS = "personas"
|
||||
LARGE_CHUNK_REFERENCE_IDS = "large_chunk_reference_ids"
|
||||
METADATA = "metadata"
|
||||
METADATA_LIST = "metadata_list"
|
||||
|
||||
@@ -146,7 +146,6 @@ class DocumentIndexingBatchAdapter:
|
||||
doc_id_to_document_set.get(chunk.source_document.id, [])
|
||||
),
|
||||
user_project=[],
|
||||
personas=[],
|
||||
boost=(
|
||||
context.id_to_boost_map[chunk.source_document.id]
|
||||
if chunk.source_document.id in context.id_to_boost_map
|
||||
|
||||
@@ -20,7 +20,6 @@ from onyx.db.models import Persona
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.user_file import fetch_chunk_counts_for_user_files
|
||||
from onyx.db.user_file import fetch_persona_ids_for_user_files
|
||||
from onyx.db.user_file import fetch_user_project_ids_for_user_files
|
||||
from onyx.file_store.utils import store_user_file_plaintext
|
||||
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
|
||||
@@ -120,10 +119,6 @@ class UserFileIndexingAdapter:
|
||||
user_file_ids=updatable_ids,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
user_file_id_to_persona_ids = fetch_persona_ids_for_user_files(
|
||||
user_file_ids=updatable_ids,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
user_file_id_to_access: dict[str, DocumentAccess] = get_access_for_user_files(
|
||||
user_file_ids=updatable_ids,
|
||||
db_session=self.db_session,
|
||||
@@ -187,7 +182,7 @@ class UserFileIndexingAdapter:
|
||||
user_project=user_file_id_to_project_ids.get(
|
||||
chunk.source_document.id, []
|
||||
),
|
||||
personas=user_file_id_to_persona_ids.get(chunk.source_document.id, []),
|
||||
# we are going to index userfiles only once, so we just set the boost to the default
|
||||
boost=DEFAULT_BOOST,
|
||||
tenant_id=tenant_id,
|
||||
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
|
||||
|
||||
@@ -112,7 +112,6 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
access: "DocumentAccess"
|
||||
document_sets: set[str]
|
||||
user_project: list[int]
|
||||
personas: list[int]
|
||||
boost: int
|
||||
aggregated_chunk_boost_factor: float
|
||||
# Full ancestor path from root hierarchy node to document's parent.
|
||||
@@ -127,7 +126,6 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
access: "DocumentAccess",
|
||||
document_sets: set[str],
|
||||
user_project: list[int],
|
||||
personas: list[int],
|
||||
boost: int,
|
||||
aggregated_chunk_boost_factor: float,
|
||||
tenant_id: str,
|
||||
@@ -139,7 +137,6 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
access=access,
|
||||
document_sets=document_sets,
|
||||
user_project=user_project,
|
||||
personas=personas,
|
||||
boost=boost,
|
||||
aggregated_chunk_boost_factor=aggregated_chunk_boost_factor,
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@@ -592,8 +592,11 @@ def build_slack_response_blocks(
|
||||
)
|
||||
|
||||
citations_blocks = []
|
||||
document_blocks = []
|
||||
if answer.citation_info:
|
||||
citations_blocks = _build_citations_blocks(answer)
|
||||
else:
|
||||
document_blocks = _priority_ordered_documents_blocks(answer)
|
||||
|
||||
citations_divider = [DividerBlock()] if citations_blocks else []
|
||||
buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else []
|
||||
@@ -605,6 +608,7 @@ def build_slack_response_blocks(
|
||||
+ ai_feedback_block
|
||||
+ citations_divider
|
||||
+ citations_blocks
|
||||
+ document_blocks
|
||||
+ buttons_divider
|
||||
+ web_follow_up_block
|
||||
+ follow_up_block
|
||||
|
||||
@@ -1,59 +1,10 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from mistune import create_markdown
|
||||
from mistune import HTMLRenderer
|
||||
|
||||
# Tags that should be replaced with a newline (line-break and block-level elements)
|
||||
_HTML_NEWLINE_TAG_PATTERN = re.compile(
|
||||
r"<br\s*/?>|</(?:p|div|li|h[1-6]|tr|blockquote|section|article)>",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Strips HTML tags but excludes autolinks like <https://...> and <mailto:...>
|
||||
_HTML_TAG_PATTERN = re.compile(
|
||||
r"<(?!https?://|mailto:)/?[a-zA-Z][^>]*>",
|
||||
)
|
||||
|
||||
# Matches fenced code blocks (``` ... ```) so we can skip sanitization inside them
|
||||
_FENCED_CODE_BLOCK_PATTERN = re.compile(r"```[\s\S]*?```")
|
||||
|
||||
# Matches the start of any markdown link: [text]( or [[n]](
|
||||
# The inner group handles nested brackets for citation links like [[1]](.
|
||||
_MARKDOWN_LINK_PATTERN = re.compile(r"\[(?:[^\[\]]|\[[^\]]*\])*\]\(")
|
||||
|
||||
# Matches Slack-style links <url|text> that LLMs sometimes output directly.
|
||||
# Mistune doesn't recognise this syntax, so text() would escape the angle
|
||||
# brackets and Slack would render them as literal text instead of links.
|
||||
_SLACK_LINK_PATTERN = re.compile(r"<(https?://[^|>]+)\|([^>]+)>")
|
||||
|
||||
|
||||
def _sanitize_html(text: str) -> str:
|
||||
"""Strip HTML tags from a text fragment.
|
||||
|
||||
Block-level closing tags and <br> are converted to newlines.
|
||||
All other HTML tags are removed. Autolinks (<https://...>) are preserved.
|
||||
"""
|
||||
text = _HTML_NEWLINE_TAG_PATTERN.sub("\n", text)
|
||||
text = _HTML_TAG_PATTERN.sub("", text)
|
||||
return text
|
||||
|
||||
|
||||
def _transform_outside_code_blocks(
|
||||
message: str, transform: Callable[[str], str]
|
||||
) -> str:
|
||||
"""Apply *transform* only to text outside fenced code blocks."""
|
||||
parts = _FENCED_CODE_BLOCK_PATTERN.split(message)
|
||||
code_blocks = _FENCED_CODE_BLOCK_PATTERN.findall(message)
|
||||
|
||||
result: list[str] = []
|
||||
for i, part in enumerate(parts):
|
||||
result.append(transform(part))
|
||||
if i < len(code_blocks):
|
||||
result.append(code_blocks[i])
|
||||
|
||||
return "".join(result)
|
||||
_CITATION_LINK_PATTERN = re.compile(r"\[\[\d+\]\]\(")
|
||||
|
||||
|
||||
def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int | None]:
|
||||
@@ -78,21 +29,15 @@ def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int |
|
||||
return message[start_idx:], None
|
||||
|
||||
|
||||
def _normalize_link_destinations(message: str) -> str:
|
||||
"""Wrap markdown link URLs in angle brackets so the parser handles special chars safely.
|
||||
|
||||
Markdown link syntax [text](url) breaks when the URL contains unescaped
|
||||
parentheses, spaces, or other special characters. Wrapping the URL in angle
|
||||
brackets — [text](<url>) — tells the parser to treat everything inside as
|
||||
a literal URL. This applies to all links, not just citations.
|
||||
"""
|
||||
if "](" not in message:
|
||||
def _normalize_citation_link_destinations(message: str) -> str:
|
||||
"""Wrap citation URLs in angle brackets so markdown parsers handle parentheses safely."""
|
||||
if "[[" not in message:
|
||||
return message
|
||||
|
||||
normalized_parts: list[str] = []
|
||||
cursor = 0
|
||||
|
||||
while match := _MARKDOWN_LINK_PATTERN.search(message, cursor):
|
||||
while match := _CITATION_LINK_PATTERN.search(message, cursor):
|
||||
normalized_parts.append(message[cursor : match.end()])
|
||||
destination_start = match.end()
|
||||
destination, end_idx = _extract_link_destination(message, destination_start)
|
||||
@@ -112,38 +57,18 @@ def _normalize_link_destinations(message: str) -> str:
|
||||
return "".join(normalized_parts)
|
||||
|
||||
|
||||
def _convert_slack_links_to_markdown(message: str) -> str:
|
||||
"""Convert Slack-style <url|text> links to standard markdown [text](url).
|
||||
|
||||
LLMs sometimes emit Slack mrkdwn link syntax directly. Mistune doesn't
|
||||
recognise it, so the angle brackets would be escaped by text() and Slack
|
||||
would render the link as literal text instead of a clickable link.
|
||||
"""
|
||||
return _transform_outside_code_blocks(
|
||||
message, lambda text: _SLACK_LINK_PATTERN.sub(r"[\2](\1)", text)
|
||||
)
|
||||
|
||||
|
||||
def format_slack_message(message: str | None) -> str:
|
||||
if message is None:
|
||||
return ""
|
||||
message = _transform_outside_code_blocks(message, _sanitize_html)
|
||||
message = _convert_slack_links_to_markdown(message)
|
||||
normalized_message = _normalize_link_destinations(message)
|
||||
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
|
||||
normalized_message = _normalize_citation_link_destinations(message)
|
||||
result = md(normalized_message)
|
||||
# With HTMLRenderer, result is always str (not AST list)
|
||||
assert isinstance(result, str)
|
||||
return result.rstrip("\n")
|
||||
return result
|
||||
|
||||
|
||||
class SlackRenderer(HTMLRenderer):
|
||||
"""Renders markdown as Slack mrkdwn format instead of HTML.
|
||||
|
||||
Overrides all HTMLRenderer methods that produce HTML tags to ensure
|
||||
no raw HTML ever appears in Slack messages.
|
||||
"""
|
||||
|
||||
SPECIALS: dict[str, str] = {"&": "&", "<": "<", ">": ">"}
|
||||
|
||||
def escape_special(self, text: str) -> str:
|
||||
@@ -152,7 +77,7 @@ class SlackRenderer(HTMLRenderer):
|
||||
return text
|
||||
|
||||
def heading(self, text: str, level: int, **attrs: Any) -> str: # noqa: ARG002
|
||||
return f"*{text}*\n\n"
|
||||
return f"*{text}*\n"
|
||||
|
||||
def emphasis(self, text: str) -> str:
|
||||
return f"_{text}_"
|
||||
@@ -171,7 +96,7 @@ class SlackRenderer(HTMLRenderer):
|
||||
count += 1
|
||||
prefix = f"{count}. " if ordered else "• "
|
||||
lines[i] = f"{prefix}{line[4:]}"
|
||||
return "\n".join(lines) + "\n"
|
||||
return "\n".join(lines)
|
||||
|
||||
def list_item(self, text: str) -> str:
|
||||
return f"li: {text}\n"
|
||||
@@ -193,30 +118,7 @@ class SlackRenderer(HTMLRenderer):
|
||||
return f"`{text}`"
|
||||
|
||||
def block_code(self, code: str, info: str | None = None) -> str: # noqa: ARG002
|
||||
return f"```\n{code.rstrip(chr(10))}\n```\n\n"
|
||||
|
||||
def linebreak(self) -> str:
|
||||
return "\n"
|
||||
|
||||
def thematic_break(self) -> str:
|
||||
return "---\n\n"
|
||||
|
||||
def block_quote(self, text: str) -> str:
|
||||
lines = text.strip().split("\n")
|
||||
quoted = "\n".join(f">{line}" for line in lines)
|
||||
return quoted + "\n\n"
|
||||
|
||||
def block_html(self, html: str) -> str:
|
||||
return _sanitize_html(html) + "\n\n"
|
||||
|
||||
def block_error(self, text: str) -> str:
|
||||
return f"```\n{text}\n```\n\n"
|
||||
|
||||
def text(self, text: str) -> str:
|
||||
# Only escape the three entities Slack recognizes: & < >
|
||||
# HTMLRenderer.text() also escapes " to " which Slack renders
|
||||
# as literal " text since Slack doesn't recognize that entity.
|
||||
return self.escape_special(text)
|
||||
return f"```\n{code}\n```\n"
|
||||
|
||||
def paragraph(self, text: str) -> str:
|
||||
return f"{text}\n\n"
|
||||
return f"{text}\n"
|
||||
|
||||
@@ -1,19 +1,15 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate AGENTS.md by scanning the files directory and populating the template.
|
||||
|
||||
This script runs during session setup, AFTER files have been synced from S3
|
||||
and the files symlink has been created. It reads an existing AGENTS.md (which
|
||||
contains the {{KNOWLEDGE_SOURCES_SECTION}} placeholder), replaces the
|
||||
placeholder by scanning the knowledge source directory, and writes it back.
|
||||
This script runs at container startup, AFTER the init container has synced files
|
||||
from S3. It scans the /workspace/files directory to discover what knowledge sources
|
||||
are available and generates appropriate documentation.
|
||||
|
||||
Usage:
|
||||
python3 generate_agents_md.py <agents_md_path> <files_path>
|
||||
|
||||
Arguments:
|
||||
agents_md_path: Path to the AGENTS.md file to update in place
|
||||
files_path: Path to the files directory to scan for knowledge sources
|
||||
Environment variables:
|
||||
- AGENT_INSTRUCTIONS: The template content with placeholders to replace
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
@@ -193,39 +189,49 @@ def build_knowledge_sources_section(files_path: Path) -> str:
|
||||
def main() -> None:
|
||||
"""Main entry point for container startup script.
|
||||
|
||||
Reads an existing AGENTS.md, replaces the {{KNOWLEDGE_SOURCES_SECTION}}
|
||||
placeholder by scanning the files directory, and writes it back.
|
||||
|
||||
Usage:
|
||||
python3 generate_agents_md.py <agents_md_path> <files_path>
|
||||
Is called by the container startup script to scan /workspace/files and populate
|
||||
the knowledge sources section.
|
||||
"""
|
||||
if len(sys.argv) != 3:
|
||||
print(
|
||||
f"Usage: {sys.argv[0]} <agents_md_path> <files_path>",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
# Read template from environment variable
|
||||
template = os.environ.get("AGENT_INSTRUCTIONS", "")
|
||||
if not template:
|
||||
print("Warning: No AGENT_INSTRUCTIONS template provided", file=sys.stderr)
|
||||
template = "# Agent Instructions\n\nNo instructions provided."
|
||||
|
||||
agents_md_path = Path(sys.argv[1])
|
||||
files_path = Path(sys.argv[2])
|
||||
# Scan files directory - check /workspace/files first, then /workspace/demo_data
|
||||
files_path = Path("/workspace/files")
|
||||
demo_data_path = Path("/workspace/demo_data")
|
||||
|
||||
if not agents_md_path.exists():
|
||||
print(f"Error: {agents_md_path} not found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
# Use demo_data if files doesn't exist or is empty
|
||||
if not files_path.exists() or not any(files_path.iterdir()):
|
||||
if demo_data_path.exists():
|
||||
files_path = demo_data_path
|
||||
|
||||
template = agents_md_path.read_text()
|
||||
knowledge_sources_section = build_knowledge_sources_section(files_path)
|
||||
|
||||
# Resolve symlinks (handles both direct symlinks and dirs containing symlinks)
|
||||
resolved_files_path = files_path.resolve()
|
||||
|
||||
knowledge_sources_section = build_knowledge_sources_section(resolved_files_path)
|
||||
|
||||
# Replace placeholder and write back
|
||||
content = template.replace(
|
||||
# Replace placeholders
|
||||
content = template
|
||||
content = content.replace(
|
||||
"{{KNOWLEDGE_SOURCES_SECTION}}", knowledge_sources_section
|
||||
)
|
||||
agents_md_path.write_text(content)
|
||||
print(f"Populated knowledge sources in {agents_md_path}")
|
||||
|
||||
# Write AGENTS.md
|
||||
output_path = Path("/workspace/AGENTS.md")
|
||||
output_path.write_text(content)
|
||||
|
||||
# Log result
|
||||
source_count = 0
|
||||
if files_path.exists():
|
||||
source_count = len(
|
||||
[
|
||||
d
|
||||
for d in files_path.iterdir()
|
||||
if d.is_dir() and not d.name.startswith(".")
|
||||
]
|
||||
)
|
||||
print(
|
||||
f"Generated AGENTS.md with {source_count} knowledge sources from {files_path}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1352,9 +1352,6 @@ fi
|
||||
echo "Writing AGENTS.md"
|
||||
printf '%s' '{agent_instructions_escaped}' > {session_path}/AGENTS.md
|
||||
|
||||
# Populate knowledge sources by scanning the files directory
|
||||
python3 /usr/local/bin/generate_agents_md.py {session_path}/AGENTS.md {session_path}/files || true
|
||||
|
||||
# Write opencode config
|
||||
echo "Writing opencode.json"
|
||||
printf '%s' '{opencode_json_escaped}' > {session_path}/opencode.json
|
||||
@@ -1783,9 +1780,6 @@ ln -sf {symlink_target} {session_path}/files
|
||||
echo "Writing AGENTS.md"
|
||||
printf '%s' '{agent_instructions_escaped}' > {session_path}/AGENTS.md
|
||||
|
||||
# Populate knowledge sources by scanning the files directory
|
||||
python3 /usr/local/bin/generate_agents_md.py {session_path}/AGENTS.md {session_path}/files || true
|
||||
|
||||
# Write opencode config
|
||||
echo "Writing opencode.json"
|
||||
printf '%s' '{opencode_json_escaped}' > {session_path}/opencode.json
|
||||
|
||||
@@ -35,18 +35,6 @@ if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class EmailInviteStatus(str, Enum):
|
||||
SENT = "SENT"
|
||||
NOT_CONFIGURED = "NOT_CONFIGURED"
|
||||
SEND_FAILED = "SEND_FAILED"
|
||||
DISABLED = "DISABLED"
|
||||
|
||||
|
||||
class BulkInviteResponse(BaseModel):
|
||||
invited_count: int
|
||||
email_invite_status: EmailInviteStatus
|
||||
|
||||
|
||||
class VersionResponse(BaseModel):
|
||||
backend_version: str
|
||||
|
||||
|
||||
@@ -36,7 +36,6 @@ from onyx.configs.app_configs import AUTH_BACKEND
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import AuthBackend
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import ENABLE_EMAIL_INVITES
|
||||
from onyx.configs.app_configs import NUM_FREE_TRIAL_USER_INVITES
|
||||
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
|
||||
@@ -79,10 +78,8 @@ from onyx.server.documents.models import PaginatedReturn
|
||||
from onyx.server.features.projects.models import UserFileSnapshot
|
||||
from onyx.server.manage.models import AllUsersResponse
|
||||
from onyx.server.manage.models import AutoScrollRequest
|
||||
from onyx.server.manage.models import BulkInviteResponse
|
||||
from onyx.server.manage.models import ChatBackgroundRequest
|
||||
from onyx.server.manage.models import DefaultAppModeRequest
|
||||
from onyx.server.manage.models import EmailInviteStatus
|
||||
from onyx.server.manage.models import MemoryItem
|
||||
from onyx.server.manage.models import PersonalizationUpdateRequest
|
||||
from onyx.server.manage.models import TenantInfo
|
||||
@@ -371,7 +368,7 @@ def bulk_invite_users(
|
||||
emails: list[str] = Body(..., embed=True),
|
||||
current_user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> BulkInviteResponse:
|
||||
) -> int:
|
||||
"""emails are string validated. If any email fails validation, no emails are
|
||||
invited and an exception is raised."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
@@ -430,41 +427,34 @@ def bulk_invite_users(
|
||||
number_of_invited_users = write_invited_users(all_emails)
|
||||
|
||||
# send out email invitations only to new users (not already invited or existing)
|
||||
if not ENABLE_EMAIL_INVITES:
|
||||
email_invite_status = EmailInviteStatus.DISABLED
|
||||
elif not EMAIL_CONFIGURED:
|
||||
email_invite_status = EmailInviteStatus.NOT_CONFIGURED
|
||||
else:
|
||||
if ENABLE_EMAIL_INVITES:
|
||||
try:
|
||||
for email in emails_needing_seats:
|
||||
send_user_email_invite(email, current_user, AUTH_TYPE)
|
||||
email_invite_status = EmailInviteStatus.SENT
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending email invite to invited users: {e}")
|
||||
email_invite_status = EmailInviteStatus.SEND_FAILED
|
||||
|
||||
if MULTI_TENANT and not DEV_MODE:
|
||||
# for billing purposes, write to the control plane about the number of new users
|
||||
try:
|
||||
logger.info("Registering tenant users")
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.billing", "register_tenant_users", None
|
||||
)(tenant_id, get_live_users_count(db_session))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register tenant users: {str(e)}")
|
||||
logger.info(
|
||||
"Reverting changes: removing users from tenant and resetting invited users"
|
||||
)
|
||||
write_invited_users(initial_invited_users) # Reset to original state
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
|
||||
)(new_invited_emails, tenant_id)
|
||||
raise e
|
||||
if not MULTI_TENANT or DEV_MODE:
|
||||
return number_of_invited_users
|
||||
|
||||
return BulkInviteResponse(
|
||||
invited_count=number_of_invited_users,
|
||||
email_invite_status=email_invite_status,
|
||||
)
|
||||
# for billing purposes, write to the control plane about the number of new users
|
||||
try:
|
||||
logger.info("Registering tenant users")
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.billing", "register_tenant_users", None
|
||||
)(tenant_id, get_live_users_count(db_session))
|
||||
|
||||
return number_of_invited_users
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register tenant users: {str(e)}")
|
||||
logger.info(
|
||||
"Reverting changes: removing users from tenant and resetting invited users"
|
||||
)
|
||||
write_invited_users(initial_invited_users) # Reset to original state
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
|
||||
)(new_invited_emails, tenant_id)
|
||||
raise e
|
||||
|
||||
|
||||
@router.patch("/manage/admin/remove-invited-user", tags=PUBLIC_API_TAGS)
|
||||
|
||||
@@ -54,7 +54,6 @@ logger = setup_logger()
|
||||
class SearchToolConfig(BaseModel):
|
||||
user_selected_filters: BaseFilters | None = None
|
||||
project_id: int | None = None
|
||||
persona_id: int | None = None
|
||||
bypass_acl: bool = False
|
||||
additional_context: str | None = None
|
||||
slack_context: SlackContext | None = None
|
||||
@@ -181,7 +180,6 @@ def construct_tools(
|
||||
document_index=document_index,
|
||||
user_selected_filters=search_tool_config.user_selected_filters,
|
||||
project_id=search_tool_config.project_id,
|
||||
persona_id=search_tool_config.persona_id,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
slack_context=search_tool_config.slack_context,
|
||||
enable_slack_search=search_tool_config.enable_slack_search,
|
||||
@@ -429,7 +427,6 @@ def construct_tools(
|
||||
document_index=document_index,
|
||||
user_selected_filters=search_tool_config.user_selected_filters,
|
||||
project_id=search_tool_config.project_id,
|
||||
persona_id=search_tool_config.persona_id,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
slack_context=search_tool_config.slack_context,
|
||||
enable_slack_search=search_tool_config.enable_slack_search,
|
||||
|
||||
@@ -12,7 +12,6 @@ from onyx.configs.app_configs import CODE_INTERPRETER_BASE_URL
|
||||
from onyx.configs.app_configs import CODE_INTERPRETER_DEFAULT_TIMEOUT_MS
|
||||
from onyx.configs.app_configs import CODE_INTERPRETER_MAX_OUTPUT_LENGTH
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.db.code_interpreter import fetch_code_interpreter_server
|
||||
from onyx.file_store.utils import build_full_frontend_file_url
|
||||
from onyx.file_store.utils import get_default_file_store
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
@@ -104,10 +103,8 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
@override
|
||||
@classmethod
|
||||
def is_available(cls, db_session: Session) -> bool:
|
||||
if not CODE_INTERPRETER_BASE_URL:
|
||||
return False
|
||||
server = fetch_code_interpreter_server(db_session)
|
||||
return server.server_enabled
|
||||
is_available = bool(CODE_INTERPRETER_BASE_URL)
|
||||
return is_available
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
|
||||
@@ -247,8 +247,6 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
user_selected_filters: BaseFilters | None,
|
||||
# If the chat is part of a project
|
||||
project_id: int | None,
|
||||
# If set, search scopes to files attached to this persona
|
||||
persona_id: int | None = None,
|
||||
bypass_acl: bool = False,
|
||||
# Slack context for federated Slack search (tokens fetched internally)
|
||||
slack_context: SlackContext | None = None,
|
||||
@@ -263,7 +261,6 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
self.document_index = document_index
|
||||
self.user_selected_filters = user_selected_filters
|
||||
self.project_id = project_id
|
||||
self.persona_id = persona_id
|
||||
self.bypass_acl = bypass_acl
|
||||
self.slack_context = slack_context
|
||||
self.enable_slack_search = enable_slack_search
|
||||
@@ -459,7 +456,6 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
limit=num_hits,
|
||||
),
|
||||
project_id=self.project_id,
|
||||
persona_id=self.persona_id,
|
||||
document_index=self.document_index,
|
||||
user=self.user,
|
||||
persona=self.persona,
|
||||
|
||||
@@ -317,7 +317,7 @@ oauthlib==3.2.2
|
||||
# via
|
||||
# kubernetes
|
||||
# requests-oauthlib
|
||||
onyx-devtools==0.6.2
|
||||
onyx-devtools==0.6.1
|
||||
# via onyx
|
||||
openai==2.14.0
|
||||
# via
|
||||
|
||||
@@ -95,7 +95,6 @@ def generate_dummy_chunk(
|
||||
return DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
user_project=[],
|
||||
personas=[],
|
||||
access=DocumentAccess.build(
|
||||
user_emails=user_emails,
|
||||
user_groups=user_groups,
|
||||
|
||||
@@ -144,8 +144,7 @@ def use_mock_search_pipeline(
|
||||
auto_detect_filters: bool = False, # noqa: ARG001
|
||||
llm: LLM | None = None, # noqa: ARG001
|
||||
project_id: int | None = None, # noqa: ARG001
|
||||
persona_id: int | None = None, # noqa: ARG001
|
||||
# Pre-fetched data (used by SearchTool to avoid DB access in parallel calls)
|
||||
# Pre-fetched data (used by SearchTool to avoid DB access in parallel)
|
||||
acl_filters: list[str] | None = None, # noqa: ARG001
|
||||
embedding_model: EmbeddingModel | None = None, # noqa: ARG001
|
||||
prefetched_federated_retrieval_infos: ( # noqa: ARG001
|
||||
|
||||
@@ -38,7 +38,6 @@ def _get_search_filters(
|
||||
tags=[],
|
||||
document_sets=[],
|
||||
project_id=None,
|
||||
persona_id=None,
|
||||
time_cutoff=None,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
"""Tests that PythonTool.is_available() respects the server_enabled DB flag.
|
||||
|
||||
Uses a real DB session with CODE_INTERPRETER_BASE_URL mocked so the
|
||||
environment-variable check passes and the DB flag is the deciding factor.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.code_interpreter import fetch_code_interpreter_server
|
||||
from onyx.db.code_interpreter import update_code_interpreter_server_enabled
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
|
||||
|
||||
def test_python_tool_unavailable_when_server_disabled(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""With a valid base URL, the tool should be unavailable when
|
||||
server_enabled is False in the DB."""
|
||||
server = fetch_code_interpreter_server(db_session)
|
||||
initial_enabled = server.server_enabled
|
||||
|
||||
try:
|
||||
update_code_interpreter_server_enabled(db_session, enabled=False)
|
||||
|
||||
with patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
"http://fake:8888",
|
||||
):
|
||||
assert PythonTool.is_available(db_session) is False
|
||||
finally:
|
||||
update_code_interpreter_server_enabled(db_session, enabled=initial_enabled)
|
||||
|
||||
|
||||
def test_python_tool_available_when_server_enabled(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""With a valid base URL, the tool should be available when
|
||||
server_enabled is True in the DB."""
|
||||
server = fetch_code_interpreter_server(db_session)
|
||||
initial_enabled = server.server_enabled
|
||||
|
||||
try:
|
||||
update_code_interpreter_server_enabled(db_session, enabled=True)
|
||||
|
||||
with patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
"http://fake:8888",
|
||||
):
|
||||
assert PythonTool.is_available(db_session) is True
|
||||
finally:
|
||||
update_code_interpreter_server_enabled(db_session, enabled=initial_enabled)
|
||||
@@ -38,5 +38,5 @@ COPY --from=openapi-client /local/onyx_openapi_client /app/generated/onyx_openap
|
||||
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
ENTRYPOINT ["pytest", "-s", "-rs"]
|
||||
ENTRYPOINT ["pytest", "-s"]
|
||||
CMD ["/app/tests/integration", "--ignore=/app/tests/integration/multitenant_tests"]
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from urllib.parse import urlencode
|
||||
from uuid import UUID
|
||||
@@ -9,10 +8,8 @@ from requests.models import CaseInsensitiveDict
|
||||
from ee.onyx.server.query_history.models import ChatSessionMinimal
|
||||
from ee.onyx.server.query_history.models import ChatSessionSnapshot
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import MAX_DELAY
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@@ -72,42 +69,9 @@ class QueryHistoryManager:
|
||||
if end_time:
|
||||
query_params["end"] = end_time.isoformat()
|
||||
|
||||
start_response = requests.post(
|
||||
url=f"{API_SERVER_URL}/admin/query-history/start-export?{urlencode(query_params, doseq=True)}",
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/admin/query-history-csv?{urlencode(query_params, doseq=True)}",
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
start_response.raise_for_status()
|
||||
request_id = start_response.json()["request_id"]
|
||||
|
||||
deadline = time.time() + MAX_DELAY
|
||||
while time.time() < deadline:
|
||||
status_response = requests.get(
|
||||
url=f"{API_SERVER_URL}/admin/query-history/export-status",
|
||||
params={"request_id": request_id},
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
status_response.raise_for_status()
|
||||
status = status_response.json()["status"]
|
||||
if status == TaskStatus.SUCCESS:
|
||||
break
|
||||
if status == TaskStatus.FAILURE:
|
||||
raise RuntimeError("Query history export task failed")
|
||||
time.sleep(2)
|
||||
else:
|
||||
raise TimeoutError(
|
||||
f"Query history export not completed within {MAX_DELAY} seconds"
|
||||
)
|
||||
|
||||
download_response = requests.get(
|
||||
url=f"{API_SERVER_URL}/admin/query-history/download",
|
||||
params={"request_id": request_id},
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
download_response.raise_for_status()
|
||||
|
||||
if not download_response.content:
|
||||
raise RuntimeError(
|
||||
"Query history CSV download returned zero-length content"
|
||||
)
|
||||
|
||||
return download_response.headers, download_response.content.decode()
|
||||
response.raise_for_status()
|
||||
return response.headers, response.content.decode()
|
||||
|
||||
@@ -6,26 +6,16 @@ import pytest
|
||||
from onyx.connectors.slack.models import ChannelType
|
||||
from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager
|
||||
|
||||
SLACK_ADMIN_EMAIL = os.environ.get("SLACK_ADMIN_EMAIL", "evan@onyx.app")
|
||||
SLACK_TEST_USER_1_EMAIL = os.environ.get("SLACK_TEST_USER_1_EMAIL", "evan+1@onyx.app")
|
||||
SLACK_TEST_USER_2_EMAIL = os.environ.get("SLACK_TEST_USER_2_EMAIL", "justin@onyx.app")
|
||||
# from tests.load_env_vars import load_env_vars
|
||||
|
||||
# load_env_vars()
|
||||
|
||||
|
||||
def _provision_slack_channels(
|
||||
bot_token: str,
|
||||
) -> Generator[tuple[ChannelType, ChannelType], None, None]:
|
||||
slack_client = SlackManager.get_slack_client(bot_token)
|
||||
|
||||
auth_info = slack_client.auth_test()
|
||||
print(f"\nSlack workspace: {auth_info.get('team')} ({auth_info.get('url')})")
|
||||
|
||||
@pytest.fixture()
|
||||
def slack_test_setup() -> Generator[tuple[ChannelType, ChannelType], None, None]:
|
||||
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
|
||||
user_map = SlackManager.build_slack_user_email_id_map(slack_client)
|
||||
if SLACK_ADMIN_EMAIL not in user_map:
|
||||
raise KeyError(
|
||||
f"'{SLACK_ADMIN_EMAIL}' not found in Slack workspace. "
|
||||
f"Available emails: {sorted(user_map.keys())}"
|
||||
)
|
||||
admin_user_id = user_map[SLACK_ADMIN_EMAIL]
|
||||
admin_user_id = user_map["admin@example.com"]
|
||||
|
||||
(
|
||||
public_channel,
|
||||
@@ -37,16 +27,5 @@ def _provision_slack_channels(
|
||||
|
||||
yield public_channel, private_channel
|
||||
|
||||
# This part will always run after the test, even if it fails
|
||||
SlackManager.cleanup_after_test(slack_client=slack_client, test_id=run_id)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def slack_test_setup() -> Generator[tuple[ChannelType, ChannelType], None, None]:
|
||||
yield from _provision_slack_channels(os.environ["SLACK_BOT_TOKEN"])
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def slack_perm_sync_test_setup() -> (
|
||||
Generator[tuple[ChannelType, ChannelType], None, None]
|
||||
):
|
||||
yield from _provision_slack_channels(os.environ["SLACK_BOT_TOKEN_TEST_SPACE"])
|
||||
|
||||
@@ -16,6 +16,7 @@ from uuid import uuid4
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.connectors.slack.connector import default_msg_filter
|
||||
from onyx.connectors.slack.connector import get_channel_messages
|
||||
from onyx.connectors.slack.models import ChannelType
|
||||
from onyx.connectors.slack.utils import make_paginated_slack_api_call
|
||||
@@ -112,6 +113,9 @@ def _delete_slack_conversation_messages(
|
||||
channel_id = _get_slack_channel_id(channel)
|
||||
for message_batch in get_channel_messages(slack_client, channel):
|
||||
for message in message_batch:
|
||||
if default_msg_filter(message):
|
||||
continue
|
||||
|
||||
if message_to_delete and message.get("text") != message_to_delete:
|
||||
continue
|
||||
print(" removing message: ", message.get("text"))
|
||||
|
||||
@@ -22,9 +22,6 @@ from tests.integration.common_utils.test_models import DATestConnector
|
||||
from tests.integration.common_utils.test_models import DATestCredential
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
from tests.integration.connector_job_tests.slack.conftest import SLACK_ADMIN_EMAIL
|
||||
from tests.integration.connector_job_tests.slack.conftest import SLACK_TEST_USER_1_EMAIL
|
||||
from tests.integration.connector_job_tests.slack.conftest import SLACK_TEST_USER_2_EMAIL
|
||||
from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager
|
||||
|
||||
|
||||
@@ -37,24 +34,26 @@ from tests.integration.connector_job_tests.slack.slack_api_utils import SlackMan
|
||||
def test_slack_permission_sync(
|
||||
reset: None, # noqa: ARG001
|
||||
vespa_client: vespa_fixture, # noqa: ARG001
|
||||
slack_perm_sync_test_setup: tuple[ChannelType, ChannelType],
|
||||
slack_test_setup: tuple[ChannelType, ChannelType],
|
||||
) -> None:
|
||||
public_channel, private_channel = slack_perm_sync_test_setup
|
||||
public_channel, private_channel = slack_test_setup
|
||||
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(
|
||||
email=SLACK_ADMIN_EMAIL,
|
||||
email="admin@example.com",
|
||||
)
|
||||
|
||||
# Creating a non-admin user
|
||||
test_user_1: DATestUser = UserManager.create(
|
||||
email=SLACK_TEST_USER_1_EMAIL,
|
||||
email="test_user_1@example.com",
|
||||
)
|
||||
|
||||
# Creating a non-admin user
|
||||
test_user_2: DATestUser = UserManager.create(
|
||||
email=SLACK_TEST_USER_2_EMAIL,
|
||||
email="test_user_2@example.com",
|
||||
)
|
||||
|
||||
bot_token = os.environ["SLACK_BOT_TOKEN_TEST_SPACE"]
|
||||
slack_client = SlackManager.get_slack_client(bot_token)
|
||||
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
|
||||
email_id_map = SlackManager.build_slack_user_email_id_map(slack_client)
|
||||
admin_user_id = email_id_map[admin_user.email]
|
||||
|
||||
@@ -64,7 +63,7 @@ def test_slack_permission_sync(
|
||||
credential: DATestCredential = CredentialManager.create(
|
||||
source=DocumentSource.SLACK,
|
||||
credential_json={
|
||||
"slack_bot_token": bot_token,
|
||||
"slack_bot_token": os.environ["SLACK_BOT_TOKEN"],
|
||||
},
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
@@ -74,7 +73,6 @@ def test_slack_permission_sync(
|
||||
source=DocumentSource.SLACK,
|
||||
connector_specific_config={
|
||||
"channels": [public_channel["name"], private_channel["name"]],
|
||||
"include_bot_messages": True,
|
||||
},
|
||||
access_type=AccessType.SYNC,
|
||||
groups=[],
|
||||
@@ -104,11 +102,14 @@ def test_slack_permission_sync(
|
||||
public_message = "Steve's favorite number is 809752"
|
||||
private_message = "Sara's favorite number is 346794"
|
||||
|
||||
# Add messages to channels
|
||||
print(f"\n Adding public message to channel: {public_message}")
|
||||
SlackManager.add_message_to_channel(
|
||||
slack_client=slack_client,
|
||||
channel=public_channel,
|
||||
message=public_message,
|
||||
)
|
||||
print(f"\n Adding private message to channel: {private_message}")
|
||||
SlackManager.add_message_to_channel(
|
||||
slack_client=slack_client,
|
||||
channel=private_channel,
|
||||
@@ -126,9 +127,7 @@ def test_slack_permission_sync(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Run permission sync. Since initial_index_should_sync=True for Slack,
|
||||
# permissions were already set during indexing above — the explicit sync
|
||||
# should find no changes to apply.
|
||||
# Run permission sync
|
||||
CCPairManager.sync(
|
||||
cc_pair=cc_pair,
|
||||
user_performing_action=admin_user,
|
||||
@@ -136,38 +135,59 @@ def test_slack_permission_sync(
|
||||
CCPairManager.wait_for_sync(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
number_of_updated_docs=0,
|
||||
number_of_updated_docs=2,
|
||||
user_performing_action=admin_user,
|
||||
should_wait_for_group_sync=False,
|
||||
should_wait_for_vespa_sync=False,
|
||||
)
|
||||
|
||||
# Verify admin can see messages from both channels
|
||||
admin_docs = DocumentSearchManager.search_documents(
|
||||
# Search as admin with access to both channels
|
||||
print("\nSearching as admin user")
|
||||
onyx_doc_message_strings = DocumentSearchManager.search_documents(
|
||||
query="favorite number",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert public_message in admin_docs
|
||||
assert private_message in admin_docs
|
||||
print(
|
||||
"\n documents retrieved by admin user: ",
|
||||
onyx_doc_message_strings,
|
||||
)
|
||||
|
||||
# Verify test_user_2 can only see public channel messages
|
||||
user_2_docs = DocumentSearchManager.search_documents(
|
||||
# Ensure admin user can see messages from both channels
|
||||
assert public_message in onyx_doc_message_strings
|
||||
assert private_message in onyx_doc_message_strings
|
||||
|
||||
# Search as test_user_2 with access to only the public channel
|
||||
print("\n Searching as test_user_2")
|
||||
onyx_doc_message_strings = DocumentSearchManager.search_documents(
|
||||
query="favorite number",
|
||||
user_performing_action=test_user_2,
|
||||
)
|
||||
assert public_message in user_2_docs
|
||||
assert private_message not in user_2_docs
|
||||
print(
|
||||
"\n documents retrieved by test_user_2: ",
|
||||
onyx_doc_message_strings,
|
||||
)
|
||||
|
||||
# Verify test_user_1 can see both channels (member of private channel)
|
||||
user_1_docs = DocumentSearchManager.search_documents(
|
||||
# Ensure test_user_2 can only see messages from the public channel
|
||||
assert public_message in onyx_doc_message_strings
|
||||
assert private_message not in onyx_doc_message_strings
|
||||
|
||||
# Search as test_user_1 with access to both channels
|
||||
print("\n Searching as test_user_1")
|
||||
onyx_doc_message_strings = DocumentSearchManager.search_documents(
|
||||
query="favorite number",
|
||||
user_performing_action=test_user_1,
|
||||
)
|
||||
assert public_message in user_1_docs
|
||||
assert private_message in user_1_docs
|
||||
print(
|
||||
"\n documents retrieved by test_user_1 before being removed from private channel: ",
|
||||
onyx_doc_message_strings,
|
||||
)
|
||||
|
||||
# Remove test_user_1 from the private channel
|
||||
# Ensure test_user_1 can see messages from both channels
|
||||
assert public_message in onyx_doc_message_strings
|
||||
assert private_message in onyx_doc_message_strings
|
||||
|
||||
# ----------------------MAKE THE CHANGES--------------------------
|
||||
print("\n Removing test_user_1 from the private channel")
|
||||
before = datetime.now(timezone.utc)
|
||||
# Remove test_user_1 from the private channel
|
||||
desired_channel_members = [admin_user]
|
||||
SlackManager.set_channel_members(
|
||||
slack_client=slack_client,
|
||||
@@ -186,16 +206,24 @@ def test_slack_permission_sync(
|
||||
after=before,
|
||||
number_of_updated_docs=1,
|
||||
user_performing_action=admin_user,
|
||||
should_wait_for_group_sync=False,
|
||||
)
|
||||
|
||||
# Verify test_user_1 can no longer see private channel after removal
|
||||
user_1_docs = DocumentSearchManager.search_documents(
|
||||
# ----------------------------VERIFY THE CHANGES---------------------------
|
||||
# Ensure test_user_1 can no longer see messages from the private channel
|
||||
# Search as test_user_1 with access to only the public channel
|
||||
|
||||
onyx_doc_message_strings = DocumentSearchManager.search_documents(
|
||||
query="favorite number",
|
||||
user_performing_action=test_user_1,
|
||||
)
|
||||
assert public_message in user_1_docs
|
||||
assert private_message not in user_1_docs
|
||||
print(
|
||||
"\n documents retrieved by test_user_1 after being removed from private channel: ",
|
||||
onyx_doc_message_strings,
|
||||
)
|
||||
|
||||
# Ensure test_user_1 can only see messages from the public channel
|
||||
assert public_message in onyx_doc_message_strings
|
||||
assert private_message not in onyx_doc_message_strings
|
||||
|
||||
|
||||
# NOTE(rkuo): it isn't yet clear if the reason these were previously xfail'd
|
||||
@@ -207,19 +235,21 @@ def test_slack_permission_sync(
|
||||
def test_slack_group_permission_sync(
|
||||
reset: None, # noqa: ARG001
|
||||
vespa_client: vespa_fixture, # noqa: ARG001
|
||||
slack_perm_sync_test_setup: tuple[ChannelType, ChannelType],
|
||||
slack_test_setup: tuple[ChannelType, ChannelType],
|
||||
) -> None:
|
||||
"""
|
||||
This test ensures that permission sync overrides onyx group access.
|
||||
"""
|
||||
public_channel, private_channel = slack_perm_sync_test_setup
|
||||
public_channel, private_channel = slack_test_setup
|
||||
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(
|
||||
email=SLACK_ADMIN_EMAIL,
|
||||
email="admin@example.com",
|
||||
)
|
||||
|
||||
# Creating a non-admin user
|
||||
test_user_1: DATestUser = UserManager.create(
|
||||
email=SLACK_TEST_USER_1_EMAIL,
|
||||
email="test_user_1@example.com",
|
||||
)
|
||||
|
||||
# Create a user group and adding the non-admin user to it
|
||||
@@ -234,8 +264,7 @@ def test_slack_group_permission_sync(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
bot_token = os.environ["SLACK_BOT_TOKEN_TEST_SPACE"]
|
||||
slack_client = SlackManager.get_slack_client(bot_token)
|
||||
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
|
||||
email_id_map = SlackManager.build_slack_user_email_id_map(slack_client)
|
||||
admin_user_id = email_id_map[admin_user.email]
|
||||
|
||||
@@ -253,7 +282,7 @@ def test_slack_group_permission_sync(
|
||||
credential = CredentialManager.create(
|
||||
source=DocumentSource.SLACK,
|
||||
credential_json={
|
||||
"slack_bot_token": bot_token,
|
||||
"slack_bot_token": os.environ["SLACK_BOT_TOKEN"],
|
||||
},
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
@@ -265,7 +294,6 @@ def test_slack_group_permission_sync(
|
||||
source=DocumentSource.SLACK,
|
||||
connector_specific_config={
|
||||
"channels": [private_channel["name"]],
|
||||
"include_bot_messages": True,
|
||||
},
|
||||
access_type=AccessType.SYNC,
|
||||
groups=[user_group.id],
|
||||
@@ -298,8 +326,7 @@ def test_slack_group_permission_sync(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Run permission sync. Since initial_index_should_sync=True for Slack,
|
||||
# permissions were already set during indexing — no changes expected.
|
||||
# Run permission sync
|
||||
CCPairManager.sync(
|
||||
cc_pair=cc_pair,
|
||||
user_performing_action=admin_user,
|
||||
@@ -307,10 +334,8 @@ def test_slack_group_permission_sync(
|
||||
CCPairManager.wait_for_sync(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
number_of_updated_docs=0,
|
||||
number_of_updated_docs=1,
|
||||
user_performing_action=admin_user,
|
||||
should_wait_for_group_sync=False,
|
||||
should_wait_for_vespa_sync=False,
|
||||
)
|
||||
|
||||
# Verify admin can see the message
|
||||
|
||||
@@ -4,84 +4,75 @@ import time
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from onyx.db.chat import delete_chat_session
|
||||
from onyx.db.chat import get_chat_sessions_older_than
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.settings import SettingsManager
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestSettings
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
RETENTION_SECONDS = 10
|
||||
|
||||
|
||||
def _run_ttl_cleanup(retention_days: int) -> None:
|
||||
"""Directly execute TTL cleanup logic, bypassing Celery task infrastructure."""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
old_chat_sessions = get_chat_sessions_older_than(retention_days, db_session)
|
||||
|
||||
for user_id, session_id in old_chat_sessions:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_chat_session(
|
||||
user_id,
|
||||
session_id,
|
||||
db_session,
|
||||
include_deleted=True,
|
||||
hard_delete=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Chat retention tests are enterprise only",
|
||||
)
|
||||
def test_chat_retention(
|
||||
reset: None, admin_user: DATestUser, llm_provider: DATestLLMProvider # noqa: ARG001
|
||||
) -> None: # noqa: ARG001
|
||||
def test_chat_retention(reset: None, admin_user: DATestUser) -> None: # noqa: ARG001
|
||||
"""Test that chat sessions are deleted after the retention period expires."""
|
||||
|
||||
retention_days = RETENTION_SECONDS // 86400
|
||||
# Set chat retention period to 10 seconds
|
||||
retention_days = 10 / 86400 # 10 seconds in days (10 / 24 / 60 / 60)
|
||||
settings = DATestSettings(maximum_chat_retention_days=retention_days)
|
||||
SettingsManager.update_settings(settings, user_performing_action=admin_user)
|
||||
|
||||
# Create a chat session
|
||||
chat_session = ChatSessionManager.create(
|
||||
persona_id=0,
|
||||
description="Test chat retention",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
response = ChatSessionManager.send_message(
|
||||
# Send a message
|
||||
ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message="This message should be deleted soon",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert (
|
||||
response.error is None
|
||||
), f"Chat response should not have an error: {response.error}"
|
||||
|
||||
# Verify the chat session exists
|
||||
chat_history = ChatSessionManager.get_chat_history(
|
||||
chat_session=chat_session,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(chat_history) > 0, "Chat session should have messages"
|
||||
|
||||
# Wait for the retention period to elapse, then directly run TTL cleanup
|
||||
time.sleep(RETENTION_SECONDS + 2)
|
||||
_run_ttl_cleanup(retention_days)
|
||||
|
||||
# Verify the chat session was deleted
|
||||
# Wait for TTL task to run (give it ~60 seconds)
|
||||
print("Waiting for chat retention TTL task to run...")
|
||||
max_wait_time = 60 # maximum time to wait in seconds
|
||||
start_time = time.time()
|
||||
session_deleted = False
|
||||
try:
|
||||
chat_history = ChatSessionManager.get_chat_history(
|
||||
chat_session=chat_session,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
session_deleted = len(chat_history) == 0
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code in (404, 400):
|
||||
session_deleted = True
|
||||
else:
|
||||
raise
|
||||
|
||||
assert session_deleted, "Chat session was not deleted after retention period"
|
||||
while not session_deleted and (time.time() - start_time < max_wait_time):
|
||||
# Check if chat session is deleted
|
||||
try:
|
||||
# Attempt to get chat history - this should 404
|
||||
chat_history = ChatSessionManager.get_chat_history(
|
||||
chat_session=chat_session,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# If we got no messages or an empty response, session might be deleted
|
||||
if not chat_history:
|
||||
session_deleted = True
|
||||
break
|
||||
|
||||
except requests.exceptions.HTTPError as e:
|
||||
# If we get a 404 or other error, the session is gone
|
||||
if e.response.status_code in (404, 400):
|
||||
session_deleted = True
|
||||
break
|
||||
raise # Re-raise other errors
|
||||
|
||||
# Wait a bit before checking again
|
||||
time.sleep(5)
|
||||
print(f"Waited {time.time() - start_time:.1f} seconds for chat deletion...")
|
||||
|
||||
# Assert that the chat session was deleted
|
||||
assert session_deleted, "Chat session was not deleted within the expected time"
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
CODE_INTERPRETER_URL = f"{API_SERVER_URL}/admin/code-interpreter"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def preserve_code_interpreter_state(
|
||||
admin_user: DATestUser,
|
||||
) -> Generator[None, None, None]:
|
||||
"""Capture the code interpreter enabled state before a test and restore it
|
||||
afterwards, so that tests that toggle the setting cannot leak state."""
|
||||
response = requests.get(
|
||||
CODE_INTERPRETER_URL,
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
initial_enabled = response.json()["enabled"]
|
||||
|
||||
yield
|
||||
|
||||
restore = requests.put(
|
||||
CODE_INTERPRETER_URL,
|
||||
json={"enabled": initial_enabled},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
restore.raise_for_status()
|
||||
@@ -37,7 +37,6 @@ def test_get_code_interpreter_status_as_admin(
|
||||
|
||||
def test_update_code_interpreter_disable_and_enable(
|
||||
admin_user: DATestUser,
|
||||
preserve_code_interpreter_state: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""PUT endpoint should update the enabled flag and persist across reads."""
|
||||
# Disable
|
||||
|
||||
195
backend/tests/integration/tests/dev_apis/test_knowledge_chat.py
Normal file
195
backend/tests/integration/tests/dev_apis/test_knowledge_chat.py
Normal file
@@ -0,0 +1,195 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from onyx.configs.constants import MessageType
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.document import DocumentManager
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestAPIKey
|
||||
from tests.integration.common_utils.test_models import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="/chat/send-message-simple-with-history is enterprise only",
|
||||
)
|
||||
def test_all_stream_chat_message_objects_outputs(reset: None) -> None: # noqa: ARG001
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# create connector
|
||||
cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
# SEEDING DOCUMENTS
|
||||
cc_pair_1.documents = []
|
||||
cc_pair_1.documents.append(
|
||||
DocumentManager.seed_doc_with_content(
|
||||
cc_pair=cc_pair_1,
|
||||
content="Pablo's favorite color is blue",
|
||||
api_key=api_key,
|
||||
)
|
||||
)
|
||||
cc_pair_1.documents.append(
|
||||
DocumentManager.seed_doc_with_content(
|
||||
cc_pair=cc_pair_1,
|
||||
content="Chris's favorite color is red",
|
||||
api_key=api_key,
|
||||
)
|
||||
)
|
||||
cc_pair_1.documents.append(
|
||||
DocumentManager.seed_doc_with_content(
|
||||
cc_pair=cc_pair_1,
|
||||
content="Pika's favorite color is green",
|
||||
api_key=api_key,
|
||||
)
|
||||
)
|
||||
|
||||
# TESTING RESPONSE FOR QUESTION 1
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
|
||||
json={
|
||||
"messages": [
|
||||
{
|
||||
"message": "What is Pablo's favorite color?",
|
||||
"role": MessageType.USER.value,
|
||||
}
|
||||
],
|
||||
"persona_id": 0,
|
||||
},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
response_json = response.json()
|
||||
|
||||
# check that the answer is correct
|
||||
answer_1 = response_json["answer"]
|
||||
assert "blue" in answer_1.lower()
|
||||
|
||||
# FLAKY - check that the llm selected a document
|
||||
# assert 0 in response_json["llm_selected_doc_indices"]
|
||||
|
||||
# check that the final context documents are correct
|
||||
# (it should contain all documents because there arent enough to exclude any)
|
||||
assert 0 in response_json["final_context_doc_indices"]
|
||||
assert 1 in response_json["final_context_doc_indices"]
|
||||
assert 2 in response_json["final_context_doc_indices"]
|
||||
|
||||
# FLAKY - check that the cited documents are correct
|
||||
# assert cc_pair_1.documents[0].id in response_json["cited_documents"].values()
|
||||
|
||||
# flakiness likely due to non-deterministic rephrasing
|
||||
# FLAKY - check that the top documents are correct
|
||||
# assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[0].id
|
||||
print("response 1/3 passed")
|
||||
|
||||
# TESTING RESPONSE FOR QUESTION 2
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
|
||||
json={
|
||||
"messages": [
|
||||
{
|
||||
"message": "What is Pablo's favorite color?",
|
||||
"role": MessageType.USER.value,
|
||||
},
|
||||
{
|
||||
"message": answer_1,
|
||||
"role": MessageType.ASSISTANT.value,
|
||||
},
|
||||
{
|
||||
"message": "What is Chris's favorite color?",
|
||||
"role": MessageType.USER.value,
|
||||
},
|
||||
],
|
||||
"persona_id": 0,
|
||||
},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
response_json = response.json()
|
||||
|
||||
# check that the answer is correct
|
||||
answer_2 = response_json["answer"]
|
||||
assert "red" in answer_2.lower()
|
||||
|
||||
# FLAKY - check that the llm selected a document
|
||||
# assert 0 in response_json["llm_selected_doc_indices"]
|
||||
|
||||
# check that the final context documents are correct
|
||||
# (it should contain all documents because there arent enough to exclude any)
|
||||
assert 0 in response_json["final_context_doc_indices"]
|
||||
assert 1 in response_json["final_context_doc_indices"]
|
||||
assert 2 in response_json["final_context_doc_indices"]
|
||||
|
||||
# FLAKY - check that the cited documents are correct
|
||||
# assert cc_pair_1.documents[1].id in response_json["cited_documents"].values()
|
||||
|
||||
# flakiness likely due to non-deterministic rephrasing
|
||||
# FLAKY - check that the top documents are correct
|
||||
# assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[1].id
|
||||
print("response 2/3 passed")
|
||||
|
||||
# TESTING RESPONSE FOR QUESTION 3
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
|
||||
json={
|
||||
"messages": [
|
||||
{
|
||||
"message": "What is Pablo's favorite color?",
|
||||
"role": MessageType.USER.value,
|
||||
},
|
||||
{
|
||||
"message": answer_1,
|
||||
"role": MessageType.ASSISTANT.value,
|
||||
},
|
||||
{
|
||||
"message": "What is Chris's favorite color?",
|
||||
"role": MessageType.USER.value,
|
||||
},
|
||||
{
|
||||
"message": answer_2,
|
||||
"role": MessageType.ASSISTANT.value,
|
||||
},
|
||||
{
|
||||
"message": "What is Pika's favorite color?",
|
||||
"role": MessageType.USER.value,
|
||||
},
|
||||
],
|
||||
"persona_id": 0,
|
||||
},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
response_json = response.json()
|
||||
|
||||
# check that the answer is correct
|
||||
answer_3 = response_json["answer"]
|
||||
assert "green" in answer_3.lower()
|
||||
|
||||
# FLAKY - check that the llm selected a document
|
||||
# assert 0 in response_json["llm_selected_doc_indices"]
|
||||
|
||||
# check that the final context documents are correct
|
||||
# (it should contain all documents because there arent enough to exclude any)
|
||||
assert 0 in response_json["final_context_doc_indices"]
|
||||
assert 1 in response_json["final_context_doc_indices"]
|
||||
assert 2 in response_json["final_context_doc_indices"]
|
||||
|
||||
# FLAKY - check that the cited documents are correct
|
||||
# assert cc_pair_1.documents[2].id in response_json["cited_documents"].values()
|
||||
|
||||
# flakiness likely due to non-deterministic rephrasing
|
||||
# FLAKY - check that the top documents are correct
|
||||
# assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[2].id
|
||||
print("response 3/3 passed")
|
||||
250
backend/tests/integration/tests/dev_apis/test_simple_chat_api.py
Normal file
250
backend/tests/integration/tests/dev_apis/test_simple_chat_api.py
Normal file
@@ -0,0 +1,250 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from onyx.configs.constants import MessageType
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import NUM_DOCS
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.conftest import DocumentBuilderType
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="/chat/send-message-simple-with-history tests are enterprise only",
|
||||
)
|
||||
def test_send_message_simple_with_history(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
document_builder: DocumentBuilderType,
|
||||
) -> None:
|
||||
# create documents using the document builder
|
||||
# Create NUM_DOCS number of documents with dummy content
|
||||
content_list = [f"Document {i} content" for i in range(NUM_DOCS)]
|
||||
docs = document_builder(content_list)
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
|
||||
json={
|
||||
"messages": [
|
||||
{
|
||||
"message": docs[0].content,
|
||||
"role": MessageType.USER.value,
|
||||
}
|
||||
],
|
||||
"persona_id": 0,
|
||||
},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
# Check that the top document is the correct document
|
||||
assert response_json["top_documents"][0]["document_id"] == docs[0].id
|
||||
|
||||
# assert that the metadata is correct
|
||||
for doc in docs:
|
||||
found_doc = next(
|
||||
(x for x in response_json["top_documents"] if x["document_id"] == doc.id),
|
||||
None,
|
||||
)
|
||||
assert found_doc
|
||||
assert found_doc["metadata"]["document_id"] == doc.id
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="/chat/send-message-simple-with-history tests are enterprise only",
|
||||
)
|
||||
def test_using_reference_docs_with_simple_with_history_api_flow(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
document_builder: DocumentBuilderType,
|
||||
) -> None:
|
||||
# SEEDING DOCUMENTS
|
||||
docs = document_builder(
|
||||
[
|
||||
"Chris's favorite color is blue",
|
||||
"Hagen's favorite color is red",
|
||||
"Pablo's favorite color is green",
|
||||
]
|
||||
)
|
||||
|
||||
# SEINDING MESSAGE 1
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
|
||||
json={
|
||||
"messages": [
|
||||
{
|
||||
"message": "What is Pablo's favorite color?",
|
||||
"role": MessageType.USER.value,
|
||||
}
|
||||
],
|
||||
"persona_id": 0,
|
||||
},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
response_json = response.json()
|
||||
|
||||
# get the db_doc_id of the top document to use as a search doc id for second message
|
||||
first_db_doc_id = response_json["top_documents"][0]["db_doc_id"]
|
||||
|
||||
# SEINDING MESSAGE 2
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
|
||||
json={
|
||||
"messages": [
|
||||
{
|
||||
"message": "What is Pablo's favorite color?",
|
||||
"role": MessageType.USER.value,
|
||||
}
|
||||
],
|
||||
"persona_id": 0,
|
||||
"search_doc_ids": [first_db_doc_id],
|
||||
},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
response_json = response.json()
|
||||
|
||||
# make sure there is an answer
|
||||
assert response_json["answer"]
|
||||
|
||||
# This ensures the the document we think we are referencing when we send the search_doc_ids in the second
|
||||
# message is the document that we expect it to be
|
||||
assert response_json["top_documents"][0]["document_id"] == docs[2].id
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="We don't support this anymore with the DR flow :(")
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="/chat/send-message-simple-with-history tests are enterprise only",
|
||||
)
|
||||
def test_send_message_simple_with_history_strict_json(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
|
||||
json={
|
||||
# intentionally not relevant prompt to ensure that the
|
||||
# structured response format is actually used
|
||||
"messages": [
|
||||
{
|
||||
"message": "What is green?",
|
||||
"role": MessageType.USER.value,
|
||||
}
|
||||
],
|
||||
"persona_id": 0,
|
||||
"structured_response_format": {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "presidents",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"presidents": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "List of the first three US presidents",
|
||||
}
|
||||
},
|
||||
"required": ["presidents"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"strict": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
# Check that the answer is present
|
||||
assert "answer" in response_json
|
||||
assert response_json["answer"] is not None
|
||||
|
||||
# helper
|
||||
def clean_json_string(json_string: str) -> str:
|
||||
return json_string.strip().removeprefix("```json").removesuffix("```").strip()
|
||||
|
||||
# Attempt to parse the answer as JSON
|
||||
try:
|
||||
clean_answer = clean_json_string(response_json["answer"])
|
||||
parsed_answer = json.loads(clean_answer)
|
||||
|
||||
# NOTE: do not check content, just the structure
|
||||
assert isinstance(parsed_answer, dict)
|
||||
assert "presidents" in parsed_answer
|
||||
assert isinstance(parsed_answer["presidents"], list)
|
||||
for president in parsed_answer["presidents"]:
|
||||
assert isinstance(president, str)
|
||||
except json.JSONDecodeError:
|
||||
assert (
|
||||
False
|
||||
), f"The answer is not a valid JSON object - '{response_json['answer']}'"
|
||||
|
||||
# Check that the answer_citationless is also valid JSON
|
||||
assert "answer_citationless" in response_json
|
||||
assert response_json["answer_citationless"] is not None
|
||||
try:
|
||||
clean_answer_citationless = clean_json_string(
|
||||
response_json["answer_citationless"]
|
||||
)
|
||||
parsed_answer_citationless = json.loads(clean_answer_citationless)
|
||||
assert isinstance(parsed_answer_citationless, dict)
|
||||
except json.JSONDecodeError:
|
||||
assert False, "The answer_citationless is not a valid JSON object"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="/query/answer-with-citation tests are enterprise only",
|
||||
)
|
||||
def test_answer_with_citation_api(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
document_builder: DocumentBuilderType,
|
||||
) -> None:
|
||||
|
||||
# create docs
|
||||
docs = document_builder(["Chris' favorite color is green"])
|
||||
|
||||
# send a message
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/query/answer-with-citation",
|
||||
json={
|
||||
"messages": [
|
||||
{
|
||||
"message": "What is Chris' favorite color? Make sure to cite the document.",
|
||||
"role": MessageType.USER.value,
|
||||
}
|
||||
],
|
||||
"persona_id": 0,
|
||||
},
|
||||
headers=admin_user.headers,
|
||||
cookies=admin_user.cookies,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
response_json = response.json()
|
||||
assert response_json["answer"]
|
||||
|
||||
has_correct_citation = False
|
||||
for citation in response_json["citations"]:
|
||||
if citation["document_id"] == docs[0].id:
|
||||
has_correct_citation = True
|
||||
break
|
||||
|
||||
assert has_correct_citation
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
@@ -11,7 +12,6 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.mock_connector.connector import EXTERNAL_USER_EMAILS
|
||||
from onyx.connectors.mock_connector.connector import EXTERNAL_USER_GROUP_IDS
|
||||
from onyx.connectors.mock_connector.connector import MockConnectorCheckpoint
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.document import get_documents_by_ids
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
@@ -25,16 +25,128 @@ from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.document import DocumentManager
|
||||
from tests.integration.common_utils.managers.index_attempt import IndexAttemptManager
|
||||
from tests.integration.common_utils.test_document_utils import create_test_document
|
||||
from tests.integration.common_utils.test_models import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
|
||||
def _setup_mock_connector(
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Permission sync is enterprise only",
|
||||
)
|
||||
def test_mock_connector_initial_permission_sync(
|
||||
mock_server_client: httpx.Client,
|
||||
vespa_client: vespa_fixture,
|
||||
admin_user: DATestUser,
|
||||
) -> tuple[DATestCCPair, Document]:
|
||||
"""Common setup: create a test doc, configure mock server, create cc_pair, wait for indexing."""
|
||||
) -> None:
|
||||
"""Test that the MockConnector fetches and sets permissions during initial indexing when AccessType.SYNC is used"""
|
||||
|
||||
# Set up mock server behavior
|
||||
doc_uuid = uuid.uuid4()
|
||||
test_doc = create_test_document(doc_id=f"test-doc-{doc_uuid}")
|
||||
|
||||
response = mock_server_client.post(
|
||||
"/set-behavior",
|
||||
json=[
|
||||
{
|
||||
"documents": [test_doc.model_dump(mode="json")],
|
||||
"checkpoint": MockConnectorCheckpoint(has_more=False).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
"failures": [],
|
||||
}
|
||||
],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Create CC Pair with SYNC access type to enable permissions during indexing
|
||||
cc_pair = CCPairManager.create_from_scratch(
|
||||
name=f"mock-connector-permissions-{uuid.uuid4()}",
|
||||
source=DocumentSource.MOCK_CONNECTOR,
|
||||
input_type=InputType.POLL,
|
||||
connector_specific_config={
|
||||
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
|
||||
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
|
||||
},
|
||||
access_type=AccessType.SYNC, # This enables permissions during indexing
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Wait for index attempt to start
|
||||
index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Wait for index attempt to finish
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
index_attempt_id=index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Validate status
|
||||
finished_index_attempt = IndexAttemptManager.get_index_attempt_by_id(
|
||||
index_attempt_id=index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert finished_index_attempt.status == IndexingStatus.SUCCESS
|
||||
|
||||
# Verify document was indexed
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
vespa_client=vespa_client,
|
||||
)
|
||||
assert len(documents) == 1
|
||||
assert documents[0].id == test_doc.id
|
||||
|
||||
# Verify no errors occurred
|
||||
errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(errors) == 0
|
||||
|
||||
# Verify permissions were set during indexing by checking the document in the database
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_docs = get_documents_by_ids(
|
||||
db_session=db_session,
|
||||
document_ids=[test_doc.id],
|
||||
)
|
||||
assert len(db_docs) == 1
|
||||
db_doc = db_docs[0]
|
||||
|
||||
assert db_doc.external_user_emails is not None
|
||||
assert db_doc.external_user_group_ids is not None
|
||||
|
||||
# Check the specific permissions that MockConnector sets
|
||||
assert set(db_doc.external_user_emails) == EXTERNAL_USER_EMAILS
|
||||
assert set(db_doc.external_user_group_ids) == EXTERNAL_USER_GROUP_IDS
|
||||
|
||||
# Verify the document is not public (as set by MockConnector)
|
||||
assert db_doc.is_public is False
|
||||
|
||||
# Verify that the cc_pair was marked as permissions synced
|
||||
updated_cc_pair_info = CCPairManager.get_single(
|
||||
cc_pair.id, user_performing_action=admin_user
|
||||
)
|
||||
assert updated_cc_pair_info is not None
|
||||
assert updated_cc_pair_info.last_full_permission_sync is not None
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Permission sync attempt tracking is enterprise only",
|
||||
)
|
||||
def test_permission_sync_attempt_tracking_integration(
|
||||
mock_server_client: httpx.Client,
|
||||
vespa_client: vespa_fixture, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Test that permission sync attempts are properly tracked during real sync workflows."""
|
||||
|
||||
doc_uuid = uuid.uuid4()
|
||||
test_doc = create_test_document(doc_id=f"test-doc-{doc_uuid}")
|
||||
|
||||
@@ -53,7 +165,7 @@ def _setup_mock_connector(
|
||||
assert response.status_code == 200
|
||||
|
||||
cc_pair = CCPairManager.create_from_scratch(
|
||||
name=f"mock-connector-{uuid.uuid4()}",
|
||||
name=f"mock-connector-attempt-tracking-{uuid.uuid4()}",
|
||||
source=DocumentSource.MOCK_CONNECTOR,
|
||||
input_type=InputType.POLL,
|
||||
connector_specific_config={
|
||||
@@ -75,95 +187,6 @@ def _setup_mock_connector(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
finished = IndexAttemptManager.get_index_attempt_by_id(
|
||||
index_attempt_id=index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert finished.status == IndexingStatus.SUCCESS
|
||||
return cc_pair, test_doc
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Permission sync is enterprise only",
|
||||
)
|
||||
def test_mock_connector_initial_permission_sync(
|
||||
mock_server_client: httpx.Client,
|
||||
vespa_client: vespa_fixture,
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Test that the MockConnector fetches and sets permissions during initial indexing
|
||||
when AccessType.SYNC is used."""
|
||||
|
||||
cc_pair, test_doc = _setup_mock_connector(mock_server_client, admin_user)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
vespa_client=vespa_client,
|
||||
)
|
||||
assert len(documents) == 1
|
||||
assert documents[0].id == test_doc.id
|
||||
|
||||
errors = IndexAttemptManager.get_index_attempt_errors_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert len(errors) == 0
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_docs = get_documents_by_ids(
|
||||
db_session=db_session,
|
||||
document_ids=[test_doc.id],
|
||||
)
|
||||
assert len(db_docs) == 1
|
||||
db_doc = db_docs[0]
|
||||
|
||||
assert db_doc.external_user_emails is not None
|
||||
assert db_doc.external_user_group_ids is not None
|
||||
assert set(db_doc.external_user_emails) == EXTERNAL_USER_EMAILS
|
||||
assert set(db_doc.external_user_group_ids) == EXTERNAL_USER_GROUP_IDS
|
||||
assert db_doc.is_public is False
|
||||
|
||||
# After initial indexing, the beat task detects last_time_perm_sync is None
|
||||
# and triggers a doc permission sync. Explicitly trigger it to avoid
|
||||
# waiting for the 30s beat interval.
|
||||
before = datetime.now(timezone.utc)
|
||||
CCPairManager.sync(
|
||||
cc_pair=cc_pair,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
CCPairManager.wait_for_sync(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
number_of_updated_docs=1,
|
||||
user_performing_action=admin_user,
|
||||
should_wait_for_group_sync=False,
|
||||
should_wait_for_vespa_sync=False,
|
||||
)
|
||||
|
||||
updated_cc_pair_info = CCPairManager.get_single(
|
||||
cc_pair.id, user_performing_action=admin_user
|
||||
)
|
||||
assert updated_cc_pair_info is not None
|
||||
assert updated_cc_pair_info.last_full_permission_sync is not None
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Permission sync attempt tracking is enterprise only",
|
||||
)
|
||||
def test_permission_sync_attempt_tracking_integration(
|
||||
mock_server_client: httpx.Client,
|
||||
vespa_client: vespa_fixture, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Test that permission sync attempts are properly tracked during real sync workflows."""
|
||||
|
||||
cc_pair, _test_doc = _setup_mock_connector(mock_server_client, admin_user)
|
||||
|
||||
before = datetime.now(timezone.utc)
|
||||
CCPairManager.sync(
|
||||
cc_pair=cc_pair,
|
||||
@@ -175,8 +198,6 @@ def test_permission_sync_attempt_tracking_integration(
|
||||
after=before,
|
||||
number_of_updated_docs=1,
|
||||
user_performing_action=admin_user,
|
||||
should_wait_for_group_sync=False,
|
||||
should_wait_for_vespa_sync=False,
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -198,6 +219,88 @@ def test_permission_sync_attempt_tracking_integration(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Permission sync attempt tracking is enterprise only",
|
||||
)
|
||||
def test_permission_sync_attempt_tracking_with_mocked_failure(
|
||||
mock_server_client: httpx.Client,
|
||||
vespa_client: vespa_fixture, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Test that permission sync attempts are properly tracked when sync fails."""
|
||||
|
||||
doc_uuid = uuid.uuid4()
|
||||
test_doc = create_test_document(doc_id=f"test-doc-{doc_uuid}")
|
||||
|
||||
response = mock_server_client.post(
|
||||
"/set-behavior",
|
||||
json=[
|
||||
{
|
||||
"documents": [test_doc.model_dump(mode="json")],
|
||||
"checkpoint": MockConnectorCheckpoint(has_more=False).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
"failures": [],
|
||||
}
|
||||
],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
cc_pair = CCPairManager.create_from_scratch(
|
||||
name=f"mock-connector-attempt-failure-{uuid.uuid4()}",
|
||||
source=DocumentSource.MOCK_CONNECTOR,
|
||||
input_type=InputType.POLL,
|
||||
connector_specific_config={
|
||||
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
|
||||
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
|
||||
},
|
||||
access_type=AccessType.SYNC,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
index_attempt_id=index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Mock the permission sync to force a failure and verify attempt tracking
|
||||
with patch(
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing.tasks.validate_ccpair_for_user"
|
||||
) as mock_validate:
|
||||
mock_validate.side_effect = Exception("Validation failed for testing")
|
||||
|
||||
try:
|
||||
before = datetime.now(timezone.utc)
|
||||
CCPairManager.sync(
|
||||
cc_pair=cc_pair,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
CCPairManager.wait_for_sync(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
number_of_updated_docs=0,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
attempt = db_session.execute(
|
||||
select(DocPermissionSyncAttempt).where(
|
||||
DocPermissionSyncAttempt.connector_credential_pair_id == cc_pair.id
|
||||
)
|
||||
).scalar_one()
|
||||
|
||||
assert attempt.status == PermissionSyncStatus.FAILED
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Permission sync attempt tracking is enterprise only",
|
||||
@@ -208,8 +311,45 @@ def test_permission_sync_attempt_status_success(
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Test that permission sync attempts are marked as SUCCESS when sync completes without errors."""
|
||||
doc_uuid = uuid.uuid4()
|
||||
test_doc = create_test_document(doc_id=f"test-doc-{doc_uuid}")
|
||||
|
||||
cc_pair, _test_doc = _setup_mock_connector(mock_server_client, admin_user)
|
||||
response = mock_server_client.post(
|
||||
"/set-behavior",
|
||||
json=[
|
||||
{
|
||||
"documents": [test_doc.model_dump(mode="json")],
|
||||
"checkpoint": MockConnectorCheckpoint(has_more=False).model_dump(
|
||||
mode="json"
|
||||
),
|
||||
"failures": [],
|
||||
}
|
||||
],
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
cc_pair = CCPairManager.create_from_scratch(
|
||||
name=f"mock-connector-success-{uuid.uuid4()}",
|
||||
source=DocumentSource.MOCK_CONNECTOR,
|
||||
input_type=InputType.POLL,
|
||||
connector_specific_config={
|
||||
"mock_server_host": MOCK_CONNECTOR_SERVER_HOST,
|
||||
"mock_server_port": MOCK_CONNECTOR_SERVER_PORT,
|
||||
},
|
||||
access_type=AccessType.SYNC,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
index_attempt_id=index_attempt.id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
before = datetime.now(timezone.utc)
|
||||
CCPairManager.sync(
|
||||
@@ -222,8 +362,6 @@ def test_permission_sync_attempt_status_success(
|
||||
after=before,
|
||||
number_of_updated_docs=1,
|
||||
user_performing_action=admin_user,
|
||||
should_wait_for_group_sync=False,
|
||||
should_wait_for_vespa_sync=False,
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
|
||||
@@ -6,14 +6,11 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import can_user_access_llm_provider
|
||||
from onyx.db.llm import fetch_user_group_ids
|
||||
from onyx.db.models import LLMModelFlow
|
||||
from onyx.db.models import LLMProvider as LLMProviderModel
|
||||
from onyx.db.models import LLMProvider__Persona
|
||||
from onyx.db.models import LLMProvider__UserGroup
|
||||
from onyx.db.models import ModelConfiguration
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
@@ -270,24 +267,6 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
|
||||
provider_name=restricted_provider.name,
|
||||
)
|
||||
|
||||
# Set up ModelConfiguration + LLMModelFlow so get_default_llm() can
|
||||
# resolve the default provider when the fallback path is triggered.
|
||||
default_model_config = ModelConfiguration(
|
||||
llm_provider_id=default_provider.id,
|
||||
name=default_provider.default_model_name,
|
||||
is_visible=True,
|
||||
)
|
||||
db_session.add(default_model_config)
|
||||
db_session.flush()
|
||||
db_session.add(
|
||||
LLMModelFlow(
|
||||
model_configuration_id=default_model_config.id,
|
||||
llm_model_flow_type=LLMModelFlowType.CHAT,
|
||||
is_default=True,
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
access_group = UserGroup(name="persona-group")
|
||||
db_session.add(access_group)
|
||||
db_session.flush()
|
||||
|
||||
@@ -6,7 +6,7 @@ the permissions of the curator manipulating connector-credential pairs.
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from onyx_openapi_client.exceptions import ApiException # type: ignore[import-untyped,unused-ignore,import-not-found]
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.server.documents.models import DocumentSource
|
||||
@@ -93,9 +93,20 @@ def test_cc_pair_permissions(reset: None) -> None: # noqa: ARG001
|
||||
|
||||
"""Tests for things Curators should not be able to do"""
|
||||
|
||||
# Curators should not be able to create a public cc pair
|
||||
with pytest.raises(HTTPError):
|
||||
CCPairManager.create(
|
||||
connector_id=connector_1.id,
|
||||
credential_id=credential_1.id,
|
||||
name="invalid_cc_pair_1",
|
||||
access_type=AccessType.PUBLIC,
|
||||
groups=[user_group_1.id],
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Curators should not be able to create a cc
|
||||
# pair for a user group they are not a curator of
|
||||
with pytest.raises(ApiException):
|
||||
with pytest.raises(HTTPError):
|
||||
CCPairManager.create(
|
||||
connector_id=connector_1.id,
|
||||
credential_id=credential_1.id,
|
||||
@@ -107,7 +118,7 @@ def test_cc_pair_permissions(reset: None) -> None: # noqa: ARG001
|
||||
|
||||
# Curators should not be able to create a cc
|
||||
# pair without an attached user group
|
||||
with pytest.raises(ApiException):
|
||||
with pytest.raises(HTTPError):
|
||||
CCPairManager.create(
|
||||
connector_id=connector_1.id,
|
||||
credential_id=credential_1.id,
|
||||
@@ -133,7 +144,7 @@ def test_cc_pair_permissions(reset: None) -> None: # noqa: ARG001
|
||||
|
||||
# Curators should not be able to create a cc
|
||||
# pair for a user group that the credential does not belong to
|
||||
with pytest.raises(ApiException):
|
||||
with pytest.raises(HTTPError):
|
||||
CCPairManager.create(
|
||||
connector_id=connector_1.id,
|
||||
credential_id=credential_2.id,
|
||||
@@ -145,16 +156,6 @@ def test_cc_pair_permissions(reset: None) -> None: # noqa: ARG001
|
||||
|
||||
"""Tests for things Curators should be able to do"""
|
||||
|
||||
# Re-create connector since the credential_2 validation error above
|
||||
# triggers connector deletion in the exception handler
|
||||
connector_1 = ConnectorManager.create(
|
||||
name="admin_owned_connector_2",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_1.id],
|
||||
access_type=AccessType.PRIVATE,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Curators should be able to create a private
|
||||
# cc pair for a user group they are a curator of
|
||||
valid_cc_pair = CCPairManager.create(
|
||||
|
||||
@@ -59,7 +59,17 @@ def test_connector_permissions(reset: None) -> None: # noqa: ARG001
|
||||
|
||||
"""Tests for things Curators should not be able to do"""
|
||||
|
||||
# Curators should not be able to create a connector for a
|
||||
# Curators should not be able to create a public connector
|
||||
with pytest.raises(HTTPError):
|
||||
ConnectorManager.create(
|
||||
name="invalid_connector_1",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_1.id],
|
||||
access_type=AccessType.PUBLIC,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Curators should not be able to create a cc pair for a
|
||||
# user group they are not a curator of
|
||||
with pytest.raises(HTTPError):
|
||||
ConnectorManager.create(
|
||||
@@ -123,12 +133,12 @@ def test_connector_permissions(reset: None) -> None: # noqa: ARG001
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Curators should be able to create a public connector
|
||||
public_connector = ConnectorManager.create(
|
||||
name="curator_public_connector",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_1.id],
|
||||
access_type=AccessType.PUBLIC,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
assert public_connector.id is not None
|
||||
# Test that curator cannot create a public connector
|
||||
with pytest.raises(HTTPError):
|
||||
ConnectorManager.create(
|
||||
name="invalid_connector_4",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_1.id],
|
||||
access_type=AccessType.PUBLIC,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
@@ -58,6 +58,16 @@ def test_credential_permissions(reset: None) -> None: # noqa: ARG001
|
||||
|
||||
"""Tests for things Curators should not be able to do"""
|
||||
|
||||
# Curators should not be able to create a public credential
|
||||
with pytest.raises(HTTPError):
|
||||
CredentialManager.create(
|
||||
name="invalid_credential_1",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_1.id],
|
||||
curator_public=True,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Curators should not be able to create a credential for a user group they are not a curator of
|
||||
with pytest.raises(HTTPError):
|
||||
CredentialManager.create(
|
||||
@@ -103,16 +113,3 @@ def test_credential_permissions(reset: None) -> None: # noqa: ARG001
|
||||
verify_deleted=True,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
# Curators should be able to create a public credential
|
||||
public_credential = CredentialManager.create(
|
||||
name="curator_public_credential",
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
groups=[user_group_1.id],
|
||||
curator_public=True,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
CredentialManager.verify(
|
||||
credential=public_credential,
|
||||
user_performing_action=curator,
|
||||
)
|
||||
|
||||
@@ -70,11 +70,10 @@ def test_doc_set_permissions_setup(reset: None) -> None: # noqa: ARG001
|
||||
|
||||
"""Tests for things Curators/Admins should not be able to do"""
|
||||
|
||||
# Test that curator cannot create a non-public document set for the group they don't curate
|
||||
# Test that curator cannot create a document set for the group they don't curate
|
||||
with pytest.raises(HTTPError):
|
||||
DocumentSetManager.create(
|
||||
name="Invalid Document Set 1",
|
||||
is_public=False,
|
||||
groups=[user_group_2.id],
|
||||
cc_pair_ids=[public_cc_pair.id],
|
||||
user_performing_action=curator,
|
||||
|
||||
@@ -6,14 +6,12 @@ from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from io import BytesIO
|
||||
from io import StringIO
|
||||
from uuid import UUID
|
||||
from zipfile import ZipFile
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from ee.onyx.db.usage_export import UsageReportMetadata
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.db.seeding.chat_history_seeding import seed_chat_history
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
@@ -28,13 +26,7 @@ class TestUsageExportAPI:
|
||||
self, reset: None, admin_user: DATestUser # noqa: ARG002
|
||||
) -> None:
|
||||
# Seed some chat history data for the report
|
||||
seed_chat_history(
|
||||
num_sessions=10,
|
||||
num_messages=4,
|
||||
days=30,
|
||||
user_id=UUID(admin_user.id),
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
)
|
||||
seed_chat_history(num_sessions=10, num_messages=4, days=30)
|
||||
|
||||
# Get initial list of reports
|
||||
initial_response = requests.get(
|
||||
@@ -84,13 +76,7 @@ class TestUsageExportAPI:
|
||||
self, reset: None, admin_user: DATestUser # noqa: ARG002
|
||||
) -> None:
|
||||
# Seed some chat history data
|
||||
seed_chat_history(
|
||||
num_sessions=20,
|
||||
num_messages=4,
|
||||
days=60,
|
||||
user_id=UUID(admin_user.id),
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
)
|
||||
seed_chat_history(num_sessions=20, num_messages=4, days=60)
|
||||
|
||||
# Get initial list of reports
|
||||
initial_response = requests.get(
|
||||
@@ -162,13 +148,7 @@ class TestUsageExportAPI:
|
||||
self, reset: None, admin_user: DATestUser # noqa: ARG002
|
||||
) -> None:
|
||||
# First generate a report to ensure we have at least one
|
||||
seed_chat_history(
|
||||
num_sessions=5,
|
||||
num_messages=4,
|
||||
days=30,
|
||||
user_id=UUID(admin_user.id),
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
)
|
||||
seed_chat_history(num_sessions=5, num_messages=4, days=30)
|
||||
|
||||
# Get initial count
|
||||
initial_response = requests.get(
|
||||
@@ -224,13 +204,7 @@ class TestUsageExportAPI:
|
||||
self, reset: None, admin_user: DATestUser # noqa: ARG002
|
||||
) -> None:
|
||||
# First generate a report
|
||||
seed_chat_history(
|
||||
num_sessions=5,
|
||||
num_messages=4,
|
||||
days=30,
|
||||
user_id=UUID(admin_user.id),
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
)
|
||||
seed_chat_history(num_sessions=5, num_messages=4, days=30)
|
||||
|
||||
# Get initial reports count
|
||||
initial_response = requests.get(
|
||||
@@ -378,13 +352,7 @@ class TestUsageExportAPI:
|
||||
self, reset: None, admin_user: DATestUser # noqa: ARG002
|
||||
) -> None:
|
||||
# Seed some data
|
||||
seed_chat_history(
|
||||
num_sessions=10,
|
||||
num_messages=4,
|
||||
days=30,
|
||||
user_id=UUID(admin_user.id),
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
)
|
||||
seed_chat_history(num_sessions=10, num_messages=4, days=30)
|
||||
|
||||
# Get initial count of reports
|
||||
initial_response = requests.get(
|
||||
|
||||
@@ -25,11 +25,6 @@ def test_add_users_to_group(reset: None) -> None: # noqa: ARG001
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_performing_action=admin_user,
|
||||
user_groups_to_check=[user_group],
|
||||
)
|
||||
|
||||
updated_user_group = UserGroupManager.add_users(
|
||||
user_group=user_group,
|
||||
user_ids=[user_to_add.id],
|
||||
|
||||
@@ -3,8 +3,6 @@ from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
_enumerate_ad_groups_paginated,
|
||||
)
|
||||
@@ -17,9 +15,6 @@ from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
AD_GROUP_ENUMERATION_THRESHOLD,
|
||||
)
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
get_external_access_from_sharepoint,
|
||||
)
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
get_sharepoint_external_groups,
|
||||
)
|
||||
@@ -271,65 +266,3 @@ def test_enumerate_all_without_token_skips(
|
||||
|
||||
assert results == []
|
||||
mock_enum.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_external_access_from_sharepoint – site page URL handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"site_base_url, web_url, expected_relative_url",
|
||||
[
|
||||
(
|
||||
"https://tenant.sharepoint.com/sites/Evan%27sSite",
|
||||
"https://tenant.sharepoint.com/sites/Evan%27sSite/SitePages/Home.aspx",
|
||||
"/sites/Evan%27sSite/SitePages/Home.aspx",
|
||||
),
|
||||
(
|
||||
"https://tenant.sharepoint.com/sites/NormalSite",
|
||||
"https://tenant.sharepoint.com/sites/NormalSite/SitePages/Page.aspx",
|
||||
"/sites/NormalSite/SitePages/Page.aspx",
|
||||
),
|
||||
(
|
||||
"https://tenant.sharepoint.com/sites/Site%20With%20Spaces",
|
||||
"https://tenant.sharepoint.com/sites/Site%20With%20Spaces/SitePages/Doc.aspx",
|
||||
"/sites/Site%20With%20Spaces/SitePages/Doc.aspx",
|
||||
),
|
||||
],
|
||||
ids=["apostrophe-encoded", "no-special-chars", "space-encoded"],
|
||||
)
|
||||
@patch(f"{MODULE}._get_groups_and_members_recursively")
|
||||
@patch(f"{MODULE}.sleep_and_retry")
|
||||
def test_site_page_url_not_duplicated(
|
||||
mock_sleep: MagicMock, # noqa: ARG001
|
||||
mock_recursive: MagicMock,
|
||||
site_base_url: str,
|
||||
web_url: str,
|
||||
expected_relative_url: str,
|
||||
) -> None:
|
||||
"""Regression: the server-relative URL passed to
|
||||
get_file_by_server_relative_url must preserve percent-encoding so the
|
||||
Office365 library's SPResPath.create_relative() recognises the site prefix
|
||||
and doesn't duplicate it."""
|
||||
mock_recursive.return_value = GroupsResult(
|
||||
groups_to_emails={},
|
||||
found_public_group=False,
|
||||
)
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.base_url = site_base_url
|
||||
|
||||
site_page = {"webUrl": web_url}
|
||||
|
||||
get_external_access_from_sharepoint(
|
||||
client_context=ctx,
|
||||
graph_client=MagicMock(),
|
||||
drive_name=None,
|
||||
drive_item=None,
|
||||
site_page=site_page,
|
||||
)
|
||||
|
||||
ctx.web.get_file_by_server_relative_url.assert_called_once_with(
|
||||
expected_relative_url
|
||||
)
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
import json
|
||||
|
||||
import httplib2 # type: ignore[import-untyped]
|
||||
from googleapiclient.errors import HttpError # type: ignore[import-untyped]
|
||||
|
||||
from onyx.connectors.google_utils.google_utils import _is_rate_limit_error
|
||||
|
||||
|
||||
def _make_http_error(
|
||||
status: int,
|
||||
reason: str = "unknown",
|
||||
error_reason: str = "",
|
||||
) -> HttpError:
|
||||
resp = httplib2.Response({"status": status})
|
||||
if error_reason:
|
||||
body = json.dumps(
|
||||
{
|
||||
"error": {
|
||||
"message": reason,
|
||||
"errors": [{"reason": error_reason, "message": reason}],
|
||||
}
|
||||
}
|
||||
).encode()
|
||||
else:
|
||||
body = json.dumps({"error": {"message": reason}}).encode()
|
||||
return HttpError(resp, body)
|
||||
|
||||
|
||||
def test_429_is_rate_limit() -> None:
|
||||
assert _is_rate_limit_error(_make_http_error(429))
|
||||
|
||||
|
||||
def test_403_user_rate_limit_exceeded() -> None:
|
||||
err = _make_http_error(
|
||||
403,
|
||||
reason="User rate limit exceeded.",
|
||||
error_reason="userRateLimitExceeded",
|
||||
)
|
||||
assert _is_rate_limit_error(err)
|
||||
|
||||
|
||||
def test_403_rate_limit_exceeded() -> None:
|
||||
err = _make_http_error(
|
||||
403,
|
||||
reason="Rate limit exceeded.",
|
||||
error_reason="rateLimitExceeded",
|
||||
)
|
||||
assert _is_rate_limit_error(err)
|
||||
|
||||
|
||||
def test_403_permission_denied_is_not_rate_limit() -> None:
|
||||
err = _make_http_error(
|
||||
403,
|
||||
reason="The caller does not have permission",
|
||||
error_reason="forbidden",
|
||||
)
|
||||
assert not _is_rate_limit_error(err)
|
||||
|
||||
|
||||
def test_404_is_not_rate_limit() -> None:
|
||||
assert not _is_rate_limit_error(_make_http_error(404))
|
||||
|
||||
|
||||
def test_500_is_not_rate_limit() -> None:
|
||||
assert not _is_rate_limit_error(_make_http_error(500))
|
||||
@@ -1,34 +0,0 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.slab.connector import SlabConnector
|
||||
|
||||
|
||||
def _build_connector(base_url: str = "https://myteam.slab.com") -> SlabConnector:
|
||||
connector = SlabConnector(base_url=base_url)
|
||||
connector.load_credentials({"slab_bot_token": "fake-token"})
|
||||
return connector
|
||||
|
||||
|
||||
def test_validate_rejects_missing_scheme() -> None:
|
||||
connector = _build_connector(base_url="myteam.slab.com")
|
||||
with pytest.raises(ConnectorValidationError, match="https://"):
|
||||
connector.validate_connector_settings()
|
||||
|
||||
|
||||
@patch("onyx.connectors.slab.connector.get_all_post_ids", return_value=["id1"])
|
||||
def test_validate_success(mock_get_posts: object) -> None: # noqa: ARG001
|
||||
connector = _build_connector()
|
||||
connector.validate_connector_settings()
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.connectors.slab.connector.get_all_post_ids",
|
||||
side_effect=Exception("401 Unauthorized"),
|
||||
)
|
||||
def test_validate_bad_token_raises(mock_get_posts: object) -> None: # noqa: ARG001
|
||||
connector = _build_connector()
|
||||
with pytest.raises(ConnectorValidationError, match="Failed to fetch posts"):
|
||||
connector.validate_connector_settings()
|
||||
@@ -98,11 +98,6 @@ class TestScimDALUserMappings:
|
||||
"external_id": "ext-1",
|
||||
"user_id": user_id,
|
||||
"scim_username": None,
|
||||
"department": None,
|
||||
"manager": None,
|
||||
"given_name": None,
|
||||
"family_name": None,
|
||||
"scim_emails_json": None,
|
||||
}
|
||||
|
||||
def test_delete_user_mapping(
|
||||
|
||||
@@ -1,199 +0,0 @@
|
||||
"""Tests that persona IDs are correctly propagated through the indexing pipeline.
|
||||
|
||||
Covers Phase 1 (schema plumbing) and Phase 2 (write at index time) of the
|
||||
unify-assistant-project-files plan.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentSource
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
|
||||
|
||||
def _make_index_chunk(
|
||||
doc_id: str = "test-file-id",
|
||||
content: str = "test content",
|
||||
) -> IndexChunk:
|
||||
embedding = [0.1] * 10
|
||||
doc = Document(
|
||||
id=doc_id,
|
||||
semantic_identifier="test_file.txt",
|
||||
sections=[TextSection(text=content, link=None)],
|
||||
source=DocumentSource.USER_FILE,
|
||||
metadata={},
|
||||
)
|
||||
return IndexChunk(
|
||||
chunk_id=0,
|
||||
blurb=content[:50],
|
||||
content=content,
|
||||
source_links=None,
|
||||
image_file_id=None,
|
||||
section_continuation=False,
|
||||
source_document=doc,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=embedding,
|
||||
mini_chunk_embeddings=[],
|
||||
),
|
||||
title_embedding=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_access() -> DocumentAccess:
|
||||
return DocumentAccess.build(
|
||||
user_emails=["user@example.com"],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
|
||||
def test_from_index_chunk_propagates_personas() -> None:
|
||||
"""Personas list passed to from_index_chunk appears on the result."""
|
||||
chunk = _make_index_chunk()
|
||||
persona_ids = [10, 20, 30]
|
||||
|
||||
aware_chunk = DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
access=_make_access(),
|
||||
document_sets=set(),
|
||||
user_project=[1],
|
||||
personas=persona_ids,
|
||||
boost=0,
|
||||
aggregated_chunk_boost_factor=1.0,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
assert aware_chunk.personas == persona_ids
|
||||
assert aware_chunk.user_project == [1]
|
||||
|
||||
|
||||
def test_from_index_chunk_empty_personas() -> None:
|
||||
"""An empty personas list is preserved (not turned into None or omitted)."""
|
||||
chunk = _make_index_chunk()
|
||||
|
||||
aware_chunk = DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
access=_make_access(),
|
||||
document_sets=set(),
|
||||
user_project=[],
|
||||
personas=[],
|
||||
boost=0,
|
||||
aggregated_chunk_boost_factor=1.0,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
assert aware_chunk.personas == []
|
||||
|
||||
|
||||
def _make_document(doc_id: str) -> Document:
|
||||
return Document(
|
||||
id=doc_id,
|
||||
semantic_identifier="test_file.txt",
|
||||
sections=[TextSection(text="test content", link=None)],
|
||||
source=DocumentSource.USER_FILE,
|
||||
metadata={},
|
||||
)
|
||||
|
||||
|
||||
def _run_adapter_build(
|
||||
file_id: str,
|
||||
project_ids_map: dict[str, list[int]],
|
||||
persona_ids_map: dict[str, list[int]],
|
||||
) -> list[DocMetadataAwareIndexChunk]:
|
||||
"""Helper that runs UserFileIndexingAdapter.build_metadata_aware_chunks
|
||||
with all external dependencies mocked."""
|
||||
from onyx.indexing.adapters.user_file_indexing_adapter import (
|
||||
UserFileIndexingAdapter,
|
||||
)
|
||||
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
|
||||
|
||||
chunk = _make_index_chunk(doc_id=file_id)
|
||||
doc = _make_document(doc_id=file_id)
|
||||
|
||||
context = DocumentBatchPrepareContext(
|
||||
updatable_docs=[doc],
|
||||
id_to_boost_map={},
|
||||
)
|
||||
|
||||
adapter = UserFileIndexingAdapter(tenant_id="test_tenant", db_session=MagicMock())
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.fetch_user_project_ids_for_user_files",
|
||||
return_value=project_ids_map,
|
||||
),
|
||||
patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.fetch_persona_ids_for_user_files",
|
||||
return_value=persona_ids_map,
|
||||
),
|
||||
patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_access_for_user_files",
|
||||
return_value={file_id: _make_access()},
|
||||
),
|
||||
patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.fetch_chunk_counts_for_user_files",
|
||||
return_value=[(file_id, 0)],
|
||||
),
|
||||
patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
side_effect=Exception("no LLM in tests"),
|
||||
),
|
||||
):
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id="test_tenant",
|
||||
context=context,
|
||||
)
|
||||
|
||||
return result.chunks
|
||||
|
||||
|
||||
def test_build_metadata_aware_chunks_includes_persona_ids() -> None:
|
||||
"""UserFileIndexingAdapter.build_metadata_aware_chunks writes persona IDs
|
||||
fetched from the DB into each chunk's metadata."""
|
||||
file_id = str(uuid4())
|
||||
persona_ids = [5, 12]
|
||||
project_ids = [3]
|
||||
|
||||
chunks = _run_adapter_build(
|
||||
file_id=file_id,
|
||||
project_ids_map={file_id: project_ids},
|
||||
persona_ids_map={file_id: persona_ids},
|
||||
)
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].personas == persona_ids
|
||||
assert chunks[0].user_project == project_ids
|
||||
|
||||
|
||||
def test_build_metadata_aware_chunks_missing_file_defaults_to_empty() -> None:
|
||||
"""When a file has no persona or project associations in the DB, the
|
||||
adapter should default to empty lists (not KeyError or None)."""
|
||||
file_id = str(uuid4())
|
||||
|
||||
chunks = _run_adapter_build(
|
||||
file_id=file_id,
|
||||
project_ids_map={},
|
||||
persona_ids_map={},
|
||||
)
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].personas == []
|
||||
assert chunks[0].user_project == []
|
||||
@@ -1,7 +1,4 @@
|
||||
from onyx.onyxbot.slack.formatting import _convert_slack_links_to_markdown
|
||||
from onyx.onyxbot.slack.formatting import _normalize_link_destinations
|
||||
from onyx.onyxbot.slack.formatting import _sanitize_html
|
||||
from onyx.onyxbot.slack.formatting import _transform_outside_code_blocks
|
||||
from onyx.onyxbot.slack.formatting import _normalize_citation_link_destinations
|
||||
from onyx.onyxbot.slack.formatting import format_slack_message
|
||||
from onyx.onyxbot.slack.utils import remove_slack_text_interactions
|
||||
from onyx.utils.text_processing import decode_escapes
|
||||
@@ -12,7 +9,7 @@ def test_normalize_citation_link_wraps_url_with_parentheses() -> None:
|
||||
"See [[1]](https://example.com/Access%20ID%20Card(s)%20Guide.pdf) for details."
|
||||
)
|
||||
|
||||
normalized = _normalize_link_destinations(message)
|
||||
normalized = _normalize_citation_link_destinations(message)
|
||||
|
||||
assert (
|
||||
"See [[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>) for details."
|
||||
@@ -23,7 +20,7 @@ def test_normalize_citation_link_wraps_url_with_parentheses() -> None:
|
||||
def test_normalize_citation_link_keeps_existing_angle_brackets() -> None:
|
||||
message = "[[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>)"
|
||||
|
||||
normalized = _normalize_link_destinations(message)
|
||||
normalized = _normalize_citation_link_destinations(message)
|
||||
|
||||
assert message == normalized
|
||||
|
||||
@@ -34,7 +31,7 @@ def test_normalize_citation_link_handles_multiple_links() -> None:
|
||||
"[[2]](https://example.com/Plan(s)%20Overview.pdf)"
|
||||
)
|
||||
|
||||
normalized = _normalize_link_destinations(message)
|
||||
normalized = _normalize_citation_link_destinations(message)
|
||||
|
||||
assert "[[1]](<https://example.com/(USA)%20Guide.pdf>)" in normalized
|
||||
assert "[[2]](<https://example.com/Plan(s)%20Overview.pdf>)" in normalized
|
||||
@@ -53,54 +50,3 @@ def test_format_slack_message_keeps_parenthesized_citation_links_intact() -> Non
|
||||
in rendered
|
||||
)
|
||||
assert "|[1]>%20Access%20ID%20Card" not in rendered
|
||||
|
||||
|
||||
def test_slack_style_links_converted_to_clickable_links() -> None:
|
||||
message = "Visit <https://example.com/page|Example Page> for details."
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "<https://example.com/page|Example Page>" in formatted
|
||||
assert "<" not in formatted
|
||||
|
||||
|
||||
def test_slack_style_links_preserved_inside_code_blocks() -> None:
|
||||
message = "```\n<https://example.com|click>\n```"
|
||||
|
||||
converted = _convert_slack_links_to_markdown(message)
|
||||
|
||||
assert "<https://example.com|click>" in converted
|
||||
|
||||
|
||||
def test_html_tags_stripped_outside_code_blocks() -> None:
|
||||
message = "Hello<br/>world ```<div>code</div>``` after"
|
||||
|
||||
sanitized = _transform_outside_code_blocks(message, _sanitize_html)
|
||||
|
||||
assert "<br" not in sanitized
|
||||
assert "<div>code</div>" in sanitized
|
||||
|
||||
|
||||
def test_format_slack_message_block_spacing() -> None:
|
||||
message = "Paragraph one.\n\nParagraph two."
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "Paragraph one.\n\nParagraph two." == formatted
|
||||
|
||||
|
||||
def test_format_slack_message_code_block_no_trailing_blank_line() -> None:
|
||||
message = "```python\nprint('hi')\n```"
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert formatted.endswith("print('hi')\n```")
|
||||
|
||||
|
||||
def test_format_slack_message_ampersand_not_double_escaped() -> None:
|
||||
message = 'She said "hello" & goodbye.'
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "&" in formatted
|
||||
assert """ not in formatted
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
"""Test bulk invite limit for free trial tenants."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from onyx.server.manage.models import EmailInviteStatus
|
||||
from onyx.server.manage.users import bulk_invite_users
|
||||
|
||||
|
||||
@@ -35,7 +33,6 @@ def test_trial_tenant_cannot_exceed_invite_limit(*_mocks: None) -> None:
|
||||
@patch("onyx.server.manage.users.get_invited_users", return_value=[])
|
||||
@patch("onyx.server.manage.users.get_all_users", return_value=[])
|
||||
@patch("onyx.server.manage.users.write_invited_users", return_value=3)
|
||||
@patch("onyx.server.manage.users.enforce_seat_limit")
|
||||
@patch("onyx.server.manage.users.NUM_FREE_TRIAL_USER_INVITES", 5)
|
||||
@patch(
|
||||
"onyx.server.manage.users.fetch_ee_implementation_or_noop",
|
||||
@@ -47,69 +44,4 @@ def test_trial_tenant_can_invite_within_limit(*_mocks: None) -> None:
|
||||
|
||||
result = bulk_invite_users(emails=emails)
|
||||
|
||||
assert result.invited_count == 3
|
||||
assert result.email_invite_status == EmailInviteStatus.DISABLED
|
||||
|
||||
|
||||
# --- email_invite_status tests ---
|
||||
|
||||
_COMMON_PATCHES = [
|
||||
patch("onyx.server.manage.users.MULTI_TENANT", False),
|
||||
patch("onyx.server.manage.users.get_current_tenant_id", return_value="test_tenant"),
|
||||
patch("onyx.server.manage.users.get_invited_users", return_value=[]),
|
||||
patch("onyx.server.manage.users.get_all_users", return_value=[]),
|
||||
patch("onyx.server.manage.users.write_invited_users", return_value=1),
|
||||
patch("onyx.server.manage.users.enforce_seat_limit"),
|
||||
]
|
||||
|
||||
|
||||
def _with_common_patches(fn: object) -> object:
|
||||
for p in reversed(_COMMON_PATCHES):
|
||||
fn = p(fn) # type: ignore
|
||||
return fn
|
||||
|
||||
|
||||
@_with_common_patches
|
||||
@patch("onyx.server.manage.users.ENABLE_EMAIL_INVITES", False)
|
||||
def test_email_invite_status_disabled(*_mocks: None) -> None:
|
||||
"""When email invites are disabled, status is disabled."""
|
||||
result = bulk_invite_users(emails=["user@example.com"])
|
||||
|
||||
assert result.email_invite_status == EmailInviteStatus.DISABLED
|
||||
|
||||
|
||||
@_with_common_patches
|
||||
@patch("onyx.server.manage.users.ENABLE_EMAIL_INVITES", True)
|
||||
@patch("onyx.server.manage.users.EMAIL_CONFIGURED", False)
|
||||
def test_email_invite_status_not_configured(*_mocks: None) -> None:
|
||||
"""When email invites are enabled but no server is configured, status is not_configured."""
|
||||
result = bulk_invite_users(emails=["user@example.com"])
|
||||
|
||||
assert result.email_invite_status == EmailInviteStatus.NOT_CONFIGURED
|
||||
|
||||
|
||||
@_with_common_patches
|
||||
@patch("onyx.server.manage.users.ENABLE_EMAIL_INVITES", True)
|
||||
@patch("onyx.server.manage.users.EMAIL_CONFIGURED", True)
|
||||
@patch("onyx.server.manage.users.send_user_email_invite")
|
||||
def test_email_invite_status_sent(mock_send: MagicMock, *_mocks: None) -> None:
|
||||
"""When email invites are enabled and configured, status is sent."""
|
||||
result = bulk_invite_users(emails=["user@example.com"])
|
||||
|
||||
mock_send.assert_called_once()
|
||||
assert result.email_invite_status == EmailInviteStatus.SENT
|
||||
|
||||
|
||||
@_with_common_patches
|
||||
@patch("onyx.server.manage.users.ENABLE_EMAIL_INVITES", True)
|
||||
@patch("onyx.server.manage.users.EMAIL_CONFIGURED", True)
|
||||
@patch(
|
||||
"onyx.server.manage.users.send_user_email_invite",
|
||||
side_effect=Exception("SMTP auth failed"),
|
||||
)
|
||||
def test_email_invite_status_send_failed(*_mocks: None) -> None:
|
||||
"""When email sending throws, status is send_failed and invite is still saved."""
|
||||
result = bulk_invite_users(emails=["user@example.com"])
|
||||
|
||||
assert result.email_invite_status == EmailInviteStatus.SEND_FAILED
|
||||
assert result.invited_count == 1
|
||||
assert result == 3
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
@@ -13,9 +12,7 @@ import pytest
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.scim.api import ScimJSONResponse
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimListResponse
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
@@ -118,11 +115,6 @@ def make_user_mapping(**kwargs: Any) -> MagicMock:
|
||||
mapping.external_id = kwargs.get("external_id", "ext-default")
|
||||
mapping.user_id = kwargs.get("user_id", uuid4())
|
||||
mapping.scim_username = kwargs.get("scim_username", None)
|
||||
mapping.department = kwargs.get("department", None)
|
||||
mapping.manager = kwargs.get("manager", None)
|
||||
mapping.given_name = kwargs.get("given_name", None)
|
||||
mapping.family_name = kwargs.get("family_name", None)
|
||||
mapping.scim_emails_json = kwargs.get("scim_emails_json", None)
|
||||
return mapping
|
||||
|
||||
|
||||
@@ -130,35 +122,3 @@ def assert_scim_error(result: object, expected_status: int) -> None:
|
||||
"""Assert *result* is a JSONResponse with the given status code."""
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == expected_status
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response parsing helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def parse_scim_user(result: object, *, status: int = 200) -> ScimUserResource:
|
||||
"""Assert *result* is a ScimJSONResponse and parse as ScimUserResource."""
|
||||
assert isinstance(
|
||||
result, ScimJSONResponse
|
||||
), f"Expected ScimJSONResponse, got {type(result).__name__}"
|
||||
assert result.status_code == status
|
||||
return ScimUserResource.model_validate(json.loads(result.body))
|
||||
|
||||
|
||||
def parse_scim_group(result: object, *, status: int = 200) -> ScimGroupResource:
|
||||
"""Assert *result* is a ScimJSONResponse and parse as ScimGroupResource."""
|
||||
assert isinstance(
|
||||
result, ScimJSONResponse
|
||||
), f"Expected ScimJSONResponse, got {type(result).__name__}"
|
||||
assert result.status_code == status
|
||||
return ScimGroupResource.model_validate(json.loads(result.body))
|
||||
|
||||
|
||||
def parse_scim_list(result: object) -> ScimListResponse:
|
||||
"""Assert *result* is a ScimJSONResponse and parse as ScimListResponse."""
|
||||
assert isinstance(
|
||||
result, ScimJSONResponse
|
||||
), f"Expected ScimJSONResponse, got {type(result).__name__}"
|
||||
assert result.status_code == 200
|
||||
return ScimListResponse.model_validate(json.loads(result.body))
|
||||
|
||||
@@ -1,983 +0,0 @@
|
||||
"""Comprehensive Entra ID (Azure AD) SCIM compatibility tests.
|
||||
|
||||
Covers the full Entra provisioning lifecycle: service discovery, user CRUD
|
||||
with enterprise extension schema, group CRUD with excludedAttributes, and
|
||||
all Entra-specific behavioral quirks (PascalCase ops, enterprise URN in
|
||||
PATCH value dicts).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import Response
|
||||
|
||||
from ee.onyx.server.scim.api import create_user
|
||||
from ee.onyx.server.scim.api import delete_user
|
||||
from ee.onyx.server.scim.api import get_group
|
||||
from ee.onyx.server.scim.api import get_resource_types
|
||||
from ee.onyx.server.scim.api import get_schemas
|
||||
from ee.onyx.server.scim.api import get_service_provider_config
|
||||
from ee.onyx.server.scim.api import get_user
|
||||
from ee.onyx.server.scim.api import list_groups
|
||||
from ee.onyx.server.scim.api import list_users
|
||||
from ee.onyx.server.scim.api import patch_group
|
||||
from ee.onyx.server.scim.api import patch_user
|
||||
from ee.onyx.server.scim.api import replace_user
|
||||
from ee.onyx.server.scim.api import ScimJSONResponse
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimEnterpriseExtension
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimManagerRef
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
from ee.onyx.server.scim.models import ScimPatchRequest
|
||||
from ee.onyx.server.scim.models import ScimPatchResourceValue
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
from ee.onyx.server.scim.providers.entra import EntraProvider
|
||||
from tests.unit.onyx.server.scim.conftest import make_db_group
|
||||
from tests.unit.onyx.server.scim.conftest import make_db_user
|
||||
from tests.unit.onyx.server.scim.conftest import make_scim_user
|
||||
from tests.unit.onyx.server.scim.conftest import make_user_mapping
|
||||
from tests.unit.onyx.server.scim.conftest import parse_scim_group
|
||||
from tests.unit.onyx.server.scim.conftest import parse_scim_list
|
||||
from tests.unit.onyx.server.scim.conftest import parse_scim_user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def entra_provider() -> ScimProvider:
|
||||
"""An EntraProvider instance for Entra-specific endpoint tests."""
|
||||
return EntraProvider()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service Discovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEntraServiceDiscovery:
|
||||
"""Entra expects enterprise extension in discovery endpoints."""
|
||||
|
||||
def test_service_provider_config_advertises_patch(self) -> None:
|
||||
config = get_service_provider_config()
|
||||
assert config.patch.supported is True
|
||||
|
||||
def test_resource_types_include_enterprise_extension(self) -> None:
|
||||
result = get_resource_types()
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
assert "Resources" in parsed
|
||||
user_type = next(rt for rt in parsed["Resources"] if rt["id"] == "User")
|
||||
extension_schemas = [ext["schema"] for ext in user_type["schemaExtensions"]]
|
||||
assert SCIM_ENTERPRISE_USER_SCHEMA in extension_schemas
|
||||
|
||||
def test_schemas_include_enterprise_user(self) -> None:
|
||||
result = get_schemas()
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
schema_ids = [s["id"] for s in parsed["Resources"]]
|
||||
assert SCIM_ENTERPRISE_USER_SCHEMA in schema_ids
|
||||
|
||||
def test_enterprise_schema_has_expected_attributes(self) -> None:
|
||||
result = get_schemas()
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
enterprise = next(
|
||||
s for s in parsed["Resources"] if s["id"] == SCIM_ENTERPRISE_USER_SCHEMA
|
||||
)
|
||||
attr_names = {a["name"] for a in enterprise["attributes"]}
|
||||
assert "department" in attr_names
|
||||
assert "manager" in attr_names
|
||||
|
||||
def test_service_discovery_content_type(self) -> None:
|
||||
"""SCIM responses must use application/scim+json content type."""
|
||||
result = get_resource_types()
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
assert result.media_type == "application/scim+json"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User Lifecycle (Entra-specific)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEntraUserLifecycle:
|
||||
"""Test user CRUD through Entra's lens: enterprise schemas, PascalCase ops."""
|
||||
|
||||
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
|
||||
def test_create_user_includes_enterprise_schema(
|
||||
self,
|
||||
mock_seats: MagicMock, # noqa: ARG002
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_user_by_email.return_value = None
|
||||
resource = make_scim_user(userName="alice@contoso.com")
|
||||
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result, status=201)
|
||||
assert SCIM_ENTERPRISE_USER_SCHEMA in resource.schemas
|
||||
assert SCIM_USER_SCHEMA in resource.schemas
|
||||
|
||||
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
|
||||
def test_create_user_with_enterprise_extension(
|
||||
self,
|
||||
mock_seats: MagicMock, # noqa: ARG002
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Enterprise extension department/manager should round-trip on create."""
|
||||
mock_dal.get_user_by_email.return_value = None
|
||||
resource = make_scim_user(
|
||||
userName="alice@contoso.com",
|
||||
enterprise_extension=ScimEnterpriseExtension(
|
||||
department="Engineering",
|
||||
manager=ScimManagerRef(value="mgr-uuid-123"),
|
||||
),
|
||||
)
|
||||
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result, status=201)
|
||||
assert resource.enterprise_extension is not None
|
||||
assert resource.enterprise_extension.department == "Engineering"
|
||||
assert resource.enterprise_extension.manager is not None
|
||||
assert resource.enterprise_extension.manager.value == "mgr-uuid-123"
|
||||
|
||||
# Verify DAL received the enterprise fields
|
||||
mock_dal.create_user_mapping.assert_called_once()
|
||||
call_kwargs = mock_dal.create_user_mapping.call_args[1]
|
||||
assert call_kwargs["fields"] == ScimMappingFields(
|
||||
department="Engineering",
|
||||
manager="mgr-uuid-123",
|
||||
given_name="Test",
|
||||
family_name="User",
|
||||
)
|
||||
|
||||
def test_get_user_includes_enterprise_schema(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user(email="alice@contoso.com")
|
||||
mock_dal.get_user.return_value = user
|
||||
|
||||
result = get_user(
|
||||
user_id=str(user.id),
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result)
|
||||
assert SCIM_ENTERPRISE_USER_SCHEMA in resource.schemas
|
||||
|
||||
def test_get_user_returns_enterprise_extension_data(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""GET should return stored enterprise extension data."""
|
||||
user = make_db_user(email="alice@contoso.com")
|
||||
mock_dal.get_user.return_value = user
|
||||
mapping = make_user_mapping(user_id=user.id)
|
||||
mapping.department = "Sales"
|
||||
mapping.manager = "mgr-456"
|
||||
mock_dal.get_user_mapping_by_user_id.return_value = mapping
|
||||
|
||||
result = get_user(
|
||||
user_id=str(user.id),
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result)
|
||||
assert resource.enterprise_extension is not None
|
||||
assert resource.enterprise_extension.department == "Sales"
|
||||
assert resource.enterprise_extension.manager is not None
|
||||
assert resource.enterprise_extension.manager.value == "mgr-456"
|
||||
|
||||
def test_list_users_includes_enterprise_schema(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user(email="alice@contoso.com")
|
||||
mapping = make_user_mapping(external_id="entra-ext-1", user_id=user.id)
|
||||
mock_dal.list_users.return_value = ([(user, mapping)], 1)
|
||||
|
||||
result = list_users(
|
||||
filter=None,
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parsed = parse_scim_list(result)
|
||||
resource = parsed.Resources[0]
|
||||
assert isinstance(resource, ScimUserResource)
|
||||
assert SCIM_ENTERPRISE_USER_SCHEMA in resource.schemas
|
||||
|
||||
def test_patch_user_deactivate_with_pascal_case_replace(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra sends ``"Replace"`` (PascalCase) instead of ``"replace"``."""
|
||||
user = make_db_user(is_active=True)
|
||||
mock_dal.get_user.return_value = user
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op="Replace", # type: ignore[arg-type]
|
||||
path="active",
|
||||
value=False,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
# Mock doesn't propagate the change, so verify via the DAL call
|
||||
mock_dal.update_user.assert_called_once()
|
||||
call_kwargs = mock_dal.update_user.call_args
|
||||
assert call_kwargs[1]["is_active"] is False
|
||||
|
||||
def test_patch_user_add_external_id_with_pascal_case(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra sends ``"Add"`` (PascalCase) instead of ``"add"``."""
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op="Add", # type: ignore[arg-type]
|
||||
path="externalId",
|
||||
value="entra-ext-999",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
# Verify the patched externalId was synced to the DAL
|
||||
mock_dal.sync_user_external_id.assert_called_once()
|
||||
call_args = mock_dal.sync_user_external_id.call_args
|
||||
assert call_args[0][1] == "entra-ext-999"
|
||||
|
||||
def test_patch_user_enterprise_extension_in_value_dict(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra sends enterprise extension URN as key in path-less PATCH value
|
||||
dicts — enterprise data should be stored, not ignored."""
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
|
||||
value = ScimPatchResourceValue(active=False)
|
||||
assert value.__pydantic_extra__ is not None
|
||||
value.__pydantic_extra__[
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
||||
] = {"department": "Engineering"}
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.REPLACE,
|
||||
path=None,
|
||||
value=value,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
# Verify active=False was applied
|
||||
mock_dal.update_user.assert_called_once()
|
||||
call_kwargs = mock_dal.update_user.call_args
|
||||
assert call_kwargs[1]["is_active"] is False
|
||||
# Verify enterprise data was passed to DAL
|
||||
mock_dal.sync_user_external_id.assert_called_once()
|
||||
sync_kwargs = mock_dal.sync_user_external_id.call_args[1]
|
||||
assert sync_kwargs["fields"] == ScimMappingFields(
|
||||
department="Engineering",
|
||||
given_name="Test",
|
||||
family_name="User",
|
||||
scim_emails_json='[{"value": "test@example.com", "type": "work", "primary": true}]',
|
||||
)
|
||||
|
||||
def test_patch_user_remove_external_id(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""PATCH remove op should clear the target field."""
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
mapping = make_user_mapping(user_id=user.id)
|
||||
mapping.external_id = "ext-to-remove"
|
||||
mock_dal.get_user_mapping_by_user_id.return_value = mapping
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.REMOVE,
|
||||
path="externalId",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
# externalId should be cleared (None)
|
||||
mock_dal.sync_user_external_id.assert_called_once()
|
||||
call_args = mock_dal.sync_user_external_id.call_args
|
||||
assert call_args[0][1] is None
|
||||
|
||||
def test_patch_user_emails_primary_eq_true_value(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""PATCH with path emails[primary eq true].value should update
|
||||
the primary email entry, not userName."""
|
||||
user = make_db_user(email="old@contoso.com")
|
||||
mock_dal.get_user.return_value = user
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.REPLACE,
|
||||
path="emails[primary eq true].value",
|
||||
value="new@contoso.com",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result)
|
||||
# userName should remain unchanged — emails and userName are separate
|
||||
assert resource.userName == "old@contoso.com"
|
||||
# Primary email should be updated
|
||||
primary_emails = [e for e in resource.emails if e.primary]
|
||||
assert len(primary_emails) == 1
|
||||
assert primary_emails[0].value == "new@contoso.com"
|
||||
|
||||
def test_patch_user_enterprise_urn_department_path(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""PATCH with dotted enterprise URN path should store department."""
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.REPLACE,
|
||||
path="urn:ietf:params:scim:schemas:extension:enterprise:2.0:User:department",
|
||||
value="Marketing",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
mock_dal.sync_user_external_id.assert_called_once()
|
||||
sync_kwargs = mock_dal.sync_user_external_id.call_args[1]
|
||||
assert sync_kwargs["fields"] == ScimMappingFields(
|
||||
department="Marketing",
|
||||
given_name="Test",
|
||||
family_name="User",
|
||||
scim_emails_json='[{"value": "test@example.com", "type": "work", "primary": true}]',
|
||||
)
|
||||
|
||||
def test_replace_user_includes_enterprise_schema(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user(email="old@contoso.com")
|
||||
mock_dal.get_user.return_value = user
|
||||
resource = make_scim_user(
|
||||
userName="new@contoso.com",
|
||||
name=ScimName(givenName="New", familyName="Name"),
|
||||
)
|
||||
|
||||
result = replace_user(
|
||||
user_id=str(user.id),
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result)
|
||||
assert SCIM_ENTERPRISE_USER_SCHEMA in resource.schemas
|
||||
|
||||
def test_replace_user_with_enterprise_extension(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""PUT with enterprise extension should store the fields."""
|
||||
user = make_db_user(email="alice@contoso.com")
|
||||
mock_dal.get_user.return_value = user
|
||||
resource = make_scim_user(
|
||||
userName="alice@contoso.com",
|
||||
enterprise_extension=ScimEnterpriseExtension(
|
||||
department="HR",
|
||||
manager=ScimManagerRef(value="boss-id"),
|
||||
),
|
||||
)
|
||||
|
||||
result = replace_user(
|
||||
user_id=str(user.id),
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
mock_dal.sync_user_external_id.assert_called_once()
|
||||
sync_kwargs = mock_dal.sync_user_external_id.call_args[1]
|
||||
assert sync_kwargs["fields"] == ScimMappingFields(
|
||||
department="HR",
|
||||
manager="boss-id",
|
||||
given_name="Test",
|
||||
family_name="User",
|
||||
)
|
||||
|
||||
def test_delete_user_returns_204(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
mock_dal.get_user_mapping_by_user_id.return_value = MagicMock(id=1)
|
||||
|
||||
result = delete_user(
|
||||
user_id=str(user.id),
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, Response)
|
||||
assert result.status_code == 204
|
||||
|
||||
def test_double_delete_returns_404(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
) -> None:
|
||||
"""Second DELETE should return 404 — the SCIM mapping is gone."""
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
# No mapping — user was already deleted from SCIM's perspective
|
||||
mock_dal.get_user_mapping_by_user_id.return_value = None
|
||||
|
||||
result = delete_user(
|
||||
user_id=str(user.id),
|
||||
_token=mock_token,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
assert result.status_code == 404
|
||||
|
||||
def test_name_formatted_preserved_on_create(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""When name.formatted is provided, it should be used as personal_name."""
|
||||
mock_dal.get_user_by_email.return_value = None
|
||||
resource = make_scim_user(
|
||||
userName="alice@contoso.com",
|
||||
name=ScimName(
|
||||
givenName="Alice",
|
||||
familyName="Smith",
|
||||
formatted="Dr. Alice Smith",
|
||||
),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"ee.onyx.server.scim.api._check_seat_availability", return_value=None
|
||||
):
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result, status=201)
|
||||
# The User constructor should have received the formatted name
|
||||
mock_dal.add_user.assert_called_once()
|
||||
created_user = mock_dal.add_user.call_args[0][0]
|
||||
assert created_user.personal_name == "Dr. Alice Smith"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Group Lifecycle (Entra-specific)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEntraGroupLifecycle:
|
||||
"""Test group CRUD with Entra-specific behaviors."""
|
||||
|
||||
def test_get_group_standard_response(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=10, name="Contoso Engineering")
|
||||
mock_dal.get_group.return_value = group
|
||||
uid = uuid4()
|
||||
mock_dal.get_group_members.return_value = [(uid, "alice@contoso.com")]
|
||||
|
||||
result = get_group(
|
||||
group_id="10",
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_group(result)
|
||||
assert resource.displayName == "Contoso Engineering"
|
||||
assert len(resource.members) == 1
|
||||
|
||||
def test_list_groups_with_excluded_attributes_members(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra sends ?excludedAttributes=members on group list queries."""
|
||||
group = make_db_group(id=10, name="Engineering")
|
||||
uid = uuid4()
|
||||
mock_dal.list_groups.return_value = ([(group, "ext-g-1")], 1)
|
||||
mock_dal.get_group_members.return_value = [(uid, "alice@contoso.com")]
|
||||
|
||||
result = list_groups(
|
||||
filter=None,
|
||||
excludedAttributes="members",
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
assert parsed["totalResults"] == 1
|
||||
resource = parsed["Resources"][0]
|
||||
assert "members" not in resource
|
||||
assert resource["displayName"] == "Engineering"
|
||||
|
||||
def test_get_group_with_excluded_attributes_members(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra sends ?excludedAttributes=members on single group GET."""
|
||||
group = make_db_group(id=10, name="Engineering")
|
||||
uid = uuid4()
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = [(uid, "alice@contoso.com")]
|
||||
|
||||
result = get_group(
|
||||
group_id="10",
|
||||
excludedAttributes="members",
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
assert "members" not in parsed
|
||||
assert parsed["displayName"] == "Engineering"
|
||||
|
||||
@patch("ee.onyx.server.scim.api.apply_group_patch")
|
||||
def test_patch_group_add_members_with_pascal_case(
|
||||
self,
|
||||
mock_apply: MagicMock,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra sends ``"Add"`` (PascalCase) for group member additions."""
|
||||
group = make_db_group(id=10)
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = []
|
||||
mock_dal.validate_member_ids.return_value = []
|
||||
|
||||
uid = str(uuid4())
|
||||
patched = ScimGroupResource(
|
||||
id="10",
|
||||
displayName="Engineering",
|
||||
members=[ScimGroupMember(value=uid)],
|
||||
)
|
||||
mock_apply.return_value = (patched, [uid], [])
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op="Add", # type: ignore[arg-type]
|
||||
path="members",
|
||||
value=[ScimGroupMember(value=uid)],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_group(
|
||||
group_id="10",
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_group(result)
|
||||
mock_dal.upsert_group_members.assert_called_once()
|
||||
|
||||
@patch("ee.onyx.server.scim.api.apply_group_patch")
|
||||
def test_patch_group_remove_member_with_pascal_case(
|
||||
self,
|
||||
mock_apply: MagicMock,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra sends ``"Remove"`` (PascalCase) for group member removals."""
|
||||
group = make_db_group(id=10)
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = []
|
||||
|
||||
uid = str(uuid4())
|
||||
patched = ScimGroupResource(id="10", displayName="Engineering", members=[])
|
||||
mock_apply.return_value = (patched, [], [uid])
|
||||
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op="Remove", # type: ignore[arg-type]
|
||||
path=f'members[value eq "{uid}"]',
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_group(
|
||||
group_id="10",
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_group(result)
|
||||
mock_dal.remove_group_members.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# excludedAttributes (RFC 7644 §3.4.2.5)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExcludedAttributes:
|
||||
"""Test excludedAttributes query parameter on GET endpoints."""
|
||||
|
||||
def test_list_groups_excludes_members(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=1, name="Team")
|
||||
uid = uuid4()
|
||||
mock_dal.list_groups.return_value = ([(group, None)], 1)
|
||||
mock_dal.get_group_members.return_value = [(uid, "user@example.com")]
|
||||
|
||||
result = list_groups(
|
||||
filter=None,
|
||||
excludedAttributes="members",
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
resource = parsed["Resources"][0]
|
||||
assert "members" not in resource
|
||||
assert "displayName" in resource
|
||||
|
||||
def test_get_group_excludes_members(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=1, name="Team")
|
||||
uid = uuid4()
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = [(uid, "user@example.com")]
|
||||
|
||||
result = get_group(
|
||||
group_id="1",
|
||||
excludedAttributes="members",
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
assert "members" not in parsed
|
||||
assert "displayName" in parsed
|
||||
|
||||
def test_list_users_excludes_groups(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user()
|
||||
mapping = make_user_mapping(user_id=user.id)
|
||||
mock_dal.list_users.return_value = ([(user, mapping)], 1)
|
||||
mock_dal.get_users_groups_batch.return_value = {user.id: [(1, "Engineering")]}
|
||||
|
||||
result = list_users(
|
||||
filter=None,
|
||||
excludedAttributes="groups",
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
resource = parsed["Resources"][0]
|
||||
assert "groups" not in resource
|
||||
assert "userName" in resource
|
||||
|
||||
def test_get_user_excludes_groups(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
mock_dal.get_user_groups.return_value = [(1, "Engineering")]
|
||||
|
||||
result = get_user(
|
||||
user_id=str(user.id),
|
||||
excludedAttributes="groups",
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
assert "groups" not in parsed
|
||||
assert "userName" in parsed
|
||||
|
||||
def test_multiple_excluded_attributes(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=1, name="Team")
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = []
|
||||
|
||||
result = get_group(
|
||||
group_id="1",
|
||||
excludedAttributes="members,externalId",
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimJSONResponse)
|
||||
parsed = json.loads(result.body)
|
||||
assert "members" not in parsed
|
||||
assert "externalId" not in parsed
|
||||
assert "displayName" in parsed
|
||||
|
||||
def test_no_excluded_attributes_returns_full_response(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=1, name="Team")
|
||||
uid = uuid4()
|
||||
mock_dal.get_group.return_value = group
|
||||
mock_dal.get_group_members.return_value = [(uid, "user@example.com")]
|
||||
|
||||
result = get_group(
|
||||
group_id="1",
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_group(result)
|
||||
assert len(resource.members) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entra Connection Probe
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEntraConnectionProbe:
|
||||
"""Entra sends a probe request during initial SCIM setup."""
|
||||
|
||||
def test_filter_for_nonexistent_user_returns_empty_list(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
entra_provider: ScimProvider,
|
||||
) -> None:
|
||||
"""Entra probes with: GET /Users?filter=userName eq "non-existent"&count=1"""
|
||||
mock_dal.list_users.return_value = ([], 0)
|
||||
|
||||
result = list_users(
|
||||
filter='userName eq "non-existent@contoso.com"',
|
||||
startIndex=1,
|
||||
count=1,
|
||||
_token=mock_token,
|
||||
provider=entra_provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parsed = parse_scim_list(result)
|
||||
assert parsed.totalResults == 0
|
||||
assert parsed.Resources == []
|
||||
@@ -16,6 +16,7 @@ from ee.onyx.server.scim.api import patch_group
|
||||
from ee.onyx.server.scim.api import replace_group
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimListResponse
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
from ee.onyx.server.scim.models import ScimPatchRequest
|
||||
@@ -24,8 +25,6 @@ from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
from tests.unit.onyx.server.scim.conftest import assert_scim_error
|
||||
from tests.unit.onyx.server.scim.conftest import make_db_group
|
||||
from tests.unit.onyx.server.scim.conftest import make_scim_group
|
||||
from tests.unit.onyx.server.scim.conftest import parse_scim_group
|
||||
from tests.unit.onyx.server.scim.conftest import parse_scim_list
|
||||
|
||||
|
||||
class TestListGroups:
|
||||
@@ -49,9 +48,9 @@ class TestListGroups:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parsed = parse_scim_list(result)
|
||||
assert parsed.totalResults == 0
|
||||
assert parsed.Resources == []
|
||||
assert isinstance(result, ScimListResponse)
|
||||
assert result.totalResults == 0
|
||||
assert result.Resources == []
|
||||
|
||||
def test_unsupported_filter_returns_400(
|
||||
self,
|
||||
@@ -96,9 +95,9 @@ class TestListGroups:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parsed = parse_scim_list(result)
|
||||
assert parsed.totalResults == 1
|
||||
resource = parsed.Resources[0]
|
||||
assert isinstance(result, ScimListResponse)
|
||||
assert result.totalResults == 1
|
||||
resource = result.Resources[0]
|
||||
assert isinstance(resource, ScimGroupResource)
|
||||
assert resource.displayName == "Engineering"
|
||||
assert resource.externalId == "ext-g-1"
|
||||
@@ -127,9 +126,9 @@ class TestGetGroup:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_group(result)
|
||||
assert resource.displayName == "Engineering"
|
||||
assert resource.id == "5"
|
||||
assert isinstance(result, ScimGroupResource)
|
||||
assert result.displayName == "Engineering"
|
||||
assert result.id == "5"
|
||||
|
||||
def test_non_integer_id_returns_404(
|
||||
self,
|
||||
@@ -191,8 +190,8 @@ class TestCreateGroup:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_group(result, status=201)
|
||||
assert resource.displayName == "New Group"
|
||||
assert isinstance(result, ScimGroupResource)
|
||||
assert result.displayName == "New Group"
|
||||
mock_dal.add_group.assert_called_once()
|
||||
mock_dal.commit.assert_called_once()
|
||||
|
||||
@@ -284,7 +283,7 @@ class TestCreateGroup:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_group(result, status=201)
|
||||
assert isinstance(result, ScimGroupResource)
|
||||
mock_dal.create_group_mapping.assert_called_once()
|
||||
|
||||
|
||||
@@ -315,7 +314,7 @@ class TestReplaceGroup:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_group(result)
|
||||
assert isinstance(result, ScimGroupResource)
|
||||
mock_dal.update_group.assert_called_once_with(group, name="New Name")
|
||||
mock_dal.replace_group_members.assert_called_once()
|
||||
mock_dal.commit.assert_called_once()
|
||||
@@ -428,7 +427,7 @@ class TestPatchGroup:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_group(result)
|
||||
assert isinstance(result, ScimGroupResource)
|
||||
mock_dal.update_group.assert_called_once_with(group, name="New Name")
|
||||
|
||||
def test_not_found_returns_404(
|
||||
@@ -535,7 +534,7 @@ class TestPatchGroup:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_group(result)
|
||||
assert isinstance(result, ScimGroupResource)
|
||||
mock_dal.validate_member_ids.assert_called_once()
|
||||
mock_dal.upsert_group_members.assert_called_once()
|
||||
|
||||
@@ -615,7 +614,7 @@ class TestPatchGroup:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_group(result)
|
||||
assert isinstance(result, ScimGroupResource)
|
||||
mock_dal.remove_group_members.assert_called_once()
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import pytest
|
||||
|
||||
from ee.onyx.server.scim.models import ScimEmail
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimMeta
|
||||
@@ -13,11 +12,9 @@ from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.patch import apply_group_patch
|
||||
from ee.onyx.server.scim.patch import apply_user_patch
|
||||
from ee.onyx.server.scim.patch import ScimPatchError
|
||||
from ee.onyx.server.scim.providers.entra import EntraProvider
|
||||
from ee.onyx.server.scim.providers.okta import OktaProvider
|
||||
|
||||
_OKTA_IGNORED = OktaProvider().ignored_patch_paths
|
||||
_ENTRA_IGNORED = EntraProvider().ignored_patch_paths
|
||||
|
||||
|
||||
def _make_user(**kwargs: object) -> ScimUserResource:
|
||||
@@ -59,36 +56,36 @@ class TestApplyUserPatch:
|
||||
|
||||
def test_deactivate_user(self) -> None:
|
||||
user = _make_user()
|
||||
result, _ = apply_user_patch([_replace_op("active", False)], user)
|
||||
result = apply_user_patch([_replace_op("active", False)], user)
|
||||
assert result.active is False
|
||||
assert result.userName == "test@example.com"
|
||||
|
||||
def test_activate_user(self) -> None:
|
||||
user = _make_user(active=False)
|
||||
result, _ = apply_user_patch([_replace_op("active", True)], user)
|
||||
result = apply_user_patch([_replace_op("active", True)], user)
|
||||
assert result.active is True
|
||||
|
||||
def test_replace_given_name(self) -> None:
|
||||
user = _make_user()
|
||||
result, _ = apply_user_patch([_replace_op("name.givenName", "NewFirst")], user)
|
||||
result = apply_user_patch([_replace_op("name.givenName", "NewFirst")], user)
|
||||
assert result.name is not None
|
||||
assert result.name.givenName == "NewFirst"
|
||||
assert result.name.familyName == "User"
|
||||
|
||||
def test_replace_family_name(self) -> None:
|
||||
user = _make_user()
|
||||
result, _ = apply_user_patch([_replace_op("name.familyName", "NewLast")], user)
|
||||
result = apply_user_patch([_replace_op("name.familyName", "NewLast")], user)
|
||||
assert result.name is not None
|
||||
assert result.name.familyName == "NewLast"
|
||||
|
||||
def test_replace_username(self) -> None:
|
||||
user = _make_user()
|
||||
result, _ = apply_user_patch([_replace_op("userName", "new@example.com")], user)
|
||||
result = apply_user_patch([_replace_op("userName", "new@example.com")], user)
|
||||
assert result.userName == "new@example.com"
|
||||
|
||||
def test_replace_without_path_uses_dict(self) -> None:
|
||||
user = _make_user()
|
||||
result, _ = apply_user_patch(
|
||||
result = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
None,
|
||||
@@ -102,7 +99,7 @@ class TestApplyUserPatch:
|
||||
|
||||
def test_multiple_operations(self) -> None:
|
||||
user = _make_user()
|
||||
result, _ = apply_user_patch(
|
||||
result = apply_user_patch(
|
||||
[
|
||||
_replace_op("active", False),
|
||||
_replace_op("name.givenName", "Updated"),
|
||||
@@ -115,7 +112,7 @@ class TestApplyUserPatch:
|
||||
|
||||
def test_case_insensitive_path(self) -> None:
|
||||
user = _make_user()
|
||||
result, _ = apply_user_patch([_replace_op("Active", False)], user)
|
||||
result = apply_user_patch([_replace_op("Active", False)], user)
|
||||
assert result.active is False
|
||||
|
||||
def test_original_not_mutated(self) -> None:
|
||||
@@ -128,22 +125,15 @@ class TestApplyUserPatch:
|
||||
with pytest.raises(ScimPatchError, match="Unsupported path"):
|
||||
apply_user_patch([_replace_op("unknownField", "value")], user)
|
||||
|
||||
def test_remove_op_clears_field(self) -> None:
|
||||
"""Remove op should clear the target field (not raise)."""
|
||||
user = _make_user(externalId="ext-123")
|
||||
result, _ = apply_user_patch([_remove_op("externalId")], user)
|
||||
assert result.externalId is None
|
||||
|
||||
def test_remove_unsupported_path_raises(self) -> None:
|
||||
"""Remove op on unsupported path (e.g. 'active') should raise."""
|
||||
def test_remove_op_on_user_raises(self) -> None:
|
||||
user = _make_user()
|
||||
with pytest.raises(ScimPatchError, match="Unsupported remove path"):
|
||||
with pytest.raises(ScimPatchError, match="Unsupported operation"):
|
||||
apply_user_patch([_remove_op("active")], user)
|
||||
|
||||
def test_replace_without_path_ignores_id(self) -> None:
|
||||
"""Okta sends 'id' alongside actual changes — it should be silently ignored."""
|
||||
user = _make_user()
|
||||
result, _ = apply_user_patch(
|
||||
result = apply_user_patch(
|
||||
[_replace_op(None, ScimPatchResourceValue(active=False, id="some-uuid"))],
|
||||
user,
|
||||
ignored_paths=_OKTA_IGNORED,
|
||||
@@ -153,7 +143,7 @@ class TestApplyUserPatch:
|
||||
def test_replace_without_path_ignores_schemas(self) -> None:
|
||||
"""The 'schemas' key in a value dict should be silently ignored."""
|
||||
user = _make_user()
|
||||
result, _ = apply_user_patch(
|
||||
result = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
None,
|
||||
@@ -171,7 +161,7 @@ class TestApplyUserPatch:
|
||||
def test_okta_deactivation_payload(self) -> None:
|
||||
"""Exact Okta deactivation payload: path-less replace with id + active."""
|
||||
user = _make_user()
|
||||
result, _ = apply_user_patch(
|
||||
result = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
None,
|
||||
@@ -186,7 +176,7 @@ class TestApplyUserPatch:
|
||||
|
||||
def test_replace_displayname(self) -> None:
|
||||
user = _make_user()
|
||||
result, _ = apply_user_patch(
|
||||
result = apply_user_patch(
|
||||
[_replace_op("displayName", "New Display Name")], user
|
||||
)
|
||||
assert result.displayName == "New Display Name"
|
||||
@@ -197,7 +187,7 @@ class TestApplyUserPatch:
|
||||
"""Okta sends id/schemas/meta alongside actual changes — complex types
|
||||
(lists, nested dicts) must not cause Pydantic validation errors."""
|
||||
user = _make_user()
|
||||
result, _ = apply_user_patch(
|
||||
result = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
None,
|
||||
@@ -217,101 +207,9 @@ class TestApplyUserPatch:
|
||||
|
||||
def test_add_operation_works_like_replace(self) -> None:
|
||||
user = _make_user()
|
||||
result, _ = apply_user_patch([_add_op("externalId", "ext-456")], user)
|
||||
result = apply_user_patch([_add_op("externalId", "ext-456")], user)
|
||||
assert result.externalId == "ext-456"
|
||||
|
||||
def test_entra_capitalized_replace_op(self) -> None:
|
||||
"""Entra ID sends ``"Replace"`` instead of ``"replace"``."""
|
||||
user = _make_user()
|
||||
op = ScimPatchOperation(op="Replace", path="active", value=False) # type: ignore[arg-type]
|
||||
result, _ = apply_user_patch([op], user)
|
||||
assert result.active is False
|
||||
|
||||
def test_entra_capitalized_add_op(self) -> None:
|
||||
"""Entra ID sends ``"Add"`` instead of ``"add"``."""
|
||||
user = _make_user()
|
||||
op = ScimPatchOperation(op="Add", path="externalId", value="ext-999") # type: ignore[arg-type]
|
||||
result, _ = apply_user_patch([op], user)
|
||||
assert result.externalId == "ext-999"
|
||||
|
||||
def test_entra_enterprise_extension_handled(self) -> None:
|
||||
"""Entra sends the enterprise extension URN as a key in path-less
|
||||
PATCH value dicts — enterprise data should be captured in ent_data."""
|
||||
user = _make_user()
|
||||
value = ScimPatchResourceValue(active=False)
|
||||
# Simulate Entra including the enterprise extension URN as extra data
|
||||
assert value.__pydantic_extra__ is not None
|
||||
value.__pydantic_extra__[
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
||||
] = {"department": "Engineering"}
|
||||
result, ent_data = apply_user_patch(
|
||||
[_replace_op(None, value)],
|
||||
user,
|
||||
ignored_paths=_ENTRA_IGNORED,
|
||||
)
|
||||
assert result.active is False
|
||||
assert result.userName == "test@example.com"
|
||||
assert ent_data["department"] == "Engineering"
|
||||
|
||||
def test_okta_handles_enterprise_extension_urn(self) -> None:
|
||||
"""Enterprise extension URN paths are handled universally, even
|
||||
for Okta — the data is captured in the enterprise data dict."""
|
||||
user = _make_user()
|
||||
value = ScimPatchResourceValue(active=False)
|
||||
assert value.__pydantic_extra__ is not None
|
||||
value.__pydantic_extra__[
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
||||
] = {"department": "Engineering"}
|
||||
result, ent_data = apply_user_patch(
|
||||
[_replace_op(None, value)],
|
||||
user,
|
||||
ignored_paths=_OKTA_IGNORED,
|
||||
)
|
||||
assert result.active is False
|
||||
assert ent_data["department"] == "Engineering"
|
||||
|
||||
def test_emails_primary_eq_true_value(self) -> None:
|
||||
"""emails[primary eq true].value should update the primary email entry."""
|
||||
user = _make_user(
|
||||
emails=[ScimEmail(value="old@example.com", type="work", primary=True)]
|
||||
)
|
||||
result, _ = apply_user_patch(
|
||||
[_replace_op("emails[primary eq true].value", "new@example.com")], user
|
||||
)
|
||||
# userName should remain unchanged — emails and userName are separate
|
||||
assert result.userName == "test@example.com"
|
||||
assert len(result.emails) == 1
|
||||
assert result.emails[0].value == "new@example.com"
|
||||
assert result.emails[0].primary is True
|
||||
|
||||
def test_enterprise_urn_department_path(self) -> None:
|
||||
"""Dotted enterprise URN path should set department in ent_data."""
|
||||
user = _make_user()
|
||||
_, ent_data = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User:department",
|
||||
"Marketing",
|
||||
)
|
||||
],
|
||||
user,
|
||||
)
|
||||
assert ent_data["department"] == "Marketing"
|
||||
|
||||
def test_enterprise_urn_manager_path(self) -> None:
|
||||
"""Dotted enterprise URN path for manager should set manager."""
|
||||
user = _make_user()
|
||||
_, ent_data = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User:manager",
|
||||
ScimPatchResourceValue.model_validate({"value": "boss-id"}),
|
||||
)
|
||||
],
|
||||
user,
|
||||
)
|
||||
assert ent_data["manager"] == "boss-id"
|
||||
|
||||
|
||||
class TestApplyGroupPatch:
|
||||
"""Tests for SCIM group PATCH operations."""
|
||||
|
||||
@@ -2,8 +2,6 @@ from unittest.mock import MagicMock
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimEmail
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
@@ -11,10 +9,7 @@ from ee.onyx.server.scim.models import ScimMeta
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimUserGroupRef
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
|
||||
from ee.onyx.server.scim.providers.base import get_default_provider
|
||||
from ee.onyx.server.scim.providers.entra import _ENTRA_IGNORED_PATCH_PATHS
|
||||
from ee.onyx.server.scim.providers.entra import EntraProvider
|
||||
from ee.onyx.server.scim.providers.okta import OktaProvider
|
||||
|
||||
|
||||
@@ -44,7 +39,9 @@ class TestOktaProvider:
|
||||
assert OktaProvider().name == "okta"
|
||||
|
||||
def test_ignored_patch_paths(self) -> None:
|
||||
assert OktaProvider().ignored_patch_paths == COMMON_IGNORED_PATCH_PATHS
|
||||
assert OktaProvider().ignored_patch_paths == frozenset(
|
||||
{"id", "schemas", "meta"}
|
||||
)
|
||||
|
||||
def test_build_user_resource_basic(self) -> None:
|
||||
provider = OktaProvider()
|
||||
@@ -63,12 +60,6 @@ class TestOktaProvider:
|
||||
meta=ScimMeta(resourceType="User"),
|
||||
)
|
||||
|
||||
def test_build_user_resource_has_core_schema_only(self) -> None:
|
||||
provider = OktaProvider()
|
||||
user = _make_mock_user()
|
||||
result = provider.build_user_resource(user, "ext-123")
|
||||
assert result.schemas == [SCIM_USER_SCHEMA]
|
||||
|
||||
def test_build_user_resource_with_groups(self) -> None:
|
||||
provider = OktaProvider()
|
||||
user = _make_mock_user()
|
||||
@@ -170,42 +161,6 @@ class TestOktaProvider:
|
||||
assert result.members == []
|
||||
|
||||
|
||||
class TestEntraProvider:
|
||||
def test_name(self) -> None:
|
||||
assert EntraProvider().name == "entra"
|
||||
|
||||
def test_ignored_patch_paths(self) -> None:
|
||||
paths = EntraProvider().ignored_patch_paths
|
||||
assert paths == _ENTRA_IGNORED_PATCH_PATHS
|
||||
# Enterprise extension URN is now handled (not ignored)
|
||||
assert paths >= COMMON_IGNORED_PATCH_PATHS
|
||||
|
||||
def test_build_user_resource_includes_enterprise_schema(self) -> None:
|
||||
provider = EntraProvider()
|
||||
user = _make_mock_user()
|
||||
result = provider.build_user_resource(user, "ext-entra-1")
|
||||
|
||||
assert result.schemas == [SCIM_USER_SCHEMA, SCIM_ENTERPRISE_USER_SCHEMA]
|
||||
|
||||
def test_build_user_resource_basic(self) -> None:
|
||||
provider = EntraProvider()
|
||||
user = _make_mock_user()
|
||||
result = provider.build_user_resource(user, "ext-entra-1")
|
||||
|
||||
assert result == ScimUserResource(
|
||||
schemas=[SCIM_USER_SCHEMA, SCIM_ENTERPRISE_USER_SCHEMA],
|
||||
id=str(user.id),
|
||||
externalId="ext-entra-1",
|
||||
userName="test@example.com",
|
||||
name=ScimName(givenName="Test", familyName="User", formatted="Test User"),
|
||||
displayName="Test User",
|
||||
emails=[ScimEmail(value="test@example.com", type="work", primary=True)],
|
||||
active=True,
|
||||
groups=[],
|
||||
meta=ScimMeta(resourceType="User"),
|
||||
)
|
||||
|
||||
|
||||
class TestGetDefaultProvider:
|
||||
def test_returns_okta(self) -> None:
|
||||
provider = get_default_provider()
|
||||
|
||||
@@ -16,7 +16,7 @@ from ee.onyx.server.scim.api import get_user
|
||||
from ee.onyx.server.scim.api import list_users
|
||||
from ee.onyx.server.scim.api import patch_user
|
||||
from ee.onyx.server.scim.api import replace_user
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from ee.onyx.server.scim.models import ScimListResponse
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
@@ -28,8 +28,6 @@ from tests.unit.onyx.server.scim.conftest import assert_scim_error
|
||||
from tests.unit.onyx.server.scim.conftest import make_db_user
|
||||
from tests.unit.onyx.server.scim.conftest import make_scim_user
|
||||
from tests.unit.onyx.server.scim.conftest import make_user_mapping
|
||||
from tests.unit.onyx.server.scim.conftest import parse_scim_list
|
||||
from tests.unit.onyx.server.scim.conftest import parse_scim_user
|
||||
|
||||
|
||||
class TestListUsers:
|
||||
@@ -53,9 +51,9 @@ class TestListUsers:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parsed = parse_scim_list(result)
|
||||
assert parsed.totalResults == 0
|
||||
assert parsed.Resources == []
|
||||
assert isinstance(result, ScimListResponse)
|
||||
assert result.totalResults == 0
|
||||
assert result.Resources == []
|
||||
|
||||
def test_returns_users_with_scim_shape(
|
||||
self,
|
||||
@@ -79,10 +77,10 @@ class TestListUsers:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parsed = parse_scim_list(result)
|
||||
assert parsed.totalResults == 1
|
||||
assert len(parsed.Resources) == 1
|
||||
resource = parsed.Resources[0]
|
||||
assert isinstance(result, ScimListResponse)
|
||||
assert result.totalResults == 1
|
||||
assert len(result.Resources) == 1
|
||||
resource = result.Resources[0]
|
||||
assert isinstance(resource, ScimUserResource)
|
||||
assert resource.userName == "Alice@example.com"
|
||||
assert resource.externalId == "ext-abc"
|
||||
@@ -148,9 +146,9 @@ class TestGetUser:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result)
|
||||
assert resource.userName == "alice@example.com"
|
||||
assert resource.id == str(user.id)
|
||||
assert isinstance(result, ScimUserResource)
|
||||
assert result.userName == "alice@example.com"
|
||||
assert result.id == str(user.id)
|
||||
|
||||
def test_invalid_uuid_returns_404(
|
||||
self,
|
||||
@@ -209,8 +207,8 @@ class TestCreateUser:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result, status=201)
|
||||
assert resource.userName == "new@example.com"
|
||||
assert isinstance(result, ScimUserResource)
|
||||
assert result.userName == "new@example.com"
|
||||
mock_dal.add_user.assert_called_once()
|
||||
mock_dal.commit.assert_called_once()
|
||||
|
||||
@@ -316,8 +314,8 @@ class TestCreateUser:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result, status=201)
|
||||
assert resource.externalId == "ext-123"
|
||||
assert isinstance(result, ScimUserResource)
|
||||
assert result.externalId == "ext-123"
|
||||
mock_dal.create_user_mapping.assert_called_once()
|
||||
|
||||
|
||||
@@ -346,7 +344,7 @@ class TestReplaceUser:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
assert isinstance(result, ScimUserResource)
|
||||
mock_dal.update_user.assert_called_once()
|
||||
mock_dal.commit.assert_called_once()
|
||||
|
||||
@@ -414,15 +412,9 @@ class TestReplaceUser:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
assert isinstance(result, ScimUserResource)
|
||||
mock_dal.sync_user_external_id.assert_called_once_with(
|
||||
user.id,
|
||||
None,
|
||||
scim_username="test@example.com",
|
||||
fields=ScimMappingFields(
|
||||
given_name="Test",
|
||||
family_name="User",
|
||||
),
|
||||
user.id, None, scim_username="test@example.com"
|
||||
)
|
||||
|
||||
|
||||
@@ -456,7 +448,7 @@ class TestPatchUser:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
assert isinstance(result, ScimUserResource)
|
||||
mock_dal.update_user.assert_called_once()
|
||||
|
||||
def test_not_found_returns_404(
|
||||
@@ -515,7 +507,7 @@ class TestPatchUser:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parse_scim_user(result)
|
||||
assert isinstance(result, ScimUserResource)
|
||||
# Verify the update_user call received the new display name
|
||||
call_kwargs = mock_dal.update_user.call_args
|
||||
assert call_kwargs[1]["personal_name"] == "New Display Name"
|
||||
@@ -613,12 +605,10 @@ class TestDeleteUser:
|
||||
class TestScimNameToStr:
|
||||
"""Tests for _scim_name_to_str helper."""
|
||||
|
||||
def test_prefers_formatted_over_components(self) -> None:
|
||||
"""When client provides formatted, use it — the client knows what it wants."""
|
||||
name = ScimName(
|
||||
givenName="Jane", familyName="Smith", formatted="Dr. Jane Smith"
|
||||
)
|
||||
assert _scim_name_to_str(name) == "Dr. Jane Smith"
|
||||
def test_prefers_given_family_over_formatted(self) -> None:
|
||||
"""Okta may send stale formatted while updating givenName/familyName."""
|
||||
name = ScimName(givenName="Jane", familyName="Smith", formatted="Old Name")
|
||||
assert _scim_name_to_str(name) == "Jane Smith"
|
||||
|
||||
def test_given_name_only(self) -> None:
|
||||
name = ScimName(givenName="Jane")
|
||||
@@ -663,9 +653,9 @@ class TestEmailCasePreservation:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result, status=201)
|
||||
assert resource.userName == "Alice@Example.COM"
|
||||
assert resource.emails[0].value == "Alice@Example.COM"
|
||||
assert isinstance(result, ScimUserResource)
|
||||
assert result.userName == "Alice@Example.COM"
|
||||
assert result.emails[0].value == "Alice@Example.COM"
|
||||
|
||||
def test_get_preserves_username_case(
|
||||
self,
|
||||
@@ -691,6 +681,6 @@ class TestEmailCasePreservation:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
resource = parse_scim_user(result)
|
||||
assert resource.userName == "Alice@Example.COM"
|
||||
assert resource.emails[0].value == "Alice@Example.COM"
|
||||
assert isinstance(result, ScimUserResource)
|
||||
assert result.userName == "Alice@Example.COM"
|
||||
assert result.emails[0].value == "Alice@Example.COM"
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
"""Tests for PythonTool availability based on server_enabled flag.
|
||||
|
||||
Verifies that PythonTool reports itself as unavailable when either:
|
||||
- CODE_INTERPRETER_BASE_URL is not set, or
|
||||
- CodeInterpreterServer.server_enabled is False in the database.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Unavailable when CODE_INTERPRETER_BASE_URL is not set
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
None,
|
||||
)
|
||||
def test_python_tool_unavailable_without_base_url() -> None:
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
assert PythonTool.is_available(db_session) is False
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
"",
|
||||
)
|
||||
def test_python_tool_unavailable_with_empty_base_url() -> None:
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
assert PythonTool.is_available(db_session) is False
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Unavailable when server_enabled is False
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
"http://localhost:8000",
|
||||
)
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.fetch_code_interpreter_server",
|
||||
)
|
||||
def test_python_tool_unavailable_when_server_disabled(
|
||||
mock_fetch: MagicMock,
|
||||
) -> None:
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
|
||||
mock_server = MagicMock()
|
||||
mock_server.server_enabled = False
|
||||
mock_fetch.return_value = mock_server
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
assert PythonTool.is_available(db_session) is False
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Available when both conditions are met
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.CODE_INTERPRETER_BASE_URL",
|
||||
"http://localhost:8000",
|
||||
)
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.python.python_tool.fetch_code_interpreter_server",
|
||||
)
|
||||
def test_python_tool_available_when_server_enabled(
|
||||
mock_fetch: MagicMock,
|
||||
) -> None:
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
|
||||
mock_server = MagicMock()
|
||||
mock_server.server_enabled = True
|
||||
mock_fetch.return_value = mock_server
|
||||
|
||||
db_session = MagicMock(spec=Session)
|
||||
assert PythonTool.is_available(db_session) is True
|
||||
@@ -144,7 +144,7 @@ dev = [
|
||||
"matplotlib==3.10.8",
|
||||
"mypy-extensions==1.0.0",
|
||||
"mypy==1.13.0",
|
||||
"onyx-devtools==0.6.2",
|
||||
"onyx-devtools==0.6.1",
|
||||
"openapi-generator-cli==7.17.0",
|
||||
"pandas-stubs~=2.3.3",
|
||||
"pre-commit==3.2.2",
|
||||
|
||||
@@ -170,30 +170,6 @@ ods pull
|
||||
ods pull --tag edge
|
||||
```
|
||||
|
||||
### `web` - Run Frontend Scripts
|
||||
|
||||
Run npm scripts from `web/package.json` without manually changing directories.
|
||||
|
||||
```shell
|
||||
ods web <script> [args...]
|
||||
```
|
||||
|
||||
Script names are available via shell completion (for supported shells via
|
||||
`ods completion`), and are read from `web/package.json`.
|
||||
|
||||
**Examples:**
|
||||
|
||||
```shell
|
||||
# Start the Next.js dev server
|
||||
ods web dev
|
||||
|
||||
# Run web lint task
|
||||
ods web lint
|
||||
|
||||
# Forward extra args to the script
|
||||
ods web test --watch
|
||||
```
|
||||
|
||||
### `db` - Database Administration
|
||||
|
||||
Manage PostgreSQL database dumps, restores, and migrations.
|
||||
|
||||
@@ -50,7 +50,6 @@ func NewRootCommand() *cobra.Command {
|
||||
cmd.AddCommand(NewPullCommand())
|
||||
cmd.AddCommand(NewRunCICommand())
|
||||
cmd.AddCommand(NewScreenshotDiffCommand())
|
||||
cmd.AddCommand(NewWebCommand())
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
@@ -1,144 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/tools/ods/internal/paths"
|
||||
)
|
||||
|
||||
type webPackageJSON struct {
|
||||
Scripts map[string]string `json:"scripts"`
|
||||
}
|
||||
|
||||
// NewWebCommand creates a command that runs npm scripts from the web directory.
|
||||
func NewWebCommand() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "web <script> [args...]",
|
||||
Short: "Run web/package.json npm scripts",
|
||||
Long: webHelpDescription(),
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
if len(args) > 0 {
|
||||
return nil, cobra.ShellCompDirectiveNoFileComp
|
||||
}
|
||||
return webScriptNames(), cobra.ShellCompDirectiveNoFileComp
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
runWebScript(args)
|
||||
},
|
||||
}
|
||||
cmd.Flags().SetInterspersed(false)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func runWebScript(args []string) {
|
||||
webDir, err := webDir()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to find web directory: %v", err)
|
||||
}
|
||||
|
||||
scriptName := args[0]
|
||||
scriptArgs := args[1:]
|
||||
if len(scriptArgs) > 0 && scriptArgs[0] == "--" {
|
||||
scriptArgs = scriptArgs[1:]
|
||||
}
|
||||
|
||||
npmArgs := []string{"run", scriptName}
|
||||
if len(scriptArgs) > 0 {
|
||||
// npm requires "--" to forward flags to the underlying script.
|
||||
npmArgs = append(npmArgs, "--")
|
||||
npmArgs = append(npmArgs, scriptArgs...)
|
||||
}
|
||||
log.Debugf("Running in %s: npm %v", webDir, npmArgs)
|
||||
|
||||
webCmd := exec.Command("npm", npmArgs...)
|
||||
webCmd.Dir = webDir
|
||||
webCmd.Stdout = os.Stdout
|
||||
webCmd.Stderr = os.Stderr
|
||||
webCmd.Stdin = os.Stdin
|
||||
|
||||
if err := webCmd.Run(); err != nil {
|
||||
// For wrapped commands, preserve the child process's exit code and
|
||||
// avoid duplicating already-printed stderr output.
|
||||
var exitErr *exec.ExitError
|
||||
if errors.As(err, &exitErr) {
|
||||
if code := exitErr.ExitCode(); code != -1 {
|
||||
os.Exit(code)
|
||||
}
|
||||
}
|
||||
log.Fatalf("Failed to run npm: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func webScriptNames() []string {
|
||||
scripts, err := loadWebScripts()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
names := make([]string, 0, len(scripts))
|
||||
for name := range scripts {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
func webHelpDescription() string {
|
||||
description := `Run npm scripts from web/package.json.
|
||||
|
||||
Examples:
|
||||
ods web dev
|
||||
ods web lint
|
||||
ods web test --watch`
|
||||
|
||||
scripts := webScriptNames()
|
||||
if len(scripts) == 0 {
|
||||
return description + "\n\nAvailable scripts: (unable to load)"
|
||||
}
|
||||
|
||||
return description + "\n\nAvailable scripts:\n " + strings.Join(scripts, "\n ")
|
||||
}
|
||||
|
||||
func loadWebScripts() (map[string]string, error) {
|
||||
webDir, err := webDir()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
packageJSONPath := filepath.Join(webDir, "package.json")
|
||||
data, err := os.ReadFile(packageJSONPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read %s: %w", packageJSONPath, err)
|
||||
}
|
||||
|
||||
var pkg webPackageJSON
|
||||
if err := json.Unmarshal(data, &pkg); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse %s: %w", packageJSONPath, err)
|
||||
}
|
||||
|
||||
if pkg.Scripts == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return pkg.Scripts, nil
|
||||
}
|
||||
|
||||
func webDir() (string, error) {
|
||||
root, err := paths.GitRoot()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(root, "web"), nil
|
||||
}
|
||||
18
uv.lock
generated
18
uv.lock
generated
@@ -4654,7 +4654,7 @@ requires-dist = [
|
||||
{ name = "numpy", marker = "extra == 'model-server'", specifier = "==2.4.1" },
|
||||
{ name = "oauthlib", marker = "extra == 'backend'", specifier = "==3.2.2" },
|
||||
{ name = "office365-rest-python-client", marker = "extra == 'backend'", specifier = "==2.6.2" },
|
||||
{ name = "onyx-devtools", marker = "extra == 'dev'", specifier = "==0.6.2" },
|
||||
{ name = "onyx-devtools", marker = "extra == 'dev'", specifier = "==0.6.1" },
|
||||
{ name = "openai", specifier = "==2.14.0" },
|
||||
{ name = "openapi-generator-cli", marker = "extra == 'dev'", specifier = "==7.17.0" },
|
||||
{ name = "openinference-instrumentation", marker = "extra == 'backend'", specifier = "==0.1.42" },
|
||||
@@ -4759,20 +4759,20 @@ requires-dist = [{ name = "onyx", extras = ["backend", "dev", "ee"], editable =
|
||||
|
||||
[[package]]
|
||||
name = "onyx-devtools"
|
||||
version = "0.6.2"
|
||||
version = "0.6.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "fastapi" },
|
||||
{ name = "openapi-generator-cli" },
|
||||
]
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/cc/20/d9f6089616044b0fb6e097cbae82122de24f3acd97820be4868d5c28ee3f/onyx_devtools-0.6.2-py3-none-any.whl", hash = "sha256:e48d14695d39d62ec3247a4c76ea56604bc5fb635af84c4ff3e9628bcc67b4fb", size = 3785941, upload-time = "2026-02-25T22:33:43.585Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d6/f5/f754a717f6b011050eb52ef09895cfa2f048f567f4aa3d5e0f773657dea4/onyx_devtools-0.6.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:505f9910a04868ab62d99bb483dc37c9f4ad94fa80e6ac0e6a10b86351c31420", size = 3832182, upload-time = "2026-02-25T22:33:43.283Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6a/35/6e653398c62078e87ebb0d03dc944df6691d92ca427c92867309d2d803b7/onyx_devtools-0.6.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:edec98e3acc0fa22cf9102c2070409ea7bcf99d7ded72bd8cb184ece8171c36a", size = 3576948, upload-time = "2026-02-25T22:33:42.962Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3c/97/cff707c5c3d2acd714365b1023f0100676abc99816a29558319e8ef01d5f/onyx_devtools-0.6.2-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:97abab61216866cdccd8c0a7e27af328776083756ce4fb57c4bd723030449e3b", size = 3439359, upload-time = "2026-02-25T22:33:44.684Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fc/98/3b768d18e5599178834b966b447075626d224e048d6eb264d89d19abacb4/onyx_devtools-0.6.2-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:681b038ab6f1457409d14b2490782c7a8014fc0f0f1b9cd69bb2b7199f99aef1", size = 3785959, upload-time = "2026-02-25T22:33:44.342Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d6/38/9b047f9e61c14ccf22b8f386c7a57da3965f90737453f3a577a97da45cdf/onyx_devtools-0.6.2-py3-none-win_amd64.whl", hash = "sha256:a2063be6be104b50a7538cf0d26c7f7ab9159d53327dd6f3e91db05d793c95f3", size = 3878776, upload-time = "2026-02-25T22:33:45.229Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/0f/742f644bae84f5f8f7b500094a2f58da3ff8027fc739944622577e2e2850/onyx_devtools-0.6.2-py3-none-win_arm64.whl", hash = "sha256:00fb90a49a15c932b5cacf818b1b4918e5b5c574bde243dc1828b57690dd5046", size = 3501112, upload-time = "2026-02-25T22:33:41.512Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bf/3c/fc0c152ecc403b8d4c929eacc7ea4c3d6cba2094f3cfa51d9e5c4d3bda3d/onyx_devtools-0.6.1-py3-none-any.whl", hash = "sha256:a9ad90ca4536ebe9aaeb604f82c418f3fd148100f14cca7749df0d076ee5c4b0", size = 3781440, upload-time = "2026-02-25T00:59:03.565Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/fd/1c/2df5a06eed5490057f0852153940142f9987ff9b865c9c185b733fa360b1/onyx_devtools-0.6.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:769a656737e2389312e8e24bf3e9dd559dcb00160f323228dfe34d005ab47af3", size = 3827421, upload-time = "2026-02-25T00:58:59.672Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/a2/e3/389644eb9ba0a3cfa975cc015a48140702b05abc9093542b2a3ba6cc5cc1/onyx_devtools-0.6.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:93886332e97e6efa5f3d7a1d1e4facf1442d301df379f65dfc2a328ed43c8f39", size = 3573060, upload-time = "2026-02-25T00:59:02.582Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/68/fe/dd0f32e08f7e7fb1861a28b82431e0a43cf6ab33e04fb2938f4ee20c891b/onyx_devtools-0.6.1-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:cf896e420c78c08c541135473627ffcab0a0156e0e462e71bcb476f560c324fa", size = 3435936, upload-time = "2026-02-25T00:59:02.313Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/3a/4376cba6adcf86b9fc55f146493450955497d988920eaa37a8aec9f9f897/onyx_devtools-0.6.1-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:4cb5a1b44a4e74c2fc68164a5caa34bce3f6d2dd5639e48438c1d04f09c4c7c6", size = 3781457, upload-time = "2026-02-25T00:59:02.126Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/0d/d2ecf7edc02354d16d9a1d9bd7d8d35f46cdde08b86635ba02075e4d3c7c/onyx_devtools-0.6.1-py3-none-win_amd64.whl", hash = "sha256:0c6c6a667851b9ab215980f1b391216bc2f157c8a29d0cfa96c32c6d10116a5c", size = 3875146, upload-time = "2026-02-25T00:59:02.364Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c5/c3/04783dcfad36b18f48befb6d85bf4f9a9f36fd4cd6e08077676c72c9c504/onyx_devtools-0.6.1-py3-none-win_arm64.whl", hash = "sha256:f095e58b4dad0671c7127a452c5d5f411f55070ebf586a2e47f9193ab753ce44", size = 3496971, upload-time = "2026-02-25T00:59:17.98Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -59,21 +59,18 @@ type ButtonContentProps =
|
||||
icon: IconFunctionComponent;
|
||||
children: string;
|
||||
rightIcon?: IconFunctionComponent;
|
||||
responsiveHideText?: never;
|
||||
}
|
||||
| {
|
||||
foldable?: false;
|
||||
icon?: IconFunctionComponent;
|
||||
children: string;
|
||||
rightIcon?: IconFunctionComponent;
|
||||
responsiveHideText?: never;
|
||||
}
|
||||
| {
|
||||
foldable?: false;
|
||||
icon: IconFunctionComponent;
|
||||
children?: string;
|
||||
rightIcon?: IconFunctionComponent;
|
||||
responsiveHideText?: boolean;
|
||||
};
|
||||
|
||||
type ButtonProps = InteractiveBaseProps &
|
||||
@@ -111,7 +108,6 @@ function Button({
|
||||
width,
|
||||
tooltip,
|
||||
tooltipSide = "top",
|
||||
responsiveHideText = false,
|
||||
...interactiveBaseProps
|
||||
}: ButtonProps) {
|
||||
const isLarge = size === "lg";
|
||||
@@ -120,8 +116,7 @@ function Button({
|
||||
<span
|
||||
className={cn(
|
||||
"opal-button-label",
|
||||
isLarge ? "font-main-ui-body " : "font-secondary-body",
|
||||
responsiveHideText && "hidden md:inline"
|
||||
isLarge ? "font-main-ui-body " : "font-secondary-body"
|
||||
)}
|
||||
>
|
||||
{children}
|
||||
@@ -151,25 +146,13 @@ function Button({
|
||||
<div className="opal-button-foldable">
|
||||
<div className="opal-button-foldable-inner">
|
||||
{labelEl}
|
||||
{responsiveHideText ? (
|
||||
<span className="hidden md:inline-flex">
|
||||
{iconWrapper(RightIcon, size, !!children)}
|
||||
</span>
|
||||
) : (
|
||||
iconWrapper(RightIcon, size, !!children)
|
||||
)}
|
||||
{iconWrapper(RightIcon, size, !!children)}
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
{labelEl}
|
||||
{responsiveHideText ? (
|
||||
<span className="hidden md:inline-flex">
|
||||
{iconWrapper(RightIcon, size, !!children)}
|
||||
</span>
|
||||
) : (
|
||||
iconWrapper(RightIcon, size, !!children)
|
||||
)}
|
||||
{iconWrapper(RightIcon, size, !!children)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
@@ -177,13 +160,7 @@ function Button({
|
||||
</Interactive.Base>
|
||||
);
|
||||
|
||||
const resolvedTooltip =
|
||||
tooltip ??
|
||||
(foldable && interactiveBaseProps.disabled && children
|
||||
? children
|
||||
: undefined);
|
||||
|
||||
if (!resolvedTooltip) return button;
|
||||
if (!tooltip) return button;
|
||||
|
||||
return (
|
||||
<TooltipPrimitive.Root>
|
||||
@@ -194,7 +171,7 @@ function Button({
|
||||
side={tooltipSide}
|
||||
sideOffset={4}
|
||||
>
|
||||
{resolvedTooltip}
|
||||
{tooltip}
|
||||
</TooltipPrimitive.Content>
|
||||
</TooltipPrimitive.Portal>
|
||||
</TooltipPrimitive.Root>
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgAzure = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 52 52"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M18.3281 3.40002C17.3811 3.40002 16.5394 4.00473 16.2763 4.89865L3.12373 43.8366C2.72915 44.9672 3.30787 46.2029 4.43899 46.5973C4.67574 46.6761 4.91248 46.7287 5.17554 46.7287H16.0396C16.855 46.571 17.539 45.9926 17.8283 45.2038L32.0068 3.40002H18.3281Z"
|
||||
fill="url(#paint0_linear_9_943)"
|
||||
/>
|
||||
<path
|
||||
d="M38.136 31.4795H16.5394C15.987 31.4795 15.5398 31.9264 15.5398 32.4786C15.5398 32.7678 15.645 33.0307 15.8555 33.2147L29.7446 46.1503C30.1392 46.5183 30.6916 46.7287 31.244 46.7287H43.4759L38.136 31.4795Z"
|
||||
fill="#0078D4"
|
||||
/>
|
||||
<path
|
||||
d="M18.3281 3.40002C17.3811 3.40002 16.5394 4.00473 16.2763 4.89865L3.12373 43.8366C2.72915 44.9672 3.30787 46.2029 4.43899 46.5973C4.67574 46.6761 4.91248 46.7287 5.17554 46.7287H16.0396C16.855 46.571 17.539 45.9926 17.8283 45.2038L20.4589 37.4741L29.8235 46.2555C30.2181 46.571 30.7179 46.755 31.2177 46.755H43.397L38.057 31.4796H22.4844L32.0068 3.40002H18.3281Z"
|
||||
fill="url(#paint1_linear_9_943)"
|
||||
/>
|
||||
<path
|
||||
d="M35.7422 4.87236C35.4528 3.97844 34.611 3.40002 33.6904 3.40002H18.5123C19.4329 3.40002 20.2747 4.00473 20.5641 4.87236L33.7167 43.8892C34.1112 45.0198 33.4799 46.2555 32.3488 46.6236C32.1384 46.7024 31.9016 46.7287 31.6649 46.7287H46.843C48.053 46.7287 49 45.7559 49 44.5728C49 44.3362 48.9737 44.0996 48.8948 43.8892L35.7422 4.87236Z"
|
||||
fill="url(#paint2_linear_9_943)"
|
||||
/>
|
||||
<defs>
|
||||
<linearGradient
|
||||
id="paint0_linear_9_943"
|
||||
x1={23.3411}
|
||||
y1={6.61094}
|
||||
x2={9.24122}
|
||||
y2={48.3769}
|
||||
gradientUnits="userSpaceOnUse"
|
||||
>
|
||||
<stop stopColor="#114A8B" />
|
||||
<stop offset={1} stopColor="#0765B6" />
|
||||
</linearGradient>
|
||||
<linearGradient
|
||||
id="paint1_linear_9_943"
|
||||
x1={27.7206}
|
||||
y1={26.0775}
|
||||
x2={24.4488}
|
||||
y2={27.1844}
|
||||
gradientUnits="userSpaceOnUse"
|
||||
>
|
||||
<stop stopOpacity={0.3} />
|
||||
<stop offset={0.071} stopOpacity={0.2} />
|
||||
<stop offset={0.321} stopOpacity={0.1} />
|
||||
<stop offset={0.623} stopOpacity={0.05} />
|
||||
<stop offset={1} stopOpacity={0} />
|
||||
</linearGradient>
|
||||
<linearGradient
|
||||
id="paint2_linear_9_943"
|
||||
x1={26.0229}
|
||||
y1={5.35655}
|
||||
x2={41.5367}
|
||||
y2={46.7094}
|
||||
gradientUnits="userSpaceOnUse"
|
||||
>
|
||||
<stop stopColor="#3BC9F3" />
|
||||
<stop offset={1} stopColor="#2892DF" />
|
||||
</linearGradient>
|
||||
</defs>
|
||||
</svg>
|
||||
);
|
||||
export default SvgAzure;
|
||||
@@ -10,6 +10,7 @@ const SvgClaude = ({ size, ...props }: IconProps) => {
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
stroke="currentColor"
|
||||
{...props}
|
||||
>
|
||||
<g clipPath={`url(#${clipId})`}>
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
import type { IconProps } from "@opal/types";
|
||||
const SvgGemini = ({ size, ...props }: IconProps) => (
|
||||
<svg
|
||||
width={size}
|
||||
height={size}
|
||||
viewBox="0 0 52 52"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M26 2C26.5034 2 26.9412 2.34378 27.064 2.83212C27.4405 4.3258 27.9315 5.78274 28.5426 7.20002C30.1345 10.8981 32.3187 14.1349 35.092 16.9081C37.8664 19.6813 41.102 21.8655 44.8 23.4574C46.2186 24.0685 47.6743 24.5595 49.1679 24.936C49.6562 25.0588 49.9999 25.4967 50 26C50 26.5034 49.6563 26.9413 49.1679 27.064C47.6743 27.4405 46.2172 27.9315 44.8 28.5426C41.1019 30.1345 37.8651 32.3187 35.092 35.092C32.3187 37.8665 30.1345 41.1019 28.5426 44.8C27.9315 46.2186 27.4405 47.6743 27.064 49.1679C26.9413 49.6563 26.5034 50 26 50C25.4967 49.9999 25.0588 49.6562 24.936 49.1679C24.5595 47.6743 24.0685 46.2172 23.4574 44.8C21.8655 41.102 19.6826 37.8651 16.9081 35.092C14.1335 32.3187 10.8981 30.1345 7.20002 28.5426C5.78137 27.9315 4.3258 27.4405 2.83212 27.064C2.34378 26.9412 2 26.5034 2 26C2.00006 25.4967 2.34381 25.0588 2.83212 24.936C4.32581 24.5595 5.78273 24.0686 7.20002 23.4574C10.8981 21.8655 14.1349 19.6813 16.9081 16.9081C19.6813 14.1349 21.8655 10.8981 23.4574 7.20002C24.0686 5.78137 24.5595 4.32581 24.936 2.83212C25.0588 2.34381 25.4967 2.00006 26 2Z"
|
||||
fill="url(#paint0_linear_9_973)"
|
||||
/>
|
||||
<defs>
|
||||
<linearGradient
|
||||
id="paint0_linear_9_973"
|
||||
x1={15.6448}
|
||||
y1={34.1163}
|
||||
x2={40.5754}
|
||||
y2={13.0975}
|
||||
gradientUnits="userSpaceOnUse"
|
||||
>
|
||||
<stop stopColor="#4893FC" />
|
||||
<stop offset={0.27} stopColor="#4893FC" />
|
||||
<stop offset={0.776981} stopColor="#969DFF" />
|
||||
<stop offset={1} stopColor="#BD99FE" />
|
||||
</linearGradient>
|
||||
</defs>
|
||||
</svg>
|
||||
);
|
||||
export default SvgGemini;
|
||||
@@ -19,7 +19,6 @@ export { default as SvgArrowUpRight } from "@opal/icons/arrow-up-right";
|
||||
export { default as SvgArrowWallRight } from "@opal/icons/arrow-wall-right";
|
||||
export { default as SvgAudioEqSmall } from "@opal/icons/audio-eq-small";
|
||||
export { default as SvgAws } from "@opal/icons/aws";
|
||||
export { default as SvgAzure } from "@opal/icons/azure";
|
||||
export { default as SvgBarChart } from "@opal/icons/bar-chart";
|
||||
export { default as SvgBarChartSmall } from "@opal/icons/bar-chart-small";
|
||||
export { default as SvgBell } from "@opal/icons/bell";
|
||||
@@ -75,7 +74,6 @@ export { default as SvgFolderIn } from "@opal/icons/folder-in";
|
||||
export { default as SvgFolderOpen } from "@opal/icons/folder-open";
|
||||
export { default as SvgFolderPartialOpen } from "@opal/icons/folder-partial-open";
|
||||
export { default as SvgFolderPlus } from "@opal/icons/folder-plus";
|
||||
export { default as SvgGemini } from "@opal/icons/gemini";
|
||||
export { default as SvgGlobe } from "@opal/icons/globe";
|
||||
export { default as SvgHardDrive } from "@opal/icons/hard-drive";
|
||||
export { default as SvgHashSmall } from "@opal/icons/hash-small";
|
||||
@@ -86,7 +84,6 @@ export { default as SvgHourglass } from "@opal/icons/hourglass";
|
||||
export { default as SvgImage } from "@opal/icons/image";
|
||||
export { default as SvgImageSmall } from "@opal/icons/image-small";
|
||||
export { default as SvgImport } from "@opal/icons/import";
|
||||
export { default as SvgInfo } from "@opal/icons/info";
|
||||
export { default as SvgInfoSmall } from "@opal/icons/info-small";
|
||||
export { default as SvgKey } from "@opal/icons/key";
|
||||
export { default as SvgKeystroke } from "@opal/icons/keystroke";
|
||||
@@ -94,7 +91,6 @@ export { default as SvgLightbulbSimple } from "@opal/icons/lightbulb-simple";
|
||||
export { default as SvgLineChartUp } from "@opal/icons/line-chart-up";
|
||||
export { default as SvgLink } from "@opal/icons/link";
|
||||
export { default as SvgLinkedDots } from "@opal/icons/linked-dots";
|
||||
export { default as SvgLitellm } from "@opal/icons/litellm";
|
||||
export { default as SvgLoader } from "@opal/icons/loader";
|
||||
export { default as SvgLock } from "@opal/icons/lock";
|
||||
export { default as SvgLogOut } from "@opal/icons/log-out";
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user