Compare commits

...

60 Commits

Author SHA1 Message Date
Justin Tahara
9fb76042a2 fix(celery): Guardrail for User File Processing (#8633) 2026-03-01 10:30:03 -08:00
Nikolas Garza
caad67a34a fix(slack): sanitize HTML tags and broken citation links in bot responses (#8767) 2026-02-26 17:27:24 -08:00
dependabot[bot]
c33437488f chore(deps): Bump mistune from 0.8.4 to 3.1.4 in /backend (#6407)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-26 17:27:24 -08:00
Jamison Lahman
9f66ee7240 chore(devtools): upgrade ods: v0.6.1->v0.6.2 (#8773) 2026-02-26 16:26:37 -08:00
justin-tahara
e6ef2b5074 Fixing mypy 2026-02-09 15:47:10 -08:00
justin-tahara
74132175a8 Fixing mypy 2026-02-09 15:47:10 -08:00
Justin Tahara
29f707ee2d fix(posthog): Chat metrics for Cloud (#8278) 2026-02-09 15:47:10 -08:00
Justin Tahara
f0eb86fb9f fix(ui): Updating Dropdown Modal component (#8033) 2026-02-06 11:59:09 -08:00
Justin Tahara
b422496a4c fix(agents): Removing Label Dependency (#8189) 2026-02-06 11:39:09 -08:00
Justin Tahara
31d6a45b23 chore(chat): Cleaning Error Codes + Tests (#8186) 2026-02-06 11:02:41 -08:00
Justin Tahara
36f3ac1ec5 feat: onyx discord bot - supervisord and kube deployment (#7706) 2026-02-02 15:05:21 -08:00
Wenxi Onyx
74f5b3025a fix: discord svg (can't cherry-pick) 2026-02-02 10:03:39 -08:00
Justin Tahara
c18545d74c feat(desktop): Ensure that UI reflects Light/Dark Toggle (#7684) 2026-02-02 10:03:39 -08:00
Justin Tahara
48171e3700 fix(ui): Agent Saving with other people files (#8095) 2026-02-02 10:03:39 -08:00
Wenxi
f5a5709876 feat: onyx discord bot - frontend (#7497) 2026-02-02 10:03:39 -08:00
Justin Tahara
85868b1b83 fix(desktop): Remove Global Shortcuts (#7914) 2026-01-30 13:46:20 -08:00
Justin Tahara
8dc14c23e6 fix(asana): Workspace Team ID mismatch (#7674) 2026-01-30 13:19:02 -08:00
Jamison Lahman
23821cc0e8 chore(mypy): fix mypy cache issues switching between HEAD and release (#7732) 2026-01-27 15:52:57 -08:00
Jamison Lahman
b359e13281 fix(citations): enable citation sidebar w/ web_search-only assistants (#7888) 2026-01-27 13:26:29 -08:00
Justin Tahara
717f410a4a fix(llm): Hide private models from Agent Creation (#7873) 2026-01-27 12:21:06 -08:00
SubashMohan
ada0946a62 fix(layout): adjust footer margin and prevent page refresh on chatsession drop (#7759) 2026-01-27 11:57:18 -08:00
Jamison Lahman
eb2ac8f5a3 fix(fe): inline code text wraps (#7574) 2026-01-27 11:33:03 -08:00
Nikolas Garza
fbeb57c592 fix(slack): Extract person names and filter garbage in query expansion (#7632) 2026-01-27 11:26:52 -08:00
Nikolas Garza
d6da9c9b85 fix: scroll to bottom when loading existing conversations (#7614) 2026-01-27 11:26:52 -08:00
Nikolas Garza
5aea2e223e fix(billing): remove grandfathered pricing option when subscription lapses (#7583) 2026-01-27 11:26:52 -08:00
Nikolas Garza
1ff91de07e fix: deflake chat user journey test (#7646) 2026-01-27 11:18:27 -08:00
Nikolas Garza
b3dbc69faf fix(tests): use crawler-friendly search query in Exa integration test (#7746) 2026-01-27 11:13:01 -08:00
Yuhong Sun
431597b0f9 fix: LiteLLM Azure models don't stream (#7761) 2026-01-27 10:49:17 -08:00
Yuhong Sun
51b4e5f2fb fix: Azure OpenAI Tool Calls (#7727) 2026-01-27 10:49:17 -08:00
Justin Tahara
9afa04a26b fix(ui): Coda Logo (#7656) 2026-01-26 17:43:54 -08:00
Justin Tahara
70a3a9c0cd fix(ui): User Groups Connectors Fix (#7658) 2026-01-26 17:43:45 -08:00
Justin Tahara
080165356c fix(ui): First Connector Result (#7657) 2026-01-26 17:43:35 -08:00
Justin Tahara
3ae974bdf6 fix(ui): Fix Token Rate Limits Page (#7659) 2026-01-26 17:42:57 -08:00
Justin Tahara
1471658151 fix(vertex ai): Extra Args for Opus 4.5 (#7586) 2026-01-26 17:42:43 -08:00
Justin Tahara
3e85e9c1a3 feat(desktop): Domain Configuration (#7655) 2026-01-26 17:12:33 -08:00
Justin Tahara
851033be5f feat(desktop): Properly Sign Mac App (#7608) 2026-01-26 17:12:24 -08:00
Jamison Lahman
91e974a6cc chore(desktop): make artifact filename version-agnostic (#7679) 2026-01-26 16:20:39 -08:00
Jamison Lahman
38ba4f8a1c chore(deployments): fix region (#7640) 2026-01-26 16:20:39 -08:00
Jamison Lahman
6f02473064 chore(deployments): fetch secrets from AWS (#7584) 2026-01-26 16:20:39 -08:00
Nikolas Garza
f89432009f fix(fe): show scroll-down button when user scrolls up during streaming (#7562) 2026-01-20 07:07:55 +00:00
Jamison Lahman
8ab2bab34e chore(fe): fix sticky header parent height (#7561) 2026-01-20 06:18:32 +00:00
Jamison Lahman
59e0d62512 chore(fe): align assistant icon with chat bar (#7537) 2026-01-19 19:47:18 -08:00
Jamison Lahman
a1471b16a5 fix(fe): chat header is sticky and transparent (#7487) 2026-01-19 19:20:03 -08:00
Yuhong Sun
9d3811cb58 fix: prompt tuning (#7550) 2026-01-19 19:04:18 -08:00
Yuhong Sun
3cd9505383 feat: Memory initial (#7547) 2026-01-19 18:57:13 -08:00
Nikolas Garza
d11829b393 refactor: proxy customer portal session through control plane (#7544) 2026-01-20 01:24:30 +00:00
Nikolas Garza
f6e068e914 feat(billing): add annual pricing support to subscription checkout (#7506) 2026-01-20 00:17:18 +00:00
roshan
0c84edd980 feat: onyx embeddable widget (#7427)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-01-20 00:01:10 +00:00
Wenxi
2b274a7683 feat: onyx discord bot - discord client (#7496) 2026-01-20 00:00:20 +00:00
Wenxi
ddd91f2d71 feat: onyx discord bot - api client and cache manager (#7495) 2026-01-19 23:15:17 +00:00
Yuhong Sun
a7c7da0dfc fix: tool call handling for weak models (#7538) 2026-01-19 13:37:00 -08:00
Evan Lohn
b00a3e8b5d fix(test): confluence group sync (#7536) 2026-01-19 21:20:48 +00:00
Raunak Bhagat
d77d1a48f1 fix: Line item fixes (#7513) 2026-01-19 20:25:35 +00:00
Raunak Bhagat
7b4fc6729c fix: Popover size fix (#7521) 2026-01-19 18:44:29 +00:00
Nikolas Garza
1f113c86ef feat(ee): license enforcement middleware (#7483) 2026-01-19 18:03:39 +00:00
Raunak Bhagat
8e38ba3e21 refactor: Fix some onboarding inaccuracies (#7511) 2026-01-19 04:33:27 +00:00
Raunak Bhagat
bb9708a64f refactor: Small styling / prop-naming refactors (#7503) 2026-01-19 02:49:27 +00:00
Raunak Bhagat
8cae97e145 fix: Fix connector-setup modal (#7502) 2026-01-19 00:29:36 +00:00
Wenxi
7e4abca224 feat: onyx discord bot - backend, crud, and apis (#7494) 2026-01-18 23:13:58 +00:00
Yuhong Sun
233a91ea65 chore: drop dead table (#7500) 2026-01-17 20:05:22 -08:00
208 changed files with 20770 additions and 2371 deletions

View File

@@ -8,7 +8,9 @@ on:
# Set restrictive default permissions for all jobs. Jobs that need more permissions
# should explicitly declare them.
permissions: {}
permissions:
# Required for OIDC authentication with AWS
id-token: write # zizmor: ignore[excessive-permissions]
env:
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
@@ -150,16 +152,30 @@ jobs:
if: always() && needs.check-version-tag.result == 'failure' && github.event_name != 'workflow_dispatch'
runs-on: ubuntu-slim
timeout-minutes: 10
environment: release
steps:
- name: Checkout
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
MONITOR_DEPLOYMENTS_WEBHOOK, deploy/monitor-deployments-webhook
parse-json-secrets: true
- name: Send Slack notification
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
webhook-url: ${{ env.MONITOR_DEPLOYMENTS_WEBHOOK }}
failed-jobs: "• check-version-tag"
title: "🚨 Version Tag Check Failed"
ref-name: ${{ github.ref_name }}
@@ -168,6 +184,7 @@ jobs:
needs: determine-builds
if: needs.determine-builds.outputs.build-desktop == 'true'
permissions:
id-token: write
contents: write
actions: read
strategy:
@@ -185,12 +202,33 @@ jobs:
runs-on: ${{ matrix.platform }}
timeout-minutes: 90
environment: release
steps:
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6.0.1
with:
# NOTE: persist-credentials is needed for tauri-action to create GitHub releases.
persist-credentials: true # zizmor: ignore[artipacked]
- name: Configure AWS credentials
if: startsWith(matrix.platform, 'macos-')
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
if: startsWith(matrix.platform, 'macos-')
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
APPLE_ID, deploy/apple-id
APPLE_PASSWORD, deploy/apple-password
APPLE_CERTIFICATE, deploy/apple-certificate
APPLE_CERTIFICATE_PASSWORD, deploy/apple-certificate-password
KEYCHAIN_PASSWORD, deploy/keychain-password
APPLE_TEAM_ID, deploy/apple-team-id
parse-json-secrets: true
- name: install dependencies (ubuntu only)
if: startsWith(matrix.platform, 'ubuntu-')
run: |
@@ -285,15 +323,40 @@ jobs:
Write-Host "Versions set to: $VERSION"
- name: Import Apple Developer Certificate
if: startsWith(matrix.platform, 'macos-')
run: |
echo $APPLE_CERTIFICATE | base64 --decode > certificate.p12
security create-keychain -p "$KEYCHAIN_PASSWORD" build.keychain
security default-keychain -s build.keychain
security unlock-keychain -p "$KEYCHAIN_PASSWORD" build.keychain
security set-keychain-settings -t 3600 -u build.keychain
security import certificate.p12 -k build.keychain -P "$APPLE_CERTIFICATE_PASSWORD" -T /usr/bin/codesign
security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k "$KEYCHAIN_PASSWORD" build.keychain
security find-identity -v -p codesigning build.keychain
- name: Verify Certificate
if: startsWith(matrix.platform, 'macos-')
run: |
CERT_INFO=$(security find-identity -v -p codesigning build.keychain | grep -E "(Developer ID Application|Apple Distribution|Apple Development)" | head -n 1)
CERT_ID=$(echo "$CERT_INFO" | awk -F'"' '{print $2}')
echo "CERT_ID=$CERT_ID" >> $GITHUB_ENV
echo "Certificate imported."
- uses: tauri-apps/tauri-action@73fb865345c54760d875b94642314f8c0c894afa # ratchet:tauri-apps/tauri-action@action-v0.6.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
APPLE_ID: ${{ env.APPLE_ID }}
APPLE_PASSWORD: ${{ env.APPLE_PASSWORD }}
APPLE_SIGNING_IDENTITY: ${{ env.CERT_ID }}
APPLE_TEAM_ID: ${{ env.APPLE_TEAM_ID }}
with:
tagName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
releaseName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
releaseBody: "See the assets to download this version and install."
releaseDraft: true
prerelease: false
assetNamePattern: "[name]_[arch][ext]"
args: ${{ matrix.args }}
build-web-amd64:
@@ -305,6 +368,7 @@ jobs:
- run-id=${{ github.run_id }}-web-amd64
- extras=ecr-cache
timeout-minutes: 90
environment: release
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -317,6 +381,20 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
@@ -331,8 +409,8 @@ jobs:
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push AMD64
id: build
@@ -363,6 +441,7 @@ jobs:
- run-id=${{ github.run_id }}-web-arm64
- extras=ecr-cache
timeout-minutes: 90
environment: release
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -375,6 +454,20 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
@@ -389,8 +482,8 @@ jobs:
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push ARM64
id: build
@@ -423,19 +516,34 @@ jobs:
- run-id=${{ github.run_id }}-merge-web
- extras=ecr-cache
timeout-minutes: 90
environment: release
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Docker meta
id: meta
@@ -471,6 +579,7 @@ jobs:
- run-id=${{ github.run_id }}-web-cloud-amd64
- extras=ecr-cache
timeout-minutes: 90
environment: release
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -483,6 +592,20 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
@@ -497,8 +620,8 @@ jobs:
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push AMD64
id: build
@@ -537,6 +660,7 @@ jobs:
- run-id=${{ github.run_id }}-web-cloud-arm64
- extras=ecr-cache
timeout-minutes: 90
environment: release
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -549,6 +673,20 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
@@ -563,8 +701,8 @@ jobs:
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push ARM64
id: build
@@ -605,19 +743,34 @@ jobs:
- run-id=${{ github.run_id }}-merge-web-cloud
- extras=ecr-cache
timeout-minutes: 90
environment: release
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Docker meta
id: meta
@@ -650,6 +803,7 @@ jobs:
- run-id=${{ github.run_id }}-backend-amd64
- extras=ecr-cache
timeout-minutes: 90
environment: release
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -662,6 +816,20 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
@@ -676,8 +844,8 @@ jobs:
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push AMD64
id: build
@@ -707,6 +875,7 @@ jobs:
- run-id=${{ github.run_id }}-backend-arm64
- extras=ecr-cache
timeout-minutes: 90
environment: release
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -719,6 +888,20 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
@@ -733,8 +916,8 @@ jobs:
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push ARM64
id: build
@@ -766,19 +949,34 @@ jobs:
- run-id=${{ github.run_id }}-merge-backend
- extras=ecr-cache
timeout-minutes: 90
environment: release
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Docker meta
id: meta
@@ -815,6 +1013,7 @@ jobs:
- volume=40gb
- extras=ecr-cache
timeout-minutes: 90
environment: release
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -827,6 +1026,20 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
@@ -843,8 +1056,8 @@ jobs:
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push AMD64
id: build
@@ -879,6 +1092,7 @@ jobs:
- volume=40gb
- extras=ecr-cache
timeout-minutes: 90
environment: release
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -891,6 +1105,20 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
@@ -907,8 +1135,8 @@ jobs:
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push ARM64
id: build
@@ -944,19 +1172,34 @@ jobs:
- run-id=${{ github.run_id }}-merge-model-server
- extras=ecr-cache
timeout-minutes: 90
environment: release
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Docker meta
id: meta
@@ -994,11 +1237,26 @@ jobs:
- run-id=${{ github.run_id }}-trivy-scan-web
- extras=ecr-cache
timeout-minutes: 90
environment: release
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Run Trivy vulnerability scanner
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
with:
@@ -1014,8 +1272,8 @@ jobs:
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
image \
--skip-version-check \
@@ -1034,11 +1292,26 @@ jobs:
- run-id=${{ github.run_id }}-trivy-scan-web-cloud
- extras=ecr-cache
timeout-minutes: 90
environment: release
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Run Trivy vulnerability scanner
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
with:
@@ -1054,8 +1327,8 @@ jobs:
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
image \
--skip-version-check \
@@ -1074,6 +1347,7 @@ jobs:
- run-id=${{ github.run_id }}-trivy-scan-backend
- extras=ecr-cache
timeout-minutes: 90
environment: release
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
steps:
@@ -1084,6 +1358,20 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Run Trivy vulnerability scanner
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
with:
@@ -1100,8 +1388,8 @@ jobs:
-v ${{ github.workspace }}/backend/.trivyignore:/tmp/.trivyignore:ro \
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
image \
--skip-version-check \
@@ -1121,11 +1409,26 @@ jobs:
- run-id=${{ github.run_id }}-trivy-scan-model-server
- extras=ecr-cache
timeout-minutes: 90
environment: release
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Run Trivy vulnerability scanner
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
with:
@@ -1141,8 +1444,8 @@ jobs:
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
image \
--skip-version-check \
@@ -1170,12 +1473,26 @@ jobs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 90
environment: release
steps:
- name: Checkout
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
MONITOR_DEPLOYMENTS_WEBHOOK, deploy/monitor-deployments-webhook
parse-json-secrets: true
- name: Determine failed jobs
id: failed-jobs
shell: bash
@@ -1241,7 +1558,7 @@ jobs:
- name: Send Slack notification
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
webhook-url: ${{ env.MONITOR_DEPLOYMENTS_WEBHOOK }}
failed-jobs: ${{ steps.failed-jobs.outputs.jobs }}
title: "🚨 Deployment Workflow Failed"
ref-name: ${{ github.ref_name }}

View File

@@ -50,8 +50,9 @@ jobs:
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: backend/.mypy_cache
key: mypy-${{ runner.os }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
restore-keys: |
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
mypy-${{ runner.os }}-
- name: Run MyPy

View File

@@ -151,6 +151,24 @@
},
"consoleTitle": "Slack Bot Console"
},
{
"name": "Discord Bot",
"consoleName": "Discord Bot",
"type": "debugpy",
"request": "launch",
"program": "onyx/onyxbot/discord/client.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"presentation": {
"group": "2"
},
"consoleTitle": "Discord Bot Console"
},
{
"name": "MCP Server",
"consoleName": "MCP Server",

View File

@@ -0,0 +1,116 @@
"""Add Discord bot tables
Revision ID: 8b5ce697290e
Revises: a1b2c3d4e5f7
Create Date: 2025-01-14
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "8b5ce697290e"
down_revision = "a1b2c3d4e5f7"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# DiscordBotConfig (singleton table - one per tenant)
op.create_table(
"discord_bot_config",
sa.Column(
"id",
sa.String(),
primary_key=True,
server_default=sa.text("'SINGLETON'"),
),
sa.Column("bot_token", sa.LargeBinary(), nullable=False), # EncryptedString
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.CheckConstraint("id = 'SINGLETON'", name="ck_discord_bot_config_singleton"),
)
# DiscordGuildConfig
op.create_table(
"discord_guild_config",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("guild_id", sa.BigInteger(), nullable=True, unique=True),
sa.Column("guild_name", sa.String(), nullable=True),
sa.Column("registration_key", sa.String(), nullable=False, unique=True),
sa.Column("registered_at", sa.DateTime(timezone=True), nullable=True),
sa.Column(
"default_persona_id",
sa.Integer(),
sa.ForeignKey("persona.id", ondelete="SET NULL"),
nullable=True,
),
sa.Column(
"enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False
),
)
# DiscordChannelConfig
op.create_table(
"discord_channel_config",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column(
"guild_config_id",
sa.Integer(),
sa.ForeignKey("discord_guild_config.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("channel_id", sa.BigInteger(), nullable=False),
sa.Column("channel_name", sa.String(), nullable=False),
sa.Column(
"channel_type",
sa.String(20),
server_default=sa.text("'text'"),
nullable=False,
),
sa.Column(
"is_private",
sa.Boolean(),
server_default=sa.text("false"),
nullable=False,
),
sa.Column(
"thread_only_mode",
sa.Boolean(),
server_default=sa.text("false"),
nullable=False,
),
sa.Column(
"require_bot_invocation",
sa.Boolean(),
server_default=sa.text("true"),
nullable=False,
),
sa.Column(
"persona_override_id",
sa.Integer(),
sa.ForeignKey("persona.id", ondelete="SET NULL"),
nullable=True,
),
sa.Column(
"enabled", sa.Boolean(), server_default=sa.text("false"), nullable=False
),
)
# Unique constraint: one config per channel per guild
op.create_unique_constraint(
"uq_discord_channel_guild_channel",
"discord_channel_config",
["guild_config_id", "channel_id"],
)
def downgrade() -> None:
op.drop_table("discord_channel_config")
op.drop_table("discord_guild_config")
op.drop_table("discord_bot_config")

View File

@@ -0,0 +1,47 @@
"""drop agent_search_metrics table
Revision ID: a1b2c3d4e5f7
Revises: 73e9983e5091
Create Date: 2026-01-17
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "a1b2c3d4e5f7"
down_revision = "73e9983e5091"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_table("agent__search_metrics")
def downgrade() -> None:
op.create_table(
"agent__search_metrics",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.UUID(), nullable=True),
sa.Column("persona_id", sa.Integer(), nullable=True),
sa.Column("agent_type", sa.String(), nullable=False),
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
sa.Column("base_duration_s", sa.Float(), nullable=False),
sa.Column("full_duration_s", sa.Float(), nullable=False),
sa.Column("base_metrics", postgresql.JSONB(), nullable=True),
sa.Column("refined_metrics", postgresql.JSONB(), nullable=True),
sa.Column("all_metrics", postgresql.JSONB(), nullable=True),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.PrimaryKeyConstraint("id"),
)

View File

@@ -109,7 +109,6 @@ CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS = float(
STRIPE_SECRET_KEY = os.environ.get("STRIPE_SECRET_KEY")
STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")
# JWT Public Key URL
JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
@@ -129,3 +128,8 @@ MARKETING_POSTHOG_API_KEY = os.environ.get("MARKETING_POSTHOG_API_KEY")
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
GATED_TENANTS_KEY = "gated_tenants"
# License enforcement - when True, blocks API access for gated/expired licenses
LICENSE_ENFORCEMENT_ENABLED = (
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "").lower() == "true"
)

View File

@@ -16,6 +16,9 @@ from ee.onyx.server.enterprise_settings.api import (
from ee.onyx.server.evals.api import router as evals_router
from ee.onyx.server.license.api import router as license_router
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
from ee.onyx.server.middleware.license_enforcement import (
add_license_enforcement_middleware,
)
from ee.onyx.server.middleware.tenant_tracking import (
add_api_server_tenant_id_middleware,
)
@@ -83,6 +86,10 @@ def get_application() -> FastAPI:
if MULTI_TENANT:
add_api_server_tenant_id_middleware(application, logger)
# Add license enforcement middleware (runs after tenant tracking)
# This blocks access when license is expired/gated
add_license_enforcement_middleware(application, logger)
if AUTH_TYPE == AuthType.CLOUD:
# For Google OAuth, refresh tokens are requested by:
# 1. Adding the right scopes

View File

@@ -0,0 +1,102 @@
"""Middleware to enforce license status application-wide."""
import logging
from collections.abc import Awaitable
from collections.abc import Callable
from fastapi import FastAPI
from fastapi import Request
from fastapi import Response
from fastapi.responses import JSONResponse
from redis.exceptions import RedisError
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
from ee.onyx.db.license import get_cached_license_metadata
from ee.onyx.server.tenants.product_gating import is_tenant_gated
from onyx.server.settings.models import ApplicationStatus
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
# Paths that are ALWAYS accessible, even when license is expired/gated.
# These enable users to:
# /auth - Log in/out (users can't fix billing if locked out of auth)
# /license - Fetch, upload, or check license status
# /health - Health checks for load balancers/orchestrators
# /me - Basic user info needed for UI rendering
# /settings, /enterprise-settings - View app status and branding
# /tenants/billing-* - Manage subscription to resolve gating
ALLOWED_PATH_PREFIXES = {
"/auth",
"/license",
"/health",
"/me",
"/settings",
"/enterprise-settings",
"/tenants/billing-information",
"/tenants/create-customer-portal-session",
"/tenants/create-subscription-session",
}
def _is_path_allowed(path: str) -> bool:
"""Check if path is in allowlist (prefix match)."""
return any(path.startswith(prefix) for prefix in ALLOWED_PATH_PREFIXES)
def add_license_enforcement_middleware(
app: FastAPI, logger: logging.LoggerAdapter
) -> None:
logger.info("License enforcement middleware registered")
@app.middleware("http")
async def enforce_license(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""Block requests when license is expired/gated."""
if not LICENSE_ENFORCEMENT_ENABLED:
return await call_next(request)
path = request.url.path
if path.startswith("/api"):
path = path[4:]
if _is_path_allowed(path):
return await call_next(request)
is_gated = False
tenant_id = get_current_tenant_id()
if MULTI_TENANT:
try:
is_gated = is_tenant_gated(tenant_id)
except RedisError as e:
logger.warning(f"Failed to check tenant gating status: {e}")
# Fail open - don't block users due to Redis connectivity issues
is_gated = False
else:
try:
metadata = get_cached_license_metadata(tenant_id)
if metadata:
if metadata.status == ApplicationStatus.GATED_ACCESS:
is_gated = True
else:
# No license metadata = gated for self-hosted EE
is_gated = True
except RedisError as e:
logger.warning(f"Failed to check license metadata: {e}")
# Fail open - don't block users due to Redis connectivity issues
is_gated = False
if is_gated:
logger.info(f"Blocking request for gated tenant: {tenant_id}, path={path}")
return JSONResponse(
status_code=402,
content={
"detail": {
"error": "license_expired",
"message": "Your subscription has expired. Please update your billing.",
}
},
)
return await call_next(request)

View File

@@ -0,0 +1,54 @@
"""EE Settings API - provides license-aware settings override."""
from redis.exceptions import RedisError
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
from ee.onyx.db.license import get_cached_license_metadata
from onyx.server.settings.models import ApplicationStatus
from onyx.server.settings.models import Settings
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
# Statuses that indicate a billing/license problem - propagate these to settings
_GATED_STATUSES = frozenset(
{
ApplicationStatus.GATED_ACCESS,
ApplicationStatus.GRACE_PERIOD,
ApplicationStatus.PAYMENT_REMINDER,
}
)
def apply_license_status_to_settings(settings: Settings) -> Settings:
"""EE version: checks license status for self-hosted deployments.
For self-hosted, looks up license metadata and overrides application_status
if the license is missing or indicates a problem (expired, grace period, etc.).
For multi-tenant (cloud), the settings already have the correct status
from the control plane, so no override is needed.
If LICENSE_ENFORCEMENT_ENABLED is false, settings are returned unchanged,
allowing the product to function normally without license checks.
"""
if not LICENSE_ENFORCEMENT_ENABLED:
return settings
if MULTI_TENANT:
return settings
tenant_id = get_current_tenant_id()
try:
metadata = get_cached_license_metadata(tenant_id)
if metadata and metadata.status in _GATED_STATUSES:
settings.application_status = metadata.status
elif not metadata:
# No license = gated access for self-hosted EE
settings.application_status = ApplicationStatus.GATED_ACCESS
except RedisError as e:
logger.warning(f"Failed to check license metadata for settings: {e}")
return settings

View File

@@ -1,9 +1,9 @@
from typing import cast
from typing import Literal
import requests
import stripe
from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import BillingInformation
@@ -16,15 +16,21 @@ stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
def fetch_stripe_checkout_session(tenant_id: str) -> str:
def fetch_stripe_checkout_session(
tenant_id: str,
billing_period: Literal["monthly", "annual"] = "monthly",
) -> str:
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = f"{CONTROL_PLANE_API_BASE_URL}/create-checkout-session"
params = {"tenant_id": tenant_id}
response = requests.post(url, headers=headers, params=params)
payload = {
"tenant_id": tenant_id,
"billing_period": billing_period,
}
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
return response.json()["sessionId"]
@@ -70,24 +76,46 @@ def fetch_billing_information(
return BillingInformation(**response_data)
def fetch_customer_portal_session(tenant_id: str, return_url: str | None = None) -> str:
"""
Fetch a Stripe customer portal session URL from the control plane.
NOTE: This is currently only used for multi-tenant (cloud) deployments.
Self-hosted proxy endpoints will be added in a future phase.
"""
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = f"{CONTROL_PLANE_API_BASE_URL}/create-customer-portal-session"
payload = {"tenant_id": tenant_id}
if return_url:
payload["return_url"] = return_url
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
return response.json()["url"]
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:
"""
Send a request to the control service to register the number of users for a tenant.
Update the number of seats for a tenant's subscription.
Preserves the existing price (monthly, annual, or grandfathered).
"""
if not STRIPE_PRICE_ID:
raise Exception("STRIPE_PRICE_ID is not set")
response = fetch_tenant_stripe_information(tenant_id)
stripe_subscription_id = cast(str, response.get("stripe_subscription_id"))
subscription = stripe.Subscription.retrieve(stripe_subscription_id)
subscription_item = subscription["items"]["data"][0]
# Use existing price to preserve the customer's current plan
current_price_id = subscription_item.price.id
updated_subscription = stripe.Subscription.modify(
stripe_subscription_id,
items=[
{
"id": subscription["items"]["data"][0].id,
"price": STRIPE_PRICE_ID,
"id": subscription_item.id,
"price": current_price_id,
"quantity": number_of_users,
}
],

View File

@@ -1,15 +1,14 @@
import stripe
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from ee.onyx.auth.users import current_admin_user
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import control_plane_dep
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_customer_portal_session
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import CreateSubscriptionSessionRequest
from ee.onyx.server.tenants.models import ProductGatingFullSyncRequest
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
@@ -23,7 +22,6 @@ from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@@ -82,21 +80,17 @@ async def billing_information(
async def create_customer_portal_session(
_: User = Depends(current_admin_user),
) -> dict:
"""
Create a Stripe customer portal session via the control plane.
NOTE: This is currently only used for multi-tenant (cloud) deployments.
Self-hosted proxy endpoints will be added in a future phase.
"""
tenant_id = get_current_tenant_id()
return_url = f"{WEB_DOMAIN}/admin/billing"
try:
stripe_info = fetch_tenant_stripe_information(tenant_id)
stripe_customer_id = stripe_info.get("stripe_customer_id")
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)
portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/billing",
)
logger.info(portal_session)
return {"url": portal_session.url}
portal_url = fetch_customer_portal_session(tenant_id, return_url)
return {"url": portal_url}
except Exception as e:
logger.exception("Failed to create customer portal session")
raise HTTPException(status_code=500, detail=str(e))
@@ -104,15 +98,18 @@ async def create_customer_portal_session(
@router.post("/create-subscription-session")
async def create_subscription_session(
request: CreateSubscriptionSessionRequest | None = None,
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
session_id = fetch_stripe_checkout_session(tenant_id)
billing_period = request.billing_period if request else "monthly"
session_id = fetch_stripe_checkout_session(tenant_id, billing_period)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create resubscription session")
logger.exception("Failed to create subscription session")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -1,4 +1,5 @@
from datetime import datetime
from typing import Literal
from pydantic import BaseModel
@@ -73,6 +74,12 @@ class SubscriptionSessionResponse(BaseModel):
sessionId: str
class CreateSubscriptionSessionRequest(BaseModel):
"""Request to create a subscription checkout session."""
billing_period: Literal["monthly", "annual"] = "monthly"
class TenantByDomainResponse(BaseModel):
tenant_id: str
number_of_users: int

View File

@@ -65,3 +65,9 @@ def get_gated_tenants() -> set[str]:
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
gated_tenants_bytes = cast(set[bytes], redis_client.smembers(GATED_TENANTS_KEY))
return {tenant_id.decode("utf-8") for tenant_id in gated_tenants_bytes}
def is_tenant_gated(tenant_id: str) -> bool:
"""Fast O(1) check if tenant is in gated set (multi-tenant only)."""
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
return bool(redis_client.sismember(GATED_TENANTS_KEY, tenant_id))

View File

@@ -12,6 +12,7 @@ from retry import retry
from sqlalchemy import select
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.configs.app_configs import MANAGED_VESPA
@@ -19,12 +20,14 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
from onyx.connectors.file.connector import LocalFileConnector
from onyx.connectors.models import Document
from onyx.db.engine.sql_engine import get_session_with_current_tenant
@@ -53,6 +56,17 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
def _user_file_queued_key(user_file_id: str | UUID) -> str:
"""Key that exists while a process_single_user_file task is sitting in the queue.
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
before enqueuing and the worker deletes it as its first action. This prevents
the beat from adding duplicate tasks for files that already have a live task
in flight.
"""
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
@@ -116,7 +130,24 @@ def _get_document_chunk_count(
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
Uses direct Redis locks to avoid overlapping runs.
Three mechanisms prevent queue runaway:
1. **Queue depth backpressure** if the broker queue already has more than
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
entirely. Workers are clearly behind; adding more tasks would only make
the backlog worse.
2. **Per-file queued guard** before enqueuing a task we set a short-lived
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
already exists the file already has a live task in the queue, so we skip
it. The worker deletes the key the moment it picks up the task so the
next beat cycle can re-enqueue if the file is still PROCESSING.
3. **Task expiry** every enqueued task carries an `expires` value equal to
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
the queue after that deadline, Celery discards it without touching the DB.
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
Redis restart), stale tasks evict themselves rather than piling up forever.
"""
task_logger.info("check_user_file_processing - Starting")
@@ -131,7 +162,21 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
return None
enqueued = 0
skipped_guard = 0
try:
# --- Protection 1: queue depth backpressure ---
r_celery = self.app.broker_connection().channel().client # type: ignore
queue_len = celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
)
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
task_logger.warning(
f"check_user_file_processing - Queue depth {queue_len} exceeds "
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
f"tenant={tenant_id}"
)
return None
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
@@ -144,12 +189,35 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
)
for user_file_id in user_file_ids:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
# --- Protection 2: per-file queued guard ---
queued_key = _user_file_queued_key(user_file_id)
guard_set = redis_client.set(
queued_key,
1,
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
nx=True,
)
if not guard_set:
skipped_guard += 1
continue
# --- Protection 3: task expiry ---
# If task submission fails, clear the guard immediately so the
# next beat cycle can retry enqueuing this file.
try:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={
"user_file_id": str(user_file_id),
"tenant_id": tenant_id,
},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
)
except Exception:
redis_client.delete(queued_key)
raise
enqueued += 1
finally:
@@ -157,7 +225,8 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
lock.release()
task_logger.info(
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
f"tasks for tenant={tenant_id}"
)
return None
@@ -172,6 +241,12 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
start = time.monotonic()
redis_client = get_redis_client(tenant_id=tenant_id)
# Clear the "queued" guard set by the beat generator so that the next beat
# cycle can re-enqueue this file if it is still in PROCESSING state after
# this task completes or fails.
redis_client.delete(_user_file_queued_key(user_file_id))
file_lock: RedisLock = redis_client.lock(
_user_file_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,

View File

@@ -9,6 +9,7 @@ from onyx.chat.citation_processor import CitationMode
from onyx.chat.citation_processor import DynamicCitationProcessor
from onyx.chat.citation_utils import update_citation_processor_from_tool_response
from onyx.chat.emitter import Emitter
from onyx.chat.llm_step import extract_tool_calls_from_response_text
from onyx.chat.llm_step import run_llm_step
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import ExtractedProjectFiles
@@ -38,6 +39,7 @@ from onyx.tools.built_in_tools import CITEABLE_TOOLS_NAMES
from onyx.tools.built_in_tools import STOPPING_TOOLS_NAMES
from onyx.tools.interface import Tool
from onyx.tools.models import ToolCallInfo
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
from onyx.tools.tool_implementations.images.models import (
FinalImageGenerationResponse,
@@ -51,6 +53,78 @@ from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
def _try_fallback_tool_extraction(
llm_step_result: LlmStepResult,
tool_choice: ToolChoiceOptions,
fallback_extraction_attempted: bool,
tool_defs: list[dict],
turn_index: int,
) -> tuple[LlmStepResult, bool]:
"""Attempt to extract tool calls from response text as a fallback.
This is a last resort fallback for low quality LLMs or those that don't have
tool calling from the serving layer. Also triggers if there's reasoning but
no answer and no tool calls.
Args:
llm_step_result: The result from the LLM step
tool_choice: The tool choice option used for this step
fallback_extraction_attempted: Whether fallback extraction was already attempted
tool_defs: List of tool definitions
turn_index: The current turn index for placement
Returns:
Tuple of (possibly updated LlmStepResult, whether fallback was attempted this call)
"""
if fallback_extraction_attempted:
return llm_step_result, False
no_tool_calls = (
not llm_step_result.tool_calls or len(llm_step_result.tool_calls) == 0
)
reasoning_but_no_answer_or_tools = (
llm_step_result.reasoning and not llm_step_result.answer and no_tool_calls
)
should_try_fallback = (
tool_choice == ToolChoiceOptions.REQUIRED and no_tool_calls
) or reasoning_but_no_answer_or_tools
if not should_try_fallback:
return llm_step_result, False
# Try to extract from answer first, then fall back to reasoning
extracted_tool_calls: list[ToolCallKickoff] = []
if llm_step_result.answer:
extracted_tool_calls = extract_tool_calls_from_response_text(
response_text=llm_step_result.answer,
tool_definitions=tool_defs,
placement=Placement(turn_index=turn_index),
)
if not extracted_tool_calls and llm_step_result.reasoning:
extracted_tool_calls = extract_tool_calls_from_response_text(
response_text=llm_step_result.reasoning,
tool_definitions=tool_defs,
placement=Placement(turn_index=turn_index),
)
if extracted_tool_calls:
logger.info(
f"Extracted {len(extracted_tool_calls)} tool call(s) from response text "
f"as fallback (tool_choice was REQUIRED but no tool calls returned)"
)
return (
LlmStepResult(
reasoning=llm_step_result.reasoning,
answer=llm_step_result.answer,
tool_calls=extracted_tool_calls,
),
True,
)
return llm_step_result, True
# Hardcoded oppinionated value, might breaks down to something like:
# Cycle 1: Calls web_search for something
# Cycle 2: Calls open_url for some results
@@ -352,6 +426,7 @@ def run_llm_loop(
ran_image_gen: bool = False
just_ran_web_search: bool = False
has_called_search_tool: bool = False
fallback_extraction_attempted: bool = False
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
default_base_system_prompt: str = get_default_base_system_prompt(db_session)
@@ -470,10 +545,11 @@ def run_llm_loop(
# This calls the LLM, yields packets (reasoning, answers, etc.) and returns the result
# It also pre-processes the tool calls in preparation for running them
tool_defs = [tool.tool_definition() for tool in final_tools]
llm_step_result, has_reasoned = run_llm_step(
emitter=emitter,
history=truncated_message_history,
tool_definitions=[tool.tool_definition() for tool in final_tools],
tool_definitions=tool_defs,
tool_choice=tool_choice,
llm=llm,
placement=Placement(turn_index=llm_cycle_count + reasoning_cycles),
@@ -488,6 +564,19 @@ def run_llm_loop(
if has_reasoned:
reasoning_cycles += 1
# Fallback extraction for LLMs that don't support tool calling natively or are lower quality
# and might incorrectly output tool calls in other channels
llm_step_result, attempted = _try_fallback_tool_extraction(
llm_step_result=llm_step_result,
tool_choice=tool_choice,
fallback_extraction_attempted=fallback_extraction_attempted,
tool_defs=tool_defs,
turn_index=llm_cycle_count + reasoning_cycles,
)
if attempted:
# To prevent the case of excessive looping with bad models, we only allow one fallback attempt
fallback_extraction_attempted = True
# Save citation mapping after each LLM step for incremental state updates
state_container.set_citation_mapping(citation_processor.citation_to_doc)
@@ -580,6 +669,12 @@ def run_llm_loop(
):
generated_images = tool_response.rich_response.generated_images
saved_response = (
tool_response.rich_response
if isinstance(tool_response.rich_response, str)
else tool_response.llm_facing_response
)
tool_call_info = ToolCallInfo(
parent_tool_call_id=None, # Top-level tool calls are attached to the chat message
turn_index=llm_cycle_count + reasoning_cycles,
@@ -589,7 +684,7 @@ def run_llm_loop(
tool_id=tool.id,
reasoning_tokens=llm_step_result.reasoning, # All tool calls from this loop share the same reasoning
tool_call_arguments=tool_call.tool_args,
tool_call_response=tool_response.llm_facing_response,
tool_call_response=saved_response,
search_docs=search_docs,
generated_images=generated_images,
)
@@ -645,7 +740,12 @@ def run_llm_loop(
should_cite_documents = True
if not llm_step_result or not llm_step_result.answer:
raise RuntimeError("LLM did not return an answer.")
raise RuntimeError(
"The LLM did not return an answer. "
"Typically this is an issue with LLMs that do not support tool calling natively, "
"or the model serving API is not configured correctly. "
"This may also happen with models that are lower quality outputting invalid tool calls."
)
emitter.emit(
Packet(

View File

@@ -49,6 +49,7 @@ from onyx.tools.models import ToolCallKickoff
from onyx.tracing.framework.create import generation_span
from onyx.utils.b64 import get_image_type_from_bytes
from onyx.utils.logger import setup_logger
from onyx.utils.text_processing import find_all_json_objects
logger = setup_logger()
@@ -278,6 +279,144 @@ def _extract_tool_call_kickoffs(
return tool_calls
def extract_tool_calls_from_response_text(
response_text: str | None,
tool_definitions: list[dict],
placement: Placement,
) -> list[ToolCallKickoff]:
"""Extract tool calls from LLM response text by matching JSON against tool definitions.
This is a fallback mechanism for when the LLM was expected to return tool calls
but didn't use the proper tool call format. It searches for JSON objects in the
response text that match the structure of available tools.
Args:
response_text: The LLM's text response to search for tool calls
tool_definitions: List of tool definitions to match against
placement: Placement information for the tool calls
Returns:
List of ToolCallKickoff objects for any matched tool calls
"""
if not response_text or not tool_definitions:
return []
# Build a map of tool names to their definitions
tool_name_to_def: dict[str, dict] = {}
for tool_def in tool_definitions:
if tool_def.get("type") == "function" and "function" in tool_def:
func_def = tool_def["function"]
tool_name = func_def.get("name")
if tool_name:
tool_name_to_def[tool_name] = func_def
if not tool_name_to_def:
return []
# Find all JSON objects in the response text
json_objects = find_all_json_objects(response_text)
tool_calls: list[ToolCallKickoff] = []
tab_index = 0
for json_obj in json_objects:
matched_tool_call = _try_match_json_to_tool(json_obj, tool_name_to_def)
if matched_tool_call:
tool_name, tool_args = matched_tool_call
tool_calls.append(
ToolCallKickoff(
tool_call_id=f"extracted_{uuid.uuid4().hex[:8]}",
tool_name=tool_name,
tool_args=tool_args,
placement=Placement(
turn_index=placement.turn_index,
tab_index=tab_index,
sub_turn_index=placement.sub_turn_index,
),
)
)
tab_index += 1
logger.info(
f"Extracted {len(tool_calls)} tool call(s) from response text as fallback"
)
return tool_calls
def _try_match_json_to_tool(
json_obj: dict[str, Any],
tool_name_to_def: dict[str, dict],
) -> tuple[str, dict[str, Any]] | None:
"""Try to match a JSON object to a tool definition.
Supports several formats:
1. Direct tool call format: {"name": "tool_name", "arguments": {...}}
2. Function call format: {"function": {"name": "tool_name", "arguments": {...}}}
3. Tool name as key: {"tool_name": {...arguments...}}
4. Arguments matching a tool's parameter schema
Args:
json_obj: The JSON object to match
tool_name_to_def: Map of tool names to their function definitions
Returns:
Tuple of (tool_name, tool_args) if matched, None otherwise
"""
# Format 1: Direct tool call format {"name": "...", "arguments": {...}}
if "name" in json_obj and json_obj["name"] in tool_name_to_def:
tool_name = json_obj["name"]
arguments = json_obj.get("arguments", json_obj.get("parameters", {}))
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {}
if isinstance(arguments, dict):
return (tool_name, arguments)
# Format 2: Function call format {"function": {"name": "...", "arguments": {...}}}
if "function" in json_obj and isinstance(json_obj["function"], dict):
func_obj = json_obj["function"]
if "name" in func_obj and func_obj["name"] in tool_name_to_def:
tool_name = func_obj["name"]
arguments = func_obj.get("arguments", func_obj.get("parameters", {}))
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except json.JSONDecodeError:
arguments = {}
if isinstance(arguments, dict):
return (tool_name, arguments)
# Format 3: Tool name as key {"tool_name": {...arguments...}}
for tool_name in tool_name_to_def:
if tool_name in json_obj:
arguments = json_obj[tool_name]
if isinstance(arguments, dict):
return (tool_name, arguments)
# Format 4: Check if the JSON object matches a tool's parameter schema
for tool_name, func_def in tool_name_to_def.items():
params = func_def.get("parameters", {})
properties = params.get("properties", {})
required = params.get("required", [])
if not properties:
continue
# Check if all required parameters are present (empty required = all optional)
if all(req in json_obj for req in required):
# Check if any of the tool's properties are in the JSON object
matching_props = [prop for prop in properties if prop in json_obj]
if matching_props:
# Filter to only include known properties
filtered_args = {k: v for k, v in json_obj.items() if k in properties}
return (tool_name, filtered_args)
return None
def translate_history_to_llm_format(
history: list[ChatMessageSimple],
llm_config: LLMConfig,

View File

@@ -86,10 +86,6 @@ from onyx.utils.logger import setup_logger
from onyx.utils.long_term_log import LongTermLogger
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.timing import log_function_time
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from onyx.utils.variable_functionality import noop_fallback
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -362,21 +358,20 @@ def handle_stream_message_objects(
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
)
# Track user message in PostHog for analytics
fetch_versioned_implementation_with_fallback(
module="onyx.utils.telemetry",
attribute="event_telemetry",
fallback=noop_fallback,
)(
distinct_id=user.email if user else tenant_id,
event="user_message_sent",
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=(
user.email
if user and not getattr(user, "is_anonymous", False)
else tenant_id
),
event=MilestoneRecordType.USER_MESSAGE_SENT,
properties={
"origin": new_msg_req.origin.value,
"has_files": len(new_msg_req.file_descriptors) > 0,
"has_project": chat_session.project_id is not None,
"has_persona": persona is not None and persona.id != DEFAULT_PERSONA_ID,
"deep_research": new_msg_req.deep_research,
"tenant_id": tenant_id,
},
)

View File

@@ -18,6 +18,7 @@ from onyx.prompts.prompt_utils import handle_onyx_date_awareness
from onyx.prompts.prompt_utils import replace_citation_guidance_tag
from onyx.prompts.tool_prompts import GENERATE_IMAGE_GUIDANCE
from onyx.prompts.tool_prompts import INTERNAL_SEARCH_GUIDANCE
from onyx.prompts.tool_prompts import MEMORY_GUIDANCE
from onyx.prompts.tool_prompts import OPEN_URLS_GUIDANCE
from onyx.prompts.tool_prompts import PYTHON_TOOL_GUIDANCE
from onyx.prompts.tool_prompts import TOOL_DESCRIPTION_SEARCH_GUIDANCE
@@ -28,6 +29,7 @@ from onyx.tools.interface import Tool
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.memory.memory_tool import MemoryTool
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
from onyx.tools.tool_implementations.python.python_tool import PythonTool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
@@ -178,8 +180,9 @@ def build_system_prompt(
site_colon_disabled=WEB_SEARCH_SITE_DISABLED_GUIDANCE
)
+ OPEN_URLS_GUIDANCE
+ GENERATE_IMAGE_GUIDANCE
+ PYTHON_TOOL_GUIDANCE
+ GENERATE_IMAGE_GUIDANCE
+ MEMORY_GUIDANCE
)
return system_prompt
@@ -193,6 +196,7 @@ def build_system_prompt(
has_generate_image = any(
isinstance(tool, ImageGenerationTool) for tool in tools
)
has_memory = any(isinstance(tool, MemoryTool) for tool in tools)
if has_web_search or has_internal_search or include_all_guidance:
system_prompt += TOOL_DESCRIPTION_SEARCH_GUIDANCE
@@ -222,4 +226,7 @@ def build_system_prompt(
if has_generate_image or include_all_guidance:
system_prompt += GENERATE_IMAGE_GUIDANCE
if has_memory or include_all_guidance:
system_prompt += MEMORY_GUIDANCE
return system_prompt

View File

@@ -1011,3 +1011,8 @@ INSTANCE_TYPE = (
if os.environ.get("IS_MANAGED_INSTANCE", "").lower() == "true"
else "cloud" if AUTH_TYPE == AuthType.CLOUD else "self_hosted"
)
## Discord Bot Configuration
DISCORD_BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
DISCORD_BOT_INVOKE_CHAR = os.environ.get("DISCORD_BOT_INVOKE_CHAR", "!")

View File

@@ -93,6 +93,7 @@ SSL_CERT_FILE = "bundle.pem"
DANSWER_API_KEY_PREFIX = "API_KEY__"
DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN = "onyxapikey.ai"
UNNAMED_KEY_PLACEHOLDER = "Unnamed"
DISCORD_SERVICE_API_KEY_NAME = "discord-bot-service"
# Key-Value store keys
KV_REINDEX_KEY = "needs_reindexing"
@@ -152,6 +153,17 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
# How long a queued user-file task is valid before workers discard it.
# Should be longer than the beat interval (20 s) but short enough to prevent
# indefinite queue growth. Workers drop tasks older than this without touching
# the DB, so a shorter value = faster drain of stale duplicates.
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
# Maximum number of tasks allowed in the user-file-processing queue before the
# beat generator stops adding more. Prevents unbounded queue growth when workers
# fall behind.
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
@@ -340,6 +352,7 @@ class MilestoneRecordType(str, Enum):
CREATED_CONNECTOR = "created_connector"
CONNECTOR_SUCCEEDED = "connector_succeeded"
RAN_QUERY = "ran_query"
USER_MESSAGE_SENT = "user_message_sent"
MULTIPLE_ASSISTANTS = "multiple_assistants"
CREATED_ASSISTANT = "created_assistant"
CREATED_ONYX_BOT = "created_onyx_bot"
@@ -422,6 +435,9 @@ class OnyxRedisLocks:
# User file processing
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
# Prevents the beat from re-enqueuing the same file while a task is already queued.
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"

View File

@@ -25,11 +25,17 @@ class AsanaConnector(LoadConnector, PollConnector):
batch_size: int = INDEX_BATCH_SIZE,
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
) -> None:
self.workspace_id = asana_workspace_id
self.project_ids_to_index: list[str] | None = (
asana_project_ids.split(",") if asana_project_ids is not None else None
)
self.asana_team_id = asana_team_id
self.workspace_id = asana_workspace_id.strip()
if asana_project_ids:
project_ids = [
project_id.strip()
for project_id in asana_project_ids.split(",")
if project_id.strip()
]
self.project_ids_to_index = project_ids or None
else:
self.project_ids_to_index = None
self.asana_team_id = (asana_team_id.strip() or None) if asana_team_id else None
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
logger.info(

View File

@@ -567,6 +567,23 @@ def extract_content_words_from_recency_query(
return content_words_filtered[:MAX_CONTENT_WORDS]
def _is_valid_keyword_query(line: str) -> bool:
"""Check if a line looks like a valid keyword query vs explanatory text.
Returns False for lines that appear to be LLM explanations rather than keywords.
"""
# Reject lines that start with parentheses (explanatory notes)
if line.startswith("("):
return False
# Reject lines that are too long (likely sentences, not keywords)
# Keywords should be short - reject if > 50 chars or > 6 words
if len(line) > 50 or len(line.split()) > 6:
return False
return True
def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
"""Use LLM to expand query into multiple search variations.
@@ -589,10 +606,18 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
response_clean = _parse_llm_code_block_response(response)
# Split into lines and filter out empty lines
rephrased_queries = [
raw_queries = [
line.strip() for line in response_clean.split("\n") if line.strip()
]
# Filter out lines that look like explanatory text rather than keywords
rephrased_queries = [q for q in raw_queries if _is_valid_keyword_query(q)]
# Log if we filtered out garbage
if len(raw_queries) != len(rephrased_queries):
filtered_out = set(raw_queries) - set(rephrased_queries)
logger.warning(f"Filtered out non-keyword LLM responses: {filtered_out}")
# If no queries generated, use empty query
if not rephrased_queries:
logger.debug("No content keywords extracted from query expansion")

View File

@@ -0,0 +1,451 @@
"""CRUD operations for Discord bot models."""
from datetime import datetime
from datetime import timezone
from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from onyx.auth.api_key import build_displayable_api_key
from onyx.auth.api_key import generate_api_key
from onyx.auth.api_key import hash_api_key
from onyx.auth.schemas import UserRole
from onyx.configs.constants import DISCORD_SERVICE_API_KEY_NAME
from onyx.db.api_key import insert_api_key
from onyx.db.models import ApiKey
from onyx.db.models import DiscordBotConfig
from onyx.db.models import DiscordChannelConfig
from onyx.db.models import DiscordGuildConfig
from onyx.db.models import User
from onyx.db.utils import DiscordChannelView
from onyx.server.api_key.models import APIKeyArgs
from onyx.utils.logger import setup_logger
logger = setup_logger()
# === DiscordBotConfig ===
def get_discord_bot_config(db_session: Session) -> DiscordBotConfig | None:
"""Get the Discord bot config for this tenant (at most one)."""
return db_session.scalar(select(DiscordBotConfig).limit(1))
def create_discord_bot_config(
db_session: Session,
bot_token: str,
) -> DiscordBotConfig:
"""Create the Discord bot config. Raises ValueError if already exists.
The check constraint on id='SINGLETON' ensures only one config per tenant.
"""
existing = get_discord_bot_config(db_session)
if existing:
raise ValueError("Discord bot config already exists")
config = DiscordBotConfig(bot_token=bot_token)
db_session.add(config)
try:
db_session.flush()
except IntegrityError:
# Race condition: another request created the config concurrently
db_session.rollback()
raise ValueError("Discord bot config already exists")
return config
def delete_discord_bot_config(db_session: Session) -> bool:
"""Delete the Discord bot config. Returns True if deleted."""
result = db_session.execute(delete(DiscordBotConfig))
db_session.flush()
return result.rowcount > 0 # type: ignore[attr-defined]
# === Discord Service API Key ===
def get_discord_service_api_key(db_session: Session) -> ApiKey | None:
"""Get the Discord service API key if it exists."""
return db_session.scalar(
select(ApiKey).where(ApiKey.name == DISCORD_SERVICE_API_KEY_NAME)
)
def get_or_create_discord_service_api_key(
db_session: Session,
tenant_id: str,
) -> str:
"""Get existing Discord service API key or create one.
The API key is used by the Discord bot to authenticate with the
Onyx API pods when sending chat requests.
Args:
db_session: Database session for the tenant.
tenant_id: The tenant ID (used for logging/context).
Returns:
The raw API key string (not hashed).
Raises:
RuntimeError: If API key creation fails.
"""
# Check for existing key
existing = get_discord_service_api_key(db_session)
if existing:
# Database only stores the hash, so we must regenerate to get the raw key.
# This is safe since the Discord bot is the only consumer of this key.
logger.debug(
f"Found existing Discord service API key for tenant {tenant_id} that isn't in cache, "
"regenerating to update cache"
)
new_api_key = generate_api_key(tenant_id)
existing.hashed_api_key = hash_api_key(new_api_key)
existing.api_key_display = build_displayable_api_key(new_api_key)
db_session.flush()
return new_api_key
# Create new API key
logger.info(f"Creating Discord service API key for tenant {tenant_id}")
api_key_args = APIKeyArgs(
name=DISCORD_SERVICE_API_KEY_NAME,
role=UserRole.LIMITED, # Limited role is sufficient for chat requests
)
api_key_descriptor = insert_api_key(
db_session=db_session,
api_key_args=api_key_args,
user_id=None, # Service account, no owner
)
if not api_key_descriptor.api_key:
raise RuntimeError(
f"Failed to create Discord service API key for tenant {tenant_id}"
)
return api_key_descriptor.api_key
def delete_discord_service_api_key(db_session: Session) -> bool:
"""Delete the Discord service API key for a tenant.
Called when:
- Bot config is deleted (self-hosted)
- All guild configs are deleted (Cloud)
Args:
db_session: Database session for the tenant.
Returns:
True if the key was deleted, False if it didn't exist.
"""
existing_key = get_discord_service_api_key(db_session)
if not existing_key:
return False
# Also delete the associated user
api_key_user = db_session.scalar(
select(User).where(User.id == existing_key.user_id) # type: ignore[arg-type]
)
db_session.delete(existing_key)
if api_key_user:
db_session.delete(api_key_user)
db_session.flush()
logger.info("Deleted Discord service API key")
return True
# === DiscordGuildConfig ===
def get_guild_configs(
db_session: Session,
include_channels: bool = False,
) -> list[DiscordGuildConfig]:
"""Get all guild configs for this tenant."""
stmt = select(DiscordGuildConfig)
if include_channels:
stmt = stmt.options(joinedload(DiscordGuildConfig.channels))
return list(db_session.scalars(stmt).unique().all())
def get_guild_config_by_internal_id(
db_session: Session,
internal_id: int,
) -> DiscordGuildConfig | None:
"""Get a specific guild config by its ID."""
return db_session.scalar(
select(DiscordGuildConfig).where(DiscordGuildConfig.id == internal_id)
)
def get_guild_config_by_discord_id(
db_session: Session,
guild_id: int,
) -> DiscordGuildConfig | None:
"""Get a guild config by Discord guild ID."""
return db_session.scalar(
select(DiscordGuildConfig).where(DiscordGuildConfig.guild_id == guild_id)
)
def get_guild_config_by_registration_key(
db_session: Session,
registration_key: str,
) -> DiscordGuildConfig | None:
"""Get a guild config by its registration key."""
return db_session.scalar(
select(DiscordGuildConfig).where(
DiscordGuildConfig.registration_key == registration_key
)
)
def create_guild_config(
db_session: Session,
registration_key: str,
) -> DiscordGuildConfig:
"""Create a new guild config with a registration key (guild_id=NULL)."""
config = DiscordGuildConfig(registration_key=registration_key)
db_session.add(config)
db_session.flush()
return config
def register_guild(
db_session: Session,
config: DiscordGuildConfig,
guild_id: int,
guild_name: str,
) -> DiscordGuildConfig:
"""Complete registration by setting guild_id and guild_name."""
config.guild_id = guild_id
config.guild_name = guild_name
config.registered_at = datetime.now(timezone.utc)
db_session.flush()
return config
def update_guild_config(
db_session: Session,
config: DiscordGuildConfig,
enabled: bool,
default_persona_id: int | None = None,
) -> DiscordGuildConfig:
"""Update guild config fields."""
config.enabled = enabled
config.default_persona_id = default_persona_id
db_session.flush()
return config
def delete_guild_config(
db_session: Session,
internal_id: int,
) -> bool:
"""Delete guild config (cascades to channel configs). Returns True if deleted."""
result = db_session.execute(
delete(DiscordGuildConfig).where(DiscordGuildConfig.id == internal_id)
)
db_session.flush()
return result.rowcount > 0 # type: ignore[attr-defined]
# === DiscordChannelConfig ===
def get_channel_configs(
db_session: Session,
guild_config_id: int,
) -> list[DiscordChannelConfig]:
"""Get all channel configs for a guild."""
return list(
db_session.scalars(
select(DiscordChannelConfig).where(
DiscordChannelConfig.guild_config_id == guild_config_id
)
).all()
)
def get_channel_config_by_discord_ids(
db_session: Session,
guild_id: int,
channel_id: int,
) -> DiscordChannelConfig | None:
"""Get a specific channel config by guild_id and channel_id."""
return db_session.scalar(
select(DiscordChannelConfig)
.join(DiscordGuildConfig)
.where(
DiscordGuildConfig.guild_id == guild_id,
DiscordChannelConfig.channel_id == channel_id,
)
)
def get_channel_config_by_internal_ids(
db_session: Session,
guild_config_id: int,
channel_config_id: int,
) -> DiscordChannelConfig | None:
"""Get a specific channel config by guild_config_id and channel_config_id"""
return db_session.scalar(
select(DiscordChannelConfig).where(
DiscordChannelConfig.guild_config_id == guild_config_id,
DiscordChannelConfig.id == channel_config_id,
)
)
def update_discord_channel_config(
db_session: Session,
config: DiscordChannelConfig,
channel_name: str,
thread_only_mode: bool,
require_bot_invocation: bool,
enabled: bool,
persona_override_id: int | None = None,
) -> DiscordChannelConfig:
"""Update channel config fields."""
config.channel_name = channel_name
config.require_bot_invocation = require_bot_invocation
config.persona_override_id = persona_override_id
config.enabled = enabled
config.thread_only_mode = thread_only_mode
db_session.flush()
return config
def delete_discord_channel_config(
db_session: Session,
guild_config_id: int,
channel_config_id: int,
) -> bool:
"""Delete a channel config. Returns True if deleted."""
result = db_session.execute(
delete(DiscordChannelConfig).where(
DiscordChannelConfig.guild_config_id == guild_config_id,
DiscordChannelConfig.id == channel_config_id,
)
)
db_session.flush()
return result.rowcount > 0 # type: ignore[attr-defined]
def create_channel_config(
db_session: Session,
guild_config_id: int,
channel_view: DiscordChannelView,
) -> DiscordChannelConfig:
"""Create a new channel config with default settings (disabled by default, admin enables via UI)."""
config = DiscordChannelConfig(
guild_config_id=guild_config_id,
channel_id=channel_view.channel_id,
channel_name=channel_view.channel_name,
channel_type=channel_view.channel_type,
is_private=channel_view.is_private,
)
db_session.add(config)
db_session.flush()
return config
def bulk_create_channel_configs(
db_session: Session,
guild_config_id: int,
channels: list[DiscordChannelView],
) -> list[DiscordChannelConfig]:
"""Create multiple channel configs at once. Skips existing channels."""
# Get existing channel IDs for this guild
existing_channel_ids = set(
db_session.scalars(
select(DiscordChannelConfig.channel_id).where(
DiscordChannelConfig.guild_config_id == guild_config_id
)
).all()
)
# Create configs for new channels only
new_configs = []
for channel_view in channels:
if channel_view.channel_id not in existing_channel_ids:
config = DiscordChannelConfig(
guild_config_id=guild_config_id,
channel_id=channel_view.channel_id,
channel_name=channel_view.channel_name,
channel_type=channel_view.channel_type,
is_private=channel_view.is_private,
)
db_session.add(config)
new_configs.append(config)
db_session.flush()
return new_configs
def sync_channel_configs(
db_session: Session,
guild_config_id: int,
current_channels: list[DiscordChannelView],
) -> tuple[int, int, int]:
"""Sync channel configs with current Discord channels.
- Creates configs for new channels (disabled by default)
- Removes configs for deleted channels
- Updates names and types for existing channels if changed
Returns: (added_count, removed_count, updated_count)
"""
current_channel_map = {
channel_view.channel_id: channel_view for channel_view in current_channels
}
current_channel_ids = set(current_channel_map.keys())
# Get existing configs
existing_configs = get_channel_configs(db_session, guild_config_id)
existing_channel_ids = {c.channel_id for c in existing_configs}
# Find channels to add, remove, and potentially update
to_add = current_channel_ids - existing_channel_ids
to_remove = existing_channel_ids - current_channel_ids
# Add new channels
added_count = 0
for channel_id in to_add:
channel_view = current_channel_map[channel_id]
create_channel_config(db_session, guild_config_id, channel_view)
added_count += 1
# Remove deleted channels
removed_count = 0
for config in existing_configs:
if config.channel_id in to_remove:
db_session.delete(config)
removed_count += 1
# Update names, types, and privacy for existing channels if changed
updated_count = 0
for config in existing_configs:
if config.channel_id in current_channel_ids:
channel_view = current_channel_map[config.channel_id]
changed = False
if config.channel_name != channel_view.channel_name:
config.channel_name = channel_view.channel_name
changed = True
if config.channel_type != channel_view.channel_type:
config.channel_type = channel_view.channel_type
changed = True
if config.is_private != channel_view.is_private:
config.is_private = channel_view.is_private
changed = True
if changed:
updated_count += 1
db_session.flush()
return added_count, removed_count, updated_count

View File

@@ -26,6 +26,7 @@ from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import Index
from sqlalchemy import Integer
from sqlalchemy import BigInteger
from sqlalchemy import Sequence
from sqlalchemy import String
@@ -2931,8 +2932,6 @@ class PersonaLabel(Base):
"Persona",
secondary=Persona__PersonaLabel.__table__,
back_populates="labels",
cascade="all, delete-orphan",
single_parent=True,
)
@@ -3038,6 +3037,124 @@ class SlackBot(Base):
)
class DiscordBotConfig(Base):
"""Global Discord bot configuration (one per tenant).
Stores the bot token when not provided via DISCORD_BOT_TOKEN env var.
Uses a fixed ID with check constraint to enforce only one row per tenant.
"""
__tablename__ = "discord_bot_config"
id: Mapped[str] = mapped_column(
String, primary_key=True, server_default=text("'SINGLETON'")
)
bot_token: Mapped[str] = mapped_column(EncryptedString(), nullable=False)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
class DiscordGuildConfig(Base):
"""Configuration for a Discord guild (server) connected to this tenant.
registration_key is a one-time key used to link a Discord server to this tenant.
Format: discord_<tenant_id>.<random_token>
guild_id is NULL until the Discord admin runs !register with the key.
"""
__tablename__ = "discord_guild_config"
id: Mapped[int] = mapped_column(primary_key=True)
# Discord snowflake - NULL until registered via command in Discord
guild_id: Mapped[int | None] = mapped_column(BigInteger, nullable=True, unique=True)
guild_name: Mapped[str | None] = mapped_column(String(256), nullable=True)
# One-time registration key: discord_<tenant_id>.<random_token>
registration_key: Mapped[str] = mapped_column(String, unique=True, nullable=False)
registered_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
# Configuration
default_persona_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id", ondelete="SET NULL"), nullable=True
)
enabled: Mapped[bool] = mapped_column(
Boolean, server_default=text("true"), nullable=False
)
# Relationships
default_persona: Mapped["Persona | None"] = relationship(
"Persona", foreign_keys=[default_persona_id]
)
channels: Mapped[list["DiscordChannelConfig"]] = relationship(
back_populates="guild_config", cascade="all, delete-orphan"
)
class DiscordChannelConfig(Base):
"""Per-channel configuration for Discord bot behavior.
Used to whitelist specific channels and configure per-channel behavior.
"""
__tablename__ = "discord_channel_config"
id: Mapped[int] = mapped_column(primary_key=True)
guild_config_id: Mapped[int] = mapped_column(
ForeignKey("discord_guild_config.id", ondelete="CASCADE"), nullable=False
)
# Discord snowflake
channel_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
channel_name: Mapped[str] = mapped_column(String(), nullable=False)
# Channel type from Discord (text, forum)
channel_type: Mapped[str] = mapped_column(
String(20), server_default=text("'text'"), nullable=False
)
# True if @everyone cannot view the channel
is_private: Mapped[bool] = mapped_column(
Boolean, server_default=text("false"), nullable=False
)
# If true, bot only responds to messages in threads
# Otherwise, will reply in channel
thread_only_mode: Mapped[bool] = mapped_column(
Boolean, server_default=text("false"), nullable=False
)
# If true (default), bot only responds when @mentioned
# If false, bot responds to ALL messages in this channel
require_bot_invocation: Mapped[bool] = mapped_column(
Boolean, server_default=text("true"), nullable=False
)
# Override the guild's default persona for this channel
persona_override_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id", ondelete="SET NULL"), nullable=True
)
enabled: Mapped[bool] = mapped_column(
Boolean, server_default=text("false"), nullable=False
)
# Relationships
guild_config: Mapped["DiscordGuildConfig"] = relationship(back_populates="channels")
persona_override: Mapped["Persona | None"] = relationship()
# Constraints
__table_args__ = (
UniqueConstraint(
"guild_config_id", "channel_id", name="uq_discord_channel_guild_channel"
),
)
class Milestone(Base):
# This table is used to track significant events for a deployment towards finding value
# The table is currently not used for features but it may be used in the future to inform
@@ -3115,25 +3232,6 @@ class FileRecord(Base):
)
class AgentSearchMetrics(Base):
__tablename__ = "agent__search_metrics"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
persona_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id"), nullable=True
)
agent_type: Mapped[str] = mapped_column(String)
start_time: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
base_duration_s: Mapped[float] = mapped_column(Float)
full_duration_s: Mapped[float] = mapped_column(Float)
base_metrics: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
refined_metrics: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
all_metrics: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
"""
************************************************************************
Enterprise Edition Models

View File

@@ -917,7 +917,9 @@ def upsert_persona(
existing_persona.icon_name = icon_name
existing_persona.is_visible = is_visible
existing_persona.search_start_date = search_start_date
existing_persona.labels = labels or []
if label_ids is not None:
existing_persona.labels.clear()
existing_persona.labels = labels or []
existing_persona.is_default_persona = (
is_default_persona
if is_default_persona is not None

View File

@@ -40,3 +40,10 @@ class DocumentRow(BaseModel):
class SortOrder(str, Enum):
ASC = "asc"
DESC = "desc"
class DiscordChannelView(BaseModel):
channel_id: int
channel_name: str
channel_type: str = "text" # text, forum
is_private: bool = False # True if @everyone cannot view the channel

View File

@@ -369,6 +369,8 @@ def _patch_openai_responses_chunk_parser() -> None:
# New output item added
output_item = parsed_chunk.get("item", {})
if output_item.get("type") == "function_call":
# Track that we've received tool calls via streaming
self._has_streamed_tool_calls = True
return GenericStreamingChunk(
text="",
tool_use=ChatCompletionToolCallChunk(
@@ -394,6 +396,8 @@ def _patch_openai_responses_chunk_parser() -> None:
elif event_type == "response.function_call_arguments.delta":
content_part: Optional[str] = parsed_chunk.get("delta", None)
if content_part:
# Track that we've received tool calls via streaming
self._has_streamed_tool_calls = True
return GenericStreamingChunk(
text="",
tool_use=ChatCompletionToolCallChunk(
@@ -491,22 +495,72 @@ def _patch_openai_responses_chunk_parser() -> None:
elif event_type == "response.completed":
# Final event signaling all output items (including parallel tool calls) are done
# Check if we already received tool calls via streaming events
# There is an issue where OpenAI (not via Azure) will give back the tool calls streamed out as tokens
# But on Azure, it's only given out all at once. OpenAI also happens to give back the tool calls in the
# response.completed event so we need to throw it out here or there are duplicate tool calls.
has_streamed_tool_calls = getattr(self, "_has_streamed_tool_calls", False)
response_data = parsed_chunk.get("response", {})
# Determine finish reason based on response content
finish_reason = "stop"
if response_data.get("output"):
for item in response_data["output"]:
if isinstance(item, dict) and item.get("type") == "function_call":
finish_reason = "tool_calls"
break
return GenericStreamingChunk(
text="",
tool_use=None,
is_finished=True,
finish_reason=finish_reason,
usage=None,
output_items = response_data.get("output", [])
# Check if there are function_call items in the output
has_function_calls = any(
isinstance(item, dict) and item.get("type") == "function_call"
for item in output_items
)
if has_function_calls and not has_streamed_tool_calls:
# Azure's Responses API returns all tool calls in response.completed
# without streaming them incrementally. Extract them here.
from litellm.types.utils import (
Delta,
ModelResponseStream,
StreamingChoices,
)
tool_calls = []
for idx, item in enumerate(output_items):
if isinstance(item, dict) and item.get("type") == "function_call":
tool_calls.append(
ChatCompletionToolCallChunk(
id=item.get("call_id"),
index=idx,
type="function",
function=ChatCompletionToolCallFunctionChunk(
name=item.get("name"),
arguments=item.get("arguments", ""),
),
)
)
return ModelResponseStream(
choices=[
StreamingChoices(
index=0,
delta=Delta(tool_calls=tool_calls),
finish_reason="tool_calls",
)
]
)
elif has_function_calls:
# Tool calls were already streamed, just signal completion
return GenericStreamingChunk(
text="",
tool_use=None,
is_finished=True,
finish_reason="tool_calls",
usage=None,
)
else:
return GenericStreamingChunk(
text="",
tool_use=None,
is_finished=True,
finish_reason="stop",
usage=None,
)
else:
pass
@@ -631,6 +685,40 @@ def _patch_openai_responses_transform_response() -> None:
LiteLLMResponsesTransformationHandler.transform_response = _patched_transform_response # type: ignore[method-assign]
def _patch_azure_responses_should_fake_stream() -> None:
"""
Patches AzureOpenAIResponsesAPIConfig.should_fake_stream to always return False.
By default, LiteLLM uses "fake streaming" (MockResponsesAPIStreamingIterator) for models
not in its database. This causes Azure custom model deployments to buffer the entire
response before yielding, resulting in poor time-to-first-token.
Azure's Responses API supports native streaming, so we override this to always use
real streaming (SyncResponsesAPIStreamingIterator).
"""
from litellm.llms.azure.responses.transformation import (
AzureOpenAIResponsesAPIConfig,
)
if (
getattr(AzureOpenAIResponsesAPIConfig.should_fake_stream, "__name__", "")
== "_patched_should_fake_stream"
):
return
def _patched_should_fake_stream(
self: Any,
model: Optional[str],
stream: Optional[bool],
custom_llm_provider: Optional[str] = None,
) -> bool:
# Azure Responses API supports native streaming - never fake it
return False
_patched_should_fake_stream.__name__ = "_patched_should_fake_stream"
AzureOpenAIResponsesAPIConfig.should_fake_stream = _patched_should_fake_stream # type: ignore[method-assign]
def apply_monkey_patches() -> None:
"""
Apply all necessary monkey patches to LiteLLM for compatibility.
@@ -640,12 +728,13 @@ def apply_monkey_patches() -> None:
- Patching OllamaChatCompletionResponseIterator.chunk_parser for streaming content
- Patching OpenAiResponsesToChatCompletionStreamIterator.chunk_parser for OpenAI Responses API
- Patching LiteLLMResponsesTransformationHandler.transform_response for non-streaming responses
- Patching LiteLLMResponsesTransformationHandler._convert_content_str_to_input_text for tool content types
- Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming
"""
_patch_ollama_transform_request()
_patch_ollama_chunk_parser()
_patch_openai_responses_chunk_parser()
_patch_openai_responses_transform_response()
_patch_azure_responses_should_fake_stream()
def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[str]]:

View File

@@ -301,6 +301,12 @@ class LitellmLLM(LLM):
)
is_ollama = self._model_provider == LlmProviderNames.OLLAMA_CHAT
is_mistral = self._model_provider == LlmProviderNames.MISTRAL
is_vertex_ai = self._model_provider == LlmProviderNames.VERTEX_AI
# Vertex Anthropic Opus 4.5 rejects output_config (LiteLLM maps reasoning_effort).
# Keep this guard until LiteLLM/Vertex accept the field for this model.
is_vertex_opus_4_5 = (
is_vertex_ai and "claude-opus-4-5" in self.config.model_name.lower()
)
#########################
# Build arguments
@@ -331,12 +337,16 @@ class LitellmLLM(LLM):
# Temperature
temperature = 1 if is_reasoning else self._temperature
if stream:
if stream and not is_vertex_opus_4_5:
optional_kwargs["stream_options"] = {"include_usage": True}
# Use configured default if not provided (if not set in env, low)
reasoning_effort = reasoning_effort or ReasoningEffort(DEFAULT_REASONING_EFFORT)
if is_reasoning and reasoning_effort != ReasoningEffort.OFF:
if (
is_reasoning
and reasoning_effort != ReasoningEffort.OFF
and not is_vertex_opus_4_5
):
if is_openai_model:
# OpenAI API does not accept reasoning params for GPT 5 chat models
# (neither reasoning nor reasoning_effort are accepted)

View File

@@ -96,6 +96,7 @@ from onyx.server.long_term_logs.long_term_logs_api import (
router as long_term_logs_router,
)
from onyx.server.manage.administrative import router as admin_router
from onyx.server.manage.discord_bot.api import router as discord_bot_router
from onyx.server.manage.embedding.api import admin_router as embedding_admin_router
from onyx.server.manage.embedding.api import basic_router as embedding_router
from onyx.server.manage.get_state import router as state_router
@@ -380,6 +381,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
include_router_with_global_prefix_prepended(
application, slack_bot_management_router
)
include_router_with_global_prefix_prepended(application, discord_bot_router)
include_router_with_global_prefix_prepended(application, persona_router)
include_router_with_global_prefix_prepended(application, admin_persona_router)
include_router_with_global_prefix_prepended(application, agents_router)

View File

@@ -0,0 +1,287 @@
# Discord Bot Multitenant Architecture
This document analyzes how the Discord cache manager and API client coordinate to handle multitenant API keys from a single Discord client.
## Overview
The Discord bot uses a **single-client, multi-tenant** architecture where one `OnyxDiscordClient` instance serves multiple tenants (organizations) simultaneously. Tenant isolation is achieved through:
- **Cache Manager**: Maps Discord guilds to tenants and stores per-tenant API keys
- **API Client**: Stateless HTTP client that accepts dynamic API keys per request
```
┌─────────────────────────────────────────────────────────────────────┐
│ OnyxDiscordClient │
│ │
│ ┌─────────────────────────┐ ┌─────────────────────────────┐ │
│ │ DiscordCacheManager │ │ OnyxAPIClient │ │
│ │ │ │ │ │
│ │ guild_id → tenant_id │───▶│ send_chat_message( │ │
│ │ tenant_id → api_key │ │ message, │ │
│ │ │ │ api_key=<per-tenant>, │ │
│ └─────────────────────────┘ │ persona_id=... │ │
│ │ ) │ │
│ └─────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────────┘
```
---
## Component Details
### 1. Cache Manager (`backend/onyx/onyxbot/discord/cache.py`)
The `DiscordCacheManager` maintains two critical in-memory mappings:
```python
class DiscordCacheManager:
_guild_tenants: dict[int, str] # guild_id → tenant_id
_api_keys: dict[str, str] # tenant_id → api_key
_lock: asyncio.Lock # Concurrency control
```
#### Key Responsibilities
| Function | Purpose |
|----------|---------|
| `get_tenant(guild_id)` | O(1) lookup: guild → tenant |
| `get_api_key(tenant_id)` | O(1) lookup: tenant → API key |
| `refresh_all()` | Full cache rebuild from database |
| `refresh_guild()` | Incremental update for single guild |
#### API Key Provisioning Strategy
API keys are **lazily provisioned** - only created when first needed:
```python
async def _load_tenant_data(self, tenant_id: str) -> tuple[list[int], str | None]:
needs_key = tenant_id not in self._api_keys
with get_session_with_tenant(tenant_id) as db:
# Load guild configs
configs = get_discord_bot_configs(db)
guild_ids = [c.guild_id for c in configs if c.enabled]
# Only provision API key if not already cached
api_key = None
if needs_key:
api_key = get_or_create_discord_service_api_key(db, tenant_id)
return guild_ids, api_key
```
This optimization avoids repeated database calls for API key generation.
#### Concurrency Control
All write operations acquire an async lock to prevent race conditions:
```python
async def refresh_all(self) -> None:
async with self._lock:
# Safe to modify _guild_tenants and _api_keys
for tenant_id in get_all_tenant_ids():
guild_ids, api_key = await self._load_tenant_data(tenant_id)
# Update mappings...
```
Read operations (`get_tenant`, `get_api_key`) are lock-free since Python dict lookups are atomic.
---
### 2. API Client (`backend/onyx/onyxbot/discord/api_client.py`)
The `OnyxAPIClient` is a **stateless async HTTP client** that communicates with Onyx API pods.
#### Key Design: Per-Request API Key Injection
```python
class OnyxAPIClient:
async def send_chat_message(
self,
message: str,
api_key: str, # Injected per-request
persona_id: int | None,
...
) -> ChatFullResponse:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}", # Tenant-specific auth
}
# Make request...
```
The client accepts `api_key` as a parameter to each method, enabling **dynamic tenant selection at request time**. This design allows a single client instance to serve multiple tenants:
```python
# Same client, different tenants
await api_client.send_chat_message(msg, api_key=key_for_tenant_1, ...)
await api_client.send_chat_message(msg, api_key=key_for_tenant_2, ...)
```
---
## Coordination Flow
### Message Processing Pipeline
When a Discord message arrives, the client coordinates cache and API client:
```python
async def on_message(self, message: Message) -> None:
guild_id = message.guild.id
# Step 1: Cache lookup - guild → tenant
tenant_id = self.cache.get_tenant(guild_id)
if not tenant_id:
return # Guild not registered
# Step 2: Cache lookup - tenant → API key
api_key = self.cache.get_api_key(tenant_id)
if not api_key:
logger.warning(f"No API key for tenant {tenant_id}")
return
# Step 3: API call with tenant-specific credentials
await process_chat_message(
message=message,
api_key=api_key, # Tenant-specific
persona_id=persona_id, # Tenant-specific
api_client=self.api_client,
)
```
### Startup Sequence
```python
async def setup_hook(self) -> None:
# 1. Initialize API client (create aiohttp session)
await self.api_client.initialize()
# 2. Populate cache with all tenants
await self.cache.refresh_all()
# 3. Start background refresh task
self._cache_refresh_task = self.loop.create_task(
self._periodic_cache_refresh() # Every 60 seconds
)
```
### Shutdown Sequence
```python
async def close(self) -> None:
# 1. Cancel background refresh
if self._cache_refresh_task:
self._cache_refresh_task.cancel()
# 2. Close Discord connection
await super().close()
# 3. Close API client session
await self.api_client.close()
# 4. Clear cache
self.cache.clear()
```
---
## Tenant Isolation Mechanisms
### 1. Per-Tenant API Keys
Each tenant has a dedicated service API key:
```python
# backend/onyx/db/discord_bot.py
def get_or_create_discord_service_api_key(db_session: Session, tenant_id: str) -> str:
existing = get_discord_service_api_key(db_session)
if existing:
return regenerate_key(existing)
# Create LIMITED role key (chat-only permissions)
return insert_api_key(
db_session=db_session,
api_key_args=APIKeyArgs(
name=DISCORD_SERVICE_API_KEY_NAME,
role=UserRole.LIMITED, # Minimal permissions
),
user_id=None, # Service account (system-owned)
).api_key
```
### 2. Database Context Variables
The cache uses context variables for proper tenant-scoped DB sessions:
```python
context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
with get_session_with_tenant(tenant_id) as db:
# All DB operations scoped to this tenant
...
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token)
```
### 3. Enterprise Gating Support
Gated tenants are filtered during cache refresh:
```python
gated_tenants = fetch_ee_implementation_or_noop(
"onyx.server.tenants.product_gating",
"get_gated_tenants",
set(),
)()
for tenant_id in get_all_tenant_ids():
if tenant_id in gated_tenants:
continue # Skip gated tenants
```
---
## Cache Refresh Strategy
| Trigger | Method | Scope |
|---------|--------|-------|
| Startup | `refresh_all()` | All tenants |
| Periodic (60s) | `refresh_all()` | All tenants |
| Guild registration | `refresh_guild()` | Single tenant |
### Error Handling
- **Tenant-level errors**: Logged and skipped (doesn't stop other tenants)
- **Missing API key**: Bot silently ignores messages from that guild
- **Network errors**: Logged, cache continues with stale data until next refresh
---
## Key Design Insights
1. **Single Client, Multiple Tenants**: One `OnyxAPIClient` and one `DiscordCacheManager` instance serves all tenants via dynamic API key injection.
2. **Cache-First Architecture**: Guild lookups are O(1) in-memory; API keys are cached after first provisioning to avoid repeated DB calls.
3. **Graceful Degradation**: If an API key is missing or stale, the bot simply doesn't respond (no crash or error propagation).
4. **Thread Safety Without Blocking**: `asyncio.Lock` prevents race conditions while maintaining async concurrency for reads.
5. **Lazy Provisioning**: API keys are only created when first needed, then cached for performance.
6. **Stateless API Client**: The HTTP client holds no tenant state - all tenant context is injected per-request via the `api_key` parameter.
---
## File References
| Component | Path |
|-----------|------|
| Cache Manager | `backend/onyx/onyxbot/discord/cache.py` |
| API Client | `backend/onyx/onyxbot/discord/api_client.py` |
| Discord Client | `backend/onyx/onyxbot/discord/client.py` |
| API Key DB Operations | `backend/onyx/db/discord_bot.py` |
| Cache Manager Tests | `backend/tests/unit/onyx/onyxbot/discord/test_cache_manager.py` |
| API Client Tests | `backend/tests/unit/onyx/onyxbot/discord/test_api_client.py` |

View File

@@ -0,0 +1,215 @@
"""Async HTTP client for communicating with Onyx API pods."""
import aiohttp
from onyx.chat.models import ChatFullResponse
from onyx.onyxbot.discord.constants import API_REQUEST_TIMEOUT
from onyx.onyxbot.discord.exceptions import APIConnectionError
from onyx.onyxbot.discord.exceptions import APIResponseError
from onyx.onyxbot.discord.exceptions import APITimeoutError
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
from onyx.server.query_and_chat.models import MessageOrigin
from onyx.server.query_and_chat.models import SendMessageRequest
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import build_api_server_url_for_http_requests
logger = setup_logger()
class OnyxAPIClient:
"""Async HTTP client for sending chat requests to Onyx API pods.
This client manages an aiohttp session for making non-blocking HTTP
requests to the Onyx API server. It handles authentication with per-tenant
API keys and multi-tenant routing.
Usage:
client = OnyxAPIClient()
await client.initialize()
try:
response = await client.send_chat_message(
message="What is our deployment process?",
tenant_id="tenant_123",
api_key="dn_xxx...",
persona_id=1,
)
print(response.answer)
finally:
await client.close()
"""
def __init__(
self,
timeout: int = API_REQUEST_TIMEOUT,
) -> None:
"""Initialize the API client.
Args:
timeout: Request timeout in seconds.
"""
# Helm chart uses API_SERVER_URL_OVERRIDE_FOR_HTTP_REQUESTS to set the base URL
# TODO: Ideally, this override is only used when someone is launching an Onyx service independently
self._base_url = build_api_server_url_for_http_requests(
respect_env_override_if_set=True
).rstrip("/")
self._timeout = timeout
self._session: aiohttp.ClientSession | None = None
async def initialize(self) -> None:
"""Create the aiohttp session.
Must be called before making any requests. The session is created
with a total timeout and connection timeout.
"""
if self._session is not None:
logger.warning("API client session already initialized")
return
timeout = aiohttp.ClientTimeout(
total=self._timeout,
connect=30, # 30 seconds to establish connection
)
self._session = aiohttp.ClientSession(timeout=timeout)
logger.info(f"API client initialized with base URL: {self._base_url}")
async def close(self) -> None:
"""Close the aiohttp session.
Should be called when shutting down the bot to properly release
resources.
"""
if self._session is not None:
await self._session.close()
self._session = None
logger.info("API client session closed")
@property
def is_initialized(self) -> bool:
"""Check if the session is initialized."""
return self._session is not None
async def send_chat_message(
self,
message: str,
api_key: str,
persona_id: int | None = None,
) -> ChatFullResponse:
"""Send a chat message to the Onyx API server and get a response.
This method sends a non-streaming chat request to the API server. The response
contains the complete answer with any citations and metadata.
Args:
message: The user's message to process.
api_key: The API key for authentication.
persona_id: Optional persona ID to use for the response.
Returns:
ChatFullResponse containing the answer, citations, and metadata.
Raises:
APIConnectionError: If unable to connect to the API.
APITimeoutError: If the request times out.
APIResponseError: If the API returns an error response.
"""
if self._session is None:
raise APIConnectionError(
"API client not initialized. Call initialize() first."
)
url = f"{self._base_url}/chat/send-chat-message"
# Build request payload
request = SendMessageRequest(
message=message,
stream=False,
origin=MessageOrigin.DISCORDBOT,
chat_session_info=ChatSessionCreationRequest(
persona_id=persona_id if persona_id is not None else 0,
),
)
# Build headers
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}
try:
async with self._session.post(
url,
json=request.model_dump(mode="json"),
headers=headers,
) as response:
if response.status == 401:
raise APIResponseError(
"Authentication failed - invalid API key",
status_code=401,
)
elif response.status == 403:
raise APIResponseError(
"Access denied - insufficient permissions",
status_code=403,
)
elif response.status == 404:
raise APIResponseError(
"API endpoint not found",
status_code=404,
)
elif response.status >= 500:
error_text = await response.text()
raise APIResponseError(
f"Server error: {error_text}",
status_code=response.status,
)
elif response.status >= 400:
error_text = await response.text()
raise APIResponseError(
f"Request error: {error_text}",
status_code=response.status,
)
# Parse successful response
data = await response.json()
response_obj = ChatFullResponse.model_validate(data)
if response_obj.error_msg:
logger.warning(f"Chat API returned error: {response_obj.error_msg}")
return response_obj
except aiohttp.ClientConnectorError as e:
logger.error(f"Failed to connect to API: {e}")
raise APIConnectionError(
f"Failed to connect to API at {self._base_url}: {e}"
) from e
except TimeoutError as e:
logger.error(f"API request timed out after {self._timeout}s")
raise APITimeoutError(
f"Request timed out after {self._timeout} seconds"
) from e
except aiohttp.ClientError as e:
logger.error(f"HTTP client error: {e}")
raise APIConnectionError(f"HTTP client error: {e}") from e
async def health_check(self) -> bool:
"""Check if the API server is healthy.
Returns:
True if the API server is reachable and healthy, False otherwise.
"""
if self._session is None:
logger.warning("API client not initialized. Call initialize() first.")
return False
try:
url = f"{self._base_url}/health"
async with self._session.get(
url, timeout=aiohttp.ClientTimeout(total=10)
) as response:
return response.status == 200
except Exception as e:
logger.warning(f"API server health check failed: {e}")
return False

View File

@@ -0,0 +1,154 @@
"""Multi-tenant cache for Discord bot guild-tenant mappings and API keys."""
import asyncio
from onyx.db.discord_bot import get_guild_configs
from onyx.db.discord_bot import get_or_create_discord_service_api_key
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.onyxbot.discord.exceptions import CacheError
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
class DiscordCacheManager:
"""Caches guild->tenant mappings and tenant->API key mappings.
Refreshed on startup, periodically (every 60s), and when guilds register.
"""
def __init__(self) -> None:
self._guild_tenants: dict[int, str] = {} # guild_id -> tenant_id
self._api_keys: dict[str, str] = {} # tenant_id -> api_key
self._lock = asyncio.Lock()
self._initialized = False
@property
def is_initialized(self) -> bool:
return self._initialized
async def refresh_all(self) -> None:
"""Full cache refresh from all tenants."""
async with self._lock:
logger.info("Starting Discord cache refresh")
new_guild_tenants: dict[int, str] = {}
new_api_keys: dict[str, str] = {}
try:
gated = fetch_ee_implementation_or_noop(
"onyx.server.tenants.product_gating",
"get_gated_tenants",
set(),
)()
tenant_ids = await asyncio.to_thread(get_all_tenant_ids)
for tenant_id in tenant_ids:
if tenant_id in gated:
continue
context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
guild_ids, api_key = await self._load_tenant_data(tenant_id)
if not guild_ids:
logger.debug(f"No guilds found for tenant {tenant_id}")
continue
if not api_key:
logger.warning(
"Discord service API key missing for tenant that has registered guilds. "
f"{tenant_id} will not be handled in this refresh cycle."
)
continue
for guild_id in guild_ids:
new_guild_tenants[guild_id] = tenant_id
new_api_keys[tenant_id] = api_key
except Exception as e:
logger.warning(f"Failed to refresh tenant {tenant_id}: {e}")
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token)
self._guild_tenants = new_guild_tenants
self._api_keys = new_api_keys
self._initialized = True
logger.info(
f"Cache refresh complete: {len(new_guild_tenants)} guilds, "
f"{len(new_api_keys)} tenants"
)
except Exception as e:
logger.error(f"Cache refresh failed: {e}")
raise CacheError(f"Failed to refresh cache: {e}") from e
async def refresh_guild(self, guild_id: int, tenant_id: str) -> None:
"""Add a single guild to cache after registration."""
async with self._lock:
logger.info(f"Refreshing cache for guild {guild_id} (tenant: {tenant_id})")
guild_ids, api_key = await self._load_tenant_data(tenant_id)
if guild_id in guild_ids:
self._guild_tenants[guild_id] = tenant_id
if api_key:
self._api_keys[tenant_id] = api_key
logger.info(f"Cache updated for guild {guild_id}")
else:
logger.warning(f"Guild {guild_id} not found or disabled")
async def _load_tenant_data(self, tenant_id: str) -> tuple[list[int], str | None]:
"""Load guild IDs and provision API key if needed.
Returns:
(active_guild_ids, api_key) - api_key is the cached key if available,
otherwise a newly created key. Returns None if no guilds found.
"""
cached_key = self._api_keys.get(tenant_id)
def _sync() -> tuple[list[int], str | None]:
with get_session_with_tenant(tenant_id=tenant_id) as db:
configs = get_guild_configs(db)
guild_ids = [
config.guild_id
for config in configs
if config.enabled and config.guild_id is not None
]
if not guild_ids:
return [], None
if not cached_key:
new_key = get_or_create_discord_service_api_key(db, tenant_id)
db.commit()
return guild_ids, new_key
return guild_ids, cached_key
return await asyncio.to_thread(_sync)
def get_tenant(self, guild_id: int) -> str | None:
"""Get tenant ID for a guild."""
return self._guild_tenants.get(guild_id)
def get_api_key(self, tenant_id: str) -> str | None:
"""Get API key for a tenant."""
return self._api_keys.get(tenant_id)
def remove_guild(self, guild_id: int) -> None:
"""Remove a guild from cache."""
self._guild_tenants.pop(guild_id, None)
def get_all_guild_ids(self) -> list[int]:
"""Get all cached guild IDs."""
return list(self._guild_tenants.keys())
def clear(self) -> None:
"""Clear all caches."""
self._guild_tenants.clear()
self._api_keys.clear()
self._initialized = False

View File

@@ -0,0 +1,232 @@
"""Discord bot client with integrated message handling."""
import asyncio
import time
import discord
from discord.ext import commands
from onyx.configs.app_configs import DISCORD_BOT_INVOKE_CHAR
from onyx.onyxbot.discord.api_client import OnyxAPIClient
from onyx.onyxbot.discord.cache import DiscordCacheManager
from onyx.onyxbot.discord.constants import CACHE_REFRESH_INTERVAL
from onyx.onyxbot.discord.handle_commands import handle_dm
from onyx.onyxbot.discord.handle_commands import handle_registration_command
from onyx.onyxbot.discord.handle_commands import handle_sync_channels_command
from onyx.onyxbot.discord.handle_message import process_chat_message
from onyx.onyxbot.discord.handle_message import should_respond
from onyx.onyxbot.discord.utils import get_bot_token
from onyx.utils.logger import setup_logger
logger = setup_logger()
class OnyxDiscordClient(commands.Bot):
"""Discord bot client with integrated cache, API client, and message handling.
This client handles:
- Guild registration via !register command
- Message processing with persona-based responses
- Thread context for conversation continuity
- Multi-tenant support via cached API keys
"""
def __init__(self, command_prefix: str = DISCORD_BOT_INVOKE_CHAR) -> None:
intents = discord.Intents.default()
intents.message_content = True
intents.members = True
super().__init__(command_prefix=command_prefix, intents=intents)
self.ready = False
self.cache = DiscordCacheManager()
self.api_client = OnyxAPIClient()
self._cache_refresh_task: asyncio.Task | None = None
# -------------------------------------------------------------------------
# Lifecycle Methods
# -------------------------------------------------------------------------
async def setup_hook(self) -> None:
"""Called before on_ready. Initialize components."""
logger.info("Initializing Discord bot components...")
# Initialize API client
await self.api_client.initialize()
# Initial cache load
await self.cache.refresh_all()
# Start periodic cache refresh
self._cache_refresh_task = self.loop.create_task(self._periodic_cache_refresh())
logger.info("Discord bot components initialized")
async def _periodic_cache_refresh(self) -> None:
"""Background task to refresh cache periodically."""
while not self.is_closed():
await asyncio.sleep(CACHE_REFRESH_INTERVAL)
try:
await self.cache.refresh_all()
except Exception as e:
logger.error(f"Cache refresh failed: {e}")
async def on_ready(self) -> None:
"""Bot connected and ready."""
if self.ready:
return
if not self.user:
raise RuntimeError("Critical error: Discord Bot user not found")
logger.info(f"Discord Bot connected as {self.user} (ID: {self.user.id})")
logger.info(f"Connected to {len(self.guilds)} guild(s)")
logger.info(f"Cached {len(self.cache.get_all_guild_ids())} registered guild(s)")
self.ready = True
async def close(self) -> None:
"""Graceful shutdown."""
logger.info("Shutting down Discord bot...")
# Cancel cache refresh task
if self._cache_refresh_task:
self._cache_refresh_task.cancel()
try:
await self._cache_refresh_task
except asyncio.CancelledError:
pass
# Close Discord connection first - stops new commands from triggering cache ops
if not self.is_closed():
await super().close()
# Close API client
await self.api_client.close()
# Clear cache (safe now - no concurrent operations possible)
self.cache.clear()
self.ready = False
logger.info("Discord bot shutdown complete")
# -------------------------------------------------------------------------
# Message Handling
# -------------------------------------------------------------------------
async def on_message(self, message: discord.Message) -> None:
"""Main message handler."""
# mypy
if not self.user:
raise RuntimeError("Critical error: Discord Bot user not found")
try:
# Ignore bot messages
if message.author.bot:
return
# Ignore thread starter messages (empty reference nodes that don't contain content)
if message.type == discord.MessageType.thread_starter_message:
return
# Handle DMs
if isinstance(message.channel, discord.DMChannel):
await handle_dm(message)
return
# Must have a guild
if not message.guild or not message.guild.id:
return
guild_id = message.guild.id
# Check for registration command first
if await handle_registration_command(message, self.cache):
return
# Look up guild in cache
tenant_id = self.cache.get_tenant(guild_id)
# Check for sync-channels command (requires registered guild)
if await handle_sync_channels_command(message, tenant_id, self):
return
if not tenant_id:
# Guild not registered, ignore
return
# Get API key
api_key = self.cache.get_api_key(tenant_id)
if not api_key:
logger.warning(f"No API key cached for tenant {tenant_id}")
return
# Check if bot should respond
should_respond_context = await should_respond(message, tenant_id, self.user)
if not should_respond_context.should_respond:
return
logger.debug(
f"Processing message: '{message.content[:50]}' in "
f"#{getattr(message.channel, 'name', 'unknown')} ({message.guild.name}), "
f"persona_id={should_respond_context.persona_id}"
)
# Process the message
await process_chat_message(
message=message,
api_key=api_key,
persona_id=should_respond_context.persona_id,
thread_only_mode=should_respond_context.thread_only_mode,
api_client=self.api_client,
bot_user=self.user,
)
except Exception as e:
logger.exception(f"Error processing message: {e}")
# -----------------------------------------------------------------------------
# Entry Point
# -----------------------------------------------------------------------------
def main() -> None:
"""Main entry point for Discord bot."""
from onyx.db.engine.sql_engine import SqlEngine
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
logger.info("Starting Onyx Discord Bot...")
# Initialize the database engine (required before any DB operations)
SqlEngine.init_engine(pool_size=20, max_overflow=5)
# Initialize EE features based on environment
set_is_ee_based_on_env_variable()
counter = 0
while True:
token = get_bot_token()
if not token:
if counter % 180 == 0:
logger.info(
"Discord bot is dormant. Waiting for token configuration..."
)
counter += 1
time.sleep(5)
continue
counter = 0
bot = OnyxDiscordClient()
try:
# bot.run() handles SIGINT/SIGTERM and calls close() automatically
bot.run(token)
except Exception:
logger.exception("Fatal error in Discord bot")
raise
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,19 @@
"""Discord bot constants."""
# API settings
API_REQUEST_TIMEOUT: int = 3 * 60 # 3 minutes
# Cache settings
CACHE_REFRESH_INTERVAL: int = 60 # 1 minute
# Message settings
MAX_MESSAGE_LENGTH: int = 2000 # Discord's character limit
MAX_CONTEXT_MESSAGES: int = 10 # Max messages to include in conversation context
# Note: Discord.py's add_reaction() requires unicode emoji, not :name: format
THINKING_EMOJI: str = "🤔" # U+1F914 - Thinking Face
SUCCESS_EMOJI: str = "" # U+2705 - White Heavy Check Mark
ERROR_EMOJI: str = "" # U+274C - Cross Mark
# Command prefix
REGISTER_COMMAND: str = "register"
SYNC_CHANNELS_COMMAND: str = "sync-channels"

View File

@@ -0,0 +1,37 @@
"""Custom exception classes for Discord bot."""
class DiscordBotError(Exception):
"""Base exception for Discord bot errors."""
class RegistrationError(DiscordBotError):
"""Error during guild registration."""
class SyncChannelsError(DiscordBotError):
"""Error during channel sync."""
class APIError(DiscordBotError):
"""Base API error."""
class CacheError(DiscordBotError):
"""Error during cache operations."""
class APIConnectionError(APIError):
"""Failed to connect to API."""
class APITimeoutError(APIError):
"""Request timed out."""
class APIResponseError(APIError):
"""API returned an error response."""
def __init__(self, message: str, status_code: int | None = None):
super().__init__(message)
self.status_code = status_code

View File

@@ -0,0 +1,437 @@
"""Discord bot command handlers for registration and channel sync."""
import asyncio
from datetime import datetime
from datetime import timezone
import discord
from onyx.configs.app_configs import DISCORD_BOT_INVOKE_CHAR
from onyx.configs.constants import ONYX_DISCORD_URL
from onyx.db.discord_bot import bulk_create_channel_configs
from onyx.db.discord_bot import get_guild_config_by_discord_id
from onyx.db.discord_bot import get_guild_config_by_internal_id
from onyx.db.discord_bot import get_guild_config_by_registration_key
from onyx.db.discord_bot import sync_channel_configs
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.utils import DiscordChannelView
from onyx.onyxbot.discord.cache import DiscordCacheManager
from onyx.onyxbot.discord.constants import REGISTER_COMMAND
from onyx.onyxbot.discord.constants import SYNC_CHANNELS_COMMAND
from onyx.onyxbot.discord.exceptions import RegistrationError
from onyx.onyxbot.discord.exceptions import SyncChannelsError
from onyx.server.manage.discord_bot.utils import parse_discord_registration_key
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
async def handle_dm(message: discord.Message) -> None:
"""Handle direct messages."""
dm_response = (
"**I can't respond to DMs** :sweat:\n\n"
f"Please chat with me in a server channel, or join the official "
f"[Onyx Discord]({ONYX_DISCORD_URL}) for help!"
)
await message.channel.send(dm_response)
# -------------------------------------------------------------------------
# Helper functions for error handling
# -------------------------------------------------------------------------
async def _try_dm_author(message: discord.Message, content: str) -> bool:
"""Attempt to DM the message author. Returns True if successful."""
logger.debug(f"Responding in Discord DM with {content}")
try:
await message.author.send(content)
return True
except (discord.Forbidden, discord.HTTPException) as e:
# User has DMs disabled or other error
logger.warning(f"Failed to DM author {message.author.id}: {e}")
except Exception as e:
logger.exception(f"Unexpected error DMing author {message.author.id}: {e}")
return False
async def _try_delete_message(message: discord.Message) -> bool:
"""Attempt to delete a message. Returns True if successful."""
logger.debug(f"Deleting potentially sensitive message {message.id}")
try:
await message.delete()
return True
except (discord.Forbidden, discord.HTTPException) as e:
# Bot lacks permission or other error
logger.warning(f"Failed to delete message {message.id}: {e}")
except Exception as e:
logger.exception(f"Unexpected error deleting message {message.id}: {e}")
return False
async def _try_react_x(message: discord.Message) -> bool:
"""Attempt to react to a message with ❌. Returns True if successful."""
try:
await message.add_reaction("")
return True
except (discord.Forbidden, discord.HTTPException) as e:
# Bot lacks permission or other error
logger.warning(f"Failed to react to message {message.id}: {e}")
except Exception as e:
logger.exception(f"Unexpected error reacting to message {message.id}: {e}")
return False
# -------------------------------------------------------------------------
# Registration
# -------------------------------------------------------------------------
async def handle_registration_command(
message: discord.Message,
cache: DiscordCacheManager,
) -> bool:
"""Handle !register command. Returns True if command was handled."""
content = message.content.strip()
# Check for !register command
if not content.startswith(f"{DISCORD_BOT_INVOKE_CHAR}{REGISTER_COMMAND}"):
return False
# Must be in a server
if not message.guild:
await _try_dm_author(
message, "This command can only be used in a server channel."
)
return True
guild_name = message.guild.name
logger.info(f"Registration command received: {guild_name}")
try:
# Parse the registration key
parts = content.split(maxsplit=1)
if len(parts) < 2:
raise RegistrationError(
"Invalid registration key format. Please check the key and try again."
)
registration_key = parts[1].strip()
if not message.author or not isinstance(message.author, discord.Member):
raise RegistrationError(
"You need to be a server administrator to register the bot."
)
# Check permissions - require admin or manage_guild
if not message.author.guild_permissions.administrator:
if not message.author.guild_permissions.manage_guild:
raise RegistrationError(
"You need **Administrator** or **Manage Server** permissions "
"to register this bot."
)
await _register_guild(message, registration_key, cache)
logger.info(f"Registration successful: {guild_name}")
await message.reply(
":white_check_mark: **Successfully registered!**\n\n"
"This server is now connected to Onyx. "
"I'll respond to messages based on your server and channel settings set in Onyx."
)
except RegistrationError as e:
logger.debug(f"Registration failed: {guild_name}, error={e}")
await _try_dm_author(message, f":x: **Registration failed.**\n\n{e}")
await _try_delete_message(message)
except Exception:
logger.exception(f"Registration failed unexpectedly: {guild_name}")
await _try_dm_author(
message,
":x: **Registration failed.**\n\n"
"An unexpected error occurred. Please try again later.",
)
await _try_delete_message(message)
return True
async def _register_guild(
message: discord.Message,
registration_key: str,
cache: DiscordCacheManager,
) -> None:
"""Register a guild with a registration key."""
if not message.guild:
# mypy, even though we already know that message.guild is not None
raise RegistrationError("This command can only be used in a server.")
logger.info(f"Guild '{message.guild.name}' attempting to register Discord bot")
registration_key = registration_key.strip()
# Parse tenant_id from registration key
parsed = parse_discord_registration_key(registration_key)
if parsed is None:
raise RegistrationError(
"Invalid registration key format. Please check the key and try again."
)
tenant_id = parsed
logger.info(f"Parsed tenant_id {tenant_id} from registration key")
# Check if this guild is already registered to any tenant
guild_id = message.guild.id
existing_tenant = cache.get_tenant(guild_id)
if existing_tenant is not None:
logger.warning(
f"Guild {guild_id} is already registered to tenant {existing_tenant}"
)
raise RegistrationError(
"This server is already registered.\n\n"
"OnyxBot can only connect one Discord server to one Onyx workspace."
)
context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
guild = message.guild
guild_name = guild.name
# Collect all text channels from the guild
channels = get_text_channels(guild)
logger.info(f"Found {len(channels)} text channels in guild '{guild_name}'")
# Validate and update in database
def _sync_register() -> int:
with get_session_with_tenant(tenant_id=tenant_id) as db:
# Find the guild config by registration key
config = get_guild_config_by_registration_key(db, registration_key)
if not config:
raise RegistrationError(
"Registration key not found.\n\n"
"The key may have expired or been deleted. "
"Please generate a new one from the Onyx admin panel."
)
# Check if already used
if config.guild_id is not None:
raise RegistrationError(
"This registration key has already been used.\n\n"
"Each key can only be used once. "
"Please generate a new key from the Onyx admin panel."
)
# Update the guild config
config.guild_id = guild_id
config.guild_name = guild_name
config.registered_at = datetime.now(timezone.utc)
# Create channel configs for all text channels
bulk_create_channel_configs(db, config.id, channels)
db.commit()
return config.id
await asyncio.to_thread(_sync_register)
# Refresh cache for this guild
await cache.refresh_guild(guild_id, tenant_id)
logger.info(
f"Guild '{guild_name}' registered with {len(channels)} channel configs"
)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token)
def get_text_channels(guild: discord.Guild) -> list[DiscordChannelView]:
"""Get all text channels from a guild as DiscordChannelView objects."""
channels: list[DiscordChannelView] = []
for channel in guild.channels:
# Include text channels and forum channels (where threads can be created)
if isinstance(channel, (discord.TextChannel, discord.ForumChannel)):
# Check if channel is private (not visible to @everyone)
everyone_perms = channel.permissions_for(guild.default_role)
is_private = not everyone_perms.view_channel
logger.debug(
f"Found channel: #{channel.name}, "
f"type={channel.type.name}, is_private={is_private}"
)
channels.append(
DiscordChannelView(
channel_id=channel.id,
channel_name=channel.name,
channel_type=channel.type.name, # "text" or "forum"
is_private=is_private,
)
)
logger.debug(f"Retrieved {len(channels)} channels from guild '{guild.name}'")
return channels
# -------------------------------------------------------------------------
# Sync Channels
# -------------------------------------------------------------------------
async def handle_sync_channels_command(
message: discord.Message,
tenant_id: str | None,
bot: discord.Client,
) -> bool:
"""Handle !sync-channels command. Returns True if command was handled."""
content = message.content.strip()
# Check for !sync-channels command
if not content.startswith(f"{DISCORD_BOT_INVOKE_CHAR}{SYNC_CHANNELS_COMMAND}"):
return False
# Must be in a server
if not message.guild:
await _try_dm_author(
message, "This command can only be used in a server channel."
)
return True
guild_name = message.guild.name
logger.info(f"Sync-channels command received: {guild_name}")
try:
# Must be registered
if not tenant_id:
raise SyncChannelsError(
"This server is not registered. Please register it first."
)
# Check permissions - require admin or manage_guild
if not message.author or not isinstance(message.author, discord.Member):
raise SyncChannelsError(
"You need to be a server administrator to sync channels."
)
if not message.author.guild_permissions.administrator:
if not message.author.guild_permissions.manage_guild:
raise SyncChannelsError(
"You need **Administrator** or **Manage Server** permissions "
"to sync channels."
)
# Get guild config ID
def _get_guild_config_id() -> int | None:
with get_session_with_tenant(tenant_id=tenant_id) as db:
if not message.guild:
raise SyncChannelsError(
"Server not found. This shouldn't happen. Please contact Onyx support."
)
config = get_guild_config_by_discord_id(db, message.guild.id)
return config.id if config else None
guild_config_id = await asyncio.to_thread(_get_guild_config_id)
if not guild_config_id:
raise SyncChannelsError(
"Server config not found. This shouldn't happen. Please contact Onyx support."
)
# Perform the sync
added, removed, updated = await sync_guild_channels(
guild_config_id, tenant_id, bot
)
logger.info(
f"Sync-channels successful: {guild_name}, "
f"added={added}, removed={removed}, updated={updated}"
)
await message.reply(
f":white_check_mark: **Channel sync complete!**\n\n"
f"* **{added}** new channel(s) added\n"
f"* **{removed}** deleted channel(s) removed\n"
f"* **{updated}** channel name(s) updated\n\n"
"New channels are disabled by default. Enable them in the Onyx admin panel."
)
except SyncChannelsError as e:
logger.debug(f"Sync-channels failed: {guild_name}, error={e}")
await _try_dm_author(message, f":x: **Channel sync failed.**\n\n{e}")
await _try_react_x(message)
except Exception:
logger.exception(f"Sync-channels failed unexpectedly: {guild_name}")
await _try_dm_author(
message,
":x: **Channel sync failed.**\n\n"
"An unexpected error occurred. Please try again later.",
)
await _try_react_x(message)
return True
async def sync_guild_channels(
guild_config_id: int,
tenant_id: str,
bot: discord.Client,
) -> tuple[int, int, int]:
"""Sync channel configs with current Discord channels for a guild.
Fetches current channels from Discord and syncs with database:
- Creates configs for new channels (disabled by default)
- Removes configs for deleted channels
- Updates names for existing channels if changed
Args:
guild_config_id: Internal ID of the guild config
tenant_id: Tenant ID for database access
bot: Discord bot client
Returns:
(added_count, removed_count, updated_count)
Raises:
ValueError: If guild config not found or guild not registered
"""
context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
# Get guild_id from config
def _get_guild_id() -> int | None:
with get_session_with_tenant(tenant_id=tenant_id) as db:
config = get_guild_config_by_internal_id(db, guild_config_id)
if not config:
return None
return config.guild_id
guild_id = await asyncio.to_thread(_get_guild_id)
if guild_id is None:
raise ValueError(
f"Guild config {guild_config_id} not found or not registered"
)
# Get the guild from Discord
guild = bot.get_guild(guild_id)
if not guild:
raise ValueError(f"Guild {guild_id} not found in Discord cache")
# Get current channels from Discord
channels = get_text_channels(guild)
logger.info(f"Syncing {len(channels)} channels for guild '{guild.name}'")
# Sync with database
def _sync() -> tuple[int, int, int]:
with get_session_with_tenant(tenant_id=tenant_id) as db:
added, removed, updated = sync_channel_configs(
db, guild_config_id, channels
)
db.commit()
return added, removed, updated
added, removed, updated = await asyncio.to_thread(_sync)
logger.info(
f"Channel sync complete for guild '{guild.name}': "
f"added={added}, removed={removed}, updated={updated}"
)
return added, removed, updated
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token)

View File

@@ -0,0 +1,493 @@
"""Discord bot message handling and response logic."""
import asyncio
import discord
from pydantic import BaseModel
from onyx.chat.models import ChatFullResponse
from onyx.db.discord_bot import get_channel_config_by_discord_ids
from onyx.db.discord_bot import get_guild_config_by_discord_id
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.models import DiscordChannelConfig
from onyx.db.models import DiscordGuildConfig
from onyx.onyxbot.discord.api_client import OnyxAPIClient
from onyx.onyxbot.discord.constants import MAX_CONTEXT_MESSAGES
from onyx.onyxbot.discord.constants import MAX_MESSAGE_LENGTH
from onyx.onyxbot.discord.constants import THINKING_EMOJI
from onyx.onyxbot.discord.exceptions import APIError
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Message types with actual content (excludes system notifications like "user joined")
CONTENT_MESSAGE_TYPES = (
discord.MessageType.default,
discord.MessageType.reply,
discord.MessageType.thread_starter_message,
)
class ShouldRespondContext(BaseModel):
"""Context for whether the bot should respond to a message."""
should_respond: bool
persona_id: int | None
thread_only_mode: bool
# -------------------------------------------------------------------------
# Response Logic
# -------------------------------------------------------------------------
async def should_respond(
message: discord.Message,
tenant_id: str,
bot_user: discord.ClientUser,
) -> ShouldRespondContext:
"""Determine if bot should respond and which persona to use."""
if not message.guild:
logger.warning("Received a message that isn't in a server.")
return ShouldRespondContext(
should_respond=False, persona_id=None, thread_only_mode=False
)
guild_id = message.guild.id
channel_id = message.channel.id
bot_mentioned = bot_user in message.mentions
def _get_configs() -> tuple[DiscordGuildConfig | None, DiscordChannelConfig | None]:
with get_session_with_tenant(tenant_id=tenant_id) as db:
guild_config = get_guild_config_by_discord_id(db, guild_id)
if not guild_config or not guild_config.enabled:
return None, None
# For threads, use parent channel ID
actual_channel_id = channel_id
if isinstance(message.channel, discord.Thread) and message.channel.parent:
actual_channel_id = message.channel.parent.id
channel_config = get_channel_config_by_discord_ids(
db, guild_id, actual_channel_id
)
return guild_config, channel_config
guild_config, channel_config = await asyncio.to_thread(_get_configs)
if not guild_config or not channel_config or not channel_config.enabled:
return ShouldRespondContext(
should_respond=False, persona_id=None, thread_only_mode=False
)
# Determine persona (channel override or guild default)
persona_id = channel_config.persona_override_id or guild_config.default_persona_id
# Check mention requirement (with exceptions for implicit invocation)
if channel_config.require_bot_invocation and not bot_mentioned:
if not await check_implicit_invocation(message, bot_user):
return ShouldRespondContext(
should_respond=False, persona_id=None, thread_only_mode=False
)
return ShouldRespondContext(
should_respond=True,
persona_id=persona_id,
thread_only_mode=channel_config.thread_only_mode,
)
async def check_implicit_invocation(
message: discord.Message,
bot_user: discord.ClientUser,
) -> bool:
"""Check if the bot should respond without explicit mention.
Returns True if:
1. User is replying to a bot message
2. User is in a thread owned by the bot
3. User is in a thread created from a bot message
"""
# Check if replying to a bot message
if message.reference and message.reference.message_id:
try:
referenced_msg = await message.channel.fetch_message(
message.reference.message_id
)
if referenced_msg.author.id == bot_user.id:
logger.debug(
f"Implicit invocation via reply: '{message.content[:50]}...'"
)
return True
except (discord.NotFound, discord.HTTPException):
pass
# Check thread-related conditions
if isinstance(message.channel, discord.Thread):
thread = message.channel
# Bot owns the thread
if thread.owner_id == bot_user.id:
logger.debug(
f"Implicit invocation via bot-owned thread: "
f"'{message.content[:50]}...' in #{thread.name}"
)
return True
# Thread was created from a bot message
if thread.parent and not isinstance(thread.parent, discord.ForumChannel):
try:
starter = await thread.parent.fetch_message(thread.id)
if starter.author.id == bot_user.id:
logger.debug(
f"Implicit invocation via bot-started thread: "
f"'{message.content[:50]}...' in #{thread.name}"
)
return True
except (discord.NotFound, discord.HTTPException):
pass
return False
# -------------------------------------------------------------------------
# Message Processing
# -------------------------------------------------------------------------
async def process_chat_message(
message: discord.Message,
api_key: str,
persona_id: int | None,
thread_only_mode: bool,
api_client: OnyxAPIClient,
bot_user: discord.ClientUser,
) -> None:
"""Process a message and send response."""
try:
await message.add_reaction(THINKING_EMOJI)
except discord.DiscordException:
logger.warning(
f"Failed to add thinking reaction to message: '{message.content[:50]}...'"
)
try:
# Build conversation context
context = await _build_conversation_context(message, bot_user)
# Prepare full message content
parts = []
if context:
parts.append(context)
if isinstance(message.channel, discord.Thread):
if isinstance(message.channel.parent, discord.ForumChannel):
parts.append(f"Forum post title: {message.channel.name}")
parts.append(
f"Current message from @{message.author.display_name}: {format_message_content(message)}"
)
# Send to API
response = await api_client.send_chat_message(
message="\n\n".join(parts),
api_key=api_key,
persona_id=persona_id,
)
# Format response with citations
answer = response.answer or "I couldn't generate a response."
answer = _append_citations(answer, response)
await send_response(message, answer, thread_only_mode)
try:
await message.remove_reaction(THINKING_EMOJI, bot_user)
except discord.DiscordException:
pass
except APIError as e:
logger.error(f"API error processing message: {e}")
await send_error_response(message, bot_user)
except Exception as e:
logger.exception(f"Error processing chat message: {e}")
await send_error_response(message, bot_user)
async def _build_conversation_context(
message: discord.Message,
bot_user: discord.ClientUser,
) -> str | None:
"""Build conversation context from thread history or reply chain."""
if isinstance(message.channel, discord.Thread):
return await _build_thread_context(message, bot_user)
elif message.reference:
return await _build_reply_chain_context(message, bot_user)
return None
def _append_citations(answer: str, response: ChatFullResponse) -> str:
"""Append citation sources to the answer if present."""
if not response.citation_info or not response.top_documents:
return answer
cited_docs: list[tuple[int, str, str | None]] = []
for citation in response.citation_info:
doc = next(
(
d
for d in response.top_documents
if d.document_id == citation.document_id
),
None,
)
if doc:
cited_docs.append(
(
citation.citation_number,
doc.semantic_identifier or "Source",
doc.link,
)
)
if not cited_docs:
return answer
cited_docs.sort(key=lambda x: x[0])
citations = "\n\n**Sources:**\n"
for num, name, link in cited_docs[:5]:
if link:
citations += f"{num}. [{name}](<{link}>)\n"
else:
citations += f"{num}. {name}\n"
return answer + citations
# -------------------------------------------------------------------------
# Context Building
# -------------------------------------------------------------------------
async def _build_reply_chain_context(
message: discord.Message,
bot_user: discord.ClientUser,
) -> str | None:
"""Build context by following the reply chain backwards."""
if not message.reference or not message.reference.message_id:
return None
try:
messages: list[discord.Message] = []
current = message
# Follow reply chain backwards up to MAX_CONTEXT_MESSAGES
while (
current.reference
and current.reference.message_id
and len(messages) < MAX_CONTEXT_MESSAGES
):
try:
parent = await message.channel.fetch_message(
current.reference.message_id
)
messages.append(parent)
current = parent
except (discord.NotFound, discord.HTTPException):
break
if not messages:
return None
messages.reverse() # Chronological order
logger.debug(
f"Built reply chain context: {len(messages)} messages in #{getattr(message.channel, 'name', 'unknown')}"
)
return _format_messages_as_context(messages, bot_user)
except Exception as e:
logger.warning(f"Failed to build reply chain context: {e}")
return None
async def _build_thread_context(
message: discord.Message,
bot_user: discord.ClientUser,
) -> str | None:
"""Build context from thread message history."""
if not isinstance(message.channel, discord.Thread):
return None
try:
thread = message.channel
messages: list[discord.Message] = []
# Fetch recent messages (excluding current)
async for msg in thread.history(limit=MAX_CONTEXT_MESSAGES, oldest_first=False):
if msg.id != message.id:
messages.append(msg)
# Include thread starter message and its reply chain if not already present
if thread.parent and not isinstance(thread.parent, discord.ForumChannel):
try:
starter = await thread.parent.fetch_message(thread.id)
if starter.id != message.id and not any(
m.id == starter.id for m in messages
):
messages.append(starter)
# Trace back through the starter's reply chain for more context
current = starter
while (
current.reference
and current.reference.message_id
and len(messages) < MAX_CONTEXT_MESSAGES
):
try:
parent = await thread.parent.fetch_message(
current.reference.message_id
)
if not any(m.id == parent.id for m in messages):
messages.append(parent)
current = parent
except (discord.NotFound, discord.HTTPException):
break
except (discord.NotFound, discord.HTTPException):
pass
if not messages:
return None
messages.sort(key=lambda m: m.id) # Chronological order
logger.debug(
f"Built thread context: {len(messages)} messages in #{thread.name}"
)
return _format_messages_as_context(messages, bot_user)
except Exception as e:
logger.warning(f"Failed to build thread context: {e}")
return None
def _format_messages_as_context(
messages: list[discord.Message],
bot_user: discord.ClientUser,
) -> str | None:
"""Format a list of messages into a conversation context string."""
formatted = []
for msg in messages:
if msg.type not in CONTENT_MESSAGE_TYPES:
continue
sender = (
"OnyxBot" if msg.author.id == bot_user.id else f"@{msg.author.display_name}"
)
formatted.append(f"{sender}: {format_message_content(msg)}")
if not formatted:
return None
return (
"You are a Discord bot named OnyxBot.\n"
'Always assume that [user] is the same as the "Current message" author.'
"Conversation history:\n"
"---\n" + "\n".join(formatted) + "\n---"
)
# -------------------------------------------------------------------------
# Message Formatting
# -------------------------------------------------------------------------
def format_message_content(message: discord.Message) -> str:
"""Format message content with readable mentions."""
content = message.content
for user in message.mentions:
content = content.replace(f"<@{user.id}>", f"@{user.display_name}")
content = content.replace(f"<@!{user.id}>", f"@{user.display_name}")
for role in message.role_mentions:
content = content.replace(f"<@&{role.id}>", f"@{role.name}")
for channel in message.channel_mentions:
content = content.replace(f"<#{channel.id}>", f"#{channel.name}")
return content
# -------------------------------------------------------------------------
# Response Sending
# -------------------------------------------------------------------------
async def send_response(
message: discord.Message,
content: str,
thread_only_mode: bool,
) -> None:
"""Send response based on thread_only_mode setting."""
chunks = _split_message(content)
if isinstance(message.channel, discord.Thread):
for chunk in chunks:
await message.channel.send(chunk)
elif thread_only_mode:
thread_name = f"OnyxBot <> {message.author.display_name}"[:100]
thread = await message.create_thread(name=thread_name)
for chunk in chunks:
await thread.send(chunk)
else:
for i, chunk in enumerate(chunks):
if i == 0:
await message.reply(chunk)
else:
await message.channel.send(chunk)
def _split_message(content: str) -> list[str]:
"""Split content into chunks that fit Discord's message limit."""
chunks = []
while content:
if len(content) <= MAX_MESSAGE_LENGTH:
chunks.append(content)
break
# Find a good split point
split_at = MAX_MESSAGE_LENGTH
for sep in ["\n\n", "\n", ". ", " "]:
idx = content.rfind(sep, 0, MAX_MESSAGE_LENGTH)
if idx > MAX_MESSAGE_LENGTH // 2:
split_at = idx + len(sep)
break
chunks.append(content[:split_at])
content = content[split_at:]
return chunks
async def send_error_response(
message: discord.Message,
bot_user: discord.ClientUser,
) -> None:
"""Send error response and clean up reaction."""
try:
await message.remove_reaction(THINKING_EMOJI, bot_user)
except discord.DiscordException:
pass
error_msg = "Sorry, I encountered an error processing your message. You may want to contact Onyx for support :sweat_smile:"
try:
if isinstance(message.channel, discord.Thread):
await message.channel.send(error_msg)
else:
thread = await message.create_thread(
name=f"Response to {message.author.display_name}"[:100]
)
await thread.send(error_msg)
except discord.DiscordException:
pass

View File

@@ -0,0 +1,39 @@
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import DISCORD_BOT_TOKEN
from onyx.configs.constants import AuthType
from onyx.db.discord_bot import get_discord_bot_config
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.utils.logger import setup_logger
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = setup_logger()
def get_bot_token() -> str | None:
"""Get Discord bot token from env var or database.
Priority:
1. DISCORD_BOT_TOKEN env var (always takes precedence)
2. For self-hosted: DiscordBotConfig in database (default tenant)
3. For Cloud: should always have env var set
Returns:
Bot token string, or None if not configured.
"""
# Environment variable takes precedence
if DISCORD_BOT_TOKEN:
return DISCORD_BOT_TOKEN
# Cloud should always have env var; if not, return None
if AUTH_TYPE == AuthType.CLOUD:
logger.warning("Cloud deployment missing DISCORD_BOT_TOKEN env var")
return None
# Self-hosted: check database for bot config
try:
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db:
config = get_discord_bot_config(db)
except Exception as e:
logger.error(f"Failed to get bot token from database: {e}")
return None
return config.bot_token if config else None

View File

@@ -592,11 +592,8 @@ def build_slack_response_blocks(
)
citations_blocks = []
document_blocks = []
if answer.citation_info:
citations_blocks = _build_citations_blocks(answer)
else:
document_blocks = _priority_ordered_documents_blocks(answer)
citations_divider = [DividerBlock()] if citations_blocks else []
buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else []
@@ -608,7 +605,6 @@ def build_slack_response_blocks(
+ ai_feedback_block
+ citations_divider
+ citations_blocks
+ document_blocks
+ buttons_divider
+ web_follow_up_block
+ follow_up_block

View File

@@ -1,12 +1,149 @@
from mistune import Markdown # type: ignore[import-untyped]
from mistune import Renderer
import re
from collections.abc import Callable
from typing import Any
from mistune import create_markdown
from mistune import HTMLRenderer
# Tags that should be replaced with a newline (line-break and block-level elements)
_HTML_NEWLINE_TAG_PATTERN = re.compile(
r"<br\s*/?>|</(?:p|div|li|h[1-6]|tr|blockquote|section|article)>",
re.IGNORECASE,
)
# Strips HTML tags but excludes autolinks like <https://...> and <mailto:...>
_HTML_TAG_PATTERN = re.compile(
r"<(?!https?://|mailto:)/?[a-zA-Z][^>]*>",
)
# Matches fenced code blocks (``` ... ```) so we can skip sanitization inside them
_FENCED_CODE_BLOCK_PATTERN = re.compile(r"```[\s\S]*?```")
# Matches the start of any markdown link: [text]( or [[n]](
# The inner group handles nested brackets for citation links like [[1]](.
_MARKDOWN_LINK_PATTERN = re.compile(r"\[(?:[^\[\]]|\[[^\]]*\])*\]\(")
# Matches Slack-style links <url|text> that LLMs sometimes output directly.
# Mistune doesn't recognise this syntax, so text() would escape the angle
# brackets and Slack would render them as literal text instead of links.
_SLACK_LINK_PATTERN = re.compile(r"<(https?://[^|>]+)\|([^>]+)>")
def _sanitize_html(text: str) -> str:
"""Strip HTML tags from a text fragment.
Block-level closing tags and <br> are converted to newlines.
All other HTML tags are removed. Autolinks (<https://...>) are preserved.
"""
text = _HTML_NEWLINE_TAG_PATTERN.sub("\n", text)
text = _HTML_TAG_PATTERN.sub("", text)
return text
def _transform_outside_code_blocks(
message: str, transform: Callable[[str], str]
) -> str:
"""Apply *transform* only to text outside fenced code blocks."""
parts = _FENCED_CODE_BLOCK_PATTERN.split(message)
code_blocks = _FENCED_CODE_BLOCK_PATTERN.findall(message)
result: list[str] = []
for i, part in enumerate(parts):
result.append(transform(part))
if i < len(code_blocks):
result.append(code_blocks[i])
return "".join(result)
def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int | None]:
"""Extract markdown link destination, allowing nested parentheses in the URL."""
depth = 0
i = start_idx
while i < len(message):
curr = message[i]
if curr == "\\":
i += 2
continue
if curr == "(":
depth += 1
elif curr == ")":
if depth == 0:
return message[start_idx:i], i
depth -= 1
i += 1
return message[start_idx:], None
def _normalize_link_destinations(message: str) -> str:
"""Wrap markdown link URLs in angle brackets so the parser handles special chars safely.
Markdown link syntax [text](url) breaks when the URL contains unescaped
parentheses, spaces, or other special characters. Wrapping the URL in angle
brackets — [text](<url>) — tells the parser to treat everything inside as
a literal URL. This applies to all links, not just citations.
"""
if "](" not in message:
return message
normalized_parts: list[str] = []
cursor = 0
while match := _MARKDOWN_LINK_PATTERN.search(message, cursor):
normalized_parts.append(message[cursor : match.end()])
destination_start = match.end()
destination, end_idx = _extract_link_destination(message, destination_start)
if end_idx is None:
normalized_parts.append(message[destination_start:])
return "".join(normalized_parts)
already_wrapped = destination.startswith("<") and destination.endswith(">")
if destination and not already_wrapped:
destination = f"<{destination}>"
normalized_parts.append(destination)
normalized_parts.append(")")
cursor = end_idx + 1
normalized_parts.append(message[cursor:])
return "".join(normalized_parts)
def _convert_slack_links_to_markdown(message: str) -> str:
"""Convert Slack-style <url|text> links to standard markdown [text](url).
LLMs sometimes emit Slack mrkdwn link syntax directly. Mistune doesn't
recognise it, so the angle brackets would be escaped by text() and Slack
would render the link as literal text instead of a clickable link.
"""
return _transform_outside_code_blocks(
message, lambda text: _SLACK_LINK_PATTERN.sub(r"[\2](\1)", text)
)
def format_slack_message(message: str | None) -> str:
return Markdown(renderer=SlackRenderer()).render(message)
if message is None:
return ""
message = _transform_outside_code_blocks(message, _sanitize_html)
message = _convert_slack_links_to_markdown(message)
normalized_message = _normalize_link_destinations(message)
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
result = md(normalized_message)
# With HTMLRenderer, result is always str (not AST list)
assert isinstance(result, str)
return result.rstrip("\n")
class SlackRenderer(Renderer):
class SlackRenderer(HTMLRenderer):
"""Renders markdown as Slack mrkdwn format instead of HTML.
Overrides all HTMLRenderer methods that produce HTML tags to ensure
no raw HTML ever appears in Slack messages.
"""
SPECIALS: dict[str, str] = {"&": "&amp;", "<": "&lt;", ">": "&gt;"}
def escape_special(self, text: str) -> str:
@@ -14,52 +151,72 @@ class SlackRenderer(Renderer):
text = text.replace(special, replacement)
return text
def header(self, text: str, level: int, raw: str | None = None) -> str:
return f"*{text}*\n"
def heading(self, text: str, level: int, **attrs: Any) -> str: # noqa: ARG002
return f"*{text}*\n\n"
def emphasis(self, text: str) -> str:
return f"_{text}_"
def double_emphasis(self, text: str) -> str:
def strong(self, text: str) -> str:
return f"*{text}*"
def strikethrough(self, text: str) -> str:
return f"~{text}~"
def list(self, body: str, ordered: bool = True) -> str:
lines = body.split("\n")
def list(self, text: str, ordered: bool, **attrs: Any) -> str:
lines = text.split("\n")
count = 0
for i, line in enumerate(lines):
if line.startswith("li: "):
count += 1
prefix = f"{count}. " if ordered else ""
lines[i] = f"{prefix}{line[4:]}"
return "\n".join(lines)
return "\n".join(lines) + "\n"
def list_item(self, text: str) -> str:
return f"li: {text}\n"
def link(self, link: str, title: str | None, content: str | None) -> str:
escaped_link = self.escape_special(link)
if content:
return f"<{escaped_link}|{content}>"
def link(self, text: str, url: str, title: str | None = None) -> str:
escaped_url = self.escape_special(url)
if text:
return f"<{escaped_url}|{text}>"
if title:
return f"<{escaped_link}|{title}>"
return f"<{escaped_link}>"
return f"<{escaped_url}|{title}>"
return f"<{escaped_url}>"
def image(self, src: str, title: str | None, text: str | None) -> str:
escaped_src = self.escape_special(src)
def image(self, text: str, url: str, title: str | None = None) -> str:
escaped_url = self.escape_special(url)
display_text = title or text
return f"<{escaped_src}|{display_text}>" if display_text else f"<{escaped_src}>"
return f"<{escaped_url}|{display_text}>" if display_text else f"<{escaped_url}>"
def codespan(self, text: str) -> str:
return f"`{text}`"
def block_code(self, text: str, lang: str | None) -> str:
return f"```\n{text}\n```\n"
def block_code(self, code: str, info: str | None = None) -> str: # noqa: ARG002
return f"```\n{code.rstrip(chr(10))}\n```\n\n"
def linebreak(self) -> str:
return "\n"
def thematic_break(self) -> str:
return "---\n\n"
def block_quote(self, text: str) -> str:
lines = text.strip().split("\n")
quoted = "\n".join(f">{line}" for line in lines)
return quoted + "\n\n"
def block_html(self, html: str) -> str:
return _sanitize_html(html) + "\n\n"
def block_error(self, text: str) -> str:
return f"```\n{text}\n```\n\n"
def text(self, text: str) -> str:
# Only escape the three entities Slack recognizes: & < >
# HTMLRenderer.text() also escapes " to &quot; which Slack renders
# as literal &quot; text since Slack doesn't recognize that entity.
return self.escape_special(text)
def paragraph(self, text: str) -> str:
return f"{text}\n"
def autolink(self, link: str, is_email: bool) -> str:
return link if is_email else self.link(link, None, None)
return f"{text}\n\n"

View File

@@ -0,0 +1,38 @@
# ruff: noqa: E501, W605 start
# Note that the user_basic_information is only included if we have at least 1 of the following: user_name, user_email, user_role
# This is included because sometimes we need to know the user's name or basic info to best generate the memory.
FULL_MEMORY_UPDATE_PROMPT = """
You are a memory update agent that helps the user add or update memories. You are given a list of existing memories and a new memory to add. \
Just as context, you are also given the last few user messages from the conversation which generated the new memory. You must determine if the memory is brand new or if it is related to an existing memory. \
If the new memory is an update to an existing memory or contradicts an existing memory, it should be treated as an update and you should reference the existing memory by memory_id (see below). \
The memory should omit the user's name and direct reference to the user - for example, a memory like "Yuhong prefers dark mode." should be modified to "Prefers dark mode." (if the user's name is Yuhong).
# Truncated chat history
{chat_history}{user_basic_information}
# User's existing memories
{existing_memories}
# New memory the user wants to insert
{new_memory}
# Response Style
You MUST respond in a json which follows the following format and keys:
```json
{{
"operation": "add or update",
"memory_id": "if the operation is update, the id of the memory to update, otherwise null",
"memory_text": "the text of the memory to add or update"
}}
```
""".strip()
# ruff: noqa: E501, W605 end
MEMORY_USER_BASIC_INFORMATION_PROMPT = """
# User Basic Information
User name: {user_name}
User email: {user_email}
User role: {user_role}
"""

View File

@@ -1,30 +1,39 @@
from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS
SLACK_QUERY_EXPANSION_PROMPT = f"""
Rewrite the user's query and, if helpful, split it into at most {MAX_SLACK_QUERY_EXPANSIONS} \
keyword-only queries, so that Slack's keyword search yields the best matches.
Rewrite the user's query into at most {MAX_SLACK_QUERY_EXPANSIONS} keyword-only queries for Slack's keyword search.
Keep in mind the Slack's search behavior:
- Pure keyword AND search (no semantics).
- Word order matters.
- More words = fewer matches, so keep each query concise.
- IMPORTANT: Prefer simple 1-2 word queries over longer multi-word queries.
Slack search behavior:
- Pure keyword AND search (no semantics)
- More words = fewer matches, so keep queries concise (1-3 words)
Critical: Extract ONLY keywords that would actually appear in Slack message content.
ALWAYS include:
- Person names (e.g., "Sarah Chen", "Mike Johnson") - people search for messages from/about specific people
- Project/product names, technical terms, proper nouns
- Actual content words: "performance", "bug", "deployment", "API", "error"
DO NOT include:
- Meta-words: "topics", "conversations", "discussed", "summary", "messages", "big", "main", "talking"
- Temporal: "today", "yesterday", "week", "month", "recent", "past", "last"
- Channels/Users: "general", "eng-general", "engineering", "@username"
DO include:
- Actual content: "performance", "bug", "deployment", "API", "database", "error", "feature"
- Meta-words: "topics", "conversations", "discussed", "summary", "messages"
- Temporal: "today", "yesterday", "week", "month", "recent", "last"
- Channel names: "general", "eng-general", "random"
Examples:
Query: "what are the big topics in eng-general this week?"
Output:
Query: "messages with Sarah about the deployment"
Output:
Sarah deployment
Sarah
deployment
Query: "what did Mike say about the budget?"
Output:
Mike budget
Mike
budget
Query: "performance issues in eng-general"
Output:
performance issues
@@ -41,7 +50,7 @@ Now process this query:
{{query}}
Output:
Output (keywords only, one per line, NO explanations or commentary):
"""
SLACK_DATE_EXTRACTION_PROMPT = """

View File

@@ -71,6 +71,14 @@ GENERATE_IMAGE_GUIDANCE = """
NEVER use generate_image unless the user specifically requests an image.
"""
MEMORY_GUIDANCE = """
## add_memory
Use the `add_memory` tool for facts shared by the user that should be remembered for future conversations. \
Only add memories that are specific, likely to remain true, and likely to be useful later. \
Focus on enduring preferences, long-term goals, stable constraints, and explicit "remember this" type requests.
"""
TOOL_CALL_FAILURE_PROMPT = """
LLM attempted to call a tool but failed. Most likely the tool name or arguments were misspelled.
""".strip()

View File

@@ -0,0 +1,40 @@
# ruff: noqa: E501, W605 start
USER_INFORMATION_HEADER = "\n\n# User Information\n"
BASIC_INFORMATION_PROMPT = """
## Basic Information
User name: {user_name}
User email: {user_email}{user_role}
"""
# This line only shows up if the user has configured their role.
USER_ROLE_PROMPT = """
User role: {user_role}
"""
# Team information should be a paragraph style description of the user's team.
TEAM_INFORMATION_PROMPT = """
## Team Information
{team_information}
"""
# User preferences should be a paragraph style description of the user's preferences.
USER_PREFERENCES_PROMPT = """
## User Preferences
{user_preferences}
"""
# User memories should look something like:
# - Memory 1
# - Memory 2
# - Memory 3
USER_MEMORIES_PROMPT = """
## User Memories
{user_memories}
"""
# ruff: noqa: E501, W605 end

View File

@@ -109,6 +109,7 @@ class TenantRedis(redis.Redis):
"unlock",
"get",
"set",
"setex",
"delete",
"exists",
"incrby",

View File

@@ -0,0 +1,184 @@
from onyx.configs.constants import MessageType
from onyx.llm.interfaces import LLM
from onyx.llm.models import ReasoningEffort
from onyx.llm.models import UserMessage
from onyx.prompts.basic_memory import FULL_MEMORY_UPDATE_PROMPT
from onyx.tools.models import ChatMinimalTextMessage
from onyx.utils.logger import setup_logger
from onyx.utils.text_processing import parse_llm_json_response
logger = setup_logger()
# Maximum number of user messages to include
MAX_USER_MESSAGES = 3
MAX_CHARS_PER_MESSAGE = 500
def _format_chat_history(chat_history: list[ChatMinimalTextMessage]) -> str:
user_messages = [
msg for msg in chat_history if msg.message_type == MessageType.USER
]
if not user_messages:
return "No chat history available."
# Take the last N user messages
recent_user_messages = user_messages[-MAX_USER_MESSAGES:]
formatted_parts = []
for i, msg in enumerate(recent_user_messages, start=1):
if len(msg.message) > MAX_CHARS_PER_MESSAGE:
truncated_message = msg.message[:MAX_CHARS_PER_MESSAGE] + "[...truncated]"
else:
truncated_message = msg.message
formatted_parts.append(f"\nUser message:\n{truncated_message}\n")
return "".join(formatted_parts).strip()
def _format_existing_memories(existing_memories: list[str]) -> str:
"""Format existing memories as a numbered list (1-indexed for readability)."""
if not existing_memories:
return "No existing memories."
formatted_lines = []
for i, memory in enumerate(existing_memories, start=1):
formatted_lines.append(f"{i}. {memory}")
return "\n".join(formatted_lines)
def _format_user_basic_information(
user_name: str | None,
user_email: str | None,
user_role: str | None,
) -> str:
"""Format user basic information, only including fields that have values."""
lines = []
if user_name:
lines.append(f"User name: {user_name}")
if user_email:
lines.append(f"User email: {user_email}")
if user_role:
lines.append(f"User role: {user_role}")
if not lines:
return ""
return "\n\n# User Basic Information\n" + "\n".join(lines)
def process_memory_update(
new_memory: str,
existing_memories: list[str],
chat_history: list[ChatMinimalTextMessage],
llm: LLM,
user_name: str | None = None,
user_email: str | None = None,
user_role: str | None = None,
) -> tuple[str, int | None]:
"""
Determine if a memory should be added or updated.
Uses the LLM to analyze the new memory against existing memories and
determine whether to add it as new or update an existing memory.
Args:
new_memory: The new memory text from the memory tool
existing_memories: List of existing memory strings
chat_history: Recent chat history for context
llm: LLM instance to use for the decision
user_name: Optional user name for context
user_email: Optional user email for context
user_role: Optional user role for context
Returns:
Tuple of (memory_text, index_to_replace)
- memory_text: The final memory text to store
- index_to_replace: Index in existing_memories to replace, or None if adding new
"""
# Format inputs for the prompt
formatted_chat_history = _format_chat_history(chat_history)
formatted_memories = _format_existing_memories(existing_memories)
formatted_user_info = _format_user_basic_information(
user_name, user_email, user_role
)
# Build the prompt
prompt = FULL_MEMORY_UPDATE_PROMPT.format(
chat_history=formatted_chat_history,
user_basic_information=formatted_user_info,
existing_memories=formatted_memories,
new_memory=new_memory,
)
# Call LLM
try:
response = llm.invoke(
prompt=UserMessage(content=prompt), reasoning_effort=ReasoningEffort.OFF
)
content = response.choice.message.content
except Exception as e:
logger.warning(f"LLM invocation failed for memory update: {e}")
return (new_memory, None)
# Handle empty response
if not content:
logger.warning(
"LLM returned empty response for memory update, defaulting to add"
)
return (new_memory, None)
# Parse JSON response
parsed_response = parse_llm_json_response(content)
if not parsed_response:
logger.warning(
f"Failed to parse JSON from LLM response: {content[:200]}..., defaulting to add"
)
return (new_memory, None)
# Extract fields from response
operation = parsed_response.get("operation", "add").lower()
memory_id = parsed_response.get("memory_id")
memory_text = parsed_response.get("memory_text", new_memory)
# Ensure memory_text is valid
if not memory_text or not isinstance(memory_text, str):
memory_text = new_memory
# Handle add operation
if operation == "add":
logger.debug("Memory update operation: add")
return (memory_text, None)
# Handle update operation
if operation == "update":
# Validate memory_id
if memory_id is None:
logger.warning("Update operation specified but no memory_id provided")
return (memory_text, None)
# Convert memory_id to integer if it's a string
try:
memory_id_int = int(memory_id)
except (ValueError, TypeError):
logger.warning(f"Invalid memory_id format: {memory_id}")
return (memory_text, None)
# Convert from 1-indexed (LLM response) to 0-indexed (internal)
index_to_replace = memory_id_int - 1
# Validate index is in range
if index_to_replace < 0 or index_to_replace >= len(existing_memories):
logger.warning(
f"memory_id {memory_id_int} out of range (1-{len(existing_memories)}), defaulting to add"
)
return (memory_text, None)
logger.debug(f"Memory update operation: update at index {index_to_replace}")
return (memory_text, index_to_replace)
# Unknown operation, default to add
logger.warning(f"Unknown operation '{operation}', defaulting to add")
return (memory_text, None)

View File

@@ -0,0 +1,294 @@
"""Discord bot admin API endpoints."""
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import status
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import DISCORD_BOT_TOKEN
from onyx.configs.constants import AuthType
from onyx.db.discord_bot import create_discord_bot_config
from onyx.db.discord_bot import create_guild_config
from onyx.db.discord_bot import delete_discord_bot_config
from onyx.db.discord_bot import delete_discord_service_api_key
from onyx.db.discord_bot import delete_guild_config
from onyx.db.discord_bot import get_channel_config_by_internal_ids
from onyx.db.discord_bot import get_channel_configs
from onyx.db.discord_bot import get_discord_bot_config
from onyx.db.discord_bot import get_guild_config_by_internal_id
from onyx.db.discord_bot import get_guild_configs
from onyx.db.discord_bot import update_discord_channel_config
from onyx.db.discord_bot import update_guild_config
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.server.manage.discord_bot.models import DiscordBotConfigCreateRequest
from onyx.server.manage.discord_bot.models import DiscordBotConfigResponse
from onyx.server.manage.discord_bot.models import DiscordChannelConfigResponse
from onyx.server.manage.discord_bot.models import DiscordChannelConfigUpdateRequest
from onyx.server.manage.discord_bot.models import DiscordGuildConfigCreateResponse
from onyx.server.manage.discord_bot.models import DiscordGuildConfigResponse
from onyx.server.manage.discord_bot.models import DiscordGuildConfigUpdateRequest
from onyx.server.manage.discord_bot.utils import (
generate_discord_registration_key,
)
from shared_configs.contextvars import get_current_tenant_id
router = APIRouter(prefix="/manage/admin/discord-bot")
def _check_bot_config_api_access() -> None:
"""Raise 403 if bot config cannot be managed via API.
Bot config endpoints are disabled:
- On Cloud (managed by Onyx)
- When DISCORD_BOT_TOKEN env var is set (managed via env)
"""
if AUTH_TYPE == AuthType.CLOUD:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Discord bot configuration is managed by Onyx on Cloud.",
)
if DISCORD_BOT_TOKEN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Discord bot is configured via environment variables. API access disabled.",
)
# === Bot Config ===
@router.get("/config", response_model=DiscordBotConfigResponse)
def get_bot_config(
_: None = Depends(_check_bot_config_api_access),
__: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> DiscordBotConfigResponse:
"""Get Discord bot config. Returns 403 on Cloud or if env vars set."""
config = get_discord_bot_config(db_session)
if not config:
return DiscordBotConfigResponse(configured=False)
return DiscordBotConfigResponse(
configured=True,
created_at=config.created_at,
)
@router.post("/config", response_model=DiscordBotConfigResponse)
def create_bot_request(
request: DiscordBotConfigCreateRequest,
_: None = Depends(_check_bot_config_api_access),
__: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> DiscordBotConfigResponse:
"""Create Discord bot config. Returns 403 on Cloud or if env vars set."""
try:
config = create_discord_bot_config(
db_session,
bot_token=request.bot_token,
)
except ValueError:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Discord bot config already exists. Delete it first to create a new one.",
)
db_session.commit()
return DiscordBotConfigResponse(
configured=True,
created_at=config.created_at,
)
@router.delete("/config")
def delete_bot_config_endpoint(
_: None = Depends(_check_bot_config_api_access),
__: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict:
"""Delete Discord bot config.
Also deletes the Discord service API key since the bot is being removed.
"""
deleted = delete_discord_bot_config(db_session)
if not deleted:
raise HTTPException(status_code=404, detail="Bot config not found")
# Also delete the service API key used by the Discord bot
delete_discord_service_api_key(db_session)
db_session.commit()
return {"deleted": True}
# === Service API Key ===
@router.delete("/service-api-key")
def delete_service_api_key_endpoint(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict:
"""Delete the Discord service API key.
This endpoint allows manual deletion of the service API key used by the
Discord bot to authenticate with the Onyx API. The key is also automatically
deleted when:
- Bot config is deleted (self-hosted)
- All guild configs are deleted (Cloud)
"""
deleted = delete_discord_service_api_key(db_session)
if not deleted:
raise HTTPException(status_code=404, detail="Service API key not found")
db_session.commit()
return {"deleted": True}
# === Guild Config ===
@router.get("/guilds", response_model=list[DiscordGuildConfigResponse])
def list_guild_configs(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[DiscordGuildConfigResponse]:
"""List all guild configs (pending and registered)."""
configs = get_guild_configs(db_session)
return [DiscordGuildConfigResponse.model_validate(c) for c in configs]
@router.post("/guilds", response_model=DiscordGuildConfigCreateResponse)
def create_guild_request(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> DiscordGuildConfigCreateResponse:
"""Create new guild config with registration key. Key shown once."""
tenant_id = get_current_tenant_id()
registration_key = generate_discord_registration_key(tenant_id)
config = create_guild_config(db_session, registration_key)
db_session.commit()
return DiscordGuildConfigCreateResponse(
id=config.id,
registration_key=registration_key, # Shown once!
)
@router.get("/guilds/{config_id}", response_model=DiscordGuildConfigResponse)
def get_guild_config(
config_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> DiscordGuildConfigResponse:
"""Get specific guild config."""
config = get_guild_config_by_internal_id(db_session, internal_id=config_id)
if not config:
raise HTTPException(status_code=404, detail="Guild config not found")
return DiscordGuildConfigResponse.model_validate(config)
@router.patch("/guilds/{config_id}", response_model=DiscordGuildConfigResponse)
def update_guild_request(
config_id: int,
request: DiscordGuildConfigUpdateRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> DiscordGuildConfigResponse:
"""Update guild config."""
config = get_guild_config_by_internal_id(db_session, internal_id=config_id)
if not config:
raise HTTPException(status_code=404, detail="Guild config not found")
config = update_guild_config(
db_session,
config,
enabled=request.enabled,
default_persona_id=request.default_persona_id,
)
db_session.commit()
return DiscordGuildConfigResponse.model_validate(config)
@router.delete("/guilds/{config_id}")
def delete_guild_request(
config_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict:
"""Delete guild config (invalidates registration key).
On Cloud, if this was the last guild config, also deletes the service API key.
"""
deleted = delete_guild_config(db_session, config_id)
if not deleted:
raise HTTPException(status_code=404, detail="Guild config not found")
# On Cloud, delete service API key when all guilds are removed
if AUTH_TYPE == AuthType.CLOUD:
remaining_guilds = get_guild_configs(db_session)
if not remaining_guilds:
delete_discord_service_api_key(db_session)
db_session.commit()
return {"deleted": True}
# === Channel Config ===
@router.get(
"/guilds/{config_id}/channels", response_model=list[DiscordChannelConfigResponse]
)
def list_channel_configs(
config_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[DiscordChannelConfigResponse]:
"""List whitelisted channels for a guild."""
guild_config = get_guild_config_by_internal_id(db_session, internal_id=config_id)
if not guild_config:
raise HTTPException(status_code=404, detail="Guild config not found")
if not guild_config.guild_id:
raise HTTPException(status_code=400, detail="Guild not yet registered")
configs = get_channel_configs(db_session, config_id)
return [DiscordChannelConfigResponse.model_validate(c) for c in configs]
@router.patch(
"/guilds/{guild_config_id}/channels/{channel_config_id}",
response_model=DiscordChannelConfigResponse,
)
def update_channel_request(
guild_config_id: int,
channel_config_id: int,
request: DiscordChannelConfigUpdateRequest,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> DiscordChannelConfigResponse:
"""Update channel config."""
config = get_channel_config_by_internal_ids(
db_session, guild_config_id, channel_config_id
)
if not config:
raise HTTPException(status_code=404, detail="Channel config not found")
config = update_discord_channel_config(
db_session,
config,
channel_name=config.channel_name, # Keep existing name, only Discord can update
thread_only_mode=request.thread_only_mode,
require_bot_invocation=request.require_bot_invocation,
persona_override_id=request.persona_override_id,
enabled=request.enabled,
)
db_session.commit()
return DiscordChannelConfigResponse.model_validate(config)

View File

@@ -0,0 +1,71 @@
"""Pydantic models for Discord bot API."""
from datetime import datetime
from pydantic import BaseModel
# === Bot Config ===
class DiscordBotConfigResponse(BaseModel):
configured: bool
created_at: datetime | None = None
class Config:
from_attributes = True
class DiscordBotConfigCreateRequest(BaseModel):
bot_token: str
# === Guild Config ===
class DiscordGuildConfigResponse(BaseModel):
id: int
guild_id: int | None
guild_name: str | None
registered_at: datetime | None
default_persona_id: int | None
enabled: bool
class Config:
from_attributes = True
class DiscordGuildConfigCreateResponse(BaseModel):
id: int
registration_key: str # Shown once!
class DiscordGuildConfigUpdateRequest(BaseModel):
enabled: bool
default_persona_id: int | None
# === Channel Config ===
class DiscordChannelConfigResponse(BaseModel):
id: int
guild_config_id: int
channel_id: int
channel_name: str
channel_type: str
is_private: bool
require_bot_invocation: bool
thread_only_mode: bool
persona_override_id: int | None
enabled: bool
class Config:
from_attributes = True
class DiscordChannelConfigUpdateRequest(BaseModel):
require_bot_invocation: bool
persona_override_id: int | None
enabled: bool
thread_only_mode: bool

View File

@@ -0,0 +1,46 @@
"""Discord registration key generation and parsing."""
import secrets
from urllib.parse import quote
from urllib.parse import unquote
from onyx.utils.logger import setup_logger
logger = setup_logger()
REGISTRATION_KEY_PREFIX: str = "discord_"
def generate_discord_registration_key(tenant_id: str) -> str:
"""Generate a one-time registration key with embedded tenant_id.
Format: discord_<url_encoded_tenant_id>.<random_token>
Follows the same pattern as API keys for consistency.
"""
encoded_tenant = quote(tenant_id)
random_token = secrets.token_urlsafe(16)
logger.info(f"Generated Discord registration key for tenant {tenant_id}")
return f"{REGISTRATION_KEY_PREFIX}{encoded_tenant}.{random_token}"
def parse_discord_registration_key(key: str) -> str | None:
"""Parse registration key to extract tenant_id.
Returns tenant_id or None if invalid format.
"""
if not key.startswith(REGISTRATION_KEY_PREFIX):
return None
try:
key_body = key.removeprefix(REGISTRATION_KEY_PREFIX)
parts = key_body.split(".", 1)
if len(parts) != 2:
return None
encoded_tenant = parts[0]
tenant_id = unquote(encoded_tenant)
return tenant_id
except Exception:
return None

View File

@@ -410,26 +410,20 @@ def list_llm_provider_basics(
all_providers = fetch_existing_llm_providers(db_session)
user_group_ids = fetch_user_group_ids(db_session, user) if user else set()
is_admin = user and user.role == UserRole.ADMIN
is_admin = user is not None and user.role == UserRole.ADMIN
accessible_providers = []
for provider in all_providers:
# Include all public providers
if provider.is_public:
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
continue
# Include restricted providers user has access to via groups
if is_admin:
# Admins see all providers
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
elif provider.groups:
# User must be in at least one of the provider's groups
if user_group_ids.intersection({g.id for g in provider.groups}):
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
elif not provider.personas:
# No restrictions = accessible
# Use centralized access control logic with persona=None since we're
# listing providers without a specific persona context. This correctly:
# - Includes all public providers
# - Includes providers user can access via group membership
# - Excludes persona-only restricted providers (requires specific persona)
# - Excludes non-public providers with no restrictions (admin-only)
if can_user_access_llm_provider(
provider, user_group_ids, persona=None, is_admin=is_admin
):
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
end_time = datetime.now(timezone.utc)

View File

@@ -58,6 +58,7 @@ from onyx.db.engine.sql_engine import get_session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.feedback import create_chat_message_feedback
from onyx.db.feedback import remove_chat_message_feedback
from onyx.db.models import ChatSessionSharedStatus
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.persona import get_persona_by_id
@@ -266,7 +267,35 @@ def get_chat_session(
include_deleted=include_deleted,
)
except ValueError:
raise ValueError("Chat session does not exist or has been deleted")
try:
# If we failed to get a chat session, try to retrieve the session with
# less restrictive filters in order to identify what exactly mismatched
# so we can bubble up an accurate error code andmessage.
existing_chat_session = get_chat_session_by_id(
chat_session_id=session_id,
user_id=None,
db_session=db_session,
is_shared=False,
include_deleted=True,
)
except ValueError:
raise HTTPException(status_code=404, detail="Chat session not found")
if not include_deleted and existing_chat_session.deleted:
raise HTTPException(status_code=404, detail="Chat session has been deleted")
if is_shared:
if existing_chat_session.shared_status != ChatSessionSharedStatus.PUBLIC:
raise HTTPException(
status_code=403, detail="Chat session is not shared"
)
elif user_id is not None and existing_chat_session.user_id not in (
user_id,
None,
):
raise HTTPException(status_code=403, detail="Access denied")
raise HTTPException(status_code=404, detail="Chat session not found")
# for chat-seeding: if the session is unassigned, assign it now. This is done here
# to avoid another back and forth between FE -> BE before starting the first

View File

@@ -35,6 +35,8 @@ class MessageOrigin(str, Enum):
CHROME_EXTENSION = "chrome_extension"
API = "api"
SLACKBOT = "slackbot"
WIDGET = "widget"
DISCORDBOT = "discordbot"
UNKNOWN = "unknown"
UNSET = "unset"

View File

@@ -23,6 +23,9 @@ from onyx.server.settings.models import UserSettings
from onyx.server.settings.store import load_settings
from onyx.server.settings.store import store_settings
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
logger = setup_logger()
@@ -37,6 +40,11 @@ def admin_put_settings(
store_settings(settings)
def apply_license_status_to_settings(settings: Settings) -> Settings:
"""MIT version: no-op, returns settings unchanged."""
return settings
@basic_router.get("")
def fetch_settings(
user: User | None = Depends(current_user),
@@ -53,6 +61,13 @@ def fetch_settings(
except KvKeyNotFoundError:
needs_reindexing = False
apply_fn = fetch_versioned_implementation_with_fallback(
"onyx.server.settings.api",
"apply_license_status_to_settings",
apply_license_status_to_settings,
)
general_settings = apply_fn(general_settings)
return UserSettings(
**general_settings.model_dump(),
notifications=settings_notifications,

View File

@@ -80,6 +80,8 @@ class ToolResponse(BaseModel):
# | WebContentResponse
# This comes from custom tools, tool result needs to be saved
| CustomToolCallSummary
# If the rich response is a string, this is what's saved to the tool call in the DB
| str
| None # If nothing needs to be persisted outside of the string value passed to the LLM
)
# This is the final string that needs to be wrapped in a tool call response message and concatenated to the history

View File

@@ -0,0 +1,135 @@
"""
Memory Tool for storing user-specific information.
This tool allows the LLM to save memories about the user for future conversations.
The memories are passed in via override_kwargs which contains the current list of
memories that exist for the user.
"""
from typing import Any
from typing import cast
from pydantic import BaseModel
from typing_extensions import override
from onyx.chat.emitter import Emitter
from onyx.llm.interfaces import LLM
from onyx.secondary_llm_flows.memory_update import process_memory_update
from onyx.server.query_and_chat.placement import Placement
from onyx.tools.interface import Tool
from onyx.tools.models import ChatMinimalTextMessage
from onyx.tools.models import ToolResponse
from onyx.utils.logger import setup_logger
logger = setup_logger()
MEMORY_FIELD = "memory"
class MemoryToolOverrideKwargs(BaseModel):
# Not including the Team Information or User Preferences because these are less likely to contribute to building the memory
# Things like the user's name is important because the LLM may create a memory like "Dave prefers light mode." instead of
# User prefers light mode.
user_name: str | None
user_email: str | None
user_role: str | None
existing_memories: list[str]
chat_history: list[ChatMinimalTextMessage]
class MemoryTool(Tool[MemoryToolOverrideKwargs]):
NAME = "add_memory"
DISPLAY_NAME = "Add Memory"
DESCRIPTION = "Save memories about the user for future conversations."
def __init__(
self,
tool_id: int,
emitter: Emitter,
llm: LLM,
) -> None:
super().__init__(emitter=emitter)
self._id = tool_id
self.llm = llm
@property
def id(self) -> int:
return self._id
@property
def name(self) -> str:
return self.NAME
@property
def description(self) -> str:
return self.DESCRIPTION
@property
def display_name(self) -> str:
return self.DISPLAY_NAME
@override
def tool_definition(self) -> dict:
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": {
MEMORY_FIELD: {
"type": "string",
"description": (
"The text of the memory to add or update. "
"Should be a concise, standalone statement that "
"captures the key information. For example: "
"'User prefers dark mode' or 'User's favorite frontend framework is React'."
),
},
},
"required": [MEMORY_FIELD],
},
},
}
@override
def emit_start(self, placement: Placement) -> None:
# TODO
pass
@override
def run(
self,
placement: Placement,
override_kwargs: MemoryToolOverrideKwargs,
**llm_kwargs: Any,
) -> ToolResponse:
memory = cast(str, llm_kwargs[MEMORY_FIELD])
existing_memories = override_kwargs.existing_memories
chat_history = override_kwargs.chat_history
# Determine if this should be an add or update operation
memory_text, index_to_replace = process_memory_update(
new_memory=memory,
existing_memories=existing_memories,
chat_history=chat_history,
llm=self.llm,
user_name=override_kwargs.user_name,
user_email=override_kwargs.user_email,
user_role=override_kwargs.user_role,
)
# TODO: the data should be return and processed outside of the tool
# Persisted to the db for future conversations
# The actual persistence of the memory will be handled by the caller
# This tool just returns the memory to be saved
logger.info(f"New memory to be added: {memory_text}")
return ToolResponse(
rich_response=memory_text,
llm_facing_response=f"New memory added: {memory_text}",
)

View File

@@ -17,6 +17,8 @@ from onyx.tools.models import ToolCallException
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
from onyx.tools.models import WebSearchToolOverrideKwargs
from onyx.tools.tool_implementations.memory.memory_tool import MemoryTool
from onyx.tools.tool_implementations.memory.memory_tool import MemoryToolOverrideKwargs
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
@@ -297,6 +299,7 @@ def run_tool_calls(
SearchToolOverrideKwargs
| WebSearchToolOverrideKwargs
| OpenURLToolOverrideKwargs
| MemoryToolOverrideKwargs
| None
) = None
@@ -330,6 +333,9 @@ def run_tool_calls(
)
starting_citation_num += 100
elif isinstance(tool, MemoryTool):
raise NotImplementedError("MemoryTool is not implemented")
tool_run_params.append((tool, tool_call, override_kwargs))
# Run all tools in parallel

View File

@@ -21,6 +21,28 @@ ESCAPE_SEQUENCE_RE = re.compile(
re.UNICODE | re.VERBOSE,
)
_INITIAL_FILTER = re.compile(
"["
"\U0000fff0-\U0000ffff" # Specials
"\U0001f000-\U0001f9ff" # Emoticons
"\U00002000-\U0000206f" # General Punctuation
"\U00002190-\U000021ff" # Arrows
"\U00002700-\U000027bf" # Dingbats
"]+",
flags=re.UNICODE,
)
# Regex to match invalid Unicode characters that cause UTF-8 encoding errors:
# - \x00-\x08: Control characters (except tab \x09)
# - \x0b-\x0c: Vertical tab and form feed
# - \x0e-\x1f: More control characters (except newline \x0a, carriage return \x0d)
# - \ud800-\udfff: Surrogate pairs (invalid when unpaired, causes "surrogates not allowed" errors)
# - \ufdd0-\ufdef: Non-characters
# - \ufffe-\uffff: Non-characters
_INVALID_UNICODE_CHARS_RE = re.compile(
"[\x00-\x08\x0b\x0c\x0e-\x1f\ud800-\udfff\ufdd0-\ufdef\ufffe\uffff]"
)
def decode_escapes(s: str) -> str:
def decode_match(match: re.Match) -> str:
@@ -76,27 +98,98 @@ def escape_quotes(original_json_str: str) -> str:
return "".join(result)
def extract_embedded_json(s: str) -> dict:
first_brace_index = s.find("{")
last_brace_index = s.rfind("}")
def find_all_json_objects(text: str) -> list[dict]:
"""Find all JSON objects in text using balanced brace matching.
if first_brace_index == -1 or last_brace_index == -1:
logger.warning("No valid json found, assuming answer is entire string")
return {"answer": s, "quotes": []}
Iterates through the text, and for each '{' found, attempts to find its
matching '}' by counting brace depth. Each balanced substring is then
validated as JSON. This includes nested JSON objects within other objects.
json_str = s[first_brace_index : last_brace_index + 1]
try:
return json.loads(json_str, strict=False)
Use case: Parsing LLM output that may contain multiple JSON objects, or when
the LLM/serving layer outputs function calls in non-standard formats
(e.g. OpenAI's function.open_url style).
except json.JSONDecodeError:
Args:
text: The text to search for JSON objects.
Returns:
A list of all successfully parsed JSON objects (dicts only).
"""
json_objects: list[dict] = []
i = 0
while i < len(text):
if text[i] == "{":
# Try to find a matching closing brace
brace_count = 0
start = i
for j in range(i, len(text)):
if text[j] == "{":
brace_count += 1
elif text[j] == "}":
brace_count -= 1
if brace_count == 0:
# Found potential JSON object
candidate = text[start : j + 1]
try:
parsed = json.loads(candidate)
if isinstance(parsed, dict):
json_objects.append(parsed)
except json.JSONDecodeError:
pass
break
i += 1
return json_objects
def parse_llm_json_response(content: str) -> dict | None:
"""Parse a single JSON object from LLM output, handling markdown code blocks.
Designed for LLM responses that typically contain exactly one JSON object,
possibly wrapped in markdown formatting.
Tries extraction in order:
1. JSON inside markdown code block (```json ... ``` or ``` ... ```)
2. Entire content as raw JSON
3. First '{' to last '}' in content (greedy match)
Args:
content: The LLM response text to parse.
Returns:
The parsed JSON dict if found, None otherwise.
"""
# Try to find JSON in markdown code block first
# Use greedy .* (not .*?) to match nested objects correctly within code block bounds
json_match = re.search(r"```(?:json)?\s*(\{.*\})\s*```", content, re.DOTALL)
if json_match:
try:
return json.loads(escape_quotes(json_str), strict=False)
except json.JSONDecodeError as e:
raise ValueError("Failed to parse JSON, even after escaping quotes") from e
result = json.loads(json_match.group(1))
if isinstance(result, dict):
return result
except json.JSONDecodeError:
pass
# Try to parse the entire content as JSON
try:
result = json.loads(content)
if isinstance(result, dict):
return result
except json.JSONDecodeError:
pass
def clean_up_code_blocks(model_out_raw: str) -> str:
return model_out_raw.strip().strip("```").strip().replace("\\xa0", "")
# Try to find any JSON object in the content
json_match = re.search(r"\{.*\}", content, re.DOTALL)
if json_match:
try:
result = json.loads(json_match.group(0))
if isinstance(result, dict):
return result
except json.JSONDecodeError:
pass
return None
def clean_model_quote(quote: str, trim_length: int) -> str:
@@ -126,18 +219,6 @@ def shared_precompare_cleanup(text: str) -> str:
return text
_INITIAL_FILTER = re.compile(
"["
"\U0000fff0-\U0000ffff" # Specials
"\U0001f000-\U0001f9ff" # Emoticons
"\U00002000-\U0000206f" # General Punctuation
"\U00002190-\U000021ff" # Arrows
"\U00002700-\U000027bf" # Dingbats
"]+",
flags=re.UNICODE,
)
def clean_text(text: str) -> str:
# Remove specific Unicode ranges that might cause issues
cleaned = _INITIAL_FILTER.sub("", text)
@@ -167,18 +248,6 @@ def remove_markdown_image_references(text: str) -> str:
return re.sub(r"!\[[^\]]*\]\([^\)]+\)", "", text)
# Regex to match invalid Unicode characters that cause UTF-8 encoding errors:
# - \x00-\x08: Control characters (except tab \x09)
# - \x0b-\x0c: Vertical tab and form feed
# - \x0e-\x1f: More control characters (except newline \x0a, carriage return \x0d)
# - \ud800-\udfff: Surrogate pairs (invalid when unpaired, causes "surrogates not allowed" errors)
# - \ufdd0-\ufdef: Non-characters
# - \ufffe-\uffff: Non-characters
_INVALID_UNICODE_CHARS_RE = re.compile(
"[\x00-\x08\x0b\x0c\x0e-\x1f\ud800-\udfff\ufdd0-\ufdef\ufffe\uffff]"
)
def remove_invalid_unicode_chars(text: str) -> str:
"""Remove Unicode characters that are invalid in UTF-8 or cause encoding issues.

View File

@@ -573,7 +573,7 @@ mcp==1.25.0
# onyx
mdurl==0.1.2
# via markdown-it-py
mistune==0.8.4
mistune==3.2.0
# via onyx
more-itertools==10.8.0
# via

View File

@@ -11,6 +11,7 @@ aiohappyeyeballs==2.6.1
aiohttp==3.13.3
# via
# aiobotocore
# discord-py
# litellm
# voyageai
aioitertools==0.13.0
@@ -94,6 +95,8 @@ decorator==5.2.1
# via
# ipython
# retry
discord-py==2.4.0
# via onyx
distlib==0.4.0
# via virtualenv
distro==1.9.0
@@ -295,7 +298,7 @@ numpy==2.4.1
# pandas-stubs
# shapely
# voyageai
onyx-devtools==0.3.2
onyx-devtools==0.6.2
# via onyx
openai==2.14.0
# via

View File

@@ -11,6 +11,7 @@ aiohappyeyeballs==2.6.1
aiohttp==3.13.3
# via
# aiobotocore
# discord-py
# litellm
# voyageai
aioitertools==0.13.0
@@ -67,6 +68,8 @@ colorama==0.4.6 ; sys_platform == 'win32'
# tqdm
decorator==5.2.1
# via retry
discord-py==2.4.0
# via onyx
distro==1.9.0
# via openai
docstring-parser==0.17.0

View File

@@ -13,6 +13,7 @@ aiohappyeyeballs==2.6.1
aiohttp==3.13.3
# via
# aiobotocore
# discord-py
# litellm
# voyageai
aioitertools==0.13.0
@@ -83,6 +84,8 @@ colorama==0.4.6 ; sys_platform == 'win32'
# tqdm
decorator==5.2.1
# via retry
discord-py==2.4.0
# via onyx
distro==1.9.0
# via openai
docstring-parser==0.17.0

View File

@@ -191,6 +191,18 @@ autorestart=true
startretries=5
startsecs=60
# Listens for Discord messages and responds with answers
# for all guilds/channels that the OnyxBot has been added to.
# If not configured, will continue to probe every 3 minutes for a Discord bot token.
[program:discord_bot]
command=python onyx/onyxbot/discord/client.py
stdout_logfile=/var/log/discord_bot.log
stdout_logfile_maxbytes=16MB
redirect_stderr=true
autorestart=true
startretries=5
startsecs=60
# Pushes all logs from the above programs to stdout
# No log rotation here, since it's stdout it's handled by the Docker container logging
[program:log-redirect-handler]
@@ -206,6 +218,7 @@ command=tail -qF
/var/log/celery_worker_user_file_processing.log
/var/log/celery_worker_docfetching.log
/var/log/slack_bot.log
/var/log/discord_bot.log
/var/log/supervisord_watchdog_celery_beat.log
/var/log/mcp_server.log
/var/log/mcp_server.err.log

View File

@@ -0,0 +1,281 @@
"""
External dependency unit tests for user file processing queue protections.
Verifies that the three mechanisms added to check_user_file_processing work
correctly:
1. Queue depth backpressure when the broker queue exceeds
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH, no new tasks are enqueued.
2. Per-file Redis guard key if the guard key for a file already exists in
Redis, that file is skipped even though it is still in PROCESSING status.
3. Task expiry every send_task call carries expires=
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES so that stale queued tasks are
discarded by workers automatically.
Also verifies that process_single_user_file clears the guard key the moment
it is picked up by a worker.
Uses real Redis (DB 0 via get_redis_client) and real PostgreSQL for UserFile
rows. The Celery app is provided as a MagicMock injected via a PropertyMock
on the task class so no real broker is needed.
"""
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from unittest.mock import PropertyMock
from uuid import uuid4
from sqlalchemy.orm import Session
from onyx.background.celery.tasks.user_file_processing.tasks import (
_user_file_lock_key,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
_user_file_queued_key,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
check_user_file_processing,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
process_single_user_file,
)
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
from onyx.db.enums import UserFileStatus
from onyx.db.models import UserFile
from onyx.redis.redis_pool import get_redis_client
from tests.external_dependency_unit.conftest import create_test_user
from tests.external_dependency_unit.constants import TEST_TENANT_ID
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_PATCH_QUEUE_LEN = (
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_queue_length"
)
def _create_processing_user_file(db_session: Session, user_id: object) -> UserFile:
"""Insert a UserFile in PROCESSING status and return it."""
uf = UserFile(
id=uuid4(),
user_id=user_id,
file_id=f"test_file_{uuid4().hex[:8]}",
name=f"test_{uuid4().hex[:8]}.txt",
file_type="text/plain",
status=UserFileStatus.PROCESSING,
)
db_session.add(uf)
db_session.commit()
db_session.refresh(uf)
return uf
@contextmanager
def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, None]:
"""Patch the ``app`` property on *task*'s class so that ``self.app``
inside the task function returns *mock_app*.
With ``bind=True``, ``task.run`` is a bound method whose ``__self__`` is
the actual task instance. We patch ``app`` on that instance's class
(a unique Celery-generated Task subclass) so the mock is scoped to this
task only.
"""
task_instance = task.run.__self__
with patch.object(
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
):
yield
# ---------------------------------------------------------------------------
# Test classes
# ---------------------------------------------------------------------------
class TestQueueDepthBackpressure:
"""Protection 1: skip all enqueuing when the broker queue is too deep."""
def test_no_tasks_enqueued_when_queue_over_limit(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""When the queue depth exceeds the limit the beat cycle is skipped."""
user = create_test_user(db_session, "bp_user")
_create_processing_user_file(db_session, user.id)
mock_app = MagicMock()
with (
_patch_task_app(check_user_file_processing, mock_app),
patch(
_PATCH_QUEUE_LEN, return_value=USER_FILE_PROCESSING_MAX_QUEUE_DEPTH + 1
),
):
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
mock_app.send_task.assert_not_called()
class TestPerFileGuardKey:
"""Protection 2: per-file Redis guard key prevents duplicate enqueue."""
def test_guarded_file_not_re_enqueued(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""A file whose guard key is already set in Redis is skipped."""
user = create_test_user(db_session, "guard_user")
uf = _create_processing_user_file(db_session, user.id)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_queued_key(uf.id)
redis_client.setex(guard_key, CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, 1)
mock_app = MagicMock()
try:
with (
_patch_task_app(check_user_file_processing, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=0),
):
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
# send_task must not have been called with this specific file's ID
for call in mock_app.send_task.call_args_list:
kwargs = call.kwargs.get("kwargs", {})
assert kwargs.get("user_file_id") != str(
uf.id
), f"File {uf.id} should have been skipped because its guard key exists"
finally:
redis_client.delete(guard_key)
def test_guard_key_exists_in_redis_after_enqueue(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""After a file is enqueued its guard key is present in Redis with a TTL."""
user = create_test_user(db_session, "guard_set_user")
uf = _create_processing_user_file(db_session, user.id)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_queued_key(uf.id)
redis_client.delete(guard_key) # clean slate
mock_app = MagicMock()
try:
with (
_patch_task_app(check_user_file_processing, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=0),
):
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
assert redis_client.exists(
guard_key
), "Guard key should be set in Redis after enqueue"
ttl = int(redis_client.ttl(guard_key)) # type: ignore[arg-type]
assert 0 < ttl <= CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, (
f"Guard key TTL {ttl}s is outside the expected range "
f"(0, {CELERY_USER_FILE_PROCESSING_TASK_EXPIRES}]"
)
finally:
redis_client.delete(guard_key)
class TestTaskExpiry:
"""Protection 3: every send_task call includes an expires value."""
def test_send_task_called_with_expires(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""send_task is called with the correct queue, task name, and expires."""
user = create_test_user(db_session, "expires_user")
uf = _create_processing_user_file(db_session, user.id)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_queued_key(uf.id)
redis_client.delete(guard_key)
mock_app = MagicMock()
try:
with (
_patch_task_app(check_user_file_processing, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=0),
):
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
# At least one task should have been submitted (for our file)
assert (
mock_app.send_task.call_count >= 1
), "Expected at least one task to be submitted"
# Every submitted task must carry expires
for call in mock_app.send_task.call_args_list:
assert call.args[0] == OnyxCeleryTask.PROCESS_SINGLE_USER_FILE
assert call.kwargs.get("queue") == OnyxCeleryQueues.USER_FILE_PROCESSING
assert (
call.kwargs.get("expires")
== CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
), (
"Task must be submitted with the correct expires value to prevent "
"stale task accumulation"
)
finally:
redis_client.delete(guard_key)
class TestWorkerClearsGuardKey:
"""process_single_user_file removes the guard key when it picks up a task."""
def test_guard_key_deleted_on_pickup(
self,
tenant_context: None, # noqa: ARG002
) -> None:
"""The guard key is deleted before the worker does any real work.
We simulate an already-locked file so process_single_user_file returns
early but crucially, after the guard key deletion.
"""
user_file_id = str(uuid4())
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_queued_key(user_file_id)
# Simulate the guard key set when the beat enqueued the task
redis_client.setex(guard_key, CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, 1)
assert redis_client.exists(guard_key), "Guard key must exist before pickup"
# Hold the per-file processing lock so the worker exits early without
# touching the database or file store.
lock_key = _user_file_lock_key(user_file_id)
processing_lock = redis_client.lock(lock_key, timeout=10)
acquired = processing_lock.acquire(blocking=False)
assert acquired, "Should be able to acquire the processing lock for this test"
try:
process_single_user_file.run(
user_file_id=user_file_id,
tenant_id=TEST_TENANT_ID,
)
finally:
if processing_lock.owned():
processing_lock.release()
assert not redis_client.exists(
guard_key
), "Guard key should be deleted when the worker picks up the task"

View File

@@ -90,12 +90,12 @@ _EXPECTED_CONFLUENCE_GROUPS = [
),
ExternalUserGroupSet(
id="bitbucket-users-onyxai",
user_emails={"oauth@onyx.app"},
user_emails={"founders@onyx.app", "oauth@onyx.app"},
gives_anyone_access=False,
),
ExternalUserGroupSet(
id="bitbucket-admins-onyxai",
user_emails={"oauth@onyx.app"},
user_emails={"founders@onyx.app", "oauth@onyx.app"},
gives_anyone_access=False,
),
ExternalUserGroupSet(

View File

@@ -77,6 +77,16 @@ _EXPECTED_JIRA_GROUPS = [
},
gives_anyone_access=False,
),
ExternalUserGroupSet(
id="bitbucket-admins-onyxai",
user_emails={"founders@onyx.app"}, # no Oauth, we skip "app" account in jira
gives_anyone_access=False,
),
ExternalUserGroupSet(
id="bitbucket-users-onyxai",
user_emails={"founders@onyx.app"}, # no Oauth, we skip "app" account in jira
gives_anyone_access=False,
),
]

View File

@@ -0,0 +1,162 @@
"""Fixtures for Discord bot external dependency tests."""
from collections.abc import Generator
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
import discord
import pytest
from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
TEST_TENANT_ID: str = "public"
@pytest.fixture(scope="function")
def db_session() -> Generator[Session, None, None]:
"""Create a database session for testing."""
SqlEngine.init_engine(pool_size=10, max_overflow=5)
with get_session_with_current_tenant() as session:
yield session
@pytest.fixture(scope="function")
def tenant_context() -> Generator[None, None, None]:
"""Set up tenant context for testing."""
token = CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
try:
yield
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@pytest.fixture
def mock_cache_manager() -> MagicMock:
"""Mock DiscordCacheManager."""
cache = MagicMock()
cache.get_tenant.return_value = TEST_TENANT_ID
cache.get_api_key.return_value = "test_api_key"
cache.refresh_all = AsyncMock()
cache.refresh_guild = AsyncMock()
cache.is_initialized = True
return cache
@pytest.fixture
def mock_api_client() -> MagicMock:
"""Mock OnyxAPIClient."""
client = MagicMock()
client.initialize = AsyncMock()
client.close = AsyncMock()
client.is_initialized = True
# Mock successful response
mock_response = MagicMock()
mock_response.answer = "Test response from bot"
mock_response.citation_info = None
mock_response.top_documents = None
mock_response.error_msg = None
client.send_chat_message = AsyncMock(return_value=mock_response)
client.health_check = AsyncMock(return_value=True)
return client
@pytest.fixture
def mock_discord_guild() -> MagicMock:
"""Mock Discord guild with channels."""
guild = MagicMock(spec=discord.Guild)
guild.id = 123456789
guild.name = "Test Server"
guild.default_role = MagicMock()
# Create some mock channels
text_channel = MagicMock(spec=discord.TextChannel)
text_channel.id = 111111111
text_channel.name = "general"
text_channel.type = discord.ChannelType.text
perms = MagicMock()
perms.view_channel = True
text_channel.permissions_for.return_value = perms
forum_channel = MagicMock(spec=discord.ForumChannel)
forum_channel.id = 222222222
forum_channel.name = "forum"
forum_channel.type = discord.ChannelType.forum
forum_channel.permissions_for.return_value = perms
private_channel = MagicMock(spec=discord.TextChannel)
private_channel.id = 333333333
private_channel.name = "private"
private_channel.type = discord.ChannelType.text
private_perms = MagicMock()
private_perms.view_channel = False
private_channel.permissions_for.return_value = private_perms
guild.channels = [text_channel, forum_channel, private_channel]
guild.text_channels = [text_channel, private_channel]
guild.forum_channels = [forum_channel]
return guild
@pytest.fixture
def mock_discord_message(mock_discord_guild: MagicMock) -> MagicMock:
"""Mock Discord message for testing."""
msg = MagicMock(spec=discord.Message)
msg.id = 555555555
msg.author = MagicMock(spec=discord.Member)
msg.author.id = 444444444
msg.author.bot = False
msg.author.display_name = "TestUser"
msg.author.guild_permissions = MagicMock()
msg.author.guild_permissions.administrator = True
msg.author.guild_permissions.manage_guild = True
msg.content = "Hello bot"
msg.guild = mock_discord_guild
msg.channel = MagicMock()
msg.channel.id = 111111111
msg.channel.name = "general"
msg.channel.send = AsyncMock()
msg.type = discord.MessageType.default
msg.mentions = []
msg.role_mentions = []
msg.channel_mentions = []
msg.reference = None
msg.add_reaction = AsyncMock()
msg.remove_reaction = AsyncMock()
msg.reply = AsyncMock()
msg.create_thread = AsyncMock()
return msg
@pytest.fixture
def mock_bot_user() -> MagicMock:
"""Mock Discord bot user."""
user = MagicMock(spec=discord.ClientUser)
user.id = 987654321
user.display_name = "OnyxBot"
user.bot = True
return user
@pytest.fixture
def mock_discord_bot(
mock_cache_manager: MagicMock,
mock_api_client: MagicMock,
mock_bot_user: MagicMock,
) -> MagicMock:
"""Mock OnyxDiscordClient."""
bot = MagicMock()
bot.user = mock_bot_user
bot.cache = mock_cache_manager
bot.api_client = mock_api_client
bot.ready = True
bot.loop = MagicMock()
bot.is_closed.return_value = False
bot.guilds = []
return bot

View File

@@ -0,0 +1,616 @@
"""Tests for Discord bot event handling with mocked Discord API.
These tests mock the Discord API to test event handling logic.
"""
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
from unittest.mock import patch
import discord
import pytest
from onyx.onyxbot.discord.handle_commands import get_text_channels
from onyx.onyxbot.discord.handle_commands import handle_dm
from onyx.onyxbot.discord.handle_commands import handle_registration_command
from onyx.onyxbot.discord.handle_commands import handle_sync_channels_command
from onyx.onyxbot.discord.handle_message import process_chat_message
from onyx.onyxbot.discord.handle_message import send_error_response
from onyx.onyxbot.discord.handle_message import send_response
class TestGuildRegistrationCommand:
"""Tests for !register command handling."""
@pytest.mark.asyncio
async def test_register_guild_success(
self,
mock_discord_message: MagicMock,
mock_cache_manager: MagicMock,
) -> None:
"""Valid registration key with admin perms succeeds."""
mock_discord_message.content = "!register discord_public.valid_token"
with (
patch(
"onyx.onyxbot.discord.handle_commands.parse_discord_registration_key",
return_value="public",
),
patch(
"onyx.onyxbot.discord.handle_commands.get_session_with_tenant"
) as mock_session,
patch(
"onyx.onyxbot.discord.handle_commands.get_guild_config_by_registration_key"
) as mock_get_config,
patch("onyx.onyxbot.discord.handle_commands.bulk_create_channel_configs"),
):
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
mock_config = MagicMock()
mock_config.id = 1
mock_config.guild_id = None # Not yet registered
mock_get_config.return_value = mock_config
mock_cache_manager.get_tenant.return_value = None # Not in cache yet
result = await handle_registration_command(
mock_discord_message, mock_cache_manager
)
assert result is True
mock_discord_message.reply.assert_called()
# Check that success message was sent
call_args = mock_discord_message.reply.call_args
assert "Successfully registered" in str(call_args)
@pytest.mark.asyncio
async def test_register_invalid_key_format(
self,
mock_discord_message: MagicMock,
mock_cache_manager: MagicMock,
) -> None:
"""Malformed key DMs user and deletes message."""
mock_discord_message.content = "!register abc" # Malformed
with patch(
"onyx.onyxbot.discord.handle_commands.parse_discord_registration_key",
return_value=None, # Invalid format
):
result = await handle_registration_command(
mock_discord_message, mock_cache_manager
)
assert result is True
# On failure: DM the author and delete the message
mock_discord_message.author.send.assert_called()
call_args = mock_discord_message.author.send.call_args
assert "Invalid" in str(call_args)
mock_discord_message.delete.assert_called()
@pytest.mark.asyncio
async def test_register_key_not_found(
self,
mock_discord_message: MagicMock,
mock_cache_manager: MagicMock,
) -> None:
"""Key not in database DMs user and deletes message."""
mock_discord_message.content = "!register discord_public.notexist"
with (
patch(
"onyx.onyxbot.discord.handle_commands.parse_discord_registration_key",
return_value="public",
),
patch(
"onyx.onyxbot.discord.handle_commands.get_session_with_tenant"
) as mock_session,
patch(
"onyx.onyxbot.discord.handle_commands.get_guild_config_by_registration_key",
return_value=None, # Not found
),
):
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
# Must return False so exceptions are not suppressed
mock_session.return_value.__exit__ = MagicMock(return_value=False)
mock_cache_manager.get_tenant.return_value = None
result = await handle_registration_command(
mock_discord_message, mock_cache_manager
)
assert result is True
# On failure: DM the author and delete the message
mock_discord_message.author.send.assert_called()
call_args = mock_discord_message.author.send.call_args
assert "not found" in str(call_args).lower()
mock_discord_message.delete.assert_called()
@pytest.mark.asyncio
async def test_register_key_already_used(
self,
mock_discord_message: MagicMock,
mock_cache_manager: MagicMock,
) -> None:
"""Previously used key DMs user and deletes message."""
mock_discord_message.content = "!register discord_public.used_key"
with (
patch(
"onyx.onyxbot.discord.handle_commands.parse_discord_registration_key",
return_value="public",
),
patch(
"onyx.onyxbot.discord.handle_commands.get_session_with_tenant"
) as mock_session,
patch(
"onyx.onyxbot.discord.handle_commands.get_guild_config_by_registration_key"
) as mock_get_config,
):
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
# Must return False so exceptions are not suppressed
mock_session.return_value.__exit__ = MagicMock(return_value=False)
mock_config = MagicMock()
mock_config.guild_id = 999999 # Already registered!
mock_get_config.return_value = mock_config
mock_cache_manager.get_tenant.return_value = None
result = await handle_registration_command(
mock_discord_message, mock_cache_manager
)
assert result is True
# On failure: DM the author and delete the message
mock_discord_message.author.send.assert_called()
call_args = mock_discord_message.author.send.call_args
assert "already" in str(call_args).lower()
mock_discord_message.delete.assert_called()
@pytest.mark.asyncio
async def test_register_guild_already_registered(
self,
mock_discord_message: MagicMock,
mock_cache_manager: MagicMock,
) -> None:
"""Guild already in cache DMs user and deletes message."""
mock_discord_message.content = "!register discord_public.valid_token"
with patch(
"onyx.onyxbot.discord.handle_commands.parse_discord_registration_key",
return_value="public",
):
# Guild already in cache
mock_cache_manager.get_tenant.return_value = "existing_tenant"
result = await handle_registration_command(
mock_discord_message, mock_cache_manager
)
assert result is True
# On failure: DM the author and delete the message
mock_discord_message.author.send.assert_called()
call_args = mock_discord_message.author.send.call_args
assert "already registered" in str(call_args).lower()
mock_discord_message.delete.assert_called()
@pytest.mark.asyncio
async def test_register_no_permission(
self,
mock_discord_message: MagicMock,
mock_cache_manager: MagicMock,
) -> None:
"""User without admin perms gets DM and message deleted."""
mock_discord_message.content = "!register discord_public.valid_token"
mock_discord_message.author.guild_permissions.administrator = False
mock_discord_message.author.guild_permissions.manage_guild = False
result = await handle_registration_command(
mock_discord_message, mock_cache_manager
)
assert result is True
# On failure: DM the author and delete the message
mock_discord_message.author.send.assert_called()
call_args = mock_discord_message.author.send.call_args
assert "permission" in str(call_args).lower()
mock_discord_message.delete.assert_called()
@pytest.mark.asyncio
async def test_register_in_dm(
self,
mock_cache_manager: MagicMock,
) -> None:
"""Registration in DM sends DM and returns True."""
msg = MagicMock(spec=discord.Message)
msg.guild = None # DM
msg.content = "!register discord_public.token"
msg.author = MagicMock()
msg.author.send = AsyncMock()
result = await handle_registration_command(msg, mock_cache_manager)
assert result is True
msg.author.send.assert_called()
call_args = msg.author.send.call_args
assert "server" in str(call_args).lower()
@pytest.mark.asyncio
async def test_register_syncs_forum_channels(
self,
mock_discord_message: MagicMock,
mock_discord_guild: MagicMock,
) -> None:
"""Forum channels are included in sync."""
channels = get_text_channels(mock_discord_guild)
channel_types = [c.channel_type for c in channels]
assert "forum" in channel_types
@pytest.mark.asyncio
async def test_register_private_channel_detection(
self,
mock_discord_message: MagicMock,
mock_discord_guild: MagicMock,
) -> None:
"""Private channels are marked correctly."""
channels = get_text_channels(mock_discord_guild)
private_channels = [c for c in channels if c.is_private]
assert len(private_channels) >= 1
class TestSyncChannelsCommand:
"""Tests for !sync-channels command handling."""
@pytest.mark.asyncio
async def test_sync_channels_adds_new(
self,
mock_discord_message: MagicMock,
mock_discord_bot: MagicMock,
) -> None:
"""New channel in Discord creates channel config."""
mock_discord_message.content = "!sync-channels"
with (
patch(
"onyx.onyxbot.discord.handle_commands.get_session_with_tenant"
) as mock_session,
patch(
"onyx.onyxbot.discord.handle_commands.get_guild_config_by_discord_id"
) as mock_get_guild,
patch(
"onyx.onyxbot.discord.handle_commands.get_guild_config_by_internal_id"
) as mock_get_guild_internal,
patch(
"onyx.onyxbot.discord.handle_commands.sync_channel_configs"
) as mock_sync,
):
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
mock_config = MagicMock()
mock_config.id = 1
mock_config.guild_id = 123456789
mock_get_guild.return_value = mock_config
mock_get_guild_internal.return_value = mock_config
mock_sync.return_value = (1, 0, 0) # 1 added, 0 removed, 0 updated
mock_discord_bot.get_guild.return_value = mock_discord_message.guild
result = await handle_sync_channels_command(
mock_discord_message, "public", mock_discord_bot
)
assert result is True
mock_discord_message.reply.assert_called()
@pytest.mark.asyncio
async def test_sync_channels_no_permission(
self,
mock_discord_message: MagicMock,
mock_discord_bot: MagicMock,
) -> None:
"""User without admin perms gets DM and reaction."""
mock_discord_message.content = "!sync-channels"
mock_discord_message.author.guild_permissions.administrator = False
mock_discord_message.author.guild_permissions.manage_guild = False
result = await handle_sync_channels_command(
mock_discord_message, "public", mock_discord_bot
)
assert result is True
# On failure: DM the author and react with ❌
mock_discord_message.author.send.assert_called()
call_args = mock_discord_message.author.send.call_args
assert "permission" in str(call_args).lower()
mock_discord_message.add_reaction.assert_called_with("")
@pytest.mark.asyncio
async def test_sync_channels_unregistered_guild(
self,
mock_discord_message: MagicMock,
mock_discord_bot: MagicMock,
) -> None:
"""Sync in unregistered guild gets DM and reaction."""
mock_discord_message.content = "!sync-channels"
# tenant_id is None = not registered
result = await handle_sync_channels_command(
mock_discord_message, None, mock_discord_bot
)
assert result is True
# On failure: DM the author and react with ❌
mock_discord_message.author.send.assert_called()
call_args = mock_discord_message.author.send.call_args
assert "not registered" in str(call_args).lower()
mock_discord_message.add_reaction.assert_called_with("")
class TestMessageHandling:
"""Tests for message handling behavior."""
@pytest.mark.asyncio
async def test_message_adds_thinking_emoji(
self,
mock_discord_message: MagicMock,
mock_api_client: MagicMock,
mock_bot_user: MagicMock,
) -> None:
"""Thinking emoji is added during processing."""
await process_chat_message(
message=mock_discord_message,
api_key="test_key",
persona_id=None,
thread_only_mode=False,
api_client=mock_api_client,
bot_user=mock_bot_user,
)
mock_discord_message.add_reaction.assert_called()
@pytest.mark.asyncio
async def test_message_removes_thinking_emoji(
self,
mock_discord_message: MagicMock,
mock_api_client: MagicMock,
mock_bot_user: MagicMock,
) -> None:
"""Thinking emoji is removed after response."""
await process_chat_message(
message=mock_discord_message,
api_key="test_key",
persona_id=None,
thread_only_mode=False,
api_client=mock_api_client,
bot_user=mock_bot_user,
)
mock_discord_message.remove_reaction.assert_called()
@pytest.mark.asyncio
async def test_message_reaction_failure_non_blocking(
self,
mock_discord_message: MagicMock,
mock_api_client: MagicMock,
mock_bot_user: MagicMock,
) -> None:
"""add_reaction failure doesn't block processing."""
mock_discord_message.add_reaction = AsyncMock(
side_effect=discord.DiscordException("Cannot add reaction")
)
# Should not raise - just log warning and continue
await process_chat_message(
message=mock_discord_message,
api_key="test_key",
persona_id=None,
thread_only_mode=False,
api_client=mock_api_client,
bot_user=mock_bot_user,
)
# Should still complete and send reply
mock_discord_message.reply.assert_called()
@pytest.mark.asyncio
async def test_dm_response(self) -> None:
"""DM to bot sends redirect message."""
msg = MagicMock(spec=discord.Message)
msg.channel = MagicMock(spec=discord.DMChannel)
msg.channel.send = AsyncMock()
await handle_dm(msg)
msg.channel.send.assert_called_once()
call_args = msg.channel.send.call_args
assert "DM" in str(call_args) or "server" in str(call_args).lower()
class TestThreadCreationAndResponseRouting:
"""Tests for thread creation and response routing."""
@pytest.mark.asyncio
async def test_response_in_existing_thread(
self,
mock_bot_user: MagicMock,
) -> None:
"""Message in thread - response appended to thread."""
thread = MagicMock(spec=discord.Thread)
thread.send = AsyncMock()
msg = MagicMock(spec=discord.Message)
msg.channel = thread
msg.reply = AsyncMock()
msg.create_thread = AsyncMock()
await send_response(msg, "Test response", thread_only_mode=False)
# Should send to thread, not create new thread
thread.send.assert_called()
msg.create_thread.assert_not_called()
@pytest.mark.asyncio
async def test_response_creates_thread_thread_only_mode(
self,
mock_discord_message: MagicMock,
mock_bot_user: MagicMock,
) -> None:
"""thread_only_mode=true creates new thread for response."""
mock_thread = MagicMock()
mock_thread.send = AsyncMock()
mock_discord_message.create_thread = AsyncMock(return_value=mock_thread)
# Make sure it's not a thread
mock_discord_message.channel = MagicMock(spec=discord.TextChannel)
await send_response(
mock_discord_message, "Test response", thread_only_mode=True
)
mock_discord_message.create_thread.assert_called()
mock_thread.send.assert_called()
@pytest.mark.asyncio
async def test_response_replies_inline(
self,
mock_discord_message: MagicMock,
mock_bot_user: MagicMock,
) -> None:
"""thread_only_mode=false uses message.reply()."""
# Make sure it's not a thread
mock_discord_message.channel = MagicMock(spec=discord.TextChannel)
await send_response(
mock_discord_message, "Test response", thread_only_mode=False
)
mock_discord_message.reply.assert_called()
@pytest.mark.asyncio
async def test_thread_name_truncation(
self,
mock_bot_user: MagicMock,
) -> None:
"""Thread name is truncated to 100 chars."""
msg = MagicMock(spec=discord.Message)
msg.channel = MagicMock(spec=discord.TextChannel)
msg.author = MagicMock()
msg.author.display_name = "A" * 200 # Very long name
mock_thread = MagicMock()
mock_thread.send = AsyncMock()
msg.create_thread = AsyncMock(return_value=mock_thread)
await send_response(msg, "Test", thread_only_mode=True)
call_args = msg.create_thread.call_args
thread_name = call_args.kwargs.get("name") or call_args[1].get("name")
assert len(thread_name) <= 100
@pytest.mark.asyncio
async def test_error_response_creates_thread(
self,
mock_discord_message: MagicMock,
mock_bot_user: MagicMock,
) -> None:
"""Error response in channel creates thread."""
mock_discord_message.channel = MagicMock(spec=discord.TextChannel)
mock_thread = MagicMock()
mock_thread.send = AsyncMock()
mock_discord_message.create_thread = AsyncMock(return_value=mock_thread)
await send_error_response(mock_discord_message, mock_bot_user)
mock_discord_message.create_thread.assert_called()
class TestBotLifecycle:
"""Tests for bot lifecycle management."""
@pytest.mark.asyncio
async def test_setup_hook_initializes_cache(
self,
mock_cache_manager: MagicMock,
mock_api_client: MagicMock,
) -> None:
"""setup_hook calls cache.refresh_all()."""
from onyx.onyxbot.discord.client import OnyxDiscordClient
with (
patch.object(OnyxDiscordClient, "__init__", lambda self: None),
patch(
"onyx.onyxbot.discord.client.DiscordCacheManager",
return_value=mock_cache_manager,
),
patch(
"onyx.onyxbot.discord.client.OnyxAPIClient",
return_value=mock_api_client,
),
):
bot = OnyxDiscordClient()
bot.cache = mock_cache_manager
bot.api_client = mock_api_client
bot.loop = MagicMock()
bot.loop.create_task = MagicMock()
await bot.setup_hook()
mock_cache_manager.refresh_all.assert_called()
@pytest.mark.asyncio
async def test_setup_hook_initializes_api_client(
self,
mock_cache_manager: MagicMock,
mock_api_client: MagicMock,
) -> None:
"""setup_hook calls api_client.initialize()."""
from onyx.onyxbot.discord.client import OnyxDiscordClient
with (patch.object(OnyxDiscordClient, "__init__", lambda self: None),):
bot = OnyxDiscordClient()
bot.cache = mock_cache_manager
bot.api_client = mock_api_client
bot.loop = MagicMock()
bot.loop.create_task = MagicMock()
await bot.setup_hook()
mock_api_client.initialize.assert_called()
@pytest.mark.asyncio
async def test_close_closes_api_client(
self,
mock_cache_manager: MagicMock,
mock_api_client: MagicMock,
) -> None:
"""close() calls api_client.close()."""
from onyx.onyxbot.discord.client import OnyxDiscordClient
with (
patch.object(OnyxDiscordClient, "__init__", lambda self: None),
patch.object(OnyxDiscordClient, "is_closed", return_value=True),
):
bot = OnyxDiscordClient()
bot.cache = mock_cache_manager
bot.api_client = mock_api_client
bot._cache_refresh_task = None
bot.ready = True
# Mock parent close
async def mock_super_close() -> None:
pass
with patch("discord.ext.commands.Bot.close", mock_super_close):
await bot.close()
mock_api_client.close.assert_called()
mock_cache_manager.clear.assert_called()

View File

@@ -476,8 +476,8 @@ class ChatSessionManager:
else GENERAL_HEADERS
),
)
# Chat session should return 400 if it doesn't exist
return response.status_code == 400
# Chat session should return 404 if it doesn't exist or is deleted
return response.status_code == 404
@staticmethod
def verify_soft_deleted(

View File

@@ -0,0 +1,310 @@
"""Manager for Discord bot API integration tests."""
import requests
from onyx.db.discord_bot import create_channel_config
from onyx.db.discord_bot import create_guild_config
from onyx.db.discord_bot import register_guild
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.utils import DiscordChannelView
from onyx.server.manage.discord_bot.utils import generate_discord_registration_key
from shared_configs.contextvars import get_current_tenant_id
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.test_models import DATestDiscordChannelConfig
from tests.integration.common_utils.test_models import DATestDiscordGuildConfig
from tests.integration.common_utils.test_models import DATestUser
DISCORD_BOT_API_URL = f"{API_SERVER_URL}/manage/admin/discord-bot"
class DiscordBotManager:
"""Manager for Discord bot API operations."""
# === Bot Config ===
@staticmethod
def get_bot_config(
user_performing_action: DATestUser,
) -> dict:
"""Get Discord bot config."""
response = requests.get(
url=f"{DISCORD_BOT_API_URL}/config",
headers=user_performing_action.headers,
cookies=user_performing_action.cookies,
)
response.raise_for_status()
return response.json()
@staticmethod
def create_bot_config(
bot_token: str,
user_performing_action: DATestUser,
) -> dict:
"""Create Discord bot config."""
response = requests.post(
url=f"{DISCORD_BOT_API_URL}/config",
headers=user_performing_action.headers,
cookies=user_performing_action.cookies,
json={"bot_token": bot_token},
)
response.raise_for_status()
return response.json()
@staticmethod
def delete_bot_config(
user_performing_action: DATestUser,
) -> dict:
"""Delete Discord bot config."""
response = requests.delete(
url=f"{DISCORD_BOT_API_URL}/config",
headers=user_performing_action.headers,
cookies=user_performing_action.cookies,
)
response.raise_for_status()
return response.json()
# === Guild Config ===
@staticmethod
def list_guilds(
user_performing_action: DATestUser,
) -> list[dict]:
"""List all guild configs."""
response = requests.get(
url=f"{DISCORD_BOT_API_URL}/guilds",
headers=user_performing_action.headers,
cookies=user_performing_action.cookies,
)
response.raise_for_status()
return response.json()
@staticmethod
def create_guild(
user_performing_action: DATestUser,
) -> DATestDiscordGuildConfig:
"""Create a new guild config with registration key."""
response = requests.post(
url=f"{DISCORD_BOT_API_URL}/guilds",
headers=user_performing_action.headers,
cookies=user_performing_action.cookies,
)
response.raise_for_status()
data = response.json()
return DATestDiscordGuildConfig(
id=data["id"],
registration_key=data["registration_key"],
)
@staticmethod
def get_guild(
config_id: int,
user_performing_action: DATestUser,
) -> dict:
"""Get a specific guild config."""
response = requests.get(
url=f"{DISCORD_BOT_API_URL}/guilds/{config_id}",
headers=user_performing_action.headers,
cookies=user_performing_action.cookies,
)
response.raise_for_status()
return response.json()
@staticmethod
def update_guild(
config_id: int,
user_performing_action: DATestUser,
enabled: bool | None = None,
default_persona_id: int | None = None,
) -> dict:
"""Update a guild config."""
# Fetch current guild config to get existing values
current_guild = DiscordBotManager.get_guild(config_id, user_performing_action)
# Build request body with required fields
body: dict = {
"enabled": enabled if enabled is not None else current_guild["enabled"],
"default_persona_id": (
default_persona_id
if default_persona_id is not None
else current_guild.get("default_persona_id")
),
}
response = requests.patch(
url=f"{DISCORD_BOT_API_URL}/guilds/{config_id}",
headers=user_performing_action.headers,
cookies=user_performing_action.cookies,
json=body,
)
response.raise_for_status()
return response.json()
@staticmethod
def delete_guild(
config_id: int,
user_performing_action: DATestUser,
) -> dict:
"""Delete a guild config."""
response = requests.delete(
url=f"{DISCORD_BOT_API_URL}/guilds/{config_id}",
headers=user_performing_action.headers,
cookies=user_performing_action.cookies,
)
response.raise_for_status()
return response.json()
# === Channel Config ===
@staticmethod
def list_channels(
guild_config_id: int,
user_performing_action: DATestUser,
) -> list[DATestDiscordChannelConfig]:
"""List all channel configs for a guild."""
response = requests.get(
url=f"{DISCORD_BOT_API_URL}/guilds/{guild_config_id}/channels",
headers=user_performing_action.headers,
cookies=user_performing_action.cookies,
)
response.raise_for_status()
return [DATestDiscordChannelConfig(**c) for c in response.json()]
@staticmethod
def update_channel(
guild_config_id: int,
channel_config_id: int,
user_performing_action: DATestUser,
enabled: bool = False,
thread_only_mode: bool = False,
require_bot_invocation: bool = True,
persona_override_id: int | None = None,
) -> DATestDiscordChannelConfig:
"""Update a channel config.
All fields are required by the API. Default values match the channel
config defaults from create_channel_config.
"""
body: dict = {
"enabled": enabled,
"thread_only_mode": thread_only_mode,
"require_bot_invocation": require_bot_invocation,
"persona_override_id": persona_override_id,
}
response = requests.patch(
url=f"{DISCORD_BOT_API_URL}/guilds/{guild_config_id}/channels/{channel_config_id}",
headers=user_performing_action.headers,
cookies=user_performing_action.cookies,
json=body,
)
response.raise_for_status()
return DATestDiscordChannelConfig(**response.json())
# === Utility methods for testing ===
@staticmethod
def create_registered_guild_in_db(
guild_id: int,
guild_name: str,
) -> DATestDiscordGuildConfig:
"""Create a registered guild config directly in the database.
This creates a guild that has already completed registration,
with guild_id and guild_name set. Use this for testing channel
endpoints which require a registered guild.
"""
with get_session_with_current_tenant() as db_session:
tenant_id = get_current_tenant_id()
registration_key = generate_discord_registration_key(tenant_id)
config = create_guild_config(db_session, registration_key)
config = register_guild(db_session, config, guild_id, guild_name)
db_session.commit()
return DATestDiscordGuildConfig(
id=config.id,
registration_key=registration_key,
)
@staticmethod
def get_guild_or_none(
config_id: int,
user_performing_action: DATestUser,
) -> dict | None:
"""Get a guild config, returning None if not found."""
response = requests.get(
url=f"{DISCORD_BOT_API_URL}/guilds/{config_id}",
headers=user_performing_action.headers,
cookies=user_performing_action.cookies,
)
if response.status_code == 404:
return None
response.raise_for_status()
return response.json()
@staticmethod
def delete_guild_if_exists(
config_id: int,
user_performing_action: DATestUser,
) -> bool:
"""Delete a guild config if it exists. Returns True if deleted."""
response = requests.delete(
url=f"{DISCORD_BOT_API_URL}/guilds/{config_id}",
headers=user_performing_action.headers,
cookies=user_performing_action.cookies,
)
if response.status_code == 404:
return False
response.raise_for_status()
return True
@staticmethod
def delete_bot_config_if_exists(
user_performing_action: DATestUser,
) -> bool:
"""Delete bot config if it exists. Returns True if deleted."""
response = requests.delete(
url=f"{DISCORD_BOT_API_URL}/config",
headers=user_performing_action.headers,
cookies=user_performing_action.cookies,
)
if response.status_code == 404:
return False
response.raise_for_status()
return True
@staticmethod
def create_test_channel_in_db(
guild_config_id: int,
channel_id: int,
channel_name: str,
channel_type: str = "text",
is_private: bool = False,
) -> DATestDiscordChannelConfig:
"""Create a test channel config directly in the database.
This is needed because channels are normally synced from Discord,
not created via API. For testing the channel API endpoints,
we need to populate test data directly.
"""
with get_session_with_current_tenant() as db_session:
channel_view = DiscordChannelView(
channel_id=channel_id,
channel_name=channel_name,
channel_type=channel_type,
is_private=is_private,
)
config = create_channel_config(db_session, guild_config_id, channel_view)
db_session.commit()
return DATestDiscordChannelConfig(
id=config.id,
guild_config_id=config.guild_config_id,
channel_id=config.channel_id,
channel_name=config.channel_name,
channel_type=config.channel_type,
is_private=config.is_private,
enabled=config.enabled,
thread_only_mode=config.thread_only_mode,
require_bot_invocation=config.require_bot_invocation,
persona_override_id=config.persona_override_id,
)

View File

@@ -273,3 +273,30 @@ class DATestTool(BaseModel):
description: str
display_name: str
in_code_tool_id: str | None
# Discord Bot Models
class DATestDiscordGuildConfig(BaseModel):
"""Discord guild config model for testing."""
id: int
registration_key: str | None = None # Only present on creation
guild_id: int | None = None
guild_name: str | None = None
enabled: bool = True
default_persona_id: int | None = None
class DATestDiscordChannelConfig(BaseModel):
"""Discord channel config model for testing."""
id: int
guild_config_id: int
channel_id: int
channel_name: str
channel_type: str
is_private: bool
enabled: bool = False
thread_only_mode: bool = False
require_bot_invocation: bool = True
persona_override_id: int | None = None

View File

@@ -0,0 +1,456 @@
"""Multi-tenant isolation tests for Discord bot.
These tests ensure tenant isolation and prevent data leakage between tenants.
Tests follow the multi-tenant integration test pattern using API requests.
"""
from unittest.mock import patch
from uuid import uuid4
import pytest
import requests
from onyx.configs.constants import AuthType
from onyx.db.discord_bot import get_guild_config_by_registration_key
from onyx.db.discord_bot import register_guild
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.models import UserRole
from onyx.onyxbot.discord.cache import DiscordCacheManager
from onyx.server.manage.discord_bot.utils import generate_discord_registration_key
from onyx.server.manage.discord_bot.utils import parse_discord_registration_key
from onyx.server.manage.discord_bot.utils import REGISTRATION_KEY_PREFIX
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
class TestBotConfigIsolationCloudMode:
"""Tests for bot config isolation in cloud mode."""
def test_cannot_create_bot_config_in_cloud_mode(self) -> None:
"""Bot config creation is blocked in cloud mode."""
with patch("onyx.configs.app_configs.AUTH_TYPE", AuthType.CLOUD):
from fastapi import HTTPException
from onyx.server.manage.discord_bot.api import _check_bot_config_api_access
with pytest.raises(HTTPException) as exc_info:
_check_bot_config_api_access()
assert exc_info.value.status_code == 403
assert "Cloud" in str(exc_info.value.detail)
def test_bot_token_from_env_only_in_cloud(self) -> None:
"""Bot token comes from env var in cloud mode, ignores DB."""
from onyx.onyxbot.discord.utils import get_bot_token
with (
patch("onyx.onyxbot.discord.utils.DISCORD_BOT_TOKEN", "env_token"),
patch("onyx.onyxbot.discord.utils.AUTH_TYPE", AuthType.CLOUD),
):
result = get_bot_token()
assert result == "env_token"
class TestGuildRegistrationIsolation:
"""Tests for guild registration isolation between tenants."""
def test_guild_can_only_register_to_one_tenant(self) -> None:
"""Guild registered to tenant 1 cannot be registered to tenant 2."""
cache = DiscordCacheManager()
# Register guild to tenant 1
cache._guild_tenants[123456789] = "tenant1"
# Check if guild is already registered
existing = cache.get_tenant(123456789)
assert existing is not None
assert existing == "tenant1"
def test_registration_key_tenant_mismatch(self) -> None:
"""Key created in tenant 1 cannot be used in tenant 2 context."""
key = generate_discord_registration_key("tenant1")
# Parse the key to get tenant
parsed_tenant = parse_discord_registration_key(key)
assert parsed_tenant == "tenant1"
assert parsed_tenant != "tenant2"
def test_registration_key_encodes_correct_tenant(self) -> None:
"""Key format discord_<tenant_id>.<token> encodes correct tenant."""
tenant_id = "my_tenant_123"
key = generate_discord_registration_key(tenant_id)
assert key.startswith(REGISTRATION_KEY_PREFIX)
assert "my_tenant_123" in key or "my%5Ftenant%5F123" in key
parsed = parse_discord_registration_key(key)
assert parsed == tenant_id
class TestGuildDataIsolation:
"""Tests for guild data isolation between tenants via API."""
def test_tenant_cannot_see_other_tenant_guilds(
self, reset_multitenant: None
) -> None:
"""Guilds created in tenant 1 are not visible from tenant 2.
Creates guilds via API in tenant 1, then queries from tenant 2
context to verify the guilds are not visible.
"""
unique = uuid4().hex
# Create admin user for tenant 1
admin_user1: DATestUser = UserManager.create(
email=f"discord_admin1+{unique}@example.com",
)
assert UserManager.is_role(admin_user1, UserRole.ADMIN)
# Create admin user for tenant 2
admin_user2: DATestUser = UserManager.create(
email=f"discord_admin2+{unique}@example.com",
)
assert UserManager.is_role(admin_user2, UserRole.ADMIN)
# Create a guild registration key in tenant 1
response1 = requests.post(
f"{API_SERVER_URL}/manage/admin/discord-bot/guilds",
headers=admin_user1.headers,
)
# If Discord bot feature is not enabled, skip the test
if response1.status_code == 404:
pytest.skip("Discord bot feature not enabled")
assert response1.ok, f"Failed to create guild in tenant 1: {response1.text}"
guild1_data = response1.json()
guild1_id = guild1_data["id"]
try:
# List guilds from tenant 1 - should see the guild
list_response1 = requests.get(
f"{API_SERVER_URL}/manage/admin/discord-bot/guilds",
headers=admin_user1.headers,
)
assert list_response1.ok
tenant1_guilds = list_response1.json()
tenant1_guild_ids = [g["id"] for g in tenant1_guilds]
assert guild1_id in tenant1_guild_ids
# List guilds from tenant 2 - should NOT see tenant 1's guild
list_response2 = requests.get(
f"{API_SERVER_URL}/manage/admin/discord-bot/guilds",
headers=admin_user2.headers,
)
assert list_response2.ok
tenant2_guilds = list_response2.json()
tenant2_guild_ids = [g["id"] for g in tenant2_guilds]
assert guild1_id not in tenant2_guild_ids
finally:
# Cleanup - delete guild from tenant 1
requests.delete(
f"{API_SERVER_URL}/manage/admin/discord-bot/guilds/{guild1_id}",
headers=admin_user1.headers,
)
def test_guild_list_returns_only_own_tenant(self, reset_multitenant: None) -> None:
"""List guilds returns exactly the guilds for that tenant.
Creates 1 guild in each tenant, registers them with different data,
and verifies each tenant only sees their own guild.
"""
unique = uuid4().hex
# Create admin users for two tenants
admin_user1: DATestUser = UserManager.create(
email=f"discord_list1+{unique}@example.com",
)
admin_user2: DATestUser = UserManager.create(
email=f"discord_list2+{unique}@example.com",
)
# Create 1 guild in tenant 1
response1 = requests.post(
f"{API_SERVER_URL}/manage/admin/discord-bot/guilds",
headers=admin_user1.headers,
)
if response1.status_code == 404:
pytest.skip("Discord bot feature not enabled")
assert response1.ok, f"Failed to create guild in tenant 1: {response1.text}"
guild1_data = response1.json()
guild1_id = guild1_data["id"]
registration_key1 = guild1_data["registration_key"]
tenant1_id = parse_discord_registration_key(registration_key1)
assert (
tenant1_id is not None
), "Failed to parse tenant ID from registration key 1"
# Create 1 guild in tenant 2
response2 = requests.post(
f"{API_SERVER_URL}/manage/admin/discord-bot/guilds",
headers=admin_user2.headers,
)
assert response2.ok, f"Failed to create guild in tenant 2: {response2.text}"
guild2_data = response2.json()
guild2_id = guild2_data["id"]
registration_key2 = guild2_data["registration_key"]
tenant2_id = parse_discord_registration_key(registration_key2)
assert (
tenant2_id is not None
), "Failed to parse tenant ID from registration key 2"
# Verify tenant IDs are different
assert (
tenant1_id != tenant2_id
), "Tenant 1 and tenant 2 should have different tenant IDs"
# Register guild 1 with tenant 1's context - populate with different data
with get_session_with_tenant(tenant_id=tenant1_id) as db_session:
config1 = get_guild_config_by_registration_key(
db_session, registration_key1
)
assert config1 is not None, "Guild config 1 should exist"
register_guild(
db_session=db_session,
config=config1,
guild_id=111111111111111111, # Different Discord guild ID
guild_name="Tenant 1 Server", # Different guild name
)
db_session.commit()
# Register guild 2 with tenant 2's context - populate with different data
with get_session_with_tenant(tenant_id=tenant2_id) as db_session:
config2 = get_guild_config_by_registration_key(
db_session, registration_key2
)
assert config2 is not None, "Guild config 2 should exist"
register_guild(
db_session=db_session,
config=config2,
guild_id=222222222222222222, # Different Discord guild ID
guild_name="Tenant 2 Server", # Different guild name
)
db_session.commit()
try:
# Verify tenant 1 sees only their guild
list_response1 = requests.get(
f"{API_SERVER_URL}/manage/admin/discord-bot/guilds",
headers=admin_user1.headers,
)
assert list_response1.ok
tenant1_guilds = list_response1.json()
# Tenant 1 should see exactly 1 guild
assert (
len(tenant1_guilds) == 1
), f"Tenant 1 should see 1 guild, got {len(tenant1_guilds)}"
# Verify tenant 1's guild has the correct data
tenant1_guild = tenant1_guilds[0]
assert (
tenant1_guild["id"] == guild1_id
), "Tenant 1 should see their own guild"
assert tenant1_guild["guild_id"] == 111111111111111111, (
f"Tenant 1's guild should have guild_id 111111111111111111, "
f"got {tenant1_guild['guild_id']}"
)
assert tenant1_guild["guild_name"] == "Tenant 1 Server", (
f"Tenant 1's guild should have name 'Tenant 1 Server', "
f"got {tenant1_guild['guild_name']}"
)
assert (
tenant1_guild["registered_at"] is not None
), "Tenant 1's guild should be registered"
# Tenant 1 should NOT see tenant 2's guild
assert (
tenant1_guild["guild_id"] != 222222222222222222
), "Tenant 1 should not see tenant 2's guild_id"
assert (
tenant1_guild["guild_name"] != "Tenant 2 Server"
), "Tenant 1 should not see tenant 2's guild_name"
# Verify tenant 2 sees only their guild
list_response2 = requests.get(
f"{API_SERVER_URL}/manage/admin/discord-bot/guilds",
headers=admin_user2.headers,
)
assert list_response2.ok
tenant2_guilds = list_response2.json()
# Tenant 2 should see exactly 1 guild
assert (
len(tenant2_guilds) == 1
), f"Tenant 2 should see 1 guild, got {len(tenant2_guilds)}"
# Verify tenant 2's guild has the correct data
tenant2_guild = tenant2_guilds[0]
assert (
tenant2_guild["id"] == guild2_id
), "Tenant 2 should see their own guild"
assert tenant2_guild["guild_id"] == 222222222222222222, (
f"Tenant 2's guild should have guild_id 222222222222222222, "
f"got {tenant2_guild['guild_id']}"
)
assert tenant2_guild["guild_name"] == "Tenant 2 Server", (
f"Tenant 2's guild should have name 'Tenant 2 Server', "
f"got {tenant2_guild['guild_name']}"
)
assert (
tenant2_guild["registered_at"] is not None
), "Tenant 2's guild should be registered"
# Tenant 2 should NOT see tenant 1's guild
assert (
tenant2_guild["guild_id"] != 111111111111111111
), "Tenant 2 should not see tenant 1's guild_id"
assert (
tenant2_guild["guild_name"] != "Tenant 1 Server"
), "Tenant 2 should not see tenant 1's guild_name"
# Verify the guilds are different (different data)
assert (
tenant1_guild["guild_id"] != tenant2_guild["guild_id"]
), "Guilds should have different Discord guild IDs"
assert (
tenant1_guild["guild_name"] != tenant2_guild["guild_name"]
), "Guilds should have different names"
finally:
# Cleanup
requests.delete(
f"{API_SERVER_URL}/manage/admin/discord-bot/guilds/{guild1_id}",
headers=admin_user1.headers,
)
requests.delete(
f"{API_SERVER_URL}/manage/admin/discord-bot/guilds/{guild2_id}",
headers=admin_user2.headers,
)
class TestGuildAccessIsolation:
"""Tests for guild access isolation between tenants."""
def test_tenant_cannot_access_other_tenant_guild(
self, reset_multitenant: None
) -> None:
"""Tenant 2 cannot access or modify tenant 1's guild by ID.
Creates a guild in tenant 1, then attempts to access it from tenant 2.
"""
unique = uuid4().hex
# Create admin users for two tenants
admin_user1: DATestUser = UserManager.create(
email=f"discord_access1+{unique}@example.com",
)
admin_user2: DATestUser = UserManager.create(
email=f"discord_access2+{unique}@example.com",
)
# Create a guild in tenant 1
response = requests.post(
f"{API_SERVER_URL}/manage/admin/discord-bot/guilds",
headers=admin_user1.headers,
)
if response.status_code == 404:
pytest.skip("Discord bot feature not enabled")
assert response.ok
guild1_id = response.json()["id"]
try:
# Tenant 2 tries to get the guild - should fail (404 or 403)
get_response = requests.get(
f"{API_SERVER_URL}/manage/admin/discord-bot/guilds/{guild1_id}",
headers=admin_user2.headers,
)
# Should either return 404 (not found) or 403 (forbidden)
assert get_response.status_code in [
403,
404,
], f"Expected 403 or 404, got {get_response.status_code}"
# Tenant 2 tries to delete the guild - should fail
delete_response = requests.delete(
f"{API_SERVER_URL}/manage/admin/discord-bot/guilds/{guild1_id}",
headers=admin_user2.headers,
)
assert delete_response.status_code in [403, 404]
finally:
# Cleanup - delete from tenant 1
requests.delete(
f"{API_SERVER_URL}/manage/admin/discord-bot/guilds/{guild1_id}",
headers=admin_user1.headers,
)
class TestCacheManagerIsolation:
"""Tests for cache manager tenant isolation."""
def test_cache_maps_guild_to_correct_tenant(self) -> None:
"""Cache correctly maps guild_id to tenant_id."""
cache = DiscordCacheManager()
# Set up mappings
cache._guild_tenants[111] = "tenant1"
cache._guild_tenants[222] = "tenant2"
cache._guild_tenants[333] = "tenant1"
assert cache.get_tenant(111) == "tenant1"
assert cache.get_tenant(222) == "tenant2"
assert cache.get_tenant(333) == "tenant1"
assert cache.get_tenant(444) is None
def test_api_key_per_tenant_isolation(self) -> None:
"""Each tenant has unique API key."""
cache = DiscordCacheManager()
cache._api_keys["tenant1"] = "key_for_tenant1"
cache._api_keys["tenant2"] = "key_for_tenant2"
assert cache.get_api_key("tenant1") == "key_for_tenant1"
assert cache.get_api_key("tenant2") == "key_for_tenant2"
assert cache.get_api_key("tenant1") != cache.get_api_key("tenant2")
class TestAPIRequestIsolation:
"""Tests for API request isolation between tenants."""
@pytest.mark.asyncio
async def test_discord_bot_uses_tenant_specific_api_key(self) -> None:
"""Message from guild in tenant 1 uses tenant 1's API key."""
cache = DiscordCacheManager()
cache._guild_tenants[123456] = "tenant1"
cache._api_keys["tenant1"] = "tenant1_api_key"
cache._api_keys["tenant2"] = "tenant2_api_key"
# When processing message from guild 123456
tenant = cache.get_tenant(123456)
assert tenant is not None
api_key = cache.get_api_key(tenant)
assert tenant == "tenant1"
assert api_key == "tenant1_api_key"
assert api_key != "tenant2_api_key"
@pytest.mark.asyncio
async def test_guild_message_routes_to_correct_tenant(self) -> None:
"""Message from registered guild routes to correct tenant context."""
cache = DiscordCacheManager()
cache._guild_tenants[999] = "target_tenant"
cache._api_keys["target_tenant"] = "target_key"
# Simulate message routing
guild_id = 999
tenant = cache.get_tenant(guild_id)
api_key = cache.get_api_key(tenant) if tenant else None
assert tenant == "target_tenant"
assert api_key == "target_key"

View File

@@ -0,0 +1,185 @@
from uuid import uuid4
import pytest
import requests
from requests import HTTPError
from onyx.auth.schemas import UserRole
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.user import build_email
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.reset import reset_all
from tests.integration.common_utils.test_models import DATestUser
@pytest.fixture(scope="module", autouse=True)
def reset_for_module() -> None:
"""Reset all data once before running any tests in this module."""
reset_all()
@pytest.fixture
def second_user(admin_user: DATestUser) -> DATestUser:
# Ensure admin exists so this new user is created with BASIC role.
try:
return UserManager.create(name="second_basic_user")
except HTTPError as e:
response = e.response
if response is None:
raise
if response.status_code not in (400, 409):
raise
try:
payload = response.json()
except ValueError:
raise
detail = payload.get("detail")
if not _is_user_already_exists_detail(detail):
raise
print("Second basic user already exists; logging in instead.")
return UserManager.login_as_user(
DATestUser(
id="",
email=build_email("second_basic_user"),
password=DEFAULT_PASSWORD,
headers=GENERAL_HEADERS,
role=UserRole.BASIC,
is_active=True,
)
)
def _is_user_already_exists_detail(detail: object) -> bool:
if isinstance(detail, str):
normalized = detail.lower()
return (
"already exists" in normalized
or "register_user_already_exists" in normalized
)
if isinstance(detail, dict):
code = detail.get("code")
if isinstance(code, str) and code.lower() == "register_user_already_exists":
return True
message = detail.get("message")
if isinstance(message, str) and "already exists" in message.lower():
return True
return False
def _get_chat_session(
chat_session_id: str,
user: DATestUser,
is_shared: bool | None = None,
include_deleted: bool | None = None,
) -> requests.Response:
params: dict[str, str] = {}
if is_shared is not None:
params["is_shared"] = str(is_shared).lower()
if include_deleted is not None:
params["include_deleted"] = str(include_deleted).lower()
return requests.get(
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session_id}",
params=params,
headers=user.headers,
cookies=user.cookies,
)
def _set_sharing_status(
chat_session_id: str, sharing_status: str, user: DATestUser
) -> requests.Response:
return requests.patch(
f"{API_SERVER_URL}/chat/chat-session/{chat_session_id}",
json={"sharing_status": sharing_status},
headers=user.headers,
cookies=user.cookies,
)
def test_private_chat_session_access(
basic_user: DATestUser, second_user: DATestUser
) -> None:
"""Verify private sessions are only accessible by the owner and never via share link."""
# Create a private chat session owned by basic_user.
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
# Owner can access the private session normally.
response = _get_chat_session(str(chat_session.id), basic_user)
assert response.status_code == 200
# Share link should be forbidden when the session is private.
response = _get_chat_session(str(chat_session.id), basic_user, is_shared=True)
assert response.status_code == 403
# Other users cannot access private sessions directly.
response = _get_chat_session(str(chat_session.id), second_user)
assert response.status_code == 403
# Other users also cannot access private sessions via share link.
response = _get_chat_session(str(chat_session.id), second_user, is_shared=True)
assert response.status_code == 403
def test_public_shared_chat_session_access(
basic_user: DATestUser, second_user: DATestUser
) -> None:
"""Verify shared sessions are accessible only via share link for non-owners."""
# Create a private session, then mark it public.
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
response = _set_sharing_status(str(chat_session.id), "public", basic_user)
assert response.status_code == 200
# Owner can access normally.
response = _get_chat_session(str(chat_session.id), basic_user)
assert response.status_code == 200
# Owner can also access via share link.
response = _get_chat_session(str(chat_session.id), basic_user, is_shared=True)
assert response.status_code == 200
# Non-owner cannot access without share link.
response = _get_chat_session(str(chat_session.id), second_user)
assert response.status_code == 403
# Non-owner can access with share link for public sessions.
response = _get_chat_session(str(chat_session.id), second_user, is_shared=True)
assert response.status_code == 200
def test_deleted_chat_session_access(
basic_user: DATestUser, second_user: DATestUser
) -> None:
"""Verify deleted sessions return 404, with include_deleted gated by access checks."""
# Create and soft-delete a session.
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
deletion_success = ChatSessionManager.soft_delete(
chat_session=chat_session, user_performing_action=basic_user
)
assert deletion_success is True
# Deleted sessions are not accessible normally.
response = _get_chat_session(str(chat_session.id), basic_user)
assert response.status_code == 404
# Owner can fetch deleted session only with include_deleted.
response = _get_chat_session(str(chat_session.id), basic_user, include_deleted=True)
assert response.status_code == 200
assert response.json().get("deleted") is True
# Non-owner should be blocked even with include_deleted.
response = _get_chat_session(
str(chat_session.id), second_user, include_deleted=True
)
assert response.status_code == 403
def test_chat_session_not_found_returns_404(basic_user: DATestUser) -> None:
"""Verify unknown IDs return 404."""
response = _get_chat_session(str(uuid4()), basic_user)
assert response.status_code == 404

View File

@@ -0,0 +1,443 @@
"""Integration tests for Discord bot API endpoints.
These tests hit actual API endpoints via HTTP requests.
"""
import pytest
import requests
from onyx.db.discord_bot import get_discord_service_api_key
from onyx.db.discord_bot import get_or_create_discord_service_api_key
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from tests.integration.common_utils.managers.discord_bot import DiscordBotManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
class TestBotConfigEndpoints:
"""Tests for /manage/admin/discord-bot/config endpoints."""
def test_get_bot_config_not_configured(self, reset: None) -> None:
"""GET /config returns configured=False when no config exists."""
admin_user: DATestUser = UserManager.create(name="admin_user")
# Ensure no config exists
DiscordBotManager.delete_bot_config_if_exists(admin_user)
config = DiscordBotManager.get_bot_config(admin_user)
assert config["configured"] is False
assert "created_at" not in config or config.get("created_at") is None
def test_create_bot_config(self, reset: None) -> None:
"""POST /config creates a new bot config."""
admin_user: DATestUser = UserManager.create(name="admin_user")
# Ensure no config exists
DiscordBotManager.delete_bot_config_if_exists(admin_user)
config = DiscordBotManager.create_bot_config(
bot_token="test_token_123",
user_performing_action=admin_user,
)
assert config["configured"] is True
assert "created_at" in config
# Cleanup
DiscordBotManager.delete_bot_config_if_exists(admin_user)
def test_create_bot_config_already_exists(self, reset: None) -> None:
"""POST /config returns 409 if config already exists."""
admin_user: DATestUser = UserManager.create(name="admin_user")
# Ensure no config exists, then create one
DiscordBotManager.delete_bot_config_if_exists(admin_user)
DiscordBotManager.create_bot_config(
bot_token="token1",
user_performing_action=admin_user,
)
# Try to create another - should fail
with pytest.raises(requests.HTTPError) as exc_info:
DiscordBotManager.create_bot_config(
bot_token="token2",
user_performing_action=admin_user,
)
assert exc_info.value.response.status_code == 409
# Cleanup
DiscordBotManager.delete_bot_config_if_exists(admin_user)
def test_delete_bot_config(self, reset: None) -> None:
"""DELETE /config removes the bot config."""
admin_user: DATestUser = UserManager.create(name="admin_user")
# Ensure no config exists, then create one
DiscordBotManager.delete_bot_config_if_exists(admin_user)
DiscordBotManager.create_bot_config(
bot_token="test_token",
user_performing_action=admin_user,
)
# Delete it
result = DiscordBotManager.delete_bot_config(admin_user)
assert result["deleted"] is True
# Verify it's gone
config = DiscordBotManager.get_bot_config(admin_user)
assert config["configured"] is False
def test_delete_bot_config_not_found(self, reset: None) -> None:
"""DELETE /config returns 404 if no config exists."""
admin_user: DATestUser = UserManager.create(name="admin_user")
# Ensure no config exists
DiscordBotManager.delete_bot_config_if_exists(admin_user)
# Try to delete - should fail
with pytest.raises(requests.HTTPError) as exc_info:
DiscordBotManager.delete_bot_config(admin_user)
assert exc_info.value.response.status_code == 404
class TestGuildConfigEndpoints:
"""Tests for /manage/admin/discord-bot/guilds endpoints."""
def test_create_guild_config(self, reset: None) -> None:
"""POST /guilds creates a new guild config with registration key."""
admin_user: DATestUser = UserManager.create(name="admin_user")
guild = DiscordBotManager.create_guild(admin_user)
assert guild.id is not None
assert guild.registration_key is not None
assert guild.registration_key.startswith("discord_")
# Cleanup
DiscordBotManager.delete_guild_if_exists(guild.id, admin_user)
def test_list_guilds(self, reset: None) -> None:
"""GET /guilds returns all guild configs."""
admin_user: DATestUser = UserManager.create(name="admin_user")
# Create some guilds
guild1 = DiscordBotManager.create_guild(admin_user)
guild2 = DiscordBotManager.create_guild(admin_user)
guilds = DiscordBotManager.list_guilds(admin_user)
guild_ids = [g["id"] for g in guilds]
assert guild1.id in guild_ids
assert guild2.id in guild_ids
# Cleanup
DiscordBotManager.delete_guild_if_exists(guild1.id, admin_user)
DiscordBotManager.delete_guild_if_exists(guild2.id, admin_user)
def test_get_guild_config(self, reset: None) -> None:
"""GET /guilds/{config_id} returns the specific guild config."""
admin_user: DATestUser = UserManager.create(name="admin_user")
guild = DiscordBotManager.create_guild(admin_user)
fetched = DiscordBotManager.get_guild(guild.id, admin_user)
assert fetched["id"] == guild.id
assert fetched["enabled"] is True # Default
assert fetched["guild_id"] is None # Not registered yet
assert fetched["guild_name"] is None
# Cleanup
DiscordBotManager.delete_guild_if_exists(guild.id, admin_user)
def test_get_guild_config_not_found(self, reset: None) -> None:
"""GET /guilds/{config_id} returns 404 for non-existent guild."""
admin_user: DATestUser = UserManager.create(name="admin_user")
result = DiscordBotManager.get_guild_or_none(999999, admin_user)
assert result is None
def test_update_guild_config(self, reset: None) -> None:
"""PATCH /guilds/{config_id} updates the guild config."""
admin_user: DATestUser = UserManager.create(name="admin_user")
guild = DiscordBotManager.create_guild(admin_user)
# Update enabled status
updated = DiscordBotManager.update_guild(
guild.id,
admin_user,
enabled=False,
)
assert updated["enabled"] is False
# Verify persistence
fetched = DiscordBotManager.get_guild(guild.id, admin_user)
assert fetched["enabled"] is False
# Cleanup
DiscordBotManager.delete_guild_if_exists(guild.id, admin_user)
def test_delete_guild_config(self, reset: None) -> None:
"""DELETE /guilds/{config_id} removes the guild config."""
admin_user: DATestUser = UserManager.create(name="admin_user")
guild = DiscordBotManager.create_guild(admin_user)
# Delete it
result = DiscordBotManager.delete_guild(guild.id, admin_user)
assert result["deleted"] is True
# Verify it's gone
assert DiscordBotManager.get_guild_or_none(guild.id, admin_user) is None
def test_delete_guild_config_not_found(self, reset: None) -> None:
"""DELETE /guilds/{config_id} returns 404 for non-existent guild."""
admin_user: DATestUser = UserManager.create(name="admin_user")
with pytest.raises(requests.HTTPError) as exc_info:
DiscordBotManager.delete_guild(999999, admin_user)
assert exc_info.value.response.status_code == 404
def test_registration_key_format(self, reset: None) -> None:
"""Registration key has proper format with tenant encoded."""
admin_user: DATestUser = UserManager.create(name="admin_user")
guild = DiscordBotManager.create_guild(admin_user)
# Key should be: discord_{encoded_tenant}.{random}
key = guild.registration_key
assert key is not None
assert key.startswith("discord_")
# Should have two parts separated by dot
key_body = key.removeprefix("discord_")
parts = key_body.split(".", 1)
assert len(parts) == 2
# Cleanup
DiscordBotManager.delete_guild_if_exists(guild.id, admin_user)
def test_each_registration_key_is_unique(self, reset: None) -> None:
"""Each created guild gets a unique registration key."""
admin_user: DATestUser = UserManager.create(name="admin_user")
guilds = [DiscordBotManager.create_guild(admin_user) for _ in range(5)]
keys = [g.registration_key for g in guilds]
assert len(set(keys)) == 5 # All unique
# Cleanup
for guild in guilds:
DiscordBotManager.delete_guild_if_exists(guild.id, admin_user)
class TestChannelConfigEndpoints:
"""Tests for /manage/admin/discord-bot/guilds/{id}/channels endpoints."""
def test_list_channels_empty(self, reset: None) -> None:
"""GET /guilds/{id}/channels returns empty list when no channels exist."""
admin_user: DATestUser = UserManager.create(name="admin_user")
# Create a registered guild (has guild_id set)
guild = DiscordBotManager.create_registered_guild_in_db(
guild_id=111111111,
guild_name="Test Guild",
)
channels = DiscordBotManager.list_channels(guild.id, admin_user)
assert channels == []
# Cleanup
DiscordBotManager.delete_guild_if_exists(guild.id, admin_user)
def test_list_channels_with_data(self, reset: None) -> None:
"""GET /guilds/{id}/channels returns channel configs."""
admin_user: DATestUser = UserManager.create(name="admin_user")
# Create a registered guild (has guild_id set)
guild = DiscordBotManager.create_registered_guild_in_db(
guild_id=222222222,
guild_name="Test Guild",
)
# Create test channels directly in DB
channel1 = DiscordBotManager.create_test_channel_in_db(
guild_config_id=guild.id,
channel_id=123456789,
channel_name="general",
)
channel2 = DiscordBotManager.create_test_channel_in_db(
guild_config_id=guild.id,
channel_id=987654321,
channel_name="help",
channel_type="forum",
)
channels = DiscordBotManager.list_channels(guild.id, admin_user)
assert len(channels) == 2
channel_ids = [c.id for c in channels]
assert channel1.id in channel_ids
assert channel2.id in channel_ids
# Cleanup
DiscordBotManager.delete_guild_if_exists(guild.id, admin_user)
def test_update_channel_enabled(self, reset: None) -> None:
"""PATCH /guilds/{id}/channels/{id} updates enabled status."""
admin_user: DATestUser = UserManager.create(name="admin_user")
# Create a registered guild (has guild_id set)
guild = DiscordBotManager.create_registered_guild_in_db(
guild_id=333333333,
guild_name="Test Guild",
)
channel = DiscordBotManager.create_test_channel_in_db(
guild_config_id=guild.id,
channel_id=123456789,
channel_name="general",
)
# Default is disabled
assert channel.enabled is False
# Enable the channel
updated = DiscordBotManager.update_channel(
guild.id,
channel.id,
admin_user,
enabled=True,
)
assert updated.enabled is True
# Verify persistence
channels = DiscordBotManager.list_channels(guild.id, admin_user)
found = next(c for c in channels if c.id == channel.id)
assert found.enabled is True
# Cleanup
DiscordBotManager.delete_guild_if_exists(guild.id, admin_user)
def test_update_channel_thread_only_mode(self, reset: None) -> None:
"""PATCH /guilds/{id}/channels/{id} updates thread_only_mode."""
admin_user: DATestUser = UserManager.create(name="admin_user")
# Create a registered guild (has guild_id set)
guild = DiscordBotManager.create_registered_guild_in_db(
guild_id=444444444,
guild_name="Test Guild",
)
channel = DiscordBotManager.create_test_channel_in_db(
guild_config_id=guild.id,
channel_id=123456789,
channel_name="general",
)
# Default is False
assert channel.thread_only_mode is False
# Enable thread_only_mode
updated = DiscordBotManager.update_channel(
guild.id,
channel.id,
admin_user,
thread_only_mode=True,
)
assert updated.thread_only_mode is True
# Cleanup
DiscordBotManager.delete_guild_if_exists(guild.id, admin_user)
def test_update_channel_require_bot_invocation(self, reset: None) -> None:
"""PATCH /guilds/{id}/channels/{id} updates require_bot_invocation."""
admin_user: DATestUser = UserManager.create(name="admin_user")
# Create a registered guild (has guild_id set)
guild = DiscordBotManager.create_registered_guild_in_db(
guild_id=555555555,
guild_name="Test Guild",
)
channel = DiscordBotManager.create_test_channel_in_db(
guild_config_id=guild.id,
channel_id=123456789,
channel_name="general",
)
# Default is True
assert channel.require_bot_invocation is True
# Disable require_bot_invocation
updated = DiscordBotManager.update_channel(
guild.id,
channel.id,
admin_user,
require_bot_invocation=False,
)
assert updated.require_bot_invocation is False
# Cleanup
DiscordBotManager.delete_guild_if_exists(guild.id, admin_user)
def test_update_channel_not_found(self, reset: None) -> None:
"""PATCH /guilds/{id}/channels/{id} returns 404 for non-existent channel."""
admin_user: DATestUser = UserManager.create(name="admin_user")
# Create a registered guild (has guild_id set)
guild = DiscordBotManager.create_registered_guild_in_db(
guild_id=666666666,
guild_name="Test Guild",
)
with pytest.raises(requests.HTTPError) as exc_info:
DiscordBotManager.update_channel(
guild.id,
999999,
admin_user,
enabled=True,
)
assert exc_info.value.response.status_code == 404
# Cleanup
DiscordBotManager.delete_guild_if_exists(guild.id, admin_user)
class TestServiceApiKeyCleanup:
"""Tests for service API key cleanup when bot/guild configs are deleted."""
def test_delete_bot_config_also_deletes_service_api_key(self, reset: None) -> None:
"""DELETE /config also deletes the service API key (self-hosted flow)."""
admin_user: DATestUser = UserManager.create(name="admin_user")
# Setup: create bot config via API
DiscordBotManager.delete_bot_config_if_exists(admin_user)
DiscordBotManager.create_bot_config(
bot_token="test_token",
user_performing_action=admin_user,
)
# Create service API key directly in DB (simulating bot registration)
with get_session_with_current_tenant() as db_session:
get_or_create_discord_service_api_key(db_session, "public")
db_session.commit()
# Verify it exists
assert get_discord_service_api_key(db_session) is not None
# Delete bot config via API
result = DiscordBotManager.delete_bot_config(admin_user)
assert result["deleted"] is True
# Verify service API key was also deleted
with get_session_with_current_tenant() as db_session:
assert get_discord_service_api_key(db_session) is None

View File

@@ -0,0 +1,673 @@
"""Integration tests for Discord bot database operations.
These tests verify CRUD operations for Discord bot models.
"""
from collections.abc import Generator
import pytest
from sqlalchemy.orm import Session
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.discord_bot import bulk_create_channel_configs
from onyx.db.discord_bot import create_discord_bot_config
from onyx.db.discord_bot import create_guild_config
from onyx.db.discord_bot import delete_discord_bot_config
from onyx.db.discord_bot import delete_discord_service_api_key
from onyx.db.discord_bot import delete_guild_config
from onyx.db.discord_bot import get_channel_configs
from onyx.db.discord_bot import get_discord_bot_config
from onyx.db.discord_bot import get_discord_service_api_key
from onyx.db.discord_bot import get_guild_config_by_internal_id
from onyx.db.discord_bot import get_guild_config_by_registration_key
from onyx.db.discord_bot import get_guild_configs
from onyx.db.discord_bot import get_or_create_discord_service_api_key
from onyx.db.discord_bot import sync_channel_configs
from onyx.db.discord_bot import update_discord_channel_config
from onyx.db.discord_bot import update_guild_config
from onyx.db.models import Persona
from onyx.db.utils import DiscordChannelView
from onyx.server.manage.discord_bot.utils import generate_discord_registration_key
def _create_test_persona(db_session: Session, persona_id: int, name: str) -> Persona:
"""Create a minimal test persona."""
persona = Persona(
id=persona_id,
name=name,
description="Test persona for Discord bot tests",
num_chunks=5.0,
chunks_above=1,
chunks_below=1,
llm_relevance_filter=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.FAVOR_RECENT,
is_visible=True,
is_default_persona=False,
deleted=False,
builtin_persona=False,
)
db_session.add(persona)
db_session.flush()
return persona
def _delete_test_persona(db_session: Session, persona_id: int) -> None:
"""Delete a test persona."""
db_session.query(Persona).filter(Persona.id == persona_id).delete()
db_session.flush()
class TestBotConfigAPI:
"""Tests for bot config API operations."""
def test_create_bot_config(self, db_session: Session) -> None:
"""Create bot config succeeds with valid token."""
# Clean up any existing config first
delete_discord_bot_config(db_session)
db_session.commit()
config = create_discord_bot_config(db_session, bot_token="test_token_123")
db_session.commit()
assert config is not None
assert config.bot_token == "test_token_123"
# Cleanup
delete_discord_bot_config(db_session)
db_session.commit()
def test_create_bot_config_already_exists(self, db_session: Session) -> None:
"""Creating config twice raises ValueError."""
# Clean up first
delete_discord_bot_config(db_session)
db_session.commit()
create_discord_bot_config(db_session, bot_token="token1")
db_session.commit()
with pytest.raises(ValueError):
create_discord_bot_config(db_session, bot_token="token2")
# Cleanup
delete_discord_bot_config(db_session)
db_session.commit()
def test_get_bot_config(self, db_session: Session) -> None:
"""Get bot config returns config with masked token."""
# Clean up first
delete_discord_bot_config(db_session)
db_session.commit()
create_discord_bot_config(db_session, bot_token="my_secret_token")
db_session.commit()
config = get_discord_bot_config(db_session)
assert config is not None
# Token should be stored (we don't mask in DB, only API response)
assert config.bot_token is not None
# Cleanup
delete_discord_bot_config(db_session)
db_session.commit()
def test_delete_bot_config(self, db_session: Session) -> None:
"""Delete bot config removes it from DB."""
# Clean up first
delete_discord_bot_config(db_session)
db_session.commit()
create_discord_bot_config(db_session, bot_token="token")
db_session.commit()
deleted = delete_discord_bot_config(db_session)
db_session.commit()
assert deleted is True
assert get_discord_bot_config(db_session) is None
def test_delete_bot_config_not_found(self, db_session: Session) -> None:
"""Delete when no config exists returns False."""
# Ensure no config exists
delete_discord_bot_config(db_session)
db_session.commit()
deleted = delete_discord_bot_config(db_session)
assert deleted is False
class TestRegistrationKeyAPI:
"""Tests for registration key API operations."""
def test_create_registration_key(self, db_session: Session) -> None:
"""Create registration key with proper format."""
key = generate_discord_registration_key("test_tenant")
config = create_guild_config(db_session, registration_key=key)
db_session.commit()
assert config is not None
assert config.registration_key == key
assert key.startswith("discord_")
assert "test_tenant" in key or "test%5Ftenant" in key
# Cleanup
delete_guild_config(db_session, config.id)
db_session.commit()
def test_registration_key_is_unique(self, db_session: Session) -> None:
"""Each generated key is unique."""
keys = [generate_discord_registration_key("tenant") for _ in range(5)]
assert len(set(keys)) == 5
def test_delete_registration_key(self, db_session: Session) -> None:
"""Deleted key can no longer be used."""
key = generate_discord_registration_key("tenant")
config = create_guild_config(db_session, registration_key=key)
db_session.commit()
config_id = config.id
# Delete
deleted = delete_guild_config(db_session, config_id)
db_session.commit()
assert deleted is True
# Should not find it anymore
found = get_guild_config_by_registration_key(db_session, key)
assert found is None
class TestGuildConfigAPI:
"""Tests for guild config API operations."""
def test_list_guilds(self, db_session: Session) -> None:
"""List guilds returns all guild configs."""
# Create some guild configs
key1 = generate_discord_registration_key("t1")
key2 = generate_discord_registration_key("t2")
config1 = create_guild_config(db_session, registration_key=key1)
config2 = create_guild_config(db_session, registration_key=key2)
db_session.commit()
configs = get_guild_configs(db_session)
assert len(configs) >= 2
# Cleanup
delete_guild_config(db_session, config1.id)
delete_guild_config(db_session, config2.id)
db_session.commit()
def test_get_guild_config(self, db_session: Session) -> None:
"""Get specific guild config by ID."""
key = generate_discord_registration_key("tenant")
config = create_guild_config(db_session, registration_key=key)
db_session.commit()
found = get_guild_config_by_internal_id(db_session, config.id)
assert found is not None
assert found.id == config.id
assert found.registration_key == key
# Cleanup
delete_guild_config(db_session, config.id)
db_session.commit()
def test_update_guild_enabled(self, db_session: Session) -> None:
"""Update guild enabled status."""
key = generate_discord_registration_key("tenant")
config = create_guild_config(db_session, registration_key=key)
db_session.commit()
# Initially enabled is True by default
assert config.enabled is True
# Disable
updated = update_guild_config(
db_session, config, enabled=False, default_persona_id=None
)
db_session.commit()
assert updated.enabled is False
# Cleanup
delete_guild_config(db_session, config.id)
db_session.commit()
def test_update_guild_persona(self, db_session: Session) -> None:
"""Update guild default persona."""
# Create test persona first to satisfy foreign key constraint
_create_test_persona(db_session, 5, "Test Persona 5")
db_session.commit()
key = generate_discord_registration_key("tenant")
config = create_guild_config(db_session, registration_key=key)
db_session.commit()
# Set persona
updated = update_guild_config(
db_session, config, enabled=True, default_persona_id=5
)
db_session.commit()
assert updated.default_persona_id == 5
# Cleanup
delete_guild_config(db_session, config.id)
_delete_test_persona(db_session, 5)
db_session.commit()
class TestChannelConfigAPI:
"""Tests for channel config API operations."""
def test_list_channels_for_guild(self, db_session: Session) -> None:
"""List channels returns all channel configs for guild."""
key = generate_discord_registration_key("tenant")
guild = create_guild_config(db_session, registration_key=key)
db_session.commit()
# Create some channels
channels = [
DiscordChannelView(
channel_id=111,
channel_name="general",
channel_type="text",
is_private=False,
),
DiscordChannelView(
channel_id=222,
channel_name="help",
channel_type="text",
is_private=False,
),
]
bulk_create_channel_configs(db_session, guild.id, channels)
db_session.commit()
channel_configs = get_channel_configs(db_session, guild.id)
assert len(channel_configs) == 2
# Cleanup
delete_guild_config(db_session, guild.id)
db_session.commit()
def test_update_channel_enabled(self, db_session: Session) -> None:
"""Update channel enabled status."""
key = generate_discord_registration_key("tenant")
guild = create_guild_config(db_session, registration_key=key)
db_session.commit()
channels = [
DiscordChannelView(
channel_id=111,
channel_name="general",
channel_type="text",
is_private=False,
),
]
created = bulk_create_channel_configs(db_session, guild.id, channels)
db_session.commit()
# Channels are disabled by default
assert created[0].enabled is False
# Enable
updated = update_discord_channel_config(
db_session,
created[0],
channel_name="general",
thread_only_mode=False,
require_bot_invocation=True,
enabled=True,
)
db_session.commit()
assert updated.enabled is True
# Cleanup
delete_guild_config(db_session, guild.id)
db_session.commit()
def test_update_channel_thread_only_mode(self, db_session: Session) -> None:
"""Update channel thread_only_mode setting."""
key = generate_discord_registration_key("tenant")
guild = create_guild_config(db_session, registration_key=key)
db_session.commit()
channels = [
DiscordChannelView(
channel_id=111,
channel_name="general",
channel_type="text",
is_private=False,
),
]
created = bulk_create_channel_configs(db_session, guild.id, channels)
db_session.commit()
# Update thread_only_mode
updated = update_discord_channel_config(
db_session,
created[0],
channel_name="general",
thread_only_mode=True,
require_bot_invocation=True,
enabled=True,
)
db_session.commit()
assert updated.thread_only_mode is True
# Cleanup
delete_guild_config(db_session, guild.id)
db_session.commit()
def test_sync_channels_adds_new(self, db_session: Session) -> None:
"""Sync channels adds new channels."""
key = generate_discord_registration_key("tenant")
guild = create_guild_config(db_session, registration_key=key)
db_session.commit()
# Initial channels
initial = [
DiscordChannelView(
channel_id=111,
channel_name="general",
channel_type="text",
is_private=False,
),
]
bulk_create_channel_configs(db_session, guild.id, initial)
db_session.commit()
# Sync with new channel
current = [
DiscordChannelView(
channel_id=111,
channel_name="general",
channel_type="text",
is_private=False,
),
DiscordChannelView(
channel_id=222,
channel_name="new-channel",
channel_type="text",
is_private=False,
),
]
added, removed, updated = sync_channel_configs(db_session, guild.id, current)
db_session.commit()
assert added == 1
assert removed == 0
# Cleanup
delete_guild_config(db_session, guild.id)
db_session.commit()
def test_sync_channels_removes_deleted(self, db_session: Session) -> None:
"""Sync channels removes deleted channels."""
key = generate_discord_registration_key("tenant")
guild = create_guild_config(db_session, registration_key=key)
db_session.commit()
# Initial channels
initial = [
DiscordChannelView(
channel_id=111,
channel_name="general",
channel_type="text",
is_private=False,
),
DiscordChannelView(
channel_id=222,
channel_name="old-channel",
channel_type="text",
is_private=False,
),
]
bulk_create_channel_configs(db_session, guild.id, initial)
db_session.commit()
# Sync with one channel removed
current = [
DiscordChannelView(
channel_id=111,
channel_name="general",
channel_type="text",
is_private=False,
),
]
added, removed, updated = sync_channel_configs(db_session, guild.id, current)
db_session.commit()
assert added == 0
assert removed == 1
# Cleanup
delete_guild_config(db_session, guild.id)
db_session.commit()
def test_sync_channels_updates_renamed(self, db_session: Session) -> None:
"""Sync channels updates renamed channels."""
key = generate_discord_registration_key("tenant")
guild = create_guild_config(db_session, registration_key=key)
db_session.commit()
# Initial channels
initial = [
DiscordChannelView(
channel_id=111,
channel_name="old-name",
channel_type="text",
is_private=False,
),
]
bulk_create_channel_configs(db_session, guild.id, initial)
db_session.commit()
# Sync with renamed channel
current = [
DiscordChannelView(
channel_id=111,
channel_name="new-name",
channel_type="text",
is_private=False,
),
]
added, removed, updated = sync_channel_configs(db_session, guild.id, current)
db_session.commit()
assert added == 0
assert removed == 0
assert updated == 1
# Verify name was updated
configs = get_channel_configs(db_session, guild.id)
assert configs[0].channel_name == "new-name"
# Cleanup
delete_guild_config(db_session, guild.id)
db_session.commit()
class TestPersonaConfigurationAPI:
"""Tests for persona configuration in API."""
def test_guild_persona_used_in_api_call(self, db_session: Session) -> None:
"""Guild default_persona_id is used when no channel override."""
# Create test persona first
_create_test_persona(db_session, 42, "Test Persona 42")
db_session.commit()
key = generate_discord_registration_key("tenant")
guild = create_guild_config(db_session, registration_key=key)
update_guild_config(db_session, guild, enabled=True, default_persona_id=42)
db_session.commit()
# Verify persona is set
config = get_guild_config_by_internal_id(db_session, guild.id)
assert config is not None
assert config.default_persona_id == 42
# Cleanup
delete_guild_config(db_session, guild.id)
_delete_test_persona(db_session, 42)
db_session.commit()
def test_channel_persona_override_in_api_call(self, db_session: Session) -> None:
"""Channel persona_override_id takes precedence over guild default."""
# Create test personas first
_create_test_persona(db_session, 42, "Test Persona 42")
_create_test_persona(db_session, 99, "Test Persona 99")
db_session.commit()
key = generate_discord_registration_key("tenant")
guild = create_guild_config(db_session, registration_key=key)
update_guild_config(db_session, guild, enabled=True, default_persona_id=42)
db_session.commit()
channels = [
DiscordChannelView(
channel_id=111,
channel_name="general",
channel_type="text",
is_private=False,
),
]
created = bulk_create_channel_configs(db_session, guild.id, channels)
db_session.commit()
# Set channel persona override
updated = update_discord_channel_config(
db_session,
created[0],
channel_name="general",
thread_only_mode=False,
require_bot_invocation=True,
enabled=True,
persona_override_id=99, # Override!
)
db_session.commit()
assert updated.persona_override_id == 99
# Cleanup
delete_guild_config(db_session, guild.id)
_delete_test_persona(db_session, 42)
_delete_test_persona(db_session, 99)
db_session.commit()
def test_no_persona_uses_default(self, db_session: Session) -> None:
"""Neither guild nor channel has persona - uses API default."""
key = generate_discord_registration_key("tenant")
guild = create_guild_config(db_session, registration_key=key)
# No persona set
db_session.commit()
config = get_guild_config_by_internal_id(db_session, guild.id)
assert config is not None
assert config.default_persona_id is None
# Cleanup
delete_guild_config(db_session, guild.id)
db_session.commit()
class TestServiceApiKeyAPI:
"""Tests for Discord service API key operations."""
def test_create_service_api_key(self, db_session: Session) -> None:
"""Create service API key returns valid key."""
# Clean up any existing key first
delete_discord_service_api_key(db_session)
db_session.commit()
api_key = get_or_create_discord_service_api_key(db_session, "public")
db_session.commit()
assert api_key is not None
assert len(api_key) > 0
# Verify key was stored in database
stored_key = get_discord_service_api_key(db_session)
assert stored_key is not None
# Cleanup
delete_discord_service_api_key(db_session)
db_session.commit()
def test_get_or_create_returns_existing(self, db_session: Session) -> None:
"""get_or_create_discord_service_api_key regenerates key if exists."""
# Clean up any existing key first
delete_discord_service_api_key(db_session)
db_session.commit()
# Create first key
key1 = get_or_create_discord_service_api_key(db_session, "public")
db_session.commit()
# Call again - should regenerate (per implementation, it regenerates to update cache)
key2 = get_or_create_discord_service_api_key(db_session, "public")
db_session.commit()
# Keys should be different since it regenerates
assert key1 != key2
# But there should still be only one key in the database
stored_key = get_discord_service_api_key(db_session)
assert stored_key is not None
# Cleanup
delete_discord_service_api_key(db_session)
db_session.commit()
def test_delete_service_api_key(self, db_session: Session) -> None:
"""Delete service API key removes it from DB."""
# Clean up any existing key first
delete_discord_service_api_key(db_session)
db_session.commit()
# Create a key
get_or_create_discord_service_api_key(db_session, "public")
db_session.commit()
# Delete it
deleted = delete_discord_service_api_key(db_session)
db_session.commit()
assert deleted is True
assert get_discord_service_api_key(db_session) is None
def test_delete_service_api_key_not_found(self, db_session: Session) -> None:
"""Delete when no key exists returns False."""
# Ensure no key exists
delete_discord_service_api_key(db_session)
db_session.commit()
deleted = delete_discord_service_api_key(db_session)
assert deleted is False
# Pytest fixture for db_session
@pytest.fixture
def db_session() -> Generator[Session, None, None]:
"""Create database session for tests."""
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
SqlEngine.init_engine(pool_size=10, max_overflow=5)
token = CURRENT_TENANT_ID_CONTEXTVAR.set("public")
try:
with get_session_with_current_tenant() as session:
yield session
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)

View File

@@ -309,6 +309,63 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
assert fallback_llm.config.model_name == default_provider.default_model_name
def test_list_llm_provider_basics_excludes_non_public_unrestricted(
users: tuple[DATestUser, DATestUser],
) -> None:
"""Test that the /llm/provider endpoint correctly excludes non-public providers
with no group/persona restrictions.
This tests the fix for the bug where non-public providers with no restrictions
were incorrectly shown to all users instead of being admin-only.
"""
admin_user, basic_user = users
# Create a public provider (should be visible to all)
public_provider = LLMProviderManager.create(
name="public-provider",
is_public=True,
set_as_default=True,
user_performing_action=admin_user,
)
# Create a non-public provider with no restrictions (should be admin-only)
non_public_provider = LLMProviderManager.create(
name="non-public-unrestricted",
is_public=False,
groups=[],
personas=[],
set_as_default=False,
user_performing_action=admin_user,
)
# Non-admin user calls the /llm/provider endpoint
response = requests.get(
f"{API_SERVER_URL}/llm/provider",
headers=basic_user.headers,
)
assert response.status_code == 200
providers = response.json()
provider_names = [p["name"] for p in providers]
# Public provider should be visible
assert public_provider.name in provider_names
# Non-public provider with no restrictions should NOT be visible to non-admin
assert non_public_provider.name not in provider_names
# Admin user should see both providers
admin_response = requests.get(
f"{API_SERVER_URL}/llm/provider",
headers=admin_user.headers,
)
assert admin_response.status_code == 200
admin_providers = admin_response.json()
admin_provider_names = [p["name"] for p in admin_providers]
assert public_provider.name in admin_provider_names
assert non_public_provider.name in admin_provider_names
def test_provider_delete_clears_persona_references(reset: None) -> None:
"""Test that deleting a provider automatically clears persona references."""
admin_user = UserManager.create(name="admin_user")

View File

@@ -0,0 +1,65 @@
from uuid import uuid4
import requests
from onyx.server.features.persona.models import PersonaUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.persona import PersonaLabelManager
from tests.integration.common_utils.managers.persona import PersonaManager
from tests.integration.common_utils.test_models import DATestPersonaLabel
from tests.integration.common_utils.test_models import DATestUser
def test_update_persona_with_null_label_ids_preserves_labels(
reset: None, admin_user: DATestUser
) -> None:
persona_label = PersonaLabelManager.create(
label=DATestPersonaLabel(name=f"Test label {uuid4()}"),
user_performing_action=admin_user,
)
assert persona_label.id is not None
persona = PersonaManager.create(
label_ids=[persona_label.id],
user_performing_action=admin_user,
)
updated_description = f"{persona.description}-updated"
update_request = PersonaUpsertRequest(
name=persona.name,
description=updated_description,
system_prompt=persona.system_prompt or "",
task_prompt=persona.task_prompt or "",
datetime_aware=persona.datetime_aware,
document_set_ids=persona.document_set_ids,
num_chunks=persona.num_chunks,
is_public=persona.is_public,
recency_bias=persona.recency_bias,
llm_filter_extraction=persona.llm_filter_extraction,
llm_relevance_filter=persona.llm_relevance_filter,
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
tool_ids=persona.tool_ids,
users=[],
groups=[],
label_ids=None,
)
response = requests.patch(
f"{API_SERVER_URL}/persona/{persona.id}",
json=update_request.model_dump(mode="json", exclude_none=False),
headers=admin_user.headers,
cookies=admin_user.cookies,
)
response.raise_for_status()
fetched = requests.get(
f"{API_SERVER_URL}/persona/{persona.id}",
headers=admin_user.headers,
cookies=admin_user.cookies,
)
fetched.raise_for_status()
fetched_persona = fetched.json()
assert fetched_persona["description"] == updated_description
fetched_label_ids = {label["id"] for label in fetched_persona["labels"]}
assert persona_label.id in fetched_label_ids

View File

@@ -270,7 +270,7 @@ def test_web_search_endpoints_with_exa(
provider_id = _activate_exa_provider(admin_user)
assert isinstance(provider_id, int)
search_request = {"queries": ["latest ai research news"], "max_results": 3}
search_request = {"queries": ["wikipedia python programming"], "max_results": 3}
lite_response = requests.post(
f"{API_SERVER_URL}/web-search/search-lite",

View File

@@ -0,0 +1,159 @@
"""Tests for license enforcement middleware."""
from collections.abc import Awaitable
from collections.abc import Callable
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from starlette.requests import Request
from starlette.responses import Response
from ee.onyx.server.middleware.license_enforcement import _is_path_allowed
# Type alias for the middleware harness tuple
MiddlewareHarness = tuple[
Callable[[Request, Callable[[Request], Awaitable[Response]]], Awaitable[Response]],
Callable[[Request], Awaitable[Response]],
]
class TestPathAllowlist:
"""Tests for the path allowlist logic."""
@pytest.mark.parametrize(
"path,expected",
[
# Each allowlisted prefix (one example each)
("/auth", True),
("/license", True),
("/health", True),
("/me", True),
("/settings", True),
("/enterprise-settings", True),
("/tenants/billing-information", True),
("/tenants/create-customer-portal-session", True),
# Verify prefix matching works (subpath of allowlisted)
("/auth/callback/google", True),
# Blocked paths (core functionality that requires license)
("/chat", False),
("/search", False),
("/admin", False),
("/connector", False),
("/persona", False),
],
)
def test_path_allowlist(self, path: str, expected: bool) -> None:
"""Verify correct paths are allowed/blocked when license is gated."""
assert _is_path_allowed(path) is expected
class TestLicenseEnforcementMiddleware:
"""Tests for middleware behavior under different conditions."""
@pytest.fixture
def middleware_harness(self) -> MiddlewareHarness:
"""Create a test harness for the middleware."""
from ee.onyx.server.middleware.license_enforcement import (
add_license_enforcement_middleware,
)
app = MagicMock()
logger = MagicMock()
captured_middleware: Any = None
def capture_middleware(middleware_type: str) -> Callable[[Any], Any]:
def decorator(func: Any) -> Any:
nonlocal captured_middleware
captured_middleware = func
return func
return decorator
app.middleware = capture_middleware
add_license_enforcement_middleware(app, logger)
async def call_next(req: Request) -> Response:
response = MagicMock()
response.status_code = 200
return response
return captured_middleware, call_next
@pytest.mark.asyncio
@patch(
"ee.onyx.server.middleware.license_enforcement.LICENSE_ENFORCEMENT_ENABLED",
True,
)
@patch("ee.onyx.server.middleware.license_enforcement.MULTI_TENANT", True)
@patch("ee.onyx.server.middleware.license_enforcement.get_current_tenant_id")
@patch("ee.onyx.server.middleware.license_enforcement.is_tenant_gated")
async def test_gated_tenant_gets_402(
self,
mock_is_gated: MagicMock,
mock_get_tenant: MagicMock,
middleware_harness: MiddlewareHarness,
) -> None:
"""Gated tenants receive 402 Payment Required on non-allowlisted paths."""
mock_get_tenant.return_value = "gated_tenant"
mock_is_gated.return_value = True
middleware, call_next = middleware_harness
mock_request = MagicMock()
mock_request.url.path = "/api/chat"
response = await middleware(mock_request, call_next)
assert response.status_code == 402
@pytest.mark.asyncio
@patch(
"ee.onyx.server.middleware.license_enforcement.LICENSE_ENFORCEMENT_ENABLED",
True,
)
@patch("ee.onyx.server.middleware.license_enforcement.MULTI_TENANT", False)
@patch("ee.onyx.server.middleware.license_enforcement.get_current_tenant_id")
@patch("ee.onyx.server.middleware.license_enforcement.get_cached_license_metadata")
async def test_no_license_self_hosted_gets_402(
self,
mock_get_metadata: MagicMock,
mock_get_tenant: MagicMock,
middleware_harness: MiddlewareHarness,
) -> None:
"""Self-hosted with no license receives 402 on non-allowlisted paths."""
mock_get_tenant.return_value = "default"
mock_get_metadata.return_value = None
middleware, call_next = middleware_harness
mock_request = MagicMock()
mock_request.url.path = "/api/chat"
response = await middleware(mock_request, call_next)
assert response.status_code == 402
@pytest.mark.asyncio
@patch(
"ee.onyx.server.middleware.license_enforcement.LICENSE_ENFORCEMENT_ENABLED",
True,
)
@patch("ee.onyx.server.middleware.license_enforcement.MULTI_TENANT", True)
@patch("ee.onyx.server.middleware.license_enforcement.get_current_tenant_id")
@patch("ee.onyx.server.middleware.license_enforcement.is_tenant_gated")
async def test_redis_error_fails_open(
self,
mock_is_gated: MagicMock,
mock_get_tenant: MagicMock,
middleware_harness: MiddlewareHarness,
) -> None:
"""Redis errors should not block users - fail open to allow access."""
from redis.exceptions import RedisError
mock_get_tenant.return_value = "test_tenant"
mock_is_gated.side_effect = RedisError("Connection failed")
middleware, call_next = middleware_harness
mock_request = MagicMock()
mock_request.url.path = "/api/chat"
response = await middleware(mock_request, call_next)
assert response.status_code == 200 # Fail open

View File

@@ -0,0 +1,93 @@
"""Tests for license enforcement in settings API."""
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from redis.exceptions import RedisError
from onyx.server.settings.models import ApplicationStatus
from onyx.server.settings.models import Settings
@pytest.fixture
def base_settings() -> Settings:
"""Create base settings for testing."""
return Settings(
maximum_chat_retention_days=None,
gpu_enabled=False,
application_status=ApplicationStatus.ACTIVE,
)
class TestApplyLicenseStatusToSettings:
"""Tests for apply_license_status_to_settings function."""
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", False)
def test_enforcement_disabled_returns_unchanged(
self, base_settings: Settings
) -> None:
"""Critical: When LICENSE_ENFORCEMENT_ENABLED=False, settings remain unchanged.
This is the key behavior that allows disabling enforcement for rollback.
"""
from ee.onyx.server.settings.api import apply_license_status_to_settings
result = apply_license_status_to_settings(base_settings)
assert result.application_status == ApplicationStatus.ACTIVE
@pytest.mark.parametrize(
"license_status,expected_status",
[
(None, ApplicationStatus.GATED_ACCESS), # No license = gated
(
ApplicationStatus.GATED_ACCESS,
ApplicationStatus.GATED_ACCESS,
), # Gated status propagated
(ApplicationStatus.ACTIVE, ApplicationStatus.ACTIVE), # Active stays active
],
)
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
@patch("ee.onyx.server.settings.api.get_current_tenant_id")
@patch("ee.onyx.server.settings.api.get_cached_license_metadata")
def test_self_hosted_license_status_propagation(
self,
mock_get_metadata: MagicMock,
mock_get_tenant: MagicMock,
license_status: ApplicationStatus | None,
expected_status: ApplicationStatus,
base_settings: Settings,
) -> None:
"""Self-hosted: license status is propagated to settings correctly."""
from ee.onyx.server.settings.api import apply_license_status_to_settings
mock_get_tenant.return_value = "test_tenant"
if license_status is None:
mock_get_metadata.return_value = None
else:
mock_metadata = MagicMock()
mock_metadata.status = license_status
mock_get_metadata.return_value = mock_metadata
result = apply_license_status_to_settings(base_settings)
assert result.application_status == expected_status
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
@patch("ee.onyx.server.settings.api.get_current_tenant_id")
@patch("ee.onyx.server.settings.api.get_cached_license_metadata")
def test_redis_error_fails_open(
self,
mock_get_metadata: MagicMock,
mock_get_tenant: MagicMock,
base_settings: Settings,
) -> None:
"""Redis errors should not block users - fail open."""
from ee.onyx.server.settings.api import apply_license_status_to_settings
mock_get_tenant.return_value = "test_tenant"
mock_get_metadata.side_effect = RedisError("Connection failed")
result = apply_license_status_to_settings(base_settings)
assert result.application_status == ApplicationStatus.ACTIVE

View File

@@ -0,0 +1,69 @@
"""Tests for product gating functions."""
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
class TestIsTenantGated:
"""Tests for is_tenant_gated - the O(1) Redis check used by middleware."""
@pytest.mark.parametrize(
"redis_result,expected",
[
(True, True),
(False, False),
(1, True), # Redis sismember can return int
(0, False),
],
)
@patch("ee.onyx.server.tenants.product_gating.get_redis_replica_client")
def test_tenant_gated_status(
self,
mock_get_redis: MagicMock,
redis_result: bool | int,
expected: bool,
) -> None:
"""is_tenant_gated correctly interprets Redis sismember result."""
from ee.onyx.server.tenants.product_gating import is_tenant_gated
mock_redis = MagicMock()
mock_redis.sismember.return_value = redis_result
mock_get_redis.return_value = mock_redis
assert is_tenant_gated("test_tenant") is expected
class TestUpdateTenantGating:
"""Tests for update_tenant_gating - modifies Redis gated set."""
@pytest.mark.parametrize(
"status,should_add_to_set",
[
("gated_access", True), # Only GATED_ACCESS adds to set
("active", False), # All other statuses remove from set
],
)
@patch("ee.onyx.server.tenants.product_gating.get_redis_client")
def test_gating_set_modification(
self,
mock_get_redis: MagicMock,
status: str,
should_add_to_set: bool,
) -> None:
"""update_tenant_gating adds tenant to set only for GATED_ACCESS status."""
from ee.onyx.server.tenants.product_gating import update_tenant_gating
from onyx.server.settings.models import ApplicationStatus
mock_redis = MagicMock()
mock_get_redis.return_value = mock_redis
update_tenant_gating("test_tenant", ApplicationStatus(status))
if should_add_to_set:
mock_redis.sadd.assert_called_once()
mock_redis.srem.assert_not_called()
else:
mock_redis.srem.assert_called_once()
mock_redis.sadd.assert_not_called()

View File

@@ -0,0 +1,50 @@
"""Tests for Asana connector configuration parsing."""
import pytest
from onyx.connectors.asana.connector import AsanaConnector
@pytest.mark.parametrize(
"project_ids,expected",
[
(None, None),
("", None),
(" ", None),
(" 123 ", ["123"]),
(" 123 , , 456 , ", ["123", "456"]),
],
)
def test_asana_connector_project_ids_normalization(
project_ids: str | None, expected: list[str] | None
) -> None:
connector = AsanaConnector(
asana_workspace_id=" 1153293530468850 ",
asana_project_ids=project_ids,
asana_team_id=" 1210918501948021 ",
)
assert connector.workspace_id == "1153293530468850"
assert connector.project_ids_to_index == expected
assert connector.asana_team_id == "1210918501948021"
@pytest.mark.parametrize(
"team_id,expected",
[
(None, None),
("", None),
(" ", None),
(" 1210918501948021 ", "1210918501948021"),
],
)
def test_asana_connector_team_id_normalization(
team_id: str | None, expected: str | None
) -> None:
connector = AsanaConnector(
asana_workspace_id="1153293530468850",
asana_project_ids=None,
asana_team_id=team_id,
)
assert connector.asana_team_id == expected

View File

@@ -409,6 +409,53 @@ def test_multiple_tool_calls_streaming(default_multi_llm: LitellmLLM) -> None:
)
def test_vertex_stream_omits_stream_options() -> None:
llm = LitellmLLM(
api_key="test_key",
timeout=30,
model_provider=LlmProviderNames.VERTEX_AI,
model_name="claude-opus-4-5@20251101",
max_input_tokens=get_max_input_tokens(
model_provider=LlmProviderNames.VERTEX_AI,
model_name="claude-opus-4-5@20251101",
),
)
with patch("litellm.completion") as mock_completion:
mock_completion.return_value = []
messages: LanguageModelInput = [UserMessage(content="Hi")]
list(llm.stream(messages))
kwargs = mock_completion.call_args.kwargs
assert "stream_options" not in kwargs
def test_vertex_opus_4_5_omits_reasoning_effort() -> None:
llm = LitellmLLM(
api_key="test_key",
timeout=30,
model_provider=LlmProviderNames.VERTEX_AI,
model_name="claude-opus-4-5@20251101",
max_input_tokens=get_max_input_tokens(
model_provider=LlmProviderNames.VERTEX_AI,
model_name="claude-opus-4-5@20251101",
),
)
with (
patch("litellm.completion") as mock_completion,
patch("onyx.llm.multi_llm.model_is_reasoning_model", return_value=True),
):
mock_completion.return_value = []
messages: LanguageModelInput = [UserMessage(content="Hi")]
list(llm.stream(messages))
kwargs = mock_completion.call_args.kwargs
assert "reasoning_effort" not in kwargs
def test_user_identity_metadata_enabled(default_multi_llm: LitellmLLM) -> None:
with (
patch("litellm.completion") as mock_completion,

View File

@@ -0,0 +1,281 @@
"""Fixtures for Discord bot unit tests."""
import random
from collections.abc import Callable
from typing import Any
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
import discord
import pytest
class AsyncIteratorMock:
"""Helper class to mock async iterators like channel.history()."""
def __init__(self, items: list[Any]) -> None:
self.items = items
self.index = 0
def __aiter__(self) -> "AsyncIteratorMock":
return self
async def __anext__(self) -> Any:
if self.index >= len(self.items):
raise StopAsyncIteration
item = self.items[self.index]
self.index += 1
return item
def mock_message(
content: str = "Test message",
author_bot: bool = False,
message_type: discord.MessageType = discord.MessageType.default,
reference: MagicMock | None = None,
message_id: int | None = None,
author_id: int | None = None,
author_display_name: str | None = None,
) -> MagicMock:
"""Helper to create mock Discord messages."""
msg = MagicMock(spec=discord.Message)
msg.id = message_id or random.randint(100000, 999999)
msg.content = content
msg.author = MagicMock()
msg.author.id = author_id or random.randint(100000, 999999)
msg.author.bot = author_bot
msg.author.display_name = author_display_name or ("Bot" if author_bot else "User")
msg.type = message_type
msg.reference = reference
msg.mentions = []
msg.role_mentions = []
msg.channel_mentions = []
return msg
@pytest.fixture
def mock_bot_user() -> MagicMock:
"""Mock Discord bot user."""
user = MagicMock(spec=discord.ClientUser)
user.id = 123456789
user.display_name = "OnyxBot"
user.bot = True
return user
@pytest.fixture
def mock_discord_guild() -> MagicMock:
"""Mock Discord guild with channels."""
guild = MagicMock(spec=discord.Guild)
guild.id = 987654321
guild.name = "Test Server"
guild.default_role = MagicMock()
# Create some mock channels
text_channel = MagicMock(spec=discord.TextChannel)
text_channel.id = 111111111
text_channel.name = "general"
text_channel.type = discord.ChannelType.text
perms = MagicMock()
perms.view_channel = True
text_channel.permissions_for.return_value = perms
forum_channel = MagicMock(spec=discord.ForumChannel)
forum_channel.id = 222222222
forum_channel.name = "forum"
forum_channel.type = discord.ChannelType.forum
forum_channel.permissions_for.return_value = perms
guild.channels = [text_channel, forum_channel]
guild.text_channels = [text_channel]
guild.forum_channels = [forum_channel]
return guild
@pytest.fixture
def mock_discord_message(mock_bot_user: MagicMock) -> MagicMock:
"""Mock Discord message for testing."""
msg = MagicMock(spec=discord.Message)
msg.id = 555555555
msg.author = MagicMock()
msg.author.id = 444444444
msg.author.bot = False
msg.author.display_name = "TestUser"
msg.content = "Hello bot"
msg.guild = MagicMock()
msg.guild.id = 987654321
msg.guild.name = "Test Server"
msg.channel = MagicMock()
msg.channel.id = 111111111
msg.channel.name = "general"
msg.type = discord.MessageType.default
msg.mentions = []
msg.role_mentions = []
msg.channel_mentions = []
msg.reference = None
return msg
@pytest.fixture
def mock_thread_with_messages(mock_bot_user: MagicMock) -> MagicMock:
"""Mock Discord thread with message history."""
thread = MagicMock(spec=discord.Thread)
thread.id = 666666666
thread.name = "Test Thread"
thread.owner_id = mock_bot_user.id
thread.parent = MagicMock(spec=discord.TextChannel)
thread.parent.id = 111111111
# Mock starter message
starter = mock_message(
content="Thread starter message",
author_bot=False,
message_id=thread.id,
)
messages = [
mock_message(author_bot=False, content="User msg 1", message_id=100),
mock_message(author_bot=True, content="Bot response", message_id=101),
mock_message(author_bot=False, content="User msg 2", message_id=102),
]
# Setup async iterator for history
def history(**kwargs: Any) -> AsyncIteratorMock:
return AsyncIteratorMock(messages)
thread.history = history
# Mock parent.fetch_message
async def fetch_starter(msg_id: int) -> MagicMock:
if msg_id == thread.id:
return starter
raise discord.NotFound(MagicMock(), "Not found")
thread.parent.fetch_message = AsyncMock(side_effect=fetch_starter)
return thread
@pytest.fixture
def mock_thread_forum_parent() -> MagicMock:
"""Mock thread with ForumChannel parent (special case)."""
thread = MagicMock(spec=discord.Thread)
thread.id = 777777777
thread.name = "Forum Post"
thread.parent = MagicMock(spec=discord.ForumChannel)
thread.parent.id = 222222222
return thread
@pytest.fixture
def mock_reply_chain() -> MagicMock:
"""Mock message with reply chain."""
# Build chain backwards: msg3 -> msg2 -> msg1
ref3 = MagicMock()
ref3.message_id = 1003
ref2 = MagicMock()
ref2.message_id = 1002
msg3 = mock_message(content="Third message", reference=None, message_id=1003)
msg2 = mock_message(content="Second message", reference=ref3, message_id=1002)
msg1 = mock_message(content="First message", reference=ref2, message_id=1001)
# Store messages for lookup
msg1._chain = {1002: msg2, 1003: msg3}
msg2._chain = {1003: msg3}
return msg1
@pytest.fixture
def mock_guild_config_enabled() -> MagicMock:
"""Guild config that is enabled."""
config = MagicMock()
config.id = 1
config.guild_id = 987654321
config.enabled = True
config.default_persona_id = 1
return config
@pytest.fixture
def mock_guild_config_disabled() -> MagicMock:
"""Guild config that is disabled."""
config = MagicMock()
config.id = 2
config.guild_id = 987654321
config.enabled = False
config.default_persona_id = None
return config
@pytest.fixture
def mock_channel_config_factory() -> Callable[..., MagicMock]:
"""Factory fixture for creating channel configs with various settings."""
def _make_config(
enabled: bool = True,
require_bot_invocation: bool = True,
thread_only_mode: bool = False,
persona_override_id: int | None = None,
) -> MagicMock:
config = MagicMock()
config.id = random.randint(1, 1000)
config.channel_id = 111111111
config.enabled = enabled
config.require_bot_invocation = require_bot_invocation
config.thread_only_mode = thread_only_mode
config.persona_override_id = persona_override_id
return config
return _make_config
@pytest.fixture
def mock_message_with_bot_mention(mock_bot_user: MagicMock) -> MagicMock:
"""Message that mentions the bot."""
msg = MagicMock(spec=discord.Message)
msg.id = 888888888
msg.mentions = [mock_bot_user]
msg.author = MagicMock()
msg.author.id = 444444444
msg.author.bot = False
msg.author.display_name = "TestUser"
msg.type = discord.MessageType.default
msg.content = f"<@{mock_bot_user.id}> hello"
msg.reference = None
msg.guild = MagicMock()
msg.guild.id = 987654321
msg.channel = MagicMock()
msg.channel.id = 111111111
msg.role_mentions = []
msg.channel_mentions = []
return msg
@pytest.fixture
def mock_guild_with_members() -> MagicMock:
"""Mock guild for mention resolution."""
guild = MagicMock(spec=discord.Guild)
def get_member(member_id: int) -> MagicMock:
member = MagicMock()
member.display_name = f"User{member_id}"
return member
def get_role(role_id: int) -> MagicMock:
role = MagicMock()
role.name = f"Role{role_id}"
return role
def get_channel(channel_id: int) -> MagicMock:
channel = MagicMock()
channel.name = f"channel{channel_id}"
return channel
guild.get_member = get_member
guild.get_role = get_role
guild.get_channel = get_channel
return guild

View File

@@ -0,0 +1,441 @@
"""Unit tests for Discord bot API client.
Tests for OnyxAPIClient class functionality.
"""
from typing import Any
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
from unittest.mock import patch
import aiohttp
import pytest
from onyx.chat.models import ChatFullResponse
from onyx.onyxbot.discord.api_client import OnyxAPIClient
from onyx.onyxbot.discord.constants import API_REQUEST_TIMEOUT
from onyx.onyxbot.discord.exceptions import APIConnectionError
from onyx.onyxbot.discord.exceptions import APIResponseError
from onyx.onyxbot.discord.exceptions import APITimeoutError
class MockAsyncContextManager:
"""Helper class to create proper async context managers for testing."""
def __init__(
self, return_value: Any = None, enter_side_effect: Exception | None = None
) -> None:
self.return_value = return_value
self.enter_side_effect = enter_side_effect
async def __aenter__(self) -> Any:
if self.enter_side_effect:
raise self.enter_side_effect
return self.return_value
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
pass
class TestClientLifecycle:
"""Tests for API client lifecycle management."""
@pytest.mark.asyncio
async def test_initialize_creates_session(self) -> None:
"""initialize() creates aiohttp session."""
client = OnyxAPIClient()
assert client._session is None
with patch("aiohttp.ClientSession") as mock_session_class:
mock_session = MagicMock()
mock_session_class.return_value = mock_session
await client.initialize()
assert client._session is not None
mock_session_class.assert_called_once()
def test_is_initialized_before_init(self) -> None:
"""is_initialized returns False before initialize()."""
client = OnyxAPIClient()
assert client.is_initialized is False
@pytest.mark.asyncio
async def test_is_initialized_after_init(self) -> None:
"""is_initialized returns True after initialize()."""
client = OnyxAPIClient()
with patch("aiohttp.ClientSession"):
await client.initialize()
assert client.is_initialized is True
@pytest.mark.asyncio
async def test_close_closes_session(self) -> None:
"""close() closes session and resets is_initialized."""
client = OnyxAPIClient()
mock_session = AsyncMock()
with patch("aiohttp.ClientSession", return_value=mock_session):
await client.initialize()
assert client.is_initialized is True
await client.close()
assert client.is_initialized is False
mock_session.close.assert_called_once()
@pytest.mark.asyncio
async def test_send_message_not_initialized(self) -> None:
"""send_chat_message() before initialize() raises APIConnectionError."""
client = OnyxAPIClient()
with pytest.raises(APIConnectionError) as exc_info:
await client.send_chat_message("test", "api_key")
assert "not initialized" in str(exc_info.value)
class TestSendChatMessage:
"""Tests for send_chat_message functionality."""
@pytest.mark.asyncio
async def test_send_message_success(self) -> None:
"""Valid request returns ChatFullResponse."""
client = OnyxAPIClient()
response_data = {
"answer": "Test response",
"citations": [],
"error_msg": None,
}
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value=response_data)
mock_session = MagicMock()
mock_session.post = MagicMock(
return_value=MockAsyncContextManager(return_value=mock_response)
)
client._session = mock_session
with patch.object(
ChatFullResponse,
"model_validate",
return_value=MagicMock(answer="Test response", error_msg=None),
):
result = await client.send_chat_message("Hello", "api_key_123")
assert result is not None
@pytest.mark.asyncio
async def test_send_message_with_persona(self) -> None:
"""persona_id is passed to API."""
client = OnyxAPIClient()
response_data = {"answer": "Response", "citations": [], "error_msg": None}
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value=response_data)
mock_session = MagicMock()
mock_post = MagicMock(
return_value=MockAsyncContextManager(return_value=mock_response)
)
mock_session.post = mock_post
client._session = mock_session
with patch.object(
ChatFullResponse,
"model_validate",
return_value=MagicMock(answer="Response", error_msg=None),
):
await client.send_chat_message("Hello", "api_key", persona_id=5)
# Verify persona was included in request
call_args = mock_post.call_args
json_data = call_args.kwargs.get("json") or call_args[1].get("json")
assert json_data is not None
@pytest.mark.asyncio
async def test_send_message_401_error(self) -> None:
"""Invalid API key returns APIResponseError with 401."""
client = OnyxAPIClient()
mock_response = MagicMock()
mock_response.status = 401
mock_session = MagicMock()
mock_session.post = MagicMock(
return_value=MockAsyncContextManager(return_value=mock_response)
)
client._session = mock_session
with pytest.raises(APIResponseError) as exc_info:
await client.send_chat_message("Hello", "bad_key")
assert exc_info.value.status_code == 401
@pytest.mark.asyncio
async def test_send_message_403_error(self) -> None:
"""Persona not accessible returns APIResponseError with 403."""
client = OnyxAPIClient()
mock_response = MagicMock()
mock_response.status = 403
mock_session = MagicMock()
mock_session.post = MagicMock(
return_value=MockAsyncContextManager(return_value=mock_response)
)
client._session = mock_session
with pytest.raises(APIResponseError) as exc_info:
await client.send_chat_message("Hello", "api_key", persona_id=999)
assert exc_info.value.status_code == 403
@pytest.mark.asyncio
async def test_send_message_timeout(self) -> None:
"""Request timeout raises APITimeoutError."""
client = OnyxAPIClient()
mock_session = MagicMock()
mock_session.post = MagicMock(
return_value=MockAsyncContextManager(
enter_side_effect=TimeoutError("Timeout")
)
)
client._session = mock_session
with pytest.raises(APITimeoutError):
await client.send_chat_message("Hello", "api_key")
@pytest.mark.asyncio
async def test_send_message_connection_error(self) -> None:
"""Network failure raises APIConnectionError."""
client = OnyxAPIClient()
mock_session = MagicMock()
mock_session.post = MagicMock(
return_value=MockAsyncContextManager(
enter_side_effect=aiohttp.ClientConnectorError(
MagicMock(), OSError("Connection refused")
)
)
)
client._session = mock_session
with pytest.raises(APIConnectionError):
await client.send_chat_message("Hello", "api_key")
@pytest.mark.asyncio
async def test_send_message_server_error(self) -> None:
"""500 response raises APIResponseError with 500."""
client = OnyxAPIClient()
mock_response = MagicMock()
mock_response.status = 500
mock_response.text = AsyncMock(return_value="Internal Server Error")
mock_session = MagicMock()
mock_session.post = MagicMock(
return_value=MockAsyncContextManager(return_value=mock_response)
)
client._session = mock_session
with pytest.raises(APIResponseError) as exc_info:
await client.send_chat_message("Hello", "api_key")
assert exc_info.value.status_code == 500
class TestHealthCheck:
"""Tests for health_check functionality."""
@pytest.mark.asyncio
async def test_health_check_success(self) -> None:
"""Server healthy returns True."""
client = OnyxAPIClient()
mock_response = MagicMock()
mock_response.status = 200
mock_session = MagicMock()
mock_session.get = MagicMock(
return_value=MockAsyncContextManager(return_value=mock_response)
)
client._session = mock_session
result = await client.health_check()
assert result is True
@pytest.mark.asyncio
async def test_health_check_failure(self) -> None:
"""Server unhealthy returns False."""
client = OnyxAPIClient()
mock_response = MagicMock()
mock_response.status = 503
mock_session = MagicMock()
mock_session.get = MagicMock(
return_value=MockAsyncContextManager(return_value=mock_response)
)
client._session = mock_session
result = await client.health_check()
assert result is False
@pytest.mark.asyncio
async def test_health_check_timeout(self) -> None:
"""Request times out returns False."""
client = OnyxAPIClient()
mock_session = MagicMock()
mock_session.get = MagicMock(
return_value=MockAsyncContextManager(
enter_side_effect=TimeoutError("Timeout")
)
)
client._session = mock_session
result = await client.health_check()
assert result is False
@pytest.mark.asyncio
async def test_health_check_not_initialized(self) -> None:
"""Health check before initialize returns False."""
client = OnyxAPIClient()
result = await client.health_check()
assert result is False
class TestResponseParsing:
"""Tests for API response parsing."""
@pytest.mark.asyncio
async def test_response_malformed_json(self) -> None:
"""API returns invalid JSON raises exception."""
client = OnyxAPIClient()
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(side_effect=ValueError("Invalid JSON"))
mock_session = MagicMock()
mock_session.post = MagicMock(
return_value=MockAsyncContextManager(return_value=mock_response)
)
client._session = mock_session
with pytest.raises(ValueError):
await client.send_chat_message("Hello", "api_key")
@pytest.mark.asyncio
async def test_response_with_error_msg(self) -> None:
"""200 status but error_msg present - warning logged, response returned."""
client = OnyxAPIClient()
response_data = {
"answer": "Partial response",
"citations": [],
"error_msg": "Some warning",
}
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value=response_data)
mock_session = MagicMock()
mock_session.post = MagicMock(
return_value=MockAsyncContextManager(return_value=mock_response)
)
client._session = mock_session
mock_result = MagicMock()
mock_result.answer = "Partial response"
mock_result.error_msg = "Some warning"
with patch.object(ChatFullResponse, "model_validate", return_value=mock_result):
result = await client.send_chat_message("Hello", "api_key")
# Should still return response
assert result is not None
@pytest.mark.asyncio
async def test_response_empty_answer(self) -> None:
"""answer field is empty string - handled gracefully."""
client = OnyxAPIClient()
response_data = {
"answer": "",
"citations": [],
"error_msg": None,
}
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value=response_data)
mock_session = MagicMock()
mock_session.post = MagicMock(
return_value=MockAsyncContextManager(return_value=mock_response)
)
client._session = mock_session
mock_result = MagicMock()
mock_result.answer = ""
mock_result.error_msg = None
with patch.object(ChatFullResponse, "model_validate", return_value=mock_result):
result = await client.send_chat_message("Hello", "api_key")
# Should return response even with empty answer
assert result is not None
class TestClientConfiguration:
"""Tests for client configuration."""
def test_default_timeout(self) -> None:
"""Client uses API_REQUEST_TIMEOUT by default."""
client = OnyxAPIClient()
assert client._timeout == API_REQUEST_TIMEOUT
def test_custom_timeout(self) -> None:
"""Client accepts custom timeout."""
client = OnyxAPIClient(timeout=60)
assert client._timeout == 60
@pytest.mark.asyncio
async def test_double_initialize_warning(self) -> None:
"""Calling initialize() twice logs warning but doesn't error."""
client = OnyxAPIClient()
with patch("aiohttp.ClientSession") as mock_session_class:
mock_session = MagicMock()
mock_session_class.return_value = mock_session
await client.initialize()
# Second call should be safe
await client.initialize()
# Should only create one session
assert mock_session_class.call_count == 1

View File

@@ -0,0 +1,520 @@
"""Unit tests for Discord bot cache manager.
Tests for DiscordCacheManager class functionality.
"""
import asyncio
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.onyxbot.discord.cache import DiscordCacheManager
class TestCacheInitialization:
"""Tests for cache initialization."""
def test_cache_starts_empty(self) -> None:
"""New cache manager has empty caches."""
cache = DiscordCacheManager()
assert cache._guild_tenants == {}
assert cache._api_keys == {}
assert cache.is_initialized is False
@pytest.mark.asyncio
async def test_cache_refresh_all_loads_guilds(self) -> None:
"""refresh_all() loads all active guilds."""
cache = DiscordCacheManager()
mock_config1 = MagicMock()
mock_config1.guild_id = 111111
mock_config1.enabled = True
mock_config2 = MagicMock()
mock_config2.guild_id = 222222
mock_config2.enabled = True
with (
patch(
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
return_value=["tenant1"],
),
patch(
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
return_value=lambda: set(),
),
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
patch(
"onyx.onyxbot.discord.cache.get_guild_configs",
return_value=[mock_config1, mock_config2],
),
patch(
"onyx.onyxbot.discord.cache.get_or_create_discord_service_api_key",
return_value="test_api_key",
),
):
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
await cache.refresh_all()
assert cache.is_initialized is True
assert 111111 in cache._guild_tenants
assert 222222 in cache._guild_tenants
assert cache._guild_tenants[111111] == "tenant1"
assert cache._guild_tenants[222222] == "tenant1"
@pytest.mark.asyncio
async def test_cache_refresh_provisions_api_key(self) -> None:
"""Refresh for tenant without key creates API key."""
cache = DiscordCacheManager()
mock_config = MagicMock()
mock_config.guild_id = 111111
mock_config.enabled = True
with (
patch(
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
return_value=["tenant1"],
),
patch(
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
return_value=lambda: set(),
),
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
patch(
"onyx.onyxbot.discord.cache.get_guild_configs",
return_value=[mock_config],
),
patch(
"onyx.onyxbot.discord.cache.get_or_create_discord_service_api_key",
return_value="new_api_key",
) as mock_provision,
):
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
await cache.refresh_all()
assert cache._api_keys.get("tenant1") == "new_api_key"
mock_provision.assert_called()
class TestCacheLookups:
"""Tests for cache lookup operations."""
def test_get_tenant_returns_correct(self) -> None:
"""Lookup registered guild returns correct tenant ID."""
cache = DiscordCacheManager()
cache._guild_tenants[123456] = "tenant1"
result = cache.get_tenant(123456)
assert result == "tenant1"
def test_get_tenant_returns_none_unknown(self) -> None:
"""Lookup unregistered guild returns None."""
cache = DiscordCacheManager()
result = cache.get_tenant(999999)
assert result is None
def test_get_api_key_returns_correct(self) -> None:
"""Lookup tenant's API key returns valid key."""
cache = DiscordCacheManager()
cache._api_keys["tenant1"] = "api_key_123"
result = cache.get_api_key("tenant1")
assert result == "api_key_123"
def test_get_api_key_returns_none_unknown(self) -> None:
"""Lookup unknown tenant returns None."""
cache = DiscordCacheManager()
result = cache.get_api_key("unknown_tenant")
assert result is None
def test_get_all_guild_ids(self) -> None:
"""After loading returns all cached guild IDs."""
cache = DiscordCacheManager()
cache._guild_tenants = {111: "t1", 222: "t2", 333: "t1"}
result = cache.get_all_guild_ids()
assert set(result) == {111, 222, 333}
class TestCacheUpdates:
"""Tests for cache update operations."""
@pytest.mark.asyncio
async def test_refresh_guild_adds_new(self) -> None:
"""refresh_guild() for new guild adds it to cache."""
cache = DiscordCacheManager()
mock_config = MagicMock()
mock_config.guild_id = 111111
mock_config.enabled = True
with (
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
patch(
"onyx.onyxbot.discord.cache.get_guild_configs",
return_value=[mock_config],
),
patch(
"onyx.onyxbot.discord.cache.get_or_create_discord_service_api_key",
return_value="api_key",
),
):
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
await cache.refresh_guild(111111, "tenant1")
assert cache.get_tenant(111111) == "tenant1"
@pytest.mark.asyncio
async def test_refresh_guild_verifies_active(self) -> None:
"""refresh_guild() for disabled guild doesn't add it."""
cache = DiscordCacheManager()
mock_config = MagicMock()
mock_config.guild_id = 111111
mock_config.enabled = False # Disabled!
with (
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
patch(
"onyx.onyxbot.discord.cache.get_guild_configs",
return_value=[mock_config],
),
):
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
await cache.refresh_guild(111111, "tenant1")
# Should not be added because it's disabled
assert cache.get_tenant(111111) is None
def test_remove_guild(self) -> None:
"""remove_guild() removes guild from cache."""
cache = DiscordCacheManager()
cache._guild_tenants[111111] = "tenant1"
cache.remove_guild(111111)
assert cache.get_tenant(111111) is None
def test_clear_removes_all(self) -> None:
"""clear() empties all caches."""
cache = DiscordCacheManager()
cache._guild_tenants = {111: "t1", 222: "t2"}
cache._api_keys = {"t1": "key1", "t2": "key2"}
cache._initialized = True
cache.clear()
assert cache._guild_tenants == {}
assert cache._api_keys == {}
assert cache.is_initialized is False
class TestThreadSafety:
"""Tests for thread/async safety."""
@pytest.mark.asyncio
async def test_concurrent_refresh_no_race(self) -> None:
"""Multiple concurrent refresh_all() calls don't corrupt data."""
cache = DiscordCacheManager()
mock_config = MagicMock()
mock_config.guild_id = 111111
mock_config.enabled = True
call_count = 0
async def slow_refresh() -> tuple[list[int], str]:
nonlocal call_count
call_count += 1
# Simulate slow operation
await asyncio.sleep(0.01)
return ([111111], "api_key")
with (
patch(
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
return_value=["tenant1"],
),
patch(
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
return_value=lambda: set(),
),
patch.object(cache, "_load_tenant_data", side_effect=slow_refresh),
):
# Run multiple concurrent refreshes
await asyncio.gather(
cache.refresh_all(),
cache.refresh_all(),
cache.refresh_all(),
)
# Each refresh should complete without error
assert cache.is_initialized is True
@pytest.mark.asyncio
async def test_concurrent_read_write(self) -> None:
"""Read during refresh doesn't cause exceptions."""
cache = DiscordCacheManager()
cache._guild_tenants[111111] = "tenant1"
async def read_loop() -> None:
for _ in range(10):
cache.get_tenant(111111)
await asyncio.sleep(0.001)
async def write_loop() -> None:
for i in range(10):
cache._guild_tenants[200000 + i] = f"tenant{i}"
await asyncio.sleep(0.001)
# Should not raise any exceptions
await asyncio.gather(read_loop(), write_loop())
class TestAPIKeyProvisioning:
"""Tests for API key provisioning via cache refresh."""
@pytest.mark.asyncio
async def test_api_key_created_on_first_refresh(self) -> None:
"""Cache refresh with no existing key creates new API key."""
cache = DiscordCacheManager()
mock_config = MagicMock()
mock_config.guild_id = 111111
mock_config.enabled = True
with (
patch(
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
return_value=["tenant1"],
),
patch(
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
return_value=lambda: set(),
),
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
patch(
"onyx.onyxbot.discord.cache.get_guild_configs",
return_value=[mock_config],
),
patch(
"onyx.onyxbot.discord.cache.get_or_create_discord_service_api_key",
return_value="new_api_key_123",
) as mock_create,
):
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
await cache.refresh_all()
mock_create.assert_called_once()
assert cache.get_api_key("tenant1") == "new_api_key_123"
@pytest.mark.asyncio
async def test_api_key_cached_after_creation(self) -> None:
"""Subsequent lookups after creation use cached key."""
cache = DiscordCacheManager()
cache._api_keys["tenant1"] = "cached_key"
mock_config = MagicMock()
mock_config.guild_id = 111111
mock_config.enabled = True
with (
patch(
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
return_value=["tenant1"],
),
patch(
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
return_value=lambda: set(),
),
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
patch(
"onyx.onyxbot.discord.cache.get_guild_configs",
return_value=[mock_config],
),
patch(
"onyx.onyxbot.discord.cache.get_or_create_discord_service_api_key",
) as mock_create,
):
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
await cache.refresh_all()
# Should NOT call create because key is already cached
mock_create.assert_not_called()
# Cached key should be preserved after refresh
assert cache.get_api_key("tenant1") == "cached_key"
class TestGatedTenantHandling:
"""Tests for gated tenant filtering."""
@pytest.mark.asyncio
async def test_refresh_skips_gated_tenants(self) -> None:
"""Gated tenant's guilds are not loaded."""
cache = DiscordCacheManager()
# tenant2 is gated
gated_tenants = {"tenant2"}
mock_config_t1 = MagicMock()
mock_config_t1.guild_id = 111111
mock_config_t1.enabled = True
mock_config_t2 = MagicMock()
mock_config_t2.guild_id = 222222
mock_config_t2.enabled = True
def mock_get_configs(db: MagicMock) -> list[MagicMock]:
# Track which tenant this was called for
return [mock_config_t1] # Always return same for simplicity
with (
patch(
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
return_value=["tenant1", "tenant2"],
),
patch(
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
return_value=lambda: gated_tenants,
),
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
patch(
"onyx.onyxbot.discord.cache.get_guild_configs",
side_effect=mock_get_configs,
),
patch(
"onyx.onyxbot.discord.cache.get_or_create_discord_service_api_key",
return_value="api_key",
),
):
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
await cache.refresh_all()
# Only tenant1 should be loaded (tenant2 is gated)
assert "tenant1" in cache._api_keys and 111111 in cache._guild_tenants
# tenant2's guilds should NOT be in cache
assert "tenant2" not in cache._api_keys and 222222 not in cache._guild_tenants
@pytest.mark.asyncio
async def test_gated_check_calls_ee_function(self) -> None:
"""Refresh all tenants calls fetch_ee_implementation_or_noop."""
cache = DiscordCacheManager()
with (
patch(
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
return_value=["tenant1"],
),
patch(
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
return_value=lambda: set(),
) as mock_ee,
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
patch(
"onyx.onyxbot.discord.cache.get_guild_configs",
return_value=[],
),
):
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
await cache.refresh_all()
mock_ee.assert_called_once()
@pytest.mark.asyncio
async def test_ungated_tenant_included(self) -> None:
"""Regular (ungated) tenant has guilds loaded normally."""
cache = DiscordCacheManager()
mock_config = MagicMock()
mock_config.guild_id = 111111
mock_config.enabled = True
with (
patch(
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
return_value=["tenant1"],
),
patch(
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
return_value=lambda: set(), # No gated tenants
),
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
patch(
"onyx.onyxbot.discord.cache.get_guild_configs",
return_value=[mock_config],
),
patch(
"onyx.onyxbot.discord.cache.get_or_create_discord_service_api_key",
return_value="api_key",
),
):
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
await cache.refresh_all()
assert cache.get_tenant(111111) == "tenant1"
class TestCacheErrorHandling:
"""Tests for error handling in cache operations."""
@pytest.mark.asyncio
async def test_refresh_all_handles_tenant_error(self) -> None:
"""Error loading one tenant doesn't stop others."""
cache = DiscordCacheManager()
call_count = 0
async def mock_load(tenant_id: str) -> tuple[list[int], str]:
nonlocal call_count
call_count += 1
if tenant_id == "tenant1":
raise Exception("Tenant 1 error")
return ([222222], "api_key")
with (
patch(
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
return_value=["tenant1", "tenant2"],
),
patch(
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
return_value=lambda: set(),
),
patch.object(cache, "_load_tenant_data", side_effect=mock_load),
):
await cache.refresh_all()
# Should still complete and load tenant2
assert call_count == 2 # Both tenants attempted
assert cache.get_tenant(222222) == "tenant2"

View File

@@ -0,0 +1,645 @@
"""Unit tests for Discord bot context builders.
Tests the thread and reply context building logic with mocked Discord API.
"""
from typing import Any
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
import discord
import pytest
from onyx.onyxbot.discord.constants import MAX_CONTEXT_MESSAGES
from onyx.onyxbot.discord.handle_message import _build_conversation_context
from onyx.onyxbot.discord.handle_message import _build_reply_chain_context
from onyx.onyxbot.discord.handle_message import _build_thread_context
from onyx.onyxbot.discord.handle_message import _format_messages_as_context
from onyx.onyxbot.discord.handle_message import format_message_content
from tests.unit.onyx.onyxbot.discord.conftest import AsyncIteratorMock
from tests.unit.onyx.onyxbot.discord.conftest import mock_message
class TestThreadContextBuilder:
"""Tests for _build_thread_context function."""
@pytest.mark.asyncio
async def test_build_thread_context_basic(
self, mock_thread_with_messages: MagicMock, mock_bot_user: MagicMock
) -> None:
"""Thread with messages returns context in order."""
msg = MagicMock(spec=discord.Message)
msg.id = 999 # Current message ID
msg.channel = mock_thread_with_messages
result = await _build_thread_context(msg, mock_bot_user)
assert result is not None
assert "Conversation history" in result
# Should contain message content
assert "User msg" in result or "Bot response" in result
@pytest.mark.asyncio
async def test_build_thread_context_max_limit(
self, mock_bot_user: MagicMock
) -> None:
"""Thread with 20 messages returns only MAX_CONTEXT_MESSAGES."""
# Create 20 messages
messages = [
mock_message(content=f"Message {i}", message_id=i) for i in range(20)
]
thread = MagicMock(spec=discord.Thread)
thread.id = 666666
thread.parent = MagicMock(spec=discord.TextChannel)
def history(**kwargs: Any) -> AsyncIteratorMock:
limit = kwargs.get("limit", MAX_CONTEXT_MESSAGES)
return AsyncIteratorMock(messages[:limit])
thread.history = history
thread.parent.fetch_message = AsyncMock(
side_effect=discord.NotFound(MagicMock(), "")
)
msg = MagicMock(spec=discord.Message)
msg.id = 999
msg.channel = thread
result = await _build_thread_context(msg, mock_bot_user)
assert result is not None
# Should only have MAX_CONTEXT_MESSAGES worth of content
@pytest.mark.asyncio
async def test_build_thread_context_includes_starter(
self, mock_bot_user: MagicMock
) -> None:
"""Thread with starter message includes it at beginning."""
starter = mock_message(
content="This is the thread starter",
message_id=666666,
)
thread = MagicMock(spec=discord.Thread)
thread.id = 666666
thread.parent = MagicMock(spec=discord.TextChannel)
thread.parent.fetch_message = AsyncMock(return_value=starter)
messages = [
mock_message(content="Reply 1", message_id=1),
mock_message(content="Reply 2", message_id=2),
]
def history(**kwargs: Any) -> AsyncIteratorMock:
return AsyncIteratorMock(messages)
thread.history = history
msg = MagicMock(spec=discord.Message)
msg.id = 999
msg.channel = thread
result = await _build_thread_context(msg, mock_bot_user)
assert result is not None
assert "thread starter" in result
@pytest.mark.asyncio
async def test_build_thread_context_filters_system_messages(
self, mock_bot_user: MagicMock
) -> None:
"""Thread with system messages only includes content messages."""
messages = [
mock_message(
content="Normal message", message_type=discord.MessageType.default
),
mock_message(
content="", message_type=discord.MessageType.pins_add
), # System
mock_message(
content="Another normal", message_type=discord.MessageType.reply
),
]
thread = MagicMock(spec=discord.Thread)
thread.id = 666666
thread.parent = MagicMock(spec=discord.TextChannel)
thread.parent.fetch_message = AsyncMock(
side_effect=discord.NotFound(MagicMock(), "")
)
def history(**kwargs: Any) -> AsyncIteratorMock:
return AsyncIteratorMock(messages)
thread.history = history
msg = MagicMock(spec=discord.Message)
msg.id = 999
msg.channel = thread
result = await _build_thread_context(msg, mock_bot_user)
# Should not include system message type
assert result is not None
@pytest.mark.asyncio
async def test_build_thread_context_includes_bot_messages(
self, mock_bot_user: MagicMock
) -> None:
"""Bot messages in thread are included for context."""
messages = [
mock_message(content="User question", author_bot=False),
mock_message(
content="Bot response",
author_bot=True,
author_id=mock_bot_user.id,
author_display_name="OnyxBot",
),
]
thread = MagicMock(spec=discord.Thread)
thread.id = 666666
thread.parent = MagicMock(spec=discord.TextChannel)
thread.parent.fetch_message = AsyncMock(
side_effect=discord.NotFound(MagicMock(), "")
)
def history(**kwargs: Any) -> AsyncIteratorMock:
return AsyncIteratorMock(messages)
thread.history = history
msg = MagicMock(spec=discord.Message)
msg.id = 999
msg.channel = thread
result = await _build_thread_context(msg, mock_bot_user)
assert result is not None
assert "Bot response" in result
@pytest.mark.asyncio
async def test_build_thread_context_empty_thread(
self, mock_bot_user: MagicMock
) -> None:
"""Thread with only system messages returns None."""
messages = [
mock_message(content="", message_type=discord.MessageType.pins_add),
]
thread = MagicMock(spec=discord.Thread)
thread.id = 666666
thread.parent = MagicMock(spec=discord.TextChannel)
thread.parent.fetch_message = AsyncMock(
side_effect=discord.NotFound(MagicMock(), "")
)
def history(**kwargs: Any) -> AsyncIteratorMock:
return AsyncIteratorMock(messages)
thread.history = history
msg = MagicMock(spec=discord.Message)
msg.id = 999
msg.channel = thread
await _build_thread_context(msg, mock_bot_user)
# Should return None for empty context
# (depends on implementation - may return None or empty string)
@pytest.mark.asyncio
async def test_build_thread_context_forum_channel(
self, mock_bot_user: MagicMock
) -> None:
"""Thread parent is ForumChannel - does NOT fetch starter message."""
messages = [
mock_message(content="Forum reply", message_id=1),
]
thread = MagicMock(spec=discord.Thread)
thread.id = 666666
thread.parent = MagicMock(spec=discord.ForumChannel) # Forum!
# Set up mock before calling function so we can verify it wasn't called
thread.parent.fetch_message = AsyncMock()
def history(**kwargs: Any) -> AsyncIteratorMock:
return AsyncIteratorMock(messages)
thread.history = history
msg = MagicMock(spec=discord.Message)
msg.id = 999
msg.channel = thread
await _build_thread_context(msg, mock_bot_user)
# Should not try to fetch starter message for forum channels
thread.parent.fetch_message.assert_not_called()
@pytest.mark.asyncio
async def test_build_thread_context_starter_fetch_fails(
self, mock_bot_user: MagicMock
) -> None:
"""Starter message fetch raises NotFound - continues without starter."""
messages = [
mock_message(content="Reply message", message_id=1),
]
thread = MagicMock(spec=discord.Thread)
thread.id = 666666
thread.parent = MagicMock(spec=discord.TextChannel)
thread.parent.fetch_message = AsyncMock(
side_effect=discord.NotFound(MagicMock(), "Not found")
)
def history(**kwargs: Any) -> AsyncIteratorMock:
return AsyncIteratorMock(messages)
thread.history = history
msg = MagicMock(spec=discord.Message)
msg.id = 999
msg.channel = thread
result = await _build_thread_context(msg, mock_bot_user)
# Should still return context without starter
assert result is not None
@pytest.mark.asyncio
async def test_build_thread_context_deduplicates_starter(
self, mock_bot_user: MagicMock
) -> None:
"""Starter also in recent history is not duplicated."""
starter = mock_message(content="Thread starter", message_id=666666)
messages = [
starter, # Starter in history
mock_message(content="Reply", message_id=1),
]
thread = MagicMock(spec=discord.Thread)
thread.id = 666666
thread.parent = MagicMock(spec=discord.TextChannel)
thread.parent.fetch_message = AsyncMock(return_value=starter)
def history(**kwargs: Any) -> AsyncIteratorMock:
return AsyncIteratorMock(messages)
thread.history = history
msg = MagicMock(spec=discord.Message)
msg.id = 999
msg.channel = thread
result = await _build_thread_context(msg, mock_bot_user)
# Should only have starter once
if result:
assert (
result.count("Thread starter") <= 2
) # At most once in formatted output
class TestReplyChainContextBuilder:
"""Tests for _build_reply_chain_context function."""
@pytest.mark.asyncio
async def test_build_reply_chain_single_reply(
self, mock_bot_user: MagicMock
) -> None:
"""Message replies to one message returns 1 message in chain."""
parent = mock_message(content="Parent message", message_id=100)
parent.reference = None
child = MagicMock(spec=discord.Message)
child.id = 200
child.reference = MagicMock()
child.reference.message_id = 100
child.channel = MagicMock()
child.channel.fetch_message = AsyncMock(return_value=parent)
child.channel.name = "general"
result = await _build_reply_chain_context(child, mock_bot_user)
assert result is not None
assert "Parent message" in result
@pytest.mark.asyncio
async def test_build_reply_chain_deep_chain(self, mock_bot_user: MagicMock) -> None:
"""A → B → C → D reply chain returns full chain in chronological order."""
msg_d = mock_message(content="Message D", message_id=4)
msg_d.reference = None
msg_c = mock_message(content="Message C", message_id=3)
ref_c = MagicMock()
ref_c.message_id = 4
msg_c.reference = ref_c
msg_b = mock_message(content="Message B", message_id=2)
ref_b = MagicMock()
ref_b.message_id = 3
msg_b.reference = ref_b
# Current message replying to B
ref_a = MagicMock()
ref_a.message_id = 2
msg_a = MagicMock(spec=discord.Message)
msg_a.id = 1
msg_a.reference = ref_a
msg_a.channel = MagicMock()
msg_a.channel.name = "general"
# Mock fetch to return the chain
message_map = {2: msg_b, 3: msg_c, 4: msg_d}
async def fetch_message(msg_id: int) -> MagicMock:
if msg_id in message_map:
return message_map[msg_id]
raise discord.NotFound(MagicMock(), "Not found")
msg_a.channel.fetch_message = AsyncMock(side_effect=fetch_message)
result = await _build_reply_chain_context(msg_a, mock_bot_user)
assert result is not None
# Should have all messages from the chain
@pytest.mark.asyncio
async def test_build_reply_chain_max_depth(self, mock_bot_user: MagicMock) -> None:
"""Chain depth > MAX_CONTEXT_MESSAGES stops at limit."""
# Create a chain longer than MAX_CONTEXT_MESSAGES
messages = {}
for i in range(MAX_CONTEXT_MESSAGES + 5, 0, -1):
msg = mock_message(content=f"Message {i}", message_id=i)
if i < MAX_CONTEXT_MESSAGES + 5:
ref = MagicMock()
ref.message_id = i + 1
msg.reference = ref
else:
msg.reference = None
messages[i] = msg
# Start from message 1
start = MagicMock(spec=discord.Message)
start.id = 0
start.reference = MagicMock()
start.reference.message_id = 1
start.channel = MagicMock()
start.channel.name = "general"
async def fetch_message(msg_id: int) -> MagicMock:
if msg_id in messages:
return messages[msg_id]
raise discord.NotFound(MagicMock(), "Not found")
start.channel.fetch_message = AsyncMock(side_effect=fetch_message)
result = await _build_reply_chain_context(start, mock_bot_user)
# Should have at most MAX_CONTEXT_MESSAGES
assert result is not None
@pytest.mark.asyncio
async def test_build_reply_chain_no_reply(self, mock_bot_user: MagicMock) -> None:
"""Message is not a reply returns None."""
msg = MagicMock(spec=discord.Message)
msg.reference = None
result = await _build_reply_chain_context(msg, mock_bot_user)
assert result is None
@pytest.mark.asyncio
async def test_build_reply_chain_deleted_message(
self, mock_bot_user: MagicMock
) -> None:
"""Reply to deleted message handles gracefully with partial chain."""
msg = MagicMock(spec=discord.Message)
msg.id = 200
msg.reference = MagicMock()
msg.reference.message_id = 100
msg.channel = MagicMock()
msg.channel.fetch_message = AsyncMock(
side_effect=discord.NotFound(MagicMock(), "Not found")
)
msg.channel.name = "general"
await _build_reply_chain_context(msg, mock_bot_user)
# Should handle gracefully - may return None or partial context
# Either is acceptable
@pytest.mark.asyncio
async def test_build_reply_chain_missing_reference_data(
self, mock_bot_user: MagicMock
) -> None:
"""message.reference.message_id is None returns None."""
msg = MagicMock(spec=discord.Message)
msg.reference = MagicMock()
msg.reference.message_id = None
result = await _build_reply_chain_context(msg, mock_bot_user)
assert result is None
@pytest.mark.asyncio
async def test_build_reply_chain_http_exception(
self, mock_bot_user: MagicMock
) -> None:
"""discord.HTTPException on fetch stops chain."""
msg = MagicMock(spec=discord.Message)
msg.id = 200
msg.reference = MagicMock()
msg.reference.message_id = 100
msg.channel = MagicMock()
msg.channel.fetch_message = AsyncMock(
side_effect=discord.HTTPException(MagicMock(), "HTTP error")
)
msg.channel.name = "general"
await _build_reply_chain_context(msg, mock_bot_user)
# Should handle gracefully
class TestCombinedContext:
"""Tests for combined thread + reply context."""
@pytest.mark.asyncio
async def test_combined_context_thread_with_reply(
self, mock_bot_user: MagicMock
) -> None:
"""Reply inside thread includes both contexts."""
# Create a thread with messages
thread = MagicMock(spec=discord.Thread)
thread.id = 666666
thread.parent = MagicMock(spec=discord.TextChannel)
thread.parent.fetch_message = AsyncMock(
side_effect=discord.NotFound(MagicMock(), "")
)
# Thread history
thread_messages = [
mock_message(content="Thread msg 1", message_id=1),
mock_message(content="Thread msg 2", message_id=2),
]
def history(**kwargs: Any) -> AsyncIteratorMock:
return AsyncIteratorMock(thread_messages)
thread.history = history
# Message is a reply to another message in the thread
parent_msg = mock_message(content="Parent message", message_id=2)
parent_msg.reference = None
ref = MagicMock()
ref.message_id = 2
msg = MagicMock(spec=discord.Message)
msg.id = 999
msg.channel = thread
msg.reference = ref
msg.channel.fetch_message = AsyncMock(return_value=parent_msg)
msg.channel.name = "test-thread"
result = await _build_conversation_context(msg, mock_bot_user)
# Should have context from the thread
assert result is not None
assert "Conversation history" in result
@pytest.mark.asyncio
async def test_build_conversation_context_routes_to_thread(
self, mock_bot_user: MagicMock
) -> None:
"""Message in thread routes to _build_thread_context."""
thread = MagicMock(spec=discord.Thread)
thread.id = 666666
thread.parent = MagicMock(spec=discord.TextChannel)
thread.parent.fetch_message = AsyncMock(
side_effect=discord.NotFound(MagicMock(), "")
)
messages = [mock_message(content="Thread msg")]
def history(**kwargs: Any) -> AsyncIteratorMock:
return AsyncIteratorMock(messages)
thread.history = history
msg = MagicMock(spec=discord.Message)
msg.id = 999
msg.channel = thread
msg.reference = None
result = await _build_conversation_context(msg, mock_bot_user)
assert result is not None
@pytest.mark.asyncio
async def test_build_conversation_context_routes_to_reply(
self, mock_bot_user: MagicMock
) -> None:
"""Message with reference routes to _build_reply_chain_context."""
parent = mock_message(content="Parent", message_id=100)
parent.reference = None
msg = MagicMock(spec=discord.Message)
msg.id = 200
msg.channel = MagicMock(spec=discord.TextChannel) # Not a thread
msg.reference = MagicMock()
msg.reference.message_id = 100
msg.channel.fetch_message = AsyncMock(return_value=parent)
msg.channel.name = "general"
result = await _build_conversation_context(msg, mock_bot_user)
assert result is not None
class TestContextFormatting:
"""Tests for context formatting."""
def test_format_message_content_mentions(self) -> None:
"""Messages with <@123> mentions are converted to @username."""
msg = MagicMock(spec=discord.Message)
msg.content = "Hello <@123456789> how are you?"
user = MagicMock()
user.id = 123456789
user.display_name = "TestUser"
msg.mentions = [user]
msg.role_mentions = []
msg.channel_mentions = []
result = format_message_content(msg)
assert "@TestUser" in result
assert "<@123456789>" not in result
def test_format_message_content_roles(self) -> None:
"""Messages with <@&456> roles are converted to @rolename."""
msg = MagicMock(spec=discord.Message)
msg.content = "Attention <@&456789> members"
role = MagicMock()
role.id = 456789
role.name = "Moderators"
msg.mentions = []
msg.role_mentions = [role]
msg.channel_mentions = []
result = format_message_content(msg)
assert "@Moderators" in result
assert "<@&456789>" not in result
def test_format_message_content_channels(self) -> None:
"""Messages with <#789> channels are converted to #channelname."""
msg = MagicMock(spec=discord.Message)
msg.content = "Check out <#789012>"
channel = MagicMock()
channel.id = 789012
channel.name = "announcements"
msg.mentions = []
msg.role_mentions = []
msg.channel_mentions = [channel]
result = format_message_content(msg)
assert "#announcements" in result
assert "<#789012>" not in result
def test_context_format_output(self, mock_bot_user: MagicMock) -> None:
"""Build full context has expected format."""
messages: list[Any] = [
mock_message(content="Hello bot", author_bot=False),
]
messages[0].type = discord.MessageType.default
result = _format_messages_as_context(messages, mock_bot_user)
assert result is not None
assert "Conversation history" in result
assert "---" in result
def test_context_format_with_username(self, mock_bot_user: MagicMock) -> None:
"""Messages from users include @username: prefix."""
msg = mock_message(content="User message", author_bot=False)
msg.author.display_name = "TestUser"
msg.type = discord.MessageType.default
result = _format_messages_as_context([msg], mock_bot_user)
assert result is not None
assert "@TestUser:" in result
def test_context_format_bot_marker(self, mock_bot_user: MagicMock) -> None:
"""Bot messages in context are marked as OnyxBot:."""
msg = mock_message(
content="Bot response",
author_bot=True,
author_id=mock_bot_user.id,
)
msg.type = discord.MessageType.default
result = _format_messages_as_context([msg], mock_bot_user)
assert result is not None
assert "OnyxBot:" in result

View File

@@ -0,0 +1,157 @@
"""Unit tests for Discord bot utilities.
Tests for:
- Token management (get_bot_token)
- Registration key parsing (parse_discord_registration_key, generate_discord_registration_key)
"""
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.onyxbot.discord.utils import get_bot_token
from onyx.server.manage.discord_bot.utils import generate_discord_registration_key
from onyx.server.manage.discord_bot.utils import parse_discord_registration_key
from onyx.server.manage.discord_bot.utils import REGISTRATION_KEY_PREFIX
class TestGetBotToken:
"""Tests for get_bot_token function."""
def test_get_token_from_env(self) -> None:
"""When env var is set, returns env var."""
with patch("onyx.onyxbot.discord.utils.DISCORD_BOT_TOKEN", "env_token_123"):
result = get_bot_token()
assert result == "env_token_123"
def test_get_token_from_db(self) -> None:
"""When no env var and DB config exists, returns DB token."""
mock_config = MagicMock()
mock_config.bot_token = "db_token_456"
with (
patch("onyx.onyxbot.discord.utils.DISCORD_BOT_TOKEN", None),
patch("onyx.onyxbot.discord.utils.AUTH_TYPE", "basic"), # Not CLOUD
patch("onyx.onyxbot.discord.utils.get_session_with_tenant") as mock_session,
patch(
"onyx.onyxbot.discord.utils.get_discord_bot_config",
return_value=mock_config,
),
):
mock_session.return_value.__enter__ = MagicMock()
mock_session.return_value.__exit__ = MagicMock()
result = get_bot_token()
assert result == "db_token_456"
def test_get_token_none(self) -> None:
"""When no env var and no DB config, returns None."""
with (
patch("onyx.onyxbot.discord.utils.DISCORD_BOT_TOKEN", None),
patch("onyx.onyxbot.discord.utils.AUTH_TYPE", "basic"), # Not CLOUD
patch("onyx.onyxbot.discord.utils.get_session_with_tenant") as mock_session,
patch(
"onyx.onyxbot.discord.utils.get_discord_bot_config",
return_value=None,
),
):
mock_session.return_value.__enter__ = MagicMock()
mock_session.return_value.__exit__ = MagicMock()
result = get_bot_token()
assert result is None
def test_get_token_env_priority(self) -> None:
"""When both env var and DB exist, env var takes priority."""
mock_config = MagicMock()
mock_config.bot_token = "db_token_456"
with (
patch("onyx.onyxbot.discord.utils.DISCORD_BOT_TOKEN", "env_token_123"),
patch(
"onyx.onyxbot.discord.utils.get_discord_bot_config",
return_value=mock_config,
),
):
result = get_bot_token()
# Should return env var, not DB token
assert result == "env_token_123"
class TestParseRegistrationKey:
"""Tests for parse_discord_registration_key function."""
def test_parse_registration_key_valid(self) -> None:
"""Valid key format returns tenant_id."""
key = "discord_tenant123.randomtoken"
result = parse_discord_registration_key(key)
assert result == "tenant123"
def test_parse_registration_key_invalid(self) -> None:
"""Malformed key returns None."""
result = parse_discord_registration_key("malformed_key")
assert result is None
def test_parse_registration_key_missing_prefix(self) -> None:
"""Key without 'discord_' prefix returns None."""
key = "tenant123.randomtoken"
result = parse_discord_registration_key(key)
assert result is None
def test_parse_registration_key_missing_dot(self) -> None:
"""Key without separator '.' returns None."""
key = "discord_tenant123randomtoken"
result = parse_discord_registration_key(key)
assert result is None
def test_parse_registration_key_empty_token(self) -> None:
"""Key with empty token part returns None."""
# This test verifies behavior with empty token after dot
key = "discord_tenant123."
result = parse_discord_registration_key(key)
# Current implementation allows empty token, but returns tenant
# If this should be invalid, update the implementation
assert result == "tenant123" or result is None
def test_parse_registration_key_url_encoded_tenant(self) -> None:
"""Tenant ID with URL encoding is decoded correctly."""
# URL encoded "my tenant" -> "my%20tenant"
key = "discord_my%20tenant.randomtoken"
result = parse_discord_registration_key(key)
assert result == "my tenant"
def test_parse_registration_key_special_chars(self) -> None:
"""Key with special characters in tenant ID."""
# Tenant with slashes (URL encoded)
key = "discord_tenant%2Fwith%2Fslashes.randomtoken"
result = parse_discord_registration_key(key)
assert result == "tenant/with/slashes"
class TestGenerateRegistrationKey:
"""Tests for generate_discord_registration_key function."""
def test_generate_registration_key(self) -> None:
"""Generated key has correct format."""
key = generate_discord_registration_key("tenant123")
assert key.startswith(REGISTRATION_KEY_PREFIX)
assert "tenant123" in key
assert "." in key
# Parse it back to verify round-trip
parsed = parse_discord_registration_key(key)
assert parsed == "tenant123"
def test_generate_registration_key_unique(self) -> None:
"""Each generated key is unique."""
keys = [generate_discord_registration_key("tenant123") for _ in range(10)]
assert len(set(keys)) == 10 # All unique
def test_generate_registration_key_special_tenant(self) -> None:
"""Key generation handles special characters in tenant ID."""
key = generate_discord_registration_key("my tenant/id")
# Should be URL encoded
assert "%20" in key or "%2F" in key
# Parse it back
parsed = parse_discord_registration_key(key)
assert parsed == "my tenant/id"

View File

@@ -0,0 +1,316 @@
"""Unit tests for Discord bot message utilities.
Tests for:
- Message splitting (_split_message)
- Citation formatting (_append_citations)
"""
from unittest.mock import MagicMock
from onyx.chat.models import ChatFullResponse
from onyx.onyxbot.discord.constants import MAX_MESSAGE_LENGTH
from onyx.onyxbot.discord.handle_message import _append_citations
from onyx.onyxbot.discord.handle_message import _split_message
class TestSplitMessage:
"""Tests for _split_message function."""
def test_split_message_under_limit(self) -> None:
"""Message under 2000 chars returns single chunk."""
content = "x" * 1999
chunks = _split_message(content)
assert len(chunks) == 1
assert chunks[0] == content
def test_split_message_at_limit(self) -> None:
"""Message exactly at 2000 chars returns single chunk."""
content = "x" * MAX_MESSAGE_LENGTH
chunks = _split_message(content)
assert len(chunks) == 1
assert chunks[0] == content
def test_split_message_over_limit(self) -> None:
"""Message over 2000 chars splits into multiple chunks."""
content = "x" * 2001
chunks = _split_message(content)
assert len(chunks) == 2
# All chunks should be <= MAX_MESSAGE_LENGTH
for chunk in chunks:
assert len(chunk) <= MAX_MESSAGE_LENGTH
def test_split_at_double_newline(self) -> None:
"""Prefers splitting at double newline."""
# Create content with double newline near the end but before limit
first_part = "x" * 1500
second_part = "y" * 1000
content = f"{first_part}\n\n{second_part}"
chunks = _split_message(content)
assert len(chunks) == 2
# First chunk should end with or right after the double newline
assert chunks[0].endswith("\n\n") or first_part in chunks[0]
def test_split_at_single_newline(self) -> None:
"""When no double newline, splits at single newline."""
first_part = "x" * 1500
second_part = "y" * 1000
content = f"{first_part}\n{second_part}"
chunks = _split_message(content)
assert len(chunks) == 2
def test_split_at_period_space(self) -> None:
"""When no newlines, splits at '. ' (period + space)."""
first_part = "x" * 1500
second_part = "y" * 1000
content = f"{first_part}. {second_part}"
chunks = _split_message(content)
assert len(chunks) == 2
# First chunk should include the period
assert chunks[0].endswith(". ") or chunks[0].endswith(".")
def test_split_at_space(self) -> None:
"""When no better breakpoints, splits at space."""
first_part = "x" * 1500
second_part = "y" * 1000
content = f"{first_part} {second_part}"
chunks = _split_message(content)
assert len(chunks) == 2
def test_split_no_breakpoint(self) -> None:
"""Handles gracefully when no breakpoints available (hard split)."""
# 2001 chars with no spaces or newlines
content = "x" * 2001
chunks = _split_message(content)
assert len(chunks) == 2
# Content should be preserved
assert "".join(chunks) == content
def test_split_threshold_50_percent(self) -> None:
"""Breakpoint at less than 50% of limit is skipped."""
# Put a breakpoint early (at 40% = 800 chars)
# and another late (at 80% = 1600 chars)
early_part = "x" * 800
middle_part = "m" * 800 # Total: 1600
late_part = "y" * 600 # Total: 2200
content = f"{early_part}\n\n{middle_part}\n\n{late_part}"
chunks = _split_message(content)
# Should prefer the later breakpoint over the 40% one
assert len(chunks) == 2
# First chunk should be longer than 800 chars
assert len(chunks[0]) > 800
def test_split_multiple_chunks(self) -> None:
"""5000 char message splits into 3 chunks."""
content = "x" * 5000
chunks = _split_message(content)
assert len(chunks) == 3
# Each chunk should be <= MAX_MESSAGE_LENGTH
for chunk in chunks:
assert len(chunk) <= MAX_MESSAGE_LENGTH
def test_split_preserves_content(self) -> None:
"""Concatenated chunks equal original content."""
content = "Hello world! " * 200 # About 2600 chars
chunks = _split_message(content)
assert "".join(chunks) == content
def test_split_with_unicode(self) -> None:
"""Handles unicode characters correctly."""
# Mix of ASCII and unicode
content = "Hello " + "🎉" * 500 + " World " + "x" * 1500
chunks = _split_message(content)
# Should not break in the middle of emoji
assert "".join(chunks) == content
class TestAppendCitations:
"""Tests for _append_citations function."""
def _make_response(
self,
answer: str,
citations: list[dict] | None = None,
documents: list[dict] | None = None,
) -> ChatFullResponse:
"""Helper to create ChatFullResponse with citations."""
response = MagicMock(spec=ChatFullResponse)
response.answer = answer
if citations:
citation_mocks = []
for c in citations:
cm = MagicMock()
cm.citation_number = c.get("num", 1)
cm.document_id = c.get("doc_id", "doc1")
citation_mocks.append(cm)
response.citation_info = citation_mocks
else:
response.citation_info = None
if documents:
doc_mocks = []
for d in documents:
dm = MagicMock()
dm.document_id = d.get("doc_id", "doc1")
dm.semantic_identifier = d.get("name", "Source")
dm.link = d.get("link")
doc_mocks.append(dm)
response.top_documents = doc_mocks
else:
response.top_documents = None
return response
def test_format_citations_empty_list(self) -> None:
"""No citations returns answer unchanged."""
response = self._make_response("Test answer")
result = _append_citations("Test answer", response)
assert result == "Test answer"
assert "Sources:" not in result
def test_format_citations_single(self) -> None:
"""Single citation is formatted correctly."""
response = self._make_response(
"Test answer",
citations=[{"num": 1, "doc_id": "doc1"}],
documents=[
{
"doc_id": "doc1",
"name": "Document One",
"link": "https://example.com",
}
],
)
result = _append_citations("Test answer", response)
assert "**Sources:**" in result
assert "[Document One](<https://example.com>)" in result
def test_format_citations_multiple(self) -> None:
"""Multiple citations are all formatted and numbered."""
response = self._make_response(
"Test answer",
citations=[
{"num": 1, "doc_id": "doc1"},
{"num": 2, "doc_id": "doc2"},
{"num": 3, "doc_id": "doc3"},
],
documents=[
{"doc_id": "doc1", "name": "Doc 1", "link": "https://example.com/1"},
{"doc_id": "doc2", "name": "Doc 2", "link": "https://example.com/2"},
{"doc_id": "doc3", "name": "Doc 3", "link": "https://example.com/3"},
],
)
result = _append_citations("Test answer", response)
assert "1. [Doc 1]" in result
assert "2. [Doc 2]" in result
assert "3. [Doc 3]" in result
def test_format_citations_max_five(self) -> None:
"""Only first 5 citations are included."""
citations = [{"num": i, "doc_id": f"doc{i}"} for i in range(1, 11)]
documents = [
{
"doc_id": f"doc{i}",
"name": f"Doc {i}",
"link": f"https://example.com/{i}",
}
for i in range(1, 11)
]
response = self._make_response(
"Test answer", citations=citations, documents=documents
)
result = _append_citations("Test answer", response)
# Should have 5 citations
assert "1. [Doc 1]" in result
assert "5. [Doc 5]" in result
# Should NOT have 6th citation
assert "6. [Doc 6]" not in result
def test_format_citation_no_link(self) -> None:
"""Citation without link formats as plain text (no markdown)."""
response = self._make_response(
"Test answer",
citations=[{"num": 1, "doc_id": "doc1"}],
documents=[{"doc_id": "doc1", "name": "No Link Doc", "link": None}],
)
result = _append_citations("Test answer", response)
assert "1. No Link Doc" in result
# Should not have markdown link syntax
assert "[No Link Doc](<" not in result
def test_format_citation_empty_name(self) -> None:
"""Empty semantic_identifier defaults to 'Source'."""
response = self._make_response(
"Test answer",
citations=[{"num": 1, "doc_id": "doc1"}],
documents=[{"doc_id": "doc1", "name": "", "link": "https://example.com"}],
)
result = _append_citations("Test answer", response)
# Should use fallback "Source" name
assert "[Source]" in result or "Source" in result
def test_format_citation_link_with_brackets(self) -> None:
"""Link with special characters is wrapped with angle brackets."""
response = self._make_response(
"Test answer",
citations=[{"num": 1, "doc_id": "doc1"}],
documents=[
{
"doc_id": "doc1",
"name": "Special Doc",
"link": "https://example.com/path?query=value&other=123",
}
],
)
result = _append_citations("Test answer", response)
# Discord markdown uses <link> to prevent embed
assert "(<https://example.com" in result
def test_format_citations_sorted_by_number(self) -> None:
"""Citations are sorted by citation number."""
# Add in reverse order
response = self._make_response(
"Test answer",
citations=[
{"num": 3, "doc_id": "doc3"},
{"num": 1, "doc_id": "doc1"},
{"num": 2, "doc_id": "doc2"},
],
documents=[
{"doc_id": "doc1", "name": "Doc 1", "link": "https://example.com/1"},
{"doc_id": "doc2", "name": "Doc 2", "link": "https://example.com/2"},
{"doc_id": "doc3", "name": "Doc 3", "link": "https://example.com/3"},
],
)
result = _append_citations("Test answer", response)
# Find positions
pos1 = result.find("1. [Doc 1]")
pos2 = result.find("2. [Doc 2]")
pos3 = result.find("3. [Doc 3]")
# Should be in order
assert pos1 < pos2 < pos3
def test_format_citations_with_missing_document(self) -> None:
"""Citation referencing non-existent document is skipped."""
response = self._make_response(
"Test answer",
citations=[
{"num": 1, "doc_id": "doc1"},
{"num": 2, "doc_id": "doc_missing"}, # No matching document
],
documents=[
{"doc_id": "doc1", "name": "Doc 1", "link": "https://example.com/1"},
],
)
result = _append_citations("Test answer", response)
assert "Doc 1" in result
# Missing doc should not appear
assert "doc_missing" not in result.lower()

View File

@@ -0,0 +1,645 @@
"""Unit tests for Discord bot should_respond logic.
Tests the decision tree for when the bot should respond to messages.
"""
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
from unittest.mock import patch
import discord
import pytest
from onyx.onyxbot.discord.handle_message import check_implicit_invocation
from onyx.onyxbot.discord.handle_message import should_respond
class TestBasicShouldRespond:
"""Tests for basic should_respond decision logic."""
@pytest.mark.asyncio
async def test_should_respond_guild_disabled(
self, mock_discord_message: MagicMock, mock_bot_user: MagicMock
) -> None:
"""Guild config enabled=false returns False."""
mock_guild_config = MagicMock()
mock_guild_config.enabled = False
with patch(
"onyx.onyxbot.discord.handle_message.get_session_with_tenant"
) as mock_session:
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
with patch(
"onyx.onyxbot.discord.handle_message.get_guild_config_by_discord_id",
return_value=mock_guild_config,
):
result = await should_respond(
mock_discord_message, "tenant1", mock_bot_user
)
assert result.should_respond is False
@pytest.mark.asyncio
async def test_should_respond_guild_enabled(
self, mock_discord_message: MagicMock, mock_bot_user: MagicMock
) -> None:
"""Guild config enabled=true proceeds to channel check."""
mock_guild_config = MagicMock()
mock_guild_config.enabled = True
mock_guild_config.default_persona_id = 1
mock_channel_config = MagicMock()
mock_channel_config.enabled = True
mock_channel_config.require_bot_invocation = False
mock_channel_config.thread_only_mode = False
mock_channel_config.persona_override_id = None
with patch(
"onyx.onyxbot.discord.handle_message.get_session_with_tenant"
) as mock_session:
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
with (
patch(
"onyx.onyxbot.discord.handle_message.get_guild_config_by_discord_id",
return_value=mock_guild_config,
),
patch(
"onyx.onyxbot.discord.handle_message.get_channel_config_by_discord_ids",
return_value=mock_channel_config,
),
):
result = await should_respond(
mock_discord_message, "tenant1", mock_bot_user
)
assert result.should_respond is True
@pytest.mark.asyncio
async def test_should_respond_channel_disabled(
self, mock_discord_message: MagicMock, mock_bot_user: MagicMock
) -> None:
"""Channel config enabled=false returns False."""
mock_guild_config = MagicMock()
mock_guild_config.enabled = True
mock_channel_config = MagicMock()
mock_channel_config.enabled = False
with patch(
"onyx.onyxbot.discord.handle_message.get_session_with_tenant"
) as mock_session:
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
with (
patch(
"onyx.onyxbot.discord.handle_message.get_guild_config_by_discord_id",
return_value=mock_guild_config,
),
patch(
"onyx.onyxbot.discord.handle_message.get_channel_config_by_discord_ids",
return_value=mock_channel_config,
),
):
result = await should_respond(
mock_discord_message, "tenant1", mock_bot_user
)
assert result.should_respond is False
@pytest.mark.asyncio
async def test_should_respond_channel_enabled(
self, mock_discord_message: MagicMock, mock_bot_user: MagicMock
) -> None:
"""Channel config enabled=true proceeds to mention check."""
mock_guild_config = MagicMock()
mock_guild_config.enabled = True
mock_guild_config.default_persona_id = 2
mock_channel_config = MagicMock()
mock_channel_config.enabled = True
mock_channel_config.require_bot_invocation = False
mock_channel_config.thread_only_mode = False
mock_channel_config.persona_override_id = None
with patch(
"onyx.onyxbot.discord.handle_message.get_session_with_tenant"
) as mock_session:
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
with (
patch(
"onyx.onyxbot.discord.handle_message.get_guild_config_by_discord_id",
return_value=mock_guild_config,
),
patch(
"onyx.onyxbot.discord.handle_message.get_channel_config_by_discord_ids",
return_value=mock_channel_config,
),
):
result = await should_respond(
mock_discord_message, "tenant1", mock_bot_user
)
assert result.should_respond is True
assert result.persona_id == 2
@pytest.mark.asyncio
async def test_should_respond_channel_not_found(
self, mock_discord_message: MagicMock, mock_bot_user: MagicMock
) -> None:
"""No channel config returns False (not whitelisted)."""
mock_guild_config = MagicMock()
mock_guild_config.enabled = True
with patch(
"onyx.onyxbot.discord.handle_message.get_session_with_tenant"
) as mock_session:
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
with (
patch(
"onyx.onyxbot.discord.handle_message.get_guild_config_by_discord_id",
return_value=mock_guild_config,
),
patch(
"onyx.onyxbot.discord.handle_message.get_channel_config_by_discord_ids",
return_value=None, # No config
),
):
result = await should_respond(
mock_discord_message, "tenant1", mock_bot_user
)
assert result.should_respond is False
@pytest.mark.asyncio
async def test_should_respond_require_mention_true_no_mention(
self, mock_discord_message: MagicMock, mock_bot_user: MagicMock
) -> None:
"""require_bot_invocation=true with no @mention returns False."""
mock_guild_config = MagicMock()
mock_guild_config.enabled = True
mock_guild_config.default_persona_id = 1
mock_channel_config = MagicMock()
mock_channel_config.enabled = True
mock_channel_config.require_bot_invocation = True
mock_channel_config.thread_only_mode = False
mock_channel_config.persona_override_id = None
# No bot mention
mock_discord_message.mentions = []
with patch(
"onyx.onyxbot.discord.handle_message.get_session_with_tenant"
) as mock_session:
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
with (
patch(
"onyx.onyxbot.discord.handle_message.get_guild_config_by_discord_id",
return_value=mock_guild_config,
),
patch(
"onyx.onyxbot.discord.handle_message.get_channel_config_by_discord_ids",
return_value=mock_channel_config,
),
patch(
"onyx.onyxbot.discord.handle_message.check_implicit_invocation",
return_value=False,
),
):
result = await should_respond(
mock_discord_message, "tenant1", mock_bot_user
)
assert result.should_respond is False
@pytest.mark.asyncio
async def test_should_respond_require_mention_true_with_mention(
self, mock_message_with_bot_mention: MagicMock, mock_bot_user: MagicMock
) -> None:
"""require_bot_invocation=true with @mention returns True."""
mock_guild_config = MagicMock()
mock_guild_config.enabled = True
mock_guild_config.default_persona_id = 1
mock_channel_config = MagicMock()
mock_channel_config.enabled = True
mock_channel_config.require_bot_invocation = True
mock_channel_config.thread_only_mode = False
mock_channel_config.persona_override_id = None
with patch(
"onyx.onyxbot.discord.handle_message.get_session_with_tenant"
) as mock_session:
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
with (
patch(
"onyx.onyxbot.discord.handle_message.get_guild_config_by_discord_id",
return_value=mock_guild_config,
),
patch(
"onyx.onyxbot.discord.handle_message.get_channel_config_by_discord_ids",
return_value=mock_channel_config,
),
):
result = await should_respond(
mock_message_with_bot_mention, "tenant1", mock_bot_user
)
assert result.should_respond is True
@pytest.mark.asyncio
async def test_should_respond_require_mention_false_no_mention(
self, mock_discord_message: MagicMock, mock_bot_user: MagicMock
) -> None:
"""require_bot_invocation=false with no @mention returns True."""
mock_guild_config = MagicMock()
mock_guild_config.enabled = True
mock_guild_config.default_persona_id = 1
mock_channel_config = MagicMock()
mock_channel_config.enabled = True
mock_channel_config.require_bot_invocation = False
mock_channel_config.thread_only_mode = False
mock_channel_config.persona_override_id = None
with patch(
"onyx.onyxbot.discord.handle_message.get_session_with_tenant"
) as mock_session:
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
with (
patch(
"onyx.onyxbot.discord.handle_message.get_guild_config_by_discord_id",
return_value=mock_guild_config,
),
patch(
"onyx.onyxbot.discord.handle_message.get_channel_config_by_discord_ids",
return_value=mock_channel_config,
),
):
result = await should_respond(
mock_discord_message, "tenant1", mock_bot_user
)
assert result.should_respond is True
class TestImplicitShouldRespond:
"""Tests for implicit invocation (no @mention required in certain contexts)."""
@pytest.mark.asyncio
async def test_implicit_respond_reply_to_bot_message(
self, mock_bot_user: MagicMock
) -> None:
"""User replies to a bot message returns True."""
# Create a message that replies to the bot
msg = MagicMock(spec=discord.Message)
msg.reference = MagicMock()
msg.reference.message_id = 12345
# Mock the referenced message as a bot message
referenced_msg = MagicMock()
referenced_msg.author.id = mock_bot_user.id
msg.channel = MagicMock()
msg.channel.fetch_message = AsyncMock(return_value=referenced_msg)
result = await check_implicit_invocation(msg, mock_bot_user)
assert result is True
@pytest.mark.asyncio
async def test_implicit_respond_reply_to_user_message(
self, mock_bot_user: MagicMock
) -> None:
"""User replies to another user's message returns False."""
msg = MagicMock(spec=discord.Message)
msg.reference = MagicMock()
msg.reference.message_id = 12345
# Mock the referenced message as a user message
referenced_msg = MagicMock()
referenced_msg.author.id = 999999 # Different from bot
msg.channel = MagicMock()
msg.channel.fetch_message = AsyncMock(return_value=referenced_msg)
result = await check_implicit_invocation(msg, mock_bot_user)
assert result is False
@pytest.mark.asyncio
async def test_implicit_respond_in_bot_owned_thread(
self, mock_bot_user: MagicMock
) -> None:
"""Message in thread owned by bot returns True."""
thread = MagicMock(spec=discord.Thread)
thread.owner_id = mock_bot_user.id # Bot owns the thread
thread.parent = MagicMock(spec=discord.TextChannel)
msg = MagicMock(spec=discord.Message)
msg.reference = None
msg.channel = thread
result = await check_implicit_invocation(msg, mock_bot_user)
assert result is True
@pytest.mark.asyncio
async def test_implicit_respond_in_user_owned_thread(
self, mock_bot_user: MagicMock
) -> None:
"""Message in thread owned by user returns False."""
thread = MagicMock(spec=discord.Thread)
thread.owner_id = 999999 # User owns the thread
thread.parent = MagicMock(spec=discord.TextChannel)
msg = MagicMock(spec=discord.Message)
msg.reference = None
msg.channel = thread
result = await check_implicit_invocation(msg, mock_bot_user)
assert result is False
@pytest.mark.asyncio
async def test_implicit_respond_reply_in_bot_thread(
self, mock_bot_user: MagicMock
) -> None:
"""Reply to user in bot-owned thread returns True (thread context)."""
thread = MagicMock(spec=discord.Thread)
thread.owner_id = mock_bot_user.id
thread.parent = MagicMock(spec=discord.TextChannel)
# User replying to another user in bot's thread
referenced_msg = MagicMock()
referenced_msg.author.id = 888888 # Another user
msg = MagicMock(spec=discord.Message)
msg.reference = MagicMock()
msg.reference.message_id = 12345
msg.channel = thread
msg.channel.fetch_message = AsyncMock(return_value=referenced_msg)
result = await check_implicit_invocation(msg, mock_bot_user)
# Should return True because it's in bot's thread
assert result is True
@pytest.mark.asyncio
async def test_implicit_respond_thread_from_bot_message(
self, mock_bot_user: MagicMock
) -> None:
"""Thread created from bot message (non-forum) returns True."""
thread = MagicMock(spec=discord.Thread)
thread.id = 777777
thread.owner_id = 999999 # User owns thread but...
thread.parent = MagicMock(spec=discord.TextChannel)
# The starter message is from the bot
starter_msg = MagicMock()
starter_msg.author.id = mock_bot_user.id
thread.parent.fetch_message = AsyncMock(return_value=starter_msg)
msg = MagicMock(spec=discord.Message)
msg.reference = None
msg.channel = thread
result = await check_implicit_invocation(msg, mock_bot_user)
assert result is True
@pytest.mark.asyncio
async def test_implicit_respond_forum_channel_excluded(
self, mock_bot_user: MagicMock, mock_thread_forum_parent: MagicMock
) -> None:
"""Thread parent is ForumChannel - does NOT check starter message."""
msg = MagicMock(spec=discord.Message)
msg.reference = None
msg.channel = mock_thread_forum_parent
mock_thread_forum_parent.owner_id = 999999 # Not bot
result = await check_implicit_invocation(msg, mock_bot_user)
# Should be False - forum threads don't use starter message check
assert result is False
@pytest.mark.asyncio
async def test_implicit_respond_combined_with_mention(
self, mock_bot_user: MagicMock
) -> None:
"""Has @mention AND is implicit - should return True (either works)."""
thread = MagicMock(spec=discord.Thread)
thread.owner_id = mock_bot_user.id
thread.parent = MagicMock(spec=discord.TextChannel)
msg = MagicMock(spec=discord.Message)
msg.reference = None
msg.channel = thread
msg.mentions = [mock_bot_user]
result = await check_implicit_invocation(msg, mock_bot_user)
assert result is True
@pytest.mark.asyncio
async def test_implicit_respond_reference_fetch_fails(
self, mock_bot_user: MagicMock
) -> None:
"""discord.NotFound when fetching reply reference returns False."""
msg = MagicMock(spec=discord.Message)
msg.reference = MagicMock()
msg.reference.message_id = 12345
msg.channel = MagicMock()
msg.channel.fetch_message = AsyncMock(
side_effect=discord.NotFound(MagicMock(), "Not found")
)
result = await check_implicit_invocation(msg, mock_bot_user)
assert result is False
@pytest.mark.asyncio
async def test_implicit_respond_http_exception(
self, mock_bot_user: MagicMock
) -> None:
"""discord.HTTPException during check returns False."""
msg = MagicMock(spec=discord.Message)
msg.reference = MagicMock()
msg.reference.message_id = 12345
msg.channel = MagicMock()
msg.channel.fetch_message = AsyncMock(
side_effect=discord.HTTPException(MagicMock(), "HTTP error")
)
result = await check_implicit_invocation(msg, mock_bot_user)
assert result is False
class TestThreadOnlyMode:
"""Tests for thread_only_mode behavior."""
@pytest.mark.asyncio
async def test_thread_only_mode_message_in_thread(
self, mock_bot_user: MagicMock
) -> None:
"""thread_only_mode=true, message in thread returns True."""
mock_guild_config = MagicMock()
mock_guild_config.enabled = True
mock_guild_config.default_persona_id = 1
mock_channel_config = MagicMock()
mock_channel_config.enabled = True
mock_channel_config.require_bot_invocation = False
mock_channel_config.thread_only_mode = True
mock_channel_config.persona_override_id = None
# Create thread message
thread = MagicMock(spec=discord.Thread)
thread.parent = MagicMock(spec=discord.TextChannel)
thread.parent.id = 111111111
msg = MagicMock(spec=discord.Message)
msg.guild = MagicMock()
msg.guild.id = 987654321
msg.channel = thread
msg.mentions = []
msg.reference = None
with patch(
"onyx.onyxbot.discord.handle_message.get_session_with_tenant"
) as mock_session:
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
with (
patch(
"onyx.onyxbot.discord.handle_message.get_guild_config_by_discord_id",
return_value=mock_guild_config,
),
patch(
"onyx.onyxbot.discord.handle_message.get_channel_config_by_discord_ids",
return_value=mock_channel_config,
),
):
result = await should_respond(msg, "tenant1", mock_bot_user)
assert result.should_respond is True
assert result.thread_only_mode is True
@pytest.mark.asyncio
async def test_thread_only_mode_false_message_in_channel(
self, mock_discord_message: MagicMock, mock_bot_user: MagicMock
) -> None:
"""thread_only_mode=false, message in channel returns True."""
mock_guild_config = MagicMock()
mock_guild_config.enabled = True
mock_guild_config.default_persona_id = 1
mock_channel_config = MagicMock()
mock_channel_config.enabled = True
mock_channel_config.require_bot_invocation = False
mock_channel_config.thread_only_mode = False
mock_channel_config.persona_override_id = None
with patch(
"onyx.onyxbot.discord.handle_message.get_session_with_tenant"
) as mock_session:
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
with (
patch(
"onyx.onyxbot.discord.handle_message.get_guild_config_by_discord_id",
return_value=mock_guild_config,
),
patch(
"onyx.onyxbot.discord.handle_message.get_channel_config_by_discord_ids",
return_value=mock_channel_config,
),
):
result = await should_respond(
mock_discord_message, "tenant1", mock_bot_user
)
assert result.should_respond is True
assert result.thread_only_mode is False
class TestEdgeCases:
"""Edge case tests for should_respond."""
@pytest.mark.asyncio
async def test_should_respond_no_guild(self, mock_bot_user: MagicMock) -> None:
"""Message without guild (DM) returns False."""
msg = MagicMock(spec=discord.Message)
msg.guild = None
result = await should_respond(msg, "tenant1", mock_bot_user)
assert result.should_respond is False
@pytest.mark.asyncio
async def test_should_respond_thread_uses_parent_channel_config(
self, mock_bot_user: MagicMock
) -> None:
"""Thread under channel uses parent channel's config."""
mock_guild_config = MagicMock()
mock_guild_config.enabled = True
mock_guild_config.default_persona_id = 1
mock_channel_config = MagicMock()
mock_channel_config.enabled = True
mock_channel_config.require_bot_invocation = False
mock_channel_config.thread_only_mode = False
mock_channel_config.persona_override_id = 5 # Specific persona
# Create thread message
thread = MagicMock(spec=discord.Thread)
thread.id = 666666
thread.parent = MagicMock(spec=discord.TextChannel)
thread.parent.id = 111111111 # Parent channel ID
msg = MagicMock(spec=discord.Message)
msg.guild = MagicMock()
msg.guild.id = 987654321
msg.channel = thread
msg.mentions = []
msg.reference = None
with patch(
"onyx.onyxbot.discord.handle_message.get_session_with_tenant"
) as mock_session:
mock_db = MagicMock()
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
mock_session.return_value.__exit__ = MagicMock()
with (
patch(
"onyx.onyxbot.discord.handle_message.get_guild_config_by_discord_id",
return_value=mock_guild_config,
),
patch(
"onyx.onyxbot.discord.handle_message.get_channel_config_by_discord_ids",
return_value=mock_channel_config,
),
):
result = await should_respond(msg, "tenant1", mock_bot_user)
assert result.should_respond is True
# Should use parent's persona override
assert result.persona_id == 5

View File

@@ -0,0 +1,106 @@
from onyx.onyxbot.slack.formatting import _convert_slack_links_to_markdown
from onyx.onyxbot.slack.formatting import _normalize_link_destinations
from onyx.onyxbot.slack.formatting import _sanitize_html
from onyx.onyxbot.slack.formatting import _transform_outside_code_blocks
from onyx.onyxbot.slack.formatting import format_slack_message
from onyx.onyxbot.slack.utils import remove_slack_text_interactions
from onyx.utils.text_processing import decode_escapes
def test_normalize_citation_link_wraps_url_with_parentheses() -> None:
message = (
"See [[1]](https://example.com/Access%20ID%20Card(s)%20Guide.pdf) for details."
)
normalized = _normalize_link_destinations(message)
assert (
"See [[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>) for details."
== normalized
)
def test_normalize_citation_link_keeps_existing_angle_brackets() -> None:
message = "[[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>)"
normalized = _normalize_link_destinations(message)
assert message == normalized
def test_normalize_citation_link_handles_multiple_links() -> None:
message = (
"[[1]](https://example.com/(USA)%20Guide.pdf) "
"[[2]](https://example.com/Plan(s)%20Overview.pdf)"
)
normalized = _normalize_link_destinations(message)
assert "[[1]](<https://example.com/(USA)%20Guide.pdf>)" in normalized
assert "[[2]](<https://example.com/Plan(s)%20Overview.pdf>)" in normalized
def test_format_slack_message_keeps_parenthesized_citation_links_intact() -> None:
message = (
"Download [[1]](https://example.com/(USA)%20Access%20ID%20Card(s)%20Guide.pdf)"
)
formatted = format_slack_message(message)
rendered = decode_escapes(remove_slack_text_interactions(formatted))
assert (
"<https://example.com/(USA)%20Access%20ID%20Card(s)%20Guide.pdf|[1]>"
in rendered
)
assert "|[1]>%20Access%20ID%20Card" not in rendered
def test_slack_style_links_converted_to_clickable_links() -> None:
message = "Visit <https://example.com/page|Example Page> for details."
formatted = format_slack_message(message)
assert "<https://example.com/page|Example Page>" in formatted
assert "&lt;" not in formatted
def test_slack_style_links_preserved_inside_code_blocks() -> None:
message = "```\n<https://example.com|click>\n```"
converted = _convert_slack_links_to_markdown(message)
assert "<https://example.com|click>" in converted
def test_html_tags_stripped_outside_code_blocks() -> None:
message = "Hello<br/>world ```<div>code</div>``` after"
sanitized = _transform_outside_code_blocks(message, _sanitize_html)
assert "<br" not in sanitized
assert "<div>code</div>" in sanitized
def test_format_slack_message_block_spacing() -> None:
message = "Paragraph one.\n\nParagraph two."
formatted = format_slack_message(message)
assert "Paragraph one.\n\nParagraph two." == formatted
def test_format_slack_message_code_block_no_trailing_blank_line() -> None:
message = "```python\nprint('hi')\n```"
formatted = format_slack_message(message)
assert formatted.endswith("print('hi')\n```")
def test_format_slack_message_ampersand_not_double_escaped() -> None:
message = 'She said "hello" & goodbye.'
formatted = format_slack_message(message)
assert "&amp;" in formatted
assert "&quot;" not in formatted

View File

@@ -0,0 +1,57 @@
from typing import Any
from unittest.mock import Mock
from onyx.configs.constants import MilestoneRecordType
from onyx.utils import telemetry as telemetry_utils
def test_mt_cloud_telemetry_noop_when_not_multi_tenant(monkeypatch: Any) -> None:
fetch_impl = Mock()
monkeypatch.setattr(
telemetry_utils,
"fetch_versioned_implementation_with_fallback",
fetch_impl,
)
# mt_cloud_telemetry reads the module-local imported symbol, so patch this path.
monkeypatch.setattr("onyx.utils.telemetry.MULTI_TENANT", False)
telemetry_utils.mt_cloud_telemetry(
tenant_id="tenant-1",
distinct_id="user@example.com",
event=MilestoneRecordType.USER_MESSAGE_SENT,
properties={"origin": "web"},
)
fetch_impl.assert_not_called()
def test_mt_cloud_telemetry_calls_event_telemetry_when_multi_tenant(
monkeypatch: Any,
) -> None:
event_telemetry = Mock()
fetch_impl = Mock(return_value=event_telemetry)
monkeypatch.setattr(
telemetry_utils,
"fetch_versioned_implementation_with_fallback",
fetch_impl,
)
# mt_cloud_telemetry reads the module-local imported symbol, so patch this path.
monkeypatch.setattr("onyx.utils.telemetry.MULTI_TENANT", True)
telemetry_utils.mt_cloud_telemetry(
tenant_id="tenant-1",
distinct_id="user@example.com",
event=MilestoneRecordType.USER_MESSAGE_SENT,
properties={"origin": "web"},
)
fetch_impl.assert_called_once_with(
module="onyx.utils.telemetry",
attribute="event_telemetry",
fallback=telemetry_utils.noop_fallback,
)
event_telemetry.assert_called_once_with(
"user@example.com",
MilestoneRecordType.USER_MESSAGE_SENT,
{"origin": "web", "tenant_id": "tenant-1"},
)

View File

@@ -221,6 +221,13 @@ services:
- NOTIFY_SLACKBOT_NO_ANSWER=${NOTIFY_SLACKBOT_NO_ANSWER:-}
- ONYX_BOT_MAX_QPM=${ONYX_BOT_MAX_QPM:-}
- ONYX_BOT_MAX_WAIT_TIME=${ONYX_BOT_MAX_WAIT_TIME:-}
# Discord Bot Configuration (runs via supervisord, requires DISCORD_BOT_TOKEN to be set)
# IMPORTANT: Only one Discord bot instance can run per token - do not scale background workers
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
# API Server connection for Discord bot message processing
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
# Logging
# Leave this on pretty please? Nothing sensitive is collected!
- DISABLE_TELEMETRY=${DISABLE_TELEMETRY:-}

View File

@@ -63,6 +63,11 @@ services:
- S3_ENDPOINT_URL=${S3_ENDPOINT_URL:-http://minio:9000}
- S3_AWS_ACCESS_KEY_ID=${S3_AWS_ACCESS_KEY_ID:-minioadmin}
- S3_AWS_SECRET_ACCESS_KEY=${S3_AWS_SECRET_ACCESS_KEY:-minioadmin}
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
# API Server connection for Discord bot message processing
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
env_file:
- path: .env
required: false

View File

@@ -82,6 +82,11 @@ services:
- S3_ENDPOINT_URL=${S3_ENDPOINT_URL:-http://minio:9000}
- S3_AWS_ACCESS_KEY_ID=${S3_AWS_ACCESS_KEY_ID:-minioadmin}
- S3_AWS_SECRET_ACCESS_KEY=${S3_AWS_SECRET_ACCESS_KEY:-minioadmin}
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
# API Server connection for Discord bot message processing
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
env_file:
- path: .env
required: false

View File

@@ -129,6 +129,11 @@ services:
- S3_ENDPOINT_URL=${S3_ENDPOINT_URL:-http://minio:9000}
- S3_AWS_ACCESS_KEY_ID=${S3_AWS_ACCESS_KEY_ID:-minioadmin}
- S3_AWS_SECRET_ACCESS_KEY=${S3_AWS_SECRET_ACCESS_KEY:-minioadmin}
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
# API Server connection for Discord bot message processing
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
# PRODUCTION: Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
# volumes:
# - ./bundle.pem:/app/bundle.pem:ro

View File

@@ -77,6 +77,13 @@ MINIO_ROOT_PASSWORD=minioadmin
## CORS origins for MCP clients (comma-separated list)
# MCP_SERVER_CORS_ORIGINS=
## Discord Bot Configuration
## The Discord bot allows users to interact with Onyx from Discord servers
## Bot token from Discord Developer Portal (required to enable the bot)
# DISCORD_BOT_TOKEN=
## Command prefix for bot commands (default: "!")
# DISCORD_BOT_INVOKE_CHAR=!
## Celery Configuration
# CELERY_BROKER_POOL_LIMIT=
# CELERY_WORKER_DOCFETCHING_CONCURRENCY=

View File

@@ -0,0 +1,98 @@
{{- if .Values.discordbot.enabled }}
# Discord bot MUST run as a single replica - Discord only allows one client connection per bot token.
# Do NOT enable HPA or increase replicas. Message processing is offloaded to scalable API pods via HTTP.
apiVersion: apps/v1
kind: Deployment
metadata:
name: {{ include "onyx.fullname" . }}-discordbot
labels:
{{- include "onyx.labels" . | nindent 4 }}
{{- with .Values.discordbot.deploymentLabels }}
{{- toYaml . | nindent 4 }}
{{- end }}
spec:
# CRITICAL: Discord bots cannot be horizontally scaled - only one WebSocket connection per token is allowed
replicas: 1
strategy:
type: Recreate # Ensure old pod is terminated before new one starts to avoid duplicate connections
selector:
matchLabels:
{{- include "onyx.selectorLabels" . | nindent 6 }}
{{- if .Values.discordbot.deploymentLabels }}
{{- toYaml .Values.discordbot.deploymentLabels | nindent 6 }}
{{- end }}
template:
metadata:
annotations:
checksum/config: {{ include (print $.Template.BasePath "/configmap.yaml") . | sha256sum }}
{{- with .Values.discordbot.podAnnotations }}
{{- toYaml . | nindent 8 }}
{{- end }}
labels:
{{- include "onyx.labels" . | nindent 8 }}
{{- with .Values.discordbot.deploymentLabels }}
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.discordbot.podLabels }}
{{- toYaml . | nindent 8 }}
{{- end }}
spec:
{{- with .Values.imagePullSecrets }}
imagePullSecrets:
{{- toYaml . | nindent 8 }}
{{- end }}
serviceAccountName: {{ include "onyx.serviceAccountName" . }}
securityContext:
{{- toYaml .Values.discordbot.podSecurityContext | nindent 8 }}
{{- with .Values.discordbot.nodeSelector }}
nodeSelector:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.discordbot.affinity }}
affinity:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.discordbot.tolerations }}
tolerations:
{{- toYaml . | nindent 8 }}
{{- end }}
containers:
- name: discordbot
securityContext:
{{- toYaml .Values.discordbot.securityContext | nindent 12 }}
image: "{{ .Values.discordbot.image.repository }}:{{ .Values.discordbot.image.tag | default .Values.global.version }}"
imagePullPolicy: {{ .Values.global.pullPolicy }}
command: ["python", "onyx/onyxbot/discord/client.py"]
resources:
{{- toYaml .Values.discordbot.resources | nindent 12 }}
envFrom:
- configMapRef:
name: {{ .Values.config.envConfigMapName }}
env:
{{- include "onyx.envSecrets" . | nindent 12}}
# Discord bot token - required for bot to connect
{{- if .Values.discordbot.botToken }}
- name: DISCORD_BOT_TOKEN
value: {{ .Values.discordbot.botToken | quote }}
{{- end }}
{{- if .Values.discordbot.botTokenSecretName }}
- name: DISCORD_BOT_TOKEN
valueFrom:
secretKeyRef:
name: {{ .Values.discordbot.botTokenSecretName }}
key: {{ .Values.discordbot.botTokenSecretKey | default "token" }}
{{- end }}
# Command prefix for bot commands (default: "!")
{{- if .Values.discordbot.invokeChar }}
- name: DISCORD_BOT_INVOKE_CHAR
value: {{ .Values.discordbot.invokeChar | quote }}
{{- end }}
{{- with .Values.discordbot.volumeMounts }}
volumeMounts:
{{- toYaml . | nindent 12 }}
{{- end }}
{{- with .Values.discordbot.volumes }}
volumes:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- end }}

View File

@@ -655,6 +655,44 @@ celery_worker_user_file_processing:
tolerations: []
affinity: {}
# Discord bot for Onyx
# The bot offloads message processing to scalable API pods via HTTP requests.
discordbot:
enabled: false # Disabled by default - requires bot token configuration
# Bot token can be provided directly or via a Kubernetes secret
# Option 1: Direct token (not recommended for production)
botToken: ""
# Option 2: Reference a Kubernetes secret (recommended)
botTokenSecretName: "" # Name of the secret containing the bot token
botTokenSecretKey: "token" # Key within the secret (default: "token")
# Command prefix for bot commands (default: "!")
invokeChar: "!"
image:
repository: onyxdotapp/onyx-backend
tag: "" # Overrides the image tag whose default is the chart appVersion.
podAnnotations: {}
podLabels:
scope: onyx-backend
app: discord-bot
deploymentLabels:
app: discord-bot
podSecurityContext:
{}
securityContext:
{}
resources:
requests:
cpu: "500m"
memory: "512Mi"
limits:
cpu: "1000m"
memory: "2000Mi"
volumes: []
volumeMounts: []
nodeSelector: {}
tolerations: []
affinity: {}
slackbot:
enabled: true
replicaCount: 1
@@ -1090,6 +1128,8 @@ configMap:
ONYX_BOT_DISPLAY_ERROR_MSGS: ""
ONYX_BOT_RESPOND_EVERY_CHANNEL: ""
NOTIFY_SLACKBOT_NO_ANSWER: ""
DISCORD_BOT_TOKEN: ""
DISCORD_BOT_INVOKE_CHAR: ""
# Logging
# Optional Telemetry, please keep it on (nothing sensitive is collected)? <3
DISABLE_TELEMETRY: ""

Some files were not shown because too many files have changed in this diff Show More