Compare commits

..

2 Commits

Author SHA1 Message Date
Nik
32766066b4 fix(teams): address PR review feedback from Cubic and Greptile
- server.py: Forward InvokeResponse from process_activity instead of
  discarding it, so Teams invoke activities get valid HTTP responses
- cache.py: Move data-building outside asyncio.Lock, only hold lock
  for the atomic swap of cache dicts (reduces contention)
- cache.py: Set CURRENT_TENANT_ID_CONTEXTVAR in refresh_entity to
  match refresh_all behavior
- cache.py: Narrow bare Exception catch to (OperationalError,
  ConnectionError, OSError) so programming errors propagate
- migration: Add CheckConstraint("id = 'SINGLETON'") to
  teams_bot_config table to enforce singleton
- handle_commands.py: Use SELECT FOR UPDATE on registration key row
  to prevent concurrent registration race condition
- teams_bot.py: Rename get_or_create_teams_service_api_key to
  provision_teams_service_api_key with explicit docstring about
  non-idempotent key regeneration behavior
- bot.py: Add debug log for ignored DMs, fix TODO format with owner
- app_configs.py: Use `or` pattern for TEAMS_BOT_PORT to handle
  empty env values safely
- handle_message.py: Convert ShouldRespondContext from BaseModel to
  dataclass for internal DTO
2026-03-02 21:12:16 -08:00
Nik
6562e63ab8 feat(teams): add Microsoft Teams bot integration
Add a full Microsoft Teams bot using botbuilder-core SDK v4, following
existing Discord bot patterns. Shared utilities (API client, cache base
class, exceptions, registration keys, constants) are extracted into
onyx/onyxbot/ to eliminate duplication between the two bots.

Includes: DB models + migration, HTTP server with Bot Framework adapter,
Adaptive Card responses with citations, multi-tenant cache, admin API,
Helm/Docker deployment, and 48 unit tests.
2026-03-02 12:20:06 -08:00
146 changed files with 3796 additions and 2937 deletions

View File

@@ -426,9 +426,8 @@ jobs:
ONYX_VERSION=${{ github.ref_name }}
NODE_OPTIONS=--max-old-space-size=8192
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64,mode=max
@@ -500,9 +499,8 @@ jobs:
ONYX_VERSION=${{ github.ref_name }}
NODE_OPTIONS=--max-old-space-size=8192
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64,mode=max
@@ -648,8 +646,8 @@ jobs:
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
NODE_OPTIONS=--max-old-space-size=8192
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64,mode=max
@@ -730,8 +728,8 @@ jobs:
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
NODE_OPTIONS=--max-old-space-size=8192
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64,mode=max
@@ -864,9 +862,8 @@ jobs:
build-args: |
ONYX_VERSION=${{ github.ref_name }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64,mode=max
@@ -937,9 +934,8 @@ jobs:
build-args: |
ONYX_VERSION=${{ github.ref_name }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64,mode=max
@@ -1076,8 +1072,8 @@ jobs:
ONYX_VERSION=${{ github.ref_name }}
ENABLE_CRAFT=true
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64,mode=max
@@ -1149,8 +1145,8 @@ jobs:
ONYX_VERSION=${{ github.ref_name }}
ENABLE_CRAFT=true
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64,mode=max
@@ -1291,9 +1287,8 @@ jobs:
build-args: |
ONYX_VERSION=${{ github.ref_name }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64,mode=max
@@ -1371,9 +1366,8 @@ jobs:
build-args: |
ONYX_VERSION=${{ github.ref_name }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64
cache-to: |
type=inline
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64,mode=max

View File

@@ -15,9 +15,6 @@ permissions:
jobs:
provider-chat-test:
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
permissions:
contents: read
id-token: write
with:
openai_models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
anthropic_models: ${{ vars.NIGHTLY_LLM_ANTHROPIC_MODELS }}
@@ -28,6 +25,16 @@ jobs:
ollama_models: ${{ vars.NIGHTLY_LLM_OLLAMA_MODELS }}
openrouter_models: ${{ vars.NIGHTLY_LLM_OPENROUTER_MODELS }}
strict: true
secrets:
openai_api_key: ${{ secrets.OPENAI_API_KEY }}
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
bedrock_api_key: ${{ secrets.BEDROCK_API_KEY }}
vertex_ai_custom_config_json: ${{ secrets.NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON }}
azure_api_key: ${{ secrets.AZURE_API_KEY }}
ollama_api_key: ${{ secrets.OLLAMA_API_KEY }}
openrouter_api_key: ${{ secrets.OPENROUTER_API_KEY }}
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
notify-slack-on-failure:
needs: [provider-chat-test]

View File

@@ -8,7 +8,7 @@ on:
pull_request:
branches:
- main
- "release/**"
- 'release/**'
push:
tags:
- "v*.*.*"
@@ -21,13 +21,7 @@ jobs:
# See https://runs-on.com/runners/linux/
# Note: Mypy seems quite optimized for x64 compared to arm64.
# Similarly, mypy is single-threaded and incremental, so 2cpu is sufficient.
runs-on:
[
runs-on,
runner=2cpu-linux-x64,
"run-id=${{ github.run_id }}-mypy-check",
"extras=s3-cache",
]
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}-mypy-check", "extras=s3-cache"]
timeout-minutes: 45
steps:
@@ -58,14 +52,21 @@ jobs:
if: ${{ vars.DISABLE_MYPY_CACHE != 'true' }}
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: .mypy_cache
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'pyproject.toml') }}
path: backend/.mypy_cache
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
restore-keys: |
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
mypy-${{ runner.os }}-
- name: Run MyPy
working-directory: ./backend
env:
MYPY_FORCE_COLOR: 1
TERM: xterm-256color
run: mypy .
- name: Run MyPy (tools/)
env:
MYPY_FORCE_COLOR: 1
TERM: xterm-256color
run: mypy tools/

View File

@@ -48,10 +48,28 @@ on:
required: false
default: true
type: boolean
secrets:
openai_api_key:
required: false
anthropic_api_key:
required: false
bedrock_api_key:
required: false
vertex_ai_custom_config_json:
required: false
azure_api_key:
required: false
ollama_api_key:
required: false
openrouter_api_key:
required: false
DOCKER_USERNAME:
required: true
DOCKER_TOKEN:
required: true
permissions:
contents: read
id-token: write
jobs:
build-backend-image:
@@ -63,7 +81,6 @@ jobs:
"extras=ecr-cache",
]
timeout-minutes: 45
environment: ci-protected
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
@@ -72,19 +89,6 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, test/docker-username
DOCKER_TOKEN, test/docker-token
- name: Build backend image
uses: ./.github/actions/build-backend-image
with:
@@ -93,8 +97,8 @@ jobs:
pr-number: ${{ github.event.pull_request.number }}
github-sha: ${{ github.sha }}
run-id: ${{ github.run_id }}
docker-username: ${{ env.DOCKER_USERNAME }}
docker-token: ${{ env.DOCKER_TOKEN }}
docker-username: ${{ secrets.DOCKER_USERNAME }}
docker-token: ${{ secrets.DOCKER_TOKEN }}
docker-no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' && 'true' || 'false' }}
build-model-server-image:
@@ -106,7 +110,6 @@ jobs:
"extras=ecr-cache",
]
timeout-minutes: 45
environment: ci-protected
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
@@ -115,19 +118,6 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, test/docker-username
DOCKER_TOKEN, test/docker-token
- name: Build model server image
uses: ./.github/actions/build-model-server-image
with:
@@ -136,8 +126,8 @@ jobs:
pr-number: ${{ github.event.pull_request.number }}
github-sha: ${{ github.sha }}
run-id: ${{ github.run_id }}
docker-username: ${{ env.DOCKER_USERNAME }}
docker-token: ${{ env.DOCKER_TOKEN }}
docker-username: ${{ secrets.DOCKER_USERNAME }}
docker-token: ${{ secrets.DOCKER_TOKEN }}
build-integration-image:
runs-on:
@@ -148,7 +138,6 @@ jobs:
"extras=ecr-cache",
]
timeout-minutes: 45
environment: ci-protected
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
@@ -157,19 +146,6 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, test/docker-username
DOCKER_TOKEN, test/docker-token
- name: Build integration image
uses: ./.github/actions/build-integration-image
with:
@@ -178,8 +154,8 @@ jobs:
pr-number: ${{ github.event.pull_request.number }}
github-sha: ${{ github.sha }}
run-id: ${{ github.run_id }}
docker-username: ${{ env.DOCKER_USERNAME }}
docker-token: ${{ env.DOCKER_TOKEN }}
docker-username: ${{ secrets.DOCKER_USERNAME }}
docker-token: ${{ secrets.DOCKER_TOKEN }}
provider-chat-test:
needs:
@@ -194,56 +170,56 @@ jobs:
include:
- provider: openai
models: ${{ inputs.openai_models }}
api_key_env: OPENAI_API_KEY
custom_config_env: ""
api_key_secret: openai_api_key
custom_config_secret: ""
api_base: ""
api_version: ""
deployment_name: ""
required: true
- provider: anthropic
models: ${{ inputs.anthropic_models }}
api_key_env: ANTHROPIC_API_KEY
custom_config_env: ""
api_key_secret: anthropic_api_key
custom_config_secret: ""
api_base: ""
api_version: ""
deployment_name: ""
required: true
- provider: bedrock
models: ${{ inputs.bedrock_models }}
api_key_env: BEDROCK_API_KEY
custom_config_env: ""
api_key_secret: bedrock_api_key
custom_config_secret: ""
api_base: ""
api_version: ""
deployment_name: ""
required: false
- provider: vertex_ai
models: ${{ inputs.vertex_ai_models }}
api_key_env: ""
custom_config_env: NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON
api_key_secret: ""
custom_config_secret: vertex_ai_custom_config_json
api_base: ""
api_version: ""
deployment_name: ""
required: false
- provider: azure
models: ${{ inputs.azure_models }}
api_key_env: AZURE_API_KEY
custom_config_env: ""
api_key_secret: azure_api_key
custom_config_secret: ""
api_base: ${{ inputs.azure_api_base }}
api_version: "2025-04-01-preview"
deployment_name: ""
required: false
- provider: ollama_chat
models: ${{ inputs.ollama_models }}
api_key_env: OLLAMA_API_KEY
custom_config_env: ""
api_key_secret: ollama_api_key
custom_config_secret: ""
api_base: "https://ollama.com"
api_version: ""
deployment_name: ""
required: false
- provider: openrouter
models: ${{ inputs.openrouter_models }}
api_key_env: OPENROUTER_API_KEY
custom_config_env: ""
api_key_secret: openrouter_api_key
custom_config_secret: ""
api_base: "https://openrouter.ai/api/v1"
api_version: ""
deployment_name: ""
@@ -254,7 +230,6 @@ jobs:
- "run-id=${{ github.run_id }}-nightly-${{ matrix.provider }}-provider-chat-test"
- extras=ecr-cache
timeout-minutes: 45
environment: ci-protected
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
@@ -263,43 +238,21 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
# Keep JSON values unparsed so vertex custom config is passed as raw JSON.
parse-json-secrets: false
secret-ids: |
DOCKER_USERNAME, test/docker-username
DOCKER_TOKEN, test/docker-token
OPENAI_API_KEY, test/openai-api-key
ANTHROPIC_API_KEY, test/anthropic-api-key
BEDROCK_API_KEY, test/bedrock-api-key
NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON, test/nightly-llm-vertex-ai-custom-config-json
AZURE_API_KEY, test/azure-api-key
OLLAMA_API_KEY, test/ollama-api-key
OPENROUTER_API_KEY, test/openrouter-api-key
- name: Run nightly provider chat test
uses: ./.github/actions/run-nightly-provider-chat-test
with:
provider: ${{ matrix.provider }}
models: ${{ matrix.models }}
provider-api-key: ${{ matrix.api_key_env && env[matrix.api_key_env] || '' }}
provider-api-key: ${{ matrix.api_key_secret && secrets[matrix.api_key_secret] || '' }}
strict: ${{ inputs.strict && 'true' || 'false' }}
api-base: ${{ matrix.api_base }}
api-version: ${{ matrix.api_version }}
deployment-name: ${{ matrix.deployment_name }}
custom-config-json: ${{ matrix.custom_config_env && env[matrix.custom_config_env] || '' }}
custom-config-json: ${{ matrix.custom_config_secret && secrets[matrix.custom_config_secret] || '' }}
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
run-id: ${{ github.run_id }}
docker-username: ${{ env.DOCKER_USERNAME }}
docker-token: ${{ env.DOCKER_TOKEN }}
docker-username: ${{ secrets.DOCKER_USERNAME }}
docker-token: ${{ secrets.DOCKER_TOKEN }}
- name: Dump API server logs
if: always()

View File

@@ -0,0 +1,105 @@
"""add teams bot tables
Revision ID: a1b2c3d4e5f6
Revises: 6b3b4083c5aa
Create Date: 2026-03-02 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "a1b2c3d4e5f6"
down_revision = "6b3b4083c5aa"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"teams_bot_config",
sa.Column(
"id",
sa.String(),
server_default=sa.text("'SINGLETON'"),
nullable=False,
),
sa.Column("app_id", sa.String(), nullable=False),
sa.Column("app_secret", sa.LargeBinary(), nullable=False),
sa.Column("azure_tenant_id", sa.String(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.PrimaryKeyConstraint("id"),
sa.CheckConstraint("id = 'SINGLETON'", name="ck_teams_bot_config_singleton"),
)
op.create_table(
"teams_team_config",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("team_id", sa.String(), nullable=True),
sa.Column("team_name", sa.String(length=256), nullable=True),
sa.Column("registration_key", sa.String(), nullable=False),
sa.Column("registered_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("default_persona_id", sa.Integer(), nullable=True),
sa.Column(
"enabled",
sa.Boolean(),
server_default=sa.text("true"),
nullable=False,
),
sa.ForeignKeyConstraint(
["default_persona_id"],
["persona.id"],
ondelete="SET NULL",
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("team_id"),
sa.UniqueConstraint("registration_key"),
)
op.create_table(
"teams_channel_config",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("team_config_id", sa.Integer(), nullable=False),
sa.Column("channel_id", sa.String(), nullable=False),
sa.Column("channel_name", sa.String(), nullable=False),
sa.Column(
"require_bot_mention",
sa.Boolean(),
server_default=sa.text("true"),
nullable=False,
),
sa.Column("persona_override_id", sa.Integer(), nullable=True),
sa.Column(
"enabled",
sa.Boolean(),
server_default=sa.text("false"),
nullable=False,
),
sa.ForeignKeyConstraint(
["team_config_id"],
["teams_team_config.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["persona_override_id"],
["persona.id"],
ondelete="SET NULL",
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint(
"team_config_id", "channel_id", name="uq_teams_channel_team_channel"
),
)
def downgrade() -> None:
op.drop_table("teams_channel_config")
op.drop_table("teams_team_config")
op.drop_table("teams_bot_config")

View File

@@ -31,7 +31,6 @@ from ee.onyx.server.query_and_chat.query_backend import (
from ee.onyx.server.query_and_chat.search_backend import router as search_router
from ee.onyx.server.query_history.api import router as query_history_router
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
from ee.onyx.server.scim.api import register_scim_exception_handlers
from ee.onyx.server.scim.api import scim_router
from ee.onyx.server.seeding import seed_db
from ee.onyx.server.tenants.api import router as tenants_router
@@ -168,7 +167,6 @@ def get_application() -> FastAPI:
# they use their own SCIM bearer token auth).
# Not behind APP_API_PREFIX because IdPs expect /scim/v2/... directly.
application.include_router(scim_router)
register_scim_exception_handlers(application)
# Ensure all routes have auth enabled or are explicitly marked as public
check_ee_router_auth(application)

View File

@@ -15,9 +15,7 @@ from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from fastapi import FastAPI
from fastapi import Query
from fastapi import Request
from fastapi import Response
from fastapi.responses import JSONResponse
from fastapi_users.password import PasswordHelper
@@ -26,7 +24,6 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.db.scim import ScimDAL
from ee.onyx.server.scim.auth import ScimAuthError
from ee.onyx.server.scim.auth import verify_scim_token
from ee.onyx.server.scim.filtering import parse_scim_filter
from ee.onyx.server.scim.models import SCIM_LIST_RESPONSE_SCHEMA
@@ -80,22 +77,6 @@ scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
_pw_helper = PasswordHelper()
def register_scim_exception_handlers(app: FastAPI) -> None:
"""Register SCIM-specific exception handlers on the FastAPI app.
Call this after ``app.include_router(scim_router)`` so that auth
failures from ``verify_scim_token`` return RFC 7644 §3.12 error
envelopes (with ``schemas`` and ``status`` fields) instead of
FastAPI's default ``{"detail": "..."}`` format.
"""
@app.exception_handler(ScimAuthError)
async def _handle_scim_auth_error(
_request: Request, exc: ScimAuthError
) -> ScimJSONResponse:
return _scim_error_response(exc.status_code, exc.detail)
def _get_provider(
_token: ScimToken = Depends(verify_scim_token),
) -> ScimProvider:
@@ -423,6 +404,12 @@ def create_user(
email = user_resource.userName.strip()
# externalId is how the IdP correlates this user on subsequent requests.
# Without it, the IdP can't find the user and will try to re-create,
# hitting a 409 conflict — so we require it up front.
if not user_resource.externalId:
return _scim_error_response(400, "externalId is required")
# Enforce seat limit
seat_error = _check_seat_availability(dal)
if seat_error:
@@ -449,19 +436,16 @@ def create_user(
dal.rollback()
return _scim_error_response(409, f"User with email {email} already exists")
# Create SCIM mapping when externalId is provided — this is how the IdP
# correlates this user on subsequent requests. Per RFC 7643, externalId
# is optional and assigned by the provisioning client.
# Create SCIM mapping (externalId is validated above, always present)
external_id = user_resource.externalId
scim_username = user_resource.userName.strip()
fields = _fields_from_resource(user_resource)
if external_id:
dal.create_user_mapping(
external_id=external_id,
user_id=user.id,
scim_username=scim_username,
fields=fields,
)
dal.create_user_mapping(
external_id=external_id,
user_id=user.id,
scim_username=scim_username,
fields=fields,
)
dal.commit()

View File

@@ -19,6 +19,7 @@ import hashlib
import secrets
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy.orm import Session
@@ -27,21 +28,6 @@ from onyx.auth.utils import get_hashed_bearer_token_from_request
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import ScimToken
class ScimAuthError(Exception):
"""Raised when SCIM bearer token authentication fails.
Unlike HTTPException, this carries the status and detail so the SCIM
exception handler can wrap them in an RFC 7644 §3.12 error envelope
with ``schemas`` and ``status`` fields.
"""
def __init__(self, status_code: int, detail: str) -> None:
self.status_code = status_code
self.detail = detail
super().__init__(detail)
SCIM_TOKEN_PREFIX = "onyx_scim_"
SCIM_TOKEN_LENGTH = 48
@@ -96,14 +82,23 @@ def verify_scim_token(
"""
hashed = _get_hashed_scim_token_from_request(request)
if not hashed:
raise ScimAuthError(401, "Missing or invalid SCIM bearer token")
raise HTTPException(
status_code=401,
detail="Missing or invalid SCIM bearer token",
)
token = dal.get_token_by_hash(hashed)
if not token:
raise ScimAuthError(401, "Invalid SCIM bearer token")
raise HTTPException(
status_code=401,
detail="Invalid SCIM bearer token",
)
if not token.is_active:
raise ScimAuthError(401, "SCIM token has been revoked")
raise HTTPException(
status_code=401,
detail="SCIM token has been revoked",
)
return token

View File

@@ -153,28 +153,26 @@ class ScimProvider(ABC):
self,
user: User,
fields: ScimMappingFields,
) -> ScimName:
) -> ScimName | None:
"""Build SCIM name components for the response.
Round-trips stored ``given_name``/``family_name`` when available (so
the IdP gets back what it sent). Falls back to splitting
``personal_name`` for users provisioned before we stored components.
Always returns a ScimName — Okta's spec tests expect ``name``
(with ``givenName``/``familyName``) on every user resource.
Providers may override for custom behavior.
"""
if fields.given_name is not None or fields.family_name is not None:
return ScimName(
givenName=fields.given_name or "",
familyName=fields.family_name or "",
formatted=user.personal_name or "",
givenName=fields.given_name,
familyName=fields.family_name,
formatted=user.personal_name,
)
if not user.personal_name:
return ScimName(givenName="", familyName="", formatted="")
return None
parts = user.personal_name.split(" ", 1)
return ScimName(
givenName=parts[0],
familyName=parts[1] if len(parts) > 1 else "",
familyName=parts[1] if len(parts) > 1 else None,
formatted=user.personal_name,
)

View File

@@ -32,16 +32,13 @@ PERIODIC_TASK_KV_PREFIX = "periodic_poller:last_claimed:"
# ------------------------------------------------------------------
_NEVER_RAN: float = -1e18
@dataclass
class _PeriodicTaskDef:
name: str
interval_seconds: float
lock_id: int
run_fn: Callable[[], None]
last_run_at: float = field(default=_NEVER_RAN)
last_run_at: float = field(default=0.0)
def _run_auto_llm_update() -> None:

View File

@@ -1,45 +0,0 @@
from collections.abc import Callable
from onyx.cache.interface import CacheBackend
from onyx.cache.interface import CacheBackendType
from onyx.configs.app_configs import CACHE_BACKEND
def _build_redis_backend(tenant_id: str) -> CacheBackend:
from onyx.cache.redis_backend import RedisCacheBackend
from onyx.redis.redis_pool import redis_pool
return RedisCacheBackend(redis_pool.get_client(tenant_id))
_BACKEND_BUILDERS: dict[CacheBackendType, Callable[[str], CacheBackend]] = {
CacheBackendType.REDIS: _build_redis_backend,
# CacheBackendType.POSTGRES will be added in a follow-up PR.
}
def get_cache_backend(*, tenant_id: str | None = None) -> CacheBackend:
"""Return a tenant-aware ``CacheBackend``.
If *tenant_id* is ``None``, the current tenant is read from the
thread-local context variable (same behaviour as ``get_redis_client``).
"""
if tenant_id is None:
from shared_configs.contextvars import get_current_tenant_id
tenant_id = get_current_tenant_id()
builder = _BACKEND_BUILDERS.get(CACHE_BACKEND)
if builder is None:
raise ValueError(
f"Unsupported CACHE_BACKEND={CACHE_BACKEND!r}. "
f"Supported values: {[t.value for t in CacheBackendType]}"
)
return builder(tenant_id)
def get_shared_cache_backend() -> CacheBackend:
"""Return a ``CacheBackend`` in the shared (cross-tenant) namespace."""
from shared_configs.configs import DEFAULT_REDIS_PREFIX
return get_cache_backend(tenant_id=DEFAULT_REDIS_PREFIX)

View File

@@ -1,89 +0,0 @@
import abc
from enum import Enum
class CacheBackendType(str, Enum):
REDIS = "redis"
POSTGRES = "postgres"
class CacheLock(abc.ABC):
"""Abstract distributed lock returned by CacheBackend.lock()."""
@abc.abstractmethod
def acquire(
self,
blocking: bool = True,
blocking_timeout: float | None = None,
) -> bool:
raise NotImplementedError
@abc.abstractmethod
def release(self) -> None:
raise NotImplementedError
@abc.abstractmethod
def owned(self) -> bool:
raise NotImplementedError
class CacheBackend(abc.ABC):
"""Thin abstraction over a key-value cache with TTL, locks, and blocking lists.
Covers the subset of Redis operations used outside of Celery. When
CACHE_BACKEND=postgres, a PostgreSQL-backed implementation is used instead.
"""
# -- basic key/value ---------------------------------------------------
@abc.abstractmethod
def get(self, key: str) -> bytes | None:
raise NotImplementedError
@abc.abstractmethod
def set(
self,
key: str,
value: str | bytes | int | float,
ex: int | None = None,
) -> None:
raise NotImplementedError
@abc.abstractmethod
def delete(self, key: str) -> None:
raise NotImplementedError
@abc.abstractmethod
def exists(self, key: str) -> bool:
raise NotImplementedError
# -- TTL ---------------------------------------------------------------
@abc.abstractmethod
def expire(self, key: str, seconds: int) -> None:
raise NotImplementedError
@abc.abstractmethod
def ttl(self, key: str) -> int:
"""Return remaining TTL in seconds. -1 if no expiry, -2 if key missing."""
raise NotImplementedError
# -- distributed lock --------------------------------------------------
@abc.abstractmethod
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
raise NotImplementedError
# -- blocking list (used by MCP OAuth BLPOP pattern) -------------------
@abc.abstractmethod
def rpush(self, key: str, value: str | bytes) -> None:
raise NotImplementedError
@abc.abstractmethod
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
"""Block until a value is available on one of *keys*, or *timeout* expires.
Returns ``(key, value)`` or ``None`` on timeout.
"""
raise NotImplementedError

View File

@@ -1,92 +0,0 @@
from typing import cast
from redis.client import Redis
from redis.lock import Lock as RedisLock
from onyx.cache.interface import CacheBackend
from onyx.cache.interface import CacheLock
class RedisCacheLock(CacheLock):
"""Wraps ``redis.lock.Lock`` behind the ``CacheLock`` interface."""
def __init__(self, lock: RedisLock) -> None:
self._lock = lock
def acquire(
self,
blocking: bool = True,
blocking_timeout: float | None = None,
) -> bool:
return bool(
self._lock.acquire(
blocking=blocking,
blocking_timeout=blocking_timeout,
)
)
def release(self) -> None:
self._lock.release()
def owned(self) -> bool:
return bool(self._lock.owned())
class RedisCacheBackend(CacheBackend):
"""``CacheBackend`` implementation that delegates to a ``redis.Redis`` client.
This is a thin pass-through — every method maps 1-to-1 to the underlying
Redis command. ``TenantRedis`` key-prefixing is handled by the client
itself (provided by ``get_redis_client``).
"""
def __init__(self, redis_client: Redis) -> None:
self._r = redis_client
# -- basic key/value ---------------------------------------------------
def get(self, key: str) -> bytes | None:
val = self._r.get(key)
if val is None:
return None
if isinstance(val, bytes):
return val
return str(val).encode()
def set(
self,
key: str,
value: str | bytes | int | float,
ex: int | None = None,
) -> None:
self._r.set(key, value, ex=ex)
def delete(self, key: str) -> None:
self._r.delete(key)
def exists(self, key: str) -> bool:
return bool(self._r.exists(key))
# -- TTL ---------------------------------------------------------------
def expire(self, key: str, seconds: int) -> None:
self._r.expire(key, seconds)
def ttl(self, key: str) -> int:
return cast(int, self._r.ttl(key))
# -- distributed lock --------------------------------------------------
def lock(self, name: str, timeout: float | None = None) -> CacheLock:
return RedisCacheLock(self._r.lock(name, timeout=timeout))
# -- blocking list (MCP OAuth BLPOP pattern) ---------------------------
def rpush(self, key: str, value: str | bytes) -> None:
self._r.rpush(key, value)
def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None:
result = cast(list[bytes] | None, self._r.blpop(keys, timeout=timeout))
if result is None:
return None
return (result[0], result[1])

View File

