mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-03 14:45:46 +00:00
Compare commits
2 Commits
tokenizer
...
nik/teams-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
32766066b4 | ||
|
|
6562e63ab8 |
26
.github/workflows/deployment.yml
vendored
26
.github/workflows/deployment.yml
vendored
@@ -426,9 +426,8 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64,mode=max
|
||||
@@ -500,9 +499,8 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64,mode=max
|
||||
@@ -648,8 +646,8 @@ jobs:
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64,mode=max
|
||||
@@ -730,8 +728,8 @@ jobs:
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64,mode=max
|
||||
@@ -864,9 +862,8 @@ jobs:
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64,mode=max
|
||||
@@ -937,9 +934,8 @@ jobs:
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64,mode=max
|
||||
@@ -1076,8 +1072,8 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
ENABLE_CRAFT=true
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64,mode=max
|
||||
@@ -1149,8 +1145,8 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
ENABLE_CRAFT=true
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64,mode=max
|
||||
@@ -1291,9 +1287,8 @@ jobs:
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64,mode=max
|
||||
@@ -1371,9 +1366,8 @@ jobs:
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64,mode=max
|
||||
|
||||
13
.github/workflows/nightly-llm-provider-chat.yml
vendored
13
.github/workflows/nightly-llm-provider-chat.yml
vendored
@@ -15,9 +15,6 @@ permissions:
|
||||
jobs:
|
||||
provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
with:
|
||||
openai_models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
|
||||
anthropic_models: ${{ vars.NIGHTLY_LLM_ANTHROPIC_MODELS }}
|
||||
@@ -28,6 +25,16 @@ jobs:
|
||||
ollama_models: ${{ vars.NIGHTLY_LLM_OLLAMA_MODELS }}
|
||||
openrouter_models: ${{ vars.NIGHTLY_LLM_OPENROUTER_MODELS }}
|
||||
strict: true
|
||||
secrets:
|
||||
openai_api_key: ${{ secrets.OPENAI_API_KEY }}
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
bedrock_api_key: ${{ secrets.BEDROCK_API_KEY }}
|
||||
vertex_ai_custom_config_json: ${{ secrets.NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON }}
|
||||
azure_api_key: ${{ secrets.AZURE_API_KEY }}
|
||||
ollama_api_key: ${{ secrets.OLLAMA_API_KEY }}
|
||||
openrouter_api_key: ${{ secrets.OPENROUTER_API_KEY }}
|
||||
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: [provider-chat-test]
|
||||
|
||||
21
.github/workflows/pr-python-checks.yml
vendored
21
.github/workflows/pr-python-checks.yml
vendored
@@ -8,7 +8,7 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
- 'release/**'
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
@@ -21,13 +21,7 @@ jobs:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
# Note: Mypy seems quite optimized for x64 compared to arm64.
|
||||
# Similarly, mypy is single-threaded and incremental, so 2cpu is sufficient.
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=2cpu-linux-x64,
|
||||
"run-id=${{ github.run_id }}-mypy-check",
|
||||
"extras=s3-cache",
|
||||
]
|
||||
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}-mypy-check", "extras=s3-cache"]
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
@@ -58,14 +52,21 @@ jobs:
|
||||
if: ${{ vars.DISABLE_MYPY_CACHE != 'true' }}
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: .mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'pyproject.toml') }}
|
||||
path: backend/.mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
restore-keys: |
|
||||
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
|
||||
mypy-${{ runner.os }}-
|
||||
|
||||
- name: Run MyPy
|
||||
working-directory: ./backend
|
||||
env:
|
||||
MYPY_FORCE_COLOR: 1
|
||||
TERM: xterm-256color
|
||||
run: mypy .
|
||||
|
||||
- name: Run MyPy (tools/)
|
||||
env:
|
||||
MYPY_FORCE_COLOR: 1
|
||||
TERM: xterm-256color
|
||||
run: mypy tools/
|
||||
|
||||
@@ -48,10 +48,28 @@ on:
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
secrets:
|
||||
openai_api_key:
|
||||
required: false
|
||||
anthropic_api_key:
|
||||
required: false
|
||||
bedrock_api_key:
|
||||
required: false
|
||||
vertex_ai_custom_config_json:
|
||||
required: false
|
||||
azure_api_key:
|
||||
required: false
|
||||
ollama_api_key:
|
||||
required: false
|
||||
openrouter_api_key:
|
||||
required: false
|
||||
DOCKER_USERNAME:
|
||||
required: true
|
||||
DOCKER_TOKEN:
|
||||
required: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
|
||||
jobs:
|
||||
build-backend-image:
|
||||
@@ -63,7 +81,6 @@ jobs:
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
@@ -72,19 +89,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, test/docker-username
|
||||
DOCKER_TOKEN, test/docker-token
|
||||
|
||||
- name: Build backend image
|
||||
uses: ./.github/actions/build-backend-image
|
||||
with:
|
||||
@@ -93,8 +97,8 @@ jobs:
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ env.DOCKER_USERNAME }}
|
||||
docker-token: ${{ env.DOCKER_TOKEN }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
docker-no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' && 'true' || 'false' }}
|
||||
|
||||
build-model-server-image:
|
||||
@@ -106,7 +110,6 @@ jobs:
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
@@ -115,19 +118,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, test/docker-username
|
||||
DOCKER_TOKEN, test/docker-token
|
||||
|
||||
- name: Build model server image
|
||||
uses: ./.github/actions/build-model-server-image
|
||||
with:
|
||||
@@ -136,8 +126,8 @@ jobs:
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ env.DOCKER_USERNAME }}
|
||||
docker-token: ${{ env.DOCKER_TOKEN }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
build-integration-image:
|
||||
runs-on:
|
||||
@@ -148,7 +138,6 @@ jobs:
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
@@ -157,19 +146,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, test/docker-username
|
||||
DOCKER_TOKEN, test/docker-token
|
||||
|
||||
- name: Build integration image
|
||||
uses: ./.github/actions/build-integration-image
|
||||
with:
|
||||
@@ -178,8 +154,8 @@ jobs:
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ env.DOCKER_USERNAME }}
|
||||
docker-token: ${{ env.DOCKER_TOKEN }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
provider-chat-test:
|
||||
needs:
|
||||
@@ -194,56 +170,56 @@ jobs:
|
||||
include:
|
||||
- provider: openai
|
||||
models: ${{ inputs.openai_models }}
|
||||
api_key_env: OPENAI_API_KEY
|
||||
custom_config_env: ""
|
||||
api_key_secret: openai_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: true
|
||||
- provider: anthropic
|
||||
models: ${{ inputs.anthropic_models }}
|
||||
api_key_env: ANTHROPIC_API_KEY
|
||||
custom_config_env: ""
|
||||
api_key_secret: anthropic_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: true
|
||||
- provider: bedrock
|
||||
models: ${{ inputs.bedrock_models }}
|
||||
api_key_env: BEDROCK_API_KEY
|
||||
custom_config_env: ""
|
||||
api_key_secret: bedrock_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: vertex_ai
|
||||
models: ${{ inputs.vertex_ai_models }}
|
||||
api_key_env: ""
|
||||
custom_config_env: NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON
|
||||
api_key_secret: ""
|
||||
custom_config_secret: vertex_ai_custom_config_json
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: azure
|
||||
models: ${{ inputs.azure_models }}
|
||||
api_key_env: AZURE_API_KEY
|
||||
custom_config_env: ""
|
||||
api_key_secret: azure_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: ${{ inputs.azure_api_base }}
|
||||
api_version: "2025-04-01-preview"
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: ollama_chat
|
||||
models: ${{ inputs.ollama_models }}
|
||||
api_key_env: OLLAMA_API_KEY
|
||||
custom_config_env: ""
|
||||
api_key_secret: ollama_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: "https://ollama.com"
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: openrouter
|
||||
models: ${{ inputs.openrouter_models }}
|
||||
api_key_env: OPENROUTER_API_KEY
|
||||
custom_config_env: ""
|
||||
api_key_secret: openrouter_api_key
|
||||
custom_config_secret: ""
|
||||
api_base: "https://openrouter.ai/api/v1"
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
@@ -254,7 +230,6 @@ jobs:
|
||||
- "run-id=${{ github.run_id }}-nightly-${{ matrix.provider }}-provider-chat-test"
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
@@ -263,43 +238,21 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
# Keep JSON values unparsed so vertex custom config is passed as raw JSON.
|
||||
parse-json-secrets: false
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, test/docker-username
|
||||
DOCKER_TOKEN, test/docker-token
|
||||
OPENAI_API_KEY, test/openai-api-key
|
||||
ANTHROPIC_API_KEY, test/anthropic-api-key
|
||||
BEDROCK_API_KEY, test/bedrock-api-key
|
||||
NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON, test/nightly-llm-vertex-ai-custom-config-json
|
||||
AZURE_API_KEY, test/azure-api-key
|
||||
OLLAMA_API_KEY, test/ollama-api-key
|
||||
OPENROUTER_API_KEY, test/openrouter-api-key
|
||||
|
||||
- name: Run nightly provider chat test
|
||||
uses: ./.github/actions/run-nightly-provider-chat-test
|
||||
with:
|
||||
provider: ${{ matrix.provider }}
|
||||
models: ${{ matrix.models }}
|
||||
provider-api-key: ${{ matrix.api_key_env && env[matrix.api_key_env] || '' }}
|
||||
provider-api-key: ${{ matrix.api_key_secret && secrets[matrix.api_key_secret] || '' }}
|
||||
strict: ${{ inputs.strict && 'true' || 'false' }}
|
||||
api-base: ${{ matrix.api_base }}
|
||||
api-version: ${{ matrix.api_version }}
|
||||
deployment-name: ${{ matrix.deployment_name }}
|
||||
custom-config-json: ${{ matrix.custom_config_env && env[matrix.custom_config_env] || '' }}
|
||||
custom-config-json: ${{ matrix.custom_config_secret && secrets[matrix.custom_config_secret] || '' }}
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ env.DOCKER_USERNAME }}
|
||||
docker-token: ${{ env.DOCKER_TOKEN }}
|
||||
docker-username: ${{ secrets.DOCKER_USERNAME }}
|
||||
docker-token: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Dump API server logs
|
||||
if: always()
|
||||
|
||||
105
backend/alembic/versions/a1b2c3d4e5f6_add_teams_bot_tables.py
Normal file
105
backend/alembic/versions/a1b2c3d4e5f6_add_teams_bot_tables.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""add teams bot tables
|
||||
|
||||
Revision ID: a1b2c3d4e5f6
|
||||
Revises: 6b3b4083c5aa
|
||||
Create Date: 2026-03-02 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a1b2c3d4e5f6"
|
||||
down_revision = "6b3b4083c5aa"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"teams_bot_config",
|
||||
sa.Column(
|
||||
"id",
|
||||
sa.String(),
|
||||
server_default=sa.text("'SINGLETON'"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("app_id", sa.String(), nullable=False),
|
||||
sa.Column("app_secret", sa.LargeBinary(), nullable=False),
|
||||
sa.Column("azure_tenant_id", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.CheckConstraint("id = 'SINGLETON'", name="ck_teams_bot_config_singleton"),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"teams_team_config",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("team_id", sa.String(), nullable=True),
|
||||
sa.Column("team_name", sa.String(length=256), nullable=True),
|
||||
sa.Column("registration_key", sa.String(), nullable=False),
|
||||
sa.Column("registered_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("default_persona_id", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"enabled",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("true"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["default_persona_id"],
|
||||
["persona.id"],
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("team_id"),
|
||||
sa.UniqueConstraint("registration_key"),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"teams_channel_config",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("team_config_id", sa.Integer(), nullable=False),
|
||||
sa.Column("channel_id", sa.String(), nullable=False),
|
||||
sa.Column("channel_name", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"require_bot_mention",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("true"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("persona_override_id", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"enabled",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("false"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["team_config_id"],
|
||||
["teams_team_config.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_override_id"],
|
||||
["persona.id"],
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"team_config_id", "channel_id", name="uq_teams_channel_team_channel"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("teams_channel_config")
|
||||
op.drop_table("teams_team_config")
|
||||
op.drop_table("teams_bot_config")
|
||||
@@ -31,7 +31,6 @@ from ee.onyx.server.query_and_chat.query_backend import (
|
||||
from ee.onyx.server.query_and_chat.search_backend import router as search_router
|
||||
from ee.onyx.server.query_history.api import router as query_history_router
|
||||
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
|
||||
from ee.onyx.server.scim.api import register_scim_exception_handlers
|
||||
from ee.onyx.server.scim.api import scim_router
|
||||
from ee.onyx.server.seeding import seed_db
|
||||
from ee.onyx.server.tenants.api import router as tenants_router
|
||||
@@ -168,7 +167,6 @@ def get_application() -> FastAPI:
|
||||
# they use their own SCIM bearer token auth).
|
||||
# Not behind APP_API_PREFIX because IdPs expect /scim/v2/... directly.
|
||||
application.include_router(scim_router)
|
||||
register_scim_exception_handlers(application)
|
||||
|
||||
# Ensure all routes have auth enabled or are explicitly marked as public
|
||||
check_ee_router_auth(application)
|
||||
|
||||
@@ -15,9 +15,7 @@ from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Query
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi_users.password import PasswordHelper
|
||||
@@ -26,7 +24,6 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from ee.onyx.server.scim.auth import ScimAuthError
|
||||
from ee.onyx.server.scim.auth import verify_scim_token
|
||||
from ee.onyx.server.scim.filtering import parse_scim_filter
|
||||
from ee.onyx.server.scim.models import SCIM_LIST_RESPONSE_SCHEMA
|
||||
@@ -80,22 +77,6 @@ scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
|
||||
_pw_helper = PasswordHelper()
|
||||
|
||||
|
||||
def register_scim_exception_handlers(app: FastAPI) -> None:
|
||||
"""Register SCIM-specific exception handlers on the FastAPI app.
|
||||
|
||||
Call this after ``app.include_router(scim_router)`` so that auth
|
||||
failures from ``verify_scim_token`` return RFC 7644 §3.12 error
|
||||
envelopes (with ``schemas`` and ``status`` fields) instead of
|
||||
FastAPI's default ``{"detail": "..."}`` format.
|
||||
"""
|
||||
|
||||
@app.exception_handler(ScimAuthError)
|
||||
async def _handle_scim_auth_error(
|
||||
_request: Request, exc: ScimAuthError
|
||||
) -> ScimJSONResponse:
|
||||
return _scim_error_response(exc.status_code, exc.detail)
|
||||
|
||||
|
||||
def _get_provider(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
) -> ScimProvider:
|
||||
@@ -423,6 +404,12 @@ def create_user(
|
||||
|
||||
email = user_resource.userName.strip()
|
||||
|
||||
# externalId is how the IdP correlates this user on subsequent requests.
|
||||
# Without it, the IdP can't find the user and will try to re-create,
|
||||
# hitting a 409 conflict — so we require it up front.
|
||||
if not user_resource.externalId:
|
||||
return _scim_error_response(400, "externalId is required")
|
||||
|
||||
# Enforce seat limit
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
@@ -449,19 +436,16 @@ def create_user(
|
||||
dal.rollback()
|
||||
return _scim_error_response(409, f"User with email {email} already exists")
|
||||
|
||||
# Create SCIM mapping when externalId is provided — this is how the IdP
|
||||
# correlates this user on subsequent requests. Per RFC 7643, externalId
|
||||
# is optional and assigned by the provisioning client.
|
||||
# Create SCIM mapping (externalId is validated above, always present)
|
||||
external_id = user_resource.externalId
|
||||
scim_username = user_resource.userName.strip()
|
||||
fields = _fields_from_resource(user_resource)
|
||||
if external_id:
|
||||
dal.create_user_mapping(
|
||||
external_id=external_id,
|
||||
user_id=user.id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
dal.create_user_mapping(
|
||||
external_id=external_id,
|
||||
user_id=user.id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import hashlib
|
||||
import secrets
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -27,21 +28,6 @@ from onyx.auth.utils import get_hashed_bearer_token_from_request
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import ScimToken
|
||||
|
||||
|
||||
class ScimAuthError(Exception):
|
||||
"""Raised when SCIM bearer token authentication fails.
|
||||
|
||||
Unlike HTTPException, this carries the status and detail so the SCIM
|
||||
exception handler can wrap them in an RFC 7644 §3.12 error envelope
|
||||
with ``schemas`` and ``status`` fields.
|
||||
"""
|
||||
|
||||
def __init__(self, status_code: int, detail: str) -> None:
|
||||
self.status_code = status_code
|
||||
self.detail = detail
|
||||
super().__init__(detail)
|
||||
|
||||
|
||||
SCIM_TOKEN_PREFIX = "onyx_scim_"
|
||||
SCIM_TOKEN_LENGTH = 48
|
||||
|
||||
@@ -96,14 +82,23 @@ def verify_scim_token(
|
||||
"""
|
||||
hashed = _get_hashed_scim_token_from_request(request)
|
||||
if not hashed:
|
||||
raise ScimAuthError(401, "Missing or invalid SCIM bearer token")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing or invalid SCIM bearer token",
|
||||
)
|
||||
|
||||
token = dal.get_token_by_hash(hashed)
|
||||
|
||||
if not token:
|
||||
raise ScimAuthError(401, "Invalid SCIM bearer token")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid SCIM bearer token",
|
||||
)
|
||||
|
||||
if not token.is_active:
|
||||
raise ScimAuthError(401, "SCIM token has been revoked")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="SCIM token has been revoked",
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
@@ -153,28 +153,26 @@ class ScimProvider(ABC):
|
||||
self,
|
||||
user: User,
|
||||
fields: ScimMappingFields,
|
||||
) -> ScimName:
|
||||
) -> ScimName | None:
|
||||
"""Build SCIM name components for the response.
|
||||
|
||||
Round-trips stored ``given_name``/``family_name`` when available (so
|
||||
the IdP gets back what it sent). Falls back to splitting
|
||||
``personal_name`` for users provisioned before we stored components.
|
||||
Always returns a ScimName — Okta's spec tests expect ``name``
|
||||
(with ``givenName``/``familyName``) on every user resource.
|
||||
Providers may override for custom behavior.
|
||||
"""
|
||||
if fields.given_name is not None or fields.family_name is not None:
|
||||
return ScimName(
|
||||
givenName=fields.given_name or "",
|
||||
familyName=fields.family_name or "",
|
||||
formatted=user.personal_name or "",
|
||||
givenName=fields.given_name,
|
||||
familyName=fields.family_name,
|
||||
formatted=user.personal_name,
|
||||
)
|
||||
if not user.personal_name:
|
||||
return ScimName(givenName="", familyName="", formatted="")
|
||||
return None
|
||||
parts = user.personal_name.split(" ", 1)
|
||||
return ScimName(
|
||||
givenName=parts[0],
|
||||
familyName=parts[1] if len(parts) > 1 else "",
|
||||
familyName=parts[1] if len(parts) > 1 else None,
|
||||
formatted=user.personal_name,
|
||||
)
|
||||
|
||||
|
||||
@@ -32,16 +32,13 @@ PERIODIC_TASK_KV_PREFIX = "periodic_poller:last_claimed:"
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
_NEVER_RAN: float = -1e18
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PeriodicTaskDef:
|
||||
name: str
|
||||
interval_seconds: float
|
||||
lock_id: int
|
||||
run_fn: Callable[[], None]
|
||||
last_run_at: float = field(default=_NEVER_RAN)
|
||||
last_run_at: float = field(default=0.0)
|
||||
|
||||
|
||||
def _run_auto_llm_update() -> None:
|
||||
|
||||
45
backend/onyx/cache/factory.py
vendored
45
backend/onyx/cache/factory.py
vendored
@@ -1,45 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import CacheBackendType
|
||||
from onyx.configs.app_configs import CACHE_BACKEND
|
||||
|
||||
|
||||
def _build_redis_backend(tenant_id: str) -> CacheBackend:
|
||||
from onyx.cache.redis_backend import RedisCacheBackend
|
||||
from onyx.redis.redis_pool import redis_pool
|
||||
|
||||
return RedisCacheBackend(redis_pool.get_client(tenant_id))
|
||||
|
||||
|
||||
_BACKEND_BUILDERS: dict[CacheBackendType, Callable[[str], CacheBackend]] = {
|
||||
CacheBackendType.REDIS: _build_redis_backend,
|
||||
# CacheBackendType.POSTGRES will be added in a follow-up PR.
|
||||
}
|
||||
|
||||
|
||||
def get_cache_backend(*, tenant_id: str | None = None) -> CacheBackend:
|
||||
"""Return a tenant-aware ``CacheBackend``.
|
||||
|
||||
If *tenant_id* is ``None``, the current tenant is read from the
|
||||
thread-local context variable (same behaviour as ``get_redis_client``).
|
||||
"""
|
||||
if tenant_id is None:
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
builder = _BACKEND_BUILDERS.get(CACHE_BACKEND)
|
||||
if builder is None:
|
||||
raise ValueError(
|
||||
f"Unsupported CACHE_BACKEND={CACHE_BACKEND!r}. "
|
||||
f"Supported values: {[t.value for t in CacheBackendType]}"
|
||||
)
|
||||
return builder(tenant_id)
|
||||
|
||||
|
||||
def get_shared_cache_backend() -> CacheBackend:
|
||||
"""Return a ``CacheBackend`` in the shared (cross-tenant) namespace."""
|
||||
from shared_configs.configs import DEFAULT_REDIS_PREFIX
|
||||
|
||||
return get_cache_backend(tenant_id=DEFAULT_REDIS_PREFIX)
|
||||
89
backend/onyx/cache/interface.py
vendored
89
backend/onyx/cache/interface.py
vendored
@@ -1,89 +0,0 @@
|
||||
import abc
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class CacheBackendType(str, Enum):
|
||||
REDIS = "redis"
|
||||
POSTGRES = "postgres"
|
||||
|
||||
|
||||
class CacheLock(abc.ABC):
|
||||
"""Abstract distributed lock returned by CacheBackend.lock()."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def acquire(
|
||||
self,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: float | None = None,
|
||||
) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def release(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def owned(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CacheBackend(abc.ABC):
|
||||
"""Thin abstraction over a key-value cache with TTL, locks, and blocking lists.
|
||||
|
||||
Covers the subset of Redis operations used outside of Celery. When
|
||||
CACHE_BACKEND=postgres, a PostgreSQL-backed implementation is used instead.
|
||||
"""
|
||||
|
||||
# -- basic key/value ---------------------------------------------------
|
||||
|
||||
@abc.abstractmethod
|
||||
def get(self, key: str) -> bytes | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete(self, key: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def exists(self, key: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
# -- TTL ---------------------------------------------------------------
|
||||
|
||||
@abc.abstractmethod
|
||||
def expire(self, key: str, seconds: int) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def ttl(self, key: str) -> int:
|
||||
"""Return remaining TTL in seconds. -1 if no expiry, -2 if key missing."""
|
||||
raise NotImplementedError
|
||||
|
||||
# -- distributed lock --------------------------------------------------
|
||||
|
||||
@abc.abstractmethod
|
||||
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
|
||||
raise NotImplementedError
|
||||
|
||||
# -- blocking list (used by MCP OAuth BLPOP pattern) -------------------
|
||||
|
||||
@abc.abstractmethod
|
||||
def rpush(self, key: str, value: str | bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
|
||||
"""Block until a value is available on one of *keys*, or *timeout* expires.
|
||||
|
||||
Returns ``(key, value)`` or ``None`` on timeout.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
92
backend/onyx/cache/redis_backend.py
vendored
92
backend/onyx/cache/redis_backend.py
vendored
@@ -1,92 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
from redis.client import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.cache.interface import CacheBackend
|
||||
from onyx.cache.interface import CacheLock
|
||||
|
||||
|
||||
class RedisCacheLock(CacheLock):
|
||||
"""Wraps ``redis.lock.Lock`` behind the ``CacheLock`` interface."""
|
||||
|
||||
def __init__(self, lock: RedisLock) -> None:
|
||||
self._lock = lock
|
||||
|
||||
def acquire(
|
||||
self,
|
||||
blocking: bool = True,
|
||||
blocking_timeout: float | None = None,
|
||||
) -> bool:
|
||||
return bool(
|
||||
self._lock.acquire(
|
||||
blocking=blocking,
|
||||
blocking_timeout=blocking_timeout,
|
||||
)
|
||||
)
|
||||
|
||||
def release(self) -> None:
|
||||
self._lock.release()
|
||||
|
||||
def owned(self) -> bool:
|
||||
return bool(self._lock.owned())
|
||||
|
||||
|
||||
class RedisCacheBackend(CacheBackend):
|
||||
"""``CacheBackend`` implementation that delegates to a ``redis.Redis`` client.
|
||||
|
||||
This is a thin pass-through — every method maps 1-to-1 to the underlying
|
||||
Redis command. ``TenantRedis`` key-prefixing is handled by the client
|
||||
itself (provided by ``get_redis_client``).
|
||||
"""
|
||||
|
||||
def __init__(self, redis_client: Redis) -> None:
|
||||
self._r = redis_client
|
||||
|
||||
# -- basic key/value ---------------------------------------------------
|
||||
|
||||
def get(self, key: str) -> bytes | None:
|
||||
val = self._r.get(key)
|
||||
if val is None:
|
||||
return None
|
||||
if isinstance(val, bytes):
|
||||
return val
|
||||
return str(val).encode()
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: str,
|
||||
value: str | bytes | int | float,
|
||||
ex: int | None = None,
|
||||
) -> None:
|
||||
self._r.set(key, value, ex=ex)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
self._r.delete(key)
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
return bool(self._r.exists(key))
|
||||
|
||||
# -- TTL ---------------------------------------------------------------
|
||||
|
||||
def expire(self, key: str, seconds: int) -> None:
|
||||
self._r.expire(key, seconds)
|
||||
|
||||
def ttl(self, key: str) -> int:
|
||||
return cast(int, self._r.ttl(key))
|
||||
|
||||
# -- distributed lock --------------------------------------------------
|
||||
|
||||
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
|
||||
return RedisCacheLock(self._r.lock(name, timeout=timeout))
|
||||
|
||||
# -- blocking list (MCP OAuth BLPOP pattern) ---------------------------
|
||||
|
||||
def rpush(self, key: str, value: str | bytes) -> None:
|
||||
self._r.rpush(key, value)
|
||||
|
||||
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
|
||||
result = cast(list[bytes] | None, self._r.blpop(keys, timeout=timeout))
|
||||
if result is None:
|
||||
return None
|
||||
return (result[0], result[1])
|
||||
@@ -6,7 +6,6 @@ from datetime import timezone
|
||||
from typing import cast
|
||||
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.cache.interface import CacheBackendType
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
|
||||
@@ -55,12 +54,6 @@ DISABLE_USER_KNOWLEDGE = os.environ.get("DISABLE_USER_KNOWLEDGE", "").lower() ==
|
||||
# are disabled but core chat, tools, user file uploads, and Projects still work.
|
||||
DISABLE_VECTOR_DB = os.environ.get("DISABLE_VECTOR_DB", "").lower() == "true"
|
||||
|
||||
# Which backend to use for caching, locks, and ephemeral state.
|
||||
# "redis" (default) or "postgres" (only valid when DISABLE_VECTOR_DB=true).
|
||||
CACHE_BACKEND = CacheBackendType(
|
||||
os.environ.get("CACHE_BACKEND", CacheBackendType.REDIS)
|
||||
)
|
||||
|
||||
# Maximum token count for a single uploaded file. Files exceeding this are rejected.
|
||||
# Defaults to 100k tokens (or 10M when vector DB is disabled).
|
||||
_DEFAULT_FILE_TOKEN_LIMIT = 10_000_000 if DISABLE_VECTOR_DB else 100_000
|
||||
@@ -1130,6 +1123,13 @@ DISCORD_BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
|
||||
DISCORD_BOT_INVOKE_CHAR = os.environ.get("DISCORD_BOT_INVOKE_CHAR", "!")
|
||||
|
||||
|
||||
## Teams Bot Configuration
|
||||
TEAMS_BOT_APP_ID = os.environ.get("TEAMS_BOT_APP_ID")
|
||||
TEAMS_BOT_APP_SECRET = os.environ.get("TEAMS_BOT_APP_SECRET")
|
||||
TEAMS_BOT_AZURE_TENANT_ID = os.environ.get("TEAMS_BOT_AZURE_TENANT_ID")
|
||||
TEAMS_BOT_PORT = int(os.environ.get("TEAMS_BOT_PORT") or "3978")
|
||||
|
||||
|
||||
## Stripe Configuration
|
||||
# URL to fetch the Stripe publishable key from a public S3 bucket.
|
||||
# Publishable keys are safe to expose publicly - they can only initialize
|
||||
|
||||
@@ -99,6 +99,7 @@ DANSWER_API_KEY_PREFIX = "API_KEY__"
|
||||
DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN = "onyxapikey.ai"
|
||||
UNNAMED_KEY_PLACEHOLDER = "Unnamed"
|
||||
DISCORD_SERVICE_API_KEY_NAME = "discord-bot-service"
|
||||
TEAMS_SERVICE_API_KEY_NAME = "teams-bot-service"
|
||||
|
||||
# Key-Value store keys
|
||||
KV_REINDEX_KEY = "needs_reindexing"
|
||||
|
||||
@@ -20,6 +20,7 @@ from fastapi_users_db_sqlalchemy import SQLAlchemyBaseUserTableUUID
|
||||
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTableUUID
|
||||
from fastapi_users_db_sqlalchemy.generics import TIMESTAMPAware
|
||||
from sqlalchemy import Boolean
|
||||
from sqlalchemy import CheckConstraint
|
||||
from sqlalchemy import DateTime
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import Enum
|
||||
@@ -3668,6 +3669,115 @@ class DiscordChannelConfig(Base):
|
||||
)
|
||||
|
||||
|
||||
class TeamsBotConfig(Base):
|
||||
"""Global Teams bot configuration (one per tenant).
|
||||
|
||||
Stores the Azure Bot Service credentials when not provided via env vars.
|
||||
Uses a fixed ID with check constraint to enforce only one row per tenant.
|
||||
"""
|
||||
|
||||
__tablename__ = "teams_bot_config"
|
||||
__table_args__ = (
|
||||
CheckConstraint("id = 'SINGLETON'", name="ck_teams_bot_config_singleton"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String, primary_key=True, server_default=text("'SINGLETON'")
|
||||
)
|
||||
app_id: Mapped[str] = mapped_column(String, nullable=False)
|
||||
app_secret: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=False
|
||||
)
|
||||
azure_tenant_id: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
|
||||
|
||||
class TeamsTeamConfig(Base):
|
||||
"""Configuration for a Teams team connected to this tenant.
|
||||
|
||||
registration_key is a one-time key used to link a Teams team to this tenant.
|
||||
Format: teams_<tenant_id>.<random_token>
|
||||
team_id is NULL until the Teams admin runs @bot register with the key.
|
||||
"""
|
||||
|
||||
__tablename__ = "teams_team_config"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
|
||||
# Teams team ID (GUID string) - NULL until registered via command in Teams
|
||||
team_id: Mapped[str | None] = mapped_column(String, nullable=True, unique=True)
|
||||
team_name: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
||||
|
||||
# One-time registration key: teams_<tenant_id>.<random_token>
|
||||
registration_key: Mapped[str] = mapped_column(String, unique=True, nullable=False)
|
||||
|
||||
registered_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
# Configuration
|
||||
default_persona_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("persona.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
enabled: Mapped[bool] = mapped_column(
|
||||
Boolean, server_default=text("true"), nullable=False
|
||||
)
|
||||
|
||||
# Relationships
|
||||
default_persona: Mapped["Persona | None"] = relationship(
|
||||
"Persona", foreign_keys=[default_persona_id]
|
||||
)
|
||||
channels: Mapped[list["TeamsChannelConfig"]] = relationship(
|
||||
back_populates="team_config", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class TeamsChannelConfig(Base):
|
||||
"""Per-channel configuration for Teams bot behavior.
|
||||
|
||||
Used to whitelist specific channels and configure per-channel behavior.
|
||||
"""
|
||||
|
||||
__tablename__ = "teams_channel_config"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
team_config_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("teams_team_config.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
# Teams channel ID (string)
|
||||
channel_id: Mapped[str] = mapped_column(String, nullable=False)
|
||||
channel_name: Mapped[str] = mapped_column(String(), nullable=False)
|
||||
|
||||
# If true (default), bot only responds when @mentioned
|
||||
# If false, bot responds to ALL messages in this channel
|
||||
require_bot_mention: Mapped[bool] = mapped_column(
|
||||
Boolean, server_default=text("true"), nullable=False
|
||||
)
|
||||
|
||||
# Override the team's default persona for this channel
|
||||
persona_override_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("persona.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
|
||||
enabled: Mapped[bool] = mapped_column(
|
||||
Boolean, server_default=text("false"), nullable=False
|
||||
)
|
||||
|
||||
# Relationships
|
||||
team_config: Mapped["TeamsTeamConfig"] = relationship(back_populates="channels")
|
||||
persona_override: Mapped["Persona | None"] = relationship()
|
||||
|
||||
# Constraints
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"team_config_id", "channel_id", name="uq_teams_channel_team_channel"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Milestone(Base):
|
||||
# This table is used to track significant events for a deployment towards finding value
|
||||
# The table is currently not used for features but it may be used in the future to inform
|
||||
|
||||
@@ -52,7 +52,7 @@ def create_user_files(
|
||||
) -> CategorizedFilesResult:
|
||||
|
||||
# Categorize the files
|
||||
categorized_files = categorize_uploaded_files(files, db_session)
|
||||
categorized_files = categorize_uploaded_files(files)
|
||||
# NOTE: At the moment, zip metadata is not used for user files.
|
||||
# Should revisit to decide whether this should be a feature.
|
||||
upload_response = upload_files(categorized_files.acceptable, FileOrigin.USER_FILE)
|
||||
|
||||
331
backend/onyx/db/teams_bot.py
Normal file
331
backend/onyx/db/teams_bot.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""CRUD operations for Teams bot models."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.api_key import build_displayable_api_key
|
||||
from onyx.auth.api_key import generate_api_key
|
||||
from onyx.auth.api_key import hash_api_key
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import TEAMS_SERVICE_API_KEY_NAME
|
||||
from onyx.db.api_key import insert_api_key
|
||||
from onyx.db.models import ApiKey
|
||||
from onyx.db.models import TeamsBotConfig
|
||||
from onyx.db.models import TeamsChannelConfig
|
||||
from onyx.db.models import TeamsTeamConfig
|
||||
from onyx.db.models import User
|
||||
from onyx.server.api_key.models import APIKeyArgs
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# === TeamsBotConfig ===
|
||||
|
||||
|
||||
def get_teams_bot_config(db_session: Session) -> TeamsBotConfig | None:
|
||||
"""Get the Teams bot config for this tenant (at most one)."""
|
||||
return db_session.scalar(select(TeamsBotConfig).limit(1))
|
||||
|
||||
|
||||
def create_teams_bot_config(
|
||||
db_session: Session,
|
||||
app_id: str,
|
||||
app_secret: str,
|
||||
azure_tenant_id: str | None = None,
|
||||
) -> TeamsBotConfig:
|
||||
"""Create the Teams bot config. Raises ValueError if already exists.
|
||||
|
||||
The check constraint on id='SINGLETON' ensures only one config per tenant.
|
||||
"""
|
||||
existing = get_teams_bot_config(db_session)
|
||||
if existing:
|
||||
raise ValueError("Teams bot config already exists")
|
||||
|
||||
config = TeamsBotConfig(
|
||||
app_id=app_id,
|
||||
app_secret=app_secret,
|
||||
azure_tenant_id=azure_tenant_id,
|
||||
)
|
||||
db_session.add(config)
|
||||
try:
|
||||
db_session.flush()
|
||||
except IntegrityError:
|
||||
db_session.rollback()
|
||||
raise ValueError("Teams bot config already exists")
|
||||
return config
|
||||
|
||||
|
||||
def delete_teams_bot_config(db_session: Session) -> bool:
|
||||
"""Delete the Teams bot config. Returns True if deleted."""
|
||||
result = db_session.execute(delete(TeamsBotConfig))
|
||||
db_session.flush()
|
||||
return result.rowcount > 0 # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# === Teams Service API Key ===
|
||||
|
||||
|
||||
def get_teams_service_api_key(db_session: Session) -> ApiKey | None:
|
||||
"""Get the Teams service API key if it exists."""
|
||||
return db_session.scalar(
|
||||
select(ApiKey).where(ApiKey.name == TEAMS_SERVICE_API_KEY_NAME)
|
||||
)
|
||||
|
||||
|
||||
def provision_teams_service_api_key(
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
) -> str:
|
||||
"""Create or regenerate the Teams service API key, returning the raw key.
|
||||
|
||||
The database only stores the hashed key. When the cache is cold
|
||||
(e.g. after a pod restart), the raw key is unrecoverable, so we
|
||||
regenerate a new one and update the stored hash. This is safe because
|
||||
the bot is the sole consumer of this key.
|
||||
|
||||
This function is **not** idempotent — it mutates the stored hash on
|
||||
every call when a key already exists. Only call it on cache miss.
|
||||
"""
|
||||
existing = get_teams_service_api_key(db_session)
|
||||
if existing:
|
||||
logger.debug(
|
||||
f"Regenerating Teams service API key for tenant {tenant_id} "
|
||||
"(raw key unrecoverable from hash)"
|
||||
)
|
||||
new_api_key = generate_api_key(tenant_id)
|
||||
existing.hashed_api_key = hash_api_key(new_api_key)
|
||||
existing.api_key_display = build_displayable_api_key(new_api_key)
|
||||
db_session.flush()
|
||||
return new_api_key
|
||||
|
||||
logger.info(f"Creating Teams service API key for tenant {tenant_id}")
|
||||
api_key_args = APIKeyArgs(
|
||||
name=TEAMS_SERVICE_API_KEY_NAME,
|
||||
role=UserRole.LIMITED,
|
||||
)
|
||||
api_key_descriptor = insert_api_key(
|
||||
db_session=db_session,
|
||||
api_key_args=api_key_args,
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
if not api_key_descriptor.api_key:
|
||||
raise RuntimeError(
|
||||
f"Failed to create Teams service API key for tenant {tenant_id}"
|
||||
)
|
||||
|
||||
return api_key_descriptor.api_key
|
||||
|
||||
|
||||
def delete_teams_service_api_key(db_session: Session) -> bool:
|
||||
"""Delete the Teams service API key for a tenant.
|
||||
|
||||
Called when:
|
||||
- Bot config is deleted (self-hosted)
|
||||
- All team configs are deleted (Cloud)
|
||||
"""
|
||||
existing_key = get_teams_service_api_key(db_session)
|
||||
if not existing_key:
|
||||
return False
|
||||
|
||||
api_key_user = db_session.scalar(
|
||||
select(User).where(User.id == existing_key.user_id) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
db_session.delete(existing_key)
|
||||
if api_key_user:
|
||||
db_session.delete(api_key_user)
|
||||
|
||||
db_session.flush()
|
||||
logger.info("Deleted Teams service API key")
|
||||
return True
|
||||
|
||||
|
||||
# === TeamsTeamConfig ===
|
||||
|
||||
|
||||
def get_team_configs(
|
||||
db_session: Session,
|
||||
include_channels: bool = False,
|
||||
) -> list[TeamsTeamConfig]:
|
||||
"""Get all team configs for this tenant."""
|
||||
stmt = select(TeamsTeamConfig)
|
||||
if include_channels:
|
||||
stmt = stmt.options(joinedload(TeamsTeamConfig.channels))
|
||||
return list(db_session.scalars(stmt).unique().all())
|
||||
|
||||
|
||||
def get_team_config_by_internal_id(
|
||||
db_session: Session,
|
||||
internal_id: int,
|
||||
) -> TeamsTeamConfig | None:
|
||||
"""Get a specific team config by its ID."""
|
||||
return db_session.scalar(
|
||||
select(TeamsTeamConfig).where(TeamsTeamConfig.id == internal_id)
|
||||
)
|
||||
|
||||
|
||||
def get_team_config_by_teams_id(
|
||||
db_session: Session,
|
||||
team_id: str,
|
||||
) -> TeamsTeamConfig | None:
|
||||
"""Get a team config by Teams team ID."""
|
||||
return db_session.scalar(
|
||||
select(TeamsTeamConfig).where(TeamsTeamConfig.team_id == team_id)
|
||||
)
|
||||
|
||||
|
||||
def get_team_config_by_registration_key(
|
||||
db_session: Session,
|
||||
registration_key: str,
|
||||
for_update: bool = False,
|
||||
) -> TeamsTeamConfig | None:
|
||||
"""Get a team config by its registration key.
|
||||
|
||||
Use ``for_update=True`` to acquire a row-level lock, preventing
|
||||
concurrent registration races.
|
||||
"""
|
||||
stmt = select(TeamsTeamConfig).where(
|
||||
TeamsTeamConfig.registration_key == registration_key
|
||||
)
|
||||
if for_update:
|
||||
stmt = stmt.with_for_update()
|
||||
return db_session.scalar(stmt)
|
||||
|
||||
|
||||
def create_team_config(
|
||||
db_session: Session,
|
||||
registration_key: str,
|
||||
) -> TeamsTeamConfig:
|
||||
"""Create a new team config with a registration key (team_id=NULL)."""
|
||||
config = TeamsTeamConfig(registration_key=registration_key)
|
||||
db_session.add(config)
|
||||
db_session.flush()
|
||||
return config
|
||||
|
||||
|
||||
def register_team(
|
||||
db_session: Session,
|
||||
config: TeamsTeamConfig,
|
||||
team_id: str,
|
||||
team_name: str,
|
||||
) -> TeamsTeamConfig:
|
||||
"""Complete registration by setting team_id and team_name."""
|
||||
config.team_id = team_id
|
||||
config.team_name = team_name
|
||||
config.registered_at = datetime.now(timezone.utc)
|
||||
db_session.flush()
|
||||
return config
|
||||
|
||||
|
||||
def update_team_config(
|
||||
db_session: Session,
|
||||
config: TeamsTeamConfig,
|
||||
enabled: bool,
|
||||
default_persona_id: int | None = None,
|
||||
) -> TeamsTeamConfig:
|
||||
"""Update team config fields."""
|
||||
config.enabled = enabled
|
||||
config.default_persona_id = default_persona_id
|
||||
db_session.flush()
|
||||
return config
|
||||
|
||||
|
||||
def delete_team_config(
|
||||
db_session: Session,
|
||||
internal_id: int,
|
||||
) -> bool:
|
||||
"""Delete team config (cascades to channel configs). Returns True if deleted."""
|
||||
result = db_session.execute(
|
||||
delete(TeamsTeamConfig).where(TeamsTeamConfig.id == internal_id)
|
||||
)
|
||||
db_session.flush()
|
||||
return result.rowcount > 0 # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# === TeamsChannelConfig ===
|
||||
|
||||
|
||||
def get_channel_configs(
|
||||
db_session: Session,
|
||||
team_config_id: int,
|
||||
) -> list[TeamsChannelConfig]:
|
||||
"""Get all channel configs for a team."""
|
||||
return list(
|
||||
db_session.scalars(
|
||||
select(TeamsChannelConfig).where(
|
||||
TeamsChannelConfig.team_config_id == team_config_id
|
||||
)
|
||||
).all()
|
||||
)
|
||||
|
||||
|
||||
def get_channel_config_by_teams_ids(
|
||||
db_session: Session,
|
||||
team_id: str,
|
||||
channel_id: str,
|
||||
) -> TeamsChannelConfig | None:
|
||||
"""Get a specific channel config by team_id and channel_id."""
|
||||
return db_session.scalar(
|
||||
select(TeamsChannelConfig)
|
||||
.join(TeamsTeamConfig)
|
||||
.where(
|
||||
TeamsTeamConfig.team_id == team_id,
|
||||
TeamsChannelConfig.channel_id == channel_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_channel_config_by_internal_ids(
|
||||
db_session: Session,
|
||||
team_config_id: int,
|
||||
channel_config_id: int,
|
||||
) -> TeamsChannelConfig | None:
|
||||
"""Get a specific channel config by team_config_id and channel_config_id."""
|
||||
return db_session.scalar(
|
||||
select(TeamsChannelConfig).where(
|
||||
TeamsChannelConfig.team_config_id == team_config_id,
|
||||
TeamsChannelConfig.id == channel_config_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def update_teams_channel_config(
|
||||
db_session: Session,
|
||||
config: TeamsChannelConfig,
|
||||
channel_name: str,
|
||||
require_bot_mention: bool,
|
||||
enabled: bool,
|
||||
persona_override_id: int | None = None,
|
||||
) -> TeamsChannelConfig:
|
||||
"""Update channel config fields."""
|
||||
config.channel_name = channel_name
|
||||
config.require_bot_mention = require_bot_mention
|
||||
config.persona_override_id = persona_override_id
|
||||
config.enabled = enabled
|
||||
db_session.flush()
|
||||
return config
|
||||
|
||||
|
||||
def create_channel_config(
|
||||
db_session: Session,
|
||||
team_config_id: int,
|
||||
channel_id: str,
|
||||
channel_name: str,
|
||||
) -> TeamsChannelConfig:
|
||||
"""Create a new channel config with default settings (disabled by default)."""
|
||||
config = TeamsChannelConfig(
|
||||
team_config_id=team_config_id,
|
||||
channel_id=channel_id,
|
||||
channel_name=channel_name,
|
||||
)
|
||||
db_session.add(config)
|
||||
db_session.flush()
|
||||
return config
|
||||
@@ -6,6 +6,7 @@ import httpx
|
||||
from opensearchpy import NotFoundError
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
|
||||
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
@@ -562,7 +563,12 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
)
|
||||
|
||||
if not self._client.index_exists():
|
||||
index_settings = DocumentSchema.get_index_settings_based_on_environment()
|
||||
if USING_AWS_MANAGED_OPENSEARCH:
|
||||
index_settings = (
|
||||
DocumentSchema.get_index_settings_for_aws_managed_opensearch()
|
||||
)
|
||||
else:
|
||||
index_settings = DocumentSchema.get_index_settings()
|
||||
self._client.create_index(
|
||||
mappings=expected_mappings,
|
||||
settings=index_settings,
|
||||
|
||||
@@ -12,7 +12,6 @@ from pydantic import model_validator
|
||||
from pydantic import SerializerFunctionWrapHandler
|
||||
|
||||
from onyx.configs.app_configs import OPENSEARCH_TEXT_ANALYZER
|
||||
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.constants import EF_CONSTRUCTION
|
||||
@@ -526,7 +525,7 @@ class DocumentSchema:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_index_settings_for_aws_managed_opensearch_st_dev() -> dict[str, Any]:
|
||||
def get_index_settings_for_aws_managed_opensearch() -> dict[str, Any]:
|
||||
"""
|
||||
Settings for AWS-managed OpenSearch.
|
||||
|
||||
@@ -547,41 +546,3 @@ class DocumentSchema:
|
||||
"knn.algo_param.ef_search": EF_SEARCH,
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_index_settings_for_aws_managed_opensearch_mt_cloud() -> dict[str, Any]:
|
||||
"""
|
||||
Settings for AWS-managed OpenSearch in multi-tenant cloud.
|
||||
|
||||
324 shards very roughly targets a storage load of ~30Gb per shard, which
|
||||
according to AWS OpenSearch documentation is within a good target range.
|
||||
|
||||
As documented above we need 2 replicas for a total of 3 copies of the
|
||||
data because the cluster is configured with 3-AZ awareness.
|
||||
"""
|
||||
return {
|
||||
"index": {
|
||||
"number_of_shards": 324,
|
||||
"number_of_replicas": 2,
|
||||
# Required for vector search.
|
||||
"knn": True,
|
||||
"knn.algo_param.ef_search": EF_SEARCH,
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_index_settings_based_on_environment() -> dict[str, Any]:
|
||||
"""
|
||||
Returns the index settings based on the environment.
|
||||
"""
|
||||
if USING_AWS_MANAGED_OPENSEARCH:
|
||||
if MULTI_TENANT:
|
||||
return (
|
||||
DocumentSchema.get_index_settings_for_aws_managed_opensearch_mt_cloud()
|
||||
)
|
||||
else:
|
||||
return (
|
||||
DocumentSchema.get_index_settings_for_aws_managed_opensearch_st_dev()
|
||||
)
|
||||
else:
|
||||
return DocumentSchema.get_index_settings()
|
||||
|
||||
@@ -67,18 +67,6 @@ Status checked against LiteLLM v1.81.6-nightly (2026-02-02):
|
||||
STATUS: STILL NEEDED - litellm_core_utils/litellm_logging.py lines 3185-3199 set
|
||||
usage as a dict with chat completion format instead of keeping it as
|
||||
ResponseAPIUsage. Our patch creates a deep copy before modification.
|
||||
|
||||
7. Responses API metadata=None TypeError (_patch_responses_metadata_none):
|
||||
- LiteLLM's @client decorator wrapper in utils.py uses kwargs.get("metadata", {})
|
||||
to check for router calls, but when metadata is explicitly None (key exists with
|
||||
value None), the default {} is not used
|
||||
- This causes "argument of type 'NoneType' is not iterable" TypeError which swallows
|
||||
the real exception (e.g. AuthenticationError for wrong API key)
|
||||
- Surfaces as: APIConnectionError: OpenAIException - argument of type 'NoneType' is
|
||||
not iterable
|
||||
STATUS: STILL NEEDED - litellm/utils.py wrapper function (line 1721) does not guard
|
||||
against metadata being explicitly None. Triggered when Responses API bridge
|
||||
passes **litellm_params containing metadata=None.
|
||||
"""
|
||||
|
||||
import time
|
||||
@@ -737,44 +725,6 @@ def _patch_logging_assembled_streaming_response() -> None:
|
||||
LiteLLMLoggingObj._get_assembled_streaming_response = _patched_get_assembled_streaming_response # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _patch_responses_metadata_none() -> None:
|
||||
"""
|
||||
Patches litellm.responses to normalize metadata=None to metadata={} in kwargs.
|
||||
|
||||
LiteLLM's @client decorator wrapper in utils.py (line 1721) does:
|
||||
_is_litellm_router_call = "model_group" in kwargs.get("metadata", {})
|
||||
When metadata is explicitly None in kwargs, kwargs.get("metadata", {}) returns
|
||||
None (the key exists, so the default is not used), causing:
|
||||
TypeError: argument of type 'NoneType' is not iterable
|
||||
|
||||
This swallows the real exception (e.g. AuthenticationError) and surfaces as:
|
||||
APIConnectionError: OpenAIException - argument of type 'NoneType' is not iterable
|
||||
|
||||
This happens when the Responses API bridge calls litellm.responses() with
|
||||
**litellm_params which may contain metadata=None.
|
||||
|
||||
STATUS: STILL NEEDED - litellm/utils.py wrapper function uses kwargs.get("metadata", {})
|
||||
which does not guard against metadata being explicitly None. Same pattern exists
|
||||
on line 1407 for async path.
|
||||
"""
|
||||
import litellm as _litellm
|
||||
from functools import wraps
|
||||
|
||||
original_responses = _litellm.responses
|
||||
|
||||
if getattr(original_responses, "_metadata_patched", False):
|
||||
return
|
||||
|
||||
@wraps(original_responses)
|
||||
def _patched_responses(*args: Any, **kwargs: Any) -> Any:
|
||||
if kwargs.get("metadata") is None:
|
||||
kwargs["metadata"] = {}
|
||||
return original_responses(*args, **kwargs)
|
||||
|
||||
_patched_responses._metadata_patched = True # type: ignore[attr-defined]
|
||||
_litellm.responses = _patched_responses
|
||||
|
||||
|
||||
def apply_monkey_patches() -> None:
|
||||
"""
|
||||
Apply all necessary monkey patches to LiteLLM for compatibility.
|
||||
@@ -786,7 +736,6 @@ def apply_monkey_patches() -> None:
|
||||
- Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming
|
||||
- Patching ResponsesAPIResponse.model_construct to fix usage format in all code paths
|
||||
- Patching LiteLLMLoggingObj._get_assembled_streaming_response to avoid mutating original response
|
||||
- Patching litellm.responses to fix metadata=None causing TypeError in error handling
|
||||
"""
|
||||
_patch_ollama_chunk_parser()
|
||||
_patch_openai_responses_parallel_tool_calls()
|
||||
@@ -794,4 +743,3 @@ def apply_monkey_patches() -> None:
|
||||
_patch_azure_responses_should_fake_stream()
|
||||
_patch_responses_api_usage_format()
|
||||
_patch_logging_assembled_streaming_response()
|
||||
_patch_responses_metadata_none()
|
||||
|
||||
@@ -32,13 +32,11 @@ from onyx.auth.schemas import UserUpdate
|
||||
from onyx.auth.users import auth_backend
|
||||
from onyx.auth.users import create_onyx_oauth_router
|
||||
from onyx.auth.users import fastapi_users
|
||||
from onyx.cache.interface import CacheBackendType
|
||||
from onyx.configs.app_configs import APP_API_PREFIX
|
||||
from onyx.configs.app_configs import APP_HOST
|
||||
from onyx.configs.app_configs import APP_PORT
|
||||
from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import CACHE_BACKEND
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import LOG_ENDPOINT_LATENCY
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_ID
|
||||
@@ -117,6 +115,7 @@ from onyx.server.manage.opensearch_migration.api import (
|
||||
)
|
||||
from onyx.server.manage.search_settings import router as search_settings_router
|
||||
from onyx.server.manage.slack_bot import router as slack_bot_management_router
|
||||
from onyx.server.manage.teams_bot.api import router as teams_bot_router
|
||||
from onyx.server.manage.users import router as user_router
|
||||
from onyx.server.manage.web_search.api import (
|
||||
admin_router as web_search_admin_router,
|
||||
@@ -257,20 +256,6 @@ def include_auth_router_with_prefix(
|
||||
)
|
||||
|
||||
|
||||
def validate_cache_backend_settings() -> None:
|
||||
"""Validate that CACHE_BACKEND=postgres is only used with DISABLE_VECTOR_DB.
|
||||
|
||||
The Postgres cache backend eliminates the Redis dependency, but only works
|
||||
when Celery is not running (which requires DISABLE_VECTOR_DB=true).
|
||||
"""
|
||||
if CACHE_BACKEND == CacheBackendType.POSTGRES and not DISABLE_VECTOR_DB:
|
||||
raise RuntimeError(
|
||||
"CACHE_BACKEND=postgres requires DISABLE_VECTOR_DB=true. "
|
||||
"The Postgres cache backend is only supported in no-vector-DB "
|
||||
"deployments where Celery is replaced by the in-process task runner."
|
||||
)
|
||||
|
||||
|
||||
def validate_no_vector_db_settings() -> None:
|
||||
"""Validate that DISABLE_VECTOR_DB is not combined with incompatible settings.
|
||||
|
||||
@@ -302,7 +287,6 @@ def validate_no_vector_db_settings() -> None:
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
|
||||
validate_no_vector_db_settings()
|
||||
validate_cache_backend_settings()
|
||||
|
||||
# Set recursion limit
|
||||
if SYSTEM_RECURSION_LIMIT is not None:
|
||||
@@ -466,6 +450,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
application, slack_bot_management_router
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, discord_bot_router)
|
||||
include_router_with_global_prefix_prepended(application, teams_bot_router)
|
||||
include_router_with_global_prefix_prepended(application, persona_router)
|
||||
include_router_with_global_prefix_prepended(application, admin_persona_router)
|
||||
include_router_with_global_prefix_prepended(application, agents_router)
|
||||
|
||||
0
backend/onyx/onyxbot/__init__.py
Normal file
0
backend/onyx/onyxbot/__init__.py
Normal file
@@ -1,12 +1,12 @@
|
||||
"""Async HTTP client for communicating with Onyx API pods."""
|
||||
"""Shared async HTTP client for communicating with Onyx API pods."""
|
||||
|
||||
import aiohttp
|
||||
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.onyxbot.discord.constants import API_REQUEST_TIMEOUT
|
||||
from onyx.onyxbot.discord.exceptions import APIConnectionError
|
||||
from onyx.onyxbot.discord.exceptions import APIResponseError
|
||||
from onyx.onyxbot.discord.exceptions import APITimeoutError
|
||||
from onyx.onyxbot.constants import API_REQUEST_TIMEOUT
|
||||
from onyx.onyxbot.exceptions import APIConnectionError
|
||||
from onyx.onyxbot.exceptions import APIResponseError
|
||||
from onyx.onyxbot.exceptions import APITimeoutError
|
||||
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
@@ -19,36 +19,17 @@ logger = setup_logger()
|
||||
class OnyxAPIClient:
|
||||
"""Async HTTP client for sending chat requests to Onyx API pods.
|
||||
|
||||
This client manages an aiohttp session for making non-blocking HTTP
|
||||
requests to the Onyx API server. It handles authentication with per-tenant
|
||||
API keys and multi-tenant routing.
|
||||
|
||||
Usage:
|
||||
client = OnyxAPIClient()
|
||||
await client.initialize()
|
||||
try:
|
||||
response = await client.send_chat_message(
|
||||
message="What is our deployment process?",
|
||||
tenant_id="tenant_123",
|
||||
api_key="dn_xxx...",
|
||||
persona_id=1,
|
||||
)
|
||||
print(response.answer)
|
||||
finally:
|
||||
await client.close()
|
||||
Used by both Discord and Teams bots. The ``origin`` parameter controls
|
||||
which ``MessageOrigin`` value is attached to outgoing requests for
|
||||
telemetry tracking.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
origin: MessageOrigin,
|
||||
timeout: int = API_REQUEST_TIMEOUT,
|
||||
) -> None:
|
||||
"""Initialize the API client.
|
||||
|
||||
Args:
|
||||
timeout: Request timeout in seconds.
|
||||
"""
|
||||
# Helm chart uses API_SERVER_URL_OVERRIDE_FOR_HTTP_REQUESTS to set the base URL
|
||||
# TODO: Ideally, this override is only used when someone is launching an Onyx service independently
|
||||
self._origin = origin
|
||||
self._base_url = build_api_server_url_for_http_requests(
|
||||
respect_env_override_if_set=True
|
||||
).rstrip("/")
|
||||
@@ -56,28 +37,20 @@ class OnyxAPIClient:
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Create the aiohttp session.
|
||||
|
||||
Must be called before making any requests. The session is created
|
||||
with a total timeout and connection timeout.
|
||||
"""
|
||||
"""Create the aiohttp session."""
|
||||
if self._session is not None:
|
||||
logger.warning("API client session already initialized")
|
||||
return
|
||||
|
||||
timeout = aiohttp.ClientTimeout(
|
||||
total=self._timeout,
|
||||
connect=30, # 30 seconds to establish connection
|
||||
connect=30,
|
||||
)
|
||||
self._session = aiohttp.ClientSession(timeout=timeout)
|
||||
logger.info(f"API client initialized with base URL: {self._base_url}")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the aiohttp session.
|
||||
|
||||
Should be called when shutting down the bot to properly release
|
||||
resources.
|
||||
"""
|
||||
"""Close the aiohttp session."""
|
||||
if self._session is not None:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
@@ -85,7 +58,6 @@ class OnyxAPIClient:
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if the session is initialized."""
|
||||
return self._session is not None
|
||||
|
||||
async def send_chat_message(
|
||||
@@ -94,24 +66,7 @@ class OnyxAPIClient:
|
||||
api_key: str,
|
||||
persona_id: int | None = None,
|
||||
) -> ChatFullResponse:
|
||||
"""Send a chat message to the Onyx API server and get a response.
|
||||
|
||||
This method sends a non-streaming chat request to the API server. The response
|
||||
contains the complete answer with any citations and metadata.
|
||||
|
||||
Args:
|
||||
message: The user's message to process.
|
||||
api_key: The API key for authentication.
|
||||
persona_id: Optional persona ID to use for the response.
|
||||
|
||||
Returns:
|
||||
ChatFullResponse containing the answer, citations, and metadata.
|
||||
|
||||
Raises:
|
||||
APIConnectionError: If unable to connect to the API.
|
||||
APITimeoutError: If the request times out.
|
||||
APIResponseError: If the API returns an error response.
|
||||
"""
|
||||
"""Send a chat message to the Onyx API server and get a response."""
|
||||
if self._session is None:
|
||||
raise APIConnectionError(
|
||||
"API client not initialized. Call initialize() first."
|
||||
@@ -119,17 +74,15 @@ class OnyxAPIClient:
|
||||
|
||||
url = f"{self._base_url}/chat/send-chat-message"
|
||||
|
||||
# Build request payload
|
||||
request = SendMessageRequest(
|
||||
message=message,
|
||||
stream=False,
|
||||
origin=MessageOrigin.DISCORDBOT,
|
||||
origin=self._origin,
|
||||
chat_session_info=ChatSessionCreationRequest(
|
||||
persona_id=persona_id if persona_id is not None else 0,
|
||||
),
|
||||
)
|
||||
|
||||
# Build headers
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
@@ -169,7 +122,6 @@ class OnyxAPIClient:
|
||||
status_code=response.status,
|
||||
)
|
||||
|
||||
# Parse successful response
|
||||
data = await response.json()
|
||||
response_obj = ChatFullResponse.model_validate(data)
|
||||
|
||||
@@ -195,11 +147,7 @@ class OnyxAPIClient:
|
||||
raise APIConnectionError(f"HTTP client error: {e}") from e
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""Check if the API server is healthy.
|
||||
|
||||
Returns:
|
||||
True if the API server is reachable and healthy, False otherwise.
|
||||
"""
|
||||
"""Check if the API server is healthy."""
|
||||
if self._session is None:
|
||||
logger.warning("API client not initialized. Call initialize() first.")
|
||||
return False
|
||||
195
backend/onyx/onyxbot/cache.py
Normal file
195
backend/onyx/onyxbot/cache.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""Shared multi-tenant cache for bot entity-tenant mappings and API keys.
|
||||
|
||||
Subclass ``BotCacheManager`` and implement the three abstract helpers to
|
||||
create a platform-specific cache (e.g. Discord guilds, Teams teams).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Generic
|
||||
from typing import TypeVar
|
||||
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.onyxbot.exceptions import CacheError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
EntityIdT = TypeVar("EntityIdT")
|
||||
|
||||
|
||||
class BotCacheManager(ABC, Generic[EntityIdT]):
|
||||
"""Caches entity->tenant mappings and tenant->API key mappings.
|
||||
|
||||
``EntityIdT`` is ``int`` for Discord guilds, ``str`` for Teams teams.
|
||||
"""
|
||||
|
||||
def __init__(self, entity_name: str) -> None:
|
||||
self._entity_name = entity_name
|
||||
self._entity_tenants: dict[EntityIdT, str] = {}
|
||||
self._api_keys: dict[str, str] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._initialized = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Abstract hooks — platform-specific DB access
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@abstractmethod
|
||||
def _get_entity_ids(self, db: Session) -> list[EntityIdT]:
|
||||
"""Return active entity IDs from DB configs."""
|
||||
|
||||
@abstractmethod
|
||||
def _get_or_create_api_key(self, db: Session, tenant_id: str) -> str:
|
||||
"""Provision (or retrieve) a service API key for *tenant_id*."""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
return self._initialized
|
||||
|
||||
async def refresh_all(self) -> None:
|
||||
"""Full cache refresh from all tenants.
|
||||
|
||||
Data is loaded outside the lock; the lock is only held for the
|
||||
atomic swap of the cache dicts so that ``refresh_entity`` and
|
||||
read operations are not blocked during I/O.
|
||||
"""
|
||||
logger.info(f"Starting {self._entity_name} cache refresh")
|
||||
|
||||
new_entity_tenants: dict[EntityIdT, str] = {}
|
||||
new_api_keys: dict[str, str] = {}
|
||||
|
||||
try:
|
||||
gated = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.product_gating",
|
||||
"get_gated_tenants",
|
||||
set(),
|
||||
)()
|
||||
|
||||
tenant_ids = await asyncio.to_thread(get_all_tenant_ids)
|
||||
for tenant_id in tenant_ids:
|
||||
if tenant_id in gated:
|
||||
continue
|
||||
|
||||
context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
entity_ids, api_key = await self._load_tenant_data(tenant_id)
|
||||
if not entity_ids:
|
||||
logger.debug(
|
||||
f"No {self._entity_name} found for tenant " f"{tenant_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not api_key:
|
||||
logger.warning(
|
||||
f"Service API key missing for tenant that has "
|
||||
f"registered {self._entity_name}. {tenant_id} "
|
||||
f"will not be handled in this refresh cycle."
|
||||
)
|
||||
continue
|
||||
|
||||
for entity_id in entity_ids:
|
||||
new_entity_tenants[entity_id] = tenant_id
|
||||
|
||||
new_api_keys[tenant_id] = api_key
|
||||
except (OperationalError, ConnectionError, OSError) as e:
|
||||
logger.warning(f"Failed to refresh tenant {tenant_id}: {e}")
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token)
|
||||
|
||||
# Atomic swap under lock
|
||||
async with self._lock:
|
||||
self._entity_tenants = new_entity_tenants
|
||||
self._api_keys = new_api_keys
|
||||
self._initialized = True
|
||||
|
||||
logger.info(
|
||||
f"Cache refresh complete: "
|
||||
f"{len(new_entity_tenants)} {self._entity_name}, "
|
||||
f"{len(new_api_keys)} tenants"
|
||||
)
|
||||
|
||||
except (OperationalError, ConnectionError, OSError) as e:
|
||||
logger.error(f"Cache refresh failed: {e}")
|
||||
raise CacheError(f"Failed to refresh cache: {e}") from e
|
||||
|
||||
async def refresh_entity(self, entity_id: EntityIdT, tenant_id: str) -> None:
|
||||
"""Add a single entity to cache after registration."""
|
||||
logger.info(
|
||||
f"Refreshing cache for {self._entity_name} entity "
|
||||
f"{entity_id} (tenant: {tenant_id})"
|
||||
)
|
||||
|
||||
context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
entity_ids, api_key = await self._load_tenant_data(tenant_id)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token)
|
||||
|
||||
async with self._lock:
|
||||
if entity_id in entity_ids:
|
||||
self._entity_tenants[entity_id] = tenant_id
|
||||
if api_key:
|
||||
self._api_keys[tenant_id] = api_key
|
||||
logger.info(f"Cache updated for entity {entity_id}")
|
||||
else:
|
||||
logger.warning(f"Entity {entity_id} not found or disabled")
|
||||
|
||||
def get_tenant(self, entity_id: EntityIdT) -> str | None:
|
||||
"""Get tenant ID for an entity."""
|
||||
return self._entity_tenants.get(entity_id)
|
||||
|
||||
def get_api_key(self, tenant_id: str) -> str | None:
|
||||
"""Get API key for a tenant."""
|
||||
return self._api_keys.get(tenant_id)
|
||||
|
||||
def remove_entity(self, entity_id: EntityIdT) -> None:
|
||||
"""Remove an entity from cache."""
|
||||
self._entity_tenants.pop(entity_id, None)
|
||||
|
||||
def get_all_entity_ids(self) -> list[EntityIdT]:
|
||||
"""Get all cached entity IDs."""
|
||||
return list(self._entity_tenants.keys())
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all caches."""
|
||||
self._entity_tenants.clear()
|
||||
self._api_keys.clear()
|
||||
self._initialized = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _load_tenant_data(
|
||||
self, tenant_id: str
|
||||
) -> tuple[list[EntityIdT], str | None]:
|
||||
"""Load entity IDs and provision API key if needed."""
|
||||
cached_key = self._api_keys.get(tenant_id)
|
||||
|
||||
def _sync() -> tuple[list[EntityIdT], str | None]:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db:
|
||||
entity_ids = self._get_entity_ids(db)
|
||||
|
||||
if not entity_ids:
|
||||
return [], None
|
||||
|
||||
if not cached_key:
|
||||
new_key = self._get_or_create_api_key(db, tenant_id)
|
||||
db.commit()
|
||||
return entity_ids, new_key
|
||||
|
||||
return entity_ids, cached_key
|
||||
|
||||
return await asyncio.to_thread(_sync)
|
||||
10
backend/onyx/onyxbot/constants.py
Normal file
10
backend/onyx/onyxbot/constants.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Shared constants for Onyx bot integrations (Discord, Teams, etc.)."""
|
||||
|
||||
# API settings
|
||||
API_REQUEST_TIMEOUT: int = 3 * 60 # 3 minutes
|
||||
|
||||
# Cache settings
|
||||
CACHE_REFRESH_INTERVAL: int = 60 # 1 minute
|
||||
|
||||
# Registration
|
||||
REGISTER_COMMAND: str = "register"
|
||||
@@ -1,154 +1,35 @@
|
||||
"""Multi-tenant cache for Discord bot guild-tenant mappings and API keys."""
|
||||
|
||||
import asyncio
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.discord_bot import get_guild_configs
|
||||
from onyx.db.discord_bot import get_or_create_discord_service_api_key
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.onyxbot.discord.exceptions import CacheError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
from onyx.onyxbot.cache import BotCacheManager
|
||||
|
||||
|
||||
class DiscordCacheManager:
|
||||
"""Caches guild->tenant mappings and tenant->API key mappings.
|
||||
|
||||
Refreshed on startup, periodically (every 60s), and when guilds register.
|
||||
"""
|
||||
class DiscordCacheManager(BotCacheManager[int]):
|
||||
"""Caches guild->tenant mappings and tenant->API key mappings."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._guild_tenants: dict[int, str] = {} # guild_id -> tenant_id
|
||||
self._api_keys: dict[str, str] = {} # tenant_id -> api_key
|
||||
self._lock = asyncio.Lock()
|
||||
self._initialized = False
|
||||
super().__init__(entity_name="guilds")
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
return self._initialized
|
||||
def _get_entity_ids(self, db: Session) -> list[int]:
|
||||
configs = get_guild_configs(db)
|
||||
return [
|
||||
config.guild_id
|
||||
for config in configs
|
||||
if config.enabled and config.guild_id is not None
|
||||
]
|
||||
|
||||
async def refresh_all(self) -> None:
|
||||
"""Full cache refresh from all tenants."""
|
||||
async with self._lock:
|
||||
logger.info("Starting Discord cache refresh")
|
||||
|
||||
new_guild_tenants: dict[int, str] = {}
|
||||
new_api_keys: dict[str, str] = {}
|
||||
|
||||
try:
|
||||
gated = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.product_gating",
|
||||
"get_gated_tenants",
|
||||
set(),
|
||||
)()
|
||||
|
||||
tenant_ids = await asyncio.to_thread(get_all_tenant_ids)
|
||||
for tenant_id in tenant_ids:
|
||||
if tenant_id in gated:
|
||||
continue
|
||||
|
||||
context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
guild_ids, api_key = await self._load_tenant_data(tenant_id)
|
||||
if not guild_ids:
|
||||
logger.debug(f"No guilds found for tenant {tenant_id}")
|
||||
continue
|
||||
|
||||
if not api_key:
|
||||
logger.warning(
|
||||
"Discord service API key missing for tenant that has registered guilds. "
|
||||
f"{tenant_id} will not be handled in this refresh cycle."
|
||||
)
|
||||
continue
|
||||
|
||||
for guild_id in guild_ids:
|
||||
new_guild_tenants[guild_id] = tenant_id
|
||||
|
||||
new_api_keys[tenant_id] = api_key
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to refresh tenant {tenant_id}: {e}")
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token)
|
||||
|
||||
self._guild_tenants = new_guild_tenants
|
||||
self._api_keys = new_api_keys
|
||||
self._initialized = True
|
||||
|
||||
logger.info(
|
||||
f"Cache refresh complete: {len(new_guild_tenants)} guilds, "
|
||||
f"{len(new_api_keys)} tenants"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cache refresh failed: {e}")
|
||||
raise CacheError(f"Failed to refresh cache: {e}") from e
|
||||
def _get_or_create_api_key(self, db: Session, tenant_id: str) -> str:
|
||||
return get_or_create_discord_service_api_key(db, tenant_id)
|
||||
|
||||
# Convenience aliases for backward compatibility with callers
|
||||
async def refresh_guild(self, guild_id: int, tenant_id: str) -> None:
|
||||
"""Add a single guild to cache after registration."""
|
||||
async with self._lock:
|
||||
logger.info(f"Refreshing cache for guild {guild_id} (tenant: {tenant_id})")
|
||||
|
||||
guild_ids, api_key = await self._load_tenant_data(tenant_id)
|
||||
|
||||
if guild_id in guild_ids:
|
||||
self._guild_tenants[guild_id] = tenant_id
|
||||
if api_key:
|
||||
self._api_keys[tenant_id] = api_key
|
||||
logger.info(f"Cache updated for guild {guild_id}")
|
||||
else:
|
||||
logger.warning(f"Guild {guild_id} not found or disabled")
|
||||
|
||||
async def _load_tenant_data(self, tenant_id: str) -> tuple[list[int], str | None]:
|
||||
"""Load guild IDs and provision API key if needed.
|
||||
|
||||
Returns:
|
||||
(active_guild_ids, api_key) - api_key is the cached key if available,
|
||||
otherwise a newly created key. Returns None if no guilds found.
|
||||
"""
|
||||
cached_key = self._api_keys.get(tenant_id)
|
||||
|
||||
def _sync() -> tuple[list[int], str | None]:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db:
|
||||
configs = get_guild_configs(db)
|
||||
guild_ids = [
|
||||
config.guild_id
|
||||
for config in configs
|
||||
if config.enabled and config.guild_id is not None
|
||||
]
|
||||
|
||||
if not guild_ids:
|
||||
return [], None
|
||||
|
||||
if not cached_key:
|
||||
new_key = get_or_create_discord_service_api_key(db, tenant_id)
|
||||
db.commit()
|
||||
return guild_ids, new_key
|
||||
|
||||
return guild_ids, cached_key
|
||||
|
||||
return await asyncio.to_thread(_sync)
|
||||
|
||||
def get_tenant(self, guild_id: int) -> str | None:
|
||||
"""Get tenant ID for a guild."""
|
||||
return self._guild_tenants.get(guild_id)
|
||||
|
||||
def get_api_key(self, tenant_id: str) -> str | None:
|
||||
"""Get API key for a tenant."""
|
||||
return self._api_keys.get(tenant_id)
|
||||
await self.refresh_entity(guild_id, tenant_id)
|
||||
|
||||
def remove_guild(self, guild_id: int) -> None:
|
||||
"""Remove a guild from cache."""
|
||||
self._guild_tenants.pop(guild_id, None)
|
||||
self.remove_entity(guild_id)
|
||||
|
||||
def get_all_guild_ids(self) -> list[int]:
|
||||
"""Get all cached guild IDs."""
|
||||
return list(self._guild_tenants.keys())
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all caches."""
|
||||
self._guild_tenants.clear()
|
||||
self._api_keys.clear()
|
||||
self._initialized = False
|
||||
return self.get_all_entity_ids()
|
||||
|
||||
@@ -7,15 +7,16 @@ import discord
|
||||
from discord.ext import commands
|
||||
|
||||
from onyx.configs.app_configs import DISCORD_BOT_INVOKE_CHAR
|
||||
from onyx.onyxbot.discord.api_client import OnyxAPIClient
|
||||
from onyx.onyxbot.api_client import OnyxAPIClient
|
||||
from onyx.onyxbot.constants import CACHE_REFRESH_INTERVAL
|
||||
from onyx.onyxbot.discord.cache import DiscordCacheManager
|
||||
from onyx.onyxbot.discord.constants import CACHE_REFRESH_INTERVAL
|
||||
from onyx.onyxbot.discord.handle_commands import handle_dm
|
||||
from onyx.onyxbot.discord.handle_commands import handle_registration_command
|
||||
from onyx.onyxbot.discord.handle_commands import handle_sync_channels_command
|
||||
from onyx.onyxbot.discord.handle_message import process_chat_message
|
||||
from onyx.onyxbot.discord.handle_message import should_respond
|
||||
from onyx.onyxbot.discord.utils import get_bot_token
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -40,7 +41,7 @@ class OnyxDiscordClient(commands.Bot):
|
||||
|
||||
self.ready = False
|
||||
self.cache = DiscordCacheManager()
|
||||
self.api_client = OnyxAPIClient()
|
||||
self.api_client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
|
||||
self._cache_refresh_task: asyncio.Task | None = None
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@@ -1,19 +1,16 @@
|
||||
"""Discord bot constants."""
|
||||
"""Discord-specific bot constants.
|
||||
|
||||
# API settings
|
||||
API_REQUEST_TIMEOUT: int = 3 * 60 # 3 minutes
|
||||
|
||||
# Cache settings
|
||||
CACHE_REFRESH_INTERVAL: int = 60 # 1 minute
|
||||
Shared constants (API_REQUEST_TIMEOUT, CACHE_REFRESH_INTERVAL,
|
||||
REGISTER_COMMAND) live in ``onyx.onyxbot.constants``.
|
||||
"""
|
||||
|
||||
# Message settings
|
||||
MAX_MESSAGE_LENGTH: int = 2000 # Discord's character limit
|
||||
MAX_CONTEXT_MESSAGES: int = 10 # Max messages to include in conversation context
|
||||
# Note: Discord.py's add_reaction() requires unicode emoji, not :name: format
|
||||
THINKING_EMOJI: str = "🤔" # U+1F914 - Thinking Face
|
||||
SUCCESS_EMOJI: str = "✅" # U+2705 - White Heavy Check Mark
|
||||
ERROR_EMOJI: str = "❌" # U+274C - Cross Mark
|
||||
THINKING_EMOJI: str = "\U0001f914" # U+1F914 - Thinking Face
|
||||
SUCCESS_EMOJI: str = "\u2705" # U+2705 - White Heavy Check Mark
|
||||
ERROR_EMOJI: str = "\u274c" # U+274C - Cross Mark
|
||||
|
||||
# Command prefix
|
||||
REGISTER_COMMAND: str = "register"
|
||||
# Discord-specific commands
|
||||
SYNC_CHANNELS_COMMAND: str = "sync-channels"
|
||||
|
||||
@@ -1,37 +1,7 @@
|
||||
"""Custom exception classes for Discord bot."""
|
||||
"""Discord-specific exception classes."""
|
||||
|
||||
from onyx.onyxbot.exceptions import OnyxBotError
|
||||
|
||||
|
||||
class DiscordBotError(Exception):
|
||||
"""Base exception for Discord bot errors."""
|
||||
|
||||
|
||||
class RegistrationError(DiscordBotError):
|
||||
"""Error during guild registration."""
|
||||
|
||||
|
||||
class SyncChannelsError(DiscordBotError):
|
||||
class SyncChannelsError(OnyxBotError):
|
||||
"""Error during channel sync."""
|
||||
|
||||
|
||||
class APIError(DiscordBotError):
|
||||
"""Base API error."""
|
||||
|
||||
|
||||
class CacheError(DiscordBotError):
|
||||
"""Error during cache operations."""
|
||||
|
||||
|
||||
class APIConnectionError(APIError):
|
||||
"""Failed to connect to API."""
|
||||
|
||||
|
||||
class APITimeoutError(APIError):
|
||||
"""Request timed out."""
|
||||
|
||||
|
||||
class APIResponseError(APIError):
|
||||
"""API returned an error response."""
|
||||
|
||||
def __init__(self, message: str, status_code: int | None = None):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
@@ -15,11 +15,11 @@ from onyx.db.discord_bot import get_guild_config_by_registration_key
|
||||
from onyx.db.discord_bot import sync_channel_configs
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.utils import DiscordChannelView
|
||||
from onyx.onyxbot.constants import REGISTER_COMMAND
|
||||
from onyx.onyxbot.discord.cache import DiscordCacheManager
|
||||
from onyx.onyxbot.discord.constants import REGISTER_COMMAND
|
||||
from onyx.onyxbot.discord.constants import SYNC_CHANNELS_COMMAND
|
||||
from onyx.onyxbot.discord.exceptions import RegistrationError
|
||||
from onyx.onyxbot.discord.exceptions import SyncChannelsError
|
||||
from onyx.onyxbot.exceptions import RegistrationError
|
||||
from onyx.server.manage.discord_bot.utils import parse_discord_registration_key
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
@@ -11,11 +11,11 @@ from onyx.db.discord_bot import get_guild_config_by_discord_id
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.models import DiscordChannelConfig
|
||||
from onyx.db.models import DiscordGuildConfig
|
||||
from onyx.onyxbot.discord.api_client import OnyxAPIClient
|
||||
from onyx.onyxbot.api_client import OnyxAPIClient
|
||||
from onyx.onyxbot.discord.constants import MAX_CONTEXT_MESSAGES
|
||||
from onyx.onyxbot.discord.constants import MAX_MESSAGE_LENGTH
|
||||
from onyx.onyxbot.discord.constants import THINKING_EMOJI
|
||||
from onyx.onyxbot.discord.exceptions import APIError
|
||||
from onyx.onyxbot.exceptions import APIError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
33
backend/onyx/onyxbot/exceptions.py
Normal file
33
backend/onyx/onyxbot/exceptions.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Shared exception classes for Onyx bot integrations (Discord, Teams, etc.)."""
|
||||
|
||||
|
||||
class OnyxBotError(Exception):
|
||||
"""Base exception for all Onyx bot errors."""
|
||||
|
||||
|
||||
class RegistrationError(OnyxBotError):
|
||||
"""Error during bot registration."""
|
||||
|
||||
|
||||
class APIError(OnyxBotError):
|
||||
"""Base API error."""
|
||||
|
||||
|
||||
class CacheError(OnyxBotError):
|
||||
"""Error during cache operations."""
|
||||
|
||||
|
||||
class APIConnectionError(APIError):
|
||||
"""Failed to connect to API."""
|
||||
|
||||
|
||||
class APITimeoutError(APIError):
|
||||
"""Request timed out."""
|
||||
|
||||
|
||||
class APIResponseError(APIError):
|
||||
"""API returned an error response."""
|
||||
|
||||
def __init__(self, message: str, status_code: int | None = None):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
42
backend/onyx/onyxbot/registration.py
Normal file
42
backend/onyx/onyxbot/registration.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""Shared registration key generation and parsing for bot integrations."""
|
||||
|
||||
import secrets
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import unquote
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def generate_registration_key(prefix: str, tenant_id: str) -> str:
|
||||
"""Generate a one-time registration key with embedded tenant_id.
|
||||
|
||||
Format: <prefix>_<url_encoded_tenant_id>.<random_token>
|
||||
"""
|
||||
encoded_tenant = quote(tenant_id)
|
||||
random_token = secrets.token_urlsafe(16)
|
||||
|
||||
logger.info(f"Generated {prefix} registration key for tenant {tenant_id}")
|
||||
return f"{prefix}_{encoded_tenant}.{random_token}"
|
||||
|
||||
|
||||
def parse_registration_key(prefix: str, key: str) -> str | None:
|
||||
"""Parse registration key to extract tenant_id.
|
||||
|
||||
Returns tenant_id or None if invalid format.
|
||||
"""
|
||||
full_prefix = f"{prefix}_"
|
||||
if not key.startswith(full_prefix):
|
||||
return None
|
||||
|
||||
try:
|
||||
key_body = key.removeprefix(full_prefix)
|
||||
parts = key_body.split(".", 1)
|
||||
if len(parts) != 2:
|
||||
return None
|
||||
|
||||
encoded_tenant = parts[0]
|
||||
return unquote(encoded_tenant)
|
||||
except Exception:
|
||||
return None
|
||||
0
backend/onyx/onyxbot/teams/__init__.py
Normal file
0
backend/onyx/onyxbot/teams/__init__.py
Normal file
186
backend/onyx/onyxbot/teams/bot.py
Normal file
186
backend/onyx/onyxbot/teams/bot.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""Teams bot Activity handler using Bot Framework SDK."""
|
||||
|
||||
import asyncio
|
||||
|
||||
from botbuilder.core import ActivityHandler # type: ignore[import-untyped]
|
||||
from botbuilder.core import TurnContext
|
||||
from botbuilder.schema import Activity # type: ignore[import-untyped]
|
||||
from botbuilder.schema import ActivityTypes
|
||||
from botbuilder.schema import Attachment
|
||||
from botbuilder.schema import ChannelAccount
|
||||
|
||||
from onyx.onyxbot.api_client import OnyxAPIClient
|
||||
from onyx.onyxbot.constants import CACHE_REFRESH_INTERVAL
|
||||
from onyx.onyxbot.exceptions import RegistrationError
|
||||
from onyx.onyxbot.teams.cache import TeamsCacheManager
|
||||
from onyx.onyxbot.teams.handle_commands import handle_registration_command
|
||||
from onyx.onyxbot.teams.handle_commands import is_registration_command
|
||||
from onyx.onyxbot.teams.handle_message import process_chat_message
|
||||
from onyx.onyxbot.teams.handle_message import should_respond
|
||||
from onyx.onyxbot.teams.utils import extract_channel_id
|
||||
from onyx.onyxbot.teams.utils import extract_team_id
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class OnyxTeamsBot(ActivityHandler):
|
||||
"""Activity handler for Teams bot.
|
||||
|
||||
Handles incoming message activities, member additions, and routes
|
||||
messages to the appropriate handler (registration, chat).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.cache = TeamsCacheManager()
|
||||
self.api_client = OnyxAPIClient(origin=MessageOrigin.TEAMSBOT)
|
||||
self._cache_refresh_task: asyncio.Task[None] | None = None
|
||||
self._bot_id: str | None = None
|
||||
self._bot_name: str = "Onyx"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the bot: API client, cache, and background tasks."""
|
||||
await self.api_client.initialize()
|
||||
await self.cache.refresh_all()
|
||||
self._cache_refresh_task = asyncio.create_task(self._periodic_cache_refresh())
|
||||
logger.info("Teams bot initialized")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Gracefully shut down the bot."""
|
||||
if self._cache_refresh_task:
|
||||
self._cache_refresh_task.cancel()
|
||||
try:
|
||||
await self._cache_refresh_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
await self.api_client.close()
|
||||
self.cache.clear()
|
||||
logger.info("Teams bot shut down")
|
||||
|
||||
async def _periodic_cache_refresh(self) -> None:
|
||||
"""Background task to refresh cache periodically."""
|
||||
while True:
|
||||
await asyncio.sleep(CACHE_REFRESH_INTERVAL)
|
||||
try:
|
||||
await self.cache.refresh_all()
|
||||
except Exception as e:
|
||||
logger.error(f"Periodic cache refresh failed: {e}")
|
||||
|
||||
async def on_message_activity(self, turn_context: TurnContext) -> None:
|
||||
"""Handle incoming message activities."""
|
||||
activity = turn_context.activity
|
||||
if not activity.text:
|
||||
return
|
||||
|
||||
# Capture bot identity on first message
|
||||
if not self._bot_id and activity.recipient:
|
||||
self._bot_id = activity.recipient.id
|
||||
self._bot_name = activity.recipient.name or "Onyx"
|
||||
|
||||
activity_dict = activity.as_dict() if hasattr(activity, "as_dict") else {}
|
||||
team_id = extract_team_id(activity_dict)
|
||||
channel_id = extract_channel_id(activity_dict)
|
||||
|
||||
# Check for registration command
|
||||
if is_registration_command(activity.text, self._bot_name):
|
||||
await self._handle_registration(turn_context, activity_dict)
|
||||
return
|
||||
|
||||
# Resolve tenant from team cache
|
||||
tenant_id: str | None = None
|
||||
if team_id:
|
||||
tenant_id = self.cache.get_tenant(team_id)
|
||||
if not tenant_id:
|
||||
logger.debug(f"No tenant found for team {team_id}")
|
||||
return
|
||||
else:
|
||||
# DM — not in a team context, so we can't determine tenant.
|
||||
# TODO(nik): support DM registration or default tenant lookup
|
||||
logger.debug("Ignoring DM (no team context to resolve tenant)")
|
||||
return
|
||||
|
||||
# Check if bot should respond
|
||||
context = await asyncio.to_thread(
|
||||
should_respond,
|
||||
activity_dict,
|
||||
team_id,
|
||||
channel_id,
|
||||
tenant_id,
|
||||
self._bot_id or "",
|
||||
)
|
||||
|
||||
if not context.should_respond:
|
||||
return
|
||||
|
||||
api_key = self.cache.get_api_key(tenant_id)
|
||||
if not api_key:
|
||||
logger.warning(f"No API key for tenant {tenant_id}")
|
||||
return
|
||||
|
||||
# Send typing indicator
|
||||
await turn_context.send_activity(Activity(type=ActivityTypes.typing))
|
||||
|
||||
# Process message and send response
|
||||
card = await process_chat_message(
|
||||
text=activity.text,
|
||||
api_key=api_key,
|
||||
persona_id=context.persona_id,
|
||||
api_client=self.api_client,
|
||||
bot_name=self._bot_name,
|
||||
)
|
||||
|
||||
# Send as Adaptive Card
|
||||
attachment = Attachment(
|
||||
content_type="application/vnd.microsoft.card.adaptive",
|
||||
content=card,
|
||||
)
|
||||
response = Activity(
|
||||
type=ActivityTypes.message,
|
||||
attachments=[attachment],
|
||||
)
|
||||
await turn_context.send_activity(response)
|
||||
|
||||
async def _handle_registration(
|
||||
self,
|
||||
turn_context: TurnContext,
|
||||
activity_dict: dict,
|
||||
) -> None:
|
||||
"""Handle registration command."""
|
||||
try:
|
||||
result = await handle_registration_command(
|
||||
text=turn_context.activity.text or "",
|
||||
activity_dict=activity_dict,
|
||||
bot_name=self._bot_name,
|
||||
cache=self.cache,
|
||||
)
|
||||
await turn_context.send_activity(result)
|
||||
except RegistrationError as e:
|
||||
await turn_context.send_activity(f"Registration failed: {e}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Registration error: {e}")
|
||||
await turn_context.send_activity(
|
||||
"An unexpected error occurred during registration."
|
||||
)
|
||||
|
||||
async def on_members_added_activity(
|
||||
self,
|
||||
members_added: list[ChannelAccount],
|
||||
turn_context: TurnContext,
|
||||
) -> None:
|
||||
"""Handle when the bot is added to a team or conversation."""
|
||||
for member in members_added:
|
||||
# Only send welcome when the bot itself is added
|
||||
if member.id == turn_context.activity.recipient.id:
|
||||
from onyx.onyxbot.teams.cards import build_welcome_card
|
||||
|
||||
attachment = Attachment(
|
||||
content_type="application/vnd.microsoft.card.adaptive",
|
||||
content=build_welcome_card(),
|
||||
)
|
||||
response = Activity(
|
||||
type=ActivityTypes.message,
|
||||
attachments=[attachment],
|
||||
)
|
||||
await turn_context.send_activity(response)
|
||||
35
backend/onyx/onyxbot/teams/cache.py
Normal file
35
backend/onyx/onyxbot/teams/cache.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Multi-tenant cache for Teams bot team-tenant mappings and API keys."""
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.teams_bot import get_team_configs
|
||||
from onyx.db.teams_bot import provision_teams_service_api_key
|
||||
from onyx.onyxbot.cache import BotCacheManager
|
||||
|
||||
|
||||
class TeamsCacheManager(BotCacheManager[str]):
|
||||
"""Caches team->tenant mappings and tenant->API key mappings."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(entity_name="teams")
|
||||
|
||||
def _get_entity_ids(self, db: Session) -> list[str]:
|
||||
configs = get_team_configs(db)
|
||||
return [
|
||||
config.team_id
|
||||
for config in configs
|
||||
if config.enabled and config.team_id is not None
|
||||
]
|
||||
|
||||
def _get_or_create_api_key(self, db: Session, tenant_id: str) -> str:
|
||||
return provision_teams_service_api_key(db, tenant_id)
|
||||
|
||||
# Convenience aliases for caller clarity
|
||||
async def refresh_team(self, team_id: str, tenant_id: str) -> None:
|
||||
await self.refresh_entity(team_id, tenant_id)
|
||||
|
||||
def remove_team(self, team_id: str) -> None:
|
||||
self.remove_entity(team_id)
|
||||
|
||||
def get_all_team_ids(self) -> list[str]:
|
||||
return self.get_all_entity_ids()
|
||||
136
backend/onyx/onyxbot/teams/cards.py
Normal file
136
backend/onyx/onyxbot/teams/cards.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""Adaptive Card builders for Teams bot responses."""
|
||||
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.onyxbot.teams.constants import ADAPTIVE_CARD_SCHEMA
|
||||
from onyx.onyxbot.teams.constants import ADAPTIVE_CARD_VERSION
|
||||
from onyx.onyxbot.teams.constants import MAX_CITATIONS
|
||||
|
||||
|
||||
def build_answer_card(
|
||||
answer: str,
|
||||
response: ChatFullResponse | None = None,
|
||||
) -> dict:
|
||||
"""Build an Adaptive Card for a chat answer with optional citations.
|
||||
|
||||
Target Adaptive Card schema version 1.3 for mobile compatibility.
|
||||
"""
|
||||
body: list[dict] = [
|
||||
{
|
||||
"type": "TextBlock",
|
||||
"text": answer,
|
||||
"wrap": True,
|
||||
}
|
||||
]
|
||||
|
||||
# Add citations if present
|
||||
citations = _extract_citations(response) if response else []
|
||||
if citations:
|
||||
body.append(
|
||||
{
|
||||
"type": "TextBlock",
|
||||
"text": "**Sources:**",
|
||||
"wrap": True,
|
||||
"spacing": "Medium",
|
||||
}
|
||||
)
|
||||
for num, name, link in citations:
|
||||
if link:
|
||||
body.append(
|
||||
{
|
||||
"type": "TextBlock",
|
||||
"text": f"{num}. [{name}]({link})",
|
||||
"wrap": True,
|
||||
"spacing": "None",
|
||||
}
|
||||
)
|
||||
else:
|
||||
body.append(
|
||||
{
|
||||
"type": "TextBlock",
|
||||
"text": f"{num}. {name}",
|
||||
"wrap": True,
|
||||
"spacing": "None",
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"$schema": ADAPTIVE_CARD_SCHEMA,
|
||||
"type": "AdaptiveCard",
|
||||
"version": ADAPTIVE_CARD_VERSION,
|
||||
"body": body,
|
||||
}
|
||||
|
||||
|
||||
def build_error_card(message: str) -> dict:
|
||||
"""Build an Adaptive Card for error messages."""
|
||||
return {
|
||||
"$schema": ADAPTIVE_CARD_SCHEMA,
|
||||
"type": "AdaptiveCard",
|
||||
"version": ADAPTIVE_CARD_VERSION,
|
||||
"body": [
|
||||
{
|
||||
"type": "TextBlock",
|
||||
"text": message,
|
||||
"wrap": True,
|
||||
"color": "Attention",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def build_welcome_card() -> dict:
|
||||
"""Build an Adaptive Card for the welcome message when bot is added."""
|
||||
return {
|
||||
"$schema": ADAPTIVE_CARD_SCHEMA,
|
||||
"type": "AdaptiveCard",
|
||||
"version": ADAPTIVE_CARD_VERSION,
|
||||
"body": [
|
||||
{
|
||||
"type": "TextBlock",
|
||||
"text": "Welcome to Onyx!",
|
||||
"weight": "Bolder",
|
||||
"size": "Medium",
|
||||
},
|
||||
{
|
||||
"type": "TextBlock",
|
||||
"text": (
|
||||
"I'm the Onyx bot. I can help you search your company's knowledge base "
|
||||
"and answer questions.\n\n"
|
||||
"To get started, an admin needs to register this team. "
|
||||
"Send me a direct message with:\n\n"
|
||||
"`@Onyx register <registration_key>`"
|
||||
),
|
||||
"wrap": True,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _extract_citations(
|
||||
response: ChatFullResponse,
|
||||
) -> list[tuple[int, str, str | None]]:
|
||||
"""Extract citation information from a chat response."""
|
||||
if not response.citation_info or not response.top_documents:
|
||||
return []
|
||||
|
||||
cited_docs: list[tuple[int, str, str | None]] = []
|
||||
for citation in response.citation_info:
|
||||
doc = next(
|
||||
(
|
||||
d
|
||||
for d in response.top_documents
|
||||
if d.document_id == citation.document_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
if doc:
|
||||
cited_docs.append(
|
||||
(
|
||||
citation.citation_number,
|
||||
doc.semantic_identifier or "Source",
|
||||
doc.link,
|
||||
)
|
||||
)
|
||||
|
||||
cited_docs.sort(key=lambda x: x[0])
|
||||
return cited_docs[:MAX_CITATIONS]
|
||||
14
backend/onyx/onyxbot/teams/constants.py
Normal file
14
backend/onyx/onyxbot/teams/constants.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Teams-specific bot constants.
|
||||
|
||||
Shared constants (API_REQUEST_TIMEOUT, CACHE_REFRESH_INTERVAL,
|
||||
REGISTER_COMMAND) live in ``onyx.onyxbot.constants``.
|
||||
"""
|
||||
|
||||
# Bot Framework settings
|
||||
BOT_MESSAGES_ENDPOINT: str = "/api/messages"
|
||||
BOT_HEALTH_ENDPOINT: str = "/health"
|
||||
|
||||
# Adaptive Card settings
|
||||
ADAPTIVE_CARD_SCHEMA: str = "http://adaptivecards.io/schemas/adaptive-card.json"
|
||||
ADAPTIVE_CARD_VERSION: str = "1.3" # Compatible with mobile clients
|
||||
MAX_CITATIONS: int = 5
|
||||
81
backend/onyx/onyxbot/teams/handle_commands.py
Normal file
81
backend/onyx/onyxbot/teams/handle_commands.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Teams bot command handlers (e.g., registration)."""
|
||||
|
||||
import asyncio
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.teams_bot import get_team_config_by_registration_key
|
||||
from onyx.db.teams_bot import register_team
|
||||
from onyx.onyxbot.constants import REGISTER_COMMAND
|
||||
from onyx.onyxbot.exceptions import RegistrationError
|
||||
from onyx.onyxbot.teams.cache import TeamsCacheManager
|
||||
from onyx.onyxbot.teams.utils import extract_team_id
|
||||
from onyx.onyxbot.teams.utils import extract_team_name
|
||||
from onyx.onyxbot.teams.utils import strip_bot_mention
|
||||
from onyx.server.manage.teams_bot.utils import parse_teams_registration_key
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
async def handle_registration_command(
|
||||
text: str,
|
||||
activity_dict: dict,
|
||||
bot_name: str,
|
||||
cache: TeamsCacheManager,
|
||||
) -> str:
|
||||
"""Handle the 'register <key>' command.
|
||||
|
||||
Returns a human-readable response message.
|
||||
"""
|
||||
clean_text = strip_bot_mention(text, bot_name).strip()
|
||||
|
||||
# Parse "register <key>"
|
||||
parts = clean_text.split(None, 1)
|
||||
if len(parts) != 2 or parts[0].lower() != REGISTER_COMMAND:
|
||||
raise RegistrationError(
|
||||
f"Invalid registration command. Usage: @{bot_name} register <registration_key>"
|
||||
)
|
||||
|
||||
registration_key = parts[1].strip()
|
||||
|
||||
# Parse tenant_id from registration key
|
||||
tenant_id = parse_teams_registration_key(registration_key)
|
||||
if not tenant_id:
|
||||
raise RegistrationError("Invalid registration key format.")
|
||||
|
||||
team_id = extract_team_id(activity_dict)
|
||||
team_name = extract_team_name(activity_dict) or "Unknown Team"
|
||||
|
||||
if not team_id:
|
||||
raise RegistrationError(
|
||||
"Registration must be done from a Teams channel, not a DM."
|
||||
)
|
||||
|
||||
def _register() -> str:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db:
|
||||
# Lock the row to prevent concurrent registration with the same key
|
||||
config = get_team_config_by_registration_key(
|
||||
db, registration_key, for_update=True
|
||||
)
|
||||
if not config:
|
||||
raise RegistrationError("Registration key not found or already used.")
|
||||
|
||||
if config.team_id is not None:
|
||||
raise RegistrationError("This registration key has already been used.")
|
||||
|
||||
register_team(db, config, team_id, team_name)
|
||||
db.commit()
|
||||
return tenant_id
|
||||
|
||||
registered_tenant_id = await asyncio.to_thread(_register)
|
||||
await cache.refresh_team(team_id, registered_tenant_id)
|
||||
|
||||
logger.info(f"Team {team_id} ({team_name}) registered for tenant {tenant_id}")
|
||||
return f"Team **{team_name}** has been registered with Onyx. You can now configure channels in the admin panel."
|
||||
|
||||
|
||||
def is_registration_command(text: str, bot_name: str) -> bool:
|
||||
"""Check if a message is a registration command."""
|
||||
clean_text = strip_bot_mention(text, bot_name).strip()
|
||||
parts = clean_text.split(None, 1)
|
||||
return len(parts) >= 1 and parts[0].lower() == REGISTER_COMMAND
|
||||
111
backend/onyx/onyxbot/teams/handle_message.py
Normal file
111
backend/onyx/onyxbot/teams/handle_message.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Teams bot message handling and response logic."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.models import TeamsChannelConfig
|
||||
from onyx.db.models import TeamsTeamConfig
|
||||
from onyx.db.teams_bot import get_channel_config_by_teams_ids
|
||||
from onyx.db.teams_bot import get_team_config_by_teams_id
|
||||
from onyx.onyxbot.api_client import OnyxAPIClient
|
||||
from onyx.onyxbot.exceptions import APIError
|
||||
from onyx.onyxbot.teams.cards import build_answer_card
|
||||
from onyx.onyxbot.teams.cards import build_error_card
|
||||
from onyx.onyxbot.teams.utils import is_bot_mentioned
|
||||
from onyx.onyxbot.teams.utils import strip_bot_mention
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShouldRespondContext:
|
||||
"""Context for whether the bot should respond to a message."""
|
||||
|
||||
should_respond: bool
|
||||
persona_id: int | None
|
||||
tenant_id: str | None = field(default=None)
|
||||
api_key: str | None = field(default=None)
|
||||
|
||||
|
||||
def should_respond(
|
||||
activity_dict: dict,
|
||||
team_id: str | None,
|
||||
channel_id: str | None,
|
||||
tenant_id: str,
|
||||
bot_id: str,
|
||||
) -> ShouldRespondContext:
|
||||
"""Determine if bot should respond and which persona to use.
|
||||
|
||||
This is a synchronous function that performs DB lookups.
|
||||
"""
|
||||
no_response = ShouldRespondContext(should_respond=False, persona_id=None)
|
||||
|
||||
if not team_id or not channel_id:
|
||||
# DM or group chat — respond if we have a tenant
|
||||
return ShouldRespondContext(should_respond=True, persona_id=None)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db:
|
||||
team_config: TeamsTeamConfig | None = get_team_config_by_teams_id(db, team_id)
|
||||
if not team_config or not team_config.enabled:
|
||||
return no_response
|
||||
|
||||
channel_config: TeamsChannelConfig | None = get_channel_config_by_teams_ids(
|
||||
db, team_id, channel_id
|
||||
)
|
||||
if not channel_config or not channel_config.enabled:
|
||||
return no_response
|
||||
|
||||
# Determine persona (channel override or team default)
|
||||
persona_id = (
|
||||
channel_config.persona_override_id or team_config.default_persona_id
|
||||
)
|
||||
|
||||
# Check mention requirement
|
||||
if channel_config.require_bot_mention:
|
||||
if not is_bot_mentioned(activity_dict, bot_id):
|
||||
return no_response
|
||||
|
||||
return ShouldRespondContext(should_respond=True, persona_id=persona_id)
|
||||
|
||||
|
||||
async def process_chat_message(
|
||||
text: str,
|
||||
api_key: str,
|
||||
persona_id: int | None,
|
||||
api_client: OnyxAPIClient,
|
||||
bot_name: str,
|
||||
) -> dict:
|
||||
"""Process a message and return an Adaptive Card response.
|
||||
|
||||
Returns:
|
||||
Adaptive Card dict for the response.
|
||||
"""
|
||||
try:
|
||||
# Strip bot mention from the message
|
||||
clean_text = strip_bot_mention(text, bot_name)
|
||||
if not clean_text:
|
||||
return build_error_card("Please include a message after the @mention.")
|
||||
|
||||
# Send to Onyx API
|
||||
response = await api_client.send_chat_message(
|
||||
message=clean_text,
|
||||
api_key=api_key,
|
||||
persona_id=persona_id,
|
||||
)
|
||||
|
||||
answer = response.answer or "I couldn't generate a response."
|
||||
return build_answer_card(answer, response)
|
||||
|
||||
except APIError as e:
|
||||
logger.error(f"API error processing Teams message: {e}")
|
||||
return build_error_card(
|
||||
"Sorry, I encountered an error processing your message. "
|
||||
"Please try again later."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing Teams chat message: {e}")
|
||||
return build_error_card(
|
||||
"Sorry, an unexpected error occurred. Please try again later."
|
||||
)
|
||||
155
backend/onyx/onyxbot/teams/server.py
Normal file
155
backend/onyx/onyxbot/teams/server.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""HTTP server for Teams bot using aiohttp + Bot Framework adapter."""
|
||||
|
||||
import sys
|
||||
|
||||
from aiohttp import web
|
||||
from botbuilder.core import BotFrameworkAdapter # type: ignore[import-untyped]
|
||||
from botbuilder.core import BotFrameworkAdapterSettings
|
||||
from botbuilder.core import TurnContext
|
||||
from botbuilder.schema import Activity # type: ignore[import-untyped]
|
||||
|
||||
from onyx.configs.app_configs import TEAMS_BOT_PORT
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.teams_bot import get_teams_bot_config
|
||||
from onyx.onyxbot.teams.bot import OnyxTeamsBot
|
||||
from onyx.onyxbot.teams.constants import BOT_HEALTH_ENDPOINT
|
||||
from onyx.onyxbot.teams.constants import BOT_MESSAGES_ENDPOINT
|
||||
from onyx.onyxbot.teams.utils import get_bot_credentials_from_env
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_credentials() -> tuple[str, str, str | None] | None:
|
||||
"""Get bot credentials from env vars or database.
|
||||
|
||||
Env vars take priority. Falls back to DB config for self-hosted
|
||||
deployments that configure via admin UI.
|
||||
"""
|
||||
env_creds = get_bot_credentials_from_env()
|
||||
if env_creds:
|
||||
return env_creds
|
||||
|
||||
# Try database (for self-hosted deployments)
|
||||
try:
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
for tenant_id in tenant_ids:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db:
|
||||
config = get_teams_bot_config(db)
|
||||
if config:
|
||||
# Access the decrypted value
|
||||
app_secret = config.app_secret
|
||||
if isinstance(app_secret, str):
|
||||
return config.app_id, app_secret, config.azure_tenant_id
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load Teams bot config from DB: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def _handle_messages(request: web.Request) -> web.Response:
|
||||
"""Handle incoming Bot Framework Activities at POST /api/messages."""
|
||||
bot: OnyxTeamsBot = request.app["bot"]
|
||||
adapter: BotFrameworkAdapter = request.app["adapter"]
|
||||
|
||||
if request.content_type != "application/json":
|
||||
return web.Response(status=415, text="Unsupported media type")
|
||||
|
||||
body = await request.json()
|
||||
activity = Activity().deserialize(body)
|
||||
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
|
||||
async def _turn_callback(turn_context: TurnContext) -> None:
|
||||
await bot.on_turn(turn_context)
|
||||
|
||||
try:
|
||||
invoke_response = await adapter.process_activity(
|
||||
activity, auth_header, _turn_callback
|
||||
)
|
||||
# For invoke activities (messaging extensions, task modules),
|
||||
# process_activity returns an InvokeResponse with status/body
|
||||
# that must be forwarded to the Bot Framework.
|
||||
if invoke_response:
|
||||
return web.Response(
|
||||
status=invoke_response.status,
|
||||
body=invoke_response.body,
|
||||
content_type="application/json",
|
||||
)
|
||||
return web.Response(status=200)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing activity: {e}")
|
||||
return web.Response(status=500, text="Internal server error")
|
||||
|
||||
|
||||
async def _handle_health(request: web.Request) -> web.Response:
|
||||
"""Health check endpoint."""
|
||||
bot: OnyxTeamsBot = request.app["bot"]
|
||||
healthy = bot.api_client.is_initialized and bot.cache.is_initialized
|
||||
if healthy:
|
||||
return web.Response(status=200, text="OK")
|
||||
return web.Response(status=503, text="Not ready")
|
||||
|
||||
|
||||
async def _on_startup(app: web.Application) -> None:
|
||||
"""Initialize bot on server startup."""
|
||||
bot: OnyxTeamsBot = app["bot"]
|
||||
await bot.initialize()
|
||||
logger.info("Teams bot server started")
|
||||
|
||||
|
||||
async def _on_shutdown(app: web.Application) -> None:
|
||||
"""Shut down bot on server shutdown."""
|
||||
bot: OnyxTeamsBot = app["bot"]
|
||||
await bot.shutdown()
|
||||
logger.info("Teams bot server stopped")
|
||||
|
||||
|
||||
def create_app(
|
||||
app_id: str,
|
||||
app_secret: str,
|
||||
) -> web.Application:
|
||||
"""Create the aiohttp web application for the Teams bot."""
|
||||
settings = BotFrameworkAdapterSettings(
|
||||
app_id=app_id,
|
||||
app_password=app_secret,
|
||||
)
|
||||
adapter = BotFrameworkAdapter(settings)
|
||||
|
||||
bot = OnyxTeamsBot()
|
||||
|
||||
app = web.Application()
|
||||
app["bot"] = bot
|
||||
app["adapter"] = adapter
|
||||
app.router.add_post(BOT_MESSAGES_ENDPOINT, _handle_messages)
|
||||
app.router.add_get(BOT_HEALTH_ENDPOINT, _handle_health)
|
||||
app.on_startup.append(_on_startup)
|
||||
app.on_shutdown.append(_on_shutdown)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Entry point for the Teams bot process."""
|
||||
logger.info("Starting Teams bot...")
|
||||
|
||||
credentials = _get_credentials()
|
||||
if not credentials:
|
||||
logger.error(
|
||||
"Teams bot credentials not configured. "
|
||||
"Set TEAMS_BOT_APP_ID and TEAMS_BOT_APP_SECRET environment variables, "
|
||||
"or configure via the admin panel."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
app_id, app_secret, _azure_tenant_id = credentials
|
||||
logger.info(f"Teams bot starting with App ID: {app_id}")
|
||||
|
||||
app = create_app(app_id, app_secret)
|
||||
web.run_app(app, host="0.0.0.0", port=TEAMS_BOT_PORT)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
83
backend/onyx/onyxbot/teams/utils.py
Normal file
83
backend/onyx/onyxbot/teams/utils.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Utility functions for Teams bot."""
|
||||
|
||||
from onyx.configs.app_configs import TEAMS_BOT_APP_ID
|
||||
from onyx.configs.app_configs import TEAMS_BOT_APP_SECRET
|
||||
from onyx.configs.app_configs import TEAMS_BOT_AZURE_TENANT_ID
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_bot_credentials_from_env() -> tuple[str, str, str | None] | None:
|
||||
"""Get bot credentials from environment variables.
|
||||
|
||||
Returns:
|
||||
(app_id, app_secret, azure_tenant_id) or None if not configured.
|
||||
"""
|
||||
if not TEAMS_BOT_APP_ID or not TEAMS_BOT_APP_SECRET:
|
||||
return None
|
||||
return TEAMS_BOT_APP_ID, TEAMS_BOT_APP_SECRET, TEAMS_BOT_AZURE_TENANT_ID
|
||||
|
||||
|
||||
def extract_team_id(activity: dict) -> str | None:
|
||||
"""Extract the Teams team ID from an Activity's channelData.
|
||||
|
||||
Teams Activities include channelData.team.id for messages in team channels.
|
||||
For 1:1 or group chats, this will be None.
|
||||
"""
|
||||
channel_data = activity.get("channelData", {})
|
||||
team = channel_data.get("team")
|
||||
if team:
|
||||
return team.get("id")
|
||||
return None
|
||||
|
||||
|
||||
def extract_channel_id(activity: dict) -> str | None:
|
||||
"""Extract the Teams channel ID from an Activity's channelData."""
|
||||
channel_data = activity.get("channelData", {})
|
||||
channel = channel_data.get("channel")
|
||||
if channel:
|
||||
return channel.get("id")
|
||||
return None
|
||||
|
||||
|
||||
def extract_team_name(activity: dict) -> str | None:
|
||||
"""Extract the Teams team name from an Activity's channelData."""
|
||||
channel_data = activity.get("channelData", {})
|
||||
team = channel_data.get("team")
|
||||
if team:
|
||||
return team.get("name")
|
||||
return None
|
||||
|
||||
|
||||
def strip_bot_mention(text: str, bot_name: str) -> str:
|
||||
"""Remove the bot @mention from the message text.
|
||||
|
||||
Teams includes the @mention in the message text as <at>BotName</at>.
|
||||
"""
|
||||
import re
|
||||
|
||||
# Remove <at>BotName</at> tags
|
||||
cleaned = re.sub(
|
||||
rf"<at>{re.escape(bot_name)}</at>",
|
||||
"",
|
||||
text,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
# Also try without the specific name (some clients send generic)
|
||||
cleaned = re.sub(r"<at>[^<]*</at>", "", cleaned)
|
||||
return cleaned.strip()
|
||||
|
||||
|
||||
def is_bot_mentioned(activity: dict, bot_id: str) -> bool:
|
||||
"""Check if the bot is mentioned in the activity.
|
||||
|
||||
Teams includes mentions in the activity entities array.
|
||||
"""
|
||||
entities = activity.get("entities", [])
|
||||
for entity in entities:
|
||||
if entity.get("type") == "mention":
|
||||
mentioned = entity.get("mentioned", {})
|
||||
if mentioned.get("id") == bot_id:
|
||||
return True
|
||||
return False
|
||||
@@ -7,14 +7,13 @@ from PIL import UnidentifiedImageError
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.file_types import OnyxFileExtensions
|
||||
from onyx.file_processing.password_validation import is_file_password_protected
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -117,9 +116,7 @@ def estimate_image_tokens_for_upload(
|
||||
pass
|
||||
|
||||
|
||||
def categorize_uploaded_files(
|
||||
files: list[UploadFile], db_session: Session
|
||||
) -> CategorizedFiles:
|
||||
def categorize_uploaded_files(files: list[UploadFile]) -> CategorizedFiles:
|
||||
"""
|
||||
Categorize uploaded files based on text extractability and tokenized length.
|
||||
|
||||
@@ -131,11 +128,11 @@ def categorize_uploaded_files(
|
||||
"""
|
||||
|
||||
results = CategorizedFiles()
|
||||
default_model = fetch_default_llm_model(db_session)
|
||||
llm = get_default_llm()
|
||||
|
||||
model_name = default_model.name if default_model else None
|
||||
provider_type = default_model.llm_provider.provider if default_model else None
|
||||
tokenizer = get_tokenizer(model_name=model_name, provider_type=provider_type)
|
||||
tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name, provider_type=llm.config.model_provider
|
||||
)
|
||||
|
||||
# Check if threshold checks should be skipped
|
||||
skip_threshold = False
|
||||
|
||||
@@ -1,46 +1,16 @@
|
||||
"""Discord registration key generation and parsing."""
|
||||
|
||||
import secrets
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import unquote
|
||||
from onyx.onyxbot.registration import generate_registration_key
|
||||
from onyx.onyxbot.registration import parse_registration_key
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
REGISTRATION_KEY_PREFIX: str = "discord_"
|
||||
REGISTRATION_KEY_PREFIX: str = "discord"
|
||||
|
||||
|
||||
def generate_discord_registration_key(tenant_id: str) -> str:
|
||||
"""Generate a one-time registration key with embedded tenant_id.
|
||||
|
||||
Format: discord_<url_encoded_tenant_id>.<random_token>
|
||||
|
||||
Follows the same pattern as API keys for consistency.
|
||||
"""
|
||||
encoded_tenant = quote(tenant_id)
|
||||
random_token = secrets.token_urlsafe(16)
|
||||
|
||||
logger.info(f"Generated Discord registration key for tenant {tenant_id}")
|
||||
return f"{REGISTRATION_KEY_PREFIX}{encoded_tenant}.{random_token}"
|
||||
"""Generate a one-time registration key with embedded tenant_id."""
|
||||
return generate_registration_key(REGISTRATION_KEY_PREFIX, tenant_id)
|
||||
|
||||
|
||||
def parse_discord_registration_key(key: str) -> str | None:
|
||||
"""Parse registration key to extract tenant_id.
|
||||
|
||||
Returns tenant_id or None if invalid format.
|
||||
"""
|
||||
if not key.startswith(REGISTRATION_KEY_PREFIX):
|
||||
return None
|
||||
|
||||
try:
|
||||
key_body = key.removeprefix(REGISTRATION_KEY_PREFIX)
|
||||
parts = key_body.split(".", 1)
|
||||
if len(parts) != 2:
|
||||
return None
|
||||
|
||||
encoded_tenant = parts[0]
|
||||
tenant_id = unquote(encoded_tenant)
|
||||
return tenant_id
|
||||
except Exception:
|
||||
return None
|
||||
"""Parse registration key to extract tenant_id."""
|
||||
return parse_registration_key(REGISTRATION_KEY_PREFIX, key)
|
||||
|
||||
0
backend/onyx/server/manage/teams_bot/__init__.py
Normal file
0
backend/onyx/server/manage/teams_bot/__init__.py
Normal file
279
backend/onyx/server/manage/teams_bot/api.py
Normal file
279
backend/onyx/server/manage/teams_bot/api.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""Teams bot admin API endpoints."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import TEAMS_BOT_APP_ID
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.teams_bot import create_team_config
|
||||
from onyx.db.teams_bot import create_teams_bot_config
|
||||
from onyx.db.teams_bot import delete_team_config
|
||||
from onyx.db.teams_bot import delete_teams_bot_config
|
||||
from onyx.db.teams_bot import delete_teams_service_api_key
|
||||
from onyx.db.teams_bot import get_channel_config_by_internal_ids
|
||||
from onyx.db.teams_bot import get_channel_configs
|
||||
from onyx.db.teams_bot import get_team_config_by_internal_id
|
||||
from onyx.db.teams_bot import get_team_configs
|
||||
from onyx.db.teams_bot import get_teams_bot_config
|
||||
from onyx.db.teams_bot import update_team_config
|
||||
from onyx.db.teams_bot import update_teams_channel_config
|
||||
from onyx.server.manage.teams_bot.models import TeamsBotConfigCreateRequest
|
||||
from onyx.server.manage.teams_bot.models import TeamsBotConfigResponse
|
||||
from onyx.server.manage.teams_bot.models import TeamsChannelConfigResponse
|
||||
from onyx.server.manage.teams_bot.models import TeamsChannelConfigUpdateRequest
|
||||
from onyx.server.manage.teams_bot.models import TeamsTeamConfigCreateResponse
|
||||
from onyx.server.manage.teams_bot.models import TeamsTeamConfigResponse
|
||||
from onyx.server.manage.teams_bot.models import TeamsTeamConfigUpdateRequest
|
||||
from onyx.server.manage.teams_bot.utils import generate_teams_registration_key
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
router = APIRouter(prefix="/manage/admin/teams-bot")
|
||||
|
||||
|
||||
def _check_bot_config_api_access() -> None:
|
||||
"""Raise 403 if bot config cannot be managed via API.
|
||||
|
||||
Bot config endpoints are disabled:
|
||||
- On Cloud (managed by Onyx)
|
||||
- When TEAMS_BOT_APP_ID env var is set (managed via env)
|
||||
"""
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Teams bot configuration is managed by Onyx on Cloud.",
|
||||
)
|
||||
if TEAMS_BOT_APP_ID:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Teams bot is configured via environment variables. API access disabled.",
|
||||
)
|
||||
|
||||
|
||||
# === Bot Config ===
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
def get_bot_config(
|
||||
_: None = Depends(_check_bot_config_api_access),
|
||||
__: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> TeamsBotConfigResponse:
|
||||
"""Get Teams bot config. Returns 403 on Cloud or if env vars set."""
|
||||
config = get_teams_bot_config(db_session)
|
||||
if not config:
|
||||
return TeamsBotConfigResponse(configured=False)
|
||||
|
||||
return TeamsBotConfigResponse(
|
||||
configured=True,
|
||||
created_at=config.created_at,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/config")
|
||||
def create_bot_request(
|
||||
request: TeamsBotConfigCreateRequest,
|
||||
_: None = Depends(_check_bot_config_api_access),
|
||||
__: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> TeamsBotConfigResponse:
|
||||
"""Create Teams bot config. Returns 403 on Cloud or if env vars set."""
|
||||
try:
|
||||
config = create_teams_bot_config(
|
||||
db_session,
|
||||
app_id=request.app_id,
|
||||
app_secret=request.app_secret,
|
||||
azure_tenant_id=request.azure_tenant_id,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Teams bot config already exists. Delete it first to create a new one.",
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return TeamsBotConfigResponse(
|
||||
configured=True,
|
||||
created_at=config.created_at,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/config")
|
||||
def delete_bot_config_endpoint(
|
||||
_: None = Depends(_check_bot_config_api_access),
|
||||
__: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict:
|
||||
"""Delete Teams bot config.
|
||||
|
||||
Also deletes the Teams service API key since the bot is being removed.
|
||||
"""
|
||||
deleted = delete_teams_bot_config(db_session)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Bot config not found")
|
||||
|
||||
delete_teams_service_api_key(db_session)
|
||||
|
||||
db_session.commit()
|
||||
return {"deleted": True}
|
||||
|
||||
|
||||
# === Service API Key ===
|
||||
|
||||
|
||||
@router.delete("/service-api-key")
|
||||
def delete_service_api_key_endpoint(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict:
|
||||
"""Delete the Teams service API key."""
|
||||
deleted = delete_teams_service_api_key(db_session)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Service API key not found")
|
||||
db_session.commit()
|
||||
return {"deleted": True}
|
||||
|
||||
|
||||
# === Team Config ===
|
||||
|
||||
|
||||
@router.get("/teams")
|
||||
def list_team_configs(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[TeamsTeamConfigResponse]:
|
||||
"""List all team configs (pending and registered)."""
|
||||
configs = get_team_configs(db_session)
|
||||
return [TeamsTeamConfigResponse.model_validate(c) for c in configs]
|
||||
|
||||
|
||||
@router.post("/teams")
|
||||
def create_team_request(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> TeamsTeamConfigCreateResponse:
|
||||
"""Create new team config with registration key. Key shown once."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
registration_key = generate_teams_registration_key(tenant_id)
|
||||
|
||||
config = create_team_config(db_session, registration_key)
|
||||
db_session.commit()
|
||||
|
||||
return TeamsTeamConfigCreateResponse(
|
||||
id=config.id,
|
||||
registration_key=registration_key,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/teams/{config_id}")
|
||||
def get_team_config(
|
||||
config_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> TeamsTeamConfigResponse:
|
||||
"""Get specific team config."""
|
||||
config = get_team_config_by_internal_id(db_session, internal_id=config_id)
|
||||
if not config:
|
||||
raise HTTPException(status_code=404, detail="Team config not found")
|
||||
return TeamsTeamConfigResponse.model_validate(config)
|
||||
|
||||
|
||||
@router.patch("/teams/{config_id}")
|
||||
def update_team_request(
|
||||
config_id: int,
|
||||
request: TeamsTeamConfigUpdateRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> TeamsTeamConfigResponse:
|
||||
"""Update team config."""
|
||||
config = get_team_config_by_internal_id(db_session, internal_id=config_id)
|
||||
if not config:
|
||||
raise HTTPException(status_code=404, detail="Team config not found")
|
||||
|
||||
config = update_team_config(
|
||||
db_session,
|
||||
config,
|
||||
enabled=request.enabled,
|
||||
default_persona_id=request.default_persona_id,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
return TeamsTeamConfigResponse.model_validate(config)
|
||||
|
||||
|
||||
@router.delete("/teams/{config_id}")
|
||||
def delete_team_request(
|
||||
config_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict:
|
||||
"""Delete team config (invalidates registration key).
|
||||
|
||||
On Cloud, if this was the last team config, also deletes the service API key.
|
||||
"""
|
||||
deleted = delete_team_config(db_session, config_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Team config not found")
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
remaining_teams = get_team_configs(db_session)
|
||||
if not remaining_teams:
|
||||
delete_teams_service_api_key(db_session)
|
||||
|
||||
db_session.commit()
|
||||
return {"deleted": True}
|
||||
|
||||
|
||||
# === Channel Config ===
|
||||
|
||||
|
||||
@router.get("/teams/{config_id}/channels")
|
||||
def list_channel_configs(
|
||||
config_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[TeamsChannelConfigResponse]:
|
||||
"""List whitelisted channels for a team."""
|
||||
team_config = get_team_config_by_internal_id(db_session, internal_id=config_id)
|
||||
if not team_config:
|
||||
raise HTTPException(status_code=404, detail="Team config not found")
|
||||
if not team_config.team_id:
|
||||
raise HTTPException(status_code=400, detail="Team not yet registered")
|
||||
|
||||
configs = get_channel_configs(db_session, config_id)
|
||||
return [TeamsChannelConfigResponse.model_validate(c) for c in configs]
|
||||
|
||||
|
||||
@router.patch("/teams/{team_config_id}/channels/{channel_config_id}")
|
||||
def update_channel_request(
|
||||
team_config_id: int,
|
||||
channel_config_id: int,
|
||||
request: TeamsChannelConfigUpdateRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> TeamsChannelConfigResponse:
|
||||
"""Update channel config."""
|
||||
config = get_channel_config_by_internal_ids(
|
||||
db_session, team_config_id, channel_config_id
|
||||
)
|
||||
if not config:
|
||||
raise HTTPException(status_code=404, detail="Channel config not found")
|
||||
|
||||
config = update_teams_channel_config(
|
||||
db_session,
|
||||
config,
|
||||
channel_name=config.channel_name, # Keep existing name
|
||||
require_bot_mention=request.require_bot_mention,
|
||||
persona_override_id=request.persona_override_id,
|
||||
enabled=request.enabled,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
return TeamsChannelConfigResponse.model_validate(config)
|
||||
69
backend/onyx/server/manage/teams_bot/models.py
Normal file
69
backend/onyx/server/manage/teams_bot/models.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Pydantic models for Teams bot API."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# === Bot Config ===
|
||||
|
||||
|
||||
class TeamsBotConfigResponse(BaseModel):
|
||||
configured: bool
|
||||
created_at: datetime | None = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class TeamsBotConfigCreateRequest(BaseModel):
|
||||
app_id: str
|
||||
app_secret: str
|
||||
azure_tenant_id: str | None = None
|
||||
|
||||
|
||||
# === Team Config ===
|
||||
|
||||
|
||||
class TeamsTeamConfigResponse(BaseModel):
|
||||
id: int
|
||||
team_id: str | None
|
||||
team_name: str | None
|
||||
registered_at: datetime | None
|
||||
default_persona_id: int | None
|
||||
enabled: bool
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class TeamsTeamConfigCreateResponse(BaseModel):
|
||||
id: int
|
||||
registration_key: str # Shown once!
|
||||
|
||||
|
||||
class TeamsTeamConfigUpdateRequest(BaseModel):
|
||||
enabled: bool
|
||||
default_persona_id: int | None
|
||||
|
||||
|
||||
# === Channel Config ===
|
||||
|
||||
|
||||
class TeamsChannelConfigResponse(BaseModel):
|
||||
id: int
|
||||
team_config_id: int
|
||||
channel_id: str
|
||||
channel_name: str
|
||||
require_bot_mention: bool
|
||||
persona_override_id: int | None
|
||||
enabled: bool
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class TeamsChannelConfigUpdateRequest(BaseModel):
|
||||
require_bot_mention: bool
|
||||
persona_override_id: int | None
|
||||
enabled: bool
|
||||
16
backend/onyx/server/manage/teams_bot/utils.py
Normal file
16
backend/onyx/server/manage/teams_bot/utils.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Teams registration key generation and parsing."""
|
||||
|
||||
from onyx.onyxbot.registration import generate_registration_key
|
||||
from onyx.onyxbot.registration import parse_registration_key
|
||||
|
||||
REGISTRATION_KEY_PREFIX: str = "teams"
|
||||
|
||||
|
||||
def generate_teams_registration_key(tenant_id: str) -> str:
|
||||
"""Generate a one-time registration key with embedded tenant_id."""
|
||||
return generate_registration_key(REGISTRATION_KEY_PREFIX, tenant_id)
|
||||
|
||||
|
||||
def parse_teams_registration_key(key: str) -> str | None:
|
||||
"""Parse registration key to extract tenant_id."""
|
||||
return parse_registration_key(REGISTRATION_KEY_PREFIX, key)
|
||||
@@ -32,6 +32,7 @@ class MessageOrigin(str, Enum):
|
||||
SLACKBOT = "slackbot"
|
||||
WIDGET = "widget"
|
||||
DISCORDBOT = "discordbot"
|
||||
TEAMSBOT = "teamsbot"
|
||||
UNKNOWN = "unknown"
|
||||
UNSET = "unset"
|
||||
|
||||
|
||||
@@ -8,3 +8,37 @@ dependencies = [
|
||||
|
||||
[tool.uv.sources]
|
||||
onyx = { workspace = true }
|
||||
|
||||
[tool.mypy]
|
||||
plugins = "sqlalchemy.ext.mypy.plugin"
|
||||
mypy_path = "backend"
|
||||
explicit_package_bases = true
|
||||
disallow_untyped_defs = true
|
||||
warn_unused_ignores = true
|
||||
enable_error_code = ["possibly-undefined"]
|
||||
strict_equality = true
|
||||
# Patterns match paths whether mypy is run from backend/ (CI) or repo root (e.g. VS Code extension with target ./backend)
|
||||
exclude = [
|
||||
"(?:^|/)generated/",
|
||||
"(?:^|/)\\.venv/",
|
||||
"(?:^|/)onyx/server/features/build/sandbox/kubernetes/docker/skills/",
|
||||
"(?:^|/)onyx/server/features/build/sandbox/kubernetes/docker/templates/",
|
||||
]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "alembic.versions.*"
|
||||
disable_error_code = ["var-annotated"]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "alembic_tenants.versions.*"
|
||||
disable_error_code = ["var-annotated"]
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "generated.*"
|
||||
follow_imports = "silent"
|
||||
ignore_errors = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "transformers.*"
|
||||
follow_imports = "skip"
|
||||
ignore_errors = true
|
||||
|
||||
@@ -17,6 +17,7 @@ aiohappyeyeballs==2.6.1
|
||||
aiohttp==3.13.3
|
||||
# via
|
||||
# aiobotocore
|
||||
# botbuilder-integration-aiohttp
|
||||
# discord-py
|
||||
# litellm
|
||||
# onyx
|
||||
@@ -67,6 +68,8 @@ attrs==25.4.0
|
||||
# zeep
|
||||
authlib==1.6.6
|
||||
# via fastmcp
|
||||
azure-core==1.38.2
|
||||
# via msrest
|
||||
babel==2.17.0
|
||||
# via courlan
|
||||
backoff==2.2.1
|
||||
@@ -88,6 +91,25 @@ beautifulsoup4==4.12.3
|
||||
# unstructured
|
||||
billiard==4.2.3
|
||||
# via celery
|
||||
botbuilder-core==4.17.1
|
||||
# via
|
||||
# botbuilder-integration-aiohttp
|
||||
# onyx
|
||||
botbuilder-integration-aiohttp==4.17.1
|
||||
# via onyx
|
||||
botbuilder-schema==4.17.1
|
||||
# via
|
||||
# botbuilder-core
|
||||
# botbuilder-integration-aiohttp
|
||||
# botframework-connector
|
||||
# botframework-streaming
|
||||
botframework-connector==4.17.1
|
||||
# via
|
||||
# botbuilder-core
|
||||
# botbuilder-integration-aiohttp
|
||||
# botframework-streaming
|
||||
botframework-streaming==4.17.1
|
||||
# via botbuilder-core
|
||||
boto3==1.39.11
|
||||
# via
|
||||
# aiobotocore
|
||||
@@ -123,6 +145,7 @@ certifi==2025.11.12
|
||||
# httpx
|
||||
# hubspot-api-client
|
||||
# kubernetes
|
||||
# msrest
|
||||
# opensearch-py
|
||||
# requests
|
||||
# sentry-sdk
|
||||
@@ -444,6 +467,7 @@ iniconfig==2.3.0
|
||||
# via pytest
|
||||
isodate==0.7.2
|
||||
# via
|
||||
# msrest
|
||||
# python3-saml
|
||||
# zeep
|
||||
jaraco-classes==3.4.0
|
||||
@@ -474,6 +498,8 @@ joblib==1.5.2
|
||||
# via nltk
|
||||
jsonpatch==1.33
|
||||
# via langchain-core
|
||||
jsonpickle==3.4.2
|
||||
# via botbuilder-core
|
||||
jsonpointer==3.0.0
|
||||
# via jsonpatch
|
||||
jsonref==1.1.0
|
||||
@@ -528,7 +554,7 @@ lxml==5.3.0
|
||||
# unstructured
|
||||
# xmlsec
|
||||
# zeep
|
||||
lxml-html-clean==0.4.4
|
||||
lxml-html-clean==0.4.3
|
||||
# via lxml
|
||||
magika==0.6.3
|
||||
# via markitdown
|
||||
@@ -573,12 +599,17 @@ mpmath==1.3.0
|
||||
# via sympy
|
||||
msal==1.34.0
|
||||
# via
|
||||
# botframework-connector
|
||||
# office365-rest-python-client
|
||||
# onyx
|
||||
msgpack==1.1.2
|
||||
# via distributed
|
||||
msoffcrypto-tool==5.4.2
|
||||
# via onyx
|
||||
msrest==0.7.1
|
||||
# via
|
||||
# botbuilder-schema
|
||||
# botframework-connector
|
||||
multidict==6.7.0
|
||||
# via
|
||||
# aiobotocore
|
||||
@@ -796,6 +827,7 @@ pygments==2.19.2
|
||||
# via rich
|
||||
pyjwt==2.11.0
|
||||
# via
|
||||
# botframework-connector
|
||||
# fastapi-users
|
||||
# mcp
|
||||
# msal
|
||||
@@ -809,7 +841,7 @@ pypandoc-binary==1.16.2
|
||||
# via onyx
|
||||
pyparsing==3.2.5
|
||||
# via httplib2
|
||||
pypdf==6.7.5
|
||||
pypdf==6.7.3
|
||||
# via
|
||||
# onyx
|
||||
# unstructured-client
|
||||
@@ -922,6 +954,7 @@ regex==2025.11.3
|
||||
requests==2.32.5
|
||||
# via
|
||||
# atlassian-python-api
|
||||
# azure-core
|
||||
# braintrust
|
||||
# cohere
|
||||
# dropbox
|
||||
@@ -940,6 +973,7 @@ requests==2.32.5
|
||||
# markitdown
|
||||
# matrix-client
|
||||
# msal
|
||||
# msrest
|
||||
# office365-rest-python-client
|
||||
# onyx
|
||||
# opensearch-py
|
||||
@@ -967,6 +1001,7 @@ requests-oauthlib==1.3.1
|
||||
# google-auth-oauthlib
|
||||
# jira
|
||||
# kubernetes
|
||||
# msrest
|
||||
# onyx
|
||||
requests-toolbelt==1.0.0
|
||||
# via
|
||||
@@ -1111,6 +1146,7 @@ typing-extensions==4.15.0
|
||||
# aiosignal
|
||||
# alembic
|
||||
# anyio
|
||||
# azure-core
|
||||
# boto3-stubs
|
||||
# braintrust
|
||||
# cohere
|
||||
@@ -1177,6 +1213,7 @@ uritemplate==4.2.0
|
||||
urllib3==2.6.3
|
||||
# via
|
||||
# asana
|
||||
# botbuilder-schema
|
||||
# botocore
|
||||
# courlan
|
||||
# distributed
|
||||
@@ -1239,7 +1276,9 @@ xmlsec==1.3.14
|
||||
xmltodict==1.0.2
|
||||
# via ddtrace
|
||||
yarl==1.22.0
|
||||
# via aiohttp
|
||||
# via
|
||||
# aiohttp
|
||||
# botbuilder-integration-aiohttp
|
||||
zeep==4.3.2
|
||||
# via simple-salesforce
|
||||
zict==3.0.0
|
||||
|
||||
@@ -46,10 +46,10 @@ def _make_task(
|
||||
run_fn: MagicMock | None = None,
|
||||
) -> _PeriodicTaskDef:
|
||||
return _PeriodicTaskDef(
|
||||
name=name if name is not None else f"test-{uuid4().hex[:8]}",
|
||||
name=name or f"test-{uuid4().hex[:8]}",
|
||||
interval_seconds=interval,
|
||||
lock_id=lock_id if lock_id is not None else _TEST_LOCK_BASE,
|
||||
run_fn=run_fn if run_fn is not None else MagicMock(),
|
||||
lock_id=lock_id or _TEST_LOCK_BASE,
|
||||
run_fn=run_fn or MagicMock(),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
|
||||
|
||||
class ScimClient:
|
||||
"""HTTP client for making authenticated SCIM v2 requests."""
|
||||
|
||||
@staticmethod
|
||||
def _headers(raw_token: str) -> dict[str, str]:
|
||||
return {
|
||||
**GENERAL_HEADERS,
|
||||
"Authorization": f"Bearer {raw_token}",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get(path: str, raw_token: str) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=ScimClient._headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def post(path: str, raw_token: str, json: dict) -> requests.Response:
|
||||
return requests.post(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
json=json,
|
||||
headers=ScimClient._headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def put(path: str, raw_token: str, json: dict) -> requests.Response:
|
||||
return requests.put(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
json=json,
|
||||
headers=ScimClient._headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def patch(path: str, raw_token: str, json: dict) -> requests.Response:
|
||||
return requests.patch(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
json=json,
|
||||
headers=ScimClient._headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def delete(path: str, raw_token: str) -> requests.Response:
|
||||
return requests.delete(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=ScimClient._headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_no_auth(path: str) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=GENERAL_HEADERS,
|
||||
timeout=60,
|
||||
)
|
||||
@@ -1,6 +1,7 @@
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestScimToken
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -50,3 +51,29 @@ class ScimTokenManager:
|
||||
created_at=data["created_at"],
|
||||
last_used_at=data.get("last_used_at"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_scim_headers(raw_token: str) -> dict[str, str]:
|
||||
return {
|
||||
**GENERAL_HEADERS,
|
||||
"Authorization": f"Bearer {raw_token}",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def scim_get(
|
||||
path: str,
|
||||
raw_token: str,
|
||||
) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=ScimTokenManager.get_scim_headers(raw_token),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def scim_get_no_auth(path: str) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/scim/v2{path}",
|
||||
headers=GENERAL_HEADERS,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
@@ -1,160 +0,0 @@
|
||||
"""Integration test for the full user-file lifecycle in no-vector-DB mode.
|
||||
|
||||
Covers: upload → COMPLETED → unlink from project → delete → gone.
|
||||
|
||||
The entire lifecycle is handled by FastAPI BackgroundTasks (no Celery workers
|
||||
needed). The conftest-level ``pytestmark`` ensures these tests are skipped
|
||||
when the server is running with vector DB enabled.
|
||||
"""
|
||||
|
||||
import time
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.project import ProjectManager
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
POLL_INTERVAL_SECONDS = 1
|
||||
POLL_TIMEOUT_SECONDS = 30
|
||||
|
||||
|
||||
def _poll_file_status(
|
||||
file_id: UUID,
|
||||
user: DATestUser,
|
||||
target_status: UserFileStatus,
|
||||
timeout: int = POLL_TIMEOUT_SECONDS,
|
||||
) -> None:
|
||||
"""Poll GET /user/projects/file/{file_id} until the file reaches *target_status*."""
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
resp = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects/file/{file_id}",
|
||||
headers=user.headers,
|
||||
)
|
||||
if resp.ok:
|
||||
status = resp.json().get("status")
|
||||
if status == target_status.value:
|
||||
return
|
||||
time.sleep(POLL_INTERVAL_SECONDS)
|
||||
raise TimeoutError(
|
||||
f"File {file_id} did not reach {target_status.value} within {timeout}s"
|
||||
)
|
||||
|
||||
|
||||
def _file_is_gone(file_id: UUID, user: DATestUser, timeout: int = 15) -> None:
|
||||
"""Poll until GET /user/projects/file/{file_id} returns 404."""
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
resp = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects/file/{file_id}",
|
||||
headers=user.headers,
|
||||
)
|
||||
if resp.status_code == 404:
|
||||
return
|
||||
time.sleep(POLL_INTERVAL_SECONDS)
|
||||
raise TimeoutError(
|
||||
f"File {file_id} still accessible after {timeout}s (expected 404)"
|
||||
)
|
||||
|
||||
|
||||
def test_file_upload_process_delete_lifecycle(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Full lifecycle: upload → COMPLETED → unlink → delete → 404.
|
||||
|
||||
Validates that the API server handles all background processing
|
||||
(via FastAPI BackgroundTasks) without any Celery workers running.
|
||||
"""
|
||||
project = ProjectManager.create(
|
||||
name="lifecycle-test", user_performing_action=admin_user
|
||||
)
|
||||
|
||||
file_content = b"Integration test file content for lifecycle verification."
|
||||
upload_result = ProjectManager.upload_files(
|
||||
project_id=project.id,
|
||||
files=[("lifecycle.txt", file_content)],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert upload_result.user_files, "Expected at least one file in upload response"
|
||||
|
||||
user_file = upload_result.user_files[0]
|
||||
file_id = user_file.id
|
||||
|
||||
_poll_file_status(file_id, admin_user, UserFileStatus.COMPLETED)
|
||||
|
||||
project_files = ProjectManager.get_project_files(project.id, admin_user)
|
||||
assert any(
|
||||
f.id == file_id for f in project_files
|
||||
), "File should be listed in project files after processing"
|
||||
|
||||
# Unlink the file from the project so the delete endpoint will proceed
|
||||
unlink_resp = requests.delete(
|
||||
f"{API_SERVER_URL}/user/projects/{project.id}/files/{file_id}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert (
|
||||
unlink_resp.status_code == 204
|
||||
), f"Expected 204 on unlink, got {unlink_resp.status_code}: {unlink_resp.text}"
|
||||
|
||||
delete_resp = requests.delete(
|
||||
f"{API_SERVER_URL}/user/projects/file/{file_id}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert (
|
||||
delete_resp.ok
|
||||
), f"Delete request failed: {delete_resp.status_code} {delete_resp.text}"
|
||||
body = delete_resp.json()
|
||||
assert (
|
||||
body["has_associations"] is False
|
||||
), f"File still has associations after unlink: {body}"
|
||||
|
||||
_file_is_gone(file_id, admin_user)
|
||||
|
||||
project_files_after = ProjectManager.get_project_files(project.id, admin_user)
|
||||
assert not any(
|
||||
f.id == file_id for f in project_files_after
|
||||
), "Deleted file should not appear in project files"
|
||||
|
||||
|
||||
def test_delete_blocked_while_associated(
|
||||
reset: None, # noqa: ARG001
|
||||
admin_user: DATestUser,
|
||||
llm_provider: DATestLLMProvider, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Deleting a file that still belongs to a project should return
|
||||
has_associations=True without actually deleting the file."""
|
||||
project = ProjectManager.create(
|
||||
name="assoc-test", user_performing_action=admin_user
|
||||
)
|
||||
|
||||
upload_result = ProjectManager.upload_files(
|
||||
project_id=project.id,
|
||||
files=[("assoc.txt", b"associated file content")],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
file_id = upload_result.user_files[0].id
|
||||
|
||||
_poll_file_status(file_id, admin_user, UserFileStatus.COMPLETED)
|
||||
|
||||
# Attempt to delete while still linked
|
||||
delete_resp = requests.delete(
|
||||
f"{API_SERVER_URL}/user/projects/file/{file_id}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert delete_resp.ok
|
||||
body = delete_resp.json()
|
||||
assert body["has_associations"] is True, "Should report existing associations"
|
||||
assert project.name in body["project_names"]
|
||||
|
||||
# File should still be accessible
|
||||
get_resp = requests.get(
|
||||
f"{API_SERVER_URL}/user/projects/file/{file_id}",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert get_resp.status_code == 200, "File should still exist after blocked delete"
|
||||
@@ -1,552 +0,0 @@
|
||||
"""Integration tests for SCIM group provisioning endpoints.
|
||||
|
||||
Covers the full group lifecycle as driven by an IdP (Okta / Azure AD):
|
||||
1. Create a group via POST /Groups
|
||||
2. Retrieve a group via GET /Groups/{id}
|
||||
3. List, filter, and paginate groups via GET /Groups
|
||||
4. Replace a group via PUT /Groups/{id}
|
||||
5. Patch a group (add/remove members, rename) via PATCH /Groups/{id}
|
||||
6. Delete a group via DELETE /Groups/{id}
|
||||
7. Error cases: duplicate name, not-found, invalid member IDs
|
||||
|
||||
All tests are parameterized across IdP request styles (Okta sends lowercase
|
||||
PATCH ops; Entra sends capitalized ops like ``"Replace"``). The server
|
||||
normalizes both — these tests verify that.
|
||||
|
||||
Auth tests live in test_scim_tokens.py.
|
||||
User lifecycle tests live in test_scim_users.py.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from tests.integration.common_utils.managers.scim_client import ScimClient
|
||||
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
|
||||
|
||||
|
||||
SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group"
|
||||
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
|
||||
SCIM_PATCH_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:PatchOp"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=["okta", "entra"])
|
||||
def idp_style(request: pytest.FixtureRequest) -> str:
|
||||
"""Parameterized IdP style — runs every test with both Okta and Entra request formats."""
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def scim_token(idp_style: str) -> str:
|
||||
"""Create a single SCIM token shared across all tests in this module.
|
||||
|
||||
Creating a new token revokes the previous one, so we create exactly once
|
||||
per IdP-style run and reuse. Uses UserManager directly to avoid
|
||||
fixture-scope conflicts with the function-scoped admin_user fixture.
|
||||
"""
|
||||
from tests.integration.common_utils.constants import ADMIN_USER_NAME
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
try:
|
||||
admin = UserManager.create(name=ADMIN_USER_NAME)
|
||||
except Exception:
|
||||
admin = UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email(ADMIN_USER_NAME),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.ADMIN,
|
||||
is_active=True,
|
||||
)
|
||||
)
|
||||
|
||||
token = ScimTokenManager.create(
|
||||
name=f"scim-group-tests-{idp_style}",
|
||||
user_performing_action=admin,
|
||||
).raw_token
|
||||
assert token is not None
|
||||
return token
|
||||
|
||||
|
||||
def _make_group_resource(
|
||||
display_name: str,
|
||||
external_id: str | None = None,
|
||||
members: list[dict] | None = None,
|
||||
) -> dict:
|
||||
"""Build a minimal SCIM GroupResource payload."""
|
||||
resource: dict = {
|
||||
"schemas": [SCIM_GROUP_SCHEMA],
|
||||
"displayName": display_name,
|
||||
}
|
||||
if external_id is not None:
|
||||
resource["externalId"] = external_id
|
||||
if members is not None:
|
||||
resource["members"] = members
|
||||
return resource
|
||||
|
||||
|
||||
def _make_user_resource(email: str, external_id: str) -> dict:
|
||||
"""Build a minimal SCIM UserResource payload for member creation."""
|
||||
return {
|
||||
"schemas": [SCIM_USER_SCHEMA],
|
||||
"userName": email,
|
||||
"externalId": external_id,
|
||||
"name": {"givenName": "Test", "familyName": "User"},
|
||||
"active": True,
|
||||
}
|
||||
|
||||
|
||||
def _make_patch_request(operations: list[dict], idp_style: str = "okta") -> dict:
|
||||
"""Build a SCIM PatchOp payload, applying IdP-specific operation casing.
|
||||
|
||||
Entra sends capitalized operations (e.g. ``"Replace"`` instead of
|
||||
``"replace"``). The server's ``normalize_operation`` validator lowercases
|
||||
them — these tests verify that both casings are accepted.
|
||||
"""
|
||||
cased_operations = []
|
||||
for operation in operations:
|
||||
cased = dict(operation)
|
||||
if idp_style == "entra":
|
||||
cased["op"] = operation["op"].capitalize()
|
||||
cased_operations.append(cased)
|
||||
return {
|
||||
"schemas": [SCIM_PATCH_SCHEMA],
|
||||
"Operations": cased_operations,
|
||||
}
|
||||
|
||||
|
||||
def _create_scim_user(token: str, email: str, external_id: str) -> requests.Response:
|
||||
return ScimClient.post(
|
||||
"/Users", token, json=_make_user_resource(email, external_id)
|
||||
)
|
||||
|
||||
|
||||
def _create_scim_group(
|
||||
token: str,
|
||||
display_name: str,
|
||||
external_id: str | None = None,
|
||||
members: list[dict] | None = None,
|
||||
) -> requests.Response:
|
||||
return ScimClient.post(
|
||||
"/Groups",
|
||||
token,
|
||||
json=_make_group_resource(display_name, external_id, members),
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle: create → get → list → replace → patch → delete
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_group(scim_token: str, idp_style: str) -> None:
|
||||
"""POST /Groups creates a group and returns 201."""
|
||||
name = f"Engineering {idp_style}"
|
||||
resp = _create_scim_group(scim_token, name, external_id=f"ext-eng-{idp_style}")
|
||||
assert resp.status_code == 201
|
||||
|
||||
body = resp.json()
|
||||
assert body["displayName"] == name
|
||||
assert body["externalId"] == f"ext-eng-{idp_style}"
|
||||
assert body["id"] # integer ID assigned by server
|
||||
assert body["meta"]["resourceType"] == "Group"
|
||||
|
||||
|
||||
def test_create_group_with_members(scim_token: str, idp_style: str) -> None:
|
||||
"""POST /Groups with members populates the member list."""
|
||||
user = _create_scim_user(
|
||||
scim_token, f"grp_member1_{idp_style}@example.com", f"ext-gm-{idp_style}"
|
||||
).json()
|
||||
|
||||
resp = _create_scim_group(
|
||||
scim_token,
|
||||
f"Backend Team {idp_style}",
|
||||
external_id=f"ext-backend-{idp_style}",
|
||||
members=[{"value": user["id"]}],
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
|
||||
body = resp.json()
|
||||
member_ids = [m["value"] for m in body["members"]]
|
||||
assert user["id"] in member_ids
|
||||
|
||||
|
||||
def test_get_group(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Groups/{id} returns the group resource including members."""
|
||||
user = _create_scim_user(
|
||||
scim_token, f"grp_get_m_{idp_style}@example.com", f"ext-ggm-{idp_style}"
|
||||
).json()
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Frontend Team {idp_style}",
|
||||
external_id=f"ext-fe-{idp_style}",
|
||||
members=[{"value": user["id"]}],
|
||||
).json()
|
||||
|
||||
resp = ScimClient.get(f"/Groups/{created['id']}", scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["id"] == created["id"]
|
||||
assert body["displayName"] == f"Frontend Team {idp_style}"
|
||||
assert body["externalId"] == f"ext-fe-{idp_style}"
|
||||
member_ids = [m["value"] for m in body["members"]]
|
||||
assert user["id"] in member_ids
|
||||
|
||||
|
||||
def test_list_groups(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Groups returns a ListResponse containing provisioned groups."""
|
||||
name = f"DevOps Team {idp_style}"
|
||||
_create_scim_group(scim_token, name, external_id=f"ext-devops-{idp_style}")
|
||||
|
||||
resp = ScimClient.get("/Groups", scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["totalResults"] >= 1
|
||||
names = [r["displayName"] for r in body["Resources"]]
|
||||
assert name in names
|
||||
|
||||
|
||||
def test_list_groups_pagination(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Groups with startIndex and count returns correct pagination."""
|
||||
_create_scim_group(
|
||||
scim_token, f"Page Group A {idp_style}", external_id=f"ext-page-a-{idp_style}"
|
||||
)
|
||||
_create_scim_group(
|
||||
scim_token, f"Page Group B {idp_style}", external_id=f"ext-page-b-{idp_style}"
|
||||
)
|
||||
|
||||
resp = ScimClient.get("/Groups?startIndex=1&count=1", scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["startIndex"] == 1
|
||||
assert body["itemsPerPage"] == 1
|
||||
assert body["totalResults"] >= 2
|
||||
assert len(body["Resources"]) == 1
|
||||
|
||||
|
||||
def test_filter_groups_by_display_name(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Groups?filter=displayName eq '...' returns only matching groups."""
|
||||
name = f"Unique QA Team {idp_style}"
|
||||
_create_scim_group(scim_token, name, external_id=f"ext-qa-filter-{idp_style}")
|
||||
|
||||
resp = ScimClient.get(f'/Groups?filter=displayName eq "{name}"', scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["totalResults"] == 1
|
||||
assert body["Resources"][0]["displayName"] == name
|
||||
|
||||
|
||||
def test_filter_groups_by_external_id(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Groups?filter=externalId eq '...' returns the matching group."""
|
||||
ext_id = f"ext-unique-group-id-{idp_style}"
|
||||
_create_scim_group(
|
||||
scim_token, f"ExtId Filter Group {idp_style}", external_id=ext_id
|
||||
)
|
||||
|
||||
resp = ScimClient.get(f'/Groups?filter=externalId eq "{ext_id}"', scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["totalResults"] == 1
|
||||
assert body["Resources"][0]["externalId"] == ext_id
|
||||
|
||||
|
||||
def test_replace_group(scim_token: str, idp_style: str) -> None:
|
||||
"""PUT /Groups/{id} replaces the group resource."""
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Original Name {idp_style}",
|
||||
external_id=f"ext-replace-g-{idp_style}",
|
||||
).json()
|
||||
|
||||
user = _create_scim_user(
|
||||
scim_token, f"grp_replace_m_{idp_style}@example.com", f"ext-grm-{idp_style}"
|
||||
).json()
|
||||
|
||||
updated_resource = _make_group_resource(
|
||||
display_name=f"Renamed Group {idp_style}",
|
||||
external_id=f"ext-replace-g-{idp_style}",
|
||||
members=[{"value": user["id"]}],
|
||||
)
|
||||
resp = ScimClient.put(f"/Groups/{created['id']}", scim_token, json=updated_resource)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["displayName"] == f"Renamed Group {idp_style}"
|
||||
member_ids = [m["value"] for m in body["members"]]
|
||||
assert user["id"] in member_ids
|
||||
|
||||
|
||||
def test_replace_group_clears_members(scim_token: str, idp_style: str) -> None:
|
||||
"""PUT /Groups/{id} with empty members removes all memberships."""
|
||||
user = _create_scim_user(
|
||||
scim_token, f"grp_clear_m_{idp_style}@example.com", f"ext-gcm-{idp_style}"
|
||||
).json()
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Clear Members Group {idp_style}",
|
||||
external_id=f"ext-clear-g-{idp_style}",
|
||||
members=[{"value": user["id"]}],
|
||||
).json()
|
||||
|
||||
assert len(created["members"]) == 1
|
||||
|
||||
resp = ScimClient.put(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_group_resource(
|
||||
f"Clear Members Group {idp_style}", f"ext-clear-g-{idp_style}", members=[]
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["members"] == []
|
||||
|
||||
|
||||
def test_patch_add_member(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH /Groups/{id} with op=add adds a member."""
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Patch Add Group {idp_style}",
|
||||
external_id=f"ext-patch-add-{idp_style}",
|
||||
).json()
|
||||
user = _create_scim_user(
|
||||
scim_token, f"grp_patch_add_{idp_style}@example.com", f"ext-gpa-{idp_style}"
|
||||
).json()
|
||||
|
||||
resp = ScimClient.patch(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "add", "path": "members", "value": [{"value": user["id"]}]}],
|
||||
idp_style,
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
member_ids = [m["value"] for m in resp.json()["members"]]
|
||||
assert user["id"] in member_ids
|
||||
|
||||
|
||||
def test_patch_remove_member(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH /Groups/{id} with op=remove removes a specific member."""
|
||||
user = _create_scim_user(
|
||||
scim_token, f"grp_patch_rm_{idp_style}@example.com", f"ext-gpr-{idp_style}"
|
||||
).json()
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Patch Remove Group {idp_style}",
|
||||
external_id=f"ext-patch-rm-{idp_style}",
|
||||
members=[{"value": user["id"]}],
|
||||
).json()
|
||||
assert len(created["members"]) == 1
|
||||
|
||||
resp = ScimClient.patch(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[
|
||||
{
|
||||
"op": "remove",
|
||||
"path": f'members[value eq "{user["id"]}"]',
|
||||
}
|
||||
],
|
||||
idp_style,
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["members"] == []
|
||||
|
||||
|
||||
def test_patch_replace_members(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH /Groups/{id} with op=replace on members swaps the entire list."""
|
||||
user_a = _create_scim_user(
|
||||
scim_token, f"grp_repl_a_{idp_style}@example.com", f"ext-gra-{idp_style}"
|
||||
).json()
|
||||
user_b = _create_scim_user(
|
||||
scim_token, f"grp_repl_b_{idp_style}@example.com", f"ext-grb-{idp_style}"
|
||||
).json()
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Patch Replace Group {idp_style}",
|
||||
external_id=f"ext-patch-repl-{idp_style}",
|
||||
members=[{"value": user_a["id"]}],
|
||||
).json()
|
||||
|
||||
# Replace member list: swap A for B
|
||||
resp = ScimClient.patch(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[
|
||||
{
|
||||
"op": "replace",
|
||||
"path": "members",
|
||||
"value": [{"value": user_b["id"]}],
|
||||
}
|
||||
],
|
||||
idp_style,
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
member_ids = [m["value"] for m in resp.json()["members"]]
|
||||
assert user_b["id"] in member_ids
|
||||
assert user_a["id"] not in member_ids
|
||||
|
||||
|
||||
def test_patch_rename_group(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH /Groups/{id} with op=replace on displayName renames the group."""
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Old Group Name {idp_style}",
|
||||
external_id=f"ext-rename-g-{idp_style}",
|
||||
).json()
|
||||
|
||||
new_name = f"New Group Name {idp_style}"
|
||||
resp = ScimClient.patch(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "replace", "path": "displayName", "value": new_name}],
|
||||
idp_style,
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["displayName"] == new_name
|
||||
|
||||
# Confirm via GET
|
||||
get_resp = ScimClient.get(f"/Groups/{created['id']}", scim_token)
|
||||
assert get_resp.json()["displayName"] == new_name
|
||||
|
||||
|
||||
def test_delete_group(scim_token: str, idp_style: str) -> None:
|
||||
"""DELETE /Groups/{id} removes the group."""
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Delete Me Group {idp_style}",
|
||||
external_id=f"ext-del-g-{idp_style}",
|
||||
).json()
|
||||
|
||||
resp = ScimClient.delete(f"/Groups/{created['id']}", scim_token)
|
||||
assert resp.status_code == 204
|
||||
|
||||
# Second DELETE returns 404 (group hard-deleted)
|
||||
resp2 = ScimClient.delete(f"/Groups/{created['id']}", scim_token)
|
||||
assert resp2.status_code == 404
|
||||
|
||||
|
||||
def test_delete_group_preserves_members(scim_token: str, idp_style: str) -> None:
|
||||
"""DELETE /Groups/{id} removes memberships but does not deactivate users."""
|
||||
user = _create_scim_user(
|
||||
scim_token, f"grp_del_member_{idp_style}@example.com", f"ext-gdm-{idp_style}"
|
||||
).json()
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Delete With Members {idp_style}",
|
||||
external_id=f"ext-del-wm-{idp_style}",
|
||||
members=[{"value": user["id"]}],
|
||||
).json()
|
||||
|
||||
resp = ScimClient.delete(f"/Groups/{created['id']}", scim_token)
|
||||
assert resp.status_code == 204
|
||||
|
||||
# User should still be active and retrievable
|
||||
user_resp = ScimClient.get(f"/Users/{user['id']}", scim_token)
|
||||
assert user_resp.status_code == 200
|
||||
assert user_resp.json()["active"] is True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Error cases
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_group_duplicate_name(scim_token: str, idp_style: str) -> None:
|
||||
"""POST /Groups with an already-taken displayName returns 409."""
|
||||
name = f"Dup Name Group {idp_style}"
|
||||
resp1 = _create_scim_group(scim_token, name, external_id=f"ext-dup-g1-{idp_style}")
|
||||
assert resp1.status_code == 201
|
||||
|
||||
resp2 = _create_scim_group(scim_token, name, external_id=f"ext-dup-g2-{idp_style}")
|
||||
assert resp2.status_code == 409
|
||||
|
||||
|
||||
def test_get_nonexistent_group(scim_token: str) -> None:
|
||||
"""GET /Groups/{bad-id} returns 404."""
|
||||
resp = ScimClient.get("/Groups/999999999", scim_token)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_create_group_with_invalid_member(scim_token: str, idp_style: str) -> None:
|
||||
"""POST /Groups with a non-existent member UUID returns 400."""
|
||||
resp = _create_scim_group(
|
||||
scim_token,
|
||||
f"Bad Member Group {idp_style}",
|
||||
external_id=f"ext-bad-m-{idp_style}",
|
||||
members=[{"value": "00000000-0000-0000-0000-000000000000"}],
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert "not found" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_patch_add_nonexistent_member(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH /Groups/{id} adding a non-existent member returns 400."""
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Patch Bad Member Group {idp_style}",
|
||||
external_id=f"ext-pbm-{idp_style}",
|
||||
).json()
|
||||
|
||||
resp = ScimClient.patch(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[
|
||||
{
|
||||
"op": "add",
|
||||
"path": "members",
|
||||
"value": [{"value": "00000000-0000-0000-0000-000000000000"}],
|
||||
}
|
||||
],
|
||||
idp_style,
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert "not found" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_patch_add_duplicate_member_is_idempotent(
|
||||
scim_token: str, idp_style: str
|
||||
) -> None:
|
||||
"""PATCH /Groups/{id} adding an already-present member succeeds silently."""
|
||||
user = _create_scim_user(
|
||||
scim_token, f"grp_dup_add_{idp_style}@example.com", f"ext-gda-{idp_style}"
|
||||
).json()
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Idempotent Add Group {idp_style}",
|
||||
external_id=f"ext-idem-g-{idp_style}",
|
||||
members=[{"value": user["id"]}],
|
||||
).json()
|
||||
assert len(created["members"]) == 1
|
||||
|
||||
# Add same member again
|
||||
resp = ScimClient.patch(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "add", "path": "members", "value": [{"value": user["id"]}]}],
|
||||
idp_style,
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()["members"]) == 1 # still just one member
|
||||
@@ -15,7 +15,6 @@ import time
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.scim_client import ScimClient
|
||||
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
@@ -40,7 +39,7 @@ def test_scim_token_lifecycle(admin_user: DATestUser) -> None:
|
||||
assert active == token.model_copy(update={"raw_token": None})
|
||||
|
||||
# Token works for SCIM requests
|
||||
response = ScimClient.get("/Users", token.raw_token)
|
||||
response = ScimTokenManager.scim_get("/Users", token.raw_token)
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert "Resources" in body
|
||||
@@ -55,7 +54,7 @@ def test_scim_token_rotation_revokes_previous(admin_user: DATestUser) -> None:
|
||||
)
|
||||
assert first.raw_token is not None
|
||||
|
||||
response = ScimClient.get("/Users", first.raw_token)
|
||||
response = ScimTokenManager.scim_get("/Users", first.raw_token)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Create second token — should revoke first
|
||||
@@ -70,22 +69,25 @@ def test_scim_token_rotation_revokes_previous(admin_user: DATestUser) -> None:
|
||||
assert active == second.model_copy(update={"raw_token": None})
|
||||
|
||||
# First token rejected, second works
|
||||
assert ScimClient.get("/Users", first.raw_token).status_code == 401
|
||||
assert ScimClient.get("/Users", second.raw_token).status_code == 200
|
||||
assert ScimTokenManager.scim_get("/Users", first.raw_token).status_code == 401
|
||||
assert ScimTokenManager.scim_get("/Users", second.raw_token).status_code == 200
|
||||
|
||||
|
||||
def test_scim_request_without_token_rejected(
|
||||
admin_user: DATestUser, # noqa: ARG001
|
||||
) -> None:
|
||||
"""SCIM endpoints reject requests with no Authorization header."""
|
||||
assert ScimClient.get_no_auth("/Users").status_code == 401
|
||||
assert ScimTokenManager.scim_get_no_auth("/Users").status_code == 401
|
||||
|
||||
|
||||
def test_scim_request_with_bad_token_rejected(
|
||||
admin_user: DATestUser, # noqa: ARG001
|
||||
) -> None:
|
||||
"""SCIM endpoints reject requests with an invalid token."""
|
||||
assert ScimClient.get("/Users", "onyx_scim_bogus_token_value").status_code == 401
|
||||
assert (
|
||||
ScimTokenManager.scim_get("/Users", "onyx_scim_bogus_token_value").status_code
|
||||
== 401
|
||||
)
|
||||
|
||||
|
||||
def test_non_admin_cannot_create_token(
|
||||
@@ -137,7 +139,7 @@ def test_service_discovery_no_auth_required(
|
||||
) -> None:
|
||||
"""Service discovery endpoints work without any authentication."""
|
||||
for path in ["/ServiceProviderConfig", "/ResourceTypes", "/Schemas"]:
|
||||
response = ScimClient.get_no_auth(path)
|
||||
response = ScimTokenManager.scim_get_no_auth(path)
|
||||
assert response.status_code == 200, f"{path} returned {response.status_code}"
|
||||
|
||||
|
||||
@@ -156,7 +158,7 @@ def test_last_used_at_updated_after_scim_request(
|
||||
assert active.last_used_at is None
|
||||
|
||||
# Make a SCIM request, then verify last_used_at is set
|
||||
assert ScimClient.get("/Users", token.raw_token).status_code == 200
|
||||
assert ScimTokenManager.scim_get("/Users", token.raw_token).status_code == 200
|
||||
time.sleep(0.5)
|
||||
|
||||
active_after = ScimTokenManager.get_active(user_performing_action=admin_user)
|
||||
|
||||
@@ -1,520 +0,0 @@
|
||||
"""Integration tests for SCIM user provisioning endpoints.
|
||||
|
||||
Covers the full user lifecycle as driven by an IdP (Okta / Azure AD):
|
||||
1. Create a user via POST /Users
|
||||
2. Retrieve a user via GET /Users/{id}
|
||||
3. List, filter, and paginate users via GET /Users
|
||||
4. Replace a user via PUT /Users/{id}
|
||||
5. Patch a user (deactivate/reactivate) via PATCH /Users/{id}
|
||||
6. Delete a user via DELETE /Users/{id}
|
||||
7. Error cases: missing externalId, duplicate email, not-found, seat limit
|
||||
|
||||
All tests are parameterized across IdP request styles:
|
||||
- **Okta**: lowercase PATCH ops, minimal payloads (core schema only).
|
||||
- **Entra**: capitalized ops (``"Replace"``), enterprise extension data
|
||||
(department, manager), and structured email arrays.
|
||||
|
||||
The server normalizes both — these tests verify that all IdP-specific fields
|
||||
are accepted and round-tripped correctly.
|
||||
|
||||
Auth, revoked-token, and service-discovery tests live in test_scim_tokens.py.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
import pytest
|
||||
import redis
|
||||
import requests
|
||||
|
||||
from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from ee.onyx.server.license.models import PlanType
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import REDIS_DB_NUMBER
|
||||
from onyx.configs.app_configs import REDIS_HOST
|
||||
from onyx.configs.app_configs import REDIS_PORT
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from tests.integration.common_utils.managers.scim_client import ScimClient
|
||||
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
|
||||
|
||||
|
||||
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
|
||||
SCIM_ENTERPRISE_USER_SCHEMA = (
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
||||
)
|
||||
SCIM_PATCH_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:PatchOp"
|
||||
|
||||
_LICENSE_REDIS_KEY = "public:license:metadata"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=["okta", "entra"])
|
||||
def idp_style(request: pytest.FixtureRequest) -> str:
|
||||
"""Parameterized IdP style — runs every test with both Okta and Entra request formats."""
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def scim_token(idp_style: str) -> str:
|
||||
"""Create a single SCIM token shared across all tests in this module.
|
||||
|
||||
Creating a new token revokes the previous one, so we create exactly once
|
||||
per IdP-style run and reuse. Uses UserManager directly to avoid
|
||||
fixture-scope conflicts with the function-scoped admin_user fixture.
|
||||
"""
|
||||
from tests.integration.common_utils.constants import ADMIN_USER_NAME
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
try:
|
||||
admin = UserManager.create(name=ADMIN_USER_NAME)
|
||||
except Exception:
|
||||
admin = UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email(ADMIN_USER_NAME),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.ADMIN,
|
||||
is_active=True,
|
||||
)
|
||||
)
|
||||
|
||||
token = ScimTokenManager.create(
|
||||
name=f"scim-user-tests-{idp_style}",
|
||||
user_performing_action=admin,
|
||||
).raw_token
|
||||
assert token is not None
|
||||
return token
|
||||
|
||||
|
||||
def _make_user_resource(
|
||||
email: str,
|
||||
external_id: str,
|
||||
given_name: str = "Test",
|
||||
family_name: str = "User",
|
||||
active: bool = True,
|
||||
idp_style: str = "okta",
|
||||
department: str | None = None,
|
||||
manager_id: str | None = None,
|
||||
) -> dict:
|
||||
"""Build a SCIM UserResource payload appropriate for the IdP style.
|
||||
|
||||
Entra sends richer payloads including enterprise extension data (department,
|
||||
manager), structured email arrays, and the enterprise schema URN. Okta sends
|
||||
minimal payloads with just core user fields.
|
||||
"""
|
||||
resource: dict = {
|
||||
"schemas": [SCIM_USER_SCHEMA],
|
||||
"userName": email,
|
||||
"externalId": external_id,
|
||||
"name": {
|
||||
"givenName": given_name,
|
||||
"familyName": family_name,
|
||||
},
|
||||
"active": active,
|
||||
}
|
||||
if idp_style == "entra":
|
||||
dept = department or "Engineering"
|
||||
mgr = manager_id or "mgr-ext-001"
|
||||
resource["schemas"].append(SCIM_ENTERPRISE_USER_SCHEMA)
|
||||
resource[SCIM_ENTERPRISE_USER_SCHEMA] = {
|
||||
"department": dept,
|
||||
"manager": {"value": mgr},
|
||||
}
|
||||
resource["emails"] = [
|
||||
{"value": email, "type": "work", "primary": True},
|
||||
]
|
||||
return resource
|
||||
|
||||
|
||||
def _make_patch_request(operations: list[dict], idp_style: str = "okta") -> dict:
|
||||
"""Build a SCIM PatchOp payload, applying IdP-specific operation casing.
|
||||
|
||||
Entra sends capitalized operations (e.g. ``"Replace"`` instead of
|
||||
``"replace"``). The server's ``normalize_operation`` validator lowercases
|
||||
them — these tests verify that both casings are accepted.
|
||||
"""
|
||||
cased_operations = []
|
||||
for operation in operations:
|
||||
cased = dict(operation)
|
||||
if idp_style == "entra":
|
||||
cased["op"] = operation["op"].capitalize()
|
||||
cased_operations.append(cased)
|
||||
return {
|
||||
"schemas": [SCIM_PATCH_SCHEMA],
|
||||
"Operations": cased_operations,
|
||||
}
|
||||
|
||||
|
||||
def _create_scim_user(
|
||||
token: str,
|
||||
email: str,
|
||||
external_id: str,
|
||||
idp_style: str = "okta",
|
||||
) -> requests.Response:
|
||||
return ScimClient.post(
|
||||
"/Users",
|
||||
token,
|
||||
json=_make_user_resource(email, external_id, idp_style=idp_style),
|
||||
)
|
||||
|
||||
|
||||
def _assert_entra_extension(
|
||||
body: dict,
|
||||
expected_department: str = "Engineering",
|
||||
expected_manager: str = "mgr-ext-001",
|
||||
) -> None:
|
||||
"""Assert that Entra enterprise extension fields round-tripped correctly."""
|
||||
assert SCIM_ENTERPRISE_USER_SCHEMA in body["schemas"]
|
||||
ext = body[SCIM_ENTERPRISE_USER_SCHEMA]
|
||||
assert ext["department"] == expected_department
|
||||
assert ext["manager"]["value"] == expected_manager
|
||||
|
||||
|
||||
def _assert_entra_emails(body: dict, expected_email: str) -> None:
|
||||
"""Assert that structured email metadata round-tripped correctly."""
|
||||
emails = body["emails"]
|
||||
assert len(emails) >= 1
|
||||
work_email = next(e for e in emails if e.get("type") == "work")
|
||||
assert work_email["value"] == expected_email
|
||||
assert work_email["primary"] is True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle: create -> get -> list -> replace -> patch -> delete
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_user(scim_token: str, idp_style: str) -> None:
|
||||
"""POST /Users creates a provisioned user and returns 201."""
|
||||
email = f"scim_create_{idp_style}@example.com"
|
||||
ext_id = f"ext-create-{idp_style}"
|
||||
resp = _create_scim_user(scim_token, email, ext_id, idp_style)
|
||||
assert resp.status_code == 201
|
||||
|
||||
body = resp.json()
|
||||
assert body["userName"] == email
|
||||
assert body["externalId"] == ext_id
|
||||
assert body["active"] is True
|
||||
assert body["id"] # UUID assigned by server
|
||||
assert body["meta"]["resourceType"] == "User"
|
||||
assert body["name"]["givenName"] == "Test"
|
||||
assert body["name"]["familyName"] == "User"
|
||||
|
||||
if idp_style == "entra":
|
||||
_assert_entra_extension(body)
|
||||
_assert_entra_emails(body, email)
|
||||
|
||||
|
||||
def test_get_user(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Users/{id} returns the user resource with all stored fields."""
|
||||
email = f"scim_get_{idp_style}@example.com"
|
||||
ext_id = f"ext-get-{idp_style}"
|
||||
created = _create_scim_user(scim_token, email, ext_id, idp_style).json()
|
||||
|
||||
resp = ScimClient.get(f"/Users/{created['id']}", scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["id"] == created["id"]
|
||||
assert body["userName"] == email
|
||||
assert body["externalId"] == ext_id
|
||||
assert body["name"]["givenName"] == "Test"
|
||||
assert body["name"]["familyName"] == "User"
|
||||
|
||||
if idp_style == "entra":
|
||||
_assert_entra_extension(body)
|
||||
_assert_entra_emails(body, email)
|
||||
|
||||
|
||||
def test_list_users(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Users returns a ListResponse containing provisioned users."""
|
||||
email = f"scim_list_{idp_style}@example.com"
|
||||
_create_scim_user(scim_token, email, f"ext-list-{idp_style}", idp_style)
|
||||
|
||||
resp = ScimClient.get("/Users", scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["totalResults"] >= 1
|
||||
emails = [r["userName"] for r in body["Resources"]]
|
||||
assert email in emails
|
||||
|
||||
|
||||
def test_list_users_pagination(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Users with startIndex and count returns correct pagination."""
|
||||
_create_scim_user(
|
||||
scim_token,
|
||||
f"scim_page1_{idp_style}@example.com",
|
||||
f"ext-page-1-{idp_style}",
|
||||
idp_style,
|
||||
)
|
||||
_create_scim_user(
|
||||
scim_token,
|
||||
f"scim_page2_{idp_style}@example.com",
|
||||
f"ext-page-2-{idp_style}",
|
||||
idp_style,
|
||||
)
|
||||
|
||||
resp = ScimClient.get("/Users?startIndex=1&count=1", scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["startIndex"] == 1
|
||||
assert body["itemsPerPage"] == 1
|
||||
assert body["totalResults"] >= 2
|
||||
assert len(body["Resources"]) == 1
|
||||
|
||||
|
||||
def test_filter_users_by_username(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Users?filter=userName eq '...' returns only matching users."""
|
||||
email = f"scim_filter_{idp_style}@example.com"
|
||||
_create_scim_user(scim_token, email, f"ext-filter-{idp_style}", idp_style)
|
||||
|
||||
resp = ScimClient.get(f'/Users?filter=userName eq "{email}"', scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["totalResults"] == 1
|
||||
assert body["Resources"][0]["userName"] == email
|
||||
|
||||
|
||||
def test_replace_user(scim_token: str, idp_style: str) -> None:
|
||||
"""PUT /Users/{id} replaces the user resource including enterprise fields."""
|
||||
email = f"scim_replace_{idp_style}@example.com"
|
||||
ext_id = f"ext-replace-{idp_style}"
|
||||
created = _create_scim_user(scim_token, email, ext_id, idp_style).json()
|
||||
|
||||
updated_resource = _make_user_resource(
|
||||
email=email,
|
||||
external_id=ext_id,
|
||||
given_name="Updated",
|
||||
family_name="Name",
|
||||
idp_style=idp_style,
|
||||
department="Product",
|
||||
)
|
||||
resp = ScimClient.put(f"/Users/{created['id']}", scim_token, json=updated_resource)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["name"]["givenName"] == "Updated"
|
||||
assert body["name"]["familyName"] == "Name"
|
||||
|
||||
if idp_style == "entra":
|
||||
_assert_entra_extension(body, expected_department="Product")
|
||||
_assert_entra_emails(body, email)
|
||||
|
||||
|
||||
def test_patch_deactivate_user(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH /Users/{id} with active=false deactivates the user."""
|
||||
created = _create_scim_user(
|
||||
scim_token,
|
||||
f"scim_deactivate_{idp_style}@example.com",
|
||||
f"ext-deactivate-{idp_style}",
|
||||
idp_style,
|
||||
).json()
|
||||
assert created["active"] is True
|
||||
|
||||
resp = ScimClient.patch(
|
||||
f"/Users/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "replace", "path": "active", "value": False}], idp_style
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["active"] is False
|
||||
|
||||
# Confirm via GET
|
||||
get_resp = ScimClient.get(f"/Users/{created['id']}", scim_token)
|
||||
assert get_resp.json()["active"] is False
|
||||
|
||||
|
||||
def test_patch_reactivate_user(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH active=true reactivates a previously deactivated user."""
|
||||
created = _create_scim_user(
|
||||
scim_token,
|
||||
f"scim_reactivate_{idp_style}@example.com",
|
||||
f"ext-reactivate-{idp_style}",
|
||||
idp_style,
|
||||
).json()
|
||||
|
||||
# Deactivate
|
||||
deactivate_resp = ScimClient.patch(
|
||||
f"/Users/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "replace", "path": "active", "value": False}], idp_style
|
||||
),
|
||||
)
|
||||
assert deactivate_resp.status_code == 200
|
||||
assert deactivate_resp.json()["active"] is False
|
||||
|
||||
# Reactivate
|
||||
resp = ScimClient.patch(
|
||||
f"/Users/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "replace", "path": "active", "value": True}], idp_style
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["active"] is True
|
||||
|
||||
|
||||
def test_delete_user(scim_token: str, idp_style: str) -> None:
|
||||
"""DELETE /Users/{id} deactivates and removes the SCIM mapping."""
|
||||
created = _create_scim_user(
|
||||
scim_token,
|
||||
f"scim_delete_{idp_style}@example.com",
|
||||
f"ext-delete-{idp_style}",
|
||||
idp_style,
|
||||
).json()
|
||||
|
||||
resp = ScimClient.delete(f"/Users/{created['id']}", scim_token)
|
||||
assert resp.status_code == 204
|
||||
|
||||
# Second DELETE returns 404 per RFC 7644 §3.6 (mapping removed)
|
||||
resp2 = ScimClient.delete(f"/Users/{created['id']}", scim_token)
|
||||
assert resp2.status_code == 404
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Error cases
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_create_user_missing_external_id(scim_token: str, idp_style: str) -> None:
|
||||
"""POST /Users without externalId succeeds (RFC 7643: externalId is optional)."""
|
||||
email = f"scim_no_extid_{idp_style}@example.com"
|
||||
resp = ScimClient.post(
|
||||
"/Users",
|
||||
scim_token,
|
||||
json={
|
||||
"schemas": [SCIM_USER_SCHEMA],
|
||||
"userName": email,
|
||||
"active": True,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
body = resp.json()
|
||||
assert body["userName"] == email
|
||||
assert body.get("externalId") is None
|
||||
|
||||
|
||||
def test_create_user_duplicate_email(scim_token: str, idp_style: str) -> None:
|
||||
"""POST /Users with an already-taken email returns 409."""
|
||||
email = f"scim_dup_{idp_style}@example.com"
|
||||
resp1 = _create_scim_user(scim_token, email, f"ext-dup-1-{idp_style}", idp_style)
|
||||
assert resp1.status_code == 201
|
||||
|
||||
resp2 = _create_scim_user(scim_token, email, f"ext-dup-2-{idp_style}", idp_style)
|
||||
assert resp2.status_code == 409
|
||||
|
||||
|
||||
def test_get_nonexistent_user(scim_token: str) -> None:
|
||||
"""GET /Users/{bad-id} returns 404."""
|
||||
resp = ScimClient.get("/Users/00000000-0000-0000-0000-000000000000", scim_token)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_filter_users_by_external_id(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Users?filter=externalId eq '...' returns the matching user."""
|
||||
ext_id = f"ext-unique-filter-id-{idp_style}"
|
||||
_create_scim_user(
|
||||
scim_token, f"scim_extfilter_{idp_style}@example.com", ext_id, idp_style
|
||||
)
|
||||
|
||||
resp = ScimClient.get(f'/Users?filter=externalId eq "{ext_id}"', scim_token)
|
||||
assert resp.status_code == 200
|
||||
|
||||
body = resp.json()
|
||||
assert body["totalResults"] == 1
|
||||
assert body["Resources"][0]["externalId"] == ext_id
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Seat-limit enforcement
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _seed_license(r: redis.Redis, seats: int) -> None:
|
||||
"""Write a LicenseMetadata entry into Redis with the given seat cap."""
|
||||
now = datetime.now(timezone.utc)
|
||||
metadata = LicenseMetadata(
|
||||
tenant_id="public",
|
||||
organization_name="Test Org",
|
||||
seats=seats,
|
||||
used_seats=0, # check_seat_availability recalculates from DB
|
||||
plan_type=PlanType.ANNUAL,
|
||||
issued_at=now,
|
||||
expires_at=now + timedelta(days=365),
|
||||
status=ApplicationStatus.ACTIVE,
|
||||
source=LicenseSource.MANUAL_UPLOAD,
|
||||
)
|
||||
r.set(_LICENSE_REDIS_KEY, metadata.model_dump_json(), ex=300)
|
||||
|
||||
|
||||
def test_create_user_seat_limit(scim_token: str, idp_style: str) -> None:
|
||||
"""POST /Users returns 403 when the seat limit is reached."""
|
||||
r = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB_NUMBER)
|
||||
|
||||
# admin_user already occupies 1 seat; cap at 1 -> full
|
||||
_seed_license(r, seats=1)
|
||||
|
||||
try:
|
||||
resp = _create_scim_user(
|
||||
scim_token,
|
||||
f"scim_blocked_{idp_style}@example.com",
|
||||
f"ext-blocked-{idp_style}",
|
||||
idp_style,
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
assert "seat" in resp.json()["detail"].lower()
|
||||
finally:
|
||||
r.delete(_LICENSE_REDIS_KEY)
|
||||
|
||||
|
||||
def test_reactivate_user_seat_limit(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH active=true returns 403 when the seat limit is reached."""
|
||||
# Create and deactivate a user (before license is seeded)
|
||||
created = _create_scim_user(
|
||||
scim_token,
|
||||
f"scim_reactivate_blocked_{idp_style}@example.com",
|
||||
f"ext-reactivate-blocked-{idp_style}",
|
||||
idp_style,
|
||||
).json()
|
||||
assert created["active"] is True
|
||||
|
||||
deactivate_resp = ScimClient.patch(
|
||||
f"/Users/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "replace", "path": "active", "value": False}], idp_style
|
||||
),
|
||||
)
|
||||
assert deactivate_resp.status_code == 200
|
||||
assert deactivate_resp.json()["active"] is False
|
||||
|
||||
r = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB_NUMBER)
|
||||
|
||||
# Seed license capped at current active users -> reactivation should fail
|
||||
_seed_license(r, seats=1)
|
||||
|
||||
try:
|
||||
resp = ScimClient.patch(
|
||||
f"/Users/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "replace", "path": "active", "value": True}], idp_style
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
assert "seat" in resp.json()["detail"].lower()
|
||||
finally:
|
||||
r.delete(_LICENSE_REDIS_KEY)
|
||||
@@ -1,20 +1,11 @@
|
||||
"""Tests for license database CRUD operations."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from ee.onyx.db.license import check_seat_availability
|
||||
from ee.onyx.db.license import delete_license
|
||||
from ee.onyx.db.license import get_license
|
||||
from ee.onyx.db.license import upsert_license
|
||||
from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from ee.onyx.server.license.models import PlanType
|
||||
from onyx.db.models import License
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
|
||||
|
||||
class TestGetLicense:
|
||||
@@ -109,108 +100,3 @@ class TestDeleteLicense:
|
||||
assert result is False
|
||||
mock_session.delete.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
|
||||
def _make_license_metadata(seats: int = 10) -> LicenseMetadata:
|
||||
now = datetime.now(timezone.utc)
|
||||
return LicenseMetadata(
|
||||
tenant_id="public",
|
||||
seats=seats,
|
||||
used_seats=0,
|
||||
plan_type=PlanType.ANNUAL,
|
||||
issued_at=now,
|
||||
expires_at=now + timedelta(days=365),
|
||||
status=ApplicationStatus.ACTIVE,
|
||||
source=LicenseSource.MANUAL_UPLOAD,
|
||||
)
|
||||
|
||||
|
||||
class TestCheckSeatAvailabilitySelfHosted:
|
||||
"""Seat checks for self-hosted (MULTI_TENANT=False)."""
|
||||
|
||||
@patch("ee.onyx.db.license.get_license_metadata", return_value=None)
|
||||
def test_no_license_means_unlimited(self, _mock_meta: MagicMock) -> None:
|
||||
result = check_seat_availability(MagicMock(), seats_needed=1)
|
||||
assert result.available is True
|
||||
|
||||
@patch("ee.onyx.db.license.get_used_seats", return_value=5)
|
||||
@patch("ee.onyx.db.license.get_license_metadata")
|
||||
def test_seats_available(self, mock_meta: MagicMock, _mock_used: MagicMock) -> None:
|
||||
mock_meta.return_value = _make_license_metadata(seats=10)
|
||||
result = check_seat_availability(MagicMock(), seats_needed=1)
|
||||
assert result.available is True
|
||||
|
||||
@patch("ee.onyx.db.license.get_used_seats", return_value=10)
|
||||
@patch("ee.onyx.db.license.get_license_metadata")
|
||||
def test_seats_full_blocks_creation(
|
||||
self, mock_meta: MagicMock, _mock_used: MagicMock
|
||||
) -> None:
|
||||
mock_meta.return_value = _make_license_metadata(seats=10)
|
||||
result = check_seat_availability(MagicMock(), seats_needed=1)
|
||||
assert result.available is False
|
||||
assert result.error_message is not None
|
||||
assert "10 of 10" in result.error_message
|
||||
|
||||
@patch("ee.onyx.db.license.get_used_seats", return_value=10)
|
||||
@patch("ee.onyx.db.license.get_license_metadata")
|
||||
def test_exactly_at_capacity_allows_no_more(
|
||||
self, mock_meta: MagicMock, _mock_used: MagicMock
|
||||
) -> None:
|
||||
"""Filling to 100% is allowed; exceeding is not."""
|
||||
mock_meta.return_value = _make_license_metadata(seats=10)
|
||||
result = check_seat_availability(MagicMock(), seats_needed=1)
|
||||
assert result.available is False
|
||||
|
||||
@patch("ee.onyx.db.license.get_used_seats", return_value=9)
|
||||
@patch("ee.onyx.db.license.get_license_metadata")
|
||||
def test_filling_to_capacity_is_allowed(
|
||||
self, mock_meta: MagicMock, _mock_used: MagicMock
|
||||
) -> None:
|
||||
mock_meta.return_value = _make_license_metadata(seats=10)
|
||||
result = check_seat_availability(MagicMock(), seats_needed=1)
|
||||
assert result.available is True
|
||||
|
||||
|
||||
class TestCheckSeatAvailabilityMultiTenant:
|
||||
"""Seat checks for multi-tenant cloud (MULTI_TENANT=True).
|
||||
|
||||
Verifies that get_used_seats takes the MULTI_TENANT branch
|
||||
and delegates to get_tenant_count.
|
||||
"""
|
||||
|
||||
@patch("ee.onyx.db.license.MULTI_TENANT", True)
|
||||
@patch(
|
||||
"ee.onyx.server.tenants.user_mapping.get_tenant_count",
|
||||
return_value=5,
|
||||
)
|
||||
@patch("ee.onyx.db.license.get_license_metadata")
|
||||
def test_seats_available_multi_tenant(
|
||||
self,
|
||||
mock_meta: MagicMock,
|
||||
mock_tenant_count: MagicMock,
|
||||
) -> None:
|
||||
mock_meta.return_value = _make_license_metadata(seats=10)
|
||||
result = check_seat_availability(
|
||||
MagicMock(), seats_needed=1, tenant_id="tenant-abc"
|
||||
)
|
||||
assert result.available is True
|
||||
mock_tenant_count.assert_called_once_with("tenant-abc")
|
||||
|
||||
@patch("ee.onyx.db.license.MULTI_TENANT", True)
|
||||
@patch(
|
||||
"ee.onyx.server.tenants.user_mapping.get_tenant_count",
|
||||
return_value=10,
|
||||
)
|
||||
@patch("ee.onyx.db.license.get_license_metadata")
|
||||
def test_seats_full_multi_tenant(
|
||||
self,
|
||||
mock_meta: MagicMock,
|
||||
mock_tenant_count: MagicMock,
|
||||
) -> None:
|
||||
mock_meta.return_value = _make_license_metadata(seats=10)
|
||||
result = check_seat_availability(
|
||||
MagicMock(), seats_needed=1, tenant_id="tenant-abc"
|
||||
)
|
||||
assert result.available is False
|
||||
assert result.error_message is not None
|
||||
mock_tenant_count.assert_called_once_with("tenant-abc")
|
||||
|
||||
@@ -12,11 +12,12 @@ import aiohttp
|
||||
import pytest
|
||||
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.onyxbot.discord.api_client import OnyxAPIClient
|
||||
from onyx.onyxbot.discord.constants import API_REQUEST_TIMEOUT
|
||||
from onyx.onyxbot.discord.exceptions import APIConnectionError
|
||||
from onyx.onyxbot.discord.exceptions import APIResponseError
|
||||
from onyx.onyxbot.discord.exceptions import APITimeoutError
|
||||
from onyx.onyxbot.api_client import OnyxAPIClient
|
||||
from onyx.onyxbot.constants import API_REQUEST_TIMEOUT
|
||||
from onyx.onyxbot.exceptions import APIConnectionError
|
||||
from onyx.onyxbot.exceptions import APIResponseError
|
||||
from onyx.onyxbot.exceptions import APITimeoutError
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
|
||||
|
||||
class MockAsyncContextManager:
|
||||
@@ -43,7 +44,7 @@ class TestClientLifecycle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_creates_session(self) -> None:
|
||||
"""initialize() creates aiohttp session."""
|
||||
client = OnyxAPIClient()
|
||||
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
|
||||
assert client._session is None
|
||||
|
||||
with patch("aiohttp.ClientSession") as mock_session_class:
|
||||
@@ -57,13 +58,13 @@ class TestClientLifecycle:
|
||||
|
||||
def test_is_initialized_before_init(self) -> None:
|
||||
"""is_initialized returns False before initialize()."""
|
||||
client = OnyxAPIClient()
|
||||
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
|
||||
assert client.is_initialized is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_initialized_after_init(self) -> None:
|
||||
"""is_initialized returns True after initialize()."""
|
||||
client = OnyxAPIClient()
|
||||
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
|
||||
|
||||
with patch("aiohttp.ClientSession"):
|
||||
await client.initialize()
|
||||
@@ -73,7 +74,7 @@ class TestClientLifecycle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_closes_session(self) -> None:
|
||||
"""close() closes session and resets is_initialized."""
|
||||
client = OnyxAPIClient()
|
||||
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
@@ -88,7 +89,7 @@ class TestClientLifecycle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_not_initialized(self) -> None:
|
||||
"""send_chat_message() before initialize() raises APIConnectionError."""
|
||||
client = OnyxAPIClient()
|
||||
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
|
||||
|
||||
with pytest.raises(APIConnectionError) as exc_info:
|
||||
await client.send_chat_message("test", "api_key")
|
||||
@@ -102,7 +103,7 @@ class TestSendChatMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_success(self) -> None:
|
||||
"""Valid request returns ChatFullResponse."""
|
||||
client = OnyxAPIClient()
|
||||
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
|
||||
|
||||
response_data = {
|
||||
"answer": "Test response",
|
||||
@@ -133,7 +134,7 @@ class TestSendChatMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_with_persona(self) -> None:
|
||||
"""persona_id is passed to API."""
|
||||
client = OnyxAPIClient()
|
||||
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
|
||||
|
||||
response_data = {"answer": "Response", "citations": [], "error_msg": None}
|
||||
|
||||
@@ -164,7 +165,7 @@ class TestSendChatMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_401_error(self) -> None:
|
||||
"""Invalid API key returns APIResponseError with 401."""
|
||||
client = OnyxAPIClient()
|
||||
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 401
|
||||
@@ -184,7 +185,7 @@ class TestSendChatMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_403_error(self) -> None:
|
||||
"""Persona not accessible returns APIResponseError with 403."""
|
||||
client = OnyxAPIClient()
|
||||
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 403
|
||||
@@ -204,7 +205,7 @@ class TestSendChatMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_timeout(self) -> None:
|
||||
"""Request timeout raises APITimeoutError."""
|
||||
client = OnyxAPIClient()
|
||||
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.post = MagicMock(
|
||||
@@ -221,7 +222,7 @@ class TestSendChatMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_connection_error(self) -> None:
|
||||
"""Network failure raises APIConnectionError."""
|
||||
client = OnyxAPIClient()
|
||||
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.post = MagicMock(
|
||||
@@ -240,7 +241,7 @@ class TestSendChatMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_server_error(self) -> None:
|
||||
"""500 response raises APIResponseError with 500."""
|
||||
client = OnyxAPIClient()
|
||||
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 500
|
||||
@@ -265,7 +266,7 @@ class TestHealthCheck:
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_success(self) -> None:
|
||||
"""Server healthy returns True."""
|
||||
client = OnyxAPIClient()
|
||||
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
@@ -283,7 +284,7 @@ class TestHealthCheck:
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_failure(self) -> None:
|
||||
"""Server unhealthy returns False."""
|
||||
client = OnyxAPIClient()
|
||||
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 503
|
||||
@@ -301,7 +302,7 @@ class TestHealthCheck:
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_timeout(self) -> None:
|
||||
"""Request times out returns False."""
|
||||
client = OnyxAPIClient()
|
||||
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.get = MagicMock(
|
||||
@@ -318,7 +319,7 @@ class TestHealthCheck:
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_not_initialized(self) -> None:
|
||||
"""Health check before initialize returns False."""
|
||||
client = OnyxAPIClient()
|
||||
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
|
||||
|
||||
result = await client.health_check()
|
||||
assert result is False
|
||||
@@ -330,7 +331,7 @@ class TestResponseParsing:
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_malformed_json(self) -> None:
|
||||
"""API returns invalid JSON raises exception."""
|
||||
client = OnyxAPIClient()
|
||||
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
@@ -349,7 +350,7 @@ class TestResponseParsing:
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_with_error_msg(self) -> None:
|
||||
"""200 status but error_msg present - warning logged, response returned."""
|
||||
client = OnyxAPIClient()
|
||||
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
|
||||
|
||||
response_data = {
|
||||
"answer": "Partial response",
|
||||
@@ -381,7 +382,7 @@ class TestResponseParsing:
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_empty_answer(self) -> None:
|
||||
"""answer field is empty string - handled gracefully."""
|
||||
client = OnyxAPIClient()
|
||||
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
|
||||
|
||||
response_data = {
|
||||
"answer": "",
|
||||
@@ -416,18 +417,18 @@ class TestClientConfiguration:
|
||||
|
||||
def test_default_timeout(self) -> None:
|
||||
"""Client uses API_REQUEST_TIMEOUT by default."""
|
||||
client = OnyxAPIClient()
|
||||
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
|
||||
assert client._timeout == API_REQUEST_TIMEOUT
|
||||
|
||||
def test_custom_timeout(self) -> None:
|
||||
"""Client accepts custom timeout."""
|
||||
client = OnyxAPIClient(timeout=60)
|
||||
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT, timeout=60)
|
||||
assert client._timeout == 60
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_double_initialize_warning(self) -> None:
|
||||
"""Calling initialize() twice logs warning but doesn't error."""
|
||||
client = OnyxAPIClient()
|
||||
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
|
||||
|
||||
with patch("aiohttp.ClientSession") as mock_session_class:
|
||||
mock_session = MagicMock()
|
||||
|
||||
@@ -18,7 +18,7 @@ class TestCacheInitialization:
|
||||
def test_cache_starts_empty(self) -> None:
|
||||
"""New cache manager has empty caches."""
|
||||
cache = DiscordCacheManager()
|
||||
assert cache._guild_tenants == {}
|
||||
assert cache._entity_tenants == {}
|
||||
assert cache._api_keys == {}
|
||||
assert cache.is_initialized is False
|
||||
|
||||
@@ -37,14 +37,14 @@ class TestCacheInitialization:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
|
||||
"onyx.onyxbot.cache.get_all_tenant_ids",
|
||||
return_value=["tenant1"],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
|
||||
"onyx.onyxbot.cache.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda: set(),
|
||||
),
|
||||
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
|
||||
patch("onyx.onyxbot.cache.get_session_with_tenant") as mock_session,
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_guild_configs",
|
||||
return_value=[mock_config1, mock_config2],
|
||||
@@ -61,10 +61,10 @@ class TestCacheInitialization:
|
||||
await cache.refresh_all()
|
||||
|
||||
assert cache.is_initialized is True
|
||||
assert 111111 in cache._guild_tenants
|
||||
assert 222222 in cache._guild_tenants
|
||||
assert cache._guild_tenants[111111] == "tenant1"
|
||||
assert cache._guild_tenants[222222] == "tenant1"
|
||||
assert 111111 in cache._entity_tenants
|
||||
assert 222222 in cache._entity_tenants
|
||||
assert cache._entity_tenants[111111] == "tenant1"
|
||||
assert cache._entity_tenants[222222] == "tenant1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_refresh_provisions_api_key(self) -> None:
|
||||
@@ -77,14 +77,14 @@ class TestCacheInitialization:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
|
||||
"onyx.onyxbot.cache.get_all_tenant_ids",
|
||||
return_value=["tenant1"],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
|
||||
"onyx.onyxbot.cache.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda: set(),
|
||||
),
|
||||
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
|
||||
patch("onyx.onyxbot.cache.get_session_with_tenant") as mock_session,
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_guild_configs",
|
||||
return_value=[mock_config],
|
||||
@@ -110,7 +110,7 @@ class TestCacheLookups:
|
||||
def test_get_tenant_returns_correct(self) -> None:
|
||||
"""Lookup registered guild returns correct tenant ID."""
|
||||
cache = DiscordCacheManager()
|
||||
cache._guild_tenants[123456] = "tenant1"
|
||||
cache._entity_tenants[123456] = "tenant1"
|
||||
|
||||
result = cache.get_tenant(123456)
|
||||
assert result == "tenant1"
|
||||
@@ -140,7 +140,7 @@ class TestCacheLookups:
|
||||
def test_get_all_guild_ids(self) -> None:
|
||||
"""After loading returns all cached guild IDs."""
|
||||
cache = DiscordCacheManager()
|
||||
cache._guild_tenants = {111: "t1", 222: "t2", 333: "t1"}
|
||||
cache._entity_tenants = {111: "t1", 222: "t2", 333: "t1"}
|
||||
|
||||
result = cache.get_all_guild_ids()
|
||||
assert set(result) == {111, 222, 333}
|
||||
@@ -159,7 +159,7 @@ class TestCacheUpdates:
|
||||
mock_config.enabled = True
|
||||
|
||||
with (
|
||||
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
|
||||
patch("onyx.onyxbot.cache.get_session_with_tenant") as mock_session,
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_guild_configs",
|
||||
return_value=[mock_config],
|
||||
@@ -187,7 +187,7 @@ class TestCacheUpdates:
|
||||
mock_config.enabled = False # Disabled!
|
||||
|
||||
with (
|
||||
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
|
||||
patch("onyx.onyxbot.cache.get_session_with_tenant") as mock_session,
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_guild_configs",
|
||||
return_value=[mock_config],
|
||||
@@ -205,7 +205,7 @@ class TestCacheUpdates:
|
||||
def test_remove_guild(self) -> None:
|
||||
"""remove_guild() removes guild from cache."""
|
||||
cache = DiscordCacheManager()
|
||||
cache._guild_tenants[111111] = "tenant1"
|
||||
cache._entity_tenants[111111] = "tenant1"
|
||||
|
||||
cache.remove_guild(111111)
|
||||
|
||||
@@ -214,13 +214,13 @@ class TestCacheUpdates:
|
||||
def test_clear_removes_all(self) -> None:
|
||||
"""clear() empties all caches."""
|
||||
cache = DiscordCacheManager()
|
||||
cache._guild_tenants = {111: "t1", 222: "t2"}
|
||||
cache._entity_tenants = {111: "t1", 222: "t2"}
|
||||
cache._api_keys = {"t1": "key1", "t2": "key2"}
|
||||
cache._initialized = True
|
||||
|
||||
cache.clear()
|
||||
|
||||
assert cache._guild_tenants == {}
|
||||
assert cache._entity_tenants == {}
|
||||
assert cache._api_keys == {}
|
||||
assert cache.is_initialized is False
|
||||
|
||||
@@ -239,7 +239,7 @@ class TestThreadSafety:
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def slow_refresh() -> tuple[list[int], str]:
|
||||
async def slow_refresh(_tenant_id: str) -> tuple[list[int], str]:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
# Simulate slow operation
|
||||
@@ -248,11 +248,11 @@ class TestThreadSafety:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
|
||||
"onyx.onyxbot.cache.get_all_tenant_ids",
|
||||
return_value=["tenant1"],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
|
||||
"onyx.onyxbot.cache.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda: set(),
|
||||
),
|
||||
patch.object(cache, "_load_tenant_data", side_effect=slow_refresh),
|
||||
@@ -271,7 +271,7 @@ class TestThreadSafety:
|
||||
async def test_concurrent_read_write(self) -> None:
|
||||
"""Read during refresh doesn't cause exceptions."""
|
||||
cache = DiscordCacheManager()
|
||||
cache._guild_tenants[111111] = "tenant1"
|
||||
cache._entity_tenants[111111] = "tenant1"
|
||||
|
||||
async def read_loop() -> None:
|
||||
for _ in range(10):
|
||||
@@ -280,7 +280,7 @@ class TestThreadSafety:
|
||||
|
||||
async def write_loop() -> None:
|
||||
for i in range(10):
|
||||
cache._guild_tenants[200000 + i] = f"tenant{i}"
|
||||
cache._entity_tenants[200000 + i] = f"tenant{i}"
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
# Should not raise any exceptions
|
||||
@@ -301,14 +301,14 @@ class TestAPIKeyProvisioning:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
|
||||
"onyx.onyxbot.cache.get_all_tenant_ids",
|
||||
return_value=["tenant1"],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
|
||||
"onyx.onyxbot.cache.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda: set(),
|
||||
),
|
||||
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
|
||||
patch("onyx.onyxbot.cache.get_session_with_tenant") as mock_session,
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_guild_configs",
|
||||
return_value=[mock_config],
|
||||
@@ -339,14 +339,14 @@ class TestAPIKeyProvisioning:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
|
||||
"onyx.onyxbot.cache.get_all_tenant_ids",
|
||||
return_value=["tenant1"],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
|
||||
"onyx.onyxbot.cache.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda: set(),
|
||||
),
|
||||
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
|
||||
patch("onyx.onyxbot.cache.get_session_with_tenant") as mock_session,
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_guild_configs",
|
||||
return_value=[mock_config],
|
||||
@@ -392,14 +392,14 @@ class TestGatedTenantHandling:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
|
||||
"onyx.onyxbot.cache.get_all_tenant_ids",
|
||||
return_value=["tenant1", "tenant2"],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
|
||||
"onyx.onyxbot.cache.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda: gated_tenants,
|
||||
),
|
||||
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
|
||||
patch("onyx.onyxbot.cache.get_session_with_tenant") as mock_session,
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_guild_configs",
|
||||
side_effect=mock_get_configs,
|
||||
@@ -416,9 +416,9 @@ class TestGatedTenantHandling:
|
||||
await cache.refresh_all()
|
||||
|
||||
# Only tenant1 should be loaded (tenant2 is gated)
|
||||
assert "tenant1" in cache._api_keys and 111111 in cache._guild_tenants
|
||||
assert "tenant1" in cache._api_keys and 111111 in cache._entity_tenants
|
||||
# tenant2's guilds should NOT be in cache
|
||||
assert "tenant2" not in cache._api_keys and 222222 not in cache._guild_tenants
|
||||
assert "tenant2" not in cache._api_keys and 222222 not in cache._entity_tenants
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gated_check_calls_ee_function(self) -> None:
|
||||
@@ -427,14 +427,14 @@ class TestGatedTenantHandling:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
|
||||
"onyx.onyxbot.cache.get_all_tenant_ids",
|
||||
return_value=["tenant1"],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
|
||||
"onyx.onyxbot.cache.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda: set(),
|
||||
) as mock_ee,
|
||||
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
|
||||
patch("onyx.onyxbot.cache.get_session_with_tenant") as mock_session,
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_guild_configs",
|
||||
return_value=[],
|
||||
@@ -459,14 +459,14 @@ class TestGatedTenantHandling:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
|
||||
"onyx.onyxbot.cache.get_all_tenant_ids",
|
||||
return_value=["tenant1"],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
|
||||
"onyx.onyxbot.cache.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda: set(), # No gated tenants
|
||||
),
|
||||
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
|
||||
patch("onyx.onyxbot.cache.get_session_with_tenant") as mock_session,
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_guild_configs",
|
||||
return_value=[mock_config],
|
||||
@@ -499,16 +499,16 @@ class TestCacheErrorHandling:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if tenant_id == "tenant1":
|
||||
raise Exception("Tenant 1 error")
|
||||
raise ConnectionError("Tenant 1 connection failed")
|
||||
return ([222222], "api_key")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
|
||||
"onyx.onyxbot.cache.get_all_tenant_ids",
|
||||
return_value=["tenant1", "tenant2"],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
|
||||
"onyx.onyxbot.cache.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda: set(),
|
||||
),
|
||||
patch.object(cache, "_load_tenant_data", side_effect=mock_load),
|
||||
|
||||
0
backend/tests/unit/onyx/onyxbot/teams/__init__.py
Normal file
0
backend/tests/unit/onyx/onyxbot/teams/__init__.py
Normal file
105
backend/tests/unit/onyx/onyxbot/teams/conftest.py
Normal file
105
backend/tests/unit/onyx/onyxbot/teams/conftest.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Fixtures for Teams bot unit tests."""
|
||||
|
||||
import random
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_team_config_enabled() -> MagicMock:
|
||||
"""Team config that is enabled."""
|
||||
config = MagicMock()
|
||||
config.id = 1
|
||||
config.team_id = "team-abc-123"
|
||||
config.enabled = True
|
||||
config.default_persona_id = 1
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_team_config_disabled() -> MagicMock:
|
||||
"""Team config that is disabled."""
|
||||
config = MagicMock()
|
||||
config.id = 2
|
||||
config.team_id = "team-abc-123"
|
||||
config.enabled = False
|
||||
config.default_persona_id = None
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_channel_config_factory() -> Callable[..., MagicMock]:
|
||||
"""Factory fixture for creating channel configs with various settings."""
|
||||
|
||||
def _make_config(
|
||||
enabled: bool = True,
|
||||
require_bot_mention: bool = True,
|
||||
persona_override_id: int | None = None,
|
||||
) -> MagicMock:
|
||||
config = MagicMock()
|
||||
config.id = random.randint(1, 1000)
|
||||
config.channel_id = "19:channel-xyz@thread.tacv2"
|
||||
config.enabled = enabled
|
||||
config.require_bot_mention = require_bot_mention
|
||||
config.persona_override_id = persona_override_id
|
||||
return config
|
||||
|
||||
return _make_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_activity_dict() -> dict:
|
||||
"""Sample Teams Activity as a dict."""
|
||||
return {
|
||||
"type": "message",
|
||||
"text": "<at>Onyx</at> What is our deployment process?",
|
||||
"from": {
|
||||
"id": "29:user-id-123",
|
||||
"name": "Test User",
|
||||
},
|
||||
"recipient": {
|
||||
"id": "28:bot-id-456",
|
||||
"name": "Onyx",
|
||||
},
|
||||
"channelData": {
|
||||
"team": {
|
||||
"id": "team-abc-123",
|
||||
"name": "Engineering",
|
||||
},
|
||||
"channel": {
|
||||
"id": "19:channel-xyz@thread.tacv2",
|
||||
"name": "general",
|
||||
},
|
||||
},
|
||||
"entities": [
|
||||
{
|
||||
"type": "mention",
|
||||
"mentioned": {
|
||||
"id": "28:bot-id-456",
|
||||
"name": "Onyx",
|
||||
},
|
||||
"text": "<at>Onyx</at>",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_dm_activity_dict() -> dict:
|
||||
"""Sample Teams DM Activity (no team context)."""
|
||||
return {
|
||||
"type": "message",
|
||||
"text": "Hello bot",
|
||||
"from": {
|
||||
"id": "29:user-id-123",
|
||||
"name": "Test User",
|
||||
},
|
||||
"recipient": {
|
||||
"id": "28:bot-id-456",
|
||||
"name": "Onyx",
|
||||
},
|
||||
"channelData": {},
|
||||
"entities": [],
|
||||
}
|
||||
214
backend/tests/unit/onyx/onyxbot/teams/test_cache_manager.py
Normal file
214
backend/tests/unit/onyx/onyxbot/teams/test_cache_manager.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""Unit tests for Teams bot cache manager."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.onyxbot.teams.cache import TeamsCacheManager
|
||||
|
||||
|
||||
class TestCacheInitialization:
|
||||
"""Tests for cache initialization."""
|
||||
|
||||
def test_cache_starts_empty(self) -> None:
|
||||
cache = TeamsCacheManager()
|
||||
assert cache._entity_tenants == {}
|
||||
assert cache._api_keys == {}
|
||||
assert cache.is_initialized is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_refresh_all_loads_teams(self) -> None:
|
||||
cache = TeamsCacheManager()
|
||||
|
||||
mock_config1 = MagicMock()
|
||||
mock_config1.team_id = "team-111"
|
||||
mock_config1.enabled = True
|
||||
|
||||
mock_config2 = MagicMock()
|
||||
mock_config2.team_id = "team-222"
|
||||
mock_config2.enabled = True
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.cache.get_all_tenant_ids",
|
||||
return_value=["tenant1"],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.cache.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda: set(),
|
||||
),
|
||||
patch("onyx.onyxbot.cache.get_session_with_tenant") as mock_session,
|
||||
patch(
|
||||
"onyx.onyxbot.teams.cache.get_team_configs",
|
||||
return_value=[mock_config1, mock_config2],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.teams.cache.provision_teams_service_api_key",
|
||||
return_value="test_api_key",
|
||||
),
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
await cache.refresh_all()
|
||||
|
||||
assert cache.is_initialized is True
|
||||
assert "team-111" in cache._entity_tenants
|
||||
assert "team-222" in cache._entity_tenants
|
||||
assert cache._entity_tenants["team-111"] == "tenant1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_refresh_provisions_api_key(self) -> None:
|
||||
cache = TeamsCacheManager()
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.team_id = "team-111"
|
||||
mock_config.enabled = True
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.cache.get_all_tenant_ids",
|
||||
return_value=["tenant1"],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.cache.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda: set(),
|
||||
),
|
||||
patch("onyx.onyxbot.cache.get_session_with_tenant") as mock_session,
|
||||
patch(
|
||||
"onyx.onyxbot.teams.cache.get_team_configs",
|
||||
return_value=[mock_config],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.teams.cache.provision_teams_service_api_key",
|
||||
return_value="new_api_key",
|
||||
) as mock_provision,
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
await cache.refresh_all()
|
||||
|
||||
assert cache._api_keys.get("tenant1") == "new_api_key"
|
||||
mock_provision.assert_called()
|
||||
|
||||
|
||||
class TestCacheLookups:
|
||||
"""Tests for cache lookup operations."""
|
||||
|
||||
def test_get_tenant_returns_correct(self) -> None:
|
||||
cache = TeamsCacheManager()
|
||||
cache._entity_tenants["team-123"] = "tenant1"
|
||||
assert cache.get_tenant("team-123") == "tenant1"
|
||||
|
||||
def test_get_tenant_returns_none_unknown(self) -> None:
|
||||
cache = TeamsCacheManager()
|
||||
assert cache.get_tenant("unknown-team") is None
|
||||
|
||||
def test_get_api_key_returns_correct(self) -> None:
|
||||
cache = TeamsCacheManager()
|
||||
cache._api_keys["tenant1"] = "api_key_123"
|
||||
assert cache.get_api_key("tenant1") == "api_key_123"
|
||||
|
||||
def test_get_api_key_returns_none_unknown(self) -> None:
|
||||
cache = TeamsCacheManager()
|
||||
assert cache.get_api_key("unknown_tenant") is None
|
||||
|
||||
def test_get_all_team_ids(self) -> None:
|
||||
cache = TeamsCacheManager()
|
||||
cache._entity_tenants = {"t1": "tenant1", "t2": "tenant2", "t3": "tenant1"}
|
||||
result = cache.get_all_team_ids()
|
||||
assert set(result) == {"t1", "t2", "t3"}
|
||||
|
||||
|
||||
class TestCacheUpdates:
|
||||
"""Tests for cache update operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_team_adds_new(self) -> None:
|
||||
cache = TeamsCacheManager()
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.team_id = "team-111"
|
||||
mock_config.enabled = True
|
||||
|
||||
with (
|
||||
patch("onyx.onyxbot.cache.get_session_with_tenant") as mock_session,
|
||||
patch(
|
||||
"onyx.onyxbot.teams.cache.get_team_configs",
|
||||
return_value=[mock_config],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.teams.cache.provision_teams_service_api_key",
|
||||
return_value="api_key",
|
||||
),
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
await cache.refresh_team("team-111", "tenant1")
|
||||
|
||||
assert cache.get_tenant("team-111") == "tenant1"
|
||||
|
||||
def test_remove_team(self) -> None:
|
||||
cache = TeamsCacheManager()
|
||||
cache._entity_tenants["team-111"] = "tenant1"
|
||||
cache.remove_team("team-111")
|
||||
assert cache.get_tenant("team-111") is None
|
||||
|
||||
def test_clear_removes_all(self) -> None:
|
||||
cache = TeamsCacheManager()
|
||||
cache._entity_tenants = {"t1": "tenant1", "t2": "tenant2"}
|
||||
cache._api_keys = {"tenant1": "key1", "tenant2": "key2"}
|
||||
cache._initialized = True
|
||||
|
||||
cache.clear()
|
||||
|
||||
assert cache._entity_tenants == {}
|
||||
assert cache._api_keys == {}
|
||||
assert cache.is_initialized is False
|
||||
|
||||
|
||||
class TestGatedTenantHandling:
|
||||
"""Tests for gated tenant filtering."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_skips_gated_tenants(self) -> None:
|
||||
cache = TeamsCacheManager()
|
||||
gated_tenants = {"tenant2"}
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.team_id = "team-111"
|
||||
mock_config.enabled = True
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.cache.get_all_tenant_ids",
|
||||
return_value=["tenant1", "tenant2"],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.cache.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda: gated_tenants,
|
||||
),
|
||||
patch("onyx.onyxbot.cache.get_session_with_tenant") as mock_session,
|
||||
patch(
|
||||
"onyx.onyxbot.teams.cache.get_team_configs",
|
||||
return_value=[mock_config],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.teams.cache.provision_teams_service_api_key",
|
||||
return_value="api_key",
|
||||
),
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
await cache.refresh_all()
|
||||
|
||||
assert "tenant1" in cache._api_keys
|
||||
assert "tenant2" not in cache._api_keys
|
||||
86
backend/tests/unit/onyx/onyxbot/teams/test_cards.py
Normal file
86
backend/tests/unit/onyx/onyxbot/teams/test_cards.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""Unit tests for Teams bot Adaptive Card builders."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from onyx.onyxbot.teams.cards import build_answer_card
|
||||
from onyx.onyxbot.teams.cards import build_error_card
|
||||
from onyx.onyxbot.teams.cards import build_welcome_card
|
||||
|
||||
|
||||
class TestBuildAnswerCard:
|
||||
"""Tests for answer card generation."""
|
||||
|
||||
def test_basic_answer(self) -> None:
|
||||
card = build_answer_card("Hello world")
|
||||
assert card["type"] == "AdaptiveCard"
|
||||
assert card["version"] == "1.3"
|
||||
assert len(card["body"]) == 1
|
||||
assert card["body"][0]["text"] == "Hello world"
|
||||
|
||||
def test_answer_with_citations(self) -> None:
|
||||
mock_response = MagicMock()
|
||||
mock_citation = MagicMock()
|
||||
mock_citation.citation_number = 1
|
||||
mock_citation.document_id = "doc1"
|
||||
|
||||
mock_doc = MagicMock()
|
||||
mock_doc.document_id = "doc1"
|
||||
mock_doc.semantic_identifier = "Design Doc"
|
||||
mock_doc.link = "https://example.com/doc1"
|
||||
|
||||
mock_response.citation_info = [mock_citation]
|
||||
mock_response.top_documents = [mock_doc]
|
||||
|
||||
card = build_answer_card("Answer text", mock_response)
|
||||
# Body should have: answer + "Sources:" header + citation
|
||||
assert len(card["body"]) == 3
|
||||
assert "Sources" in card["body"][1]["text"]
|
||||
assert "Design Doc" in card["body"][2]["text"]
|
||||
|
||||
def test_answer_no_citations(self) -> None:
|
||||
mock_response = MagicMock()
|
||||
mock_response.citation_info = []
|
||||
mock_response.top_documents = []
|
||||
|
||||
card = build_answer_card("Answer text", mock_response)
|
||||
assert len(card["body"]) == 1
|
||||
|
||||
def test_answer_citation_without_link(self) -> None:
|
||||
mock_response = MagicMock()
|
||||
mock_citation = MagicMock()
|
||||
mock_citation.citation_number = 1
|
||||
mock_citation.document_id = "doc1"
|
||||
|
||||
mock_doc = MagicMock()
|
||||
mock_doc.document_id = "doc1"
|
||||
mock_doc.semantic_identifier = "Internal Doc"
|
||||
mock_doc.link = None
|
||||
|
||||
mock_response.citation_info = [mock_citation]
|
||||
mock_response.top_documents = [mock_doc]
|
||||
|
||||
card = build_answer_card("Answer text", mock_response)
|
||||
assert "Internal Doc" in card["body"][2]["text"]
|
||||
# Should not contain markdown link since link is None
|
||||
assert "http" not in card["body"][2]["text"]
|
||||
|
||||
|
||||
class TestBuildErrorCard:
|
||||
"""Tests for error card generation."""
|
||||
|
||||
def test_error_card(self) -> None:
|
||||
card = build_error_card("Something went wrong")
|
||||
assert card["type"] == "AdaptiveCard"
|
||||
assert card["body"][0]["text"] == "Something went wrong"
|
||||
assert card["body"][0]["color"] == "Attention"
|
||||
|
||||
|
||||
class TestBuildWelcomeCard:
|
||||
"""Tests for welcome card generation."""
|
||||
|
||||
def test_welcome_card(self) -> None:
|
||||
card = build_welcome_card()
|
||||
assert card["type"] == "AdaptiveCard"
|
||||
assert len(card["body"]) == 2
|
||||
assert "Welcome" in card["body"][0]["text"]
|
||||
assert "register" in card["body"][1]["text"]
|
||||
272
backend/tests/unit/onyx/onyxbot/teams/test_should_respond.py
Normal file
272
backend/tests/unit/onyx/onyxbot/teams/test_should_respond.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""Unit tests for Teams bot should_respond logic."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.onyxbot.teams.handle_message import should_respond
|
||||
|
||||
|
||||
class TestBasicShouldRespond:
|
||||
"""Tests for basic should_respond decision logic."""
|
||||
|
||||
def test_team_disabled_returns_false(self) -> None:
|
||||
"""Team config enabled=false returns False."""
|
||||
mock_team_config = MagicMock()
|
||||
mock_team_config.enabled = False
|
||||
|
||||
with patch(
|
||||
"onyx.onyxbot.teams.handle_message.get_session_with_tenant"
|
||||
) as mock_session:
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
with patch(
|
||||
"onyx.onyxbot.teams.handle_message.get_team_config_by_teams_id",
|
||||
return_value=mock_team_config,
|
||||
):
|
||||
result = should_respond(
|
||||
activity_dict={},
|
||||
team_id="team-123",
|
||||
channel_id="channel-456",
|
||||
tenant_id="tenant1",
|
||||
bot_id="bot-id",
|
||||
)
|
||||
|
||||
assert result.should_respond is False
|
||||
|
||||
def test_team_enabled_channel_enabled_no_mention_required(self) -> None:
|
||||
"""Team + channel enabled, require_bot_mention=false returns True."""
|
||||
mock_team_config = MagicMock()
|
||||
mock_team_config.enabled = True
|
||||
mock_team_config.default_persona_id = 2
|
||||
|
||||
mock_channel_config = MagicMock()
|
||||
mock_channel_config.enabled = True
|
||||
mock_channel_config.require_bot_mention = False
|
||||
mock_channel_config.persona_override_id = None
|
||||
|
||||
with patch(
|
||||
"onyx.onyxbot.teams.handle_message.get_session_with_tenant"
|
||||
) as mock_session:
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.teams.handle_message.get_team_config_by_teams_id",
|
||||
return_value=mock_team_config,
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.teams.handle_message.get_channel_config_by_teams_ids",
|
||||
return_value=mock_channel_config,
|
||||
),
|
||||
):
|
||||
result = should_respond(
|
||||
activity_dict={},
|
||||
team_id="team-123",
|
||||
channel_id="channel-456",
|
||||
tenant_id="tenant1",
|
||||
bot_id="bot-id",
|
||||
)
|
||||
|
||||
assert result.should_respond is True
|
||||
assert result.persona_id == 2
|
||||
|
||||
def test_channel_disabled_returns_false(self) -> None:
|
||||
"""Channel config enabled=false returns False."""
|
||||
mock_team_config = MagicMock()
|
||||
mock_team_config.enabled = True
|
||||
|
||||
mock_channel_config = MagicMock()
|
||||
mock_channel_config.enabled = False
|
||||
|
||||
with patch(
|
||||
"onyx.onyxbot.teams.handle_message.get_session_with_tenant"
|
||||
) as mock_session:
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.teams.handle_message.get_team_config_by_teams_id",
|
||||
return_value=mock_team_config,
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.teams.handle_message.get_channel_config_by_teams_ids",
|
||||
return_value=mock_channel_config,
|
||||
),
|
||||
):
|
||||
result = should_respond(
|
||||
activity_dict={},
|
||||
team_id="team-123",
|
||||
channel_id="channel-456",
|
||||
tenant_id="tenant1",
|
||||
bot_id="bot-id",
|
||||
)
|
||||
|
||||
assert result.should_respond is False
|
||||
|
||||
def test_channel_not_found_returns_false(self) -> None:
|
||||
"""No channel config returns False (not whitelisted)."""
|
||||
mock_team_config = MagicMock()
|
||||
mock_team_config.enabled = True
|
||||
|
||||
with patch(
|
||||
"onyx.onyxbot.teams.handle_message.get_session_with_tenant"
|
||||
) as mock_session:
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.teams.handle_message.get_team_config_by_teams_id",
|
||||
return_value=mock_team_config,
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.teams.handle_message.get_channel_config_by_teams_ids",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = should_respond(
|
||||
activity_dict={},
|
||||
team_id="team-123",
|
||||
channel_id="channel-456",
|
||||
tenant_id="tenant1",
|
||||
bot_id="bot-id",
|
||||
)
|
||||
|
||||
assert result.should_respond is False
|
||||
|
||||
def test_require_mention_true_with_mention(
|
||||
self, sample_activity_dict: dict
|
||||
) -> None:
|
||||
"""require_bot_mention=true with @mention returns True."""
|
||||
mock_team_config = MagicMock()
|
||||
mock_team_config.enabled = True
|
||||
mock_team_config.default_persona_id = 1
|
||||
|
||||
mock_channel_config = MagicMock()
|
||||
mock_channel_config.enabled = True
|
||||
mock_channel_config.require_bot_mention = True
|
||||
mock_channel_config.persona_override_id = None
|
||||
|
||||
with patch(
|
||||
"onyx.onyxbot.teams.handle_message.get_session_with_tenant"
|
||||
) as mock_session:
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.teams.handle_message.get_team_config_by_teams_id",
|
||||
return_value=mock_team_config,
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.teams.handle_message.get_channel_config_by_teams_ids",
|
||||
return_value=mock_channel_config,
|
||||
),
|
||||
):
|
||||
result = should_respond(
|
||||
activity_dict=sample_activity_dict,
|
||||
team_id="team-abc-123",
|
||||
channel_id="19:channel-xyz@thread.tacv2",
|
||||
tenant_id="tenant1",
|
||||
bot_id="28:bot-id-456",
|
||||
)
|
||||
|
||||
assert result.should_respond is True
|
||||
|
||||
def test_require_mention_true_no_mention(self) -> None:
|
||||
"""require_bot_mention=true without @mention returns False."""
|
||||
mock_team_config = MagicMock()
|
||||
mock_team_config.enabled = True
|
||||
mock_team_config.default_persona_id = 1
|
||||
|
||||
mock_channel_config = MagicMock()
|
||||
mock_channel_config.enabled = True
|
||||
mock_channel_config.require_bot_mention = True
|
||||
mock_channel_config.persona_override_id = None
|
||||
|
||||
activity_no_mention = {"entities": []}
|
||||
|
||||
with patch(
|
||||
"onyx.onyxbot.teams.handle_message.get_session_with_tenant"
|
||||
) as mock_session:
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.teams.handle_message.get_team_config_by_teams_id",
|
||||
return_value=mock_team_config,
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.teams.handle_message.get_channel_config_by_teams_ids",
|
||||
return_value=mock_channel_config,
|
||||
),
|
||||
):
|
||||
result = should_respond(
|
||||
activity_dict=activity_no_mention,
|
||||
team_id="team-123",
|
||||
channel_id="channel-456",
|
||||
tenant_id="tenant1",
|
||||
bot_id="bot-id",
|
||||
)
|
||||
|
||||
assert result.should_respond is False
|
||||
|
||||
def test_dm_no_team_returns_true(self) -> None:
|
||||
"""DM (no team_id or channel_id) returns True."""
|
||||
result = should_respond(
|
||||
activity_dict={},
|
||||
team_id=None,
|
||||
channel_id=None,
|
||||
tenant_id="tenant1",
|
||||
bot_id="bot-id",
|
||||
)
|
||||
assert result.should_respond is True
|
||||
|
||||
def test_persona_override_takes_priority(self) -> None:
|
||||
"""Channel persona override takes priority over team default."""
|
||||
mock_team_config = MagicMock()
|
||||
mock_team_config.enabled = True
|
||||
mock_team_config.default_persona_id = 1
|
||||
|
||||
mock_channel_config = MagicMock()
|
||||
mock_channel_config.enabled = True
|
||||
mock_channel_config.require_bot_mention = False
|
||||
mock_channel_config.persona_override_id = 5
|
||||
|
||||
with patch(
|
||||
"onyx.onyxbot.teams.handle_message.get_session_with_tenant"
|
||||
) as mock_session:
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.teams.handle_message.get_team_config_by_teams_id",
|
||||
return_value=mock_team_config,
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.teams.handle_message.get_channel_config_by_teams_ids",
|
||||
return_value=mock_channel_config,
|
||||
),
|
||||
):
|
||||
result = should_respond(
|
||||
activity_dict={},
|
||||
team_id="team-123",
|
||||
channel_id="channel-456",
|
||||
tenant_id="tenant1",
|
||||
bot_id="bot-id",
|
||||
)
|
||||
|
||||
assert result.should_respond is True
|
||||
assert result.persona_id == 5
|
||||
104
backend/tests/unit/onyx/onyxbot/teams/test_utils.py
Normal file
104
backend/tests/unit/onyx/onyxbot/teams/test_utils.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Unit tests for Teams bot utility functions."""
|
||||
|
||||
from onyx.onyxbot.teams.utils import extract_channel_id
|
||||
from onyx.onyxbot.teams.utils import extract_team_id
|
||||
from onyx.onyxbot.teams.utils import extract_team_name
|
||||
from onyx.onyxbot.teams.utils import is_bot_mentioned
|
||||
from onyx.onyxbot.teams.utils import strip_bot_mention
|
||||
from onyx.server.manage.teams_bot.utils import generate_teams_registration_key
|
||||
from onyx.server.manage.teams_bot.utils import parse_teams_registration_key
|
||||
|
||||
|
||||
class TestExtractIds:
|
||||
"""Tests for ID extraction from Activity dicts."""
|
||||
|
||||
def test_extract_team_id_present(self, sample_activity_dict: dict) -> None:
|
||||
assert extract_team_id(sample_activity_dict) == "team-abc-123"
|
||||
|
||||
def test_extract_team_id_missing(self, sample_dm_activity_dict: dict) -> None:
|
||||
assert extract_team_id(sample_dm_activity_dict) is None
|
||||
|
||||
def test_extract_channel_id_present(self, sample_activity_dict: dict) -> None:
|
||||
assert extract_channel_id(sample_activity_dict) == "19:channel-xyz@thread.tacv2"
|
||||
|
||||
def test_extract_channel_id_missing(self, sample_dm_activity_dict: dict) -> None:
|
||||
assert extract_channel_id(sample_dm_activity_dict) is None
|
||||
|
||||
def test_extract_team_name(self, sample_activity_dict: dict) -> None:
|
||||
assert extract_team_name(sample_activity_dict) == "Engineering"
|
||||
|
||||
def test_extract_team_name_missing(self, sample_dm_activity_dict: dict) -> None:
|
||||
assert extract_team_name(sample_dm_activity_dict) is None
|
||||
|
||||
|
||||
class TestStripBotMention:
|
||||
"""Tests for bot mention stripping."""
|
||||
|
||||
def test_strip_named_mention(self) -> None:
|
||||
text = "<at>Onyx</at> What is our process?"
|
||||
assert strip_bot_mention(text, "Onyx") == "What is our process?"
|
||||
|
||||
def test_strip_case_insensitive(self) -> None:
|
||||
text = "<at>onyx</at> Hello"
|
||||
assert strip_bot_mention(text, "Onyx") == "Hello"
|
||||
|
||||
def test_strip_no_mention(self) -> None:
|
||||
text = "Just a normal message"
|
||||
assert strip_bot_mention(text, "Onyx") == "Just a normal message"
|
||||
|
||||
def test_strip_multiple_mentions(self) -> None:
|
||||
text = "<at>Onyx</at> hello <at>Onyx</at>"
|
||||
assert strip_bot_mention(text, "Onyx") == "hello"
|
||||
|
||||
def test_strip_empty_result(self) -> None:
|
||||
text = "<at>Onyx</at>"
|
||||
assert strip_bot_mention(text, "Onyx") == ""
|
||||
|
||||
|
||||
class TestIsBotMentioned:
|
||||
"""Tests for bot mention detection."""
|
||||
|
||||
def test_bot_mentioned(self, sample_activity_dict: dict) -> None:
|
||||
assert is_bot_mentioned(sample_activity_dict, "28:bot-id-456") is True
|
||||
|
||||
def test_bot_not_mentioned(self, sample_dm_activity_dict: dict) -> None:
|
||||
assert is_bot_mentioned(sample_dm_activity_dict, "28:bot-id-456") is False
|
||||
|
||||
def test_different_bot_mentioned(self, sample_activity_dict: dict) -> None:
|
||||
assert is_bot_mentioned(sample_activity_dict, "other-bot-id") is False
|
||||
|
||||
def test_no_entities(self) -> None:
|
||||
activity = {"entities": []}
|
||||
assert is_bot_mentioned(activity, "any-id") is False
|
||||
|
||||
|
||||
class TestRegistrationKeys:
|
||||
"""Tests for registration key generation and parsing."""
|
||||
|
||||
def test_generate_and_parse_roundtrip(self) -> None:
|
||||
key = generate_teams_registration_key("tenant1")
|
||||
parsed = parse_teams_registration_key(key)
|
||||
assert parsed == "tenant1"
|
||||
|
||||
def test_generate_has_correct_prefix(self) -> None:
|
||||
key = generate_teams_registration_key("tenant1")
|
||||
assert key.startswith("teams_")
|
||||
|
||||
def test_parse_invalid_prefix(self) -> None:
|
||||
assert parse_teams_registration_key("discord_tenant1.token") is None
|
||||
|
||||
def test_parse_no_separator(self) -> None:
|
||||
assert parse_teams_registration_key("teams_noseparator") is None
|
||||
|
||||
def test_parse_empty_string(self) -> None:
|
||||
assert parse_teams_registration_key("") is None
|
||||
|
||||
def test_generate_url_encodes_tenant(self) -> None:
|
||||
key = generate_teams_registration_key("tenant with spaces")
|
||||
parsed = parse_teams_registration_key(key)
|
||||
assert parsed == "tenant with spaces"
|
||||
|
||||
def test_generate_unique_keys(self) -> None:
|
||||
key1 = generate_teams_registration_key("tenant1")
|
||||
key2 = generate_teams_registration_key("tenant1")
|
||||
assert key1 != key2
|
||||
@@ -19,7 +19,6 @@ from ee.onyx.server.scim.models import ScimListResponse
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
from ee.onyx.server.scim.providers.entra import EntraProvider
|
||||
from ee.onyx.server.scim.providers.okta import OktaProvider
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
@@ -27,10 +26,6 @@ from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
|
||||
# Every supported SCIM provider must appear here so that all endpoint tests
|
||||
# run against it. When adding a new provider, add its class to this list.
|
||||
SCIM_PROVIDERS: list[type[ScimProvider]] = [OktaProvider, EntraProvider]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session() -> MagicMock:
|
||||
@@ -46,10 +41,10 @@ def mock_token() -> MagicMock:
|
||||
return token
|
||||
|
||||
|
||||
@pytest.fixture(params=SCIM_PROVIDERS, ids=[p.__name__ for p in SCIM_PROVIDERS])
|
||||
def provider(request: pytest.FixtureRequest) -> ScimProvider:
|
||||
"""Parameterized provider — runs each test with every provider in SCIM_PROVIDERS."""
|
||||
return request.param()
|
||||
@pytest.fixture
|
||||
def provider() -> ScimProvider:
|
||||
"""An OktaProvider instance for endpoint tests."""
|
||||
return OktaProvider()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.scim.auth import _hash_scim_token
|
||||
from ee.onyx.server.scim.auth import generate_scim_token
|
||||
from ee.onyx.server.scim.auth import SCIM_TOKEN_PREFIX
|
||||
from ee.onyx.server.scim.auth import ScimAuthError
|
||||
from ee.onyx.server.scim.auth import verify_scim_token
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ class TestVerifyScimToken:
|
||||
def test_missing_header_raises_401(self) -> None:
|
||||
request = self._make_request(None)
|
||||
dal = self._make_dal()
|
||||
with pytest.raises(ScimAuthError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
verify_scim_token(request, dal)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Missing" in str(exc_info.value.detail)
|
||||
@@ -68,7 +68,7 @@ class TestVerifyScimToken:
|
||||
def test_wrong_prefix_raises_401(self) -> None:
|
||||
request = self._make_request("Bearer on_some_api_key")
|
||||
dal = self._make_dal()
|
||||
with pytest.raises(ScimAuthError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
verify_scim_token(request, dal)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
@@ -76,7 +76,7 @@ class TestVerifyScimToken:
|
||||
raw, _, _ = generate_scim_token()
|
||||
request = self._make_request(f"Bearer {raw}")
|
||||
dal = self._make_dal(token=None)
|
||||
with pytest.raises(ScimAuthError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
verify_scim_token(request, dal)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid" in str(exc_info.value.detail)
|
||||
@@ -87,7 +87,7 @@ class TestVerifyScimToken:
|
||||
mock_token = MagicMock()
|
||||
mock_token.is_active = False
|
||||
dal = self._make_dal(token=mock_token)
|
||||
with pytest.raises(ScimAuthError) as exc_info:
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
verify_scim_token(request, dal)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "revoked" in str(exc_info.value.detail)
|
||||
|
||||
@@ -109,7 +109,7 @@ class TestOktaProvider:
|
||||
result = provider.build_user_resource(user, None)
|
||||
|
||||
assert result.name == ScimName(
|
||||
givenName="Madonna", familyName="", formatted="Madonna"
|
||||
givenName="Madonna", familyName=None, formatted="Madonna"
|
||||
)
|
||||
|
||||
def test_build_user_resource_no_name(self) -> None:
|
||||
@@ -117,7 +117,7 @@ class TestOktaProvider:
|
||||
user = _make_mock_user(personal_name=None)
|
||||
result = provider.build_user_resource(user, None)
|
||||
|
||||
assert result.name == ScimName(givenName="", familyName="", formatted="")
|
||||
assert result.name is None
|
||||
assert result.displayName is None
|
||||
|
||||
def test_build_user_resource_scim_username_preserves_case(self) -> None:
|
||||
|
||||
@@ -214,16 +214,13 @@ class TestCreateUser:
|
||||
mock_dal.add_user.assert_called_once()
|
||||
mock_dal.commit.assert_called_once()
|
||||
|
||||
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
|
||||
def test_missing_external_id_creates_user_without_mapping(
|
||||
def test_missing_external_id_returns_400(
|
||||
self,
|
||||
mock_seats: MagicMock, # noqa: ARG002
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
mock_dal: MagicMock, # noqa: ARG002
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_user_by_email.return_value = None
|
||||
resource = make_scim_user(externalId=None)
|
||||
|
||||
result = create_user(
|
||||
@@ -233,11 +230,7 @@ class TestCreateUser:
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
parsed = parse_scim_user(result, status=201)
|
||||
assert parsed.userName is not None
|
||||
mock_dal.add_user.assert_called_once()
|
||||
mock_dal.create_user_mapping.assert_not_called()
|
||||
mock_dal.commit.assert_called_once()
|
||||
assert_scim_error(result, 400)
|
||||
|
||||
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
|
||||
def test_duplicate_email_returns_409(
|
||||
|
||||
@@ -126,9 +126,7 @@ Resources:
|
||||
- Effect: Allow
|
||||
Action:
|
||||
- secretsmanager:GetSecretValue
|
||||
Resource:
|
||||
- !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/postgres/user/password-*
|
||||
- !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/onyx/user-auth-secret-*
|
||||
Resource: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/postgres/user/password-*
|
||||
|
||||
Outputs:
|
||||
OutputEcsCluster:
|
||||
|
||||
@@ -167,12 +167,10 @@ Resources:
|
||||
- ImportedNamespace: !ImportValue
|
||||
Fn::Sub: "${Environment}-onyx-cluster-OnyxNamespaceName"
|
||||
- Name: AUTH_TYPE
|
||||
Value: basic
|
||||
Value: disabled
|
||||
Secrets:
|
||||
- Name: POSTGRES_PASSWORD
|
||||
ValueFrom: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/postgres/user/password
|
||||
- Name: USER_AUTH_SECRET
|
||||
ValueFrom: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/onyx/user-auth-secret
|
||||
VolumesFrom: []
|
||||
SystemControls: []
|
||||
|
||||
|
||||
@@ -166,11 +166,9 @@ Resources:
|
||||
- ImportedNamespace: !ImportValue
|
||||
Fn::Sub: "${Environment}-onyx-cluster-OnyxNamespaceName"
|
||||
- Name: AUTH_TYPE
|
||||
Value: basic
|
||||
Value: disabled
|
||||
Secrets:
|
||||
- Name: POSTGRES_PASSWORD
|
||||
ValueFrom: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/postgres/user/password
|
||||
- Name: USER_AUTH_SECRET
|
||||
ValueFrom: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/onyx/user-auth-secret
|
||||
VolumesFrom: []
|
||||
SystemControls: []
|
||||
|
||||
@@ -16,15 +16,12 @@
|
||||
# This overlay:
|
||||
# - Moves Vespa (index), both model servers, and code-interpreter to profiles
|
||||
# so they do not start by default
|
||||
# - Moves the background worker to the "background" profile (the API server
|
||||
# handles all background work via FastAPI BackgroundTasks)
|
||||
# - Makes the depends_on references to removed services optional
|
||||
# - Sets DISABLE_VECTOR_DB=true on the api_server
|
||||
# - Makes the depends_on references to those services optional
|
||||
# - Sets DISABLE_VECTOR_DB=true on backend services
|
||||
#
|
||||
# To selectively bring services back:
|
||||
# --profile vectordb Vespa + indexing model server
|
||||
# --profile inference Inference model server
|
||||
# --profile background Background worker (Celery)
|
||||
# --profile code-interpreter Code interpreter
|
||||
# =============================================================================
|
||||
|
||||
@@ -46,20 +43,20 @@ services:
|
||||
- DISABLE_VECTOR_DB=true
|
||||
- FILE_STORE_BACKEND=postgres
|
||||
|
||||
# Move the background worker to a profile so it does not start by default.
|
||||
# The API server handles all background work in NO_VECTOR_DB mode.
|
||||
background:
|
||||
profiles: ["background"]
|
||||
depends_on:
|
||||
index:
|
||||
condition: service_started
|
||||
required: false
|
||||
inference_model_server:
|
||||
condition: service_started
|
||||
required: false
|
||||
indexing_model_server:
|
||||
condition: service_started
|
||||
required: false
|
||||
inference_model_server:
|
||||
condition: service_started
|
||||
required: false
|
||||
environment:
|
||||
- DISABLE_VECTOR_DB=true
|
||||
- FILE_STORE_BACKEND=postgres
|
||||
|
||||
# Move Vespa and indexing model server to a profile so they do not start.
|
||||
index:
|
||||
|
||||
@@ -65,7 +65,10 @@ services:
|
||||
- S3_AWS_SECRET_ACCESS_KEY=${S3_AWS_SECRET_ACCESS_KEY:-minioadmin}
|
||||
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
|
||||
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
|
||||
# API Server connection for Discord bot message processing
|
||||
- TEAMS_BOT_APP_ID=${TEAMS_BOT_APP_ID:-}
|
||||
- TEAMS_BOT_APP_SECRET=${TEAMS_BOT_APP_SECRET:-}
|
||||
- TEAMS_BOT_AZURE_TENANT_ID=${TEAMS_BOT_AZURE_TENANT_ID:-}
|
||||
# API Server connection for Discord/Teams bot message processing
|
||||
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
|
||||
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
|
||||
env_file:
|
||||
|
||||
@@ -87,7 +87,10 @@ services:
|
||||
- S3_AWS_SECRET_ACCESS_KEY=${S3_AWS_SECRET_ACCESS_KEY:-minioadmin}
|
||||
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
|
||||
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
|
||||
# API Server connection for Discord bot message processing
|
||||
- TEAMS_BOT_APP_ID=${TEAMS_BOT_APP_ID:-}
|
||||
- TEAMS_BOT_APP_SECRET=${TEAMS_BOT_APP_SECRET:-}
|
||||
- TEAMS_BOT_AZURE_TENANT_ID=${TEAMS_BOT_AZURE_TENANT_ID:-}
|
||||
# API Server connection for Discord/Teams bot message processing
|
||||
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
|
||||
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
|
||||
- PERSISTENT_DOCUMENT_STORAGE_PATH=${PERSISTENT_DOCUMENT_STORAGE_PATH:-/app/file-system}
|
||||
|
||||
@@ -161,7 +161,10 @@ services:
|
||||
- S3_AWS_SECRET_ACCESS_KEY=${S3_AWS_SECRET_ACCESS_KEY:-minioadmin}
|
||||
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
|
||||
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
|
||||
# API Server connection for Discord bot message processing
|
||||
- TEAMS_BOT_APP_ID=${TEAMS_BOT_APP_ID:-}
|
||||
- TEAMS_BOT_APP_SECRET=${TEAMS_BOT_APP_SECRET:-}
|
||||
- TEAMS_BOT_AZURE_TENANT_ID=${TEAMS_BOT_AZURE_TENANT_ID:-}
|
||||
# API Server connection for Discord/Teams bot message processing
|
||||
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
|
||||
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
|
||||
# Onyx Craft configuration (set up automatically on container startup)
|
||||
|
||||
@@ -103,6 +103,14 @@ MINIO_ROOT_PASSWORD=minioadmin
|
||||
## Command prefix for bot commands (default: "!")
|
||||
# DISCORD_BOT_INVOKE_CHAR=!
|
||||
|
||||
## Teams Bot Configuration
|
||||
## The Teams bot allows users to interact with Onyx from Microsoft Teams
|
||||
## App ID and Secret from Azure Bot Service registration
|
||||
# TEAMS_BOT_APP_ID=
|
||||
# TEAMS_BOT_APP_SECRET=
|
||||
## Azure tenant ID (optional, for single-tenant bots)
|
||||
# TEAMS_BOT_AZURE_TENANT_ID=
|
||||
|
||||
## Celery Configuration
|
||||
# CELERY_BROKER_POOL_LIMIT=
|
||||
# CELERY_WORKER_DOCFETCHING_CONCURRENCY=
|
||||
|
||||
@@ -19,6 +19,6 @@ dependencies:
|
||||
version: 5.4.0
|
||||
- name: code-interpreter
|
||||
repository: https://onyx-dot-app.github.io/python-sandbox/
|
||||
version: 0.3.1
|
||||
digest: sha256:4965b6ea3674c37163832a2192cd3bc8004f2228729fca170af0b9f457e8f987
|
||||
generated: "2026-03-02T15:29:39.632344-08:00"
|
||||
version: 0.3.0
|
||||
digest: sha256:cf8f01906d46034962c6ce894770621ee183ac761e6942951118aeb48540eddd
|
||||
generated: "2026-02-24T10:59:38.78318-08:00"
|
||||
|
||||
@@ -45,6 +45,6 @@ dependencies:
|
||||
repository: https://charts.min.io/
|
||||
condition: minio.enabled
|
||||
- name: code-interpreter
|
||||
version: 0.3.1
|
||||
version: 0.3.0
|
||||
repository: https://onyx-dot-app.github.io/python-sandbox/
|
||||
condition: codeInterpreter.enabled
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_beat.replicaCount) 0) }}
|
||||
{{- if gt (int .Values.celery_beat.replicaCount) 0 }}
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_heavy.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_heavy.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_heavy.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_heavy.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: ScaledObject
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_heavy.replicaCount) 0) }}
|
||||
{{- if gt (int .Values.celery_worker_heavy.replicaCount) 0 }}
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_light.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_light.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_light.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_light.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: ScaledObject
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_light.replicaCount) 0) }}
|
||||
{{- if gt (int .Values.celery_worker_light.replicaCount) 0 }}
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_monitoring.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_monitoring.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_monitoring.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_monitoring.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: ScaledObject
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_monitoring.replicaCount) 0) }}
|
||||
{{- if gt (int .Values.celery_worker_monitoring.replicaCount) 0 }}
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_primary.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_primary.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_primary.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_primary.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: ScaledObject
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_primary.replicaCount) 0) }}
|
||||
{{- if gt (int .Values.celery_worker_primary.replicaCount) 0 }}
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_user_file_processing.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_user_file_processing.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_user_file_processing.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
{{- if and (.Values.celery_worker_user_file_processing.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: ScaledObject
|
||||
metadata:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_user_file_processing.replicaCount) 0) }}
|
||||
{{- if gt (int .Values.celery_worker_user_file_processing.replicaCount) 0 }}
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
|
||||
26
deployment/helm/charts/onyx/templates/teamsbot-service.yaml
Normal file
26
deployment/helm/charts/onyx/templates/teamsbot-service.yaml
Normal file
@@ -0,0 +1,26 @@
|
||||
{{- if .Values.teamsbot.enabled }}
|
||||
# Service to expose the Teams bot /api/messages endpoint.
|
||||
# Unlike Discord (outbound WebSocket only), Teams requires an inbound HTTP endpoint
|
||||
# that Azure Bot Service can POST Activities to.
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ include "onyx.fullname" . }}-teamsbot
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 4 }}
|
||||
{{- with .Values.teamsbot.deploymentLabels }}
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
type: {{ .Values.teamsbot.service.type | default "ClusterIP" }}
|
||||
ports:
|
||||
- port: {{ .Values.teamsbot.service.port | default 80 }}
|
||||
targetPort: http
|
||||
protocol: TCP
|
||||
name: http
|
||||
selector:
|
||||
{{- include "onyx.selectorLabels" . | nindent 4 }}
|
||||
{{- if .Values.teamsbot.deploymentLabels }}
|
||||
{{- toYaml .Values.teamsbot.deploymentLabels | nindent 4 }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
131
deployment/helm/charts/onyx/templates/teamsbot.yaml
Normal file
131
deployment/helm/charts/onyx/templates/teamsbot.yaml
Normal file
@@ -0,0 +1,131 @@
|
||||
{{- if .Values.teamsbot.enabled }}
|
||||
# Teams bot receives webhooks via HTTP POST - supports multiple replicas behind a load balancer.
|
||||
# Unlike Discord (WebSocket, single replica), Teams is horizontally scalable.
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: {{ include "onyx.fullname" . }}-teamsbot
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 4 }}
|
||||
{{- with .Values.teamsbot.deploymentLabels }}
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
replicas: {{ .Values.teamsbot.replicaCount | default 1 }}
|
||||
strategy:
|
||||
type: RollingUpdate
|
||||
rollingUpdate:
|
||||
maxSurge: 1
|
||||
maxUnavailable: 0
|
||||
selector:
|
||||
matchLabels:
|
||||
{{- include "onyx.selectorLabels" . | nindent 6 }}
|
||||
{{- if .Values.teamsbot.deploymentLabels }}
|
||||
{{- toYaml .Values.teamsbot.deploymentLabels | nindent 6 }}
|
||||
{{- end }}
|
||||
template:
|
||||
metadata:
|
||||
annotations:
|
||||
checksum/config: {{ include (print $.Template.BasePath "/configmap.yaml") . | sha256sum }}
|
||||
{{- with .Values.teamsbot.podAnnotations }}
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 8 }}
|
||||
{{- with .Values.teamsbot.deploymentLabels }}
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.teamsbot.podLabels }}
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
{{- with .Values.imagePullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
serviceAccountName: {{ include "onyx.serviceAccountName" . }}
|
||||
securityContext:
|
||||
{{- toYaml .Values.teamsbot.podSecurityContext | nindent 8 }}
|
||||
{{- with .Values.teamsbot.nodeSelector }}
|
||||
nodeSelector:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.teamsbot.affinity }}
|
||||
affinity:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.teamsbot.tolerations }}
|
||||
tolerations:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
containers:
|
||||
- name: teamsbot
|
||||
securityContext:
|
||||
{{- toYaml .Values.teamsbot.securityContext | nindent 12 }}
|
||||
image: "{{ .Values.teamsbot.image.repository }}:{{ .Values.teamsbot.image.tag | default .Values.global.version }}"
|
||||
imagePullPolicy: {{ .Values.global.pullPolicy }}
|
||||
command: ["python", "onyx/onyxbot/teams/server.py"]
|
||||
ports:
|
||||
- name: http
|
||||
containerPort: {{ .Values.teamsbot.port | default 3978 }}
|
||||
protocol: TCP
|
||||
livenessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: http
|
||||
initialDelaySeconds: 15
|
||||
periodSeconds: 30
|
||||
readinessProbe:
|
||||
httpGet:
|
||||
path: /health
|
||||
port: http
|
||||
initialDelaySeconds: 10
|
||||
periodSeconds: 10
|
||||
resources:
|
||||
{{- toYaml .Values.teamsbot.resources | nindent 12 }}
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: {{ .Values.config.envConfigMapName }}
|
||||
env:
|
||||
{{- include "onyx.envSecrets" . | nindent 12}}
|
||||
# Teams bot App ID
|
||||
{{- if .Values.teamsbot.appId }}
|
||||
- name: TEAMS_BOT_APP_ID
|
||||
value: {{ .Values.teamsbot.appId | quote }}
|
||||
{{- end }}
|
||||
{{- if .Values.teamsbot.appIdSecretName }}
|
||||
- name: TEAMS_BOT_APP_ID
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: {{ .Values.teamsbot.appIdSecretName }}
|
||||
key: {{ .Values.teamsbot.appIdSecretKey | default "app-id" }}
|
||||
{{- end }}
|
||||
# Teams bot App Secret
|
||||
{{- if .Values.teamsbot.appSecret }}
|
||||
- name: TEAMS_BOT_APP_SECRET
|
||||
value: {{ .Values.teamsbot.appSecret | quote }}
|
||||
{{- end }}
|
||||
{{- if .Values.teamsbot.appSecretSecretName }}
|
||||
- name: TEAMS_BOT_APP_SECRET
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: {{ .Values.teamsbot.appSecretSecretName }}
|
||||
key: {{ .Values.teamsbot.appSecretSecretKey | default "app-secret" }}
|
||||
{{- end }}
|
||||
# Azure tenant ID (optional, for single-tenant bots)
|
||||
{{- if .Values.teamsbot.azureTenantId }}
|
||||
- name: TEAMS_BOT_AZURE_TENANT_ID
|
||||
value: {{ .Values.teamsbot.azureTenantId | quote }}
|
||||
{{- end }}
|
||||
# Bot port
|
||||
- name: TEAMS_BOT_PORT
|
||||
value: {{ .Values.teamsbot.port | default 3978 | quote }}
|
||||
{{- with .Values.teamsbot.volumeMounts }}
|
||||
volumeMounts:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
{{- with .Values.teamsbot.volumes }}
|
||||
volumes:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user