@@ -6,7 +6,6 @@ from datetime import timezone
from typing import cast
from onyx.auth.schemas import AuthBackend
from onyx.cache.interface import CacheBackendType
from onyx.configs.constants import AuthType
from onyx.configs.constants import QueryHistoryType
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
@@ -55,12 +54,6 @@ DISABLE_USER_KNOWLEDGE = os.environ.get("DISABLE_USER_KNOWLEDGE", "").lower() ==
# are disabled but core chat, tools, user file uploads, and Projects still work.
DISABLE_VECTOR_DB = os.environ.get("DISABLE_VECTOR_DB", "").lower() == "true"
# Which backend to use for caching, locks, and ephemeral state.
# "redis" (default) or "postgres" (only valid when DISABLE_VECTOR_DB=true).
CACHE_BACKEND = CacheBackendType(
os.environ.get("CACHE_BACKEND", CacheBackendType.REDIS)
)
# Maximum token count for a single uploaded file. Files exceeding this are rejected.
# Defaults to 100k tokens (or 10M when vector DB is disabled).
_DEFAULT_FILE_TOKEN_LIMIT = 10_000_000 if DISABLE_VECTOR_DB else 100_000
@@ -1130,6 +1123,13 @@ DISCORD_BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
DISCORD_BOT_INVOKE_CHAR = os.environ.get("DISCORD_BOT_INVOKE_CHAR", "!")
## Teams Bot Configuration
TEAMS_BOT_APP_ID = os.environ.get("TEAMS_BOT_APP_ID")
TEAMS_BOT_APP_SECRET = os.environ.get("TEAMS_BOT_APP_SECRET")
TEAMS_BOT_AZURE_TENANT_ID = os.environ.get("TEAMS_BOT_AZURE_TENANT_ID")
TEAMS_BOT_PORT = int(os.environ.get("TEAMS_BOT_PORT") or "3978")
## Stripe Configuration
# URL to fetch the Stripe publishable key from a public S3 bucket.
# Publishable keys are safe to expose publicly - they can only initialize

View File

@@ -99,6 +99,7 @@ DANSWER_API_KEY_PREFIX = "API_KEY__"
DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN = "onyxapikey.ai"
UNNAMED_KEY_PLACEHOLDER = "Unnamed"
DISCORD_SERVICE_API_KEY_NAME = "discord-bot-service"
TEAMS_SERVICE_API_KEY_NAME = "teams-bot-service"
# Key-Value store keys
KV_REINDEX_KEY = "needs_reindexing"

View File

@@ -20,6 +20,7 @@ from fastapi_users_db_sqlalchemy import SQLAlchemyBaseUserTableUUID
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTableUUID
from fastapi_users_db_sqlalchemy.generics import TIMESTAMPAware
from sqlalchemy import Boolean
from sqlalchemy import CheckConstraint
from sqlalchemy import DateTime
from sqlalchemy import desc
from sqlalchemy import Enum
@@ -3668,6 +3669,115 @@ class DiscordChannelConfig(Base):
)
class TeamsBotConfig(Base):
"""Global Teams bot configuration (one per tenant).
Stores the Azure Bot Service credentials when not provided via env vars.
Uses a fixed ID with check constraint to enforce only one row per tenant.
"""
__tablename__ = "teams_bot_config"
__table_args__ = (
CheckConstraint("id = 'SINGLETON'", name="ck_teams_bot_config_singleton"),
)
id: Mapped[str] = mapped_column(
String, primary_key=True, server_default=text("'SINGLETON'")
)
app_id: Mapped[str] = mapped_column(String, nullable=False)
app_secret: Mapped[SensitiveValue[str] | None] = mapped_column(
EncryptedString(), nullable=False
)
azure_tenant_id: Mapped[str | None] = mapped_column(String, nullable=True)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
class TeamsTeamConfig(Base):
"""Configuration for a Teams team connected to this tenant.
registration_key is a one-time key used to link a Teams team to this tenant.
Format: teams_<tenant_id>.<random_token>
team_id is NULL until the Teams admin runs @bot register with the key.
"""
__tablename__ = "teams_team_config"
id: Mapped[int] = mapped_column(primary_key=True)
# Teams team ID (GUID string) - NULL until registered via command in Teams
team_id: Mapped[str | None] = mapped_column(String, nullable=True, unique=True)
team_name: Mapped[str | None] = mapped_column(String(256), nullable=True)
# One-time registration key: teams_<tenant_id>.<random_token>
registration_key: Mapped[str] = mapped_column(String, unique=True, nullable=False)
registered_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
# Configuration
default_persona_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id", ondelete="SET NULL"), nullable=True
)
enabled: Mapped[bool] = mapped_column(
Boolean, server_default=text("true"), nullable=False
)
# Relationships
default_persona: Mapped["Persona | None"] = relationship(
"Persona", foreign_keys=[default_persona_id]
)
channels: Mapped[list["TeamsChannelConfig"]] = relationship(
back_populates="team_config", cascade="all, delete-orphan"
)
class TeamsChannelConfig(Base):
"""Per-channel configuration for Teams bot behavior.
Used to whitelist specific channels and configure per-channel behavior.
"""
__tablename__ = "teams_channel_config"
id: Mapped[int] = mapped_column(primary_key=True)
team_config_id: Mapped[int] = mapped_column(
ForeignKey("teams_team_config.id", ondelete="CASCADE"), nullable=False
)
# Teams channel ID (string)
channel_id: Mapped[str] = mapped_column(String, nullable=False)
channel_name: Mapped[str] = mapped_column(String(), nullable=False)
# If true (default), bot only responds when @mentioned
# If false, bot responds to ALL messages in this channel
require_bot_mention: Mapped[bool] = mapped_column(
Boolean, server_default=text("true"), nullable=False
)
# Override the team's default persona for this channel
persona_override_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id", ondelete="SET NULL"), nullable=True
)
enabled: Mapped[bool] = mapped_column(
Boolean, server_default=text("false"), nullable=False
)
# Relationships
team_config: Mapped["TeamsTeamConfig"] = relationship(back_populates="channels")
persona_override: Mapped["Persona | None"] = relationship()
# Constraints
__table_args__ = (
UniqueConstraint(
"team_config_id", "channel_id", name="uq_teams_channel_team_channel"
),
)
class Milestone(Base):
# This table is used to track significant events for a deployment towards finding value
# The table is currently not used for features but it may be used in the future to inform

View File

@@ -52,7 +52,7 @@ def create_user_files(
) -> CategorizedFilesResult:
# Categorize the files
categorized_files = categorize_uploaded_files(files, db_session)
categorized_files = categorize_uploaded_files(files)
# NOTE: At the moment, zip metadata is not used for user files.
# Should revisit to decide whether this should be a feature.
upload_response = upload_files(categorized_files.acceptable, FileOrigin.USER_FILE)

View File

@@ -0,0 +1,331 @@
"""CRUD operations for Teams bot models."""
from datetime import datetime
from datetime import timezone
from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from onyx.auth.api_key import build_displayable_api_key
from onyx.auth.api_key import generate_api_key
from onyx.auth.api_key import hash_api_key
from onyx.auth.schemas import UserRole
from onyx.configs.constants import TEAMS_SERVICE_API_KEY_NAME
from onyx.db.api_key import insert_api_key
from onyx.db.models import ApiKey
from onyx.db.models import TeamsBotConfig
from onyx.db.models import TeamsChannelConfig
from onyx.db.models import TeamsTeamConfig
from onyx.db.models import User
from onyx.server.api_key.models import APIKeyArgs
from onyx.utils.logger import setup_logger
logger = setup_logger()
# === TeamsBotConfig ===
def get_teams_bot_config(db_session: Session) -> TeamsBotConfig | None:
"""Get the Teams bot config for this tenant (at most one)."""
return db_session.scalar(select(TeamsBotConfig).limit(1))
def create_teams_bot_config(
db_session: Session,
app_id: str,
app_secret: str,
azure_tenant_id: str | None = None,
) -> TeamsBotConfig:
"""Create the Teams bot config. Raises ValueError if already exists.
The check constraint on id='SINGLETON' ensures only one config per tenant.
"""
existing = get_teams_bot_config(db_session)
if existing:
raise ValueError("Teams bot config already exists")
config = TeamsBotConfig(
app_id=app_id,
app_secret=app_secret,
azure_tenant_id=azure_tenant_id,
)
db_session.add(config)
try:
db_session.flush()
except IntegrityError:
db_session.rollback()
raise ValueError("Teams bot config already exists")
return config
def delete_teams_bot_config(db_session: Session) -> bool:
"""Delete the Teams bot config. Returns True if deleted."""
result = db_session.execute(delete(TeamsBotConfig))
db_session.flush()
return result.rowcount > 0 # type: ignore[attr-defined]
# === Teams Service API Key ===
def get_teams_service_api_key(db_session: Session) -> ApiKey | None:
"""Get the Teams service API key if it exists."""
return db_session.scalar(
select(ApiKey).where(ApiKey.name == TEAMS_SERVICE_API_KEY_NAME)
)
def provision_teams_service_api_key(
db_session: Session,
tenant_id: str,
) -> str:
"""Create or regenerate the Teams service API key, returning the raw key.
The database only stores the hashed key. When the cache is cold
(e.g. after a pod restart), the raw key is unrecoverable, so we
regenerate a new one and update the stored hash. This is safe because
the bot is the sole consumer of this key.
This function is **not** idempotent — it mutates the stored hash on
every call when a key already exists. Only call it on cache miss.
"""
existing = get_teams_service_api_key(db_session)
if existing:
logger.debug(
f"Regenerating Teams service API key for tenant {tenant_id} "
"(raw key unrecoverable from hash)"
)
new_api_key = generate_api_key(tenant_id)
existing.hashed_api_key = hash_api_key(new_api_key)
existing.api_key_display = build_displayable_api_key(new_api_key)
db_session.flush()
return new_api_key
logger.info(f"Creating Teams service API key for tenant {tenant_id}")
api_key_args = APIKeyArgs(
name=TEAMS_SERVICE_API_KEY_NAME,
role=UserRole.LIMITED,
)
api_key_descriptor = insert_api_key(
db_session=db_session,
api_key_args=api_key_args,
user_id=None,
)
if not api_key_descriptor.api_key:
raise RuntimeError(
f"Failed to create Teams service API key for tenant {tenant_id}"
)
return api_key_descriptor.api_key
def delete_teams_service_api_key(db_session: Session) -> bool:
"""Delete the Teams service API key for a tenant.
Called when:
- Bot config is deleted (self-hosted)
- All team configs are deleted (Cloud)
"""
existing_key = get_teams_service_api_key(db_session)
if not existing_key:
return False
api_key_user = db_session.scalar(
select(User).where(User.id == existing_key.user_id) # type: ignore[arg-type]
)
db_session.delete(existing_key)
if api_key_user:
db_session.delete(api_key_user)
db_session.flush()
logger.info("Deleted Teams service API key")
return True
# === TeamsTeamConfig ===
def get_team_configs(
db_session: Session,
include_channels: bool = False,
) -> list[TeamsTeamConfig]:
"""Get all team configs for this tenant."""
stmt = select(TeamsTeamConfig)
if include_channels:
stmt = stmt.options(joinedload(TeamsTeamConfig.channels))
return list(db_session.scalars(stmt).unique().all())
def get_team_config_by_internal_id(
db_session: Session,
internal_id: int,
) -> TeamsTeamConfig | None:
"""Get a specific team config by its ID."""
return db_session.scalar(
select(TeamsTeamConfig).where(TeamsTeamConfig.id == internal_id)
)
def get_team_config_by_teams_id(
db_session: Session,
team_id: str,
) -> TeamsTeamConfig | None:
"""Get a team config by Teams team ID."""
return db_session.scalar(
select(TeamsTeamConfig).where(TeamsTeamConfig.team_id == team_id)
)
def get_team_config_by_registration_key(
db_session: Session,
registration_key: str,
for_update: bool = False,
) -> TeamsTeamConfig | None:
"""Get a team config by its registration key.
Use ``for_update=True`` to acquire a row-level lock, preventing
concurrent registration races.
"""
stmt = select(TeamsTeamConfig).where(
TeamsTeamConfig.registration_key == registration_key
)
if for_update:
stmt = stmt.with_for_update()
return db_session.scalar(stmt)
def create_team_config(
db_session: Session,
registration_key: str,
) -> TeamsTeamConfig:
"""Create a new team config with a registration key (team_id=NULL)."""
config = TeamsTeamConfig(registration_key=registration_key)
db_session.add(config)
db_session.flush()
return config
def register_team(
db_session: Session,
config: TeamsTeamConfig,
team_id: str,
team_name: str,
) -> TeamsTeamConfig:
"""Complete registration by setting team_id and team_name."""
config.team_id = team_id
config.team_name = team_name
config.registered_at = datetime.now(timezone.utc)
db_session.flush()
return config
def update_team_config(
db_session: Session,
config: TeamsTeamConfig,
enabled: bool,
default_persona_id: int | None = None,
) -> TeamsTeamConfig:
"""Update team config fields."""
config.enabled = enabled
config.default_persona_id = default_persona_id
db_session.flush()
return config
def delete_team_config(
db_session: Session,
internal_id: int,
) -> bool:
"""Delete team config (cascades to channel configs). Returns True if deleted."""
result = db_session.execute(
delete(TeamsTeamConfig).where(TeamsTeamConfig.id == internal_id)
)
db_session.flush()
return result.rowcount > 0 # type: ignore[attr-defined]
# === TeamsChannelConfig ===
def get_channel_configs(
db_session: Session,
team_config_id: int,
) -> list[TeamsChannelConfig]:
"""Get all channel configs for a team."""
return list(
db_session.scalars(
select(TeamsChannelConfig).where(
TeamsChannelConfig.team_config_id == team_config_id
)
).all()
)
def get_channel_config_by_teams_ids(
db_session: Session,
team_id: str,
channel_id: str,
) -> TeamsChannelConfig | None:
"""Get a specific channel config by team_id and channel_id."""
return db_session.scalar(
select(TeamsChannelConfig)
.join(TeamsTeamConfig)
.where(
TeamsTeamConfig.team_id == team_id,
TeamsChannelConfig.channel_id == channel_id,
)
)
def get_channel_config_by_internal_ids(
db_session: Session,
team_config_id: int,
channel_config_id: int,
) -> TeamsChannelConfig | None:
"""Get a specific channel config by team_config_id and channel_config_id."""
return db_session.scalar(
select(TeamsChannelConfig).where(
TeamsChannelConfig.team_config_id == team_config_id,
TeamsChannelConfig.id == channel_config_id,
)
)
def update_teams_channel_config(
db_session: Session,
config: TeamsChannelConfig,
channel_name: str,
require_bot_mention: bool,
enabled: bool,
persona_override_id: int | None = None,
) -> TeamsChannelConfig:
"""Update channel config fields."""
config.channel_name = channel_name
config.require_bot_mention = require_bot_mention
config.persona_override_id = persona_override_id
config.enabled = enabled
db_session.flush()
return config
def create_channel_config(
db_session: Session,
team_config_id: int,
channel_id: str,
channel_name: str,
) -> TeamsChannelConfig:
"""Create a new channel config with default settings (disabled by default)."""
config = TeamsChannelConfig(
team_config_id=team_config_id,
channel_id=channel_id,
channel_name=channel_name,
)
db_session.add(config)
db_session.flush()
return config

View File

@@ -6,6 +6,7 @@ import httpx
from opensearchpy import NotFoundError
from onyx.access.models import DocumentAccess
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
from onyx.configs.chat_configs import NUM_RETURNED_HITS
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
@@ -562,7 +563,12 @@ class OpenSearchDocumentIndex(DocumentIndex):
)
if not self._client.index_exists():
index_settings = DocumentSchema.get_index_settings_based_on_environment()
if USING_AWS_MANAGED_OPENSEARCH:
index_settings = (
DocumentSchema.get_index_settings_for_aws_managed_opensearch()
)
else:
index_settings = DocumentSchema.get_index_settings()
self._client.create_index(
mappings=expected_mappings,
settings=index_settings,

View File

@@ -12,7 +12,6 @@ from pydantic import model_validator
from pydantic import SerializerFunctionWrapHandler
from onyx.configs.app_configs import OPENSEARCH_TEXT_ANALYZER
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
from onyx.document_index.opensearch.constants import EF_CONSTRUCTION
@@ -526,7 +525,7 @@ class DocumentSchema:
}
@staticmethod
def get_index_settings_for_aws_managed_opensearch_st_dev() -> dict[str, Any]:
def get_index_settings_for_aws_managed_opensearch() -> dict[str, Any]:
"""
Settings for AWS-managed OpenSearch.
@@ -547,41 +546,3 @@ class DocumentSchema:
"knn.algo_param.ef_search": EF_SEARCH,
}
}
@staticmethod
def get_index_settings_for_aws_managed_opensearch_mt_cloud() -> dict[str, Any]:
"""
Settings for AWS-managed OpenSearch in multi-tenant cloud.
324 shards very roughly targets a storage load of ~30Gb per shard, which
according to AWS OpenSearch documentation is within a good target range.
As documented above we need 2 replicas for a total of 3 copies of the
data because the cluster is configured with 3-AZ awareness.
"""
return {
"index": {
"number_of_shards": 324,
"number_of_replicas": 2,
# Required for vector search.
"knn": True,
"knn.algo_param.ef_search": EF_SEARCH,
}
}
@staticmethod
def get_index_settings_based_on_environment() -> dict[str, Any]:
"""
Returns the index settings based on the environment.
"""
if USING_AWS_MANAGED_OPENSEARCH:
if MULTI_TENANT:
return (
DocumentSchema.get_index_settings_for_aws_managed_opensearch_mt_cloud()
)
else:
return (
DocumentSchema.get_index_settings_for_aws_managed_opensearch_st_dev()
)
else:
return DocumentSchema.get_index_settings()

View File

@@ -67,18 +67,6 @@ Status checked against LiteLLM v1.81.6-nightly (2026-02-02):
STATUS: STILL NEEDED - litellm_core_utils/litellm_logging.py lines 3185-3199 set
usage as a dict with chat completion format instead of keeping it as
ResponseAPIUsage. Our patch creates a deep copy before modification.
7. Responses API metadata=None TypeError (_patch_responses_metadata_none):
- LiteLLM's @client decorator wrapper in utils.py uses kwargs.get("metadata", {})
to check for router calls, but when metadata is explicitly None (key exists with
value None), the default {} is not used
- This causes "argument of type 'NoneType' is not iterable" TypeError which swallows
the real exception (e.g. AuthenticationError for wrong API key)
- Surfaces as: APIConnectionError: OpenAIException - argument of type 'NoneType' is
not iterable
STATUS: STILL NEEDED - litellm/utils.py wrapper function (line 1721) does not guard
against metadata being explicitly None. Triggered when Responses API bridge
passes **litellm_params containing metadata=None.
"""
import time
@@ -737,44 +725,6 @@ def _patch_logging_assembled_streaming_response() -> None:
LiteLLMLoggingObj._get_assembled_streaming_response = _patched_get_assembled_streaming_response # type: ignore[method-assign]
def _patch_responses_metadata_none() -> None:
"""
Patches litellm.responses to normalize metadata=None to metadata={} in kwargs.
LiteLLM's @client decorator wrapper in utils.py (line 1721) does:
_is_litellm_router_call = "model_group" in kwargs.get("metadata", {})
When metadata is explicitly None in kwargs, kwargs.get("metadata", {}) returns
None (the key exists, so the default is not used), causing:
TypeError: argument of type 'NoneType' is not iterable
This swallows the real exception (e.g. AuthenticationError) and surfaces as:
APIConnectionError: OpenAIException - argument of type 'NoneType' is not iterable
This happens when the Responses API bridge calls litellm.responses() with
**litellm_params which may contain metadata=None.
STATUS: STILL NEEDED - litellm/utils.py wrapper function uses kwargs.get("metadata", {})
which does not guard against metadata being explicitly None. Same pattern exists
on line 1407 for async path.
"""
import litellm as _litellm
from functools import wraps
original_responses = _litellm.responses
if getattr(original_responses, "_metadata_patched", False):
return
@wraps(original_responses)
def _patched_responses(*args: Any, **kwargs: Any) -> Any:
if kwargs.get("metadata") is None:
kwargs["metadata"] = {}
return original_responses(*args, **kwargs)
_patched_responses._metadata_patched = True # type: ignore[attr-defined]
_litellm.responses = _patched_responses
def apply_monkey_patches() -> None:
"""
Apply all necessary monkey patches to LiteLLM for compatibility.
@@ -786,7 +736,6 @@ def apply_monkey_patches() -> None:
- Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming
- Patching ResponsesAPIResponse.model_construct to fix usage format in all code paths
- Patching LiteLLMLoggingObj._get_assembled_streaming_response to avoid mutating original response
- Patching litellm.responses to fix metadata=None causing TypeError in error handling
"""
_patch_ollama_chunk_parser()
_patch_openai_responses_parallel_tool_calls()
@@ -794,4 +743,3 @@ def apply_monkey_patches() -> None:
_patch_azure_responses_should_fake_stream()
_patch_responses_api_usage_format()
_patch_logging_assembled_streaming_response()
_patch_responses_metadata_none()

View File

@@ -32,13 +32,11 @@ from onyx.auth.schemas import UserUpdate
from onyx.auth.users import auth_backend
from onyx.auth.users import create_onyx_oauth_router
from onyx.auth.users import fastapi_users
from onyx.cache.interface import CacheBackendType
from onyx.configs.app_configs import APP_API_PREFIX
from onyx.configs.app_configs import APP_HOST
from onyx.configs.app_configs import APP_PORT
from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import CACHE_BACKEND
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import LOG_ENDPOINT_LATENCY
from onyx.configs.app_configs import OAUTH_CLIENT_ID
@@ -117,6 +115,7 @@ from onyx.server.manage.opensearch_migration.api import (
)
from onyx.server.manage.search_settings import router as search_settings_router
from onyx.server.manage.slack_bot import router as slack_bot_management_router
from onyx.server.manage.teams_bot.api import router as teams_bot_router
from onyx.server.manage.users import router as user_router
from onyx.server.manage.web_search.api import (
admin_router as web_search_admin_router,
@@ -257,20 +256,6 @@ def include_auth_router_with_prefix(
)
def validate_cache_backend_settings() -> None:
"""Validate that CACHE_BACKEND=postgres is only used with DISABLE_VECTOR_DB.
The Postgres cache backend eliminates the Redis dependency, but only works
when Celery is not running (which requires DISABLE_VECTOR_DB=true).
"""
if CACHE_BACKEND == CacheBackendType.POSTGRES and not DISABLE_VECTOR_DB:
raise RuntimeError(
"CACHE_BACKEND=postgres requires DISABLE_VECTOR_DB=true. "
"The Postgres cache backend is only supported in no-vector-DB "
"deployments where Celery is replaced by the in-process task runner."
)
def validate_no_vector_db_settings() -> None:
"""Validate that DISABLE_VECTOR_DB is not combined with incompatible settings.
@@ -302,7 +287,6 @@ def validate_no_vector_db_settings() -> None:
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
validate_no_vector_db_settings()
validate_cache_backend_settings()
# Set recursion limit
if SYSTEM_RECURSION_LIMIT is not None:
@@ -466,6 +450,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
application, slack_bot_management_router
)
include_router_with_global_prefix_prepended(application, discord_bot_router)
include_router_with_global_prefix_prepended(application, teams_bot_router)
include_router_with_global_prefix_prepended(application, persona_router)
include_router_with_global_prefix_prepended(application, admin_persona_router)
include_router_with_global_prefix_prepended(application, agents_router)

View File

View File

@@ -1,12 +1,12 @@
"""Async HTTP client for communicating with Onyx API pods."""
"""Shared async HTTP client for communicating with Onyx API pods."""
import aiohttp
from onyx.chat.models import ChatFullResponse
from onyx.onyxbot.discord.constants import API_REQUEST_TIMEOUT
from onyx.onyxbot.discord.exceptions import APIConnectionError
from onyx.onyxbot.discord.exceptions import APIResponseError
from onyx.onyxbot.discord.exceptions import APITimeoutError
from onyx.onyxbot.constants import API_REQUEST_TIMEOUT
from onyx.onyxbot.exceptions import APIConnectionError
from onyx.onyxbot.exceptions import APIResponseError
from onyx.onyxbot.exceptions import APITimeoutError
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
from onyx.server.query_and_chat.models import MessageOrigin
from onyx.server.query_and_chat.models import SendMessageRequest
@@ -19,36 +19,17 @@ logger = setup_logger()
class OnyxAPIClient:
"""Async HTTP client for sending chat requests to Onyx API pods.
This client manages an aiohttp session for making non-blocking HTTP
requests to the Onyx API server. It handles authentication with per-tenant
API keys and multi-tenant routing.
Usage:
client = OnyxAPIClient()
await client.initialize()
try:
response = await client.send_chat_message(
message="What is our deployment process?",
tenant_id="tenant_123",
api_key="dn_xxx...",
persona_id=1,
)
print(response.answer)
finally:
await client.close()
Used by both Discord and Teams bots. The ``origin`` parameter controls
which ``MessageOrigin`` value is attached to outgoing requests for
telemetry tracking.
"""
def __init__(
self,
origin: MessageOrigin,
timeout: int = API_REQUEST_TIMEOUT,
) -> None:
"""Initialize the API client.
Args:
timeout: Request timeout in seconds.
"""
# Helm chart uses API_SERVER_URL_OVERRIDE_FOR_HTTP_REQUESTS to set the base URL
# TODO: Ideally, this override is only used when someone is launching an Onyx service independently
self._origin = origin
self._base_url = build_api_server_url_for_http_requests(
respect_env_override_if_set=True
).rstrip("/")
@@ -56,28 +37,20 @@ class OnyxAPIClient:
self._session: aiohttp.ClientSession | None = None
async def initialize(self) -> None:
"""Create the aiohttp session.
Must be called before making any requests. The session is created
with a total timeout and connection timeout.
"""
"""Create the aiohttp session."""
if self._session is not None:
logger.warning("API client session already initialized")
return
timeout = aiohttp.ClientTimeout(
total=self._timeout,
connect=30, # 30 seconds to establish connection
connect=30,
)
self._session = aiohttp.ClientSession(timeout=timeout)
logger.info(f"API client initialized with base URL: {self._base_url}")
async def close(self) -> None:
"""Close the aiohttp session.
Should be called when shutting down the bot to properly release
resources.
"""
"""Close the aiohttp session."""
if self._session is not None:
await self._session.close()
self._session = None
@@ -85,7 +58,6 @@ class OnyxAPIClient:
@property
def is_initialized(self) -> bool:
"""Check if the session is initialized."""
return self._session is not None
async def send_chat_message(
@@ -94,24 +66,7 @@ class OnyxAPIClient:
api_key: str,
persona_id: int | None = None,
) -> ChatFullResponse:
"""Send a chat message to the Onyx API server and get a response.
This method sends a non-streaming chat request to the API server. The response
contains the complete answer with any citations and metadata.
Args:
message: The user's message to process.
api_key: The API key for authentication.
persona_id: Optional persona ID to use for the response.
Returns:
ChatFullResponse containing the answer, citations, and metadata.
Raises:
APIConnectionError: If unable to connect to the API.
APITimeoutError: If the request times out.
APIResponseError: If the API returns an error response.
"""
"""Send a chat message to the Onyx API server and get a response."""
if self._session is None:
raise APIConnectionError(
"API client not initialized. Call initialize() first."
@@ -119,17 +74,15 @@ class OnyxAPIClient:
url = f"{self._base_url}/chat/send-chat-message"
# Build request payload
request = SendMessageRequest(
message=message,
stream=False,
origin=MessageOrigin.DISCORDBOT,
origin=self._origin,
chat_session_info=ChatSessionCreationRequest(
persona_id=persona_id if persona_id is not None else 0,
),
)
# Build headers
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
@@ -169,7 +122,6 @@ class OnyxAPIClient:
status_code=response.status,
)
# Parse successful response
data = await response.json()
response_obj = ChatFullResponse.model_validate(data)
@@ -195,11 +147,7 @@ class OnyxAPIClient:
raise APIConnectionError(f"HTTP client error: {e}") from e
async def health_check(self) -> bool:
"""Check if the API server is healthy.
Returns:
True if the API server is reachable and healthy, False otherwise.
"""
"""Check if the API server is healthy."""
if self._session is None:
logger.warning("API client not initialized. Call initialize() first.")
return False

View File

@@ -0,0 +1,195 @@
"""Shared multi-tenant cache for bot entity-tenant mappings and API keys.
Subclass ``BotCacheManager`` and implement the three abstract helpers to
create a platform-specific cache (e.g. Discord guilds, Teams teams).
"""
import asyncio
from abc import ABC
from abc import abstractmethod
from typing import Generic
from typing import TypeVar
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.onyxbot.exceptions import CacheError
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
EntityIdT = TypeVar("EntityIdT")
class BotCacheManager(ABC, Generic[EntityIdT]):
"""Caches entity->tenant mappings and tenant->API key mappings.
``EntityIdT`` is ``int`` for Discord guilds, ``str`` for Teams teams.
"""
def __init__(self, entity_name: str) -> None:
self._entity_name = entity_name
self._entity_tenants: dict[EntityIdT, str] = {}
self._api_keys: dict[str, str] = {}
self._lock = asyncio.Lock()
self._initialized = False
# ------------------------------------------------------------------
# Abstract hooks — platform-specific DB access
# ------------------------------------------------------------------
@abstractmethod
def _get_entity_ids(self, db: Session) -> list[EntityIdT]:
"""Return active entity IDs from DB configs."""
@abstractmethod
def _get_or_create_api_key(self, db: Session, tenant_id: str) -> str:
"""Provision (or retrieve) a service API key for *tenant_id*."""
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
@property
def is_initialized(self) -> bool:
return self._initialized
async def refresh_all(self) -> None:
"""Full cache refresh from all tenants.
Data is loaded outside the lock; the lock is only held for the
atomic swap of the cache dicts so that ``refresh_entity`` and
read operations are not blocked during I/O.
"""
logger.info(f"Starting {self._entity_name} cache refresh")
new_entity_tenants: dict[EntityIdT, str] = {}
new_api_keys: dict[str, str] = {}
try:
gated = fetch_ee_implementation_or_noop(
"onyx.server.tenants.product_gating",
"get_gated_tenants",
set(),
)()
tenant_ids = await asyncio.to_thread(get_all_tenant_ids)
for tenant_id in tenant_ids:
if tenant_id in gated:
continue
context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
entity_ids, api_key = await self._load_tenant_data(tenant_id)
if not entity_ids:
logger.debug(
f"No {self._entity_name} found for tenant " f"{tenant_id}"
)
continue
if not api_key:
logger.warning(
f"Service API key missing for tenant that has "
f"registered {self._entity_name}. {tenant_id} "
f"will not be handled in this refresh cycle."
)
continue
for entity_id in entity_ids:
new_entity_tenants[entity_id] = tenant_id
new_api_keys[tenant_id] = api_key
except (OperationalError, ConnectionError, OSError) as e:
logger.warning(f"Failed to refresh tenant {tenant_id}: {e}")
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token)
# Atomic swap under lock
async with self._lock:
self._entity_tenants = new_entity_tenants
self._api_keys = new_api_keys
self._initialized = True
logger.info(
f"Cache refresh complete: "
f"{len(new_entity_tenants)} {self._entity_name}, "
f"{len(new_api_keys)} tenants"
)
except (OperationalError, ConnectionError, OSError) as e:
logger.error(f"Cache refresh failed: {e}")
raise CacheError(f"Failed to refresh cache: {e}") from e
async def refresh_entity(self, entity_id: EntityIdT, tenant_id: str) -> None:
"""Add a single entity to cache after registration."""
logger.info(
f"Refreshing cache for {self._entity_name} entity "
f"{entity_id} (tenant: {tenant_id})"
)
context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
entity_ids, api_key = await self._load_tenant_data(tenant_id)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token)
async with self._lock:
if entity_id in entity_ids:
self._entity_tenants[entity_id] = tenant_id
if api_key:
self._api_keys[tenant_id] = api_key
logger.info(f"Cache updated for entity {entity_id}")
else:
logger.warning(f"Entity {entity_id} not found or disabled")
def get_tenant(self, entity_id: EntityIdT) -> str | None:
"""Get tenant ID for an entity."""
return self._entity_tenants.get(entity_id)
def get_api_key(self, tenant_id: str) -> str | None:
"""Get API key for a tenant."""
return self._api_keys.get(tenant_id)
def remove_entity(self, entity_id: EntityIdT) -> None:
"""Remove an entity from cache."""
self._entity_tenants.pop(entity_id, None)
def get_all_entity_ids(self) -> list[EntityIdT]:
"""Get all cached entity IDs."""
return list(self._entity_tenants.keys())
def clear(self) -> None:
"""Clear all caches."""
self._entity_tenants.clear()
self._api_keys.clear()
self._initialized = False
# ------------------------------------------------------------------
# Internal
# ------------------------------------------------------------------
async def _load_tenant_data(
self, tenant_id: str
) -> tuple[list[EntityIdT], str | None]:
"""Load entity IDs and provision API key if needed."""
cached_key = self._api_keys.get(tenant_id)
def _sync() -> tuple[list[EntityIdT], str | None]:
with get_session_with_tenant(tenant_id=tenant_id) as db:
entity_ids = self._get_entity_ids(db)
if not entity_ids:
return [], None
if not cached_key:
new_key = self._get_or_create_api_key(db, tenant_id)
db.commit()
return entity_ids, new_key
return entity_ids, cached_key
return await asyncio.to_thread(_sync)

View File

@@ -0,0 +1,10 @@
"""Shared constants for Onyx bot integrations (Discord, Teams, etc.)."""
# API settings
API_REQUEST_TIMEOUT: int = 3 * 60 # 3 minutes
# Cache settings
CACHE_REFRESH_INTERVAL: int = 60 # 1 minute
# Registration
REGISTER_COMMAND: str = "register"

View File

@@ -1,154 +1,35 @@
"""Multi-tenant cache for Discord bot guild-tenant mappings and API keys."""
import asyncio
from sqlalchemy.orm import Session
from onyx.db.discord_bot import get_guild_configs
from onyx.db.discord_bot import get_or_create_discord_service_api_key
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.onyxbot.discord.exceptions import CacheError
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
from onyx.onyxbot.cache import BotCacheManager
class DiscordCacheManager:
"""Caches guild->tenant mappings and tenant->API key mappings.
Refreshed on startup, periodically (every 60s), and when guilds register.
"""
class DiscordCacheManager(BotCacheManager[int]):
"""Caches guild->tenant mappings and tenant->API key mappings."""
def __init__(self) -> None:
self._guild_tenants: dict[int, str] = {} # guild_id -> tenant_id
self._api_keys: dict[str, str] = {} # tenant_id -> api_key
self._lock = asyncio.Lock()
self._initialized = False
super().__init__(entity_name="guilds")
@property
def is_initialized(self) -> bool:
return self._initialized
def _get_entity_ids(self, db: Session) -> list[int]:
configs = get_guild_configs(db)
return [
config.guild_id
for config in configs
if config.enabled and config.guild_id is not None
]
async def refresh_all(self) -> None:
"""Full cache refresh from all tenants."""
async with self._lock:
logger.info("Starting Discord cache refresh")
new_guild_tenants: dict[int, str] = {}
new_api_keys: dict[str, str] = {}
try:
gated = fetch_ee_implementation_or_noop(
"onyx.server.tenants.product_gating",
"get_gated_tenants",
set(),
)()
tenant_ids = await asyncio.to_thread(get_all_tenant_ids)
for tenant_id in tenant_ids:
if tenant_id in gated:
continue
context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
guild_ids, api_key = await self._load_tenant_data(tenant_id)
if not guild_ids:
logger.debug(f"No guilds found for tenant {tenant_id}")
continue
if not api_key:
logger.warning(
"Discord service API key missing for tenant that has registered guilds. "
f"{tenant_id} will not be handled in this refresh cycle."
)
continue
for guild_id in guild_ids:
new_guild_tenants[guild_id] = tenant_id
new_api_keys[tenant_id] = api_key
except Exception as e:
logger.warning(f"Failed to refresh tenant {tenant_id}: {e}")
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token)
self._guild_tenants = new_guild_tenants
self._api_keys = new_api_keys
self._initialized = True
logger.info(
f"Cache refresh complete: {len(new_guild_tenants)} guilds, "
f"{len(new_api_keys)} tenants"
)
except Exception as e:
logger.error(f"Cache refresh failed: {e}")
raise CacheError(f"Failed to refresh cache: {e}") from e
def _get_or_create_api_key(self, db: Session, tenant_id: str) -> str:
return get_or_create_discord_service_api_key(db, tenant_id)
# Convenience aliases for backward compatibility with callers
async def refresh_guild(self, guild_id: int, tenant_id: str) -> None:
"""Add a single guild to cache after registration."""
async with self._lock:
logger.info(f"Refreshing cache for guild {guild_id} (tenant: {tenant_id})")
guild_ids, api_key = await self._load_tenant_data(tenant_id)
if guild_id in guild_ids:
self._guild_tenants[guild_id] = tenant_id
if api_key:
self._api_keys[tenant_id] = api_key
logger.info(f"Cache updated for guild {guild_id}")
else:
logger.warning(f"Guild {guild_id} not found or disabled")
async def _load_tenant_data(self, tenant_id: str) -> tuple[list[int], str | None]:
"""Load guild IDs and provision API key if needed.
Returns:
(active_guild_ids, api_key) - api_key is the cached key if available,
otherwise a newly created key. Returns None if no guilds found.
"""
cached_key = self._api_keys.get(tenant_id)
def _sync() -> tuple[list[int], str | None]:
with get_session_with_tenant(tenant_id=tenant_id) as db:
configs = get_guild_configs(db)
guild_ids = [
config.guild_id
for config in configs
if config.enabled and config.guild_id is not None
]
if not guild_ids:
return [], None
if not cached_key:
new_key = get_or_create_discord_service_api_key(db, tenant_id)
db.commit()
return guild_ids, new_key
return guild_ids, cached_key
return await asyncio.to_thread(_sync)
def get_tenant(self, guild_id: int) -> str | None:
"""Get tenant ID for a guild."""
return self._guild_tenants.get(guild_id)
def get_api_key(self, tenant_id: str) -> str | None:
"""Get API key for a tenant."""
return self._api_keys.get(tenant_id)
await self.refresh_entity(guild_id, tenant_id)
def remove_guild(self, guild_id: int) -> None:
"""Remove a guild from cache."""
self._guild_tenants.pop(guild_id, None)
self.remove_entity(guild_id)
def get_all_guild_ids(self) -> list[int]:
"""Get all cached guild IDs."""
return list(self._guild_tenants.keys())
def clear(self) -> None:
"""Clear all caches."""
self._guild_tenants.clear()
self._api_keys.clear()
self._initialized = False
return self.get_all_entity_ids()

View File

@@ -7,15 +7,16 @@ import discord
from discord.ext import commands
from onyx.configs.app_configs import DISCORD_BOT_INVOKE_CHAR
from onyx.onyxbot.discord.api_client import OnyxAPIClient
from onyx.onyxbot.api_client import OnyxAPIClient
from onyx.onyxbot.constants import CACHE_REFRESH_INTERVAL
from onyx.onyxbot.discord.cache import DiscordCacheManager
from onyx.onyxbot.discord.constants import CACHE_REFRESH_INTERVAL
from onyx.onyxbot.discord.handle_commands import handle_dm
from onyx.onyxbot.discord.handle_commands import handle_registration_command
from onyx.onyxbot.discord.handle_commands import handle_sync_channels_command
from onyx.onyxbot.discord.handle_message import process_chat_message
from onyx.onyxbot.discord.handle_message import should_respond
from onyx.onyxbot.discord.utils import get_bot_token
from onyx.server.query_and_chat.models import MessageOrigin
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -40,7 +41,7 @@ class OnyxDiscordClient(commands.Bot):
self.ready = False
self.cache = DiscordCacheManager()
self.api_client = OnyxAPIClient()
self.api_client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
self._cache_refresh_task: asyncio.Task | None = None
# -------------------------------------------------------------------------

View File

@@ -1,19 +1,16 @@
"""Discord bot constants."""
"""Discord-specific bot constants.
# API settings
API_REQUEST_TIMEOUT: int = 3 * 60 # 3 minutes
# Cache settings
CACHE_REFRESH_INTERVAL: int = 60 # 1 minute
Shared constants (API_REQUEST_TIMEOUT, CACHE_REFRESH_INTERVAL,
REGISTER_COMMAND) live in ``onyx.onyxbot.constants``.
"""
# Message settings
MAX_MESSAGE_LENGTH: int = 2000 # Discord's character limit
MAX_CONTEXT_MESSAGES: int = 10 # Max messages to include in conversation context
# Note: Discord.py's add_reaction() requires unicode emoji, not :name: format
THINKING_EMOJI: str = "🤔" # U+1F914 - Thinking Face
SUCCESS_EMOJI: str = "" # U+2705 - White Heavy Check Mark
ERROR_EMOJI: str = "" # U+274C - Cross Mark
THINKING_EMOJI: str = "\U0001f914" # U+1F914 - Thinking Face
SUCCESS_EMOJI: str = "\u2705" # U+2705 - White Heavy Check Mark
ERROR_EMOJI: str = "\u274c" # U+274C - Cross Mark
# Command prefix
REGISTER_COMMAND: str = "register"
# Discord-specific commands
SYNC_CHANNELS_COMMAND: str = "sync-channels"

View File

@@ -1,37 +1,7 @@
"""Custom exception classes for Discord bot."""
"""Discord-specific exception classes."""
from onyx.onyxbot.exceptions import OnyxBotError
class DiscordBotError(Exception):
"""Base exception for Discord bot errors."""
class RegistrationError(DiscordBotError):
"""Error during guild registration."""
class SyncChannelsError(DiscordBotError):
class SyncChannelsError(OnyxBotError):
"""Error during channel sync."""
class APIError(DiscordBotError):
"""Base API error."""
class CacheError(DiscordBotError):
"""Error during cache operations."""
class APIConnectionError(APIError):
"""Failed to connect to API."""
class APITimeoutError(APIError):
"""Request timed out."""
class APIResponseError(APIError):
"""API returned an error response."""
def __init__(self, message: str, status_code: int | None = None):
super().__init__(message)
self.status_code = status_code

View File

@@ -15,11 +15,11 @@ from onyx.db.discord_bot import get_guild_config_by_registration_key
from onyx.db.discord_bot import sync_channel_configs
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.utils import DiscordChannelView
from onyx.onyxbot.constants import REGISTER_COMMAND
from onyx.onyxbot.discord.cache import DiscordCacheManager
from onyx.onyxbot.discord.constants import REGISTER_COMMAND
from onyx.onyxbot.discord.constants import SYNC_CHANNELS_COMMAND
from onyx.onyxbot.discord.exceptions import RegistrationError
from onyx.onyxbot.discord.exceptions import SyncChannelsError
from onyx.onyxbot.exceptions import RegistrationError
from onyx.server.manage.discord_bot.utils import parse_discord_registration_key
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR

View File

@@ -11,11 +11,11 @@ from onyx.db.discord_bot import get_guild_config_by_discord_id
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.models import DiscordChannelConfig
from onyx.db.models import DiscordGuildConfig
from onyx.onyxbot.discord.api_client import OnyxAPIClient
from onyx.onyxbot.api_client import OnyxAPIClient
from onyx.onyxbot.discord.constants import MAX_CONTEXT_MESSAGES
from onyx.onyxbot.discord.constants import MAX_MESSAGE_LENGTH
from onyx.onyxbot.discord.constants import THINKING_EMOJI
from onyx.onyxbot.discord.exceptions import APIError
from onyx.onyxbot.exceptions import APIError
from onyx.utils.logger import setup_logger
logger = setup_logger()

View File

@@ -0,0 +1,33 @@
"""Shared exception classes for Onyx bot integrations (Discord, Teams, etc.)."""
class OnyxBotError(Exception):
"""Base exception for all Onyx bot errors."""
class RegistrationError(OnyxBotError):
"""Error during bot registration."""
class APIError(OnyxBotError):
"""Base API error."""
class CacheError(OnyxBotError):
"""Error during cache operations."""
class APIConnectionError(APIError):
"""Failed to connect to API."""
class APITimeoutError(APIError):
"""Request timed out."""
class APIResponseError(APIError):
"""API returned an error response."""
def __init__(self, message: str, status_code: int | None = None):
super().__init__(message)
self.status_code = status_code

View File

@@ -0,0 +1,42 @@
"""Shared registration key generation and parsing for bot integrations."""
import secrets
from urllib.parse import quote
from urllib.parse import unquote
from onyx.utils.logger import setup_logger
logger = setup_logger()
def generate_registration_key(prefix: str, tenant_id: str) -> str:
"""Generate a one-time registration key with embedded tenant_id.
Format: <prefix>_<url_encoded_tenant_id>.<random_token>
"""
encoded_tenant = quote(tenant_id)
random_token = secrets.token_urlsafe(16)
logger.info(f"Generated {prefix} registration key for tenant {tenant_id}")
return f"{prefix}_{encoded_tenant}.{random_token}"
def parse_registration_key(prefix: str, key: str) -> str | None:
"""Parse registration key to extract tenant_id.
Returns tenant_id or None if invalid format.
"""
full_prefix = f"{prefix}_"
if not key.startswith(full_prefix):
return None
try:
key_body = key.removeprefix(full_prefix)
parts = key_body.split(".", 1)
if len(parts) != 2:
return None
encoded_tenant = parts[0]
return unquote(encoded_tenant)
except Exception:
return None

View File

View File

@@ -0,0 +1,186 @@
"""Teams bot Activity handler using Bot Framework SDK."""
import asyncio
from botbuilder.core import ActivityHandler # type: ignore[import-untyped]
from botbuilder.core import TurnContext
from botbuilder.schema import Activity # type: ignore[import-untyped]
from botbuilder.schema import ActivityTypes
from botbuilder.schema import Attachment
from botbuilder.schema import ChannelAccount
from onyx.onyxbot.api_client import OnyxAPIClient
from onyx.onyxbot.constants import CACHE_REFRESH_INTERVAL
from onyx.onyxbot.exceptions import RegistrationError
from onyx.onyxbot.teams.cache import TeamsCacheManager
from onyx.onyxbot.teams.handle_commands import handle_registration_command
from onyx.onyxbot.teams.handle_commands import is_registration_command
from onyx.onyxbot.teams.handle_message import process_chat_message
from onyx.onyxbot.teams.handle_message import should_respond
from onyx.onyxbot.teams.utils import extract_channel_id
from onyx.onyxbot.teams.utils import extract_team_id
from onyx.server.query_and_chat.models import MessageOrigin
from onyx.utils.logger import setup_logger
logger = setup_logger()
class OnyxTeamsBot(ActivityHandler):
"""Activity handler for Teams bot.
Handles incoming message activities, member additions, and routes
messages to the appropriate handler (registration, chat).
"""
def __init__(self) -> None:
self.cache = TeamsCacheManager()
self.api_client = OnyxAPIClient(origin=MessageOrigin.TEAMSBOT)
self._cache_refresh_task: asyncio.Task[None] | None = None
self._bot_id: str | None = None
self._bot_name: str = "Onyx"
async def initialize(self) -> None:
"""Initialize the bot: API client, cache, and background tasks."""
await self.api_client.initialize()
await self.cache.refresh_all()
self._cache_refresh_task = asyncio.create_task(self._periodic_cache_refresh())
logger.info("Teams bot initialized")
async def shutdown(self) -> None:
"""Gracefully shut down the bot."""
if self._cache_refresh_task:
self._cache_refresh_task.cancel()
try:
await self._cache_refresh_task
except asyncio.CancelledError:
pass
await self.api_client.close()
self.cache.clear()
logger.info("Teams bot shut down")
async def _periodic_cache_refresh(self) -> None:
"""Background task to refresh cache periodically."""
while True:
await asyncio.sleep(CACHE_REFRESH_INTERVAL)
try:
await self.cache.refresh_all()
except Exception as e:
logger.error(f"Periodic cache refresh failed: {e}")
async def on_message_activity(self, turn_context: TurnContext) -> None:
"""Handle incoming message activities."""
activity = turn_context.activity
if not activity.text:
return
# Capture bot identity on first message
if not self._bot_id and activity.recipient:
self._bot_id = activity.recipient.id
self._bot_name = activity.recipient.name or "Onyx"
activity_dict = activity.as_dict() if hasattr(activity, "as_dict") else {}
team_id = extract_team_id(activity_dict)
channel_id = extract_channel_id(activity_dict)
# Check for registration command
if is_registration_command(activity.text, self._bot_name):
await self._handle_registration(turn_context, activity_dict)
return
# Resolve tenant from team cache
tenant_id: str | None = None
if team_id:
tenant_id = self.cache.get_tenant(team_id)
if not tenant_id:
logger.debug(f"No tenant found for team {team_id}")
return
else:
# DM — not in a team context, so we can't determine tenant.
# TODO(nik): support DM registration or default tenant lookup
logger.debug("Ignoring DM (no team context to resolve tenant)")
return
# Check if bot should respond
context = await asyncio.to_thread(
should_respond,
activity_dict,
team_id,
channel_id,
tenant_id,
self._bot_id or "",
)
if not context.should_respond:
return
api_key = self.cache.get_api_key(tenant_id)
if not api_key:
logger.warning(f"No API key for tenant {tenant_id}")
return
# Send typing indicator
await turn_context.send_activity(Activity(type=ActivityTypes.typing))
# Process message and send response
card = await process_chat_message(
text=activity.text,
api_key=api_key,
persona_id=context.persona_id,
api_client=self.api_client,
bot_name=self._bot_name,
)
# Send as Adaptive Card
attachment = Attachment(
content_type="application/vnd.microsoft.card.adaptive",
content=card,
)
response = Activity(
type=ActivityTypes.message,
attachments=[attachment],
)
await turn_context.send_activity(response)
async def _handle_registration(
self,
turn_context: TurnContext,
activity_dict: dict,
) -> None:
"""Handle registration command."""
try:
result = await handle_registration_command(
text=turn_context.activity.text or "",
activity_dict=activity_dict,
bot_name=self._bot_name,
cache=self.cache,
)
await turn_context.send_activity(result)
except RegistrationError as e:
await turn_context.send_activity(f"Registration failed: {e}")
except Exception as e:
logger.exception(f"Registration error: {e}")
await turn_context.send_activity(
"An unexpected error occurred during registration."
)
async def on_members_added_activity(
self,
members_added: list[ChannelAccount],
turn_context: TurnContext,
) -> None:
"""Handle when the bot is added to a team or conversation."""
for member in members_added:
# Only send welcome when the bot itself is added
if member.id == turn_context.activity.recipient.id:
from onyx.onyxbot.teams.cards import build_welcome_card
attachment = Attachment(
content_type="application/vnd.microsoft.card.adaptive",
content=build_welcome_card(),
)
response = Activity(
type=ActivityTypes.message,
attachments=[attachment],
)
await turn_context.send_activity(response)

View File

@@ -0,0 +1,35 @@
"""Multi-tenant cache for Teams bot team-tenant mappings and API keys."""
from sqlalchemy.orm import Session
from onyx.db.teams_bot import get_team_configs
from onyx.db.teams_bot import provision_teams_service_api_key
from onyx.onyxbot.cache import BotCacheManager
class TeamsCacheManager(BotCacheManager[str]):
"""Caches team->tenant mappings and tenant->API key mappings."""
def __init__(self) -> None:
super().__init__(entity_name="teams")
def _get_entity_ids(self, db: Session) -> list[str]:
configs = get_team_configs(db)
return [
config.team_id
for config in configs
if config.enabled and config.team_id is not None
]
def _get_or_create_api_key(self, db: Session, tenant_id: str) -> str:
return provision_teams_service_api_key(db, tenant_id)
# Convenience aliases for caller clarity
async def refresh_team(self, team_id: str, tenant_id: str) -> None:
await self.refresh_entity(team_id, tenant_id)
def remove_team(self, team_id: str) -> None:
self.remove_entity(team_id)
def get_all_team_ids(self) -> list[str]:
return self.get_all_entity_ids()

View File

@@ -0,0 +1,136 @@
"""Adaptive Card builders for Teams bot responses."""
from onyx.chat.models import ChatFullResponse
from onyx.onyxbot.teams.constants import ADAPTIVE_CARD_SCHEMA
from onyx.onyxbot.teams.constants import ADAPTIVE_CARD_VERSION
from onyx.onyxbot.teams.constants import MAX_CITATIONS
def build_answer_card(
answer: str,
response: ChatFullResponse | None = None,
) -> dict:
"""Build an Adaptive Card for a chat answer with optional citations.
Target Adaptive Card schema version 1.3 for mobile compatibility.
"""
body: list[dict] = [
{
"type": "TextBlock",
"text": answer,
"wrap": True,
}
]
# Add citations if present
citations = _extract_citations(response) if response else []
if citations:
body.append(
{
"type": "TextBlock",
"text": "**Sources:**",
"wrap": True,
"spacing": "Medium",
}
)
for num, name, link in citations:
if link:
body.append(
{
"type": "TextBlock",
"text": f"{num}. [{name}]({link})",
"wrap": True,
"spacing": "None",
}
)
else:
body.append(
{
"type": "TextBlock",
"text": f"{num}. {name}",
"wrap": True,
"spacing": "None",
}
)
return {
"$schema": ADAPTIVE_CARD_SCHEMA,
"type": "AdaptiveCard",
"version": ADAPTIVE_CARD_VERSION,
"body": body,
}
def build_error_card(message: str) -> dict:
"""Build an Adaptive Card for error messages."""
return {
"$schema": ADAPTIVE_CARD_SCHEMA,
"type": "AdaptiveCard",
"version": ADAPTIVE_CARD_VERSION,
"body": [
{
"type": "TextBlock",
"text": message,
"wrap": True,
"color": "Attention",
}
],
}
def build_welcome_card() -> dict:
"""Build an Adaptive Card for the welcome message when bot is added."""
return {
"$schema": ADAPTIVE_CARD_SCHEMA,
"type": "AdaptiveCard",
"version": ADAPTIVE_CARD_VERSION,
"body": [
{
"type": "TextBlock",
"text": "Welcome to Onyx!",
"weight": "Bolder",
"size": "Medium",
},
{
"type": "TextBlock",
"text": (
"I'm the Onyx bot. I can help you search your company's knowledge base "
"and answer questions.\n\n"
"To get started, an admin needs to register this team. "
"Send me a direct message with:\n\n"
"`@Onyx register <registration_key>`"
),
"wrap": True,
},
],
}
def _extract_citations(
response: ChatFullResponse,
) -> list[tuple[int, str, str | None]]:
"""Extract citation information from a chat response."""
if not response.citation_info or not response.top_documents:
return []
cited_docs: list[tuple[int, str, str | None]] = []
for citation in response.citation_info:
doc = next(
(
d
for d in response.top_documents
if d.document_id == citation.document_id
),
None,
)
if doc:
cited_docs.append(
(
citation.citation_number,
doc.semantic_identifier or "Source",
doc.link,
)
)
cited_docs.sort(key=lambda x: x[0])
return cited_docs[:MAX_CITATIONS]

View File

@@ -0,0 +1,14 @@
"""Teams-specific bot constants.
Shared constants (API_REQUEST_TIMEOUT, CACHE_REFRESH_INTERVAL,
REGISTER_COMMAND) live in ``onyx.onyxbot.constants``.
"""
# Bot Framework settings
BOT_MESSAGES_ENDPOINT: str = "/api/messages"
BOT_HEALTH_ENDPOINT: str = "/health"
# Adaptive Card settings
ADAPTIVE_CARD_SCHEMA: str = "http://adaptivecards.io/schemas/adaptive-card.json"
ADAPTIVE_CARD_VERSION: str = "1.3" # Compatible with mobile clients
MAX_CITATIONS: int = 5

View File

@@ -0,0 +1,81 @@
"""Teams bot command handlers (e.g., registration)."""
import asyncio
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.teams_bot import get_team_config_by_registration_key
from onyx.db.teams_bot import register_team
from onyx.onyxbot.constants import REGISTER_COMMAND
from onyx.onyxbot.exceptions import RegistrationError
from onyx.onyxbot.teams.cache import TeamsCacheManager
from onyx.onyxbot.teams.utils import extract_team_id
from onyx.onyxbot.teams.utils import extract_team_name
from onyx.onyxbot.teams.utils import strip_bot_mention
from onyx.server.manage.teams_bot.utils import parse_teams_registration_key
from onyx.utils.logger import setup_logger
logger = setup_logger()
async def handle_registration_command(
text: str,
activity_dict: dict,
bot_name: str,
cache: TeamsCacheManager,
) -> str:
"""Handle the 'register <key>' command.
Returns a human-readable response message.
"""
clean_text = strip_bot_mention(text, bot_name).strip()
# Parse "register <key>"
parts = clean_text.split(None, 1)
if len(parts) != 2 or parts[0].lower() != REGISTER_COMMAND:
raise RegistrationError(
f"Invalid registration command. Usage: @{bot_name} register <registration_key>"
)
registration_key = parts[1].strip()
# Parse tenant_id from registration key
tenant_id = parse_teams_registration_key(registration_key)
if not tenant_id:
raise RegistrationError("Invalid registration key format.")
team_id = extract_team_id(activity_dict)
team_name = extract_team_name(activity_dict) or "Unknown Team"
if not team_id:
raise RegistrationError(
"Registration must be done from a Teams channel, not a DM."
)
def _register() -> str:
with get_session_with_tenant(tenant_id=tenant_id) as db:
# Lock the row to prevent concurrent registration with the same key
config = get_team_config_by_registration_key(
db, registration_key, for_update=True
)
if not config:
raise RegistrationError("Registration key not found or already used.")
if config.team_id is not None:
raise RegistrationError("This registration key has already been used.")
register_team(db, config, team_id, team_name)
db.commit()
return tenant_id
registered_tenant_id = await asyncio.to_thread(_register)
await cache.refresh_team(team_id, registered_tenant_id)
logger.info(f"Team {team_id} ({team_name}) registered for tenant {tenant_id}")
return f"Team **{team_name}** has been registered with Onyx. You can now configure channels in the admin panel."
def is_registration_command(text: str, bot_name: str) -> bool:
"""Check if a message is a registration command."""
clean_text = strip_bot_mention(text, bot_name).strip()
parts = clean_text.split(None, 1)
return len(parts) >= 1 and parts[0].lower() == REGISTER_COMMAND

View File

@@ -0,0 +1,111 @@
"""Teams bot message handling and response logic."""
from dataclasses import dataclass
from dataclasses import field
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.models import TeamsChannelConfig
from onyx.db.models import TeamsTeamConfig
from onyx.db.teams_bot import get_channel_config_by_teams_ids
from onyx.db.teams_bot import get_team_config_by_teams_id
from onyx.onyxbot.api_client import OnyxAPIClient
from onyx.onyxbot.exceptions import APIError
from onyx.onyxbot.teams.cards import build_answer_card
from onyx.onyxbot.teams.cards import build_error_card
from onyx.onyxbot.teams.utils import is_bot_mentioned
from onyx.onyxbot.teams.utils import strip_bot_mention
from onyx.utils.logger import setup_logger
logger = setup_logger()
@dataclass
class ShouldRespondContext:
"""Context for whether the bot should respond to a message."""
should_respond: bool
persona_id: int | None
tenant_id: str | None = field(default=None)
api_key: str | None = field(default=None)
def should_respond(
activity_dict: dict,
team_id: str | None,
channel_id: str | None,
tenant_id: str,
bot_id: str,
) -> ShouldRespondContext:
"""Determine if bot should respond and which persona to use.
This is a synchronous function that performs DB lookups.
"""
no_response = ShouldRespondContext(should_respond=False, persona_id=None)
if not team_id or not channel_id:
# DM or group chat — respond if we have a tenant
return ShouldRespondContext(should_respond=True, persona_id=None)
with get_session_with_tenant(tenant_id=tenant_id) as db:
team_config: TeamsTeamConfig | None = get_team_config_by_teams_id(db, team_id)
if not team_config or not team_config.enabled:
return no_response
channel_config: TeamsChannelConfig | None = get_channel_config_by_teams_ids(
db, team_id, channel_id
)
if not channel_config or not channel_config.enabled:
return no_response
# Determine persona (channel override or team default)
persona_id = (
channel_config.persona_override_id or team_config.default_persona_id
)
# Check mention requirement
if channel_config.require_bot_mention:
if not is_bot_mentioned(activity_dict, bot_id):
return no_response
return ShouldRespondContext(should_respond=True, persona_id=persona_id)
async def process_chat_message(
text: str,
api_key: str,
persona_id: int | None,
api_client: OnyxAPIClient,
bot_name: str,
) -> dict:
"""Process a message and return an Adaptive Card response.
Returns:
Adaptive Card dict for the response.
"""
try:
# Strip bot mention from the message
clean_text = strip_bot_mention(text, bot_name)
if not clean_text:
return build_error_card("Please include a message after the @mention.")
# Send to Onyx API
response = await api_client.send_chat_message(
message=clean_text,
api_key=api_key,
persona_id=persona_id,
)
answer = response.answer or "I couldn't generate a response."
return build_answer_card(answer, response)
except APIError as e:
logger.error(f"API error processing Teams message: {e}")
return build_error_card(
"Sorry, I encountered an error processing your message. "
"Please try again later."
)
except Exception as e:
logger.exception(f"Error processing Teams chat message: {e}")
return build_error_card(
"Sorry, an unexpected error occurred. Please try again later."
)

View File

@@ -0,0 +1,155 @@
"""HTTP server for Teams bot using aiohttp + Bot Framework adapter."""
import sys
from aiohttp import web
from botbuilder.core import BotFrameworkAdapter # type: ignore[import-untyped]
from botbuilder.core import BotFrameworkAdapterSettings
from botbuilder.core import TurnContext
from botbuilder.schema import Activity # type: ignore[import-untyped]
from onyx.configs.app_configs import TEAMS_BOT_PORT
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.teams_bot import get_teams_bot_config
from onyx.onyxbot.teams.bot import OnyxTeamsBot
from onyx.onyxbot.teams.constants import BOT_HEALTH_ENDPOINT
from onyx.onyxbot.teams.constants import BOT_MESSAGES_ENDPOINT
from onyx.onyxbot.teams.utils import get_bot_credentials_from_env
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _get_credentials() -> tuple[str, str, str | None] | None:
"""Get bot credentials from env vars or database.
Env vars take priority. Falls back to DB config for self-hosted
deployments that configure via admin UI.
"""
env_creds = get_bot_credentials_from_env()
if env_creds:
return env_creds
# Try database (for self-hosted deployments)
try:
from onyx.db.engine.tenant_utils import get_all_tenant_ids
tenant_ids = get_all_tenant_ids()
for tenant_id in tenant_ids:
with get_session_with_tenant(tenant_id=tenant_id) as db:
config = get_teams_bot_config(db)
if config:
# Access the decrypted value
app_secret = config.app_secret
if isinstance(app_secret, str):
return config.app_id, app_secret, config.azure_tenant_id
except Exception as e:
logger.warning(f"Failed to load Teams bot config from DB: {e}")
return None
async def _handle_messages(request: web.Request) -> web.Response:
"""Handle incoming Bot Framework Activities at POST /api/messages."""
bot: OnyxTeamsBot = request.app["bot"]
adapter: BotFrameworkAdapter = request.app["adapter"]
if request.content_type != "application/json":
return web.Response(status=415, text="Unsupported media type")
body = await request.json()
activity = Activity().deserialize(body)
auth_header = request.headers.get("Authorization", "")
async def _turn_callback(turn_context: TurnContext) -> None:
await bot.on_turn(turn_context)
try:
invoke_response = await adapter.process_activity(
activity, auth_header, _turn_callback
)
# For invoke activities (messaging extensions, task modules),
# process_activity returns an InvokeResponse with status/body
# that must be forwarded to the Bot Framework.
if invoke_response:
return web.Response(
status=invoke_response.status,
body=invoke_response.body,
content_type="application/json",
)
return web.Response(status=200)
except Exception as e:
logger.exception(f"Error processing activity: {e}")
return web.Response(status=500, text="Internal server error")
async def _handle_health(request: web.Request) -> web.Response:
"""Health check endpoint."""
bot: OnyxTeamsBot = request.app["bot"]
healthy = bot.api_client.is_initialized and bot.cache.is_initialized
if healthy:
return web.Response(status=200, text="OK")
return web.Response(status=503, text="Not ready")
async def _on_startup(app: web.Application) -> None:
"""Initialize bot on server startup."""
bot: OnyxTeamsBot = app["bot"]
await bot.initialize()
logger.info("Teams bot server started")
async def _on_shutdown(app: web.Application) -> None:
"""Shut down bot on server shutdown."""
bot: OnyxTeamsBot = app["bot"]
await bot.shutdown()
logger.info("Teams bot server stopped")
def create_app(
app_id: str,
app_secret: str,
) -> web.Application:
"""Create the aiohttp web application for the Teams bot."""
settings = BotFrameworkAdapterSettings(
app_id=app_id,
app_password=app_secret,
)
adapter = BotFrameworkAdapter(settings)
bot = OnyxTeamsBot()
app = web.Application()
app["bot"] = bot
app["adapter"] = adapter
app.router.add_post(BOT_MESSAGES_ENDPOINT, _handle_messages)
app.router.add_get(BOT_HEALTH_ENDPOINT, _handle_health)
app.on_startup.append(_on_startup)
app.on_shutdown.append(_on_shutdown)
return app
def main() -> None:
"""Entry point for the Teams bot process."""
logger.info("Starting Teams bot...")
credentials = _get_credentials()
if not credentials:
logger.error(
"Teams bot credentials not configured. "
"Set TEAMS_BOT_APP_ID and TEAMS_BOT_APP_SECRET environment variables, "
"or configure via the admin panel."
)
sys.exit(1)
app_id, app_secret, _azure_tenant_id = credentials
logger.info(f"Teams bot starting with App ID: {app_id}")
app = create_app(app_id, app_secret)
web.run_app(app, host="0.0.0.0", port=TEAMS_BOT_PORT)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,83 @@
"""Utility functions for Teams bot."""
from onyx.configs.app_configs import TEAMS_BOT_APP_ID
from onyx.configs.app_configs import TEAMS_BOT_APP_SECRET
from onyx.configs.app_configs import TEAMS_BOT_AZURE_TENANT_ID
from onyx.utils.logger import setup_logger
logger = setup_logger()
def get_bot_credentials_from_env() -> tuple[str, str, str | None] | None:
"""Get bot credentials from environment variables.
Returns:
(app_id, app_secret, azure_tenant_id) or None if not configured.
"""
if not TEAMS_BOT_APP_ID or not TEAMS_BOT_APP_SECRET:
return None
return TEAMS_BOT_APP_ID, TEAMS_BOT_APP_SECRET, TEAMS_BOT_AZURE_TENANT_ID
def extract_team_id(activity: dict) -> str | None:
"""Extract the Teams team ID from an Activity's channelData.
Teams Activities include channelData.team.id for messages in team channels.
For 1:1 or group chats, this will be None.
"""
channel_data = activity.get("channelData", {})
team = channel_data.get("team")
if team:
return team.get("id")
return None
def extract_channel_id(activity: dict) -> str | None:
"""Extract the Teams channel ID from an Activity's channelData."""
channel_data = activity.get("channelData", {})
channel = channel_data.get("channel")
if channel:
return channel.get("id")
return None
def extract_team_name(activity: dict) -> str | None:
"""Extract the Teams team name from an Activity's channelData."""
channel_data = activity.get("channelData", {})
team = channel_data.get("team")
if team:
return team.get("name")
return None
def strip_bot_mention(text: str, bot_name: str) -> str:
"""Remove the bot @mention from the message text.
Teams includes the @mention in the message text as <at>BotName</at>.
"""
import re
# Remove <at>BotName</at> tags
cleaned = re.sub(
rf"<at>{re.escape(bot_name)}</at>",
"",
text,
flags=re.IGNORECASE,
)
# Also try without the specific name (some clients send generic)
cleaned = re.sub(r"<at>[^<]*</at>", "", cleaned)
return cleaned.strip()
def is_bot_mentioned(activity: dict, bot_id: str) -> bool:
"""Check if the bot is mentioned in the activity.
Teams includes mentions in the activity entities array.
"""
entities = activity.get("entities", [])
for entity in entities:
if entity.get("type") == "mention":
mentioned = entity.get("mentioned", {})
if mentioned.get("id") == bot_id:
return True
return False

View File

@@ -7,14 +7,13 @@ from PIL import UnidentifiedImageError
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from sqlalchemy.orm import Session
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
from onyx.db.llm import fetch_default_llm_model
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.file_processing.password_validation import is_file_password_protected
from onyx.llm.factory import get_default_llm
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -117,9 +116,7 @@ def estimate_image_tokens_for_upload(
pass
def categorize_uploaded_files(
files: list[UploadFile], db_session: Session
) -> CategorizedFiles:
def categorize_uploaded_files(files: list[UploadFile]) -> CategorizedFiles:
"""
Categorize uploaded files based on text extractability and tokenized length.
@@ -131,11 +128,11 @@ def categorize_uploaded_files(
"""
results = CategorizedFiles()
default_model = fetch_default_llm_model(db_session)
llm = get_default_llm()
model_name = default_model.name if default_model else None
provider_type = default_model.llm_provider.provider if default_model else None
tokenizer = get_tokenizer(model_name=model_name, provider_type=provider_type)
tokenizer = get_tokenizer(
model_name=llm.config.model_name, provider_type=llm.config.model_provider
)
# Check if threshold checks should be skipped
skip_threshold = False

View File

@@ -1,46 +1,16 @@
"""Discord registration key generation and parsing."""
import secrets
from urllib.parse import quote
from urllib.parse import unquote
from onyx.onyxbot.registration import generate_registration_key
from onyx.onyxbot.registration import parse_registration_key
from onyx.utils.logger import setup_logger
logger = setup_logger()
REGISTRATION_KEY_PREFIX: str = "discord_"
REGISTRATION_KEY_PREFIX: str = "discord"
def generate_discord_registration_key(tenant_id: str) -> str:
"""Generate a one-time registration key with embedded tenant_id.
Format: discord_<url_encoded_tenant_id>.<random_token>
Follows the same pattern as API keys for consistency.
"""
encoded_tenant = quote(tenant_id)
random_token = secrets.token_urlsafe(16)
logger.info(f"Generated Discord registration key for tenant {tenant_id}")
return f"{REGISTRATION_KEY_PREFIX}{encoded_tenant}.{random_token}"
"""Generate a one-time registration key with embedded tenant_id."""
return generate_registration_key(REGISTRATION_KEY_PREFIX, tenant_id)
def parse_discord_registration_key(key: str) -> str | None:
"""Parse registration key to extract tenant_id.
Returns tenant_id or None if invalid format.
"""
if not key.startswith(REGISTRATION_KEY_PREFIX):
return None
try:
key_body = key.removeprefix(REGISTRATION_KEY_PREFIX)
parts = key_body.split(".", 1)
if len(parts) != 2:
return None
encoded_tenant = parts[0]
tenant_id = unquote(encoded_tenant)
return tenant_id
except Exception:
return None
"""Parse registration key to extract tenant_id."""
return parse_registration_key(REGISTRATION_KEY_PREFIX, key)

View File

@@ -0,0 +1,279 @@
"""Teams bot admin API endpoints."""
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import status
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import TEAMS_BOT_APP_ID
from onyx.configs.constants import AuthType
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.db.teams_bot import create_team_config
from onyx.db.teams_bot import create_teams_bot_config
from onyx.db.teams_bot import delete_team_config
from onyx.db.teams_bot import delete_teams_bot_config
from onyx.db.teams_bot import delete_teams_service_api_key
from onyx.db.teams_bot import get_channel_config_by_internal_ids
from onyx.db.teams_bot import get_channel_configs
from onyx.db.teams_bot import get_team_config_by_internal_id
from onyx.db.teams_bot import get_team_configs
from onyx.db.teams_bot import get_teams_bot_config
from onyx.db.teams_bot import update_team_config
from onyx.db.teams_bot import update_teams_channel_config
from onyx.server.manage.teams_bot.models import TeamsBotConfigCreateRequest
from onyx.server.manage.teams_bot.models import TeamsBotConfigResponse
from onyx.server.manage.teams_bot.models import TeamsChannelConfigResponse
from onyx.server.manage.teams_bot.models import TeamsChannelConfigUpdateRequest
from onyx.server.manage.teams_bot.models import TeamsTeamConfigCreateResponse
from onyx.server.manage.teams_bot.models import TeamsTeamConfigResponse
from onyx.server.manage.teams_bot.models import TeamsTeamConfigUpdateRequest
from onyx.server.manage.teams_bot.utils import generate_teams_registration_key
from shared_configs.contextvars import get_current_tenant_id
router = APIRouter(prefix="/manage/admin/teams-bot")
def _check_bot_config_api_access() -> None:
"""Raise 403 if bot config cannot be managed via API.
Bot config endpoints are disabled:
- On Cloud (managed by Onyx)
- When TEAMS_BOT_APP_ID env var is set (managed via env)
"""
if AUTH_TYPE == AuthType.CLOUD:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Teams bot configuration is managed by Onyx on Cloud.",
)
if TEAMS_BOT_APP_ID:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Teams bot is configured via environment variables. API access disabled.",
)
# === Bot Config ===
@router.get("/config")
def get_bot_config(
_: None = Depends(_check_bot_config_api_access),
__: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> TeamsBotConfigResponse:
"""Get Teams bot config. Returns 403 on Cloud or if env vars set."""
config = get_teams_bot_config(db_session)
if not config:
return TeamsBotConfigResponse(configured=False)
return TeamsBotConfigResponse(
configured=True,
created_at=config.created_at,
)
@router.post("/config")
def create_bot_request(
request: TeamsBotConfigCreateRequest,
_: None = Depends(_check_bot_config_api_access),
__: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> TeamsBotConfigResponse:
"""Create Teams bot config. Returns 403 on Cloud or if env vars set."""
try:
config = create_teams_bot_config(
db_session,
app_id=request.app_id,
app_secret=request.app_secret,
azure_tenant_id=request.azure_tenant_id,
)
except ValueError:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Teams bot config already exists. Delete it first to create a new one.",
)
db_session.commit()
return TeamsBotConfigResponse(
configured=True,
created_at=config.created_at,
)
@router.delete("/config")
def delete_bot_config_endpoint(
_: None = Depends(_check_bot_config_api_access),
__: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict:
"""Delete Teams bot config.
Also deletes the Teams service API key since the bot is being removed.
"""
deleted = delete_teams_bot_config(db_session)
if not deleted:
raise HTTPException(status_code=404, detail="Bot config not found")
delete_teams_service_api_key(db_session)
db_session.commit()
return {"deleted": True}
# === Service API Key ===
@router.delete("/service-api-key")
def delete_service_api_key_endpoint(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict:
"""Delete the Teams service API key."""
deleted = delete_teams_service_api_key(db_session)
if not deleted:
raise HTTPException(status_code=404, detail="Service API key not found")
db_session.commit()
return {"deleted": True}
# === Team Config ===
@router.get("/teams")
def list_team_configs(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[TeamsTeamConfigResponse]:
"""List all team configs (pending and registered)."""
configs = get_team_configs(db_session)
return [TeamsTeamConfigResponse.model_validate(c) for c in configs]
@router.post("/teams")
def create_team_request(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> TeamsTeamConfigCreateResponse:
"""Create new team config with registration key. Key shown once."""
tenant_id = get_current_tenant_id()
registration_key = generate_teams_registration_key(tenant_id)
config = create_team_config(db_session, registration_key)
db_session.commit()
return TeamsTeamConfigCreateResponse(
id=config.id,
registration_key=registration_key,
)
@router.get("/teams/{config_id}")
def get_team_config(
config_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> TeamsTeamConfigResponse:
"""Get specific team config."""
config = get_team_config_by_internal_id(db_session, internal_id=config_id)
if not config:
raise HTTPException(status_code=404, detail="Team config not found")
return TeamsTeamConfigResponse.model_validate(config)
@router.patch("/teams/{config_id}")
def update_team_request(
config_id: int,
request: TeamsTeamConfigUpdateRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> TeamsTeamConfigResponse:
"""Update team config."""
config = get_team_config_by_internal_id(db_session, internal_id=config_id)
if not config:
raise HTTPException(status_code=404, detail="Team config not found")
config = update_team_config(
db_session,
config,
enabled=request.enabled,
default_persona_id=request.default_persona_id,
)
db_session.commit()
return TeamsTeamConfigResponse.model_validate(config)
@router.delete("/teams/{config_id}")
def delete_team_request(
config_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict:
"""Delete team config (invalidates registration key).
On Cloud, if this was the last team config, also deletes the service API key.
"""
deleted = delete_team_config(db_session, config_id)
if not deleted:
raise HTTPException(status_code=404, detail="Team config not found")
if AUTH_TYPE == AuthType.CLOUD:
remaining_teams = get_team_configs(db_session)
if not remaining_teams:
delete_teams_service_api_key(db_session)
db_session.commit()
return {"deleted": True}
# === Channel Config ===
@router.get("/teams/{config_id}/channels")
def list_channel_configs(
config_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[TeamsChannelConfigResponse]:
"""List whitelisted channels for a team."""
team_config = get_team_config_by_internal_id(db_session, internal_id=config_id)
if not team_config:
raise HTTPException(status_code=404, detail="Team config not found")
if not team_config.team_id:
raise HTTPException(status_code=400, detail="Team not yet registered")
configs = get_channel_configs(db_session, config_id)
return [TeamsChannelConfigResponse.model_validate(c) for c in configs]
@router.patch("/teams/{team_config_id}/channels/{channel_config_id}")
def update_channel_request(
team_config_id: int,
channel_config_id: int,
request: TeamsChannelConfigUpdateRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> TeamsChannelConfigResponse:
"""Update channel config."""
config = get_channel_config_by_internal_ids(
db_session, team_config_id, channel_config_id
)
if not config:
raise HTTPException(status_code=404, detail="Channel config not found")
config = update_teams_channel_config(
db_session,
config,
channel_name=config.channel_name, # Keep existing name
require_bot_mention=request.require_bot_mention,
persona_override_id=request.persona_override_id,
enabled=request.enabled,
)
db_session.commit()
return TeamsChannelConfigResponse.model_validate(config)

View File

@@ -0,0 +1,69 @@
"""Pydantic models for Teams bot API."""
from datetime import datetime
from pydantic import BaseModel
# === Bot Config ===
class TeamsBotConfigResponse(BaseModel):
configured: bool
created_at: datetime | None = None
class Config:
from_attributes = True
class TeamsBotConfigCreateRequest(BaseModel):
app_id: str
app_secret: str
azure_tenant_id: str | None = None
# === Team Config ===
class TeamsTeamConfigResponse(BaseModel):
id: int
team_id: str | None
team_name: str | None
registered_at: datetime | None
default_persona_id: int | None
enabled: bool
class Config:
from_attributes = True
class TeamsTeamConfigCreateResponse(BaseModel):
id: int
registration_key: str # Shown once!
class TeamsTeamConfigUpdateRequest(BaseModel):
enabled: bool
default_persona_id: int | None
# === Channel Config ===
class TeamsChannelConfigResponse(BaseModel):
id: int
team_config_id: int
channel_id: str
channel_name: str
require_bot_mention: bool
persona_override_id: int | None
enabled: bool
class Config:
from_attributes = True
class TeamsChannelConfigUpdateRequest(BaseModel):
require_bot_mention: bool
persona_override_id: int | None
enabled: bool

View File

@@ -0,0 +1,16 @@
"""Teams registration key generation and parsing."""
from onyx.onyxbot.registration import generate_registration_key
from onyx.onyxbot.registration import parse_registration_key
REGISTRATION_KEY_PREFIX: str = "teams"
def generate_teams_registration_key(tenant_id: str) -> str:
"""Generate a one-time registration key with embedded tenant_id."""
return generate_registration_key(REGISTRATION_KEY_PREFIX, tenant_id)
def parse_teams_registration_key(key: str) -> str | None:
"""Parse registration key to extract tenant_id."""
return parse_registration_key(REGISTRATION_KEY_PREFIX, key)

View File

@@ -32,6 +32,7 @@ class MessageOrigin(str, Enum):
SLACKBOT = "slackbot"
WIDGET = "widget"
DISCORDBOT = "discordbot"
TEAMSBOT = "teamsbot"
UNKNOWN = "unknown"
UNSET = "unset"

View File

@@ -8,3 +8,37 @@ dependencies = [
[tool.uv.sources]
onyx = { workspace = true }
[tool.mypy]
plugins = "sqlalchemy.ext.mypy.plugin"
mypy_path = "backend"
explicit_package_bases = true
disallow_untyped_defs = true
warn_unused_ignores = true
enable_error_code = ["possibly-undefined"]
strict_equality = true
# Patterns match paths whether mypy is run from backend/ (CI) or repo root (e.g. VS Code extension with target ./backend)
exclude = [
"(?:^|/)generated/",
"(?:^|/)\\.venv/",
"(?:^|/)onyx/server/features/build/sandbox/kubernetes/docker/skills/",
"(?:^|/)onyx/server/features/build/sandbox/kubernetes/docker/templates/",
]
[[tool.mypy.overrides]]
module = "alembic.versions.*"
disable_error_code = ["var-annotated"]
[[tool.mypy.overrides]]
module = "alembic_tenants.versions.*"
disable_error_code = ["var-annotated"]
[[tool.mypy.overrides]]
module = "generated.*"
follow_imports = "silent"
ignore_errors = true
[[tool.mypy.overrides]]
module = "transformers.*"
follow_imports = "skip"
ignore_errors = true

View File

@@ -17,6 +17,7 @@ aiohappyeyeballs==2.6.1
aiohttp==3.13.3
# via
# aiobotocore
# botbuilder-integration-aiohttp
# discord-py
# litellm
# onyx
@@ -67,6 +68,8 @@ attrs==25.4.0
# zeep
authlib==1.6.6
# via fastmcp
azure-core==1.38.2
# via msrest
babel==2.17.0
# via courlan
backoff==2.2.1
@@ -88,6 +91,25 @@ beautifulsoup4==4.12.3
# unstructured
billiard==4.2.3
# via celery
botbuilder-core==4.17.1
# via
# botbuilder-integration-aiohttp
# onyx
botbuilder-integration-aiohttp==4.17.1
# via onyx
botbuilder-schema==4.17.1
# via
# botbuilder-core
# botbuilder-integration-aiohttp
# botframework-connector
# botframework-streaming
botframework-connector==4.17.1
# via
# botbuilder-core
# botbuilder-integration-aiohttp
# botframework-streaming
botframework-streaming==4.17.1
# via botbuilder-core
boto3==1.39.11
# via
# aiobotocore
@@ -123,6 +145,7 @@ certifi==2025.11.12
# httpx
# hubspot-api-client
# kubernetes
# msrest
# opensearch-py
# requests
# sentry-sdk
@@ -444,6 +467,7 @@ iniconfig==2.3.0
# via pytest
isodate==0.7.2
# via
# msrest
# python3-saml
# zeep
jaraco-classes==3.4.0
@@ -474,6 +498,8 @@ joblib==1.5.2
# via nltk
jsonpatch==1.33
# via langchain-core
jsonpickle==3.4.2
# via botbuilder-core
jsonpointer==3.0.0
# via jsonpatch
jsonref==1.1.0
@@ -528,7 +554,7 @@ lxml==5.3.0
# unstructured
# xmlsec
# zeep
lxml-html-clean==0.4.4
lxml-html-clean==0.4.3
# via lxml
magika==0.6.3
# via markitdown
@@ -573,12 +599,17 @@ mpmath==1.3.0
# via sympy
msal==1.34.0
# via
# botframework-connector
# office365-rest-python-client
# onyx
msgpack==1.1.2
# via distributed
msoffcrypto-tool==5.4.2
# via onyx
msrest==0.7.1
# via
# botbuilder-schema
# botframework-connector
multidict==6.7.0
# via
# aiobotocore
@@ -796,6 +827,7 @@ pygments==2.19.2
# via rich
pyjwt==2.11.0
# via
# botframework-connector
# fastapi-users
# mcp
# msal
@@ -809,7 +841,7 @@ pypandoc-binary==1.16.2
# via onyx
pyparsing==3.2.5
# via httplib2
pypdf==6.7.5
pypdf==6.7.3
# via
# onyx
# unstructured-client
@@ -922,6 +954,7 @@ regex==2025.11.3
requests==2.32.5
# via
# atlassian-python-api
# azure-core
# braintrust
# cohere
# dropbox
@@ -940,6 +973,7 @@ requests==2.32.5
# markitdown
# matrix-client
# msal
# msrest
# office365-rest-python-client
# onyx
# opensearch-py
@@ -967,6 +1001,7 @@ requests-oauthlib==1.3.1
# google-auth-oauthlib
# jira
# kubernetes
# msrest
# onyx
requests-toolbelt==1.0.0
# via
@@ -1111,6 +1146,7 @@ typing-extensions==4.15.0
# aiosignal
# alembic
# anyio
# azure-core
# boto3-stubs
# braintrust
# cohere
@@ -1177,6 +1213,7 @@ uritemplate==4.2.0
urllib3==2.6.3
# via
# asana
# botbuilder-schema
# botocore
# courlan
# distributed
@@ -1239,7 +1276,9 @@ xmlsec==1.3.14
xmltodict==1.0.2
# via ddtrace
yarl==1.22.0
# via aiohttp
# via
# aiohttp
# botbuilder-integration-aiohttp
zeep==4.3.2
# via simple-salesforce
zict==3.0.0

View File

@@ -46,10 +46,10 @@ def _make_task(
run_fn: MagicMock | None = None,
) -> _PeriodicTaskDef:
return _PeriodicTaskDef(
name=name if name is not None else f"test-{uuid4().hex[:8]}",
name=name or f"test-{uuid4().hex[:8]}",
interval_seconds=interval,
lock_id=lock_id if lock_id is not None else _TEST_LOCK_BASE,
run_fn=run_fn if run_fn is not None else MagicMock(),
lock_id=lock_id or _TEST_LOCK_BASE,
run_fn=run_fn or MagicMock(),
)

View File

@@ -1,66 +0,0 @@
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
class ScimClient:
"""HTTP client for making authenticated SCIM v2 requests."""
@staticmethod
def _headers(raw_token: str) -> dict[str, str]:
return {
**GENERAL_HEADERS,
"Authorization": f"Bearer {raw_token}",
}
@staticmethod
def get(path: str, raw_token: str) -> requests.Response:
return requests.get(
f"{API_SERVER_URL}/scim/v2{path}",
headers=ScimClient._headers(raw_token),
timeout=60,
)
@staticmethod
def post(path: str, raw_token: str, json: dict) -> requests.Response:
return requests.post(
f"{API_SERVER_URL}/scim/v2{path}",
json=json,
headers=ScimClient._headers(raw_token),
timeout=60,
)
@staticmethod
def put(path: str, raw_token: str, json: dict) -> requests.Response:
return requests.put(
f"{API_SERVER_URL}/scim/v2{path}",
json=json,
headers=ScimClient._headers(raw_token),
timeout=60,
)
@staticmethod
def patch(path: str, raw_token: str, json: dict) -> requests.Response:
return requests.patch(
f"{API_SERVER_URL}/scim/v2{path}",
json=json,
headers=ScimClient._headers(raw_token),
timeout=60,
)
@staticmethod
def delete(path: str, raw_token: str) -> requests.Response:
return requests.delete(
f"{API_SERVER_URL}/scim/v2{path}",
headers=ScimClient._headers(raw_token),
timeout=60,
)
@staticmethod
def get_no_auth(path: str) -> requests.Response:
return requests.get(
f"{API_SERVER_URL}/scim/v2{path}",
headers=GENERAL_HEADERS,
timeout=60,
)

View File

@@ -1,6 +1,7 @@
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestScimToken
from tests.integration.common_utils.test_models import DATestUser
@@ -50,3 +51,29 @@ class ScimTokenManager:
created_at=data["created_at"],
last_used_at=data.get("last_used_at"),
)
@staticmethod
def get_scim_headers(raw_token: str) -> dict[str, str]:
return {
**GENERAL_HEADERS,
"Authorization": f"Bearer {raw_token}",
}
@staticmethod
def scim_get(
path: str,
raw_token: str,
) -> requests.Response:
return requests.get(
f"{API_SERVER_URL}/scim/v2{path}",
headers=ScimTokenManager.get_scim_headers(raw_token),
timeout=60,
)
@staticmethod
def scim_get_no_auth(path: str) -> requests.Response:
return requests.get(
f"{API_SERVER_URL}/scim/v2{path}",
headers=GENERAL_HEADERS,
timeout=60,
)

View File

@@ -1,160 +0,0 @@
"""Integration test for the full user-file lifecycle in no-vector-DB mode.
Covers: upload → COMPLETED → unlink from project → delete → gone.
The entire lifecycle is handled by FastAPI BackgroundTasks (no Celery workers
needed). The conftest-level ``pytestmark`` ensures these tests are skipped
when the server is running with vector DB enabled.
"""
import time
from uuid import UUID
import requests
from onyx.db.enums import UserFileStatus
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.project import ProjectManager
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser
POLL_INTERVAL_SECONDS = 1
POLL_TIMEOUT_SECONDS = 30
def _poll_file_status(
file_id: UUID,
user: DATestUser,
target_status: UserFileStatus,
timeout: int = POLL_TIMEOUT_SECONDS,
) -> None:
"""Poll GET /user/projects/file/{file_id} until the file reaches *target_status*."""
deadline = time.time() + timeout
while time.time() < deadline:
resp = requests.get(
f"{API_SERVER_URL}/user/projects/file/{file_id}",
headers=user.headers,
)
if resp.ok:
status = resp.json().get("status")
if status == target_status.value:
return
time.sleep(POLL_INTERVAL_SECONDS)
raise TimeoutError(
f"File {file_id} did not reach {target_status.value} within {timeout}s"
)
def _file_is_gone(file_id: UUID, user: DATestUser, timeout: int = 15) -> None:
"""Poll until GET /user/projects/file/{file_id} returns 404."""
deadline = time.time() + timeout
while time.time() < deadline:
resp = requests.get(
f"{API_SERVER_URL}/user/projects/file/{file_id}",
headers=user.headers,
)
if resp.status_code == 404:
return
time.sleep(POLL_INTERVAL_SECONDS)
raise TimeoutError(
f"File {file_id} still accessible after {timeout}s (expected 404)"
)
def test_file_upload_process_delete_lifecycle(
reset: None, # noqa: ARG001
admin_user: DATestUser,
llm_provider: DATestLLMProvider, # noqa: ARG001
) -> None:
"""Full lifecycle: upload → COMPLETED → unlink → delete → 404.
Validates that the API server handles all background processing
(via FastAPI BackgroundTasks) without any Celery workers running.
"""
project = ProjectManager.create(
name="lifecycle-test", user_performing_action=admin_user
)
file_content = b"Integration test file content for lifecycle verification."
upload_result = ProjectManager.upload_files(
project_id=project.id,
files=[("lifecycle.txt", file_content)],
user_performing_action=admin_user,
)
assert upload_result.user_files, "Expected at least one file in upload response"
user_file = upload_result.user_files[0]
file_id = user_file.id
_poll_file_status(file_id, admin_user, UserFileStatus.COMPLETED)
project_files = ProjectManager.get_project_files(project.id, admin_user)
assert any(
f.id == file_id for f in project_files
), "File should be listed in project files after processing"
# Unlink the file from the project so the delete endpoint will proceed
unlink_resp = requests.delete(
f"{API_SERVER_URL}/user/projects/{project.id}/files/{file_id}",
headers=admin_user.headers,
)
assert (
unlink_resp.status_code == 204
), f"Expected 204 on unlink, got {unlink_resp.status_code}: {unlink_resp.text}"
delete_resp = requests.delete(
f"{API_SERVER_URL}/user/projects/file/{file_id}",
headers=admin_user.headers,
)
assert (
delete_resp.ok
), f"Delete request failed: {delete_resp.status_code} {delete_resp.text}"
body = delete_resp.json()
assert (
body["has_associations"] is False
), f"File still has associations after unlink: {body}"
_file_is_gone(file_id, admin_user)
project_files_after = ProjectManager.get_project_files(project.id, admin_user)
assert not any(
f.id == file_id for f in project_files_after
), "Deleted file should not appear in project files"
def test_delete_blocked_while_associated(
reset: None, # noqa: ARG001
admin_user: DATestUser,
llm_provider: DATestLLMProvider, # noqa: ARG001
) -> None:
"""Deleting a file that still belongs to a project should return
has_associations=True without actually deleting the file."""
project = ProjectManager.create(
name="assoc-test", user_performing_action=admin_user
)
upload_result = ProjectManager.upload_files(
project_id=project.id,
files=[("assoc.txt", b"associated file content")],
user_performing_action=admin_user,
)
file_id = upload_result.user_files[0].id
_poll_file_status(file_id, admin_user, UserFileStatus.COMPLETED)
# Attempt to delete while still linked
delete_resp = requests.delete(
f"{API_SERVER_URL}/user/projects/file/{file_id}",
headers=admin_user.headers,
)
assert delete_resp.ok
body = delete_resp.json()
assert body["has_associations"] is True, "Should report existing associations"
assert project.name in body["project_names"]
# File should still be accessible
get_resp = requests.get(
f"{API_SERVER_URL}/user/projects/file/{file_id}",
headers=admin_user.headers,
)
assert get_resp.status_code == 200, "File should still exist after blocked delete"

View File

@@ -1,552 +0,0 @@
"""Integration tests for SCIM group provisioning endpoints.
Covers the full group lifecycle as driven by an IdP (Okta / Azure AD):
1. Create a group via POST /Groups
2. Retrieve a group via GET /Groups/{id}
3. List, filter, and paginate groups via GET /Groups
4. Replace a group via PUT /Groups/{id}
5. Patch a group (add/remove members, rename) via PATCH /Groups/{id}
6. Delete a group via DELETE /Groups/{id}
7. Error cases: duplicate name, not-found, invalid member IDs
All tests are parameterized across IdP request styles (Okta sends lowercase
PATCH ops; Entra sends capitalized ops like ``"Replace"``). The server
normalizes both — these tests verify that.
Auth tests live in test_scim_tokens.py.
User lifecycle tests live in test_scim_users.py.
"""
import pytest
import requests
from onyx.auth.schemas import UserRole
from tests.integration.common_utils.managers.scim_client import ScimClient
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group"
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
SCIM_PATCH_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:PatchOp"
@pytest.fixture(scope="module", params=["okta", "entra"])
def idp_style(request: pytest.FixtureRequest) -> str:
"""Parameterized IdP style — runs every test with both Okta and Entra request formats."""
return request.param
@pytest.fixture(scope="module")
def scim_token(idp_style: str) -> str:
"""Create a single SCIM token shared across all tests in this module.
Creating a new token revokes the previous one, so we create exactly once
per IdP-style run and reuse. Uses UserManager directly to avoid
fixture-scope conflicts with the function-scoped admin_user fixture.
"""
from tests.integration.common_utils.constants import ADMIN_USER_NAME
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.managers.user import build_email
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
try:
admin = UserManager.create(name=ADMIN_USER_NAME)
except Exception:
admin = UserManager.login_as_user(
DATestUser(
id="",
email=build_email(ADMIN_USER_NAME),
password=DEFAULT_PASSWORD,
headers=GENERAL_HEADERS,
role=UserRole.ADMIN,
is_active=True,
)
)
token = ScimTokenManager.create(
name=f"scim-group-tests-{idp_style}",
user_performing_action=admin,
).raw_token
assert token is not None
return token
def _make_group_resource(
display_name: str,
external_id: str | None = None,
members: list[dict] | None = None,
) -> dict:
"""Build a minimal SCIM GroupResource payload."""
resource: dict = {
"schemas": [SCIM_GROUP_SCHEMA],
"displayName": display_name,
}
if external_id is not None:
resource["externalId"] = external_id
if members is not None:
resource["members"] = members
return resource
def _make_user_resource(email: str, external_id: str) -> dict:
"""Build a minimal SCIM UserResource payload for member creation."""
return {
"schemas": [SCIM_USER_SCHEMA],
"userName": email,
"externalId": external_id,
"name": {"givenName": "Test", "familyName": "User"},
"active": True,
}
def _make_patch_request(operations: list[dict], idp_style: str = "okta") -> dict:
"""Build a SCIM PatchOp payload, applying IdP-specific operation casing.
Entra sends capitalized operations (e.g. ``"Replace"`` instead of
``"replace"``). The server's ``normalize_operation`` validator lowercases
them — these tests verify that both casings are accepted.
"""
cased_operations = []
for operation in operations:
cased = dict(operation)
if idp_style == "entra":
cased["op"] = operation["op"].capitalize()
cased_operations.append(cased)
return {
"schemas": [SCIM_PATCH_SCHEMA],
"Operations": cased_operations,
}
def _create_scim_user(token: str, email: str, external_id: str) -> requests.Response:
return ScimClient.post(
"/Users", token, json=_make_user_resource(email, external_id)
)
def _create_scim_group(
token: str,
display_name: str,
external_id: str | None = None,
members: list[dict] | None = None,
) -> requests.Response:
return ScimClient.post(
"/Groups",
token,
json=_make_group_resource(display_name, external_id, members),
)
# ------------------------------------------------------------------
# Lifecycle: create → get → list → replace → patch → delete
# ------------------------------------------------------------------
def test_create_group(scim_token: str, idp_style: str) -> None:
"""POST /Groups creates a group and returns 201."""
name = f"Engineering {idp_style}"
resp = _create_scim_group(scim_token, name, external_id=f"ext-eng-{idp_style}")
assert resp.status_code == 201
body = resp.json()
assert body["displayName"] == name
assert body["externalId"] == f"ext-eng-{idp_style}"
assert body["id"] # integer ID assigned by server
assert body["meta"]["resourceType"] == "Group"
def test_create_group_with_members(scim_token: str, idp_style: str) -> None:
"""POST /Groups with members populates the member list."""
user = _create_scim_user(
scim_token, f"grp_member1_{idp_style}@example.com", f"ext-gm-{idp_style}"
).json()
resp = _create_scim_group(
scim_token,
f"Backend Team {idp_style}",
external_id=f"ext-backend-{idp_style}",
members=[{"value": user["id"]}],
)
assert resp.status_code == 201
body = resp.json()
member_ids = [m["value"] for m in body["members"]]
assert user["id"] in member_ids
def test_get_group(scim_token: str, idp_style: str) -> None:
"""GET /Groups/{id} returns the group resource including members."""
user = _create_scim_user(
scim_token, f"grp_get_m_{idp_style}@example.com", f"ext-ggm-{idp_style}"
).json()
created = _create_scim_group(
scim_token,
f"Frontend Team {idp_style}",
external_id=f"ext-fe-{idp_style}",
members=[{"value": user["id"]}],
).json()
resp = ScimClient.get(f"/Groups/{created['id']}", scim_token)
assert resp.status_code == 200
body = resp.json()
assert body["id"] == created["id"]
assert body["displayName"] == f"Frontend Team {idp_style}"
assert body["externalId"] == f"ext-fe-{idp_style}"
member_ids = [m["value"] for m in body["members"]]
assert user["id"] in member_ids
def test_list_groups(scim_token: str, idp_style: str) -> None:
"""GET /Groups returns a ListResponse containing provisioned groups."""
name = f"DevOps Team {idp_style}"
_create_scim_group(scim_token, name, external_id=f"ext-devops-{idp_style}")
resp = ScimClient.get("/Groups", scim_token)
assert resp.status_code == 200
body = resp.json()
assert body["totalResults"] >= 1
names = [r["displayName"] for r in body["Resources"]]
assert name in names
def test_list_groups_pagination(scim_token: str, idp_style: str) -> None:
"""GET /Groups with startIndex and count returns correct pagination."""
_create_scim_group(
scim_token, f"Page Group A {idp_style}", external_id=f"ext-page-a-{idp_style}"
)
_create_scim_group(
scim_token, f"Page Group B {idp_style}", external_id=f"ext-page-b-{idp_style}"
)
resp = ScimClient.get("/Groups?startIndex=1&count=1", scim_token)
assert resp.status_code == 200
body = resp.json()
assert body["startIndex"] == 1
assert body["itemsPerPage"] == 1
assert body["totalResults"] >= 2
assert len(body["Resources"]) == 1
def test_filter_groups_by_display_name(scim_token: str, idp_style: str) -> None:
"""GET /Groups?filter=displayName eq '...' returns only matching groups."""
name = f"Unique QA Team {idp_style}"
_create_scim_group(scim_token, name, external_id=f"ext-qa-filter-{idp_style}")
resp = ScimClient.get(f'/Groups?filter=displayName eq "{name}"', scim_token)
assert resp.status_code == 200
body = resp.json()
assert body["totalResults"] == 1
assert body["Resources"][0]["displayName"] == name
def test_filter_groups_by_external_id(scim_token: str, idp_style: str) -> None:
"""GET /Groups?filter=externalId eq '...' returns the matching group."""
ext_id = f"ext-unique-group-id-{idp_style}"
_create_scim_group(
scim_token, f"ExtId Filter Group {idp_style}", external_id=ext_id
)
resp = ScimClient.get(f'/Groups?filter=externalId eq "{ext_id}"', scim_token)
assert resp.status_code == 200
body = resp.json()
assert body["totalResults"] == 1
assert body["Resources"][0]["externalId"] == ext_id
def test_replace_group(scim_token: str, idp_style: str) -> None:
"""PUT /Groups/{id} replaces the group resource."""
created = _create_scim_group(
scim_token,
f"Original Name {idp_style}",
external_id=f"ext-replace-g-{idp_style}",
).json()
user = _create_scim_user(
scim_token, f"grp_replace_m_{idp_style}@example.com", f"ext-grm-{idp_style}"
).json()
updated_resource = _make_group_resource(
display_name=f"Renamed Group {idp_style}",
external_id=f"ext-replace-g-{idp_style}",
members=[{"value": user["id"]}],
)
resp = ScimClient.put(f"/Groups/{created['id']}", scim_token, json=updated_resource)
assert resp.status_code == 200
body = resp.json()
assert body["displayName"] == f"Renamed Group {idp_style}"
member_ids = [m["value"] for m in body["members"]]
assert user["id"] in member_ids
def test_replace_group_clears_members(scim_token: str, idp_style: str) -> None:
"""PUT /Groups/{id} with empty members removes all memberships."""
user = _create_scim_user(
scim_token, f"grp_clear_m_{idp_style}@example.com", f"ext-gcm-{idp_style}"
).json()
created = _create_scim_group(
scim_token,
f"Clear Members Group {idp_style}",
external_id=f"ext-clear-g-{idp_style}",
members=[{"value": user["id"]}],
).json()
assert len(created["members"]) == 1
resp = ScimClient.put(
f"/Groups/{created['id']}",
scim_token,
json=_make_group_resource(
f"Clear Members Group {idp_style}", f"ext-clear-g-{idp_style}", members=[]
),
)
assert resp.status_code == 200
assert resp.json()["members"] == []
def test_patch_add_member(scim_token: str, idp_style: str) -> None:
"""PATCH /Groups/{id} with op=add adds a member."""
created = _create_scim_group(
scim_token,
f"Patch Add Group {idp_style}",
external_id=f"ext-patch-add-{idp_style}",
).json()
user = _create_scim_user(
scim_token, f"grp_patch_add_{idp_style}@example.com", f"ext-gpa-{idp_style}"
).json()
resp = ScimClient.patch(
f"/Groups/{created['id']}",
scim_token,
json=_make_patch_request(
[{"op": "add", "path": "members", "value": [{"value": user["id"]}]}],
idp_style,
),
)
assert resp.status_code == 200
member_ids = [m["value"] for m in resp.json()["members"]]
assert user["id"] in member_ids
def test_patch_remove_member(scim_token: str, idp_style: str) -> None:
"""PATCH /Groups/{id} with op=remove removes a specific member."""
user = _create_scim_user(
scim_token, f"grp_patch_rm_{idp_style}@example.com", f"ext-gpr-{idp_style}"
).json()
created = _create_scim_group(
scim_token,
f"Patch Remove Group {idp_style}",
external_id=f"ext-patch-rm-{idp_style}",
members=[{"value": user["id"]}],
).json()
assert len(created["members"]) == 1
resp = ScimClient.patch(
f"/Groups/{created['id']}",
scim_token,
json=_make_patch_request(
[
{
"op": "remove",
"path": f'members[value eq "{user["id"]}"]',
}
],
idp_style,
),
)
assert resp.status_code == 200
assert resp.json()["members"] == []
def test_patch_replace_members(scim_token: str, idp_style: str) -> None:
"""PATCH /Groups/{id} with op=replace on members swaps the entire list."""
user_a = _create_scim_user(
scim_token, f"grp_repl_a_{idp_style}@example.com", f"ext-gra-{idp_style}"
).json()
user_b = _create_scim_user(
scim_token, f"grp_repl_b_{idp_style}@example.com", f"ext-grb-{idp_style}"
).json()
created = _create_scim_group(
scim_token,
f"Patch Replace Group {idp_style}",
external_id=f"ext-patch-repl-{idp_style}",
members=[{"value": user_a["id"]}],
).json()
# Replace member list: swap A for B
resp = ScimClient.patch(
f"/Groups/{created['id']}",
scim_token,
json=_make_patch_request(
[
{
"op": "replace",
"path": "members",
"value": [{"value": user_b["id"]}],
}
],
idp_style,
),
)
assert resp.status_code == 200
member_ids = [m["value"] for m in resp.json()["members"]]
assert user_b["id"] in member_ids
assert user_a["id"] not in member_ids
def test_patch_rename_group(scim_token: str, idp_style: str) -> None:
"""PATCH /Groups/{id} with op=replace on displayName renames the group."""
created = _create_scim_group(
scim_token,
f"Old Group Name {idp_style}",
external_id=f"ext-rename-g-{idp_style}",
).json()
new_name = f"New Group Name {idp_style}"
resp = ScimClient.patch(
f"/Groups/{created['id']}",
scim_token,
json=_make_patch_request(
[{"op": "replace", "path": "displayName", "value": new_name}],
idp_style,
),
)
assert resp.status_code == 200
assert resp.json()["displayName"] == new_name
# Confirm via GET
get_resp = ScimClient.get(f"/Groups/{created['id']}", scim_token)
assert get_resp.json()["displayName"] == new_name
def test_delete_group(scim_token: str, idp_style: str) -> None:
"""DELETE /Groups/{id} removes the group."""
created = _create_scim_group(
scim_token,
f"Delete Me Group {idp_style}",
external_id=f"ext-del-g-{idp_style}",
).json()
resp = ScimClient.delete(f"/Groups/{created['id']}", scim_token)
assert resp.status_code == 204
# Second DELETE returns 404 (group hard-deleted)
resp2 = ScimClient.delete(f"/Groups/{created['id']}", scim_token)
assert resp2.status_code == 404
def test_delete_group_preserves_members(scim_token: str, idp_style: str) -> None:
"""DELETE /Groups/{id} removes memberships but does not deactivate users."""
user = _create_scim_user(
scim_token, f"grp_del_member_{idp_style}@example.com", f"ext-gdm-{idp_style}"
).json()
created = _create_scim_group(
scim_token,
f"Delete With Members {idp_style}",
external_id=f"ext-del-wm-{idp_style}",
members=[{"value": user["id"]}],
).json()
resp = ScimClient.delete(f"/Groups/{created['id']}", scim_token)
assert resp.status_code == 204
# User should still be active and retrievable
user_resp = ScimClient.get(f"/Users/{user['id']}", scim_token)
assert user_resp.status_code == 200
assert user_resp.json()["active"] is True
# ------------------------------------------------------------------
# Error cases
# ------------------------------------------------------------------
def test_create_group_duplicate_name(scim_token: str, idp_style: str) -> None:
"""POST /Groups with an already-taken displayName returns 409."""
name = f"Dup Name Group {idp_style}"
resp1 = _create_scim_group(scim_token, name, external_id=f"ext-dup-g1-{idp_style}")
assert resp1.status_code == 201
resp2 = _create_scim_group(scim_token, name, external_id=f"ext-dup-g2-{idp_style}")
assert resp2.status_code == 409
def test_get_nonexistent_group(scim_token: str) -> None:
"""GET /Groups/{bad-id} returns 404."""
resp = ScimClient.get("/Groups/999999999", scim_token)
assert resp.status_code == 404
def test_create_group_with_invalid_member(scim_token: str, idp_style: str) -> None:
"""POST /Groups with a non-existent member UUID returns 400."""
resp = _create_scim_group(
scim_token,
f"Bad Member Group {idp_style}",
external_id=f"ext-bad-m-{idp_style}",
members=[{"value": "00000000-0000-0000-0000-000000000000"}],
)
assert resp.status_code == 400
assert "not found" in resp.json()["detail"].lower()
def test_patch_add_nonexistent_member(scim_token: str, idp_style: str) -> None:
"""PATCH /Groups/{id} adding a non-existent member returns 400."""
created = _create_scim_group(
scim_token,
f"Patch Bad Member Group {idp_style}",
external_id=f"ext-pbm-{idp_style}",
).json()
resp = ScimClient.patch(
f"/Groups/{created['id']}",
scim_token,
json=_make_patch_request(
[
{
"op": "add",
"path": "members",
"value": [{"value": "00000000-0000-0000-0000-000000000000"}],
}
],
idp_style,
),
)
assert resp.status_code == 400
assert "not found" in resp.json()["detail"].lower()
def test_patch_add_duplicate_member_is_idempotent(
scim_token: str, idp_style: str
) -> None:
"""PATCH /Groups/{id} adding an already-present member succeeds silently."""
user = _create_scim_user(
scim_token, f"grp_dup_add_{idp_style}@example.com", f"ext-gda-{idp_style}"
).json()
created = _create_scim_group(
scim_token,
f"Idempotent Add Group {idp_style}",
external_id=f"ext-idem-g-{idp_style}",
members=[{"value": user["id"]}],
).json()
assert len(created["members"]) == 1
# Add same member again
resp = ScimClient.patch(
f"/Groups/{created['id']}",
scim_token,
json=_make_patch_request(
[{"op": "add", "path": "members", "value": [{"value": user["id"]}]}],
idp_style,
),
)
assert resp.status_code == 200
assert len(resp.json()["members"]) == 1 # still just one member

View File

@@ -15,7 +15,6 @@ import time
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.scim_client import ScimClient
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
@@ -40,7 +39,7 @@ def test_scim_token_lifecycle(admin_user: DATestUser) -> None:
assert active == token.model_copy(update={"raw_token": None})
# Token works for SCIM requests
response = ScimClient.get("/Users", token.raw_token)
response = ScimTokenManager.scim_get("/Users", token.raw_token)
assert response.status_code == 200
body = response.json()
assert "Resources" in body
@@ -55,7 +54,7 @@ def test_scim_token_rotation_revokes_previous(admin_user: DATestUser) -> None:
)
assert first.raw_token is not None
response = ScimClient.get("/Users", first.raw_token)
response = ScimTokenManager.scim_get("/Users", first.raw_token)
assert response.status_code == 200
# Create second token — should revoke first
@@ -70,22 +69,25 @@ def test_scim_token_rotation_revokes_previous(admin_user: DATestUser) -> None:
assert active == second.model_copy(update={"raw_token": None})
# First token rejected, second works
assert ScimClient.get("/Users", first.raw_token).status_code == 401
assert ScimClient.get("/Users", second.raw_token).status_code == 200
assert ScimTokenManager.scim_get("/Users", first.raw_token).status_code == 401
assert ScimTokenManager.scim_get("/Users", second.raw_token).status_code == 200
def test_scim_request_without_token_rejected(
admin_user: DATestUser, # noqa: ARG001
) -> None:
"""SCIM endpoints reject requests with no Authorization header."""
assert ScimClient.get_no_auth("/Users").status_code == 401
assert ScimTokenManager.scim_get_no_auth("/Users").status_code == 401
def test_scim_request_with_bad_token_rejected(
admin_user: DATestUser, # noqa: ARG001
) -> None:
"""SCIM endpoints reject requests with an invalid token."""
assert ScimClient.get("/Users", "onyx_scim_bogus_token_value").status_code == 401
assert (
ScimTokenManager.scim_get("/Users", "onyx_scim_bogus_token_value").status_code
== 401
)
def test_non_admin_cannot_create_token(
@@ -137,7 +139,7 @@ def test_service_discovery_no_auth_required(
) -> None:
"""Service discovery endpoints work without any authentication."""
for path in ["/ServiceProviderConfig", "/ResourceTypes", "/Schemas"]:
response = ScimClient.get_no_auth(path)
response = ScimTokenManager.scim_get_no_auth(path)
assert response.status_code == 200, f"{path} returned {response.status_code}"
@@ -156,7 +158,7 @@ def test_last_used_at_updated_after_scim_request(
assert active.last_used_at is None
# Make a SCIM request, then verify last_used_at is set
assert ScimClient.get("/Users", token.raw_token).status_code == 200
assert ScimTokenManager.scim_get("/Users", token.raw_token).status_code == 200
time.sleep(0.5)
active_after = ScimTokenManager.get_active(user_performing_action=admin_user)

View File

@@ -1,520 +0,0 @@
"""Integration tests for SCIM user provisioning endpoints.
Covers the full user lifecycle as driven by an IdP (Okta / Azure AD):
1. Create a user via POST /Users
2. Retrieve a user via GET /Users/{id}
3. List, filter, and paginate users via GET /Users
4. Replace a user via PUT /Users/{id}
5. Patch a user (deactivate/reactivate) via PATCH /Users/{id}
6. Delete a user via DELETE /Users/{id}
7. Error cases: missing externalId, duplicate email, not-found, seat limit
All tests are parameterized across IdP request styles:
- **Okta**: lowercase PATCH ops, minimal payloads (core schema only).
- **Entra**: capitalized ops (``"Replace"``), enterprise extension data
(department, manager), and structured email arrays.
The server normalizes both — these tests verify that all IdP-specific fields
are accepted and round-tripped correctly.
Auth, revoked-token, and service-discovery tests live in test_scim_tokens.py.
"""
from datetime import datetime
from datetime import timedelta
from datetime import timezone
import pytest
import redis
import requests
from ee.onyx.server.license.models import LicenseMetadata
from ee.onyx.server.license.models import LicenseSource
from ee.onyx.server.license.models import PlanType
from onyx.auth.schemas import UserRole
from onyx.configs.app_configs import REDIS_DB_NUMBER
from onyx.configs.app_configs import REDIS_HOST
from onyx.configs.app_configs import REDIS_PORT
from onyx.server.settings.models import ApplicationStatus
from tests.integration.common_utils.managers.scim_client import ScimClient
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
SCIM_ENTERPRISE_USER_SCHEMA = (
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
)
SCIM_PATCH_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:PatchOp"
_LICENSE_REDIS_KEY = "public:license:metadata"
@pytest.fixture(scope="module", params=["okta", "entra"])
def idp_style(request: pytest.FixtureRequest) -> str:
"""Parameterized IdP style — runs every test with both Okta and Entra request formats."""
return request.param
@pytest.fixture(scope="module")
def scim_token(idp_style: str) -> str:
"""Create a single SCIM token shared across all tests in this module.
Creating a new token revokes the previous one, so we create exactly once
per IdP-style run and reuse. Uses UserManager directly to avoid
fixture-scope conflicts with the function-scoped admin_user fixture.
"""
from tests.integration.common_utils.constants import ADMIN_USER_NAME
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.managers.user import build_email
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
try:
admin = UserManager.create(name=ADMIN_USER_NAME)
except Exception:
admin = UserManager.login_as_user(
DATestUser(
id="",
email=build_email(ADMIN_USER_NAME),
password=DEFAULT_PASSWORD,
headers=GENERAL_HEADERS,
role=UserRole.ADMIN,
is_active=True,
)
)
token = ScimTokenManager.create(
name=f"scim-user-tests-{idp_style}",
user_performing_action=admin,
).raw_token
assert token is not None
return token
def _make_user_resource(
email: str,
external_id: str,
given_name: str = "Test",
family_name: str = "User",
active: bool = True,
idp_style: str = "okta",
department: str | None = None,
manager_id: str | None = None,
) -> dict:
"""Build a SCIM UserResource payload appropriate for the IdP style.
Entra sends richer payloads including enterprise extension data (department,
manager), structured email arrays, and the enterprise schema URN. Okta sends
minimal payloads with just core user fields.
"""
resource: dict = {
"schemas": [SCIM_USER_SCHEMA],
"userName": email,
"externalId": external_id,
"name": {
"givenName": given_name,
"familyName": family_name,
},
"active": active,
}
if idp_style == "entra":
dept = department or "Engineering"
mgr = manager_id or "mgr-ext-001"
resource["schemas"].append(SCIM_ENTERPRISE_USER_SCHEMA)
resource[SCIM_ENTERPRISE_USER_SCHEMA] = {
"department": dept,
"manager": {"value": mgr},
}
resource["emails"] = [
{"value": email, "type": "work", "primary": True},
]
return resource
def _make_patch_request(operations: list[dict], idp_style: str = "okta") -> dict:
"""Build a SCIM PatchOp payload, applying IdP-specific operation casing.
Entra sends capitalized operations (e.g. ``"Replace"`` instead of
``"replace"``). The server's ``normalize_operation`` validator lowercases
them — these tests verify that both casings are accepted.
"""
cased_operations = []
for operation in operations:
cased = dict(operation)
if idp_style == "entra":
cased["op"] = operation["op"].capitalize()
cased_operations.append(cased)
return {
"schemas": [SCIM_PATCH_SCHEMA],
"Operations": cased_operations,
}
def _create_scim_user(
token: str,
email: str,
external_id: str,
idp_style: str = "okta",
) -> requests.Response:
return ScimClient.post(
"/Users",
token,
json=_make_user_resource(email, external_id, idp_style=idp_style),
)
def _assert_entra_extension(
body: dict,
expected_department: str = "Engineering",
expected_manager: str = "mgr-ext-001",
) -> None:
"""Assert that Entra enterprise extension fields round-tripped correctly."""
assert SCIM_ENTERPRISE_USER_SCHEMA in body["schemas"]
ext = body[SCIM_ENTERPRISE_USER_SCHEMA]
assert ext["department"] == expected_department
assert ext["manager"]["value"] == expected_manager
def _assert_entra_emails(body: dict, expected_email: str) -> None:
"""Assert that structured email metadata round-tripped correctly."""
emails = body["emails"]
assert len(emails) >= 1
work_email = next(e for e in emails if e.get("type") == "work")
assert work_email["value"] == expected_email
assert work_email["primary"] is True
# ------------------------------------------------------------------
# Lifecycle: create -> get -> list -> replace -> patch -> delete
# ------------------------------------------------------------------
def test_create_user(scim_token: str, idp_style: str) -> None:
"""POST /Users creates a provisioned user and returns 201."""
email = f"scim_create_{idp_style}@example.com"
ext_id = f"ext-create-{idp_style}"
resp = _create_scim_user(scim_token, email, ext_id, idp_style)
assert resp.status_code == 201
body = resp.json()
assert body["userName"] == email
assert body["externalId"] == ext_id
assert body["active"] is True
assert body["id"] # UUID assigned by server
assert body["meta"]["resourceType"] == "User"
assert body["name"]["givenName"] == "Test"
assert body["name"]["familyName"] == "User"
if idp_style == "entra":
_assert_entra_extension(body)
_assert_entra_emails(body, email)
def test_get_user(scim_token: str, idp_style: str) -> None:
"""GET /Users/{id} returns the user resource with all stored fields."""
email = f"scim_get_{idp_style}@example.com"
ext_id = f"ext-get-{idp_style}"
created = _create_scim_user(scim_token, email, ext_id, idp_style).json()
resp = ScimClient.get(f"/Users/{created['id']}", scim_token)
assert resp.status_code == 200
body = resp.json()
assert body["id"] == created["id"]
assert body["userName"] == email
assert body["externalId"] == ext_id
assert body["name"]["givenName"] == "Test"
assert body["name"]["familyName"] == "User"
if idp_style == "entra":
_assert_entra_extension(body)
_assert_entra_emails(body, email)
def test_list_users(scim_token: str, idp_style: str) -> None:
"""GET /Users returns a ListResponse containing provisioned users."""
email = f"scim_list_{idp_style}@example.com"
_create_scim_user(scim_token, email, f"ext-list-{idp_style}", idp_style)
resp = ScimClient.get("/Users", scim_token)
assert resp.status_code == 200
body = resp.json()
assert body["totalResults"] >= 1
emails = [r["userName"] for r in body["Resources"]]
assert email in emails
def test_list_users_pagination(scim_token: str, idp_style: str) -> None:
"""GET /Users with startIndex and count returns correct pagination."""
_create_scim_user(
scim_token,
f"scim_page1_{idp_style}@example.com",
f"ext-page-1-{idp_style}",
idp_style,
)
_create_scim_user(
scim_token,
f"scim_page2_{idp_style}@example.com",
f"ext-page-2-{idp_style}",
idp_style,
)
resp = ScimClient.get("/Users?startIndex=1&count=1", scim_token)
assert resp.status_code == 200
body = resp.json()
assert body["startIndex"] == 1
assert body["itemsPerPage"] == 1
assert body["totalResults"] >= 2
assert len(body["Resources"]) == 1
def test_filter_users_by_username(scim_token: str, idp_style: str) -> None:
"""GET /Users?filter=userName eq '...' returns only matching users."""
email = f"scim_filter_{idp_style}@example.com"
_create_scim_user(scim_token, email, f"ext-filter-{idp_style}", idp_style)
resp = ScimClient.get(f'/Users?filter=userName eq "{email}"', scim_token)
assert resp.status_code == 200
body = resp.json()
assert body["totalResults"] == 1
assert body["Resources"][0]["userName"] == email
def test_replace_user(scim_token: str, idp_style: str) -> None:
"""PUT /Users/{id} replaces the user resource including enterprise fields."""
email = f"scim_replace_{idp_style}@example.com"
ext_id = f"ext-replace-{idp_style}"
created = _create_scim_user(scim_token, email, ext_id, idp_style).json()
updated_resource = _make_user_resource(
email=email,
external_id=ext_id,
given_name="Updated",
family_name="Name",
idp_style=idp_style,
department="Product",
)
resp = ScimClient.put(f"/Users/{created['id']}", scim_token, json=updated_resource)
assert resp.status_code == 200
body = resp.json()
assert body["name"]["givenName"] == "Updated"
assert body["name"]["familyName"] == "Name"
if idp_style == "entra":
_assert_entra_extension(body, expected_department="Product")
_assert_entra_emails(body, email)
def test_patch_deactivate_user(scim_token: str, idp_style: str) -> None:
"""PATCH /Users/{id} with active=false deactivates the user."""
created = _create_scim_user(
scim_token,
f"scim_deactivate_{idp_style}@example.com",
f"ext-deactivate-{idp_style}",
idp_style,
).json()
assert created["active"] is True
resp = ScimClient.patch(
f"/Users/{created['id']}",
scim_token,
json=_make_patch_request(
[{"op": "replace", "path": "active", "value": False}], idp_style
),
)
assert resp.status_code == 200
assert resp.json()["active"] is False
# Confirm via GET
get_resp = ScimClient.get(f"/Users/{created['id']}", scim_token)
assert get_resp.json()["active"] is False
def test_patch_reactivate_user(scim_token: str, idp_style: str) -> None:
"""PATCH active=true reactivates a previously deactivated user."""
created = _create_scim_user(
scim_token,
f"scim_reactivate_{idp_style}@example.com",
f"ext-reactivate-{idp_style}",
idp_style,
).json()
# Deactivate
deactivate_resp = ScimClient.patch(
f"/Users/{created['id']}",
scim_token,
json=_make_patch_request(
[{"op": "replace", "path": "active", "value": False}], idp_style
),
)
assert deactivate_resp.status_code == 200
assert deactivate_resp.json()["active"] is False
# Reactivate
resp = ScimClient.patch(
f"/Users/{created['id']}",
scim_token,
json=_make_patch_request(
[{"op": "replace", "path": "active", "value": True}], idp_style
),
)
assert resp.status_code == 200
assert resp.json()["active"] is True
def test_delete_user(scim_token: str, idp_style: str) -> None:
"""DELETE /Users/{id} deactivates and removes the SCIM mapping."""
created = _create_scim_user(
scim_token,
f"scim_delete_{idp_style}@example.com",
f"ext-delete-{idp_style}",
idp_style,
).json()
resp = ScimClient.delete(f"/Users/{created['id']}", scim_token)
assert resp.status_code == 204
# Second DELETE returns 404 per RFC 7644 §3.6 (mapping removed)
resp2 = ScimClient.delete(f"/Users/{created['id']}", scim_token)
assert resp2.status_code == 404
# ------------------------------------------------------------------
# Error cases
# ------------------------------------------------------------------
def test_create_user_missing_external_id(scim_token: str, idp_style: str) -> None:
"""POST /Users without externalId succeeds (RFC 7643: externalId is optional)."""
email = f"scim_no_extid_{idp_style}@example.com"
resp = ScimClient.post(
"/Users",
scim_token,
json={
"schemas": [SCIM_USER_SCHEMA],
"userName": email,
"active": True,
},
)
assert resp.status_code == 201
body = resp.json()
assert body["userName"] == email
assert body.get("externalId") is None
def test_create_user_duplicate_email(scim_token: str, idp_style: str) -> None:
"""POST /Users with an already-taken email returns 409."""
email = f"scim_dup_{idp_style}@example.com"
resp1 = _create_scim_user(scim_token, email, f"ext-dup-1-{idp_style}", idp_style)
assert resp1.status_code == 201
resp2 = _create_scim_user(scim_token, email, f"ext-dup-2-{idp_style}", idp_style)
assert resp2.status_code == 409
def test_get_nonexistent_user(scim_token: str) -> None:
"""GET /Users/{bad-id} returns 404."""
resp = ScimClient.get("/Users/00000000-0000-0000-0000-000000000000", scim_token)
assert resp.status_code == 404
def test_filter_users_by_external_id(scim_token: str, idp_style: str) -> None:
"""GET /Users?filter=externalId eq '...' returns the matching user."""
ext_id = f"ext-unique-filter-id-{idp_style}"
_create_scim_user(
scim_token, f"scim_extfilter_{idp_style}@example.com", ext_id, idp_style
)
resp = ScimClient.get(f'/Users?filter=externalId eq "{ext_id}"', scim_token)
assert resp.status_code == 200
body = resp.json()
assert body["totalResults"] == 1
assert body["Resources"][0]["externalId"] == ext_id
# ------------------------------------------------------------------
# Seat-limit enforcement
# ------------------------------------------------------------------
def _seed_license(r: redis.Redis, seats: int) -> None:
"""Write a LicenseMetadata entry into Redis with the given seat cap."""
now = datetime.now(timezone.utc)
metadata = LicenseMetadata(
tenant_id="public",
organization_name="Test Org",
seats=seats,
used_seats=0, # check_seat_availability recalculates from DB
plan_type=PlanType.ANNUAL,
issued_at=now,
expires_at=now + timedelta(days=365),
status=ApplicationStatus.ACTIVE,
source=LicenseSource.MANUAL_UPLOAD,
)
r.set(_LICENSE_REDIS_KEY, metadata.model_dump_json(), ex=300)
def test_create_user_seat_limit(scim_token: str, idp_style: str) -> None:
"""POST /Users returns 403 when the seat limit is reached."""
r = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB_NUMBER)
# admin_user already occupies 1 seat; cap at 1 -> full
_seed_license(r, seats=1)
try:
resp = _create_scim_user(
scim_token,
f"scim_blocked_{idp_style}@example.com",
f"ext-blocked-{idp_style}",
idp_style,
)
assert resp.status_code == 403
assert "seat" in resp.json()["detail"].lower()
finally:
r.delete(_LICENSE_REDIS_KEY)
def test_reactivate_user_seat_limit(scim_token: str, idp_style: str) -> None:
"""PATCH active=true returns 403 when the seat limit is reached."""
# Create and deactivate a user (before license is seeded)
created = _create_scim_user(
scim_token,
f"scim_reactivate_blocked_{idp_style}@example.com",
f"ext-reactivate-blocked-{idp_style}",
idp_style,
).json()
assert created["active"] is True
deactivate_resp = ScimClient.patch(
f"/Users/{created['id']}",
scim_token,
json=_make_patch_request(
[{"op": "replace", "path": "active", "value": False}], idp_style
),
)
assert deactivate_resp.status_code == 200
assert deactivate_resp.json()["active"] is False
r = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB_NUMBER)
# Seed license capped at current active users -> reactivation should fail
_seed_license(r, seats=1)
try:
resp = ScimClient.patch(
f"/Users/{created['id']}",
scim_token,
json=_make_patch_request(
[{"op": "replace", "path": "active", "value": True}], idp_style
),
)
assert resp.status_code == 403
assert "seat" in resp.json()["detail"].lower()
finally:
r.delete(_LICENSE_REDIS_KEY)

View File

@@ -1,20 +1,11 @@
"""Tests for license database CRUD operations."""
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from unittest.mock import MagicMock
from unittest.mock import patch
from ee.onyx.db.license import check_seat_availability
from ee.onyx.db.license import delete_license
from ee.onyx.db.license import get_license
from ee.onyx.db.license import upsert_license
from ee.onyx.server.license.models import LicenseMetadata
from ee.onyx.server.license.models import LicenseSource
from ee.onyx.server.license.models import PlanType
from onyx.db.models import License
from onyx.server.settings.models import ApplicationStatus
class TestGetLicense:
@@ -109,108 +100,3 @@ class TestDeleteLicense:
assert result is False
mock_session.delete.assert_not_called()
mock_session.commit.assert_not_called()
def _make_license_metadata(seats: int = 10) -> LicenseMetadata:
now = datetime.now(timezone.utc)
return LicenseMetadata(
tenant_id="public",
seats=seats,
used_seats=0,
plan_type=PlanType.ANNUAL,
issued_at=now,
expires_at=now + timedelta(days=365),
status=ApplicationStatus.ACTIVE,
source=LicenseSource.MANUAL_UPLOAD,
)
class TestCheckSeatAvailabilitySelfHosted:
"""Seat checks for self-hosted (MULTI_TENANT=False)."""
@patch("ee.onyx.db.license.get_license_metadata", return_value=None)
def test_no_license_means_unlimited(self, _mock_meta: MagicMock) -> None:
result = check_seat_availability(MagicMock(), seats_needed=1)
assert result.available is True
@patch("ee.onyx.db.license.get_used_seats", return_value=5)
@patch("ee.onyx.db.license.get_license_metadata")
def test_seats_available(self, mock_meta: MagicMock, _mock_used: MagicMock) -> None:
mock_meta.return_value = _make_license_metadata(seats=10)
result = check_seat_availability(MagicMock(), seats_needed=1)
assert result.available is True
@patch("ee.onyx.db.license.get_used_seats", return_value=10)
@patch("ee.onyx.db.license.get_license_metadata")
def test_seats_full_blocks_creation(
self, mock_meta: MagicMock, _mock_used: MagicMock
) -> None:
mock_meta.return_value = _make_license_metadata(seats=10)
result = check_seat_availability(MagicMock(), seats_needed=1)
assert result.available is False
assert result.error_message is not None
assert "10 of 10" in result.error_message
@patch("ee.onyx.db.license.get_used_seats", return_value=10)
@patch("ee.onyx.db.license.get_license_metadata")
def test_exactly_at_capacity_allows_no_more(
self, mock_meta: MagicMock, _mock_used: MagicMock
) -> None:
"""Filling to 100% is allowed; exceeding is not."""
mock_meta.return_value = _make_license_metadata(seats=10)
result = check_seat_availability(MagicMock(), seats_needed=1)
assert result.available is False
@patch("ee.onyx.db.license.get_used_seats", return_value=9)
@patch("ee.onyx.db.license.get_license_metadata")
def test_filling_to_capacity_is_allowed(
self, mock_meta: MagicMock, _mock_used: MagicMock
) -> None:
mock_meta.return_value = _make_license_metadata(seats=10)
result = check_seat_availability(MagicMock(), seats_needed=1)
assert result.available is True
class TestCheckSeatAvailabilityMultiTenant:
"""Seat checks for multi-tenant cloud (MULTI_TENANT=True).
Verifies that get_used_seats takes the MULTI_TENANT branch
and delegates to get_tenant_count.
"""
@patch("ee.onyx.db.license.MULTI_TENANT", True)
@patch(
"ee.onyx.server.tenants.user_mapping.get_tenant_count",
return_value=5,
)
@patch("ee.onyx.db.license.get_license_metadata")
def test_seats_available_multi_tenant(
self,
mock_meta: MagicMock,
mock_tenant_count: MagicMock,
) -> None:
mock_meta.return_value = _make_license_metadata(seats=10)
result = check_seat_availability(
MagicMock(), seats_needed=1, tenant_id="tenant-abc"
)
assert result.available is True
mock_tenant_count.assert_called_once_with("tenant-abc")
@patch("ee.onyx.db.license.MULTI_TENANT", True)
@patch(
"ee.onyx.server.tenants.user_mapping.get_tenant_count",
return_value=10,
)
@patch("ee.onyx.db.license.get_license_metadata")
def test_seats_full_multi_tenant(
self,
mock_meta: MagicMock,
mock_tenant_count: MagicMock,
) -> None:
mock_meta.return_value = _make_license_metadata(seats=10)
result = check_seat_availability(
MagicMock(), seats_needed=1, tenant_id="tenant-abc"
)
assert result.available is False
assert result.error_message is not None
mock_tenant_count.assert_called_once_with("tenant-abc")

View File

@@ -12,11 +12,12 @@ import aiohttp
import pytest
from onyx.chat.models import ChatFullResponse
from onyx.onyxbot.discord.api_client import OnyxAPIClient
from onyx.onyxbot.discord.constants import API_REQUEST_TIMEOUT
from onyx.onyxbot.discord.exceptions import APIConnectionError
from onyx.onyxbot.discord.exceptions import APIResponseError
from onyx.onyxbot.discord.exceptions import APITimeoutError
from onyx.onyxbot.api_client import OnyxAPIClient
from onyx.onyxbot.constants import API_REQUEST_TIMEOUT
from onyx.onyxbot.exceptions import APIConnectionError
from onyx.onyxbot.exceptions import APIResponseError
from onyx.onyxbot.exceptions import APITimeoutError
from onyx.server.query_and_chat.models import MessageOrigin
class MockAsyncContextManager:
@@ -43,7 +44,7 @@ class TestClientLifecycle:
@pytest.mark.asyncio
async def test_initialize_creates_session(self) -> None:
"""initialize() creates aiohttp session."""
client = OnyxAPIClient()
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
assert client._session is None
with patch("aiohttp.ClientSession") as mock_session_class:
@@ -57,13 +58,13 @@ class TestClientLifecycle:
def test_is_initialized_before_init(self) -> None:
"""is_initialized returns False before initialize()."""
client = OnyxAPIClient()
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
assert client.is_initialized is False
@pytest.mark.asyncio
async def test_is_initialized_after_init(self) -> None:
"""is_initialized returns True after initialize()."""
client = OnyxAPIClient()
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
with patch("aiohttp.ClientSession"):
await client.initialize()
@@ -73,7 +74,7 @@ class TestClientLifecycle:
@pytest.mark.asyncio
async def test_close_closes_session(self) -> None:
"""close() closes session and resets is_initialized."""
client = OnyxAPIClient()
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
mock_session = AsyncMock()
with patch("aiohttp.ClientSession", return_value=mock_session):
@@ -88,7 +89,7 @@ class TestClientLifecycle:
@pytest.mark.asyncio
async def test_send_message_not_initialized(self) -> None:
"""send_chat_message() before initialize() raises APIConnectionError."""
client = OnyxAPIClient()
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
with pytest.raises(APIConnectionError) as exc_info:
await client.send_chat_message("test", "api_key")
@@ -102,7 +103,7 @@ class TestSendChatMessage:
@pytest.mark.asyncio
async def test_send_message_success(self) -> None:
"""Valid request returns ChatFullResponse."""
client = OnyxAPIClient()
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
response_data = {
"answer": "Test response",
@@ -133,7 +134,7 @@ class TestSendChatMessage:
@pytest.mark.asyncio
async def test_send_message_with_persona(self) -> None:
"""persona_id is passed to API."""
client = OnyxAPIClient()
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
response_data = {"answer": "Response", "citations": [], "error_msg": None}
@@ -164,7 +165,7 @@ class TestSendChatMessage:
@pytest.mark.asyncio
async def test_send_message_401_error(self) -> None:
"""Invalid API key returns APIResponseError with 401."""
client = OnyxAPIClient()
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
mock_response = MagicMock()
mock_response.status = 401
@@ -184,7 +185,7 @@ class TestSendChatMessage:
@pytest.mark.asyncio
async def test_send_message_403_error(self) -> None:
"""Persona not accessible returns APIResponseError with 403."""
client = OnyxAPIClient()
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
mock_response = MagicMock()
mock_response.status = 403
@@ -204,7 +205,7 @@ class TestSendChatMessage:
@pytest.mark.asyncio
async def test_send_message_timeout(self) -> None:
"""Request timeout raises APITimeoutError."""
client = OnyxAPIClient()
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
mock_session = MagicMock()
mock_session.post = MagicMock(
@@ -221,7 +222,7 @@ class TestSendChatMessage:
@pytest.mark.asyncio
async def test_send_message_connection_error(self) -> None:
"""Network failure raises APIConnectionError."""
client = OnyxAPIClient()
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
mock_session = MagicMock()
mock_session.post = MagicMock(
@@ -240,7 +241,7 @@ class TestSendChatMessage:
@pytest.mark.asyncio
async def test_send_message_server_error(self) -> None:
"""500 response raises APIResponseError with 500."""
client = OnyxAPIClient()
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
mock_response = MagicMock()
mock_response.status = 500
@@ -265,7 +266,7 @@ class TestHealthCheck:
@pytest.mark.asyncio
async def test_health_check_success(self) -> None:
"""Server healthy returns True."""
client = OnyxAPIClient()
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
mock_response = MagicMock()
mock_response.status = 200
@@ -283,7 +284,7 @@ class TestHealthCheck:
@pytest.mark.asyncio
async def test_health_check_failure(self) -> None:
"""Server unhealthy returns False."""
client = OnyxAPIClient()
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
mock_response = MagicMock()
mock_response.status = 503
@@ -301,7 +302,7 @@ class TestHealthCheck:
@pytest.mark.asyncio
async def test_health_check_timeout(self) -> None:
"""Request times out returns False."""
client = OnyxAPIClient()
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
mock_session = MagicMock()
mock_session.get = MagicMock(
@@ -318,7 +319,7 @@ class TestHealthCheck:
@pytest.mark.asyncio
async def test_health_check_not_initialized(self) -> None:
"""Health check before initialize returns False."""
client = OnyxAPIClient()
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
result = await client.health_check()
assert result is False
@@ -330,7 +331,7 @@ class TestResponseParsing:
@pytest.mark.asyncio
async def test_response_malformed_json(self) -> None:
"""API returns invalid JSON raises exception."""
client = OnyxAPIClient()
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
mock_response = MagicMock()
mock_response.status = 200
@@ -349,7 +350,7 @@ class TestResponseParsing:
@pytest.mark.asyncio
async def test_response_with_error_msg(self) -> None:
"""200 status but error_msg present - warning logged, response returned."""
client = OnyxAPIClient()
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
response_data = {
"answer": "Partial response",
@@ -381,7 +382,7 @@ class TestResponseParsing:
@pytest.mark.asyncio
async def test_response_empty_answer(self) -> None:
"""answer field is empty string - handled gracefully."""
client = OnyxAPIClient()
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
response_data = {
"answer": "",
@@ -416,18 +417,18 @@ class TestClientConfiguration:
def test_default_timeout(self) -> None:
"""Client uses API_REQUEST_TIMEOUT by default."""
client = OnyxAPIClient()
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
assert client._timeout == API_REQUEST_TIMEOUT
def test_custom_timeout(self) -> None:
"""Client accepts custom timeout."""
client = OnyxAPIClient(timeout=60)
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT, timeout=60)
assert client._timeout == 60
@pytest.mark.asyncio
async def test_double_initialize_warning(self) -> None:
"""Calling initialize() twice logs warning but doesn't error."""
client = OnyxAPIClient()
client = OnyxAPIClient(origin=MessageOrigin.DISCORDBOT)
with patch("aiohttp.ClientSession") as mock_session_class:
mock_session = MagicMock()

View File

@@ -18,7 +18,7 @@ class TestCacheInitialization:
def test_cache_starts_empty(self) -> None:
"""New cache manager has empty caches."""
cache = DiscordCacheManager()
assert cache._guild_tenants == {}
assert cache._entity_tenants == {}
assert cache._api_keys == {}
assert cache.is_initialized is False
@@ -37,14 +37,14 @@ class TestCacheInitialization:
with (
patch(
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
"onyx.onyxbot.cache.get_all_tenant_ids",
return_value=["tenant1"],
),
patch(
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
"onyx.onyxbot.cache.fetch_ee_implementation_or_noop",
return_value=lambda: set(),
),
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
patch("onyx.onyxbot.cache.get_session_with_tenant") as mock_session,
patch(
"onyx.onyxbot.discord.cache.get_guild_configs",
return_value=[mock_config1, mock_config2],
@@ -61,10 +61,10 @@ class TestCacheInitialization:
await cache.refresh_all()
assert cache.is_initialized is True
assert 111111 in cache._guild_tenants
assert 222222 in cache._guild_tenants
assert cache._guild_tenants[111111] == "tenant1"
assert cache._guild_tenants[222222] == "tenant1"
assert 111111 in cache._entity_tenants
assert 222222 in cache._entity_tenants
assert cache._entity_tenants[111111] == "tenant1"
assert cache._entity_tenants[222222] == "tenant1"
@pytest.mark.asyncio
async def test_cache_refresh_provisions_api_key(self) -> None:
@@ -77,14 +77,14 @@ class TestCacheInitialization:
with (
patch(
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
"onyx.onyxbot.cache.get_all_tenant_ids",
return_value=["tenant1"],
),
patch(
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
"onyx.onyxbot.cache.fetch_ee_implementation_or_noop",
return_value=lambda: set(),
),
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
patch("onyx.onyxbot.cache.get_session_with_tenant") as mock_session,
patch(
"onyx.onyxbot.discord.cache.get_guild_configs",
return_value=[mock_config],
@@ -110,7 +110,7 @@ class TestCacheLookups:
def test_get_tenant_returns_correct(self) -> None:
"""Lookup registered guild returns correct tenant ID."""
cache = DiscordCacheManager()
cache._guild_tenants[123456] = "tenant1"
cache._entity_tenants[123456] = "tenant1"
result = cache.get_tenant(123456)
assert result == "tenant1"
@@ -140,7 +140,7 @@ class TestCacheLookups:
def test_get_all_guild_ids(self) -> None:
"""After loading returns all cached guild IDs."""
cache = DiscordCacheManager()
cache._guild_tenants = {111: "t1", 222: "t2", 333: "t1"}
cache._entity_tenants = {111: "t1", 222: "t2", 333: "t1"}
result = cache.get_all_guild_ids()
assert set(result) == {111, 222, 333}
@@ -159,7 +159,7 @@ class TestCacheUpdates:
mock_config.enabled = True
with (
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
patch("onyx.onyxbot.cache.get_session_with_tenant") as mock_session,
patch(
"onyx.onyxbot.discord.cache.get_guild_configs",
return_value=[mock_config],
@@ -187,7 +187,7 @@ class TestCacheUpdates:
mock_config.enabled = False # Disabled!
with (
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
patch("onyx.onyxbot.cache.get_session_with_tenant") as mock_session,
patch(
"onyx.onyxbot.discord.cache.get_guild_configs",
return_value=[mock_config],
@@ -205,7 +205,7 @@ class TestCacheUpdates:
def test_remove_guild(self) -> None:
"""remove_guild() removes guild from cache."""
cache = DiscordCacheManager()
cache._guild_tenants[111111] = "tenant1"
cache._entity_tenants[111111] = "tenant1"
cache.remove_guild(111111)
@@ -214,13 +214,13 @@ class TestCacheUpdates:
def test_clear_removes_all(self) -> None:
"""clear() empties all caches."""
cache = DiscordCacheManager()
cache._guild_tenants = {111: "t1", 222: "t2"}
cache._entity_tenants = {111: "t1", 222: "t2"}
cache._api_keys = {"t1": "key1", "t2": "key2"}
cache._initialized = True
cache.clear()
assert cache._guild_tenants == {}
assert cache._entity_tenants == {}
assert cache._api_keys == {}
assert cache.is_initialized is False
@@ -239,7 +239,7 @@ class TestThreadSafety:
call_count = 0
async def slow_refresh() -> tuple[list[int], str]:
async def slow_refresh(_tenant_id: str) -> tuple[list[int], str]:
nonlocal call_count
call_count += 1
# Simulate slow operation
@@ -248,11 +248,11 @@ class TestThreadSafety:
with (
patch(
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
"onyx.onyxbot.cache.get_all_tenant_ids",
return_value=["tenant1"],
),
patch(
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
"onyx.onyxbot.cache.fetch_ee_implementation_or_noop",
return_value=lambda: set(),
),
patch.object(cache, "_load_tenant_data", side_effect=slow_refresh),
@@ -271,7 +271,7 @@ class TestThreadSafety:
async def test_concurrent_read_write(self) -> None:
"""Read during refresh doesn't cause exceptions."""
cache = DiscordCacheManager()
cache._guild_tenants[111111] = "tenant1"
cache._entity_tenants[111111] = "tenant1"
async def read_loop() -> None:
for _ in range(10):
@@ -280,7 +280,7 @@ class TestThreadSafety:
async def write_loop() -> None:
for i in range(10):
cache._guild_tenants[200000 + i] = f"tenant{i}"
cache._entity_tenants[200000 + i] = f"tenant{i}"
await asyncio.sleep(0.001)
# Should not raise any exceptions
@@ -301,14 +301,14 @@ class TestAPIKeyProvisioning:
with (
patch(
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
"onyx.onyxbot.cache.get_all_tenant_ids",
return_value=["tenant1"],
),
patch(
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
"onyx.onyxbot.cache.fetch_ee_implementation_or_noop",
return_value=lambda: set(),
),
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
patch("onyx.onyxbot.cache.get_session_with_tenant") as mock_session,
patch(
"onyx.onyxbot.discord.cache.get_guild_configs",
return_value=[mock_config],
@@ -339,14 +339,14 @@ class TestAPIKeyProvisioning:
with (
patch(
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
"onyx.onyxbot.cache.get_all_tenant_ids",
return_value=["tenant1"],
),
patch(
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
"onyx.onyxbot.cache.fetch_ee_implementation_or_noop",
return_value=lambda: set(),
),
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
patch("onyx.onyxbot.cache.get_session_with_tenant") as mock_session,
patch(
"onyx.onyxbot.discord.cache.get_guild_configs",
return_value=[mock_config],
@@ -392,14 +392,14 @@ class TestGatedTenantHandling:
with (
patch(
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
"onyx.onyxbot.cache.get_all_tenant_ids",
return_value=["tenant1", "tenant2"],
),
patch(
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
"onyx.onyxbot.cache.fetch_ee_implementation_or_noop",
return_value=lambda: gated_tenants,
),
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
patch("onyx.onyxbot.cache.get_session_with_tenant") as mock_session,
patch(
"onyx.onyxbot.discord.cache.get_guild_configs",
side_effect=mock_get_configs,
@@ -416,9 +416,9 @@ class TestGatedTenantHandling:
await cache.refresh_all()
# Only tenant1 should be loaded (tenant2 is gated)
assert "tenant1" in cache._api_keys and 111111 in cache._guild_tenants
assert "tenant1" in cache._api_keys and 111111 in cache._entity_tenants
# tenant2's guilds should NOT be in cache
assert "tenant2" not in cache._api_keys and 222222 not in cache._guild_tenants
assert "tenant2" not in cache._api_keys and 222222 not in cache._entity_tenants
@pytest.mark.asyncio
async def test_gated_check_calls_ee_function(self) -> None:
@@ -427,14 +427,14 @@ class TestGatedTenantHandling:
with (
patch(
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
"onyx.onyxbot.cache.get_all_tenant_ids",
return_value=["tenant1"],
),
patch(
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
"onyx.onyxbot.cache.fetch_ee_implementation_or_noop",
return_value=lambda: set(),
) as mock_ee,
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
patch("onyx.onyxbot.cache.get_session_with_tenant") as mock_session,
patch(
"onyx.onyxbot.discord.cache.get_guild_configs",
return_value=[],
@@ -459,14 +459,14 @@ class TestGatedTenantHandling:
with (
patch(
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
"onyx.onyxbot.cache.get_all_tenant_ids",
return_value=["tenant1"],
),
patch(
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
"onyx.onyxbot.cache.fetch_ee_implementation_or_noop",
return_value=lambda: set(), # No gated tenants
),
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
patch("onyx.onyxbot.cache.get_session_with_tenant") as mock_session,
patch(
"onyx.onyxbot.discord.cache.get_guild_configs",
return_value=[mock_config],
@@ -499,16 +499,16 @@ class TestCacheErrorHandling:
nonlocal call_count
call_count += 1
if tenant_id == "tenant1":
raise Exception("Tenant 1 error")
raise ConnectionError("Tenant 1 connection failed")
return ([222222], "api_key")
with (
patch(
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
"onyx.onyxbot.cache.get_all_tenant_ids",
return_value=["tenant1", "tenant2"],
),
patch(
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
"onyx.onyxbot.cache.fetch_ee_implementation_or_noop",
return_value=lambda: set(),
),
patch.object(cache, "_load_tenant_data", side_effect=mock_load),

View File

@@ -0,0 +1,105 @@
"""Fixtures for Teams bot unit tests."""
import random
from collections.abc import Callable
from unittest.mock import MagicMock
import pytest
@pytest.fixture
def mock_team_config_enabled() -> MagicMock:
"""Team config that is enabled."""
config = MagicMock()
config.id = 1
config.team_id = "team-abc-123"
config.enabled = True
config.default_persona_id = 1
return config
@pytest.fixture
def mock_team_config_disabled() -> MagicMock:
"""Team config that is disabled."""
config = MagicMock()
config.id = 2
config.team_id = "team-abc-123"
config.enabled = False
config.default_persona_id = None
return config
@pytest.fixture
def mock_channel_config_factory() -> Callable[..., MagicMock]:
"""Factory fixture for creating channel configs with various settings."""
def _make_config(
enabled: bool = True,
require_bot_mention: bool = True,
persona_override_id: int | None = None,
) -> MagicMock:
config = MagicMock()
config.id = random.randint(1, 1000)
config.channel_id = "19:channel-xyz@thread.tacv2"
config.enabled = enabled
config.require_bot_mention = require_bot_mention
config.persona_override_id = persona_override_id
return config
return _make_config
@pytest.fixture
def sample_activity_dict() -> dict:
"""Sample Teams Activity as a dict."""
return {
"type": "message",
"text": "<at>Onyx</at> What is our deployment process?",
"from": {
"id": "29:user-id-123",
"name": "Test User",
},
"recipient": {
"id": "28:bot-id-456",
"name": "Onyx",
},
"channelData": {
"team": {
"id": "team-abc-123",
"name": "Engineering",
},
"channel": {
"id": "19:channel-xyz@thread.tacv2",
"name": "general",
},
},
"entities": [
{
"type": "mention",
"mentioned": {
"id": "28:bot-id-456",
"name": "Onyx",
},
"text": "<at>Onyx</at>",
}
],
}
@pytest.fixture
def sample_dm_activity_dict() -> dict:
"""Sample Teams DM Activity (no team context)."""
return {
"type": "message",
"text": "Hello bot",
"from": {
"id": "29:user-id-123",
"name": "Test User",
},
"recipient": {
"id": "28:bot-id-456",
"name": "Onyx",
},
"channelData": {},
"entities": [],
}

View File

@@ -0,0 +1,214 @@
"""Unit tests for Teams bot cache manager."""
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.onyxbot.teams.cache import TeamsCacheManager
class TestCacheInitialization:
"""Tests for cache initialization."""
def test_cache_starts_empty(self) -> None:
cache = TeamsCacheManager()
assert cache._entity_tenants == {}
assert cache._api_keys == {}
assert cache.is_initialized is False
@pytest.mark.asyncio
async def test_cache_refresh_all_loads_teams(self) -> None:
cache = TeamsCacheManager()
mock_config1 = MagicMock()
mock_config1.team_id = "team-111"
mock_config1.enabled = True
mock_config2 = MagicMock()
mock_config2.team_id = "team-222"
mock_config2.enabled = True
with (
patch(
"onyx.onyxbot.cache.get_all_tenant_ids",
return_value=["tenant1"],
),
patch(
"onyx.onyxbot.cache.fetch_ee_implementation_or_noop",
return_value=lambda: set(),
),
patch("onyx.onyxbot.cache.get_session_with_tenant") as mock_session,
patch(
"onyx.onyxbot.teams.cache.get_team_configs",
return_value=[mock_config1, mock_config2],
),
patch(
"onyx.onyxbot.teams.cache.provision_teams_service_api_key",
return_value="test_api_key",
),
):
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
await cache.refresh_all()
assert cache.is_initialized is True
assert "team-111" in cache._entity_tenants
assert "team-222" in cache._entity_tenants
assert cache._entity_tenants["team-111"] == "tenant1"
@pytest.mark.asyncio
async def test_cache_refresh_provisions_api_key(self) -> None:
cache = TeamsCacheManager()
mock_config = MagicMock()
mock_config.team_id = "team-111"
mock_config.enabled = True
with (
patch(
"onyx.onyxbot.cache.get_all_tenant_ids",
return_value=["tenant1"],
),
patch(
"onyx.onyxbot.cache.fetch_ee_implementation_or_noop",
return_value=lambda: set(),
),
patch("onyx.onyxbot.cache.get_session_with_tenant") as mock_session,
patch(
"onyx.onyxbot.teams.cache.get_team_configs",
return_value=[mock_config],
),
patch(
"onyx.onyxbot.teams.cache.provision_teams_service_api_key",
return_value="new_api_key",
) as mock_provision,
):
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
await cache.refresh_all()
assert cache._api_keys.get("tenant1") == "new_api_key"
mock_provision.assert_called()
class TestCacheLookups:
"""Tests for cache lookup operations."""
def test_get_tenant_returns_correct(self) -> None:
cache = TeamsCacheManager()
cache._entity_tenants["team-123"] = "tenant1"
assert cache.get_tenant("team-123") == "tenant1"
def test_get_tenant_returns_none_unknown(self) -> None:
cache = TeamsCacheManager()
assert cache.get_tenant("unknown-team") is None
def test_get_api_key_returns_correct(self) -> None:
cache = TeamsCacheManager()
cache._api_keys["tenant1"] = "api_key_123"
assert cache.get_api_key("tenant1") == "api_key_123"
def test_get_api_key_returns_none_unknown(self) -> None:
cache = TeamsCacheManager()
assert cache.get_api_key("unknown_tenant") is None
def test_get_all_team_ids(self) -> None:
cache = TeamsCacheManager()
cache._entity_tenants = {"t1": "tenant1", "t2": "tenant2", "t3": "tenant1"}
result = cache.get_all_team_ids()
assert set(result) == {"t1", "t2", "t3"}
class TestCacheUpdates:
"""Tests for cache update operations."""
@pytest.mark.asyncio
async def test_refresh_team_adds_new(self) -> None:
cache = TeamsCacheManager()
mock_config = MagicMock()
mock_config.team_id = "team-111"
mock_config.enabled = True
with (
patch("onyx.onyxbot.cache.get_session_with_tenant") as mock_session,
patch(
"onyx.onyxbot.teams.cache.get_team_configs",
return_value=[mock_config],
),
patch(
"onyx.onyxbot.teams.cache.provision_teams_service_api_key",
return_value="api_key",
),
):
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
await cache.refresh_team("team-111", "tenant1")
assert cache.get_tenant("team-111") == "tenant1"
def test_remove_team(self) -> None:
cache = TeamsCacheManager()
cache._entity_tenants["team-111"] = "tenant1"
cache.remove_team("team-111")
assert cache.get_tenant("team-111") is None
def test_clear_removes_all(self) -> None:
cache = TeamsCacheManager()
cache._entity_tenants = {"t1": "tenant1", "t2": "tenant2"}
cache._api_keys = {"tenant1": "key1", "tenant2": "key2"}
cache._initialized = True
cache.clear()
assert cache._entity_tenants == {}
assert cache._api_keys == {}
assert cache.is_initialized is False
class TestGatedTenantHandling:
"""Tests for gated tenant filtering."""
@pytest.mark.asyncio
async def test_refresh_skips_gated_tenants(self) -> None:
cache = TeamsCacheManager()
gated_tenants = {"tenant2"}
mock_config = MagicMock()
mock_config.team_id = "team-111"
mock_config.enabled = True
with (
patch(
"onyx.onyxbot.cache.get_all_tenant_ids",
return_value=["tenant1", "tenant2"],
),
patch(
"onyx.onyxbot.cache.fetch_ee_implementation_or_noop",
return_value=lambda: gated_tenants,
),
patch("onyx.onyxbot.cache.get_session_with_tenant") as mock_session,
patch(
"onyx.onyxbot.teams.cache.get_team_configs",
return_value=[mock_config],
),
patch(
"onyx.onyxbot.teams.cache.provision_teams_service_api_key",
return_value="api_key",
),
):
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
await cache.refresh_all()
assert "tenant1" in cache._api_keys
assert "tenant2" not in cache._api_keys

View File

@@ -0,0 +1,86 @@
"""Unit tests for Teams bot Adaptive Card builders."""
from unittest.mock import MagicMock
from onyx.onyxbot.teams.cards import build_answer_card
from onyx.onyxbot.teams.cards import build_error_card
from onyx.onyxbot.teams.cards import build_welcome_card
class TestBuildAnswerCard:
"""Tests for answer card generation."""
def test_basic_answer(self) -> None:
card = build_answer_card("Hello world")
assert card["type"] == "AdaptiveCard"
assert card["version"] == "1.3"
assert len(card["body"]) == 1
assert card["body"][0]["text"] == "Hello world"
def test_answer_with_citations(self) -> None:
mock_response = MagicMock()
mock_citation = MagicMock()
mock_citation.citation_number = 1
mock_citation.document_id = "doc1"
mock_doc = MagicMock()
mock_doc.document_id = "doc1"
mock_doc.semantic_identifier = "Design Doc"
mock_doc.link = "https://example.com/doc1"
mock_response.citation_info = [mock_citation]
mock_response.top_documents = [mock_doc]
card = build_answer_card("Answer text", mock_response)
# Body should have: answer + "Sources:" header + citation
assert len(card["body"]) == 3
assert "Sources" in card["body"][1]["text"]
assert "Design Doc" in card["body"][2]["text"]
def test_answer_no_citations(self) -> None:
mock_response = MagicMock()
mock_response.citation_info = []
mock_response.top_documents = []
card = build_answer_card("Answer text", mock_response)
assert len(card["body"]) == 1
def test_answer_citation_without_link(self) -> None:
mock_response = MagicMock()
mock_citation = MagicMock()
mock_citation.citation_number = 1
mock_citation.document_id = "doc1"
mock_doc = MagicMock()
mock_doc.document_id = "doc1"
mock_doc.semantic_identifier = "Internal Doc"
mock_doc.link = None
mock_response.citation_info = [mock_citation]
mock_response.top_documents = [mock_doc]
card = build_answer_card("Answer text", mock_response)
assert "Internal Doc" in card["body"][2]["text"]
# Should not contain markdown link since link is None
assert "http" not in card["body"][2]["text"]
class TestBuildErrorCard:
"""Tests for error card generation."""
def test_error_card(self) -> None:
card = build_error_card("Something went wrong")
assert card["type"] == "AdaptiveCard"
assert card["body"][0]["text"] == "Something went wrong"
assert card["body"][0]["color"] == "Attention"
class TestBuildWelcomeCard:
"""Tests for welcome card generation."""
def test_welcome_card(self) -> None:
card = build_welcome_card()
assert card["type"] == "AdaptiveCard"
assert len(card["body"]) == 2
assert "Welcome" in card["body"][0]["text"]
assert "register" in card["body"][1]["text"]

View File

@@ -0,0 +1,272 @@
"""Unit tests for Teams bot should_respond logic."""
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.onyxbot.teams.handle_message import should_respond
class TestBasicShouldRespond:
"""Tests for basic should_respond decision logic."""
def test_team_disabled_returns_false(self) -> None:
"""Team config enabled=false returns False."""
mock_team_config = MagicMock()
mock_team_config.enabled = False
with patch(
"onyx.onyxbot.teams.handle_message.get_session_with_tenant"
) as mock_session:
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
with patch(
"onyx.onyxbot.teams.handle_message.get_team_config_by_teams_id",
return_value=mock_team_config,
):
result = should_respond(
activity_dict={},
team_id="team-123",
channel_id="channel-456",
tenant_id="tenant1",
bot_id="bot-id",
)
assert result.should_respond is False
def test_team_enabled_channel_enabled_no_mention_required(self) -> None:
"""Team + channel enabled, require_bot_mention=false returns True."""
mock_team_config = MagicMock()
mock_team_config.enabled = True
mock_team_config.default_persona_id = 2
mock_channel_config = MagicMock()
mock_channel_config.enabled = True
mock_channel_config.require_bot_mention = False
mock_channel_config.persona_override_id = None
with patch(
"onyx.onyxbot.teams.handle_message.get_session_with_tenant"
) as mock_session:
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
with (
patch(
"onyx.onyxbot.teams.handle_message.get_team_config_by_teams_id",
return_value=mock_team_config,
),
patch(
"onyx.onyxbot.teams.handle_message.get_channel_config_by_teams_ids",
return_value=mock_channel_config,
),
):
result = should_respond(
activity_dict={},
team_id="team-123",
channel_id="channel-456",
tenant_id="tenant1",
bot_id="bot-id",
)
assert result.should_respond is True
assert result.persona_id == 2
def test_channel_disabled_returns_false(self) -> None:
"""Channel config enabled=false returns False."""
mock_team_config = MagicMock()
mock_team_config.enabled = True
mock_channel_config = MagicMock()
mock_channel_config.enabled = False
with patch(
"onyx.onyxbot.teams.handle_message.get_session_with_tenant"
) as mock_session:
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
with (
patch(
"onyx.onyxbot.teams.handle_message.get_team_config_by_teams_id",
return_value=mock_team_config,
),
patch(
"onyx.onyxbot.teams.handle_message.get_channel_config_by_teams_ids",
return_value=mock_channel_config,
),
):
result = should_respond(
activity_dict={},
team_id="team-123",
channel_id="channel-456",
tenant_id="tenant1",
bot_id="bot-id",
)
assert result.should_respond is False
def test_channel_not_found_returns_false(self) -> None:
"""No channel config returns False (not whitelisted)."""
mock_team_config = MagicMock()
mock_team_config.enabled = True
with patch(
"onyx.onyxbot.teams.handle_message.get_session_with_tenant"
) as mock_session:
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
with (
patch(
"onyx.onyxbot.teams.handle_message.get_team_config_by_teams_id",
return_value=mock_team_config,
),
patch(
"onyx.onyxbot.teams.handle_message.get_channel_config_by_teams_ids",
return_value=None,
),
):
result = should_respond(
activity_dict={},
team_id="team-123",
channel_id="channel-456",
tenant_id="tenant1",
bot_id="bot-id",
)
assert result.should_respond is False
def test_require_mention_true_with_mention(
self, sample_activity_dict: dict
) -> None:
"""require_bot_mention=true with @mention returns True."""
mock_team_config = MagicMock()
mock_team_config.enabled = True
mock_team_config.default_persona_id = 1
mock_channel_config = MagicMock()
mock_channel_config.enabled = True
mock_channel_config.require_bot_mention = True
mock_channel_config.persona_override_id = None
with patch(
"onyx.onyxbot.teams.handle_message.get_session_with_tenant"
) as mock_session:
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
with (
patch(
"onyx.onyxbot.teams.handle_message.get_team_config_by_teams_id",
return_value=mock_team_config,
),
patch(
"onyx.onyxbot.teams.handle_message.get_channel_config_by_teams_ids",
return_value=mock_channel_config,
),
):
result = should_respond(
activity_dict=sample_activity_dict,
team_id="team-abc-123",
channel_id="19:channel-xyz@thread.tacv2",
tenant_id="tenant1",
bot_id="28:bot-id-456",
)
assert result.should_respond is True
def test_require_mention_true_no_mention(self) -> None:
"""require_bot_mention=true without @mention returns False."""
mock_team_config = MagicMock()
mock_team_config.enabled = True
mock_team_config.default_persona_id = 1
mock_channel_config = MagicMock()
mock_channel_config.enabled = True
mock_channel_config.require_bot_mention = True
mock_channel_config.persona_override_id = None
activity_no_mention = {"entities": []}
with patch(
"onyx.onyxbot.teams.handle_message.get_session_with_tenant"
) as mock_session:
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
with (
patch(
"onyx.onyxbot.teams.handle_message.get_team_config_by_teams_id",
return_value=mock_team_config,
),
patch(
"onyx.onyxbot.teams.handle_message.get_channel_config_by_teams_ids",
return_value=mock_channel_config,
),
):
result = should_respond(
activity_dict=activity_no_mention,
team_id="team-123",
channel_id="channel-456",
tenant_id="tenant1",
bot_id="bot-id",
)
assert result.should_respond is False
def test_dm_no_team_returns_true(self) -> None:
"""DM (no team_id or channel_id) returns True."""
result = should_respond(
activity_dict={},
team_id=None,
channel_id=None,
tenant_id="tenant1",
bot_id="bot-id",
)
assert result.should_respond is True
def test_persona_override_takes_priority(self) -> None:
"""Channel persona override takes priority over team default."""
mock_team_config = MagicMock()
mock_team_config.enabled = True
mock_team_config.default_persona_id = 1
mock_channel_config = MagicMock()
mock_channel_config.enabled = True
mock_channel_config.require_bot_mention = False
mock_channel_config.persona_override_id = 5
with patch(
"onyx.onyxbot.teams.handle_message.get_session_with_tenant"
) as mock_session:
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
with (
patch(
"onyx.onyxbot.teams.handle_message.get_team_config_by_teams_id",
return_value=mock_team_config,
),
patch(
"onyx.onyxbot.teams.handle_message.get_channel_config_by_teams_ids",
return_value=mock_channel_config,
),
):
result = should_respond(
activity_dict={},
team_id="team-123",
channel_id="channel-456",
tenant_id="tenant1",
bot_id="bot-id",
)
assert result.should_respond is True
assert result.persona_id == 5

View File

@@ -0,0 +1,104 @@
"""Unit tests for Teams bot utility functions."""
from onyx.onyxbot.teams.utils import extract_channel_id
from onyx.onyxbot.teams.utils import extract_team_id
from onyx.onyxbot.teams.utils import extract_team_name
from onyx.onyxbot.teams.utils import is_bot_mentioned
from onyx.onyxbot.teams.utils import strip_bot_mention
from onyx.server.manage.teams_bot.utils import generate_teams_registration_key
from onyx.server.manage.teams_bot.utils import parse_teams_registration_key
class TestExtractIds:
"""Tests for ID extraction from Activity dicts."""
def test_extract_team_id_present(self, sample_activity_dict: dict) -> None:
assert extract_team_id(sample_activity_dict) == "team-abc-123"
def test_extract_team_id_missing(self, sample_dm_activity_dict: dict) -> None:
assert extract_team_id(sample_dm_activity_dict) is None
def test_extract_channel_id_present(self, sample_activity_dict: dict) -> None:
assert extract_channel_id(sample_activity_dict) == "19:channel-xyz@thread.tacv2"
def test_extract_channel_id_missing(self, sample_dm_activity_dict: dict) -> None:
assert extract_channel_id(sample_dm_activity_dict) is None
def test_extract_team_name(self, sample_activity_dict: dict) -> None:
assert extract_team_name(sample_activity_dict) == "Engineering"
def test_extract_team_name_missing(self, sample_dm_activity_dict: dict) -> None:
assert extract_team_name(sample_dm_activity_dict) is None
class TestStripBotMention:
"""Tests for bot mention stripping."""
def test_strip_named_mention(self) -> None:
text = "<at>Onyx</at> What is our process?"
assert strip_bot_mention(text, "Onyx") == "What is our process?"
def test_strip_case_insensitive(self) -> None:
text = "<at>onyx</at> Hello"
assert strip_bot_mention(text, "Onyx") == "Hello"
def test_strip_no_mention(self) -> None:
text = "Just a normal message"
assert strip_bot_mention(text, "Onyx") == "Just a normal message"
def test_strip_multiple_mentions(self) -> None:
text = "<at>Onyx</at> hello <at>Onyx</at>"
assert strip_bot_mention(text, "Onyx") == "hello"
def test_strip_empty_result(self) -> None:
text = "<at>Onyx</at>"
assert strip_bot_mention(text, "Onyx") == ""
class TestIsBotMentioned:
"""Tests for bot mention detection."""
def test_bot_mentioned(self, sample_activity_dict: dict) -> None:
assert is_bot_mentioned(sample_activity_dict, "28:bot-id-456") is True
def test_bot_not_mentioned(self, sample_dm_activity_dict: dict) -> None:
assert is_bot_mentioned(sample_dm_activity_dict, "28:bot-id-456") is False
def test_different_bot_mentioned(self, sample_activity_dict: dict) -> None:
assert is_bot_mentioned(sample_activity_dict, "other-bot-id") is False
def test_no_entities(self) -> None:
activity = {"entities": []}
assert is_bot_mentioned(activity, "any-id") is False
class TestRegistrationKeys:
"""Tests for registration key generation and parsing."""
def test_generate_and_parse_roundtrip(self) -> None:
key = generate_teams_registration_key("tenant1")
parsed = parse_teams_registration_key(key)
assert parsed == "tenant1"
def test_generate_has_correct_prefix(self) -> None:
key = generate_teams_registration_key("tenant1")
assert key.startswith("teams_")
def test_parse_invalid_prefix(self) -> None:
assert parse_teams_registration_key("discord_tenant1.token") is None
def test_parse_no_separator(self) -> None:
assert parse_teams_registration_key("teams_noseparator") is None
def test_parse_empty_string(self) -> None:
assert parse_teams_registration_key("") is None
def test_generate_url_encodes_tenant(self) -> None:
key = generate_teams_registration_key("tenant with spaces")
parsed = parse_teams_registration_key(key)
assert parsed == "tenant with spaces"
def test_generate_unique_keys(self) -> None:
key1 = generate_teams_registration_key("tenant1")
key2 = generate_teams_registration_key("tenant1")
assert key1 != key2

View File

@@ -19,7 +19,6 @@ from ee.onyx.server.scim.models import ScimListResponse
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimUserResource
from ee.onyx.server.scim.providers.base import ScimProvider
from ee.onyx.server.scim.providers.entra import EntraProvider
from ee.onyx.server.scim.providers.okta import OktaProvider
from onyx.db.models import ScimToken
from onyx.db.models import ScimUserMapping
@@ -27,10 +26,6 @@ from onyx.db.models import User
from onyx.db.models import UserGroup
from onyx.db.models import UserRole
# Every supported SCIM provider must appear here so that all endpoint tests
# run against it. When adding a new provider, add its class to this list.
SCIM_PROVIDERS: list[type[ScimProvider]] = [OktaProvider, EntraProvider]
@pytest.fixture
def mock_db_session() -> MagicMock:
@@ -46,10 +41,10 @@ def mock_token() -> MagicMock:
return token
@pytest.fixture(params=SCIM_PROVIDERS, ids=[p.__name__ for p in SCIM_PROVIDERS])
def provider(request: pytest.FixtureRequest) -> ScimProvider:
"""Parameterized provider — runs each test with every provider in SCIM_PROVIDERS."""
return request.param()
@pytest.fixture
def provider() -> ScimProvider:
"""An OktaProvider instance for endpoint tests."""
return OktaProvider()
@pytest.fixture

View File

@@ -1,11 +1,11 @@
from unittest.mock import MagicMock
import pytest
from fastapi import HTTPException
from ee.onyx.server.scim.auth import _hash_scim_token
from ee.onyx.server.scim.auth import generate_scim_token
from ee.onyx.server.scim.auth import SCIM_TOKEN_PREFIX
from ee.onyx.server.scim.auth import ScimAuthError
from ee.onyx.server.scim.auth import verify_scim_token
@@ -60,7 +60,7 @@ class TestVerifyScimToken:
def test_missing_header_raises_401(self) -> None:
request = self._make_request(None)
dal = self._make_dal()
with pytest.raises(ScimAuthError) as exc_info:
with pytest.raises(HTTPException) as exc_info:
verify_scim_token(request, dal)
assert exc_info.value.status_code == 401
assert "Missing" in str(exc_info.value.detail)
@@ -68,7 +68,7 @@ class TestVerifyScimToken:
def test_wrong_prefix_raises_401(self) -> None:
request = self._make_request("Bearer on_some_api_key")
dal = self._make_dal()
with pytest.raises(ScimAuthError) as exc_info:
with pytest.raises(HTTPException) as exc_info:
verify_scim_token(request, dal)
assert exc_info.value.status_code == 401
@@ -76,7 +76,7 @@ class TestVerifyScimToken:
raw, _, _ = generate_scim_token()
request = self._make_request(f"Bearer {raw}")
dal = self._make_dal(token=None)
with pytest.raises(ScimAuthError) as exc_info:
with pytest.raises(HTTPException) as exc_info:
verify_scim_token(request, dal)
assert exc_info.value.status_code == 401
assert "Invalid" in str(exc_info.value.detail)
@@ -87,7 +87,7 @@ class TestVerifyScimToken:
mock_token = MagicMock()
mock_token.is_active = False
dal = self._make_dal(token=mock_token)
with pytest.raises(ScimAuthError) as exc_info:
with pytest.raises(HTTPException) as exc_info:
verify_scim_token(request, dal)
assert exc_info.value.status_code == 401
assert "revoked" in str(exc_info.value.detail)

View File

@@ -109,7 +109,7 @@ class TestOktaProvider:
result = provider.build_user_resource(user, None)
assert result.name == ScimName(
givenName="Madonna", familyName="", formatted="Madonna"
givenName="Madonna", familyName=None, formatted="Madonna"
)
def test_build_user_resource_no_name(self) -> None:
@@ -117,7 +117,7 @@ class TestOktaProvider:
user = _make_mock_user(personal_name=None)
result = provider.build_user_resource(user, None)
assert result.name == ScimName(givenName="", familyName="", formatted="")
assert result.name is None
assert result.displayName is None
def test_build_user_resource_scim_username_preserves_case(self) -> None:

View File

@@ -214,16 +214,13 @@ class TestCreateUser:
mock_dal.add_user.assert_called_once()
mock_dal.commit.assert_called_once()
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
def test_missing_external_id_creates_user_without_mapping(
def test_missing_external_id_returns_400(
self,
mock_seats: MagicMock, # noqa: ARG002
mock_db_session: MagicMock,
mock_token: MagicMock,
mock_dal: MagicMock,
mock_dal: MagicMock, # noqa: ARG002
provider: ScimProvider,
) -> None:
mock_dal.get_user_by_email.return_value = None
resource = make_scim_user(externalId=None)
result = create_user(
@@ -233,11 +230,7 @@ class TestCreateUser:
db_session=mock_db_session,
)
parsed = parse_scim_user(result, status=201)
assert parsed.userName is not None
mock_dal.add_user.assert_called_once()
mock_dal.create_user_mapping.assert_not_called()
mock_dal.commit.assert_called_once()
assert_scim_error(result, 400)
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
def test_duplicate_email_returns_409(

View File

@@ -126,9 +126,7 @@ Resources:
- Effect: Allow
Action:
- secretsmanager:GetSecretValue
Resource:
- !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/postgres/user/password-*
- !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/onyx/user-auth-secret-*
Resource: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/postgres/user/password-*
Outputs:
OutputEcsCluster:

View File

@@ -167,12 +167,10 @@ Resources:
- ImportedNamespace: !ImportValue
Fn::Sub: "${Environment}-onyx-cluster-OnyxNamespaceName"
- Name: AUTH_TYPE
Value: basic
Value: disabled
Secrets:
- Name: POSTGRES_PASSWORD
ValueFrom: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/postgres/user/password
- Name: USER_AUTH_SECRET
ValueFrom: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/onyx/user-auth-secret
VolumesFrom: []
SystemControls: []

View File

@@ -166,11 +166,9 @@ Resources:
- ImportedNamespace: !ImportValue
Fn::Sub: "${Environment}-onyx-cluster-OnyxNamespaceName"
- Name: AUTH_TYPE
Value: basic
Value: disabled
Secrets:
- Name: POSTGRES_PASSWORD
ValueFrom: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/postgres/user/password
- Name: USER_AUTH_SECRET
ValueFrom: !Sub arn:aws:secretsmanager:${AWS::Region}:${AWS::AccountId}:secret:${Environment}/onyx/user-auth-secret
VolumesFrom: []
SystemControls: []

View File

@@ -16,15 +16,12 @@
# This overlay:
# - Moves Vespa (index), both model servers, and code-interpreter to profiles
# so they do not start by default
# - Moves the background worker to the "background" profile (the API server
# handles all background work via FastAPI BackgroundTasks)
# - Makes the depends_on references to removed services optional
# - Sets DISABLE_VECTOR_DB=true on the api_server
# - Makes the depends_on references to those services optional
# - Sets DISABLE_VECTOR_DB=true on backend services
#
# To selectively bring services back:
# --profile vectordb Vespa + indexing model server
# --profile inference Inference model server
# --profile background Background worker (Celery)
# --profile code-interpreter Code interpreter
# =============================================================================
@@ -46,20 +43,20 @@ services:
- DISABLE_VECTOR_DB=true
- FILE_STORE_BACKEND=postgres
# Move the background worker to a profile so it does not start by default.
# The API server handles all background work in NO_VECTOR_DB mode.
background:
profiles: ["background"]
depends_on:
index:
condition: service_started
required: false
inference_model_server:
condition: service_started
required: false
indexing_model_server:
condition: service_started
required: false
inference_model_server:
condition: service_started
required: false
environment:
- DISABLE_VECTOR_DB=true
- FILE_STORE_BACKEND=postgres
# Move Vespa and indexing model server to a profile so they do not start.
index:

View File

@@ -65,7 +65,10 @@ services:
- S3_AWS_SECRET_ACCESS_KEY=${S3_AWS_SECRET_ACCESS_KEY:-minioadmin}
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
# API Server connection for Discord bot message processing
- TEAMS_BOT_APP_ID=${TEAMS_BOT_APP_ID:-}
- TEAMS_BOT_APP_SECRET=${TEAMS_BOT_APP_SECRET:-}
- TEAMS_BOT_AZURE_TENANT_ID=${TEAMS_BOT_AZURE_TENANT_ID:-}
# API Server connection for Discord/Teams bot message processing
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
env_file:

View File

@@ -87,7 +87,10 @@ services:
- S3_AWS_SECRET_ACCESS_KEY=${S3_AWS_SECRET_ACCESS_KEY:-minioadmin}
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
# API Server connection for Discord bot message processing
- TEAMS_BOT_APP_ID=${TEAMS_BOT_APP_ID:-}
- TEAMS_BOT_APP_SECRET=${TEAMS_BOT_APP_SECRET:-}
- TEAMS_BOT_AZURE_TENANT_ID=${TEAMS_BOT_AZURE_TENANT_ID:-}
# API Server connection for Discord/Teams bot message processing
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
- PERSISTENT_DOCUMENT_STORAGE_PATH=${PERSISTENT_DOCUMENT_STORAGE_PATH:-/app/file-system}

View File

@@ -161,7 +161,10 @@ services:
- S3_AWS_SECRET_ACCESS_KEY=${S3_AWS_SECRET_ACCESS_KEY:-minioadmin}
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
# API Server connection for Discord bot message processing
- TEAMS_BOT_APP_ID=${TEAMS_BOT_APP_ID:-}
- TEAMS_BOT_APP_SECRET=${TEAMS_BOT_APP_SECRET:-}
- TEAMS_BOT_AZURE_TENANT_ID=${TEAMS_BOT_AZURE_TENANT_ID:-}
# API Server connection for Discord/Teams bot message processing
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
# Onyx Craft configuration (set up automatically on container startup)

View File

@@ -103,6 +103,14 @@ MINIO_ROOT_PASSWORD=minioadmin
## Command prefix for bot commands (default: "!")
# DISCORD_BOT_INVOKE_CHAR=!
## Teams Bot Configuration
## The Teams bot allows users to interact with Onyx from Microsoft Teams
## App ID and Secret from Azure Bot Service registration
# TEAMS_BOT_APP_ID=
# TEAMS_BOT_APP_SECRET=
## Azure tenant ID (optional, for single-tenant bots)
# TEAMS_BOT_AZURE_TENANT_ID=
## Celery Configuration
# CELERY_BROKER_POOL_LIMIT=
# CELERY_WORKER_DOCFETCHING_CONCURRENCY=

View File

@@ -19,6 +19,6 @@ dependencies:
version: 5.4.0
- name: code-interpreter
repository: https://onyx-dot-app.github.io/python-sandbox/
version: 0.3.1
digest: sha256:4965b6ea3674c37163832a2192cd3bc8004f2228729fca170af0b9f457e8f987
generated: "2026-03-02T15:29:39.632344-08:00"
version: 0.3.0
digest: sha256:cf8f01906d46034962c6ce894770621ee183ac761e6942951118aeb48540eddd
generated: "2026-02-24T10:59:38.78318-08:00"

View File

@@ -45,6 +45,6 @@ dependencies:
repository: https://charts.min.io/
condition: minio.enabled
- name: code-interpreter
version: 0.3.1
version: 0.3.0
repository: https://onyx-dot-app.github.io/python-sandbox/
condition: codeInterpreter.enabled

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_beat.replicaCount) 0) }}
{{- if gt (int .Values.celery_beat.replicaCount) 0 }}
apiVersion: apps/v1
kind: Deployment
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_heavy.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
{{- if and (.Values.celery_worker_heavy.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_heavy.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
{{- if and (.Values.celery_worker_heavy.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
apiVersion: keda.sh/v1alpha1
kind: ScaledObject
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_heavy.replicaCount) 0) }}
{{- if gt (int .Values.celery_worker_heavy.replicaCount) 0 }}
apiVersion: apps/v1
kind: Deployment
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_light.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
{{- if and (.Values.celery_worker_light.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_light.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
{{- if and (.Values.celery_worker_light.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
apiVersion: keda.sh/v1alpha1
kind: ScaledObject
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_light.replicaCount) 0) }}
{{- if gt (int .Values.celery_worker_light.replicaCount) 0 }}
apiVersion: apps/v1
kind: Deployment
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_monitoring.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
{{- if and (.Values.celery_worker_monitoring.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_monitoring.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
{{- if and (.Values.celery_worker_monitoring.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
apiVersion: keda.sh/v1alpha1
kind: ScaledObject
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_monitoring.replicaCount) 0) }}
{{- if gt (int .Values.celery_worker_monitoring.replicaCount) 0 }}
apiVersion: apps/v1
kind: Deployment
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_primary.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
{{- if and (.Values.celery_worker_primary.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_primary.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
{{- if and (.Values.celery_worker_primary.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
apiVersion: keda.sh/v1alpha1
kind: ScaledObject
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_primary.replicaCount) 0) }}
{{- if gt (int .Values.celery_worker_primary.replicaCount) 0 }}
apiVersion: apps/v1
kind: Deployment
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_user_file_processing.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
{{- if and (.Values.celery_worker_user_file_processing.autoscaling.enabled) (ne (include "onyx.autoscaling.engine" .) "keda") }}
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (.Values.celery_worker_user_file_processing.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
{{- if and (.Values.celery_worker_user_file_processing.autoscaling.enabled) (eq (include "onyx.autoscaling.engine" .) "keda") }}
apiVersion: keda.sh/v1alpha1
kind: ScaledObject
metadata:

View File

@@ -1,4 +1,4 @@
{{- if and .Values.vectorDB.enabled (gt (int .Values.celery_worker_user_file_processing.replicaCount) 0) }}
{{- if gt (int .Values.celery_worker_user_file_processing.replicaCount) 0 }}
apiVersion: apps/v1
kind: Deployment
metadata:

View File

@@ -0,0 +1,26 @@
{{- if .Values.teamsbot.enabled }}
# Service to expose the Teams bot /api/messages endpoint.
# Unlike Discord (outbound WebSocket only), Teams requires an inbound HTTP endpoint
# that Azure Bot Service can POST Activities to.
apiVersion: v1
kind: Service
metadata:
name: {{ include "onyx.fullname" . }}-teamsbot
labels:
{{- include "onyx.labels" . | nindent 4 }}
{{- with .Values.teamsbot.deploymentLabels }}
{{- toYaml . | nindent 4 }}
{{- end }}
spec:
type: {{ .Values.teamsbot.service.type | default "ClusterIP" }}
ports:
- port: {{ .Values.teamsbot.service.port | default 80 }}
targetPort: http
protocol: TCP
name: http
selector:
{{- include "onyx.selectorLabels" . | nindent 4 }}
{{- if .Values.teamsbot.deploymentLabels }}
{{- toYaml .Values.teamsbot.deploymentLabels | nindent 4 }}
{{- end }}
{{- end }}

View File

@@ -0,0 +1,131 @@
{{- if .Values.teamsbot.enabled }}
# Teams bot receives webhooks via HTTP POST - supports multiple replicas behind a load balancer.
# Unlike Discord (WebSocket, single replica), Teams is horizontally scalable.
apiVersion: apps/v1
kind: Deployment
metadata:
name: {{ include "onyx.fullname" . }}-teamsbot
labels:
{{- include "onyx.labels" . | nindent 4 }}
{{- with .Values.teamsbot.deploymentLabels }}
{{- toYaml . | nindent 4 }}
{{- end }}
spec:
replicas: {{ .Values.teamsbot.replicaCount | default 1 }}
strategy:
type: RollingUpdate
rollingUpdate:
maxSurge: 1
maxUnavailable: 0
selector:
matchLabels:
{{- include "onyx.selectorLabels" . | nindent 6 }}
{{- if .Values.teamsbot.deploymentLabels }}
{{- toYaml .Values.teamsbot.deploymentLabels | nindent 6 }}
{{- end }}
template:
metadata:
annotations:
checksum/config: {{ include (print $.Template.BasePath "/configmap.yaml") . | sha256sum }}
{{- with .Values.teamsbot.podAnnotations }}
{{- toYaml . | nindent 8 }}
{{- end }}
labels:
{{- include "onyx.labels" . | nindent 8 }}
{{- with .Values.teamsbot.deploymentLabels }}
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.teamsbot.podLabels }}
{{- toYaml . | nindent 8 }}
{{- end }}
spec:
{{- with .Values.imagePullSecrets }}
imagePullSecrets:
{{- toYaml . | nindent 8 }}
{{- end }}
serviceAccountName: {{ include "onyx.serviceAccountName" . }}
securityContext:
{{- toYaml .Values.teamsbot.podSecurityContext | nindent 8 }}
{{- with .Values.teamsbot.nodeSelector }}
nodeSelector:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.teamsbot.affinity }}
affinity:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.teamsbot.tolerations }}
tolerations:
{{- toYaml . | nindent 8 }}
{{- end }}
containers:
- name: teamsbot
securityContext:
{{- toYaml .Values.teamsbot.securityContext | nindent 12 }}
image: "{{ .Values.teamsbot.image.repository }}:{{ .Values.teamsbot.image.tag | default .Values.global.version }}"
imagePullPolicy: {{ .Values.global.pullPolicy }}
command: ["python", "onyx/onyxbot/teams/server.py"]
ports:
- name: http
containerPort: {{ .Values.teamsbot.port | default 3978 }}
protocol: TCP
livenessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 15
periodSeconds: 30
readinessProbe:
httpGet:
path: /health
port: http
initialDelaySeconds: 10
periodSeconds: 10
resources:
{{- toYaml .Values.teamsbot.resources | nindent 12 }}
envFrom:
- configMapRef:
name: {{ .Values.config.envConfigMapName }}
env:
{{- include "onyx.envSecrets" . | nindent 12}}
# Teams bot App ID
{{- if .Values.teamsbot.appId }}
- name: TEAMS_BOT_APP_ID
value: {{ .Values.teamsbot.appId | quote }}
{{- end }}
{{- if .Values.teamsbot.appIdSecretName }}
- name: TEAMS_BOT_APP_ID
valueFrom:
secretKeyRef:
name: {{ .Values.teamsbot.appIdSecretName }}
key: {{ .Values.teamsbot.appIdSecretKey | default "app-id" }}
{{- end }}
# Teams bot App Secret
{{- if .Values.teamsbot.appSecret }}
- name: TEAMS_BOT_APP_SECRET
value: {{ .Values.teamsbot.appSecret | quote }}
{{- end }}
{{- if .Values.teamsbot.appSecretSecretName }}
- name: TEAMS_BOT_APP_SECRET
valueFrom:
secretKeyRef:
name: {{ .Values.teamsbot.appSecretSecretName }}
key: {{ .Values.teamsbot.appSecretSecretKey | default "app-secret" }}
{{- end }}
# Azure tenant ID (optional, for single-tenant bots)
{{- if .Values.teamsbot.azureTenantId }}
- name: TEAMS_BOT_AZURE_TENANT_ID
value: {{ .Values.teamsbot.azureTenantId | quote }}
{{- end }}
# Bot port
- name: TEAMS_BOT_PORT
value: {{ .Values.teamsbot.port | default 3978 | quote }}
{{- with .Values.teamsbot.volumeMounts }}
volumeMounts:
{{- toYaml . | nindent 12 }}
{{- end }}
{{- with .Values.teamsbot.volumes }}
volumes:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- end }}

Some files were not shown because too many files have changed in this diff Show More