mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-02 22:25:47 +00:00
Compare commits
60 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9fb76042a2 | ||
|
|
caad67a34a | ||
|
|
c33437488f | ||
|
|
9f66ee7240 | ||
|
|
e6ef2b5074 | ||
|
|
74132175a8 | ||
|
|
29f707ee2d | ||
|
|
f0eb86fb9f | ||
|
|
b422496a4c | ||
|
|
31d6a45b23 | ||
|
|
36f3ac1ec5 | ||
|
|
74f5b3025a | ||
|
|
c18545d74c | ||
|
|
48171e3700 | ||
|
|
f5a5709876 | ||
|
|
85868b1b83 | ||
|
|
8dc14c23e6 | ||
|
|
23821cc0e8 | ||
|
|
b359e13281 | ||
|
|
717f410a4a | ||
|
|
ada0946a62 | ||
|
|
eb2ac8f5a3 | ||
|
|
fbeb57c592 | ||
|
|
d6da9c9b85 | ||
|
|
5aea2e223e | ||
|
|
1ff91de07e | ||
|
|
b3dbc69faf | ||
|
|
431597b0f9 | ||
|
|
51b4e5f2fb | ||
|
|
9afa04a26b | ||
|
|
70a3a9c0cd | ||
|
|
080165356c | ||
|
|
3ae974bdf6 | ||
|
|
1471658151 | ||
|
|
3e85e9c1a3 | ||
|
|
851033be5f | ||
|
|
91e974a6cc | ||
|
|
38ba4f8a1c | ||
|
|
6f02473064 | ||
|
|
f89432009f | ||
|
|
8ab2bab34e | ||
|
|
59e0d62512 | ||
|
|
a1471b16a5 | ||
|
|
9d3811cb58 | ||
|
|
3cd9505383 | ||
|
|
d11829b393 | ||
|
|
f6e068e914 | ||
|
|
0c84edd980 | ||
|
|
2b274a7683 | ||
|
|
ddd91f2d71 | ||
|
|
a7c7da0dfc | ||
|
|
b00a3e8b5d | ||
|
|
d77d1a48f1 | ||
|
|
7b4fc6729c | ||
|
|
1f113c86ef | ||
|
|
8e38ba3e21 | ||
|
|
bb9708a64f | ||
|
|
8cae97e145 | ||
|
|
7e4abca224 | ||
|
|
233a91ea65 |
387
.github/workflows/deployment.yml
vendored
387
.github/workflows/deployment.yml
vendored
@@ -8,7 +8,9 @@ on:
|
||||
|
||||
# Set restrictive default permissions for all jobs. Jobs that need more permissions
|
||||
# should explicitly declare them.
|
||||
permissions: {}
|
||||
permissions:
|
||||
# Required for OIDC authentication with AWS
|
||||
id-token: write # zizmor: ignore[excessive-permissions]
|
||||
|
||||
env:
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
|
||||
@@ -150,16 +152,30 @@ jobs:
|
||||
if: always() && needs.check-version-tag.result == 'failure' && github.event_name != 'workflow_dispatch'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
environment: release
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
MONITOR_DEPLOYMENTS_WEBHOOK, deploy/monitor-deployments-webhook
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
webhook-url: ${{ env.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
failed-jobs: "• check-version-tag"
|
||||
title: "🚨 Version Tag Check Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
@@ -168,6 +184,7 @@ jobs:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-desktop == 'true'
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: write
|
||||
actions: read
|
||||
strategy:
|
||||
@@ -185,12 +202,33 @@ jobs:
|
||||
|
||||
runs-on: ${{ matrix.platform }}
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6.0.1
|
||||
with:
|
||||
# NOTE: persist-credentials is needed for tauri-action to create GitHub releases.
|
||||
persist-credentials: true # zizmor: ignore[artipacked]
|
||||
|
||||
- name: Configure AWS credentials
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
APPLE_ID, deploy/apple-id
|
||||
APPLE_PASSWORD, deploy/apple-password
|
||||
APPLE_CERTIFICATE, deploy/apple-certificate
|
||||
APPLE_CERTIFICATE_PASSWORD, deploy/apple-certificate-password
|
||||
KEYCHAIN_PASSWORD, deploy/keychain-password
|
||||
APPLE_TEAM_ID, deploy/apple-team-id
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: install dependencies (ubuntu only)
|
||||
if: startsWith(matrix.platform, 'ubuntu-')
|
||||
run: |
|
||||
@@ -285,15 +323,40 @@ jobs:
|
||||
|
||||
Write-Host "Versions set to: $VERSION"
|
||||
|
||||
- name: Import Apple Developer Certificate
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
run: |
|
||||
echo $APPLE_CERTIFICATE | base64 --decode > certificate.p12
|
||||
security create-keychain -p "$KEYCHAIN_PASSWORD" build.keychain
|
||||
security default-keychain -s build.keychain
|
||||
security unlock-keychain -p "$KEYCHAIN_PASSWORD" build.keychain
|
||||
security set-keychain-settings -t 3600 -u build.keychain
|
||||
security import certificate.p12 -k build.keychain -P "$APPLE_CERTIFICATE_PASSWORD" -T /usr/bin/codesign
|
||||
security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k "$KEYCHAIN_PASSWORD" build.keychain
|
||||
security find-identity -v -p codesigning build.keychain
|
||||
|
||||
- name: Verify Certificate
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
run: |
|
||||
CERT_INFO=$(security find-identity -v -p codesigning build.keychain | grep -E "(Developer ID Application|Apple Distribution|Apple Development)" | head -n 1)
|
||||
CERT_ID=$(echo "$CERT_INFO" | awk -F'"' '{print $2}')
|
||||
echo "CERT_ID=$CERT_ID" >> $GITHUB_ENV
|
||||
echo "Certificate imported."
|
||||
|
||||
- uses: tauri-apps/tauri-action@73fb865345c54760d875b94642314f8c0c894afa # ratchet:tauri-apps/tauri-action@action-v0.6.1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
APPLE_ID: ${{ env.APPLE_ID }}
|
||||
APPLE_PASSWORD: ${{ env.APPLE_PASSWORD }}
|
||||
APPLE_SIGNING_IDENTITY: ${{ env.CERT_ID }}
|
||||
APPLE_TEAM_ID: ${{ env.APPLE_TEAM_ID }}
|
||||
with:
|
||||
tagName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseBody: "See the assets to download this version and install."
|
||||
releaseDraft: true
|
||||
prerelease: false
|
||||
assetNamePattern: "[name]_[arch][ext]"
|
||||
args: ${{ matrix.args }}
|
||||
|
||||
build-web-amd64:
|
||||
@@ -305,6 +368,7 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-web-amd64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -317,6 +381,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -331,8 +409,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
@@ -363,6 +441,7 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-web-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -375,6 +454,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -389,8 +482,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
@@ -423,19 +516,34 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-merge-web
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -471,6 +579,7 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-web-cloud-amd64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -483,6 +592,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -497,8 +620,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
@@ -537,6 +660,7 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-web-cloud-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -549,6 +673,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -563,8 +701,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
@@ -605,19 +743,34 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-merge-web-cloud
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -650,6 +803,7 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-backend-amd64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -662,6 +816,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -676,8 +844,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
@@ -707,6 +875,7 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-backend-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -719,6 +888,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -733,8 +916,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
@@ -766,19 +949,34 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-merge-backend
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -815,6 +1013,7 @@ jobs:
|
||||
- volume=40gb
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -827,6 +1026,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -843,8 +1056,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
@@ -879,6 +1092,7 @@ jobs:
|
||||
- volume=40gb
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -891,6 +1105,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -907,8 +1135,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
@@ -944,19 +1172,34 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-merge-model-server
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -994,11 +1237,26 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-trivy-scan-web
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
@@ -1014,8 +1272,8 @@ jobs:
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
@@ -1034,11 +1292,26 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-trivy-scan-web-cloud
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
@@ -1054,8 +1327,8 @@ jobs:
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
@@ -1074,6 +1347,7 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-trivy-scan-backend
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
steps:
|
||||
@@ -1084,6 +1358,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
@@ -1100,8 +1388,8 @@ jobs:
|
||||
-v ${{ github.workspace }}/backend/.trivyignore:/tmp/.trivyignore:ro \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
@@ -1121,11 +1409,26 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-trivy-scan-model-server
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
@@ -1141,8 +1444,8 @@ jobs:
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
@@ -1170,12 +1473,26 @@ jobs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
MONITOR_DEPLOYMENTS_WEBHOOK, deploy/monitor-deployments-webhook
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Determine failed jobs
|
||||
id: failed-jobs
|
||||
shell: bash
|
||||
@@ -1241,7 +1558,7 @@ jobs:
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
webhook-url: ${{ env.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
failed-jobs: ${{ steps.failed-jobs.outputs.jobs }}
|
||||
title: "🚨 Deployment Workflow Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
|
||||
3
.github/workflows/pr-python-checks.yml
vendored
3
.github/workflows/pr-python-checks.yml
vendored
@@ -50,8 +50,9 @@ jobs:
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: backend/.mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
restore-keys: |
|
||||
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
|
||||
mypy-${{ runner.os }}-
|
||||
|
||||
- name: Run MyPy
|
||||
|
||||
18
.vscode/launch.template.jsonc
vendored
18
.vscode/launch.template.jsonc
vendored
@@ -151,6 +151,24 @@
|
||||
},
|
||||
"consoleTitle": "Slack Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "Discord Bot",
|
||||
"consoleName": "Discord Bot",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "onyx/onyxbot/discord/client.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Discord Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "MCP Server",
|
||||
"consoleName": "MCP Server",
|
||||
|
||||
116
backend/alembic/versions/8b5ce697290e_add_discord_bot_tables.py
Normal file
116
backend/alembic/versions/8b5ce697290e_add_discord_bot_tables.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Add Discord bot tables
|
||||
|
||||
Revision ID: 8b5ce697290e
|
||||
Revises: a1b2c3d4e5f7
|
||||
Create Date: 2025-01-14
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8b5ce697290e"
|
||||
down_revision = "a1b2c3d4e5f7"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# DiscordBotConfig (singleton table - one per tenant)
|
||||
op.create_table(
|
||||
"discord_bot_config",
|
||||
sa.Column(
|
||||
"id",
|
||||
sa.String(),
|
||||
primary_key=True,
|
||||
server_default=sa.text("'SINGLETON'"),
|
||||
),
|
||||
sa.Column("bot_token", sa.LargeBinary(), nullable=False), # EncryptedString
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.CheckConstraint("id = 'SINGLETON'", name="ck_discord_bot_config_singleton"),
|
||||
)
|
||||
|
||||
# DiscordGuildConfig
|
||||
op.create_table(
|
||||
"discord_guild_config",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("guild_id", sa.BigInteger(), nullable=True, unique=True),
|
||||
sa.Column("guild_name", sa.String(), nullable=True),
|
||||
sa.Column("registration_key", sa.String(), nullable=False, unique=True),
|
||||
sa.Column("registered_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"default_persona_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("persona.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False
|
||||
),
|
||||
)
|
||||
|
||||
# DiscordChannelConfig
|
||||
op.create_table(
|
||||
"discord_channel_config",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column(
|
||||
"guild_config_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("discord_guild_config.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("channel_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("channel_name", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"channel_type",
|
||||
sa.String(20),
|
||||
server_default=sa.text("'text'"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"is_private",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("false"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"thread_only_mode",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("false"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"require_bot_invocation",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("true"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"persona_override_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("persona.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"enabled", sa.Boolean(), server_default=sa.text("false"), nullable=False
|
||||
),
|
||||
)
|
||||
|
||||
# Unique constraint: one config per channel per guild
|
||||
op.create_unique_constraint(
|
||||
"uq_discord_channel_guild_channel",
|
||||
"discord_channel_config",
|
||||
["guild_config_id", "channel_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("discord_channel_config")
|
||||
op.drop_table("discord_guild_config")
|
||||
op.drop_table("discord_bot_config")
|
||||
@@ -0,0 +1,47 @@
|
||||
"""drop agent_search_metrics table
|
||||
|
||||
Revision ID: a1b2c3d4e5f7
|
||||
Revises: 73e9983e5091
|
||||
Create Date: 2026-01-17
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a1b2c3d4e5f7"
|
||||
down_revision = "73e9983e5091"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_table("agent__search_metrics")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.create_table(
|
||||
"agent__search_metrics",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("persona_id", sa.Integer(), nullable=True),
|
||||
sa.Column("agent_type", sa.String(), nullable=False),
|
||||
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("base_duration_s", sa.Float(), nullable=False),
|
||||
sa.Column("full_duration_s", sa.Float(), nullable=False),
|
||||
sa.Column("base_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("refined_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("all_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
@@ -109,7 +109,6 @@ CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS = float(
|
||||
|
||||
|
||||
STRIPE_SECRET_KEY = os.environ.get("STRIPE_SECRET_KEY")
|
||||
STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")
|
||||
|
||||
# JWT Public Key URL
|
||||
JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
|
||||
@@ -129,3 +128,8 @@ MARKETING_POSTHOG_API_KEY = os.environ.get("MARKETING_POSTHOG_API_KEY")
|
||||
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
|
||||
|
||||
GATED_TENANTS_KEY = "gated_tenants"
|
||||
|
||||
# License enforcement - when True, blocks API access for gated/expired licenses
|
||||
LICENSE_ENFORCEMENT_ENABLED = (
|
||||
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
@@ -16,6 +16,9 @@ from ee.onyx.server.enterprise_settings.api import (
|
||||
from ee.onyx.server.evals.api import router as evals_router
|
||||
from ee.onyx.server.license.api import router as license_router
|
||||
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
|
||||
from ee.onyx.server.middleware.license_enforcement import (
|
||||
add_license_enforcement_middleware,
|
||||
)
|
||||
from ee.onyx.server.middleware.tenant_tracking import (
|
||||
add_api_server_tenant_id_middleware,
|
||||
)
|
||||
@@ -83,6 +86,10 @@ def get_application() -> FastAPI:
|
||||
if MULTI_TENANT:
|
||||
add_api_server_tenant_id_middleware(application, logger)
|
||||
|
||||
# Add license enforcement middleware (runs after tenant tracking)
|
||||
# This blocks access when license is expired/gated
|
||||
add_license_enforcement_middleware(application, logger)
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
# For Google OAuth, refresh tokens are requested by:
|
||||
# 1. Adding the right scopes
|
||||
|
||||
102
backend/ee/onyx/server/middleware/license_enforcement.py
Normal file
102
backend/ee/onyx/server/middleware/license_enforcement.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Middleware to enforce license status application-wide."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from ee.onyx.server.tenants.product_gating import is_tenant_gated
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
# Paths that are ALWAYS accessible, even when license is expired/gated.
|
||||
# These enable users to:
|
||||
# /auth - Log in/out (users can't fix billing if locked out of auth)
|
||||
# /license - Fetch, upload, or check license status
|
||||
# /health - Health checks for load balancers/orchestrators
|
||||
# /me - Basic user info needed for UI rendering
|
||||
# /settings, /enterprise-settings - View app status and branding
|
||||
# /tenants/billing-* - Manage subscription to resolve gating
|
||||
ALLOWED_PATH_PREFIXES = {
|
||||
"/auth",
|
||||
"/license",
|
||||
"/health",
|
||||
"/me",
|
||||
"/settings",
|
||||
"/enterprise-settings",
|
||||
"/tenants/billing-information",
|
||||
"/tenants/create-customer-portal-session",
|
||||
"/tenants/create-subscription-session",
|
||||
}
|
||||
|
||||
|
||||
def _is_path_allowed(path: str) -> bool:
|
||||
"""Check if path is in allowlist (prefix match)."""
|
||||
return any(path.startswith(prefix) for prefix in ALLOWED_PATH_PREFIXES)
|
||||
|
||||
|
||||
def add_license_enforcement_middleware(
|
||||
app: FastAPI, logger: logging.LoggerAdapter
|
||||
) -> None:
|
||||
logger.info("License enforcement middleware registered")
|
||||
|
||||
@app.middleware("http")
|
||||
async def enforce_license(
|
||||
request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||||
) -> Response:
|
||||
"""Block requests when license is expired/gated."""
|
||||
if not LICENSE_ENFORCEMENT_ENABLED:
|
||||
return await call_next(request)
|
||||
|
||||
path = request.url.path
|
||||
if path.startswith("/api"):
|
||||
path = path[4:]
|
||||
|
||||
if _is_path_allowed(path):
|
||||
return await call_next(request)
|
||||
|
||||
is_gated = False
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if MULTI_TENANT:
|
||||
try:
|
||||
is_gated = is_tenant_gated(tenant_id)
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check tenant gating status: {e}")
|
||||
# Fail open - don't block users due to Redis connectivity issues
|
||||
is_gated = False
|
||||
else:
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if metadata:
|
||||
if metadata.status == ApplicationStatus.GATED_ACCESS:
|
||||
is_gated = True
|
||||
else:
|
||||
# No license metadata = gated for self-hosted EE
|
||||
is_gated = True
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license metadata: {e}")
|
||||
# Fail open - don't block users due to Redis connectivity issues
|
||||
is_gated = False
|
||||
|
||||
if is_gated:
|
||||
logger.info(f"Blocking request for gated tenant: {tenant_id}, path={path}")
|
||||
return JSONResponse(
|
||||
status_code=402,
|
||||
content={
|
||||
"detail": {
|
||||
"error": "license_expired",
|
||||
"message": "Your subscription has expired. Please update your billing.",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
54
backend/ee/onyx/server/settings/api.py
Normal file
54
backend/ee/onyx/server/settings/api.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""EE Settings API - provides license-aware settings override."""
|
||||
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Statuses that indicate a billing/license problem - propagate these to settings
|
||||
_GATED_STATUSES = frozenset(
|
||||
{
|
||||
ApplicationStatus.GATED_ACCESS,
|
||||
ApplicationStatus.GRACE_PERIOD,
|
||||
ApplicationStatus.PAYMENT_REMINDER,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
"""EE version: checks license status for self-hosted deployments.
|
||||
|
||||
For self-hosted, looks up license metadata and overrides application_status
|
||||
if the license is missing or indicates a problem (expired, grace period, etc.).
|
||||
|
||||
For multi-tenant (cloud), the settings already have the correct status
|
||||
from the control plane, so no override is needed.
|
||||
|
||||
If LICENSE_ENFORCEMENT_ENABLED is false, settings are returned unchanged,
|
||||
allowing the product to function normally without license checks.
|
||||
"""
|
||||
if not LICENSE_ENFORCEMENT_ENABLED:
|
||||
return settings
|
||||
|
||||
if MULTI_TENANT:
|
||||
return settings
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if metadata and metadata.status in _GATED_STATUSES:
|
||||
settings.application_status = metadata.status
|
||||
elif not metadata:
|
||||
# No license = gated access for self-hosted EE
|
||||
settings.application_status = ApplicationStatus.GATED_ACCESS
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license metadata for settings: {e}")
|
||||
|
||||
return settings
|
||||
@@ -1,9 +1,9 @@
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
import requests
|
||||
import stripe
|
||||
|
||||
from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
@@ -16,15 +16,21 @@ stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def fetch_stripe_checkout_session(tenant_id: str) -> str:
|
||||
def fetch_stripe_checkout_session(
|
||||
tenant_id: str,
|
||||
billing_period: Literal["monthly", "annual"] = "monthly",
|
||||
) -> str:
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/create-checkout-session"
|
||||
params = {"tenant_id": tenant_id}
|
||||
response = requests.post(url, headers=headers, params=params)
|
||||
payload = {
|
||||
"tenant_id": tenant_id,
|
||||
"billing_period": billing_period,
|
||||
}
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()["sessionId"]
|
||||
|
||||
@@ -70,24 +76,46 @@ def fetch_billing_information(
|
||||
return BillingInformation(**response_data)
|
||||
|
||||
|
||||
def fetch_customer_portal_session(tenant_id: str, return_url: str | None = None) -> str:
|
||||
"""
|
||||
Fetch a Stripe customer portal session URL from the control plane.
|
||||
NOTE: This is currently only used for multi-tenant (cloud) deployments.
|
||||
Self-hosted proxy endpoints will be added in a future phase.
|
||||
"""
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/create-customer-portal-session"
|
||||
payload = {"tenant_id": tenant_id}
|
||||
if return_url:
|
||||
payload["return_url"] = return_url
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()["url"]
|
||||
|
||||
|
||||
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:
|
||||
"""
|
||||
Send a request to the control service to register the number of users for a tenant.
|
||||
Update the number of seats for a tenant's subscription.
|
||||
Preserves the existing price (monthly, annual, or grandfathered).
|
||||
"""
|
||||
|
||||
if not STRIPE_PRICE_ID:
|
||||
raise Exception("STRIPE_PRICE_ID is not set")
|
||||
|
||||
response = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_subscription_id = cast(str, response.get("stripe_subscription_id"))
|
||||
|
||||
subscription = stripe.Subscription.retrieve(stripe_subscription_id)
|
||||
subscription_item = subscription["items"]["data"][0]
|
||||
|
||||
# Use existing price to preserve the customer's current plan
|
||||
current_price_id = subscription_item.price.id
|
||||
|
||||
updated_subscription = stripe.Subscription.modify(
|
||||
stripe_subscription_id,
|
||||
items=[
|
||||
{
|
||||
"id": subscription["items"]["data"][0].id,
|
||||
"price": STRIPE_PRICE_ID,
|
||||
"id": subscription_item.id,
|
||||
"price": current_price_id,
|
||||
"quantity": number_of_users,
|
||||
}
|
||||
],
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
import stripe
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import control_plane_dep
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.billing import fetch_customer_portal_session
|
||||
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
|
||||
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import CreateSubscriptionSessionRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingFullSyncRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingResponse
|
||||
@@ -23,7 +22,6 @@ from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
@@ -82,21 +80,17 @@ async def billing_information(
|
||||
async def create_customer_portal_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> dict:
|
||||
"""
|
||||
Create a Stripe customer portal session via the control plane.
|
||||
NOTE: This is currently only used for multi-tenant (cloud) deployments.
|
||||
Self-hosted proxy endpoints will be added in a future phase.
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
return_url = f"{WEB_DOMAIN}/admin/billing"
|
||||
|
||||
try:
|
||||
stripe_info = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_customer_id = stripe_info.get("stripe_customer_id")
|
||||
if not stripe_customer_id:
|
||||
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
|
||||
logger.info(stripe_customer_id)
|
||||
|
||||
portal_session = stripe.billing_portal.Session.create(
|
||||
customer=stripe_customer_id,
|
||||
return_url=f"{WEB_DOMAIN}/admin/billing",
|
||||
)
|
||||
logger.info(portal_session)
|
||||
return {"url": portal_session.url}
|
||||
portal_url = fetch_customer_portal_session(tenant_id, return_url)
|
||||
return {"url": portal_url}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create customer portal session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -104,15 +98,18 @@ async def create_customer_portal_session(
|
||||
|
||||
@router.post("/create-subscription-session")
|
||||
async def create_subscription_session(
|
||||
request: CreateSubscriptionSessionRequest | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> SubscriptionSessionResponse:
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="Tenant ID not found")
|
||||
session_id = fetch_stripe_checkout_session(tenant_id)
|
||||
|
||||
billing_period = request.billing_period if request else "monthly"
|
||||
session_id = fetch_stripe_checkout_session(tenant_id, billing_period)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create resubscription session")
|
||||
logger.exception("Failed to create subscription session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -73,6 +74,12 @@ class SubscriptionSessionResponse(BaseModel):
|
||||
sessionId: str
|
||||
|
||||
|
||||
class CreateSubscriptionSessionRequest(BaseModel):
|
||||
"""Request to create a subscription checkout session."""
|
||||
|
||||
billing_period: Literal["monthly", "annual"] = "monthly"
|
||||
|
||||
|
||||
class TenantByDomainResponse(BaseModel):
|
||||
tenant_id: str
|
||||
number_of_users: int
|
||||
|
||||
@@ -65,3 +65,9 @@ def get_gated_tenants() -> set[str]:
|
||||
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
gated_tenants_bytes = cast(set[bytes], redis_client.smembers(GATED_TENANTS_KEY))
|
||||
return {tenant_id.decode("utf-8") for tenant_id in gated_tenants_bytes}
|
||||
|
||||
|
||||
def is_tenant_gated(tenant_id: str) -> bool:
|
||||
"""Fast O(1) check if tenant is in gated set (multi-tenant only)."""
|
||||
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
return bool(redis_client.sismember(GATED_TENANTS_KEY, tenant_id))
|
||||
|
||||
@@ -12,6 +12,7 @@ from retry import retry
|
||||
from sqlalchemy import select
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
@@ -19,12 +20,14 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
@@ -53,6 +56,17 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_queued_key(user_file_id: str | UUID) -> str:
|
||||
"""Key that exists while a process_single_user_file task is sitting in the queue.
|
||||
|
||||
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
before enqueuing and the worker deletes it as its first action. This prevents
|
||||
the beat from adding duplicate tasks for files that already have a live task
|
||||
in flight.
|
||||
"""
|
||||
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
@@ -116,7 +130,24 @@ def _get_document_chunk_count(
|
||||
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
|
||||
|
||||
Uses direct Redis locks to avoid overlapping runs.
|
||||
Three mechanisms prevent queue runaway:
|
||||
|
||||
1. **Queue depth backpressure** – if the broker queue already has more than
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
|
||||
entirely. Workers are clearly behind; adding more tasks would only make
|
||||
the backlog worse.
|
||||
|
||||
2. **Per-file queued guard** – before enqueuing a task we set a short-lived
|
||||
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
|
||||
already exists the file already has a live task in the queue, so we skip
|
||||
it. The worker deletes the key the moment it picks up the task so the
|
||||
next beat cycle can re-enqueue if the file is still PROCESSING.
|
||||
|
||||
3. **Task expiry** – every enqueued task carries an `expires` value equal to
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
|
||||
the queue after that deadline, Celery discards it without touching the DB.
|
||||
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
|
||||
Redis restart), stale tasks evict themselves rather than piling up forever.
|
||||
"""
|
||||
task_logger.info("check_user_file_processing - Starting")
|
||||
|
||||
@@ -131,7 +162,21 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
enqueued = 0
|
||||
skipped_guard = 0
|
||||
try:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
queue_len = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
|
||||
)
|
||||
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
f"check_user_file_processing - Queue depth {queue_len} exceeds "
|
||||
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file_ids = (
|
||||
db_session.execute(
|
||||
@@ -144,12 +189,35 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
)
|
||||
|
||||
for user_file_id in user_file_ids:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
# --- Protection 2: per-file queued guard ---
|
||||
queued_key = _user_file_queued_key(user_file_id)
|
||||
guard_set = redis_client.set(
|
||||
queued_key,
|
||||
1,
|
||||
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
nx=True,
|
||||
)
|
||||
if not guard_set:
|
||||
skipped_guard += 1
|
||||
continue
|
||||
|
||||
# --- Protection 3: task expiry ---
|
||||
# If task submission fails, clear the guard immediately so the
|
||||
# next beat cycle can retry enqueuing this file.
|
||||
try:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={
|
||||
"user_file_id": str(user_file_id),
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
)
|
||||
except Exception:
|
||||
redis_client.delete(queued_key)
|
||||
raise
|
||||
enqueued += 1
|
||||
|
||||
finally:
|
||||
@@ -157,7 +225,8 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
|
||||
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
|
||||
f"tasks for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -172,6 +241,12 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
|
||||
start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Clear the "queued" guard set by the beat generator so that the next beat
|
||||
# cycle can re-enqueue this file if it is still in PROCESSING state after
|
||||
# this task completes or fails.
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
|
||||
@@ -9,6 +9,7 @@ from onyx.chat.citation_processor import CitationMode
|
||||
from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
from onyx.chat.citation_utils import update_citation_processor_from_tool_response
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.llm_step import extract_tool_calls_from_response_text
|
||||
from onyx.chat.llm_step import run_llm_step
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
@@ -38,6 +39,7 @@ from onyx.tools.built_in_tools import CITEABLE_TOOLS_NAMES
|
||||
from onyx.tools.built_in_tools import STOPPING_TOOLS_NAMES
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool_implementations.images.models import (
|
||||
FinalImageGenerationResponse,
|
||||
@@ -51,6 +53,78 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _try_fallback_tool_extraction(
|
||||
llm_step_result: LlmStepResult,
|
||||
tool_choice: ToolChoiceOptions,
|
||||
fallback_extraction_attempted: bool,
|
||||
tool_defs: list[dict],
|
||||
turn_index: int,
|
||||
) -> tuple[LlmStepResult, bool]:
|
||||
"""Attempt to extract tool calls from response text as a fallback.
|
||||
|
||||
This is a last resort fallback for low quality LLMs or those that don't have
|
||||
tool calling from the serving layer. Also triggers if there's reasoning but
|
||||
no answer and no tool calls.
|
||||
|
||||
Args:
|
||||
llm_step_result: The result from the LLM step
|
||||
tool_choice: The tool choice option used for this step
|
||||
fallback_extraction_attempted: Whether fallback extraction was already attempted
|
||||
tool_defs: List of tool definitions
|
||||
turn_index: The current turn index for placement
|
||||
|
||||
Returns:
|
||||
Tuple of (possibly updated LlmStepResult, whether fallback was attempted this call)
|
||||
"""
|
||||
if fallback_extraction_attempted:
|
||||
return llm_step_result, False
|
||||
|
||||
no_tool_calls = (
|
||||
not llm_step_result.tool_calls or len(llm_step_result.tool_calls) == 0
|
||||
)
|
||||
reasoning_but_no_answer_or_tools = (
|
||||
llm_step_result.reasoning and not llm_step_result.answer and no_tool_calls
|
||||
)
|
||||
should_try_fallback = (
|
||||
tool_choice == ToolChoiceOptions.REQUIRED and no_tool_calls
|
||||
) or reasoning_but_no_answer_or_tools
|
||||
|
||||
if not should_try_fallback:
|
||||
return llm_step_result, False
|
||||
|
||||
# Try to extract from answer first, then fall back to reasoning
|
||||
extracted_tool_calls: list[ToolCallKickoff] = []
|
||||
if llm_step_result.answer:
|
||||
extracted_tool_calls = extract_tool_calls_from_response_text(
|
||||
response_text=llm_step_result.answer,
|
||||
tool_definitions=tool_defs,
|
||||
placement=Placement(turn_index=turn_index),
|
||||
)
|
||||
if not extracted_tool_calls and llm_step_result.reasoning:
|
||||
extracted_tool_calls = extract_tool_calls_from_response_text(
|
||||
response_text=llm_step_result.reasoning,
|
||||
tool_definitions=tool_defs,
|
||||
placement=Placement(turn_index=turn_index),
|
||||
)
|
||||
|
||||
if extracted_tool_calls:
|
||||
logger.info(
|
||||
f"Extracted {len(extracted_tool_calls)} tool call(s) from response text "
|
||||
f"as fallback (tool_choice was REQUIRED but no tool calls returned)"
|
||||
)
|
||||
return (
|
||||
LlmStepResult(
|
||||
reasoning=llm_step_result.reasoning,
|
||||
answer=llm_step_result.answer,
|
||||
tool_calls=extracted_tool_calls,
|
||||
),
|
||||
True,
|
||||
)
|
||||
|
||||
return llm_step_result, True
|
||||
|
||||
|
||||
# Hardcoded oppinionated value, might breaks down to something like:
|
||||
# Cycle 1: Calls web_search for something
|
||||
# Cycle 2: Calls open_url for some results
|
||||
@@ -352,6 +426,7 @@ def run_llm_loop(
|
||||
ran_image_gen: bool = False
|
||||
just_ran_web_search: bool = False
|
||||
has_called_search_tool: bool = False
|
||||
fallback_extraction_attempted: bool = False
|
||||
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
|
||||
|
||||
default_base_system_prompt: str = get_default_base_system_prompt(db_session)
|
||||
@@ -470,10 +545,11 @@ def run_llm_loop(
|
||||
|
||||
# This calls the LLM, yields packets (reasoning, answers, etc.) and returns the result
|
||||
# It also pre-processes the tool calls in preparation for running them
|
||||
tool_defs = [tool.tool_definition() for tool in final_tools]
|
||||
llm_step_result, has_reasoned = run_llm_step(
|
||||
emitter=emitter,
|
||||
history=truncated_message_history,
|
||||
tool_definitions=[tool.tool_definition() for tool in final_tools],
|
||||
tool_definitions=tool_defs,
|
||||
tool_choice=tool_choice,
|
||||
llm=llm,
|
||||
placement=Placement(turn_index=llm_cycle_count + reasoning_cycles),
|
||||
@@ -488,6 +564,19 @@ def run_llm_loop(
|
||||
if has_reasoned:
|
||||
reasoning_cycles += 1
|
||||
|
||||
# Fallback extraction for LLMs that don't support tool calling natively or are lower quality
|
||||
# and might incorrectly output tool calls in other channels
|
||||
llm_step_result, attempted = _try_fallback_tool_extraction(
|
||||
llm_step_result=llm_step_result,
|
||||
tool_choice=tool_choice,
|
||||
fallback_extraction_attempted=fallback_extraction_attempted,
|
||||
tool_defs=tool_defs,
|
||||
turn_index=llm_cycle_count + reasoning_cycles,
|
||||
)
|
||||
if attempted:
|
||||
# To prevent the case of excessive looping with bad models, we only allow one fallback attempt
|
||||
fallback_extraction_attempted = True
|
||||
|
||||
# Save citation mapping after each LLM step for incremental state updates
|
||||
state_container.set_citation_mapping(citation_processor.citation_to_doc)
|
||||
|
||||
@@ -580,6 +669,12 @@ def run_llm_loop(
|
||||
):
|
||||
generated_images = tool_response.rich_response.generated_images
|
||||
|
||||
saved_response = (
|
||||
tool_response.rich_response
|
||||
if isinstance(tool_response.rich_response, str)
|
||||
else tool_response.llm_facing_response
|
||||
)
|
||||
|
||||
tool_call_info = ToolCallInfo(
|
||||
parent_tool_call_id=None, # Top-level tool calls are attached to the chat message
|
||||
turn_index=llm_cycle_count + reasoning_cycles,
|
||||
@@ -589,7 +684,7 @@ def run_llm_loop(
|
||||
tool_id=tool.id,
|
||||
reasoning_tokens=llm_step_result.reasoning, # All tool calls from this loop share the same reasoning
|
||||
tool_call_arguments=tool_call.tool_args,
|
||||
tool_call_response=tool_response.llm_facing_response,
|
||||
tool_call_response=saved_response,
|
||||
search_docs=search_docs,
|
||||
generated_images=generated_images,
|
||||
)
|
||||
@@ -645,7 +740,12 @@ def run_llm_loop(
|
||||
should_cite_documents = True
|
||||
|
||||
if not llm_step_result or not llm_step_result.answer:
|
||||
raise RuntimeError("LLM did not return an answer.")
|
||||
raise RuntimeError(
|
||||
"The LLM did not return an answer. "
|
||||
"Typically this is an issue with LLMs that do not support tool calling natively, "
|
||||
"or the model serving API is not configured correctly. "
|
||||
"This may also happen with models that are lower quality outputting invalid tool calls."
|
||||
)
|
||||
|
||||
emitter.emit(
|
||||
Packet(
|
||||
|
||||
@@ -49,6 +49,7 @@ from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tracing.framework.create import generation_span
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.text_processing import find_all_json_objects
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -278,6 +279,144 @@ def _extract_tool_call_kickoffs(
|
||||
return tool_calls
|
||||
|
||||
|
||||
def extract_tool_calls_from_response_text(
|
||||
response_text: str | None,
|
||||
tool_definitions: list[dict],
|
||||
placement: Placement,
|
||||
) -> list[ToolCallKickoff]:
|
||||
"""Extract tool calls from LLM response text by matching JSON against tool definitions.
|
||||
|
||||
This is a fallback mechanism for when the LLM was expected to return tool calls
|
||||
but didn't use the proper tool call format. It searches for JSON objects in the
|
||||
response text that match the structure of available tools.
|
||||
|
||||
Args:
|
||||
response_text: The LLM's text response to search for tool calls
|
||||
tool_definitions: List of tool definitions to match against
|
||||
placement: Placement information for the tool calls
|
||||
|
||||
Returns:
|
||||
List of ToolCallKickoff objects for any matched tool calls
|
||||
"""
|
||||
if not response_text or not tool_definitions:
|
||||
return []
|
||||
|
||||
# Build a map of tool names to their definitions
|
||||
tool_name_to_def: dict[str, dict] = {}
|
||||
for tool_def in tool_definitions:
|
||||
if tool_def.get("type") == "function" and "function" in tool_def:
|
||||
func_def = tool_def["function"]
|
||||
tool_name = func_def.get("name")
|
||||
if tool_name:
|
||||
tool_name_to_def[tool_name] = func_def
|
||||
|
||||
if not tool_name_to_def:
|
||||
return []
|
||||
|
||||
# Find all JSON objects in the response text
|
||||
json_objects = find_all_json_objects(response_text)
|
||||
|
||||
tool_calls: list[ToolCallKickoff] = []
|
||||
tab_index = 0
|
||||
|
||||
for json_obj in json_objects:
|
||||
matched_tool_call = _try_match_json_to_tool(json_obj, tool_name_to_def)
|
||||
if matched_tool_call:
|
||||
tool_name, tool_args = matched_tool_call
|
||||
tool_calls.append(
|
||||
ToolCallKickoff(
|
||||
tool_call_id=f"extracted_{uuid.uuid4().hex[:8]}",
|
||||
tool_name=tool_name,
|
||||
tool_args=tool_args,
|
||||
placement=Placement(
|
||||
turn_index=placement.turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=placement.sub_turn_index,
|
||||
),
|
||||
)
|
||||
)
|
||||
tab_index += 1
|
||||
|
||||
logger.info(
|
||||
f"Extracted {len(tool_calls)} tool call(s) from response text as fallback"
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
|
||||
def _try_match_json_to_tool(
|
||||
json_obj: dict[str, Any],
|
||||
tool_name_to_def: dict[str, dict],
|
||||
) -> tuple[str, dict[str, Any]] | None:
|
||||
"""Try to match a JSON object to a tool definition.
|
||||
|
||||
Supports several formats:
|
||||
1. Direct tool call format: {"name": "tool_name", "arguments": {...}}
|
||||
2. Function call format: {"function": {"name": "tool_name", "arguments": {...}}}
|
||||
3. Tool name as key: {"tool_name": {...arguments...}}
|
||||
4. Arguments matching a tool's parameter schema
|
||||
|
||||
Args:
|
||||
json_obj: The JSON object to match
|
||||
tool_name_to_def: Map of tool names to their function definitions
|
||||
|
||||
Returns:
|
||||
Tuple of (tool_name, tool_args) if matched, None otherwise
|
||||
"""
|
||||
# Format 1: Direct tool call format {"name": "...", "arguments": {...}}
|
||||
if "name" in json_obj and json_obj["name"] in tool_name_to_def:
|
||||
tool_name = json_obj["name"]
|
||||
arguments = json_obj.get("arguments", json_obj.get("parameters", {}))
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
if isinstance(arguments, dict):
|
||||
return (tool_name, arguments)
|
||||
|
||||
# Format 2: Function call format {"function": {"name": "...", "arguments": {...}}}
|
||||
if "function" in json_obj and isinstance(json_obj["function"], dict):
|
||||
func_obj = json_obj["function"]
|
||||
if "name" in func_obj and func_obj["name"] in tool_name_to_def:
|
||||
tool_name = func_obj["name"]
|
||||
arguments = func_obj.get("arguments", func_obj.get("parameters", {}))
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
if isinstance(arguments, dict):
|
||||
return (tool_name, arguments)
|
||||
|
||||
# Format 3: Tool name as key {"tool_name": {...arguments...}}
|
||||
for tool_name in tool_name_to_def:
|
||||
if tool_name in json_obj:
|
||||
arguments = json_obj[tool_name]
|
||||
if isinstance(arguments, dict):
|
||||
return (tool_name, arguments)
|
||||
|
||||
# Format 4: Check if the JSON object matches a tool's parameter schema
|
||||
for tool_name, func_def in tool_name_to_def.items():
|
||||
params = func_def.get("parameters", {})
|
||||
properties = params.get("properties", {})
|
||||
required = params.get("required", [])
|
||||
|
||||
if not properties:
|
||||
continue
|
||||
|
||||
# Check if all required parameters are present (empty required = all optional)
|
||||
if all(req in json_obj for req in required):
|
||||
# Check if any of the tool's properties are in the JSON object
|
||||
matching_props = [prop for prop in properties if prop in json_obj]
|
||||
if matching_props:
|
||||
# Filter to only include known properties
|
||||
filtered_args = {k: v for k, v in json_obj.items() if k in properties}
|
||||
return (tool_name, filtered_args)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def translate_history_to_llm_format(
|
||||
history: list[ChatMessageSimple],
|
||||
llm_config: LLMConfig,
|
||||
|
||||
@@ -86,10 +86,6 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.timing import log_function_time
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from onyx.utils.variable_functionality import noop_fallback
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -362,21 +358,20 @@ def handle_stream_message_objects(
|
||||
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
|
||||
)
|
||||
|
||||
# Track user message in PostHog for analytics
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
module="onyx.utils.telemetry",
|
||||
attribute="event_telemetry",
|
||||
fallback=noop_fallback,
|
||||
)(
|
||||
distinct_id=user.email if user else tenant_id,
|
||||
event="user_message_sent",
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=(
|
||||
user.email
|
||||
if user and not getattr(user, "is_anonymous", False)
|
||||
else tenant_id
|
||||
),
|
||||
event=MilestoneRecordType.USER_MESSAGE_SENT,
|
||||
properties={
|
||||
"origin": new_msg_req.origin.value,
|
||||
"has_files": len(new_msg_req.file_descriptors) > 0,
|
||||
"has_project": chat_session.project_id is not None,
|
||||
"has_persona": persona is not None and persona.id != DEFAULT_PERSONA_ID,
|
||||
"deep_research": new_msg_req.deep_research,
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
from onyx.prompts.prompt_utils import replace_citation_guidance_tag
|
||||
from onyx.prompts.tool_prompts import GENERATE_IMAGE_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import INTERNAL_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import MEMORY_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import OPEN_URLS_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import PYTHON_TOOL_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
@@ -28,6 +29,7 @@ from onyx.tools.interface import Tool
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.memory.memory_tool import MemoryTool
|
||||
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
@@ -178,8 +180,9 @@ def build_system_prompt(
|
||||
site_colon_disabled=WEB_SEARCH_SITE_DISABLED_GUIDANCE
|
||||
)
|
||||
+ OPEN_URLS_GUIDANCE
|
||||
+ GENERATE_IMAGE_GUIDANCE
|
||||
+ PYTHON_TOOL_GUIDANCE
|
||||
+ GENERATE_IMAGE_GUIDANCE
|
||||
+ MEMORY_GUIDANCE
|
||||
)
|
||||
return system_prompt
|
||||
|
||||
@@ -193,6 +196,7 @@ def build_system_prompt(
|
||||
has_generate_image = any(
|
||||
isinstance(tool, ImageGenerationTool) for tool in tools
|
||||
)
|
||||
has_memory = any(isinstance(tool, MemoryTool) for tool in tools)
|
||||
|
||||
if has_web_search or has_internal_search or include_all_guidance:
|
||||
system_prompt += TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
@@ -222,4 +226,7 @@ def build_system_prompt(
|
||||
if has_generate_image or include_all_guidance:
|
||||
system_prompt += GENERATE_IMAGE_GUIDANCE
|
||||
|
||||
if has_memory or include_all_guidance:
|
||||
system_prompt += MEMORY_GUIDANCE
|
||||
|
||||
return system_prompt
|
||||
|
||||
@@ -1011,3 +1011,8 @@ INSTANCE_TYPE = (
|
||||
if os.environ.get("IS_MANAGED_INSTANCE", "").lower() == "true"
|
||||
else "cloud" if AUTH_TYPE == AuthType.CLOUD else "self_hosted"
|
||||
)
|
||||
|
||||
|
||||
## Discord Bot Configuration
|
||||
DISCORD_BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
|
||||
DISCORD_BOT_INVOKE_CHAR = os.environ.get("DISCORD_BOT_INVOKE_CHAR", "!")
|
||||
|
||||
@@ -93,6 +93,7 @@ SSL_CERT_FILE = "bundle.pem"
|
||||
DANSWER_API_KEY_PREFIX = "API_KEY__"
|
||||
DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN = "onyxapikey.ai"
|
||||
UNNAMED_KEY_PLACEHOLDER = "Unnamed"
|
||||
DISCORD_SERVICE_API_KEY_NAME = "discord-bot-service"
|
||||
|
||||
# Key-Value store keys
|
||||
KV_REINDEX_KEY = "needs_reindexing"
|
||||
@@ -152,6 +153,17 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
|
||||
|
||||
# How long a queued user-file task is valid before workers discard it.
|
||||
# Should be longer than the beat interval (20 s) but short enough to prevent
|
||||
# indefinite queue growth. Workers drop tasks older than this without touching
|
||||
# the DB, so a shorter value = faster drain of stale duplicates.
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
|
||||
|
||||
# Maximum number of tasks allowed in the user-file-processing queue before the
|
||||
# beat generator stops adding more. Prevents unbounded queue growth when workers
|
||||
# fall behind.
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
|
||||
|
||||
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
|
||||
@@ -340,6 +352,7 @@ class MilestoneRecordType(str, Enum):
|
||||
CREATED_CONNECTOR = "created_connector"
|
||||
CONNECTOR_SUCCEEDED = "connector_succeeded"
|
||||
RAN_QUERY = "ran_query"
|
||||
USER_MESSAGE_SENT = "user_message_sent"
|
||||
MULTIPLE_ASSISTANTS = "multiple_assistants"
|
||||
CREATED_ASSISTANT = "created_assistant"
|
||||
CREATED_ONYX_BOT = "created_onyx_bot"
|
||||
@@ -422,6 +435,9 @@ class OnyxRedisLocks:
|
||||
# User file processing
|
||||
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
|
||||
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
|
||||
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
|
||||
# Prevents the beat from re-enqueuing the same file while a task is already queued.
|
||||
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
|
||||
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
|
||||
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
|
||||
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
|
||||
|
||||
@@ -25,11 +25,17 @@ class AsanaConnector(LoadConnector, PollConnector):
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
) -> None:
|
||||
self.workspace_id = asana_workspace_id
|
||||
self.project_ids_to_index: list[str] | None = (
|
||||
asana_project_ids.split(",") if asana_project_ids is not None else None
|
||||
)
|
||||
self.asana_team_id = asana_team_id
|
||||
self.workspace_id = asana_workspace_id.strip()
|
||||
if asana_project_ids:
|
||||
project_ids = [
|
||||
project_id.strip()
|
||||
for project_id in asana_project_ids.split(",")
|
||||
if project_id.strip()
|
||||
]
|
||||
self.project_ids_to_index = project_ids or None
|
||||
else:
|
||||
self.project_ids_to_index = None
|
||||
self.asana_team_id = (asana_team_id.strip() or None) if asana_team_id else None
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
logger.info(
|
||||
|
||||
@@ -567,6 +567,23 @@ def extract_content_words_from_recency_query(
|
||||
return content_words_filtered[:MAX_CONTENT_WORDS]
|
||||
|
||||
|
||||
def _is_valid_keyword_query(line: str) -> bool:
|
||||
"""Check if a line looks like a valid keyword query vs explanatory text.
|
||||
|
||||
Returns False for lines that appear to be LLM explanations rather than keywords.
|
||||
"""
|
||||
# Reject lines that start with parentheses (explanatory notes)
|
||||
if line.startswith("("):
|
||||
return False
|
||||
|
||||
# Reject lines that are too long (likely sentences, not keywords)
|
||||
# Keywords should be short - reject if > 50 chars or > 6 words
|
||||
if len(line) > 50 or len(line.split()) > 6:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
"""Use LLM to expand query into multiple search variations.
|
||||
|
||||
@@ -589,10 +606,18 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
response_clean = _parse_llm_code_block_response(response)
|
||||
|
||||
# Split into lines and filter out empty lines
|
||||
rephrased_queries = [
|
||||
raw_queries = [
|
||||
line.strip() for line in response_clean.split("\n") if line.strip()
|
||||
]
|
||||
|
||||
# Filter out lines that look like explanatory text rather than keywords
|
||||
rephrased_queries = [q for q in raw_queries if _is_valid_keyword_query(q)]
|
||||
|
||||
# Log if we filtered out garbage
|
||||
if len(raw_queries) != len(rephrased_queries):
|
||||
filtered_out = set(raw_queries) - set(rephrased_queries)
|
||||
logger.warning(f"Filtered out non-keyword LLM responses: {filtered_out}")
|
||||
|
||||
# If no queries generated, use empty query
|
||||
if not rephrased_queries:
|
||||
logger.debug("No content keywords extracted from query expansion")
|
||||
|
||||
451
backend/onyx/db/discord_bot.py
Normal file
451
backend/onyx/db/discord_bot.py
Normal file
@@ -0,0 +1,451 @@
|
||||
"""CRUD operations for Discord bot models."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.api_key import build_displayable_api_key
|
||||
from onyx.auth.api_key import generate_api_key
|
||||
from onyx.auth.api_key import hash_api_key
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import DISCORD_SERVICE_API_KEY_NAME
|
||||
from onyx.db.api_key import insert_api_key
|
||||
from onyx.db.models import ApiKey
|
||||
from onyx.db.models import DiscordBotConfig
|
||||
from onyx.db.models import DiscordChannelConfig
|
||||
from onyx.db.models import DiscordGuildConfig
|
||||
from onyx.db.models import User
|
||||
from onyx.db.utils import DiscordChannelView
|
||||
from onyx.server.api_key.models import APIKeyArgs
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# === DiscordBotConfig ===
|
||||
|
||||
|
||||
def get_discord_bot_config(db_session: Session) -> DiscordBotConfig | None:
|
||||
"""Get the Discord bot config for this tenant (at most one)."""
|
||||
return db_session.scalar(select(DiscordBotConfig).limit(1))
|
||||
|
||||
|
||||
def create_discord_bot_config(
|
||||
db_session: Session,
|
||||
bot_token: str,
|
||||
) -> DiscordBotConfig:
|
||||
"""Create the Discord bot config. Raises ValueError if already exists.
|
||||
|
||||
The check constraint on id='SINGLETON' ensures only one config per tenant.
|
||||
"""
|
||||
existing = get_discord_bot_config(db_session)
|
||||
if existing:
|
||||
raise ValueError("Discord bot config already exists")
|
||||
|
||||
config = DiscordBotConfig(bot_token=bot_token)
|
||||
db_session.add(config)
|
||||
try:
|
||||
db_session.flush()
|
||||
except IntegrityError:
|
||||
# Race condition: another request created the config concurrently
|
||||
db_session.rollback()
|
||||
raise ValueError("Discord bot config already exists")
|
||||
return config
|
||||
|
||||
|
||||
def delete_discord_bot_config(db_session: Session) -> bool:
|
||||
"""Delete the Discord bot config. Returns True if deleted."""
|
||||
result = db_session.execute(delete(DiscordBotConfig))
|
||||
db_session.flush()
|
||||
return result.rowcount > 0 # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# === Discord Service API Key ===
|
||||
|
||||
|
||||
def get_discord_service_api_key(db_session: Session) -> ApiKey | None:
|
||||
"""Get the Discord service API key if it exists."""
|
||||
return db_session.scalar(
|
||||
select(ApiKey).where(ApiKey.name == DISCORD_SERVICE_API_KEY_NAME)
|
||||
)
|
||||
|
||||
|
||||
def get_or_create_discord_service_api_key(
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
) -> str:
|
||||
"""Get existing Discord service API key or create one.
|
||||
|
||||
The API key is used by the Discord bot to authenticate with the
|
||||
Onyx API pods when sending chat requests.
|
||||
|
||||
Args:
|
||||
db_session: Database session for the tenant.
|
||||
tenant_id: The tenant ID (used for logging/context).
|
||||
|
||||
Returns:
|
||||
The raw API key string (not hashed).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If API key creation fails.
|
||||
"""
|
||||
# Check for existing key
|
||||
existing = get_discord_service_api_key(db_session)
|
||||
if existing:
|
||||
# Database only stores the hash, so we must regenerate to get the raw key.
|
||||
# This is safe since the Discord bot is the only consumer of this key.
|
||||
logger.debug(
|
||||
f"Found existing Discord service API key for tenant {tenant_id} that isn't in cache, "
|
||||
"regenerating to update cache"
|
||||
)
|
||||
new_api_key = generate_api_key(tenant_id)
|
||||
existing.hashed_api_key = hash_api_key(new_api_key)
|
||||
existing.api_key_display = build_displayable_api_key(new_api_key)
|
||||
db_session.flush()
|
||||
return new_api_key
|
||||
|
||||
# Create new API key
|
||||
logger.info(f"Creating Discord service API key for tenant {tenant_id}")
|
||||
api_key_args = APIKeyArgs(
|
||||
name=DISCORD_SERVICE_API_KEY_NAME,
|
||||
role=UserRole.LIMITED, # Limited role is sufficient for chat requests
|
||||
)
|
||||
api_key_descriptor = insert_api_key(
|
||||
db_session=db_session,
|
||||
api_key_args=api_key_args,
|
||||
user_id=None, # Service account, no owner
|
||||
)
|
||||
|
||||
if not api_key_descriptor.api_key:
|
||||
raise RuntimeError(
|
||||
f"Failed to create Discord service API key for tenant {tenant_id}"
|
||||
)
|
||||
|
||||
return api_key_descriptor.api_key
|
||||
|
||||
|
||||
def delete_discord_service_api_key(db_session: Session) -> bool:
|
||||
"""Delete the Discord service API key for a tenant.
|
||||
|
||||
Called when:
|
||||
- Bot config is deleted (self-hosted)
|
||||
- All guild configs are deleted (Cloud)
|
||||
|
||||
Args:
|
||||
db_session: Database session for the tenant.
|
||||
|
||||
Returns:
|
||||
True if the key was deleted, False if it didn't exist.
|
||||
"""
|
||||
existing_key = get_discord_service_api_key(db_session)
|
||||
if not existing_key:
|
||||
return False
|
||||
|
||||
# Also delete the associated user
|
||||
api_key_user = db_session.scalar(
|
||||
select(User).where(User.id == existing_key.user_id) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
db_session.delete(existing_key)
|
||||
if api_key_user:
|
||||
db_session.delete(api_key_user)
|
||||
|
||||
db_session.flush()
|
||||
logger.info("Deleted Discord service API key")
|
||||
return True
|
||||
|
||||
|
||||
# === DiscordGuildConfig ===
|
||||
|
||||
|
||||
def get_guild_configs(
|
||||
db_session: Session,
|
||||
include_channels: bool = False,
|
||||
) -> list[DiscordGuildConfig]:
|
||||
"""Get all guild configs for this tenant."""
|
||||
stmt = select(DiscordGuildConfig)
|
||||
if include_channels:
|
||||
stmt = stmt.options(joinedload(DiscordGuildConfig.channels))
|
||||
return list(db_session.scalars(stmt).unique().all())
|
||||
|
||||
|
||||
def get_guild_config_by_internal_id(
|
||||
db_session: Session,
|
||||
internal_id: int,
|
||||
) -> DiscordGuildConfig | None:
|
||||
"""Get a specific guild config by its ID."""
|
||||
return db_session.scalar(
|
||||
select(DiscordGuildConfig).where(DiscordGuildConfig.id == internal_id)
|
||||
)
|
||||
|
||||
|
||||
def get_guild_config_by_discord_id(
|
||||
db_session: Session,
|
||||
guild_id: int,
|
||||
) -> DiscordGuildConfig | None:
|
||||
"""Get a guild config by Discord guild ID."""
|
||||
return db_session.scalar(
|
||||
select(DiscordGuildConfig).where(DiscordGuildConfig.guild_id == guild_id)
|
||||
)
|
||||
|
||||
|
||||
def get_guild_config_by_registration_key(
|
||||
db_session: Session,
|
||||
registration_key: str,
|
||||
) -> DiscordGuildConfig | None:
|
||||
"""Get a guild config by its registration key."""
|
||||
return db_session.scalar(
|
||||
select(DiscordGuildConfig).where(
|
||||
DiscordGuildConfig.registration_key == registration_key
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def create_guild_config(
|
||||
db_session: Session,
|
||||
registration_key: str,
|
||||
) -> DiscordGuildConfig:
|
||||
"""Create a new guild config with a registration key (guild_id=NULL)."""
|
||||
config = DiscordGuildConfig(registration_key=registration_key)
|
||||
db_session.add(config)
|
||||
db_session.flush()
|
||||
return config
|
||||
|
||||
|
||||
def register_guild(
|
||||
db_session: Session,
|
||||
config: DiscordGuildConfig,
|
||||
guild_id: int,
|
||||
guild_name: str,
|
||||
) -> DiscordGuildConfig:
|
||||
"""Complete registration by setting guild_id and guild_name."""
|
||||
config.guild_id = guild_id
|
||||
config.guild_name = guild_name
|
||||
config.registered_at = datetime.now(timezone.utc)
|
||||
db_session.flush()
|
||||
return config
|
||||
|
||||
|
||||
def update_guild_config(
|
||||
db_session: Session,
|
||||
config: DiscordGuildConfig,
|
||||
enabled: bool,
|
||||
default_persona_id: int | None = None,
|
||||
) -> DiscordGuildConfig:
|
||||
"""Update guild config fields."""
|
||||
config.enabled = enabled
|
||||
config.default_persona_id = default_persona_id
|
||||
db_session.flush()
|
||||
return config
|
||||
|
||||
|
||||
def delete_guild_config(
|
||||
db_session: Session,
|
||||
internal_id: int,
|
||||
) -> bool:
|
||||
"""Delete guild config (cascades to channel configs). Returns True if deleted."""
|
||||
result = db_session.execute(
|
||||
delete(DiscordGuildConfig).where(DiscordGuildConfig.id == internal_id)
|
||||
)
|
||||
db_session.flush()
|
||||
return result.rowcount > 0 # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# === DiscordChannelConfig ===
|
||||
|
||||
|
||||
def get_channel_configs(
|
||||
db_session: Session,
|
||||
guild_config_id: int,
|
||||
) -> list[DiscordChannelConfig]:
|
||||
"""Get all channel configs for a guild."""
|
||||
return list(
|
||||
db_session.scalars(
|
||||
select(DiscordChannelConfig).where(
|
||||
DiscordChannelConfig.guild_config_id == guild_config_id
|
||||
)
|
||||
).all()
|
||||
)
|
||||
|
||||
|
||||
def get_channel_config_by_discord_ids(
|
||||
db_session: Session,
|
||||
guild_id: int,
|
||||
channel_id: int,
|
||||
) -> DiscordChannelConfig | None:
|
||||
"""Get a specific channel config by guild_id and channel_id."""
|
||||
return db_session.scalar(
|
||||
select(DiscordChannelConfig)
|
||||
.join(DiscordGuildConfig)
|
||||
.where(
|
||||
DiscordGuildConfig.guild_id == guild_id,
|
||||
DiscordChannelConfig.channel_id == channel_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_channel_config_by_internal_ids(
|
||||
db_session: Session,
|
||||
guild_config_id: int,
|
||||
channel_config_id: int,
|
||||
) -> DiscordChannelConfig | None:
|
||||
"""Get a specific channel config by guild_config_id and channel_config_id"""
|
||||
return db_session.scalar(
|
||||
select(DiscordChannelConfig).where(
|
||||
DiscordChannelConfig.guild_config_id == guild_config_id,
|
||||
DiscordChannelConfig.id == channel_config_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def update_discord_channel_config(
|
||||
db_session: Session,
|
||||
config: DiscordChannelConfig,
|
||||
channel_name: str,
|
||||
thread_only_mode: bool,
|
||||
require_bot_invocation: bool,
|
||||
enabled: bool,
|
||||
persona_override_id: int | None = None,
|
||||
) -> DiscordChannelConfig:
|
||||
"""Update channel config fields."""
|
||||
config.channel_name = channel_name
|
||||
config.require_bot_invocation = require_bot_invocation
|
||||
config.persona_override_id = persona_override_id
|
||||
config.enabled = enabled
|
||||
config.thread_only_mode = thread_only_mode
|
||||
db_session.flush()
|
||||
return config
|
||||
|
||||
|
||||
def delete_discord_channel_config(
|
||||
db_session: Session,
|
||||
guild_config_id: int,
|
||||
channel_config_id: int,
|
||||
) -> bool:
|
||||
"""Delete a channel config. Returns True if deleted."""
|
||||
result = db_session.execute(
|
||||
delete(DiscordChannelConfig).where(
|
||||
DiscordChannelConfig.guild_config_id == guild_config_id,
|
||||
DiscordChannelConfig.id == channel_config_id,
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
return result.rowcount > 0 # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def create_channel_config(
|
||||
db_session: Session,
|
||||
guild_config_id: int,
|
||||
channel_view: DiscordChannelView,
|
||||
) -> DiscordChannelConfig:
|
||||
"""Create a new channel config with default settings (disabled by default, admin enables via UI)."""
|
||||
config = DiscordChannelConfig(
|
||||
guild_config_id=guild_config_id,
|
||||
channel_id=channel_view.channel_id,
|
||||
channel_name=channel_view.channel_name,
|
||||
channel_type=channel_view.channel_type,
|
||||
is_private=channel_view.is_private,
|
||||
)
|
||||
db_session.add(config)
|
||||
db_session.flush()
|
||||
return config
|
||||
|
||||
|
||||
def bulk_create_channel_configs(
|
||||
db_session: Session,
|
||||
guild_config_id: int,
|
||||
channels: list[DiscordChannelView],
|
||||
) -> list[DiscordChannelConfig]:
|
||||
"""Create multiple channel configs at once. Skips existing channels."""
|
||||
# Get existing channel IDs for this guild
|
||||
existing_channel_ids = set(
|
||||
db_session.scalars(
|
||||
select(DiscordChannelConfig.channel_id).where(
|
||||
DiscordChannelConfig.guild_config_id == guild_config_id
|
||||
)
|
||||
).all()
|
||||
)
|
||||
|
||||
# Create configs for new channels only
|
||||
new_configs = []
|
||||
for channel_view in channels:
|
||||
if channel_view.channel_id not in existing_channel_ids:
|
||||
config = DiscordChannelConfig(
|
||||
guild_config_id=guild_config_id,
|
||||
channel_id=channel_view.channel_id,
|
||||
channel_name=channel_view.channel_name,
|
||||
channel_type=channel_view.channel_type,
|
||||
is_private=channel_view.is_private,
|
||||
)
|
||||
db_session.add(config)
|
||||
new_configs.append(config)
|
||||
|
||||
db_session.flush()
|
||||
return new_configs
|
||||
|
||||
|
||||
def sync_channel_configs(
|
||||
db_session: Session,
|
||||
guild_config_id: int,
|
||||
current_channels: list[DiscordChannelView],
|
||||
) -> tuple[int, int, int]:
|
||||
"""Sync channel configs with current Discord channels.
|
||||
|
||||
- Creates configs for new channels (disabled by default)
|
||||
- Removes configs for deleted channels
|
||||
- Updates names and types for existing channels if changed
|
||||
|
||||
Returns: (added_count, removed_count, updated_count)
|
||||
"""
|
||||
current_channel_map = {
|
||||
channel_view.channel_id: channel_view for channel_view in current_channels
|
||||
}
|
||||
current_channel_ids = set(current_channel_map.keys())
|
||||
|
||||
# Get existing configs
|
||||
existing_configs = get_channel_configs(db_session, guild_config_id)
|
||||
existing_channel_ids = {c.channel_id for c in existing_configs}
|
||||
|
||||
# Find channels to add, remove, and potentially update
|
||||
to_add = current_channel_ids - existing_channel_ids
|
||||
to_remove = existing_channel_ids - current_channel_ids
|
||||
|
||||
# Add new channels
|
||||
added_count = 0
|
||||
for channel_id in to_add:
|
||||
channel_view = current_channel_map[channel_id]
|
||||
create_channel_config(db_session, guild_config_id, channel_view)
|
||||
added_count += 1
|
||||
|
||||
# Remove deleted channels
|
||||
removed_count = 0
|
||||
for config in existing_configs:
|
||||
if config.channel_id in to_remove:
|
||||
db_session.delete(config)
|
||||
removed_count += 1
|
||||
|
||||
# Update names, types, and privacy for existing channels if changed
|
||||
updated_count = 0
|
||||
for config in existing_configs:
|
||||
if config.channel_id in current_channel_ids:
|
||||
channel_view = current_channel_map[config.channel_id]
|
||||
changed = False
|
||||
if config.channel_name != channel_view.channel_name:
|
||||
config.channel_name = channel_view.channel_name
|
||||
changed = True
|
||||
if config.channel_type != channel_view.channel_type:
|
||||
config.channel_type = channel_view.channel_type
|
||||
changed = True
|
||||
if config.is_private != channel_view.is_private:
|
||||
config.is_private = channel_view.is_private
|
||||
changed = True
|
||||
if changed:
|
||||
updated_count += 1
|
||||
|
||||
db_session.flush()
|
||||
return added_count, removed_count, updated_count
|
||||
@@ -26,6 +26,7 @@ from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Index
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import BigInteger
|
||||
|
||||
from sqlalchemy import Sequence
|
||||
from sqlalchemy import String
|
||||
@@ -2931,8 +2932,6 @@ class PersonaLabel(Base):
|
||||
"Persona",
|
||||
secondary=Persona__PersonaLabel.__table__,
|
||||
back_populates="labels",
|
||||
cascade="all, delete-orphan",
|
||||
single_parent=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -3038,6 +3037,124 @@ class SlackBot(Base):
|
||||
)
|
||||
|
||||
|
||||
class DiscordBotConfig(Base):
|
||||
"""Global Discord bot configuration (one per tenant).
|
||||
|
||||
Stores the bot token when not provided via DISCORD_BOT_TOKEN env var.
|
||||
Uses a fixed ID with check constraint to enforce only one row per tenant.
|
||||
"""
|
||||
|
||||
__tablename__ = "discord_bot_config"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String, primary_key=True, server_default=text("'SINGLETON'")
|
||||
)
|
||||
bot_token: Mapped[str] = mapped_column(EncryptedString(), nullable=False)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
|
||||
|
||||
class DiscordGuildConfig(Base):
|
||||
"""Configuration for a Discord guild (server) connected to this tenant.
|
||||
|
||||
registration_key is a one-time key used to link a Discord server to this tenant.
|
||||
Format: discord_<tenant_id>.<random_token>
|
||||
guild_id is NULL until the Discord admin runs !register with the key.
|
||||
"""
|
||||
|
||||
__tablename__ = "discord_guild_config"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
|
||||
# Discord snowflake - NULL until registered via command in Discord
|
||||
guild_id: Mapped[int | None] = mapped_column(BigInteger, nullable=True, unique=True)
|
||||
guild_name: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
||||
|
||||
# One-time registration key: discord_<tenant_id>.<random_token>
|
||||
registration_key: Mapped[str] = mapped_column(String, unique=True, nullable=False)
|
||||
|
||||
registered_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
# Configuration
|
||||
default_persona_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("persona.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
enabled: Mapped[bool] = mapped_column(
|
||||
Boolean, server_default=text("true"), nullable=False
|
||||
)
|
||||
|
||||
# Relationships
|
||||
default_persona: Mapped["Persona | None"] = relationship(
|
||||
"Persona", foreign_keys=[default_persona_id]
|
||||
)
|
||||
channels: Mapped[list["DiscordChannelConfig"]] = relationship(
|
||||
back_populates="guild_config", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class DiscordChannelConfig(Base):
|
||||
"""Per-channel configuration for Discord bot behavior.
|
||||
|
||||
Used to whitelist specific channels and configure per-channel behavior.
|
||||
"""
|
||||
|
||||
__tablename__ = "discord_channel_config"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
guild_config_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("discord_guild_config.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
# Discord snowflake
|
||||
channel_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
channel_name: Mapped[str] = mapped_column(String(), nullable=False)
|
||||
|
||||
# Channel type from Discord (text, forum)
|
||||
channel_type: Mapped[str] = mapped_column(
|
||||
String(20), server_default=text("'text'"), nullable=False
|
||||
)
|
||||
|
||||
# True if @everyone cannot view the channel
|
||||
is_private: Mapped[bool] = mapped_column(
|
||||
Boolean, server_default=text("false"), nullable=False
|
||||
)
|
||||
|
||||
# If true, bot only responds to messages in threads
|
||||
# Otherwise, will reply in channel
|
||||
thread_only_mode: Mapped[bool] = mapped_column(
|
||||
Boolean, server_default=text("false"), nullable=False
|
||||
)
|
||||
|
||||
# If true (default), bot only responds when @mentioned
|
||||
# If false, bot responds to ALL messages in this channel
|
||||
require_bot_invocation: Mapped[bool] = mapped_column(
|
||||
Boolean, server_default=text("true"), nullable=False
|
||||
)
|
||||
|
||||
# Override the guild's default persona for this channel
|
||||
persona_override_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("persona.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
|
||||
enabled: Mapped[bool] = mapped_column(
|
||||
Boolean, server_default=text("false"), nullable=False
|
||||
)
|
||||
|
||||
# Relationships
|
||||
guild_config: Mapped["DiscordGuildConfig"] = relationship(back_populates="channels")
|
||||
persona_override: Mapped["Persona | None"] = relationship()
|
||||
|
||||
# Constraints
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"guild_config_id", "channel_id", name="uq_discord_channel_guild_channel"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Milestone(Base):
|
||||
# This table is used to track significant events for a deployment towards finding value
|
||||
# The table is currently not used for features but it may be used in the future to inform
|
||||
@@ -3115,25 +3232,6 @@ class FileRecord(Base):
|
||||
)
|
||||
|
||||
|
||||
class AgentSearchMetrics(Base):
|
||||
__tablename__ = "agent__search_metrics"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
persona_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("persona.id"), nullable=True
|
||||
)
|
||||
agent_type: Mapped[str] = mapped_column(String)
|
||||
start_time: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
|
||||
base_duration_s: Mapped[float] = mapped_column(Float)
|
||||
full_duration_s: Mapped[float] = mapped_column(Float)
|
||||
base_metrics: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
|
||||
refined_metrics: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
|
||||
all_metrics: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
|
||||
|
||||
|
||||
"""
|
||||
************************************************************************
|
||||
Enterprise Edition Models
|
||||
|
||||
@@ -917,7 +917,9 @@ def upsert_persona(
|
||||
existing_persona.icon_name = icon_name
|
||||
existing_persona.is_visible = is_visible
|
||||
existing_persona.search_start_date = search_start_date
|
||||
existing_persona.labels = labels or []
|
||||
if label_ids is not None:
|
||||
existing_persona.labels.clear()
|
||||
existing_persona.labels = labels or []
|
||||
existing_persona.is_default_persona = (
|
||||
is_default_persona
|
||||
if is_default_persona is not None
|
||||
|
||||
@@ -40,3 +40,10 @@ class DocumentRow(BaseModel):
|
||||
class SortOrder(str, Enum):
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
||||
|
||||
class DiscordChannelView(BaseModel):
|
||||
channel_id: int
|
||||
channel_name: str
|
||||
channel_type: str = "text" # text, forum
|
||||
is_private: bool = False # True if @everyone cannot view the channel
|
||||
|
||||
@@ -369,6 +369,8 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
# New output item added
|
||||
output_item = parsed_chunk.get("item", {})
|
||||
if output_item.get("type") == "function_call":
|
||||
# Track that we've received tool calls via streaming
|
||||
self._has_streamed_tool_calls = True
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=ChatCompletionToolCallChunk(
|
||||
@@ -394,6 +396,8 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
content_part: Optional[str] = parsed_chunk.get("delta", None)
|
||||
if content_part:
|
||||
# Track that we've received tool calls via streaming
|
||||
self._has_streamed_tool_calls = True
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=ChatCompletionToolCallChunk(
|
||||
@@ -491,22 +495,72 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
|
||||
elif event_type == "response.completed":
|
||||
# Final event signaling all output items (including parallel tool calls) are done
|
||||
# Check if we already received tool calls via streaming events
|
||||
# There is an issue where OpenAI (not via Azure) will give back the tool calls streamed out as tokens
|
||||
# But on Azure, it's only given out all at once. OpenAI also happens to give back the tool calls in the
|
||||
# response.completed event so we need to throw it out here or there are duplicate tool calls.
|
||||
has_streamed_tool_calls = getattr(self, "_has_streamed_tool_calls", False)
|
||||
|
||||
response_data = parsed_chunk.get("response", {})
|
||||
# Determine finish reason based on response content
|
||||
finish_reason = "stop"
|
||||
if response_data.get("output"):
|
||||
for item in response_data["output"]:
|
||||
if isinstance(item, dict) and item.get("type") == "function_call":
|
||||
finish_reason = "tool_calls"
|
||||
break
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason=finish_reason,
|
||||
usage=None,
|
||||
output_items = response_data.get("output", [])
|
||||
|
||||
# Check if there are function_call items in the output
|
||||
has_function_calls = any(
|
||||
isinstance(item, dict) and item.get("type") == "function_call"
|
||||
for item in output_items
|
||||
)
|
||||
|
||||
if has_function_calls and not has_streamed_tool_calls:
|
||||
# Azure's Responses API returns all tool calls in response.completed
|
||||
# without streaming them incrementally. Extract them here.
|
||||
from litellm.types.utils import (
|
||||
Delta,
|
||||
ModelResponseStream,
|
||||
StreamingChoices,
|
||||
)
|
||||
|
||||
tool_calls = []
|
||||
for idx, item in enumerate(output_items):
|
||||
if isinstance(item, dict) and item.get("type") == "function_call":
|
||||
tool_calls.append(
|
||||
ChatCompletionToolCallChunk(
|
||||
id=item.get("call_id"),
|
||||
index=idx,
|
||||
type="function",
|
||||
function=ChatCompletionToolCallFunctionChunk(
|
||||
name=item.get("name"),
|
||||
arguments=item.get("arguments", ""),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return ModelResponseStream(
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
index=0,
|
||||
delta=Delta(tool_calls=tool_calls),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
]
|
||||
)
|
||||
elif has_function_calls:
|
||||
# Tool calls were already streamed, just signal completion
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason="tool_calls",
|
||||
usage=None,
|
||||
)
|
||||
else:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason="stop",
|
||||
usage=None,
|
||||
)
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
@@ -631,6 +685,40 @@ def _patch_openai_responses_transform_response() -> None:
|
||||
LiteLLMResponsesTransformationHandler.transform_response = _patched_transform_response # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _patch_azure_responses_should_fake_stream() -> None:
|
||||
"""
|
||||
Patches AzureOpenAIResponsesAPIConfig.should_fake_stream to always return False.
|
||||
|
||||
By default, LiteLLM uses "fake streaming" (MockResponsesAPIStreamingIterator) for models
|
||||
not in its database. This causes Azure custom model deployments to buffer the entire
|
||||
response before yielding, resulting in poor time-to-first-token.
|
||||
|
||||
Azure's Responses API supports native streaming, so we override this to always use
|
||||
real streaming (SyncResponsesAPIStreamingIterator).
|
||||
"""
|
||||
from litellm.llms.azure.responses.transformation import (
|
||||
AzureOpenAIResponsesAPIConfig,
|
||||
)
|
||||
|
||||
if (
|
||||
getattr(AzureOpenAIResponsesAPIConfig.should_fake_stream, "__name__", "")
|
||||
== "_patched_should_fake_stream"
|
||||
):
|
||||
return
|
||||
|
||||
def _patched_should_fake_stream(
|
||||
self: Any,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
# Azure Responses API supports native streaming - never fake it
|
||||
return False
|
||||
|
||||
_patched_should_fake_stream.__name__ = "_patched_should_fake_stream"
|
||||
AzureOpenAIResponsesAPIConfig.should_fake_stream = _patched_should_fake_stream # type: ignore[method-assign]
|
||||
|
||||
|
||||
def apply_monkey_patches() -> None:
|
||||
"""
|
||||
Apply all necessary monkey patches to LiteLLM for compatibility.
|
||||
@@ -640,12 +728,13 @@ def apply_monkey_patches() -> None:
|
||||
- Patching OllamaChatCompletionResponseIterator.chunk_parser for streaming content
|
||||
- Patching OpenAiResponsesToChatCompletionStreamIterator.chunk_parser for OpenAI Responses API
|
||||
- Patching LiteLLMResponsesTransformationHandler.transform_response for non-streaming responses
|
||||
- Patching LiteLLMResponsesTransformationHandler._convert_content_str_to_input_text for tool content types
|
||||
- Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming
|
||||
"""
|
||||
_patch_ollama_transform_request()
|
||||
_patch_ollama_chunk_parser()
|
||||
_patch_openai_responses_chunk_parser()
|
||||
_patch_openai_responses_transform_response()
|
||||
_patch_azure_responses_should_fake_stream()
|
||||
|
||||
|
||||
def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[str]]:
|
||||
|
||||
@@ -301,6 +301,12 @@ class LitellmLLM(LLM):
|
||||
)
|
||||
is_ollama = self._model_provider == LlmProviderNames.OLLAMA_CHAT
|
||||
is_mistral = self._model_provider == LlmProviderNames.MISTRAL
|
||||
is_vertex_ai = self._model_provider == LlmProviderNames.VERTEX_AI
|
||||
# Vertex Anthropic Opus 4.5 rejects output_config (LiteLLM maps reasoning_effort).
|
||||
# Keep this guard until LiteLLM/Vertex accept the field for this model.
|
||||
is_vertex_opus_4_5 = (
|
||||
is_vertex_ai and "claude-opus-4-5" in self.config.model_name.lower()
|
||||
)
|
||||
|
||||
#########################
|
||||
# Build arguments
|
||||
@@ -331,12 +337,16 @@ class LitellmLLM(LLM):
|
||||
# Temperature
|
||||
temperature = 1 if is_reasoning else self._temperature
|
||||
|
||||
if stream:
|
||||
if stream and not is_vertex_opus_4_5:
|
||||
optional_kwargs["stream_options"] = {"include_usage": True}
|
||||
|
||||
# Use configured default if not provided (if not set in env, low)
|
||||
reasoning_effort = reasoning_effort or ReasoningEffort(DEFAULT_REASONING_EFFORT)
|
||||
if is_reasoning and reasoning_effort != ReasoningEffort.OFF:
|
||||
if (
|
||||
is_reasoning
|
||||
and reasoning_effort != ReasoningEffort.OFF
|
||||
and not is_vertex_opus_4_5
|
||||
):
|
||||
if is_openai_model:
|
||||
# OpenAI API does not accept reasoning params for GPT 5 chat models
|
||||
# (neither reasoning nor reasoning_effort are accepted)
|
||||
|
||||
@@ -96,6 +96,7 @@ from onyx.server.long_term_logs.long_term_logs_api import (
|
||||
router as long_term_logs_router,
|
||||
)
|
||||
from onyx.server.manage.administrative import router as admin_router
|
||||
from onyx.server.manage.discord_bot.api import router as discord_bot_router
|
||||
from onyx.server.manage.embedding.api import admin_router as embedding_admin_router
|
||||
from onyx.server.manage.embedding.api import basic_router as embedding_router
|
||||
from onyx.server.manage.get_state import router as state_router
|
||||
@@ -380,6 +381,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, slack_bot_management_router
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, discord_bot_router)
|
||||
include_router_with_global_prefix_prepended(application, persona_router)
|
||||
include_router_with_global_prefix_prepended(application, admin_persona_router)
|
||||
include_router_with_global_prefix_prepended(application, agents_router)
|
||||
|
||||
287
backend/onyx/onyxbot/discord/DISCORD_MULTITENANT_README.md
Normal file
287
backend/onyx/onyxbot/discord/DISCORD_MULTITENANT_README.md
Normal file
@@ -0,0 +1,287 @@
|
||||
# Discord Bot Multitenant Architecture
|
||||
|
||||
This document analyzes how the Discord cache manager and API client coordinate to handle multitenant API keys from a single Discord client.
|
||||
|
||||
## Overview
|
||||
|
||||
The Discord bot uses a **single-client, multi-tenant** architecture where one `OnyxDiscordClient` instance serves multiple tenants (organizations) simultaneously. Tenant isolation is achieved through:
|
||||
|
||||
- **Cache Manager**: Maps Discord guilds to tenants and stores per-tenant API keys
|
||||
- **API Client**: Stateless HTTP client that accepts dynamic API keys per request
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────┐
|
||||
│ OnyxDiscordClient │
|
||||
│ │
|
||||
│ ┌─────────────────────────┐ ┌─────────────────────────────┐ │
|
||||
│ │ DiscordCacheManager │ │ OnyxAPIClient │ │
|
||||
│ │ │ │ │ │
|
||||
│ │ guild_id → tenant_id │───▶│ send_chat_message( │ │
|
||||
│ │ tenant_id → api_key │ │ message, │ │
|
||||
│ │ │ │ api_key=<per-tenant>, │ │
|
||||
│ └─────────────────────────┘ │ persona_id=... │ │
|
||||
│ │ ) │ │
|
||||
│ └─────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Component Details
|
||||
|
||||
### 1. Cache Manager (`backend/onyx/onyxbot/discord/cache.py`)
|
||||
|
||||
The `DiscordCacheManager` maintains two critical in-memory mappings:
|
||||
|
||||
```python
|
||||
class DiscordCacheManager:
|
||||
_guild_tenants: dict[int, str] # guild_id → tenant_id
|
||||
_api_keys: dict[str, str] # tenant_id → api_key
|
||||
_lock: asyncio.Lock # Concurrency control
|
||||
```
|
||||
|
||||
#### Key Responsibilities
|
||||
|
||||
| Function | Purpose |
|
||||
|----------|---------|
|
||||
| `get_tenant(guild_id)` | O(1) lookup: guild → tenant |
|
||||
| `get_api_key(tenant_id)` | O(1) lookup: tenant → API key |
|
||||
| `refresh_all()` | Full cache rebuild from database |
|
||||
| `refresh_guild()` | Incremental update for single guild |
|
||||
|
||||
#### API Key Provisioning Strategy
|
||||
|
||||
API keys are **lazily provisioned** - only created when first needed:
|
||||
|
||||
```python
|
||||
async def _load_tenant_data(self, tenant_id: str) -> tuple[list[int], str | None]:
|
||||
needs_key = tenant_id not in self._api_keys
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db:
|
||||
# Load guild configs
|
||||
configs = get_discord_bot_configs(db)
|
||||
guild_ids = [c.guild_id for c in configs if c.enabled]
|
||||
|
||||
# Only provision API key if not already cached
|
||||
api_key = None
|
||||
if needs_key:
|
||||
api_key = get_or_create_discord_service_api_key(db, tenant_id)
|
||||
|
||||
return guild_ids, api_key
|
||||
```
|
||||
|
||||
This optimization avoids repeated database calls for API key generation.
|
||||
|
||||
#### Concurrency Control
|
||||
|
||||
All write operations acquire an async lock to prevent race conditions:
|
||||
|
||||
```python
|
||||
async def refresh_all(self) -> None:
|
||||
async with self._lock:
|
||||
# Safe to modify _guild_tenants and _api_keys
|
||||
for tenant_id in get_all_tenant_ids():
|
||||
guild_ids, api_key = await self._load_tenant_data(tenant_id)
|
||||
# Update mappings...
|
||||
```
|
||||
|
||||
Read operations (`get_tenant`, `get_api_key`) are lock-free since Python dict lookups are atomic.
|
||||
|
||||
---
|
||||
|
||||
### 2. API Client (`backend/onyx/onyxbot/discord/api_client.py`)
|
||||
|
||||
The `OnyxAPIClient` is a **stateless async HTTP client** that communicates with Onyx API pods.
|
||||
|
||||
#### Key Design: Per-Request API Key Injection
|
||||
|
||||
```python
|
||||
class OnyxAPIClient:
|
||||
async def send_chat_message(
|
||||
self,
|
||||
message: str,
|
||||
api_key: str, # Injected per-request
|
||||
persona_id: int | None,
|
||||
...
|
||||
) -> ChatFullResponse:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}", # Tenant-specific auth
|
||||
}
|
||||
# Make request...
|
||||
```
|
||||
|
||||
The client accepts `api_key` as a parameter to each method, enabling **dynamic tenant selection at request time**. This design allows a single client instance to serve multiple tenants:
|
||||
|
||||
```python
|
||||
# Same client, different tenants
|
||||
await api_client.send_chat_message(msg, api_key=key_for_tenant_1, ...)
|
||||
await api_client.send_chat_message(msg, api_key=key_for_tenant_2, ...)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Coordination Flow
|
||||
|
||||
### Message Processing Pipeline
|
||||
|
||||
When a Discord message arrives, the client coordinates cache and API client:
|
||||
|
||||
```python
|
||||
async def on_message(self, message: Message) -> None:
|
||||
guild_id = message.guild.id
|
||||
|
||||
# Step 1: Cache lookup - guild → tenant
|
||||
tenant_id = self.cache.get_tenant(guild_id)
|
||||
if not tenant_id:
|
||||
return # Guild not registered
|
||||
|
||||
# Step 2: Cache lookup - tenant → API key
|
||||
api_key = self.cache.get_api_key(tenant_id)
|
||||
if not api_key:
|
||||
logger.warning(f"No API key for tenant {tenant_id}")
|
||||
return
|
||||
|
||||
# Step 3: API call with tenant-specific credentials
|
||||
await process_chat_message(
|
||||
message=message,
|
||||
api_key=api_key, # Tenant-specific
|
||||
persona_id=persona_id, # Tenant-specific
|
||||
api_client=self.api_client,
|
||||
)
|
||||
```
|
||||
|
||||
### Startup Sequence
|
||||
|
||||
```python
|
||||
async def setup_hook(self) -> None:
|
||||
# 1. Initialize API client (create aiohttp session)
|
||||
await self.api_client.initialize()
|
||||
|
||||
# 2. Populate cache with all tenants
|
||||
await self.cache.refresh_all()
|
||||
|
||||
# 3. Start background refresh task
|
||||
self._cache_refresh_task = self.loop.create_task(
|
||||
self._periodic_cache_refresh() # Every 60 seconds
|
||||
)
|
||||
```
|
||||
|
||||
### Shutdown Sequence
|
||||
|
||||
```python
|
||||
async def close(self) -> None:
|
||||
# 1. Cancel background refresh
|
||||
if self._cache_refresh_task:
|
||||
self._cache_refresh_task.cancel()
|
||||
|
||||
# 2. Close Discord connection
|
||||
await super().close()
|
||||
|
||||
# 3. Close API client session
|
||||
await self.api_client.close()
|
||||
|
||||
# 4. Clear cache
|
||||
self.cache.clear()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tenant Isolation Mechanisms
|
||||
|
||||
### 1. Per-Tenant API Keys
|
||||
|
||||
Each tenant has a dedicated service API key:
|
||||
|
||||
```python
|
||||
# backend/onyx/db/discord_bot.py
|
||||
def get_or_create_discord_service_api_key(db_session: Session, tenant_id: str) -> str:
|
||||
existing = get_discord_service_api_key(db_session)
|
||||
if existing:
|
||||
return regenerate_key(existing)
|
||||
|
||||
# Create LIMITED role key (chat-only permissions)
|
||||
return insert_api_key(
|
||||
db_session=db_session,
|
||||
api_key_args=APIKeyArgs(
|
||||
name=DISCORD_SERVICE_API_KEY_NAME,
|
||||
role=UserRole.LIMITED, # Minimal permissions
|
||||
),
|
||||
user_id=None, # Service account (system-owned)
|
||||
).api_key
|
||||
```
|
||||
|
||||
### 2. Database Context Variables
|
||||
|
||||
The cache uses context variables for proper tenant-scoped DB sessions:
|
||||
|
||||
```python
|
||||
context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db:
|
||||
# All DB operations scoped to this tenant
|
||||
...
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token)
|
||||
```
|
||||
|
||||
### 3. Enterprise Gating Support
|
||||
|
||||
Gated tenants are filtered during cache refresh:
|
||||
|
||||
```python
|
||||
gated_tenants = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.product_gating",
|
||||
"get_gated_tenants",
|
||||
set(),
|
||||
)()
|
||||
|
||||
for tenant_id in get_all_tenant_ids():
|
||||
if tenant_id in gated_tenants:
|
||||
continue # Skip gated tenants
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Cache Refresh Strategy
|
||||
|
||||
| Trigger | Method | Scope |
|
||||
|---------|--------|-------|
|
||||
| Startup | `refresh_all()` | All tenants |
|
||||
| Periodic (60s) | `refresh_all()` | All tenants |
|
||||
| Guild registration | `refresh_guild()` | Single tenant |
|
||||
|
||||
### Error Handling
|
||||
|
||||
- **Tenant-level errors**: Logged and skipped (doesn't stop other tenants)
|
||||
- **Missing API key**: Bot silently ignores messages from that guild
|
||||
- **Network errors**: Logged, cache continues with stale data until next refresh
|
||||
|
||||
---
|
||||
|
||||
## Key Design Insights
|
||||
|
||||
1. **Single Client, Multiple Tenants**: One `OnyxAPIClient` and one `DiscordCacheManager` instance serves all tenants via dynamic API key injection.
|
||||
|
||||
2. **Cache-First Architecture**: Guild lookups are O(1) in-memory; API keys are cached after first provisioning to avoid repeated DB calls.
|
||||
|
||||
3. **Graceful Degradation**: If an API key is missing or stale, the bot simply doesn't respond (no crash or error propagation).
|
||||
|
||||
4. **Thread Safety Without Blocking**: `asyncio.Lock` prevents race conditions while maintaining async concurrency for reads.
|
||||
|
||||
5. **Lazy Provisioning**: API keys are only created when first needed, then cached for performance.
|
||||
|
||||
6. **Stateless API Client**: The HTTP client holds no tenant state - all tenant context is injected per-request via the `api_key` parameter.
|
||||
|
||||
---
|
||||
|
||||
## File References
|
||||
|
||||
| Component | Path |
|
||||
|-----------|------|
|
||||
| Cache Manager | `backend/onyx/onyxbot/discord/cache.py` |
|
||||
| API Client | `backend/onyx/onyxbot/discord/api_client.py` |
|
||||
| Discord Client | `backend/onyx/onyxbot/discord/client.py` |
|
||||
| API Key DB Operations | `backend/onyx/db/discord_bot.py` |
|
||||
| Cache Manager Tests | `backend/tests/unit/onyx/onyxbot/discord/test_cache_manager.py` |
|
||||
| API Client Tests | `backend/tests/unit/onyx/onyxbot/discord/test_api_client.py` |
|
||||
215
backend/onyx/onyxbot/discord/api_client.py
Normal file
215
backend/onyx/onyxbot/discord/api_client.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""Async HTTP client for communicating with Onyx API pods."""
|
||||
|
||||
import aiohttp
|
||||
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.onyxbot.discord.constants import API_REQUEST_TIMEOUT
|
||||
from onyx.onyxbot.discord.exceptions import APIConnectionError
|
||||
from onyx.onyxbot.discord.exceptions import APIResponseError
|
||||
from onyx.onyxbot.discord.exceptions import APITimeoutError
|
||||
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import build_api_server_url_for_http_requests
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class OnyxAPIClient:
|
||||
"""Async HTTP client for sending chat requests to Onyx API pods.
|
||||
|
||||
This client manages an aiohttp session for making non-blocking HTTP
|
||||
requests to the Onyx API server. It handles authentication with per-tenant
|
||||
API keys and multi-tenant routing.
|
||||
|
||||
Usage:
|
||||
client = OnyxAPIClient()
|
||||
await client.initialize()
|
||||
try:
|
||||
response = await client.send_chat_message(
|
||||
message="What is our deployment process?",
|
||||
tenant_id="tenant_123",
|
||||
api_key="dn_xxx...",
|
||||
persona_id=1,
|
||||
)
|
||||
print(response.answer)
|
||||
finally:
|
||||
await client.close()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout: int = API_REQUEST_TIMEOUT,
|
||||
) -> None:
|
||||
"""Initialize the API client.
|
||||
|
||||
Args:
|
||||
timeout: Request timeout in seconds.
|
||||
"""
|
||||
# Helm chart uses API_SERVER_URL_OVERRIDE_FOR_HTTP_REQUESTS to set the base URL
|
||||
# TODO: Ideally, this override is only used when someone is launching an Onyx service independently
|
||||
self._base_url = build_api_server_url_for_http_requests(
|
||||
respect_env_override_if_set=True
|
||||
).rstrip("/")
|
||||
self._timeout = timeout
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Create the aiohttp session.
|
||||
|
||||
Must be called before making any requests. The session is created
|
||||
with a total timeout and connection timeout.
|
||||
"""
|
||||
if self._session is not None:
|
||||
logger.warning("API client session already initialized")
|
||||
return
|
||||
|
||||
timeout = aiohttp.ClientTimeout(
|
||||
total=self._timeout,
|
||||
connect=30, # 30 seconds to establish connection
|
||||
)
|
||||
self._session = aiohttp.ClientSession(timeout=timeout)
|
||||
logger.info(f"API client initialized with base URL: {self._base_url}")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the aiohttp session.
|
||||
|
||||
Should be called when shutting down the bot to properly release
|
||||
resources.
|
||||
"""
|
||||
if self._session is not None:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
logger.info("API client session closed")
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if the session is initialized."""
|
||||
return self._session is not None
|
||||
|
||||
async def send_chat_message(
|
||||
self,
|
||||
message: str,
|
||||
api_key: str,
|
||||
persona_id: int | None = None,
|
||||
) -> ChatFullResponse:
|
||||
"""Send a chat message to the Onyx API server and get a response.
|
||||
|
||||
This method sends a non-streaming chat request to the API server. The response
|
||||
contains the complete answer with any citations and metadata.
|
||||
|
||||
Args:
|
||||
message: The user's message to process.
|
||||
api_key: The API key for authentication.
|
||||
persona_id: Optional persona ID to use for the response.
|
||||
|
||||
Returns:
|
||||
ChatFullResponse containing the answer, citations, and metadata.
|
||||
|
||||
Raises:
|
||||
APIConnectionError: If unable to connect to the API.
|
||||
APITimeoutError: If the request times out.
|
||||
APIResponseError: If the API returns an error response.
|
||||
"""
|
||||
if self._session is None:
|
||||
raise APIConnectionError(
|
||||
"API client not initialized. Call initialize() first."
|
||||
)
|
||||
|
||||
url = f"{self._base_url}/chat/send-chat-message"
|
||||
|
||||
# Build request payload
|
||||
request = SendMessageRequest(
|
||||
message=message,
|
||||
stream=False,
|
||||
origin=MessageOrigin.DISCORDBOT,
|
||||
chat_session_info=ChatSessionCreationRequest(
|
||||
persona_id=persona_id if persona_id is not None else 0,
|
||||
),
|
||||
)
|
||||
|
||||
# Build headers
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
|
||||
try:
|
||||
async with self._session.post(
|
||||
url,
|
||||
json=request.model_dump(mode="json"),
|
||||
headers=headers,
|
||||
) as response:
|
||||
if response.status == 401:
|
||||
raise APIResponseError(
|
||||
"Authentication failed - invalid API key",
|
||||
status_code=401,
|
||||
)
|
||||
elif response.status == 403:
|
||||
raise APIResponseError(
|
||||
"Access denied - insufficient permissions",
|
||||
status_code=403,
|
||||
)
|
||||
elif response.status == 404:
|
||||
raise APIResponseError(
|
||||
"API endpoint not found",
|
||||
status_code=404,
|
||||
)
|
||||
elif response.status >= 500:
|
||||
error_text = await response.text()
|
||||
raise APIResponseError(
|
||||
f"Server error: {error_text}",
|
||||
status_code=response.status,
|
||||
)
|
||||
elif response.status >= 400:
|
||||
error_text = await response.text()
|
||||
raise APIResponseError(
|
||||
f"Request error: {error_text}",
|
||||
status_code=response.status,
|
||||
)
|
||||
|
||||
# Parse successful response
|
||||
data = await response.json()
|
||||
response_obj = ChatFullResponse.model_validate(data)
|
||||
|
||||
if response_obj.error_msg:
|
||||
logger.warning(f"Chat API returned error: {response_obj.error_msg}")
|
||||
|
||||
return response_obj
|
||||
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"Failed to connect to API: {e}")
|
||||
raise APIConnectionError(
|
||||
f"Failed to connect to API at {self._base_url}: {e}"
|
||||
) from e
|
||||
|
||||
except TimeoutError as e:
|
||||
logger.error(f"API request timed out after {self._timeout}s")
|
||||
raise APITimeoutError(
|
||||
f"Request timed out after {self._timeout} seconds"
|
||||
) from e
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"HTTP client error: {e}")
|
||||
raise APIConnectionError(f"HTTP client error: {e}") from e
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""Check if the API server is healthy.
|
||||
|
||||
Returns:
|
||||
True if the API server is reachable and healthy, False otherwise.
|
||||
"""
|
||||
if self._session is None:
|
||||
logger.warning("API client not initialized. Call initialize() first.")
|
||||
return False
|
||||
|
||||
try:
|
||||
url = f"{self._base_url}/health"
|
||||
async with self._session.get(
|
||||
url, timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as response:
|
||||
return response.status == 200
|
||||
except Exception as e:
|
||||
logger.warning(f"API server health check failed: {e}")
|
||||
return False
|
||||
154
backend/onyx/onyxbot/discord/cache.py
Normal file
154
backend/onyx/onyxbot/discord/cache.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""Multi-tenant cache for Discord bot guild-tenant mappings and API keys."""
|
||||
|
||||
import asyncio
|
||||
|
||||
from onyx.db.discord_bot import get_guild_configs
|
||||
from onyx.db.discord_bot import get_or_create_discord_service_api_key
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.onyxbot.discord.exceptions import CacheError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class DiscordCacheManager:
|
||||
"""Caches guild->tenant mappings and tenant->API key mappings.
|
||||
|
||||
Refreshed on startup, periodically (every 60s), and when guilds register.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._guild_tenants: dict[int, str] = {} # guild_id -> tenant_id
|
||||
self._api_keys: dict[str, str] = {} # tenant_id -> api_key
|
||||
self._lock = asyncio.Lock()
|
||||
self._initialized = False
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
return self._initialized
|
||||
|
||||
async def refresh_all(self) -> None:
|
||||
"""Full cache refresh from all tenants."""
|
||||
async with self._lock:
|
||||
logger.info("Starting Discord cache refresh")
|
||||
|
||||
new_guild_tenants: dict[int, str] = {}
|
||||
new_api_keys: dict[str, str] = {}
|
||||
|
||||
try:
|
||||
gated = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.product_gating",
|
||||
"get_gated_tenants",
|
||||
set(),
|
||||
)()
|
||||
|
||||
tenant_ids = await asyncio.to_thread(get_all_tenant_ids)
|
||||
for tenant_id in tenant_ids:
|
||||
if tenant_id in gated:
|
||||
continue
|
||||
|
||||
context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
guild_ids, api_key = await self._load_tenant_data(tenant_id)
|
||||
if not guild_ids:
|
||||
logger.debug(f"No guilds found for tenant {tenant_id}")
|
||||
continue
|
||||
|
||||
if not api_key:
|
||||
logger.warning(
|
||||
"Discord service API key missing for tenant that has registered guilds. "
|
||||
f"{tenant_id} will not be handled in this refresh cycle."
|
||||
)
|
||||
continue
|
||||
|
||||
for guild_id in guild_ids:
|
||||
new_guild_tenants[guild_id] = tenant_id
|
||||
|
||||
new_api_keys[tenant_id] = api_key
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to refresh tenant {tenant_id}: {e}")
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token)
|
||||
|
||||
self._guild_tenants = new_guild_tenants
|
||||
self._api_keys = new_api_keys
|
||||
self._initialized = True
|
||||
|
||||
logger.info(
|
||||
f"Cache refresh complete: {len(new_guild_tenants)} guilds, "
|
||||
f"{len(new_api_keys)} tenants"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cache refresh failed: {e}")
|
||||
raise CacheError(f"Failed to refresh cache: {e}") from e
|
||||
|
||||
async def refresh_guild(self, guild_id: int, tenant_id: str) -> None:
|
||||
"""Add a single guild to cache after registration."""
|
||||
async with self._lock:
|
||||
logger.info(f"Refreshing cache for guild {guild_id} (tenant: {tenant_id})")
|
||||
|
||||
guild_ids, api_key = await self._load_tenant_data(tenant_id)
|
||||
|
||||
if guild_id in guild_ids:
|
||||
self._guild_tenants[guild_id] = tenant_id
|
||||
if api_key:
|
||||
self._api_keys[tenant_id] = api_key
|
||||
logger.info(f"Cache updated for guild {guild_id}")
|
||||
else:
|
||||
logger.warning(f"Guild {guild_id} not found or disabled")
|
||||
|
||||
async def _load_tenant_data(self, tenant_id: str) -> tuple[list[int], str | None]:
|
||||
"""Load guild IDs and provision API key if needed.
|
||||
|
||||
Returns:
|
||||
(active_guild_ids, api_key) - api_key is the cached key if available,
|
||||
otherwise a newly created key. Returns None if no guilds found.
|
||||
"""
|
||||
cached_key = self._api_keys.get(tenant_id)
|
||||
|
||||
def _sync() -> tuple[list[int], str | None]:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db:
|
||||
configs = get_guild_configs(db)
|
||||
guild_ids = [
|
||||
config.guild_id
|
||||
for config in configs
|
||||
if config.enabled and config.guild_id is not None
|
||||
]
|
||||
|
||||
if not guild_ids:
|
||||
return [], None
|
||||
|
||||
if not cached_key:
|
||||
new_key = get_or_create_discord_service_api_key(db, tenant_id)
|
||||
db.commit()
|
||||
return guild_ids, new_key
|
||||
|
||||
return guild_ids, cached_key
|
||||
|
||||
return await asyncio.to_thread(_sync)
|
||||
|
||||
def get_tenant(self, guild_id: int) -> str | None:
|
||||
"""Get tenant ID for a guild."""
|
||||
return self._guild_tenants.get(guild_id)
|
||||
|
||||
def get_api_key(self, tenant_id: str) -> str | None:
|
||||
"""Get API key for a tenant."""
|
||||
return self._api_keys.get(tenant_id)
|
||||
|
||||
def remove_guild(self, guild_id: int) -> None:
|
||||
"""Remove a guild from cache."""
|
||||
self._guild_tenants.pop(guild_id, None)
|
||||
|
||||
def get_all_guild_ids(self) -> list[int]:
|
||||
"""Get all cached guild IDs."""
|
||||
return list(self._guild_tenants.keys())
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all caches."""
|
||||
self._guild_tenants.clear()
|
||||
self._api_keys.clear()
|
||||
self._initialized = False
|
||||
232
backend/onyx/onyxbot/discord/client.py
Normal file
232
backend/onyx/onyxbot/discord/client.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""Discord bot client with integrated message handling."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
||||
from onyx.configs.app_configs import DISCORD_BOT_INVOKE_CHAR
|
||||
from onyx.onyxbot.discord.api_client import OnyxAPIClient
|
||||
from onyx.onyxbot.discord.cache import DiscordCacheManager
|
||||
from onyx.onyxbot.discord.constants import CACHE_REFRESH_INTERVAL
|
||||
from onyx.onyxbot.discord.handle_commands import handle_dm
|
||||
from onyx.onyxbot.discord.handle_commands import handle_registration_command
|
||||
from onyx.onyxbot.discord.handle_commands import handle_sync_channels_command
|
||||
from onyx.onyxbot.discord.handle_message import process_chat_message
|
||||
from onyx.onyxbot.discord.handle_message import should_respond
|
||||
from onyx.onyxbot.discord.utils import get_bot_token
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class OnyxDiscordClient(commands.Bot):
|
||||
"""Discord bot client with integrated cache, API client, and message handling.
|
||||
|
||||
This client handles:
|
||||
- Guild registration via !register command
|
||||
- Message processing with persona-based responses
|
||||
- Thread context for conversation continuity
|
||||
- Multi-tenant support via cached API keys
|
||||
"""
|
||||
|
||||
def __init__(self, command_prefix: str = DISCORD_BOT_INVOKE_CHAR) -> None:
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
intents.members = True
|
||||
|
||||
super().__init__(command_prefix=command_prefix, intents=intents)
|
||||
|
||||
self.ready = False
|
||||
self.cache = DiscordCacheManager()
|
||||
self.api_client = OnyxAPIClient()
|
||||
self._cache_refresh_task: asyncio.Task | None = None
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Lifecycle Methods
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def setup_hook(self) -> None:
|
||||
"""Called before on_ready. Initialize components."""
|
||||
logger.info("Initializing Discord bot components...")
|
||||
|
||||
# Initialize API client
|
||||
await self.api_client.initialize()
|
||||
|
||||
# Initial cache load
|
||||
await self.cache.refresh_all()
|
||||
|
||||
# Start periodic cache refresh
|
||||
self._cache_refresh_task = self.loop.create_task(self._periodic_cache_refresh())
|
||||
|
||||
logger.info("Discord bot components initialized")
|
||||
|
||||
async def _periodic_cache_refresh(self) -> None:
|
||||
"""Background task to refresh cache periodically."""
|
||||
while not self.is_closed():
|
||||
await asyncio.sleep(CACHE_REFRESH_INTERVAL)
|
||||
try:
|
||||
await self.cache.refresh_all()
|
||||
except Exception as e:
|
||||
logger.error(f"Cache refresh failed: {e}")
|
||||
|
||||
async def on_ready(self) -> None:
|
||||
"""Bot connected and ready."""
|
||||
if self.ready:
|
||||
return
|
||||
|
||||
if not self.user:
|
||||
raise RuntimeError("Critical error: Discord Bot user not found")
|
||||
|
||||
logger.info(f"Discord Bot connected as {self.user} (ID: {self.user.id})")
|
||||
logger.info(f"Connected to {len(self.guilds)} guild(s)")
|
||||
logger.info(f"Cached {len(self.cache.get_all_guild_ids())} registered guild(s)")
|
||||
|
||||
self.ready = True
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Graceful shutdown."""
|
||||
logger.info("Shutting down Discord bot...")
|
||||
|
||||
# Cancel cache refresh task
|
||||
if self._cache_refresh_task:
|
||||
self._cache_refresh_task.cancel()
|
||||
try:
|
||||
await self._cache_refresh_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Close Discord connection first - stops new commands from triggering cache ops
|
||||
if not self.is_closed():
|
||||
await super().close()
|
||||
|
||||
# Close API client
|
||||
await self.api_client.close()
|
||||
|
||||
# Clear cache (safe now - no concurrent operations possible)
|
||||
self.cache.clear()
|
||||
|
||||
self.ready = False
|
||||
logger.info("Discord bot shutdown complete")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Message Handling
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def on_message(self, message: discord.Message) -> None:
|
||||
"""Main message handler."""
|
||||
# mypy
|
||||
if not self.user:
|
||||
raise RuntimeError("Critical error: Discord Bot user not found")
|
||||
|
||||
try:
|
||||
# Ignore bot messages
|
||||
if message.author.bot:
|
||||
return
|
||||
|
||||
# Ignore thread starter messages (empty reference nodes that don't contain content)
|
||||
if message.type == discord.MessageType.thread_starter_message:
|
||||
return
|
||||
|
||||
# Handle DMs
|
||||
if isinstance(message.channel, discord.DMChannel):
|
||||
await handle_dm(message)
|
||||
return
|
||||
|
||||
# Must have a guild
|
||||
if not message.guild or not message.guild.id:
|
||||
return
|
||||
|
||||
guild_id = message.guild.id
|
||||
|
||||
# Check for registration command first
|
||||
if await handle_registration_command(message, self.cache):
|
||||
return
|
||||
|
||||
# Look up guild in cache
|
||||
tenant_id = self.cache.get_tenant(guild_id)
|
||||
|
||||
# Check for sync-channels command (requires registered guild)
|
||||
if await handle_sync_channels_command(message, tenant_id, self):
|
||||
return
|
||||
|
||||
if not tenant_id:
|
||||
# Guild not registered, ignore
|
||||
return
|
||||
|
||||
# Get API key
|
||||
api_key = self.cache.get_api_key(tenant_id)
|
||||
if not api_key:
|
||||
logger.warning(f"No API key cached for tenant {tenant_id}")
|
||||
return
|
||||
|
||||
# Check if bot should respond
|
||||
should_respond_context = await should_respond(message, tenant_id, self.user)
|
||||
|
||||
if not should_respond_context.should_respond:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Processing message: '{message.content[:50]}' in "
|
||||
f"#{getattr(message.channel, 'name', 'unknown')} ({message.guild.name}), "
|
||||
f"persona_id={should_respond_context.persona_id}"
|
||||
)
|
||||
|
||||
# Process the message
|
||||
await process_chat_message(
|
||||
message=message,
|
||||
api_key=api_key,
|
||||
persona_id=should_respond_context.persona_id,
|
||||
thread_only_mode=should_respond_context.thread_only_mode,
|
||||
api_client=self.api_client,
|
||||
bot_user=self.user,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing message: {e}")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Entry Point
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main entry point for Discord bot."""
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
logger.info("Starting Onyx Discord Bot...")
|
||||
|
||||
# Initialize the database engine (required before any DB operations)
|
||||
SqlEngine.init_engine(pool_size=20, max_overflow=5)
|
||||
|
||||
# Initialize EE features based on environment
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
counter = 0
|
||||
while True:
|
||||
token = get_bot_token()
|
||||
if not token:
|
||||
if counter % 180 == 0:
|
||||
logger.info(
|
||||
"Discord bot is dormant. Waiting for token configuration..."
|
||||
)
|
||||
counter += 1
|
||||
time.sleep(5)
|
||||
continue
|
||||
counter = 0
|
||||
bot = OnyxDiscordClient()
|
||||
|
||||
try:
|
||||
# bot.run() handles SIGINT/SIGTERM and calls close() automatically
|
||||
bot.run(token)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Fatal error in Discord bot")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
19
backend/onyx/onyxbot/discord/constants.py
Normal file
19
backend/onyx/onyxbot/discord/constants.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Discord bot constants."""
|
||||
|
||||
# API settings
|
||||
API_REQUEST_TIMEOUT: int = 3 * 60 # 3 minutes
|
||||
|
||||
# Cache settings
|
||||
CACHE_REFRESH_INTERVAL: int = 60 # 1 minute
|
||||
|
||||
# Message settings
|
||||
MAX_MESSAGE_LENGTH: int = 2000 # Discord's character limit
|
||||
MAX_CONTEXT_MESSAGES: int = 10 # Max messages to include in conversation context
|
||||
# Note: Discord.py's add_reaction() requires unicode emoji, not :name: format
|
||||
THINKING_EMOJI: str = "🤔" # U+1F914 - Thinking Face
|
||||
SUCCESS_EMOJI: str = "✅" # U+2705 - White Heavy Check Mark
|
||||
ERROR_EMOJI: str = "❌" # U+274C - Cross Mark
|
||||
|
||||
# Command prefix
|
||||
REGISTER_COMMAND: str = "register"
|
||||
SYNC_CHANNELS_COMMAND: str = "sync-channels"
|
||||
37
backend/onyx/onyxbot/discord/exceptions.py
Normal file
37
backend/onyx/onyxbot/discord/exceptions.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Custom exception classes for Discord bot."""
|
||||
|
||||
|
||||
class DiscordBotError(Exception):
|
||||
"""Base exception for Discord bot errors."""
|
||||
|
||||
|
||||
class RegistrationError(DiscordBotError):
|
||||
"""Error during guild registration."""
|
||||
|
||||
|
||||
class SyncChannelsError(DiscordBotError):
|
||||
"""Error during channel sync."""
|
||||
|
||||
|
||||
class APIError(DiscordBotError):
|
||||
"""Base API error."""
|
||||
|
||||
|
||||
class CacheError(DiscordBotError):
|
||||
"""Error during cache operations."""
|
||||
|
||||
|
||||
class APIConnectionError(APIError):
|
||||
"""Failed to connect to API."""
|
||||
|
||||
|
||||
class APITimeoutError(APIError):
|
||||
"""Request timed out."""
|
||||
|
||||
|
||||
class APIResponseError(APIError):
|
||||
"""API returned an error response."""
|
||||
|
||||
def __init__(self, message: str, status_code: int | None = None):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
437
backend/onyx/onyxbot/discord/handle_commands.py
Normal file
437
backend/onyx/onyxbot/discord/handle_commands.py
Normal file
@@ -0,0 +1,437 @@
|
||||
"""Discord bot command handlers for registration and channel sync."""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import discord
|
||||
|
||||
from onyx.configs.app_configs import DISCORD_BOT_INVOKE_CHAR
|
||||
from onyx.configs.constants import ONYX_DISCORD_URL
|
||||
from onyx.db.discord_bot import bulk_create_channel_configs
|
||||
from onyx.db.discord_bot import get_guild_config_by_discord_id
|
||||
from onyx.db.discord_bot import get_guild_config_by_internal_id
|
||||
from onyx.db.discord_bot import get_guild_config_by_registration_key
|
||||
from onyx.db.discord_bot import sync_channel_configs
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.utils import DiscordChannelView
|
||||
from onyx.onyxbot.discord.cache import DiscordCacheManager
|
||||
from onyx.onyxbot.discord.constants import REGISTER_COMMAND
|
||||
from onyx.onyxbot.discord.constants import SYNC_CHANNELS_COMMAND
|
||||
from onyx.onyxbot.discord.exceptions import RegistrationError
|
||||
from onyx.onyxbot.discord.exceptions import SyncChannelsError
|
||||
from onyx.server.manage.discord_bot.utils import parse_discord_registration_key
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
async def handle_dm(message: discord.Message) -> None:
|
||||
"""Handle direct messages."""
|
||||
dm_response = (
|
||||
"**I can't respond to DMs** :sweat:\n\n"
|
||||
f"Please chat with me in a server channel, or join the official "
|
||||
f"[Onyx Discord]({ONYX_DISCORD_URL}) for help!"
|
||||
)
|
||||
await message.channel.send(dm_response)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Helper functions for error handling
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _try_dm_author(message: discord.Message, content: str) -> bool:
|
||||
"""Attempt to DM the message author. Returns True if successful."""
|
||||
logger.debug(f"Responding in Discord DM with {content}")
|
||||
try:
|
||||
await message.author.send(content)
|
||||
return True
|
||||
except (discord.Forbidden, discord.HTTPException) as e:
|
||||
# User has DMs disabled or other error
|
||||
logger.warning(f"Failed to DM author {message.author.id}: {e}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error DMing author {message.author.id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def _try_delete_message(message: discord.Message) -> bool:
|
||||
"""Attempt to delete a message. Returns True if successful."""
|
||||
logger.debug(f"Deleting potentially sensitive message {message.id}")
|
||||
try:
|
||||
await message.delete()
|
||||
return True
|
||||
except (discord.Forbidden, discord.HTTPException) as e:
|
||||
# Bot lacks permission or other error
|
||||
logger.warning(f"Failed to delete message {message.id}: {e}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error deleting message {message.id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def _try_react_x(message: discord.Message) -> bool:
|
||||
"""Attempt to react to a message with ❌. Returns True if successful."""
|
||||
try:
|
||||
await message.add_reaction("❌")
|
||||
return True
|
||||
except (discord.Forbidden, discord.HTTPException) as e:
|
||||
# Bot lacks permission or other error
|
||||
logger.warning(f"Failed to react to message {message.id}: {e}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error reacting to message {message.id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Registration
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def handle_registration_command(
|
||||
message: discord.Message,
|
||||
cache: DiscordCacheManager,
|
||||
) -> bool:
|
||||
"""Handle !register command. Returns True if command was handled."""
|
||||
content = message.content.strip()
|
||||
|
||||
# Check for !register command
|
||||
if not content.startswith(f"{DISCORD_BOT_INVOKE_CHAR}{REGISTER_COMMAND}"):
|
||||
return False
|
||||
|
||||
# Must be in a server
|
||||
if not message.guild:
|
||||
await _try_dm_author(
|
||||
message, "This command can only be used in a server channel."
|
||||
)
|
||||
return True
|
||||
|
||||
guild_name = message.guild.name
|
||||
logger.info(f"Registration command received: {guild_name}")
|
||||
|
||||
try:
|
||||
# Parse the registration key
|
||||
parts = content.split(maxsplit=1)
|
||||
if len(parts) < 2:
|
||||
raise RegistrationError(
|
||||
"Invalid registration key format. Please check the key and try again."
|
||||
)
|
||||
|
||||
registration_key = parts[1].strip()
|
||||
|
||||
if not message.author or not isinstance(message.author, discord.Member):
|
||||
raise RegistrationError(
|
||||
"You need to be a server administrator to register the bot."
|
||||
)
|
||||
|
||||
# Check permissions - require admin or manage_guild
|
||||
if not message.author.guild_permissions.administrator:
|
||||
if not message.author.guild_permissions.manage_guild:
|
||||
raise RegistrationError(
|
||||
"You need **Administrator** or **Manage Server** permissions "
|
||||
"to register this bot."
|
||||
)
|
||||
|
||||
await _register_guild(message, registration_key, cache)
|
||||
logger.info(f"Registration successful: {guild_name}")
|
||||
await message.reply(
|
||||
":white_check_mark: **Successfully registered!**\n\n"
|
||||
"This server is now connected to Onyx. "
|
||||
"I'll respond to messages based on your server and channel settings set in Onyx."
|
||||
)
|
||||
except RegistrationError as e:
|
||||
logger.debug(f"Registration failed: {guild_name}, error={e}")
|
||||
await _try_dm_author(message, f":x: **Registration failed.**\n\n{e}")
|
||||
await _try_delete_message(message)
|
||||
except Exception:
|
||||
logger.exception(f"Registration failed unexpectedly: {guild_name}")
|
||||
await _try_dm_author(
|
||||
message,
|
||||
":x: **Registration failed.**\n\n"
|
||||
"An unexpected error occurred. Please try again later.",
|
||||
)
|
||||
await _try_delete_message(message)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def _register_guild(
|
||||
message: discord.Message,
|
||||
registration_key: str,
|
||||
cache: DiscordCacheManager,
|
||||
) -> None:
|
||||
"""Register a guild with a registration key."""
|
||||
if not message.guild:
|
||||
# mypy, even though we already know that message.guild is not None
|
||||
raise RegistrationError("This command can only be used in a server.")
|
||||
|
||||
logger.info(f"Guild '{message.guild.name}' attempting to register Discord bot")
|
||||
registration_key = registration_key.strip()
|
||||
|
||||
# Parse tenant_id from registration key
|
||||
parsed = parse_discord_registration_key(registration_key)
|
||||
if parsed is None:
|
||||
raise RegistrationError(
|
||||
"Invalid registration key format. Please check the key and try again."
|
||||
)
|
||||
|
||||
tenant_id = parsed
|
||||
|
||||
logger.info(f"Parsed tenant_id {tenant_id} from registration key")
|
||||
|
||||
# Check if this guild is already registered to any tenant
|
||||
guild_id = message.guild.id
|
||||
existing_tenant = cache.get_tenant(guild_id)
|
||||
if existing_tenant is not None:
|
||||
logger.warning(
|
||||
f"Guild {guild_id} is already registered to tenant {existing_tenant}"
|
||||
)
|
||||
raise RegistrationError(
|
||||
"This server is already registered.\n\n"
|
||||
"OnyxBot can only connect one Discord server to one Onyx workspace."
|
||||
)
|
||||
|
||||
context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
guild = message.guild
|
||||
guild_name = guild.name
|
||||
|
||||
# Collect all text channels from the guild
|
||||
channels = get_text_channels(guild)
|
||||
logger.info(f"Found {len(channels)} text channels in guild '{guild_name}'")
|
||||
|
||||
# Validate and update in database
|
||||
def _sync_register() -> int:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db:
|
||||
# Find the guild config by registration key
|
||||
config = get_guild_config_by_registration_key(db, registration_key)
|
||||
if not config:
|
||||
raise RegistrationError(
|
||||
"Registration key not found.\n\n"
|
||||
"The key may have expired or been deleted. "
|
||||
"Please generate a new one from the Onyx admin panel."
|
||||
)
|
||||
|
||||
# Check if already used
|
||||
if config.guild_id is not None:
|
||||
raise RegistrationError(
|
||||
"This registration key has already been used.\n\n"
|
||||
"Each key can only be used once. "
|
||||
"Please generate a new key from the Onyx admin panel."
|
||||
)
|
||||
|
||||
# Update the guild config
|
||||
config.guild_id = guild_id
|
||||
config.guild_name = guild_name
|
||||
config.registered_at = datetime.now(timezone.utc)
|
||||
|
||||
# Create channel configs for all text channels
|
||||
bulk_create_channel_configs(db, config.id, channels)
|
||||
|
||||
db.commit()
|
||||
return config.id
|
||||
|
||||
await asyncio.to_thread(_sync_register)
|
||||
|
||||
# Refresh cache for this guild
|
||||
await cache.refresh_guild(guild_id, tenant_id)
|
||||
|
||||
logger.info(
|
||||
f"Guild '{guild_name}' registered with {len(channels)} channel configs"
|
||||
)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token)
|
||||
|
||||
|
||||
def get_text_channels(guild: discord.Guild) -> list[DiscordChannelView]:
|
||||
"""Get all text channels from a guild as DiscordChannelView objects."""
|
||||
channels: list[DiscordChannelView] = []
|
||||
for channel in guild.channels:
|
||||
# Include text channels and forum channels (where threads can be created)
|
||||
if isinstance(channel, (discord.TextChannel, discord.ForumChannel)):
|
||||
# Check if channel is private (not visible to @everyone)
|
||||
everyone_perms = channel.permissions_for(guild.default_role)
|
||||
is_private = not everyone_perms.view_channel
|
||||
|
||||
logger.debug(
|
||||
f"Found channel: #{channel.name}, "
|
||||
f"type={channel.type.name}, is_private={is_private}"
|
||||
)
|
||||
|
||||
channels.append(
|
||||
DiscordChannelView(
|
||||
channel_id=channel.id,
|
||||
channel_name=channel.name,
|
||||
channel_type=channel.type.name, # "text" or "forum"
|
||||
is_private=is_private,
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"Retrieved {len(channels)} channels from guild '{guild.name}'")
|
||||
return channels
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Sync Channels
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def handle_sync_channels_command(
|
||||
message: discord.Message,
|
||||
tenant_id: str | None,
|
||||
bot: discord.Client,
|
||||
) -> bool:
|
||||
"""Handle !sync-channels command. Returns True if command was handled."""
|
||||
content = message.content.strip()
|
||||
|
||||
# Check for !sync-channels command
|
||||
if not content.startswith(f"{DISCORD_BOT_INVOKE_CHAR}{SYNC_CHANNELS_COMMAND}"):
|
||||
return False
|
||||
|
||||
# Must be in a server
|
||||
if not message.guild:
|
||||
await _try_dm_author(
|
||||
message, "This command can only be used in a server channel."
|
||||
)
|
||||
return True
|
||||
|
||||
guild_name = message.guild.name
|
||||
logger.info(f"Sync-channels command received: {guild_name}")
|
||||
|
||||
try:
|
||||
# Must be registered
|
||||
if not tenant_id:
|
||||
raise SyncChannelsError(
|
||||
"This server is not registered. Please register it first."
|
||||
)
|
||||
|
||||
# Check permissions - require admin or manage_guild
|
||||
if not message.author or not isinstance(message.author, discord.Member):
|
||||
raise SyncChannelsError(
|
||||
"You need to be a server administrator to sync channels."
|
||||
)
|
||||
|
||||
if not message.author.guild_permissions.administrator:
|
||||
if not message.author.guild_permissions.manage_guild:
|
||||
raise SyncChannelsError(
|
||||
"You need **Administrator** or **Manage Server** permissions "
|
||||
"to sync channels."
|
||||
)
|
||||
|
||||
# Get guild config ID
|
||||
def _get_guild_config_id() -> int | None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db:
|
||||
if not message.guild:
|
||||
raise SyncChannelsError(
|
||||
"Server not found. This shouldn't happen. Please contact Onyx support."
|
||||
)
|
||||
config = get_guild_config_by_discord_id(db, message.guild.id)
|
||||
return config.id if config else None
|
||||
|
||||
guild_config_id = await asyncio.to_thread(_get_guild_config_id)
|
||||
|
||||
if not guild_config_id:
|
||||
raise SyncChannelsError(
|
||||
"Server config not found. This shouldn't happen. Please contact Onyx support."
|
||||
)
|
||||
|
||||
# Perform the sync
|
||||
added, removed, updated = await sync_guild_channels(
|
||||
guild_config_id, tenant_id, bot
|
||||
)
|
||||
logger.info(
|
||||
f"Sync-channels successful: {guild_name}, "
|
||||
f"added={added}, removed={removed}, updated={updated}"
|
||||
)
|
||||
await message.reply(
|
||||
f":white_check_mark: **Channel sync complete!**\n\n"
|
||||
f"* **{added}** new channel(s) added\n"
|
||||
f"* **{removed}** deleted channel(s) removed\n"
|
||||
f"* **{updated}** channel name(s) updated\n\n"
|
||||
"New channels are disabled by default. Enable them in the Onyx admin panel."
|
||||
)
|
||||
except SyncChannelsError as e:
|
||||
logger.debug(f"Sync-channels failed: {guild_name}, error={e}")
|
||||
await _try_dm_author(message, f":x: **Channel sync failed.**\n\n{e}")
|
||||
await _try_react_x(message)
|
||||
except Exception:
|
||||
logger.exception(f"Sync-channels failed unexpectedly: {guild_name}")
|
||||
await _try_dm_author(
|
||||
message,
|
||||
":x: **Channel sync failed.**\n\n"
|
||||
"An unexpected error occurred. Please try again later.",
|
||||
)
|
||||
await _try_react_x(message)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def sync_guild_channels(
|
||||
guild_config_id: int,
|
||||
tenant_id: str,
|
||||
bot: discord.Client,
|
||||
) -> tuple[int, int, int]:
|
||||
"""Sync channel configs with current Discord channels for a guild.
|
||||
|
||||
Fetches current channels from Discord and syncs with database:
|
||||
- Creates configs for new channels (disabled by default)
|
||||
- Removes configs for deleted channels
|
||||
- Updates names for existing channels if changed
|
||||
|
||||
Args:
|
||||
guild_config_id: Internal ID of the guild config
|
||||
tenant_id: Tenant ID for database access
|
||||
bot: Discord bot client
|
||||
|
||||
Returns:
|
||||
(added_count, removed_count, updated_count)
|
||||
|
||||
Raises:
|
||||
ValueError: If guild config not found or guild not registered
|
||||
"""
|
||||
context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
# Get guild_id from config
|
||||
def _get_guild_id() -> int | None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db:
|
||||
config = get_guild_config_by_internal_id(db, guild_config_id)
|
||||
if not config:
|
||||
return None
|
||||
return config.guild_id
|
||||
|
||||
guild_id = await asyncio.to_thread(_get_guild_id)
|
||||
|
||||
if guild_id is None:
|
||||
raise ValueError(
|
||||
f"Guild config {guild_config_id} not found or not registered"
|
||||
)
|
||||
|
||||
# Get the guild from Discord
|
||||
guild = bot.get_guild(guild_id)
|
||||
if not guild:
|
||||
raise ValueError(f"Guild {guild_id} not found in Discord cache")
|
||||
|
||||
# Get current channels from Discord
|
||||
channels = get_text_channels(guild)
|
||||
logger.info(f"Syncing {len(channels)} channels for guild '{guild.name}'")
|
||||
|
||||
# Sync with database
|
||||
def _sync() -> tuple[int, int, int]:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db:
|
||||
added, removed, updated = sync_channel_configs(
|
||||
db, guild_config_id, channels
|
||||
)
|
||||
db.commit()
|
||||
return added, removed, updated
|
||||
|
||||
added, removed, updated = await asyncio.to_thread(_sync)
|
||||
|
||||
logger.info(
|
||||
f"Channel sync complete for guild '{guild.name}': "
|
||||
f"added={added}, removed={removed}, updated={updated}"
|
||||
)
|
||||
|
||||
return added, removed, updated
|
||||
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token)
|
||||
493
backend/onyx/onyxbot/discord/handle_message.py
Normal file
493
backend/onyx/onyxbot/discord/handle_message.py
Normal file
@@ -0,0 +1,493 @@
|
||||
"""Discord bot message handling and response logic."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import discord
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.db.discord_bot import get_channel_config_by_discord_ids
|
||||
from onyx.db.discord_bot import get_guild_config_by_discord_id
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.models import DiscordChannelConfig
|
||||
from onyx.db.models import DiscordGuildConfig
|
||||
from onyx.onyxbot.discord.api_client import OnyxAPIClient
|
||||
from onyx.onyxbot.discord.constants import MAX_CONTEXT_MESSAGES
|
||||
from onyx.onyxbot.discord.constants import MAX_MESSAGE_LENGTH
|
||||
from onyx.onyxbot.discord.constants import THINKING_EMOJI
|
||||
from onyx.onyxbot.discord.exceptions import APIError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Message types with actual content (excludes system notifications like "user joined")
|
||||
CONTENT_MESSAGE_TYPES = (
|
||||
discord.MessageType.default,
|
||||
discord.MessageType.reply,
|
||||
discord.MessageType.thread_starter_message,
|
||||
)
|
||||
|
||||
|
||||
class ShouldRespondContext(BaseModel):
|
||||
"""Context for whether the bot should respond to a message."""
|
||||
|
||||
should_respond: bool
|
||||
persona_id: int | None
|
||||
thread_only_mode: bool
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Response Logic
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def should_respond(
|
||||
message: discord.Message,
|
||||
tenant_id: str,
|
||||
bot_user: discord.ClientUser,
|
||||
) -> ShouldRespondContext:
|
||||
"""Determine if bot should respond and which persona to use."""
|
||||
if not message.guild:
|
||||
logger.warning("Received a message that isn't in a server.")
|
||||
return ShouldRespondContext(
|
||||
should_respond=False, persona_id=None, thread_only_mode=False
|
||||
)
|
||||
|
||||
guild_id = message.guild.id
|
||||
channel_id = message.channel.id
|
||||
bot_mentioned = bot_user in message.mentions
|
||||
|
||||
def _get_configs() -> tuple[DiscordGuildConfig | None, DiscordChannelConfig | None]:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db:
|
||||
guild_config = get_guild_config_by_discord_id(db, guild_id)
|
||||
if not guild_config or not guild_config.enabled:
|
||||
return None, None
|
||||
|
||||
# For threads, use parent channel ID
|
||||
actual_channel_id = channel_id
|
||||
if isinstance(message.channel, discord.Thread) and message.channel.parent:
|
||||
actual_channel_id = message.channel.parent.id
|
||||
|
||||
channel_config = get_channel_config_by_discord_ids(
|
||||
db, guild_id, actual_channel_id
|
||||
)
|
||||
return guild_config, channel_config
|
||||
|
||||
guild_config, channel_config = await asyncio.to_thread(_get_configs)
|
||||
|
||||
if not guild_config or not channel_config or not channel_config.enabled:
|
||||
return ShouldRespondContext(
|
||||
should_respond=False, persona_id=None, thread_only_mode=False
|
||||
)
|
||||
|
||||
# Determine persona (channel override or guild default)
|
||||
persona_id = channel_config.persona_override_id or guild_config.default_persona_id
|
||||
|
||||
# Check mention requirement (with exceptions for implicit invocation)
|
||||
if channel_config.require_bot_invocation and not bot_mentioned:
|
||||
if not await check_implicit_invocation(message, bot_user):
|
||||
return ShouldRespondContext(
|
||||
should_respond=False, persona_id=None, thread_only_mode=False
|
||||
)
|
||||
|
||||
return ShouldRespondContext(
|
||||
should_respond=True,
|
||||
persona_id=persona_id,
|
||||
thread_only_mode=channel_config.thread_only_mode,
|
||||
)
|
||||
|
||||
|
||||
async def check_implicit_invocation(
|
||||
message: discord.Message,
|
||||
bot_user: discord.ClientUser,
|
||||
) -> bool:
|
||||
"""Check if the bot should respond without explicit mention.
|
||||
|
||||
Returns True if:
|
||||
1. User is replying to a bot message
|
||||
2. User is in a thread owned by the bot
|
||||
3. User is in a thread created from a bot message
|
||||
"""
|
||||
# Check if replying to a bot message
|
||||
if message.reference and message.reference.message_id:
|
||||
try:
|
||||
referenced_msg = await message.channel.fetch_message(
|
||||
message.reference.message_id
|
||||
)
|
||||
if referenced_msg.author.id == bot_user.id:
|
||||
logger.debug(
|
||||
f"Implicit invocation via reply: '{message.content[:50]}...'"
|
||||
)
|
||||
return True
|
||||
except (discord.NotFound, discord.HTTPException):
|
||||
pass
|
||||
|
||||
# Check thread-related conditions
|
||||
if isinstance(message.channel, discord.Thread):
|
||||
thread = message.channel
|
||||
|
||||
# Bot owns the thread
|
||||
if thread.owner_id == bot_user.id:
|
||||
logger.debug(
|
||||
f"Implicit invocation via bot-owned thread: "
|
||||
f"'{message.content[:50]}...' in #{thread.name}"
|
||||
)
|
||||
return True
|
||||
|
||||
# Thread was created from a bot message
|
||||
if thread.parent and not isinstance(thread.parent, discord.ForumChannel):
|
||||
try:
|
||||
starter = await thread.parent.fetch_message(thread.id)
|
||||
if starter.author.id == bot_user.id:
|
||||
logger.debug(
|
||||
f"Implicit invocation via bot-started thread: "
|
||||
f"'{message.content[:50]}...' in #{thread.name}"
|
||||
)
|
||||
return True
|
||||
except (discord.NotFound, discord.HTTPException):
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Message Processing
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def process_chat_message(
|
||||
message: discord.Message,
|
||||
api_key: str,
|
||||
persona_id: int | None,
|
||||
thread_only_mode: bool,
|
||||
api_client: OnyxAPIClient,
|
||||
bot_user: discord.ClientUser,
|
||||
) -> None:
|
||||
"""Process a message and send response."""
|
||||
try:
|
||||
await message.add_reaction(THINKING_EMOJI)
|
||||
except discord.DiscordException:
|
||||
logger.warning(
|
||||
f"Failed to add thinking reaction to message: '{message.content[:50]}...'"
|
||||
)
|
||||
|
||||
try:
|
||||
# Build conversation context
|
||||
context = await _build_conversation_context(message, bot_user)
|
||||
|
||||
# Prepare full message content
|
||||
parts = []
|
||||
if context:
|
||||
parts.append(context)
|
||||
if isinstance(message.channel, discord.Thread):
|
||||
if isinstance(message.channel.parent, discord.ForumChannel):
|
||||
parts.append(f"Forum post title: {message.channel.name}")
|
||||
parts.append(
|
||||
f"Current message from @{message.author.display_name}: {format_message_content(message)}"
|
||||
)
|
||||
|
||||
# Send to API
|
||||
response = await api_client.send_chat_message(
|
||||
message="\n\n".join(parts),
|
||||
api_key=api_key,
|
||||
persona_id=persona_id,
|
||||
)
|
||||
|
||||
# Format response with citations
|
||||
answer = response.answer or "I couldn't generate a response."
|
||||
answer = _append_citations(answer, response)
|
||||
|
||||
await send_response(message, answer, thread_only_mode)
|
||||
|
||||
try:
|
||||
await message.remove_reaction(THINKING_EMOJI, bot_user)
|
||||
except discord.DiscordException:
|
||||
pass
|
||||
|
||||
except APIError as e:
|
||||
logger.error(f"API error processing message: {e}")
|
||||
await send_error_response(message, bot_user)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing chat message: {e}")
|
||||
await send_error_response(message, bot_user)
|
||||
|
||||
|
||||
async def _build_conversation_context(
|
||||
message: discord.Message,
|
||||
bot_user: discord.ClientUser,
|
||||
) -> str | None:
|
||||
"""Build conversation context from thread history or reply chain."""
|
||||
if isinstance(message.channel, discord.Thread):
|
||||
return await _build_thread_context(message, bot_user)
|
||||
elif message.reference:
|
||||
return await _build_reply_chain_context(message, bot_user)
|
||||
return None
|
||||
|
||||
|
||||
def _append_citations(answer: str, response: ChatFullResponse) -> str:
|
||||
"""Append citation sources to the answer if present."""
|
||||
if not response.citation_info or not response.top_documents:
|
||||
return answer
|
||||
|
||||
cited_docs: list[tuple[int, str, str | None]] = []
|
||||
for citation in response.citation_info:
|
||||
doc = next(
|
||||
(
|
||||
d
|
||||
for d in response.top_documents
|
||||
if d.document_id == citation.document_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
if doc:
|
||||
cited_docs.append(
|
||||
(
|
||||
citation.citation_number,
|
||||
doc.semantic_identifier or "Source",
|
||||
doc.link,
|
||||
)
|
||||
)
|
||||
|
||||
if not cited_docs:
|
||||
return answer
|
||||
|
||||
cited_docs.sort(key=lambda x: x[0])
|
||||
citations = "\n\n**Sources:**\n"
|
||||
for num, name, link in cited_docs[:5]:
|
||||
if link:
|
||||
citations += f"{num}. [{name}](<{link}>)\n"
|
||||
else:
|
||||
citations += f"{num}. {name}\n"
|
||||
|
||||
return answer + citations
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Context Building
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _build_reply_chain_context(
|
||||
message: discord.Message,
|
||||
bot_user: discord.ClientUser,
|
||||
) -> str | None:
|
||||
"""Build context by following the reply chain backwards."""
|
||||
if not message.reference or not message.reference.message_id:
|
||||
return None
|
||||
|
||||
try:
|
||||
messages: list[discord.Message] = []
|
||||
current = message
|
||||
|
||||
# Follow reply chain backwards up to MAX_CONTEXT_MESSAGES
|
||||
while (
|
||||
current.reference
|
||||
and current.reference.message_id
|
||||
and len(messages) < MAX_CONTEXT_MESSAGES
|
||||
):
|
||||
try:
|
||||
parent = await message.channel.fetch_message(
|
||||
current.reference.message_id
|
||||
)
|
||||
messages.append(parent)
|
||||
current = parent
|
||||
except (discord.NotFound, discord.HTTPException):
|
||||
break
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
messages.reverse() # Chronological order
|
||||
|
||||
logger.debug(
|
||||
f"Built reply chain context: {len(messages)} messages in #{getattr(message.channel, 'name', 'unknown')}"
|
||||
)
|
||||
|
||||
return _format_messages_as_context(messages, bot_user)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to build reply chain context: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _build_thread_context(
|
||||
message: discord.Message,
|
||||
bot_user: discord.ClientUser,
|
||||
) -> str | None:
|
||||
"""Build context from thread message history."""
|
||||
if not isinstance(message.channel, discord.Thread):
|
||||
return None
|
||||
|
||||
try:
|
||||
thread = message.channel
|
||||
messages: list[discord.Message] = []
|
||||
|
||||
# Fetch recent messages (excluding current)
|
||||
async for msg in thread.history(limit=MAX_CONTEXT_MESSAGES, oldest_first=False):
|
||||
if msg.id != message.id:
|
||||
messages.append(msg)
|
||||
|
||||
# Include thread starter message and its reply chain if not already present
|
||||
if thread.parent and not isinstance(thread.parent, discord.ForumChannel):
|
||||
try:
|
||||
starter = await thread.parent.fetch_message(thread.id)
|
||||
if starter.id != message.id and not any(
|
||||
m.id == starter.id for m in messages
|
||||
):
|
||||
messages.append(starter)
|
||||
|
||||
# Trace back through the starter's reply chain for more context
|
||||
current = starter
|
||||
while (
|
||||
current.reference
|
||||
and current.reference.message_id
|
||||
and len(messages) < MAX_CONTEXT_MESSAGES
|
||||
):
|
||||
try:
|
||||
parent = await thread.parent.fetch_message(
|
||||
current.reference.message_id
|
||||
)
|
||||
if not any(m.id == parent.id for m in messages):
|
||||
messages.append(parent)
|
||||
current = parent
|
||||
except (discord.NotFound, discord.HTTPException):
|
||||
break
|
||||
except (discord.NotFound, discord.HTTPException):
|
||||
pass
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
messages.sort(key=lambda m: m.id) # Chronological order
|
||||
logger.debug(
|
||||
f"Built thread context: {len(messages)} messages in #{thread.name}"
|
||||
)
|
||||
|
||||
return _format_messages_as_context(messages, bot_user)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to build thread context: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _format_messages_as_context(
|
||||
messages: list[discord.Message],
|
||||
bot_user: discord.ClientUser,
|
||||
) -> str | None:
|
||||
"""Format a list of messages into a conversation context string."""
|
||||
formatted = []
|
||||
for msg in messages:
|
||||
if msg.type not in CONTENT_MESSAGE_TYPES:
|
||||
continue
|
||||
|
||||
sender = (
|
||||
"OnyxBot" if msg.author.id == bot_user.id else f"@{msg.author.display_name}"
|
||||
)
|
||||
formatted.append(f"{sender}: {format_message_content(msg)}")
|
||||
|
||||
if not formatted:
|
||||
return None
|
||||
|
||||
return (
|
||||
"You are a Discord bot named OnyxBot.\n"
|
||||
'Always assume that [user] is the same as the "Current message" author.'
|
||||
"Conversation history:\n"
|
||||
"---\n" + "\n".join(formatted) + "\n---"
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Message Formatting
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
def format_message_content(message: discord.Message) -> str:
|
||||
"""Format message content with readable mentions."""
|
||||
content = message.content
|
||||
|
||||
for user in message.mentions:
|
||||
content = content.replace(f"<@{user.id}>", f"@{user.display_name}")
|
||||
content = content.replace(f"<@!{user.id}>", f"@{user.display_name}")
|
||||
|
||||
for role in message.role_mentions:
|
||||
content = content.replace(f"<@&{role.id}>", f"@{role.name}")
|
||||
|
||||
for channel in message.channel_mentions:
|
||||
content = content.replace(f"<#{channel.id}>", f"#{channel.name}")
|
||||
|
||||
return content
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Response Sending
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def send_response(
|
||||
message: discord.Message,
|
||||
content: str,
|
||||
thread_only_mode: bool,
|
||||
) -> None:
|
||||
"""Send response based on thread_only_mode setting."""
|
||||
chunks = _split_message(content)
|
||||
|
||||
if isinstance(message.channel, discord.Thread):
|
||||
for chunk in chunks:
|
||||
await message.channel.send(chunk)
|
||||
elif thread_only_mode:
|
||||
thread_name = f"OnyxBot <> {message.author.display_name}"[:100]
|
||||
thread = await message.create_thread(name=thread_name)
|
||||
for chunk in chunks:
|
||||
await thread.send(chunk)
|
||||
else:
|
||||
for i, chunk in enumerate(chunks):
|
||||
if i == 0:
|
||||
await message.reply(chunk)
|
||||
else:
|
||||
await message.channel.send(chunk)
|
||||
|
||||
|
||||
def _split_message(content: str) -> list[str]:
|
||||
"""Split content into chunks that fit Discord's message limit."""
|
||||
chunks = []
|
||||
while content:
|
||||
if len(content) <= MAX_MESSAGE_LENGTH:
|
||||
chunks.append(content)
|
||||
break
|
||||
|
||||
# Find a good split point
|
||||
split_at = MAX_MESSAGE_LENGTH
|
||||
for sep in ["\n\n", "\n", ". ", " "]:
|
||||
idx = content.rfind(sep, 0, MAX_MESSAGE_LENGTH)
|
||||
if idx > MAX_MESSAGE_LENGTH // 2:
|
||||
split_at = idx + len(sep)
|
||||
break
|
||||
|
||||
chunks.append(content[:split_at])
|
||||
content = content[split_at:]
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
async def send_error_response(
|
||||
message: discord.Message,
|
||||
bot_user: discord.ClientUser,
|
||||
) -> None:
|
||||
"""Send error response and clean up reaction."""
|
||||
try:
|
||||
await message.remove_reaction(THINKING_EMOJI, bot_user)
|
||||
except discord.DiscordException:
|
||||
pass
|
||||
|
||||
error_msg = "Sorry, I encountered an error processing your message. You may want to contact Onyx for support :sweat_smile:"
|
||||
|
||||
try:
|
||||
if isinstance(message.channel, discord.Thread):
|
||||
await message.channel.send(error_msg)
|
||||
else:
|
||||
thread = await message.create_thread(
|
||||
name=f"Response to {message.author.display_name}"[:100]
|
||||
)
|
||||
await thread.send(error_msg)
|
||||
except discord.DiscordException:
|
||||
pass
|
||||
39
backend/onyx/onyxbot/discord/utils.py
Normal file
39
backend/onyx/onyxbot/discord/utils.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import DISCORD_BOT_TOKEN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.db.discord_bot import get_discord_bot_config
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_bot_token() -> str | None:
|
||||
"""Get Discord bot token from env var or database.
|
||||
|
||||
Priority:
|
||||
1. DISCORD_BOT_TOKEN env var (always takes precedence)
|
||||
2. For self-hosted: DiscordBotConfig in database (default tenant)
|
||||
3. For Cloud: should always have env var set
|
||||
|
||||
Returns:
|
||||
Bot token string, or None if not configured.
|
||||
"""
|
||||
# Environment variable takes precedence
|
||||
if DISCORD_BOT_TOKEN:
|
||||
return DISCORD_BOT_TOKEN
|
||||
|
||||
# Cloud should always have env var; if not, return None
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
logger.warning("Cloud deployment missing DISCORD_BOT_TOKEN env var")
|
||||
return None
|
||||
|
||||
# Self-hosted: check database for bot config
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db:
|
||||
config = get_discord_bot_config(db)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get bot token from database: {e}")
|
||||
return None
|
||||
return config.bot_token if config else None
|
||||
@@ -592,11 +592,8 @@ def build_slack_response_blocks(
|
||||
)
|
||||
|
||||
citations_blocks = []
|
||||
document_blocks = []
|
||||
if answer.citation_info:
|
||||
citations_blocks = _build_citations_blocks(answer)
|
||||
else:
|
||||
document_blocks = _priority_ordered_documents_blocks(answer)
|
||||
|
||||
citations_divider = [DividerBlock()] if citations_blocks else []
|
||||
buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else []
|
||||
@@ -608,7 +605,6 @@ def build_slack_response_blocks(
|
||||
+ ai_feedback_block
|
||||
+ citations_divider
|
||||
+ citations_blocks
|
||||
+ document_blocks
|
||||
+ buttons_divider
|
||||
+ web_follow_up_block
|
||||
+ follow_up_block
|
||||
|
||||
@@ -1,12 +1,149 @@
|
||||
from mistune import Markdown # type: ignore[import-untyped]
|
||||
from mistune import Renderer
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from mistune import create_markdown
|
||||
from mistune import HTMLRenderer
|
||||
|
||||
# Tags that should be replaced with a newline (line-break and block-level elements)
|
||||
_HTML_NEWLINE_TAG_PATTERN = re.compile(
|
||||
r"<br\s*/?>|</(?:p|div|li|h[1-6]|tr|blockquote|section|article)>",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Strips HTML tags but excludes autolinks like <https://...> and <mailto:...>
|
||||
_HTML_TAG_PATTERN = re.compile(
|
||||
r"<(?!https?://|mailto:)/?[a-zA-Z][^>]*>",
|
||||
)
|
||||
|
||||
# Matches fenced code blocks (``` ... ```) so we can skip sanitization inside them
|
||||
_FENCED_CODE_BLOCK_PATTERN = re.compile(r"```[\s\S]*?```")
|
||||
|
||||
# Matches the start of any markdown link: [text]( or [[n]](
|
||||
# The inner group handles nested brackets for citation links like [[1]](.
|
||||
_MARKDOWN_LINK_PATTERN = re.compile(r"\[(?:[^\[\]]|\[[^\]]*\])*\]\(")
|
||||
|
||||
# Matches Slack-style links <url|text> that LLMs sometimes output directly.
|
||||
# Mistune doesn't recognise this syntax, so text() would escape the angle
|
||||
# brackets and Slack would render them as literal text instead of links.
|
||||
_SLACK_LINK_PATTERN = re.compile(r"<(https?://[^|>]+)\|([^>]+)>")
|
||||
|
||||
|
||||
def _sanitize_html(text: str) -> str:
|
||||
"""Strip HTML tags from a text fragment.
|
||||
|
||||
Block-level closing tags and <br> are converted to newlines.
|
||||
All other HTML tags are removed. Autolinks (<https://...>) are preserved.
|
||||
"""
|
||||
text = _HTML_NEWLINE_TAG_PATTERN.sub("\n", text)
|
||||
text = _HTML_TAG_PATTERN.sub("", text)
|
||||
return text
|
||||
|
||||
|
||||
def _transform_outside_code_blocks(
|
||||
message: str, transform: Callable[[str], str]
|
||||
) -> str:
|
||||
"""Apply *transform* only to text outside fenced code blocks."""
|
||||
parts = _FENCED_CODE_BLOCK_PATTERN.split(message)
|
||||
code_blocks = _FENCED_CODE_BLOCK_PATTERN.findall(message)
|
||||
|
||||
result: list[str] = []
|
||||
for i, part in enumerate(parts):
|
||||
result.append(transform(part))
|
||||
if i < len(code_blocks):
|
||||
result.append(code_blocks[i])
|
||||
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int | None]:
|
||||
"""Extract markdown link destination, allowing nested parentheses in the URL."""
|
||||
depth = 0
|
||||
i = start_idx
|
||||
|
||||
while i < len(message):
|
||||
curr = message[i]
|
||||
if curr == "\\":
|
||||
i += 2
|
||||
continue
|
||||
|
||||
if curr == "(":
|
||||
depth += 1
|
||||
elif curr == ")":
|
||||
if depth == 0:
|
||||
return message[start_idx:i], i
|
||||
depth -= 1
|
||||
i += 1
|
||||
|
||||
return message[start_idx:], None
|
||||
|
||||
|
||||
def _normalize_link_destinations(message: str) -> str:
|
||||
"""Wrap markdown link URLs in angle brackets so the parser handles special chars safely.
|
||||
|
||||
Markdown link syntax [text](url) breaks when the URL contains unescaped
|
||||
parentheses, spaces, or other special characters. Wrapping the URL in angle
|
||||
brackets — [text](<url>) — tells the parser to treat everything inside as
|
||||
a literal URL. This applies to all links, not just citations.
|
||||
"""
|
||||
if "](" not in message:
|
||||
return message
|
||||
|
||||
normalized_parts: list[str] = []
|
||||
cursor = 0
|
||||
|
||||
while match := _MARKDOWN_LINK_PATTERN.search(message, cursor):
|
||||
normalized_parts.append(message[cursor : match.end()])
|
||||
destination_start = match.end()
|
||||
destination, end_idx = _extract_link_destination(message, destination_start)
|
||||
if end_idx is None:
|
||||
normalized_parts.append(message[destination_start:])
|
||||
return "".join(normalized_parts)
|
||||
|
||||
already_wrapped = destination.startswith("<") and destination.endswith(">")
|
||||
if destination and not already_wrapped:
|
||||
destination = f"<{destination}>"
|
||||
|
||||
normalized_parts.append(destination)
|
||||
normalized_parts.append(")")
|
||||
cursor = end_idx + 1
|
||||
|
||||
normalized_parts.append(message[cursor:])
|
||||
return "".join(normalized_parts)
|
||||
|
||||
|
||||
def _convert_slack_links_to_markdown(message: str) -> str:
|
||||
"""Convert Slack-style <url|text> links to standard markdown [text](url).
|
||||
|
||||
LLMs sometimes emit Slack mrkdwn link syntax directly. Mistune doesn't
|
||||
recognise it, so the angle brackets would be escaped by text() and Slack
|
||||
would render the link as literal text instead of a clickable link.
|
||||
"""
|
||||
return _transform_outside_code_blocks(
|
||||
message, lambda text: _SLACK_LINK_PATTERN.sub(r"[\2](\1)", text)
|
||||
)
|
||||
|
||||
|
||||
def format_slack_message(message: str | None) -> str:
|
||||
return Markdown(renderer=SlackRenderer()).render(message)
|
||||
if message is None:
|
||||
return ""
|
||||
message = _transform_outside_code_blocks(message, _sanitize_html)
|
||||
message = _convert_slack_links_to_markdown(message)
|
||||
normalized_message = _normalize_link_destinations(message)
|
||||
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
|
||||
result = md(normalized_message)
|
||||
# With HTMLRenderer, result is always str (not AST list)
|
||||
assert isinstance(result, str)
|
||||
return result.rstrip("\n")
|
||||
|
||||
|
||||
class SlackRenderer(Renderer):
|
||||
class SlackRenderer(HTMLRenderer):
|
||||
"""Renders markdown as Slack mrkdwn format instead of HTML.
|
||||
|
||||
Overrides all HTMLRenderer methods that produce HTML tags to ensure
|
||||
no raw HTML ever appears in Slack messages.
|
||||
"""
|
||||
|
||||
SPECIALS: dict[str, str] = {"&": "&", "<": "<", ">": ">"}
|
||||
|
||||
def escape_special(self, text: str) -> str:
|
||||
@@ -14,52 +151,72 @@ class SlackRenderer(Renderer):
|
||||
text = text.replace(special, replacement)
|
||||
return text
|
||||
|
||||
def header(self, text: str, level: int, raw: str | None = None) -> str:
|
||||
return f"*{text}*\n"
|
||||
def heading(self, text: str, level: int, **attrs: Any) -> str: # noqa: ARG002
|
||||
return f"*{text}*\n\n"
|
||||
|
||||
def emphasis(self, text: str) -> str:
|
||||
return f"_{text}_"
|
||||
|
||||
def double_emphasis(self, text: str) -> str:
|
||||
def strong(self, text: str) -> str:
|
||||
return f"*{text}*"
|
||||
|
||||
def strikethrough(self, text: str) -> str:
|
||||
return f"~{text}~"
|
||||
|
||||
def list(self, body: str, ordered: bool = True) -> str:
|
||||
lines = body.split("\n")
|
||||
def list(self, text: str, ordered: bool, **attrs: Any) -> str:
|
||||
lines = text.split("\n")
|
||||
count = 0
|
||||
for i, line in enumerate(lines):
|
||||
if line.startswith("li: "):
|
||||
count += 1
|
||||
prefix = f"{count}. " if ordered else "• "
|
||||
lines[i] = f"{prefix}{line[4:]}"
|
||||
return "\n".join(lines)
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
def list_item(self, text: str) -> str:
|
||||
return f"li: {text}\n"
|
||||
|
||||
def link(self, link: str, title: str | None, content: str | None) -> str:
|
||||
escaped_link = self.escape_special(link)
|
||||
if content:
|
||||
return f"<{escaped_link}|{content}>"
|
||||
def link(self, text: str, url: str, title: str | None = None) -> str:
|
||||
escaped_url = self.escape_special(url)
|
||||
if text:
|
||||
return f"<{escaped_url}|{text}>"
|
||||
if title:
|
||||
return f"<{escaped_link}|{title}>"
|
||||
return f"<{escaped_link}>"
|
||||
return f"<{escaped_url}|{title}>"
|
||||
return f"<{escaped_url}>"
|
||||
|
||||
def image(self, src: str, title: str | None, text: str | None) -> str:
|
||||
escaped_src = self.escape_special(src)
|
||||
def image(self, text: str, url: str, title: str | None = None) -> str:
|
||||
escaped_url = self.escape_special(url)
|
||||
display_text = title or text
|
||||
return f"<{escaped_src}|{display_text}>" if display_text else f"<{escaped_src}>"
|
||||
return f"<{escaped_url}|{display_text}>" if display_text else f"<{escaped_url}>"
|
||||
|
||||
def codespan(self, text: str) -> str:
|
||||
return f"`{text}`"
|
||||
|
||||
def block_code(self, text: str, lang: str | None) -> str:
|
||||
return f"```\n{text}\n```\n"
|
||||
def block_code(self, code: str, info: str | None = None) -> str: # noqa: ARG002
|
||||
return f"```\n{code.rstrip(chr(10))}\n```\n\n"
|
||||
|
||||
def linebreak(self) -> str:
|
||||
return "\n"
|
||||
|
||||
def thematic_break(self) -> str:
|
||||
return "---\n\n"
|
||||
|
||||
def block_quote(self, text: str) -> str:
|
||||
lines = text.strip().split("\n")
|
||||
quoted = "\n".join(f">{line}" for line in lines)
|
||||
return quoted + "\n\n"
|
||||
|
||||
def block_html(self, html: str) -> str:
|
||||
return _sanitize_html(html) + "\n\n"
|
||||
|
||||
def block_error(self, text: str) -> str:
|
||||
return f"```\n{text}\n```\n\n"
|
||||
|
||||
def text(self, text: str) -> str:
|
||||
# Only escape the three entities Slack recognizes: & < >
|
||||
# HTMLRenderer.text() also escapes " to " which Slack renders
|
||||
# as literal " text since Slack doesn't recognize that entity.
|
||||
return self.escape_special(text)
|
||||
|
||||
def paragraph(self, text: str) -> str:
|
||||
return f"{text}\n"
|
||||
|
||||
def autolink(self, link: str, is_email: bool) -> str:
|
||||
return link if is_email else self.link(link, None, None)
|
||||
return f"{text}\n\n"
|
||||
|
||||
38
backend/onyx/prompts/basic_memory.py
Normal file
38
backend/onyx/prompts/basic_memory.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# ruff: noqa: E501, W605 start
|
||||
|
||||
# Note that the user_basic_information is only included if we have at least 1 of the following: user_name, user_email, user_role
|
||||
# This is included because sometimes we need to know the user's name or basic info to best generate the memory.
|
||||
FULL_MEMORY_UPDATE_PROMPT = """
|
||||
You are a memory update agent that helps the user add or update memories. You are given a list of existing memories and a new memory to add. \
|
||||
Just as context, you are also given the last few user messages from the conversation which generated the new memory. You must determine if the memory is brand new or if it is related to an existing memory. \
|
||||
If the new memory is an update to an existing memory or contradicts an existing memory, it should be treated as an update and you should reference the existing memory by memory_id (see below). \
|
||||
The memory should omit the user's name and direct reference to the user - for example, a memory like "Yuhong prefers dark mode." should be modified to "Prefers dark mode." (if the user's name is Yuhong).
|
||||
|
||||
# Truncated chat history
|
||||
{chat_history}{user_basic_information}
|
||||
|
||||
# User's existing memories
|
||||
{existing_memories}
|
||||
|
||||
# New memory the user wants to insert
|
||||
{new_memory}
|
||||
|
||||
# Response Style
|
||||
You MUST respond in a json which follows the following format and keys:
|
||||
```json
|
||||
{{
|
||||
"operation": "add or update",
|
||||
"memory_id": "if the operation is update, the id of the memory to update, otherwise null",
|
||||
"memory_text": "the text of the memory to add or update"
|
||||
}}
|
||||
```
|
||||
""".strip()
|
||||
# ruff: noqa: E501, W605 end
|
||||
|
||||
MEMORY_USER_BASIC_INFORMATION_PROMPT = """
|
||||
|
||||
# User Basic Information
|
||||
User name: {user_name}
|
||||
User email: {user_email}
|
||||
User role: {user_role}
|
||||
"""
|
||||
@@ -1,30 +1,39 @@
|
||||
from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS
|
||||
|
||||
SLACK_QUERY_EXPANSION_PROMPT = f"""
|
||||
Rewrite the user's query and, if helpful, split it into at most {MAX_SLACK_QUERY_EXPANSIONS} \
|
||||
keyword-only queries, so that Slack's keyword search yields the best matches.
|
||||
Rewrite the user's query into at most {MAX_SLACK_QUERY_EXPANSIONS} keyword-only queries for Slack's keyword search.
|
||||
|
||||
Keep in mind the Slack's search behavior:
|
||||
- Pure keyword AND search (no semantics).
|
||||
- Word order matters.
|
||||
- More words = fewer matches, so keep each query concise.
|
||||
- IMPORTANT: Prefer simple 1-2 word queries over longer multi-word queries.
|
||||
Slack search behavior:
|
||||
- Pure keyword AND search (no semantics)
|
||||
- More words = fewer matches, so keep queries concise (1-3 words)
|
||||
|
||||
Critical: Extract ONLY keywords that would actually appear in Slack message content.
|
||||
ALWAYS include:
|
||||
- Person names (e.g., "Sarah Chen", "Mike Johnson") - people search for messages from/about specific people
|
||||
- Project/product names, technical terms, proper nouns
|
||||
- Actual content words: "performance", "bug", "deployment", "API", "error"
|
||||
|
||||
DO NOT include:
|
||||
- Meta-words: "topics", "conversations", "discussed", "summary", "messages", "big", "main", "talking"
|
||||
- Temporal: "today", "yesterday", "week", "month", "recent", "past", "last"
|
||||
- Channels/Users: "general", "eng-general", "engineering", "@username"
|
||||
|
||||
DO include:
|
||||
- Actual content: "performance", "bug", "deployment", "API", "database", "error", "feature"
|
||||
- Meta-words: "topics", "conversations", "discussed", "summary", "messages"
|
||||
- Temporal: "today", "yesterday", "week", "month", "recent", "last"
|
||||
- Channel names: "general", "eng-general", "random"
|
||||
|
||||
Examples:
|
||||
|
||||
Query: "what are the big topics in eng-general this week?"
|
||||
Output:
|
||||
|
||||
Query: "messages with Sarah about the deployment"
|
||||
Output:
|
||||
Sarah deployment
|
||||
Sarah
|
||||
deployment
|
||||
|
||||
Query: "what did Mike say about the budget?"
|
||||
Output:
|
||||
Mike budget
|
||||
Mike
|
||||
budget
|
||||
|
||||
Query: "performance issues in eng-general"
|
||||
Output:
|
||||
performance issues
|
||||
@@ -41,7 +50,7 @@ Now process this query:
|
||||
|
||||
{{query}}
|
||||
|
||||
Output:
|
||||
Output (keywords only, one per line, NO explanations or commentary):
|
||||
"""
|
||||
|
||||
SLACK_DATE_EXTRACTION_PROMPT = """
|
||||
|
||||
@@ -71,6 +71,14 @@ GENERATE_IMAGE_GUIDANCE = """
|
||||
NEVER use generate_image unless the user specifically requests an image.
|
||||
"""
|
||||
|
||||
MEMORY_GUIDANCE = """
|
||||
|
||||
## add_memory
|
||||
Use the `add_memory` tool for facts shared by the user that should be remembered for future conversations. \
|
||||
Only add memories that are specific, likely to remain true, and likely to be useful later. \
|
||||
Focus on enduring preferences, long-term goals, stable constraints, and explicit "remember this" type requests.
|
||||
"""
|
||||
|
||||
TOOL_CALL_FAILURE_PROMPT = """
|
||||
LLM attempted to call a tool but failed. Most likely the tool name or arguments were misspelled.
|
||||
""".strip()
|
||||
|
||||
40
backend/onyx/prompts/user_info.py
Normal file
40
backend/onyx/prompts/user_info.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# ruff: noqa: E501, W605 start
|
||||
USER_INFORMATION_HEADER = "\n\n# User Information\n"
|
||||
|
||||
BASIC_INFORMATION_PROMPT = """
|
||||
|
||||
## Basic Information
|
||||
User name: {user_name}
|
||||
User email: {user_email}{user_role}
|
||||
"""
|
||||
|
||||
# This line only shows up if the user has configured their role.
|
||||
USER_ROLE_PROMPT = """
|
||||
User role: {user_role}
|
||||
"""
|
||||
|
||||
# Team information should be a paragraph style description of the user's team.
|
||||
TEAM_INFORMATION_PROMPT = """
|
||||
|
||||
## Team Information
|
||||
{team_information}
|
||||
"""
|
||||
|
||||
# User preferences should be a paragraph style description of the user's preferences.
|
||||
USER_PREFERENCES_PROMPT = """
|
||||
|
||||
## User Preferences
|
||||
{user_preferences}
|
||||
"""
|
||||
|
||||
# User memories should look something like:
|
||||
# - Memory 1
|
||||
# - Memory 2
|
||||
# - Memory 3
|
||||
USER_MEMORIES_PROMPT = """
|
||||
|
||||
## User Memories
|
||||
{user_memories}
|
||||
"""
|
||||
|
||||
# ruff: noqa: E501, W605 end
|
||||
@@ -109,6 +109,7 @@ class TenantRedis(redis.Redis):
|
||||
"unlock",
|
||||
"get",
|
||||
"set",
|
||||
"setex",
|
||||
"delete",
|
||||
"exists",
|
||||
"incrby",
|
||||
|
||||
184
backend/onyx/secondary_llm_flows/memory_update.py
Normal file
184
backend/onyx/secondary_llm_flows/memory_update.py
Normal file
@@ -0,0 +1,184 @@
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import ReasoningEffort
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.prompts.basic_memory import FULL_MEMORY_UPDATE_PROMPT
|
||||
from onyx.tools.models import ChatMinimalTextMessage
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.text_processing import parse_llm_json_response
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Maximum number of user messages to include
|
||||
MAX_USER_MESSAGES = 3
|
||||
MAX_CHARS_PER_MESSAGE = 500
|
||||
|
||||
|
||||
def _format_chat_history(chat_history: list[ChatMinimalTextMessage]) -> str:
|
||||
user_messages = [
|
||||
msg for msg in chat_history if msg.message_type == MessageType.USER
|
||||
]
|
||||
|
||||
if not user_messages:
|
||||
return "No chat history available."
|
||||
|
||||
# Take the last N user messages
|
||||
recent_user_messages = user_messages[-MAX_USER_MESSAGES:]
|
||||
|
||||
formatted_parts = []
|
||||
for i, msg in enumerate(recent_user_messages, start=1):
|
||||
if len(msg.message) > MAX_CHARS_PER_MESSAGE:
|
||||
truncated_message = msg.message[:MAX_CHARS_PER_MESSAGE] + "[...truncated]"
|
||||
else:
|
||||
truncated_message = msg.message
|
||||
formatted_parts.append(f"\nUser message:\n{truncated_message}\n")
|
||||
|
||||
return "".join(formatted_parts).strip()
|
||||
|
||||
|
||||
def _format_existing_memories(existing_memories: list[str]) -> str:
|
||||
"""Format existing memories as a numbered list (1-indexed for readability)."""
|
||||
if not existing_memories:
|
||||
return "No existing memories."
|
||||
|
||||
formatted_lines = []
|
||||
for i, memory in enumerate(existing_memories, start=1):
|
||||
formatted_lines.append(f"{i}. {memory}")
|
||||
|
||||
return "\n".join(formatted_lines)
|
||||
|
||||
|
||||
def _format_user_basic_information(
|
||||
user_name: str | None,
|
||||
user_email: str | None,
|
||||
user_role: str | None,
|
||||
) -> str:
|
||||
"""Format user basic information, only including fields that have values."""
|
||||
lines = []
|
||||
if user_name:
|
||||
lines.append(f"User name: {user_name}")
|
||||
if user_email:
|
||||
lines.append(f"User email: {user_email}")
|
||||
if user_role:
|
||||
lines.append(f"User role: {user_role}")
|
||||
|
||||
if not lines:
|
||||
return ""
|
||||
|
||||
return "\n\n# User Basic Information\n" + "\n".join(lines)
|
||||
|
||||
|
||||
def process_memory_update(
|
||||
new_memory: str,
|
||||
existing_memories: list[str],
|
||||
chat_history: list[ChatMinimalTextMessage],
|
||||
llm: LLM,
|
||||
user_name: str | None = None,
|
||||
user_email: str | None = None,
|
||||
user_role: str | None = None,
|
||||
) -> tuple[str, int | None]:
|
||||
"""
|
||||
Determine if a memory should be added or updated.
|
||||
|
||||
Uses the LLM to analyze the new memory against existing memories and
|
||||
determine whether to add it as new or update an existing memory.
|
||||
|
||||
Args:
|
||||
new_memory: The new memory text from the memory tool
|
||||
existing_memories: List of existing memory strings
|
||||
chat_history: Recent chat history for context
|
||||
llm: LLM instance to use for the decision
|
||||
user_name: Optional user name for context
|
||||
user_email: Optional user email for context
|
||||
user_role: Optional user role for context
|
||||
|
||||
Returns:
|
||||
Tuple of (memory_text, index_to_replace)
|
||||
- memory_text: The final memory text to store
|
||||
- index_to_replace: Index in existing_memories to replace, or None if adding new
|
||||
"""
|
||||
# Format inputs for the prompt
|
||||
formatted_chat_history = _format_chat_history(chat_history)
|
||||
formatted_memories = _format_existing_memories(existing_memories)
|
||||
formatted_user_info = _format_user_basic_information(
|
||||
user_name, user_email, user_role
|
||||
)
|
||||
|
||||
# Build the prompt
|
||||
prompt = FULL_MEMORY_UPDATE_PROMPT.format(
|
||||
chat_history=formatted_chat_history,
|
||||
user_basic_information=formatted_user_info,
|
||||
existing_memories=formatted_memories,
|
||||
new_memory=new_memory,
|
||||
)
|
||||
|
||||
# Call LLM
|
||||
try:
|
||||
response = llm.invoke(
|
||||
prompt=UserMessage(content=prompt), reasoning_effort=ReasoningEffort.OFF
|
||||
)
|
||||
content = response.choice.message.content
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM invocation failed for memory update: {e}")
|
||||
return (new_memory, None)
|
||||
|
||||
# Handle empty response
|
||||
if not content:
|
||||
logger.warning(
|
||||
"LLM returned empty response for memory update, defaulting to add"
|
||||
)
|
||||
return (new_memory, None)
|
||||
|
||||
# Parse JSON response
|
||||
parsed_response = parse_llm_json_response(content)
|
||||
|
||||
if not parsed_response:
|
||||
logger.warning(
|
||||
f"Failed to parse JSON from LLM response: {content[:200]}..., defaulting to add"
|
||||
)
|
||||
return (new_memory, None)
|
||||
|
||||
# Extract fields from response
|
||||
operation = parsed_response.get("operation", "add").lower()
|
||||
memory_id = parsed_response.get("memory_id")
|
||||
memory_text = parsed_response.get("memory_text", new_memory)
|
||||
|
||||
# Ensure memory_text is valid
|
||||
if not memory_text or not isinstance(memory_text, str):
|
||||
memory_text = new_memory
|
||||
|
||||
# Handle add operation
|
||||
if operation == "add":
|
||||
logger.debug("Memory update operation: add")
|
||||
return (memory_text, None)
|
||||
|
||||
# Handle update operation
|
||||
if operation == "update":
|
||||
# Validate memory_id
|
||||
if memory_id is None:
|
||||
logger.warning("Update operation specified but no memory_id provided")
|
||||
return (memory_text, None)
|
||||
|
||||
# Convert memory_id to integer if it's a string
|
||||
try:
|
||||
memory_id_int = int(memory_id)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"Invalid memory_id format: {memory_id}")
|
||||
return (memory_text, None)
|
||||
|
||||
# Convert from 1-indexed (LLM response) to 0-indexed (internal)
|
||||
index_to_replace = memory_id_int - 1
|
||||
|
||||
# Validate index is in range
|
||||
if index_to_replace < 0 or index_to_replace >= len(existing_memories):
|
||||
logger.warning(
|
||||
f"memory_id {memory_id_int} out of range (1-{len(existing_memories)}), defaulting to add"
|
||||
)
|
||||
return (memory_text, None)
|
||||
|
||||
logger.debug(f"Memory update operation: update at index {index_to_replace}")
|
||||
return (memory_text, index_to_replace)
|
||||
|
||||
# Unknown operation, default to add
|
||||
logger.warning(f"Unknown operation '{operation}', defaulting to add")
|
||||
return (memory_text, None)
|
||||
294
backend/onyx/server/manage/discord_bot/api.py
Normal file
294
backend/onyx/server/manage/discord_bot/api.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""Discord bot admin API endpoints."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import DISCORD_BOT_TOKEN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.db.discord_bot import create_discord_bot_config
|
||||
from onyx.db.discord_bot import create_guild_config
|
||||
from onyx.db.discord_bot import delete_discord_bot_config
|
||||
from onyx.db.discord_bot import delete_discord_service_api_key
|
||||
from onyx.db.discord_bot import delete_guild_config
|
||||
from onyx.db.discord_bot import get_channel_config_by_internal_ids
|
||||
from onyx.db.discord_bot import get_channel_configs
|
||||
from onyx.db.discord_bot import get_discord_bot_config
|
||||
from onyx.db.discord_bot import get_guild_config_by_internal_id
|
||||
from onyx.db.discord_bot import get_guild_configs
|
||||
from onyx.db.discord_bot import update_discord_channel_config
|
||||
from onyx.db.discord_bot import update_guild_config
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.server.manage.discord_bot.models import DiscordBotConfigCreateRequest
|
||||
from onyx.server.manage.discord_bot.models import DiscordBotConfigResponse
|
||||
from onyx.server.manage.discord_bot.models import DiscordChannelConfigResponse
|
||||
from onyx.server.manage.discord_bot.models import DiscordChannelConfigUpdateRequest
|
||||
from onyx.server.manage.discord_bot.models import DiscordGuildConfigCreateResponse
|
||||
from onyx.server.manage.discord_bot.models import DiscordGuildConfigResponse
|
||||
from onyx.server.manage.discord_bot.models import DiscordGuildConfigUpdateRequest
|
||||
from onyx.server.manage.discord_bot.utils import (
|
||||
generate_discord_registration_key,
|
||||
)
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
router = APIRouter(prefix="/manage/admin/discord-bot")
|
||||
|
||||
|
||||
def _check_bot_config_api_access() -> None:
|
||||
"""Raise 403 if bot config cannot be managed via API.
|
||||
|
||||
Bot config endpoints are disabled:
|
||||
- On Cloud (managed by Onyx)
|
||||
- When DISCORD_BOT_TOKEN env var is set (managed via env)
|
||||
"""
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Discord bot configuration is managed by Onyx on Cloud.",
|
||||
)
|
||||
if DISCORD_BOT_TOKEN:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Discord bot is configured via environment variables. API access disabled.",
|
||||
)
|
||||
|
||||
|
||||
# === Bot Config ===
|
||||
|
||||
|
||||
@router.get("/config", response_model=DiscordBotConfigResponse)
|
||||
def get_bot_config(
|
||||
_: None = Depends(_check_bot_config_api_access),
|
||||
__: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> DiscordBotConfigResponse:
|
||||
"""Get Discord bot config. Returns 403 on Cloud or if env vars set."""
|
||||
config = get_discord_bot_config(db_session)
|
||||
if not config:
|
||||
return DiscordBotConfigResponse(configured=False)
|
||||
|
||||
return DiscordBotConfigResponse(
|
||||
configured=True,
|
||||
created_at=config.created_at,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/config", response_model=DiscordBotConfigResponse)
|
||||
def create_bot_request(
|
||||
request: DiscordBotConfigCreateRequest,
|
||||
_: None = Depends(_check_bot_config_api_access),
|
||||
__: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> DiscordBotConfigResponse:
|
||||
"""Create Discord bot config. Returns 403 on Cloud or if env vars set."""
|
||||
try:
|
||||
config = create_discord_bot_config(
|
||||
db_session,
|
||||
bot_token=request.bot_token,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Discord bot config already exists. Delete it first to create a new one.",
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return DiscordBotConfigResponse(
|
||||
configured=True,
|
||||
created_at=config.created_at,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/config")
|
||||
def delete_bot_config_endpoint(
|
||||
_: None = Depends(_check_bot_config_api_access),
|
||||
__: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict:
|
||||
"""Delete Discord bot config.
|
||||
|
||||
Also deletes the Discord service API key since the bot is being removed.
|
||||
"""
|
||||
deleted = delete_discord_bot_config(db_session)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Bot config not found")
|
||||
|
||||
# Also delete the service API key used by the Discord bot
|
||||
delete_discord_service_api_key(db_session)
|
||||
|
||||
db_session.commit()
|
||||
return {"deleted": True}
|
||||
|
||||
|
||||
# === Service API Key ===
|
||||
|
||||
|
||||
@router.delete("/service-api-key")
|
||||
def delete_service_api_key_endpoint(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict:
|
||||
"""Delete the Discord service API key.
|
||||
|
||||
This endpoint allows manual deletion of the service API key used by the
|
||||
Discord bot to authenticate with the Onyx API. The key is also automatically
|
||||
deleted when:
|
||||
- Bot config is deleted (self-hosted)
|
||||
- All guild configs are deleted (Cloud)
|
||||
"""
|
||||
deleted = delete_discord_service_api_key(db_session)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Service API key not found")
|
||||
db_session.commit()
|
||||
return {"deleted": True}
|
||||
|
||||
|
||||
# === Guild Config ===
|
||||
|
||||
|
||||
@router.get("/guilds", response_model=list[DiscordGuildConfigResponse])
|
||||
def list_guild_configs(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[DiscordGuildConfigResponse]:
|
||||
"""List all guild configs (pending and registered)."""
|
||||
configs = get_guild_configs(db_session)
|
||||
return [DiscordGuildConfigResponse.model_validate(c) for c in configs]
|
||||
|
||||
|
||||
@router.post("/guilds", response_model=DiscordGuildConfigCreateResponse)
|
||||
def create_guild_request(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> DiscordGuildConfigCreateResponse:
|
||||
"""Create new guild config with registration key. Key shown once."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
registration_key = generate_discord_registration_key(tenant_id)
|
||||
|
||||
config = create_guild_config(db_session, registration_key)
|
||||
db_session.commit()
|
||||
|
||||
return DiscordGuildConfigCreateResponse(
|
||||
id=config.id,
|
||||
registration_key=registration_key, # Shown once!
|
||||
)
|
||||
|
||||
|
||||
@router.get("/guilds/{config_id}", response_model=DiscordGuildConfigResponse)
|
||||
def get_guild_config(
|
||||
config_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> DiscordGuildConfigResponse:
|
||||
"""Get specific guild config."""
|
||||
config = get_guild_config_by_internal_id(db_session, internal_id=config_id)
|
||||
if not config:
|
||||
raise HTTPException(status_code=404, detail="Guild config not found")
|
||||
return DiscordGuildConfigResponse.model_validate(config)
|
||||
|
||||
|
||||
@router.patch("/guilds/{config_id}", response_model=DiscordGuildConfigResponse)
|
||||
def update_guild_request(
|
||||
config_id: int,
|
||||
request: DiscordGuildConfigUpdateRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> DiscordGuildConfigResponse:
|
||||
"""Update guild config."""
|
||||
config = get_guild_config_by_internal_id(db_session, internal_id=config_id)
|
||||
if not config:
|
||||
raise HTTPException(status_code=404, detail="Guild config not found")
|
||||
|
||||
config = update_guild_config(
|
||||
db_session,
|
||||
config,
|
||||
enabled=request.enabled,
|
||||
default_persona_id=request.default_persona_id,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
return DiscordGuildConfigResponse.model_validate(config)
|
||||
|
||||
|
||||
@router.delete("/guilds/{config_id}")
|
||||
def delete_guild_request(
|
||||
config_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict:
|
||||
"""Delete guild config (invalidates registration key).
|
||||
|
||||
On Cloud, if this was the last guild config, also deletes the service API key.
|
||||
"""
|
||||
deleted = delete_guild_config(db_session, config_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Guild config not found")
|
||||
|
||||
# On Cloud, delete service API key when all guilds are removed
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
remaining_guilds = get_guild_configs(db_session)
|
||||
if not remaining_guilds:
|
||||
delete_discord_service_api_key(db_session)
|
||||
|
||||
db_session.commit()
|
||||
return {"deleted": True}
|
||||
|
||||
|
||||
# === Channel Config ===
|
||||
|
||||
|
||||
@router.get(
|
||||
"/guilds/{config_id}/channels", response_model=list[DiscordChannelConfigResponse]
|
||||
)
|
||||
def list_channel_configs(
|
||||
config_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[DiscordChannelConfigResponse]:
|
||||
"""List whitelisted channels for a guild."""
|
||||
guild_config = get_guild_config_by_internal_id(db_session, internal_id=config_id)
|
||||
if not guild_config:
|
||||
raise HTTPException(status_code=404, detail="Guild config not found")
|
||||
if not guild_config.guild_id:
|
||||
raise HTTPException(status_code=400, detail="Guild not yet registered")
|
||||
|
||||
configs = get_channel_configs(db_session, config_id)
|
||||
return [DiscordChannelConfigResponse.model_validate(c) for c in configs]
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/guilds/{guild_config_id}/channels/{channel_config_id}",
|
||||
response_model=DiscordChannelConfigResponse,
|
||||
)
|
||||
def update_channel_request(
|
||||
guild_config_id: int,
|
||||
channel_config_id: int,
|
||||
request: DiscordChannelConfigUpdateRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> DiscordChannelConfigResponse:
|
||||
"""Update channel config."""
|
||||
config = get_channel_config_by_internal_ids(
|
||||
db_session, guild_config_id, channel_config_id
|
||||
)
|
||||
if not config:
|
||||
raise HTTPException(status_code=404, detail="Channel config not found")
|
||||
|
||||
config = update_discord_channel_config(
|
||||
db_session,
|
||||
config,
|
||||
channel_name=config.channel_name, # Keep existing name, only Discord can update
|
||||
thread_only_mode=request.thread_only_mode,
|
||||
require_bot_invocation=request.require_bot_invocation,
|
||||
persona_override_id=request.persona_override_id,
|
||||
enabled=request.enabled,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
return DiscordChannelConfigResponse.model_validate(config)
|
||||
71
backend/onyx/server/manage/discord_bot/models.py
Normal file
71
backend/onyx/server/manage/discord_bot/models.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Pydantic models for Discord bot API."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# === Bot Config ===
|
||||
|
||||
|
||||
class DiscordBotConfigResponse(BaseModel):
|
||||
configured: bool
|
||||
created_at: datetime | None = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class DiscordBotConfigCreateRequest(BaseModel):
|
||||
bot_token: str
|
||||
|
||||
|
||||
# === Guild Config ===
|
||||
|
||||
|
||||
class DiscordGuildConfigResponse(BaseModel):
|
||||
id: int
|
||||
guild_id: int | None
|
||||
guild_name: str | None
|
||||
registered_at: datetime | None
|
||||
default_persona_id: int | None
|
||||
enabled: bool
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class DiscordGuildConfigCreateResponse(BaseModel):
|
||||
id: int
|
||||
registration_key: str # Shown once!
|
||||
|
||||
|
||||
class DiscordGuildConfigUpdateRequest(BaseModel):
|
||||
enabled: bool
|
||||
default_persona_id: int | None
|
||||
|
||||
|
||||
# === Channel Config ===
|
||||
|
||||
|
||||
class DiscordChannelConfigResponse(BaseModel):
|
||||
id: int
|
||||
guild_config_id: int
|
||||
channel_id: int
|
||||
channel_name: str
|
||||
channel_type: str
|
||||
is_private: bool
|
||||
require_bot_invocation: bool
|
||||
thread_only_mode: bool
|
||||
persona_override_id: int | None
|
||||
enabled: bool
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class DiscordChannelConfigUpdateRequest(BaseModel):
|
||||
require_bot_invocation: bool
|
||||
persona_override_id: int | None
|
||||
enabled: bool
|
||||
thread_only_mode: bool
|
||||
46
backend/onyx/server/manage/discord_bot/utils.py
Normal file
46
backend/onyx/server/manage/discord_bot/utils.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Discord registration key generation and parsing."""
|
||||
|
||||
import secrets
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import unquote
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
REGISTRATION_KEY_PREFIX: str = "discord_"
|
||||
|
||||
|
||||
def generate_discord_registration_key(tenant_id: str) -> str:
|
||||
"""Generate a one-time registration key with embedded tenant_id.
|
||||
|
||||
Format: discord_<url_encoded_tenant_id>.<random_token>
|
||||
|
||||
Follows the same pattern as API keys for consistency.
|
||||
"""
|
||||
encoded_tenant = quote(tenant_id)
|
||||
random_token = secrets.token_urlsafe(16)
|
||||
|
||||
logger.info(f"Generated Discord registration key for tenant {tenant_id}")
|
||||
return f"{REGISTRATION_KEY_PREFIX}{encoded_tenant}.{random_token}"
|
||||
|
||||
|
||||
def parse_discord_registration_key(key: str) -> str | None:
|
||||
"""Parse registration key to extract tenant_id.
|
||||
|
||||
Returns tenant_id or None if invalid format.
|
||||
"""
|
||||
if not key.startswith(REGISTRATION_KEY_PREFIX):
|
||||
return None
|
||||
|
||||
try:
|
||||
key_body = key.removeprefix(REGISTRATION_KEY_PREFIX)
|
||||
parts = key_body.split(".", 1)
|
||||
if len(parts) != 2:
|
||||
return None
|
||||
|
||||
encoded_tenant = parts[0]
|
||||
tenant_id = unquote(encoded_tenant)
|
||||
return tenant_id
|
||||
except Exception:
|
||||
return None
|
||||
@@ -410,26 +410,20 @@ def list_llm_provider_basics(
|
||||
|
||||
all_providers = fetch_existing_llm_providers(db_session)
|
||||
user_group_ids = fetch_user_group_ids(db_session, user) if user else set()
|
||||
is_admin = user and user.role == UserRole.ADMIN
|
||||
is_admin = user is not None and user.role == UserRole.ADMIN
|
||||
|
||||
accessible_providers = []
|
||||
|
||||
for provider in all_providers:
|
||||
# Include all public providers
|
||||
if provider.is_public:
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
continue
|
||||
|
||||
# Include restricted providers user has access to via groups
|
||||
if is_admin:
|
||||
# Admins see all providers
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
elif provider.groups:
|
||||
# User must be in at least one of the provider's groups
|
||||
if user_group_ids.intersection({g.id for g in provider.groups}):
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
elif not provider.personas:
|
||||
# No restrictions = accessible
|
||||
# Use centralized access control logic with persona=None since we're
|
||||
# listing providers without a specific persona context. This correctly:
|
||||
# - Includes all public providers
|
||||
# - Includes providers user can access via group membership
|
||||
# - Excludes persona-only restricted providers (requires specific persona)
|
||||
# - Excludes non-public providers with no restrictions (admin-only)
|
||||
if can_user_access_llm_provider(
|
||||
provider, user_group_ids, persona=None, is_admin=is_admin
|
||||
):
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
|
||||
end_time = datetime.now(timezone.utc)
|
||||
|
||||
@@ -58,6 +58,7 @@ from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.feedback import create_chat_message_feedback
|
||||
from onyx.db.feedback import remove_chat_message_feedback
|
||||
from onyx.db.models import ChatSessionSharedStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
@@ -266,7 +267,35 @@ def get_chat_session(
|
||||
include_deleted=include_deleted,
|
||||
)
|
||||
except ValueError:
|
||||
raise ValueError("Chat session does not exist or has been deleted")
|
||||
try:
|
||||
# If we failed to get a chat session, try to retrieve the session with
|
||||
# less restrictive filters in order to identify what exactly mismatched
|
||||
# so we can bubble up an accurate error code andmessage.
|
||||
existing_chat_session = get_chat_session_by_id(
|
||||
chat_session_id=session_id,
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
is_shared=False,
|
||||
include_deleted=True,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Chat session not found")
|
||||
|
||||
if not include_deleted and existing_chat_session.deleted:
|
||||
raise HTTPException(status_code=404, detail="Chat session has been deleted")
|
||||
|
||||
if is_shared:
|
||||
if existing_chat_session.shared_status != ChatSessionSharedStatus.PUBLIC:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Chat session is not shared"
|
||||
)
|
||||
elif user_id is not None and existing_chat_session.user_id not in (
|
||||
user_id,
|
||||
None,
|
||||
):
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
raise HTTPException(status_code=404, detail="Chat session not found")
|
||||
|
||||
# for chat-seeding: if the session is unassigned, assign it now. This is done here
|
||||
# to avoid another back and forth between FE -> BE before starting the first
|
||||
|
||||
@@ -35,6 +35,8 @@ class MessageOrigin(str, Enum):
|
||||
CHROME_EXTENSION = "chrome_extension"
|
||||
API = "api"
|
||||
SLACKBOT = "slackbot"
|
||||
WIDGET = "widget"
|
||||
DISCORDBOT = "discordbot"
|
||||
UNKNOWN = "unknown"
|
||||
UNSET = "unset"
|
||||
|
||||
|
||||
@@ -23,6 +23,9 @@ from onyx.server.settings.models import UserSettings
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.settings.store import store_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -37,6 +40,11 @@ def admin_put_settings(
|
||||
store_settings(settings)
|
||||
|
||||
|
||||
def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
"""MIT version: no-op, returns settings unchanged."""
|
||||
return settings
|
||||
|
||||
|
||||
@basic_router.get("")
|
||||
def fetch_settings(
|
||||
user: User | None = Depends(current_user),
|
||||
@@ -53,6 +61,13 @@ def fetch_settings(
|
||||
except KvKeyNotFoundError:
|
||||
needs_reindexing = False
|
||||
|
||||
apply_fn = fetch_versioned_implementation_with_fallback(
|
||||
"onyx.server.settings.api",
|
||||
"apply_license_status_to_settings",
|
||||
apply_license_status_to_settings,
|
||||
)
|
||||
general_settings = apply_fn(general_settings)
|
||||
|
||||
return UserSettings(
|
||||
**general_settings.model_dump(),
|
||||
notifications=settings_notifications,
|
||||
|
||||
@@ -80,6 +80,8 @@ class ToolResponse(BaseModel):
|
||||
# | WebContentResponse
|
||||
# This comes from custom tools, tool result needs to be saved
|
||||
| CustomToolCallSummary
|
||||
# If the rich response is a string, this is what's saved to the tool call in the DB
|
||||
| str
|
||||
| None # If nothing needs to be persisted outside of the string value passed to the LLM
|
||||
)
|
||||
# This is the final string that needs to be wrapped in a tool call response message and concatenated to the history
|
||||
|
||||
135
backend/onyx/tools/tool_implementations/memory/memory_tool.py
Normal file
135
backend/onyx/tools/tool_implementations/memory/memory_tool.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
Memory Tool for storing user-specific information.
|
||||
|
||||
This tool allows the LLM to save memories about the user for future conversations.
|
||||
The memories are passed in via override_kwargs which contains the current list of
|
||||
memories that exist for the user.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.secondary_llm_flows.memory_update import process_memory_update
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ChatMinimalTextMessage
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
MEMORY_FIELD = "memory"
|
||||
|
||||
|
||||
class MemoryToolOverrideKwargs(BaseModel):
|
||||
# Not including the Team Information or User Preferences because these are less likely to contribute to building the memory
|
||||
# Things like the user's name is important because the LLM may create a memory like "Dave prefers light mode." instead of
|
||||
# User prefers light mode.
|
||||
user_name: str | None
|
||||
user_email: str | None
|
||||
user_role: str | None
|
||||
existing_memories: list[str]
|
||||
chat_history: list[ChatMinimalTextMessage]
|
||||
|
||||
|
||||
class MemoryTool(Tool[MemoryToolOverrideKwargs]):
|
||||
NAME = "add_memory"
|
||||
DISPLAY_NAME = "Add Memory"
|
||||
DESCRIPTION = "Save memories about the user for future conversations."
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_id: int,
|
||||
emitter: Emitter,
|
||||
llm: LLM,
|
||||
) -> None:
|
||||
super().__init__(emitter=emitter)
|
||||
self._id = tool_id
|
||||
self.llm = llm
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.NAME
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self.DESCRIPTION
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return self.DISPLAY_NAME
|
||||
|
||||
@override
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
MEMORY_FIELD: {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The text of the memory to add or update. "
|
||||
"Should be a concise, standalone statement that "
|
||||
"captures the key information. For example: "
|
||||
"'User prefers dark mode' or 'User's favorite frontend framework is React'."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [MEMORY_FIELD],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@override
|
||||
def emit_start(self, placement: Placement) -> None:
|
||||
# TODO
|
||||
pass
|
||||
|
||||
@override
|
||||
def run(
|
||||
self,
|
||||
placement: Placement,
|
||||
override_kwargs: MemoryToolOverrideKwargs,
|
||||
**llm_kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
memory = cast(str, llm_kwargs[MEMORY_FIELD])
|
||||
|
||||
existing_memories = override_kwargs.existing_memories
|
||||
chat_history = override_kwargs.chat_history
|
||||
|
||||
# Determine if this should be an add or update operation
|
||||
memory_text, index_to_replace = process_memory_update(
|
||||
new_memory=memory,
|
||||
existing_memories=existing_memories,
|
||||
chat_history=chat_history,
|
||||
llm=self.llm,
|
||||
user_name=override_kwargs.user_name,
|
||||
user_email=override_kwargs.user_email,
|
||||
user_role=override_kwargs.user_role,
|
||||
)
|
||||
|
||||
# TODO: the data should be return and processed outside of the tool
|
||||
# Persisted to the db for future conversations
|
||||
# The actual persistence of the memory will be handled by the caller
|
||||
# This tool just returns the memory to be saved
|
||||
logger.info(f"New memory to be added: {memory_text}")
|
||||
|
||||
return ToolResponse(
|
||||
rich_response=memory_text,
|
||||
llm_facing_response=f"New memory added: {memory_text}",
|
||||
)
|
||||
@@ -17,6 +17,8 @@ from onyx.tools.models import ToolCallException
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.models import WebSearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.memory.memory_tool import MemoryTool
|
||||
from onyx.tools.tool_implementations.memory.memory_tool import MemoryToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
|
||||
@@ -297,6 +299,7 @@ def run_tool_calls(
|
||||
SearchToolOverrideKwargs
|
||||
| WebSearchToolOverrideKwargs
|
||||
| OpenURLToolOverrideKwargs
|
||||
| MemoryToolOverrideKwargs
|
||||
| None
|
||||
) = None
|
||||
|
||||
@@ -330,6 +333,9 @@ def run_tool_calls(
|
||||
)
|
||||
starting_citation_num += 100
|
||||
|
||||
elif isinstance(tool, MemoryTool):
|
||||
raise NotImplementedError("MemoryTool is not implemented")
|
||||
|
||||
tool_run_params.append((tool, tool_call, override_kwargs))
|
||||
|
||||
# Run all tools in parallel
|
||||
|
||||
@@ -21,6 +21,28 @@ ESCAPE_SEQUENCE_RE = re.compile(
|
||||
re.UNICODE | re.VERBOSE,
|
||||
)
|
||||
|
||||
_INITIAL_FILTER = re.compile(
|
||||
"["
|
||||
"\U0000fff0-\U0000ffff" # Specials
|
||||
"\U0001f000-\U0001f9ff" # Emoticons
|
||||
"\U00002000-\U0000206f" # General Punctuation
|
||||
"\U00002190-\U000021ff" # Arrows
|
||||
"\U00002700-\U000027bf" # Dingbats
|
||||
"]+",
|
||||
flags=re.UNICODE,
|
||||
)
|
||||
|
||||
# Regex to match invalid Unicode characters that cause UTF-8 encoding errors:
|
||||
# - \x00-\x08: Control characters (except tab \x09)
|
||||
# - \x0b-\x0c: Vertical tab and form feed
|
||||
# - \x0e-\x1f: More control characters (except newline \x0a, carriage return \x0d)
|
||||
# - \ud800-\udfff: Surrogate pairs (invalid when unpaired, causes "surrogates not allowed" errors)
|
||||
# - \ufdd0-\ufdef: Non-characters
|
||||
# - \ufffe-\uffff: Non-characters
|
||||
_INVALID_UNICODE_CHARS_RE = re.compile(
|
||||
"[\x00-\x08\x0b\x0c\x0e-\x1f\ud800-\udfff\ufdd0-\ufdef\ufffe\uffff]"
|
||||
)
|
||||
|
||||
|
||||
def decode_escapes(s: str) -> str:
|
||||
def decode_match(match: re.Match) -> str:
|
||||
@@ -76,27 +98,98 @@ def escape_quotes(original_json_str: str) -> str:
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def extract_embedded_json(s: str) -> dict:
|
||||
first_brace_index = s.find("{")
|
||||
last_brace_index = s.rfind("}")
|
||||
def find_all_json_objects(text: str) -> list[dict]:
|
||||
"""Find all JSON objects in text using balanced brace matching.
|
||||
|
||||
if first_brace_index == -1 or last_brace_index == -1:
|
||||
logger.warning("No valid json found, assuming answer is entire string")
|
||||
return {"answer": s, "quotes": []}
|
||||
Iterates through the text, and for each '{' found, attempts to find its
|
||||
matching '}' by counting brace depth. Each balanced substring is then
|
||||
validated as JSON. This includes nested JSON objects within other objects.
|
||||
|
||||
json_str = s[first_brace_index : last_brace_index + 1]
|
||||
try:
|
||||
return json.loads(json_str, strict=False)
|
||||
Use case: Parsing LLM output that may contain multiple JSON objects, or when
|
||||
the LLM/serving layer outputs function calls in non-standard formats
|
||||
(e.g. OpenAI's function.open_url style).
|
||||
|
||||
except json.JSONDecodeError:
|
||||
Args:
|
||||
text: The text to search for JSON objects.
|
||||
|
||||
Returns:
|
||||
A list of all successfully parsed JSON objects (dicts only).
|
||||
"""
|
||||
json_objects: list[dict] = []
|
||||
i = 0
|
||||
|
||||
while i < len(text):
|
||||
if text[i] == "{":
|
||||
# Try to find a matching closing brace
|
||||
brace_count = 0
|
||||
start = i
|
||||
for j in range(i, len(text)):
|
||||
if text[j] == "{":
|
||||
brace_count += 1
|
||||
elif text[j] == "}":
|
||||
brace_count -= 1
|
||||
if brace_count == 0:
|
||||
# Found potential JSON object
|
||||
candidate = text[start : j + 1]
|
||||
try:
|
||||
parsed = json.loads(candidate)
|
||||
if isinstance(parsed, dict):
|
||||
json_objects.append(parsed)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
break
|
||||
i += 1
|
||||
|
||||
return json_objects
|
||||
|
||||
|
||||
def parse_llm_json_response(content: str) -> dict | None:
|
||||
"""Parse a single JSON object from LLM output, handling markdown code blocks.
|
||||
|
||||
Designed for LLM responses that typically contain exactly one JSON object,
|
||||
possibly wrapped in markdown formatting.
|
||||
|
||||
Tries extraction in order:
|
||||
1. JSON inside markdown code block (```json ... ``` or ``` ... ```)
|
||||
2. Entire content as raw JSON
|
||||
3. First '{' to last '}' in content (greedy match)
|
||||
|
||||
Args:
|
||||
content: The LLM response text to parse.
|
||||
|
||||
Returns:
|
||||
The parsed JSON dict if found, None otherwise.
|
||||
"""
|
||||
# Try to find JSON in markdown code block first
|
||||
# Use greedy .* (not .*?) to match nested objects correctly within code block bounds
|
||||
json_match = re.search(r"```(?:json)?\s*(\{.*\})\s*```", content, re.DOTALL)
|
||||
if json_match:
|
||||
try:
|
||||
return json.loads(escape_quotes(json_str), strict=False)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError("Failed to parse JSON, even after escaping quotes") from e
|
||||
result = json.loads(json_match.group(1))
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Try to parse the entire content as JSON
|
||||
try:
|
||||
result = json.loads(content)
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
def clean_up_code_blocks(model_out_raw: str) -> str:
|
||||
return model_out_raw.strip().strip("```").strip().replace("\\xa0", "")
|
||||
# Try to find any JSON object in the content
|
||||
json_match = re.search(r"\{.*\}", content, re.DOTALL)
|
||||
if json_match:
|
||||
try:
|
||||
result = json.loads(json_match.group(0))
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def clean_model_quote(quote: str, trim_length: int) -> str:
|
||||
@@ -126,18 +219,6 @@ def shared_precompare_cleanup(text: str) -> str:
|
||||
return text
|
||||
|
||||
|
||||
_INITIAL_FILTER = re.compile(
|
||||
"["
|
||||
"\U0000fff0-\U0000ffff" # Specials
|
||||
"\U0001f000-\U0001f9ff" # Emoticons
|
||||
"\U00002000-\U0000206f" # General Punctuation
|
||||
"\U00002190-\U000021ff" # Arrows
|
||||
"\U00002700-\U000027bf" # Dingbats
|
||||
"]+",
|
||||
flags=re.UNICODE,
|
||||
)
|
||||
|
||||
|
||||
def clean_text(text: str) -> str:
|
||||
# Remove specific Unicode ranges that might cause issues
|
||||
cleaned = _INITIAL_FILTER.sub("", text)
|
||||
@@ -167,18 +248,6 @@ def remove_markdown_image_references(text: str) -> str:
|
||||
return re.sub(r"!\[[^\]]*\]\([^\)]+\)", "", text)
|
||||
|
||||
|
||||
# Regex to match invalid Unicode characters that cause UTF-8 encoding errors:
|
||||
# - \x00-\x08: Control characters (except tab \x09)
|
||||
# - \x0b-\x0c: Vertical tab and form feed
|
||||
# - \x0e-\x1f: More control characters (except newline \x0a, carriage return \x0d)
|
||||
# - \ud800-\udfff: Surrogate pairs (invalid when unpaired, causes "surrogates not allowed" errors)
|
||||
# - \ufdd0-\ufdef: Non-characters
|
||||
# - \ufffe-\uffff: Non-characters
|
||||
_INVALID_UNICODE_CHARS_RE = re.compile(
|
||||
"[\x00-\x08\x0b\x0c\x0e-\x1f\ud800-\udfff\ufdd0-\ufdef\ufffe\uffff]"
|
||||
)
|
||||
|
||||
|
||||
def remove_invalid_unicode_chars(text: str) -> str:
|
||||
"""Remove Unicode characters that are invalid in UTF-8 or cause encoding issues.
|
||||
|
||||
|
||||
@@ -573,7 +573,7 @@ mcp==1.25.0
|
||||
# onyx
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mistune==0.8.4
|
||||
mistune==3.2.0
|
||||
# via onyx
|
||||
more-itertools==10.8.0
|
||||
# via
|
||||
|
||||
@@ -11,6 +11,7 @@ aiohappyeyeballs==2.6.1
|
||||
aiohttp==3.13.3
|
||||
# via
|
||||
# aiobotocore
|
||||
# discord-py
|
||||
# litellm
|
||||
# voyageai
|
||||
aioitertools==0.13.0
|
||||
@@ -94,6 +95,8 @@ decorator==5.2.1
|
||||
# via
|
||||
# ipython
|
||||
# retry
|
||||
discord-py==2.4.0
|
||||
# via onyx
|
||||
distlib==0.4.0
|
||||
# via virtualenv
|
||||
distro==1.9.0
|
||||
@@ -295,7 +298,7 @@ numpy==2.4.1
|
||||
# pandas-stubs
|
||||
# shapely
|
||||
# voyageai
|
||||
onyx-devtools==0.3.2
|
||||
onyx-devtools==0.6.2
|
||||
# via onyx
|
||||
openai==2.14.0
|
||||
# via
|
||||
|
||||
@@ -11,6 +11,7 @@ aiohappyeyeballs==2.6.1
|
||||
aiohttp==3.13.3
|
||||
# via
|
||||
# aiobotocore
|
||||
# discord-py
|
||||
# litellm
|
||||
# voyageai
|
||||
aioitertools==0.13.0
|
||||
@@ -67,6 +68,8 @@ colorama==0.4.6 ; sys_platform == 'win32'
|
||||
# tqdm
|
||||
decorator==5.2.1
|
||||
# via retry
|
||||
discord-py==2.4.0
|
||||
# via onyx
|
||||
distro==1.9.0
|
||||
# via openai
|
||||
docstring-parser==0.17.0
|
||||
|
||||
@@ -13,6 +13,7 @@ aiohappyeyeballs==2.6.1
|
||||
aiohttp==3.13.3
|
||||
# via
|
||||
# aiobotocore
|
||||
# discord-py
|
||||
# litellm
|
||||
# voyageai
|
||||
aioitertools==0.13.0
|
||||
@@ -83,6 +84,8 @@ colorama==0.4.6 ; sys_platform == 'win32'
|
||||
# tqdm
|
||||
decorator==5.2.1
|
||||
# via retry
|
||||
discord-py==2.4.0
|
||||
# via onyx
|
||||
distro==1.9.0
|
||||
# via openai
|
||||
docstring-parser==0.17.0
|
||||
|
||||
@@ -191,6 +191,18 @@ autorestart=true
|
||||
startretries=5
|
||||
startsecs=60
|
||||
|
||||
# Listens for Discord messages and responds with answers
|
||||
# for all guilds/channels that the OnyxBot has been added to.
|
||||
# If not configured, will continue to probe every 3 minutes for a Discord bot token.
|
||||
[program:discord_bot]
|
||||
command=python onyx/onyxbot/discord/client.py
|
||||
stdout_logfile=/var/log/discord_bot.log
|
||||
stdout_logfile_maxbytes=16MB
|
||||
redirect_stderr=true
|
||||
autorestart=true
|
||||
startretries=5
|
||||
startsecs=60
|
||||
|
||||
# Pushes all logs from the above programs to stdout
|
||||
# No log rotation here, since it's stdout it's handled by the Docker container logging
|
||||
[program:log-redirect-handler]
|
||||
@@ -206,6 +218,7 @@ command=tail -qF
|
||||
/var/log/celery_worker_user_file_processing.log
|
||||
/var/log/celery_worker_docfetching.log
|
||||
/var/log/slack_bot.log
|
||||
/var/log/discord_bot.log
|
||||
/var/log/supervisord_watchdog_celery_beat.log
|
||||
/var/log/mcp_server.log
|
||||
/var/log/mcp_server.err.log
|
||||
|
||||
@@ -0,0 +1,281 @@
|
||||
"""
|
||||
External dependency unit tests for user file processing queue protections.
|
||||
|
||||
Verifies that the three mechanisms added to check_user_file_processing work
|
||||
correctly:
|
||||
|
||||
1. Queue depth backpressure – when the broker queue exceeds
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH, no new tasks are enqueued.
|
||||
|
||||
2. Per-file Redis guard key – if the guard key for a file already exists in
|
||||
Redis, that file is skipped even though it is still in PROCESSING status.
|
||||
|
||||
3. Task expiry – every send_task call carries expires=
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES so that stale queued tasks are
|
||||
discarded by workers automatically.
|
||||
|
||||
Also verifies that process_single_user_file clears the guard key the moment
|
||||
it is picked up by a worker.
|
||||
|
||||
Uses real Redis (DB 0 via get_redis_client) and real PostgreSQL for UserFile
|
||||
rows. The Celery app is provided as a MagicMock injected via a PropertyMock
|
||||
on the task class so no real broker is needed.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import PropertyMock
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_user_file_lock_key,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_user_file_queued_key,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
check_user_file_processing,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
process_single_user_file,
|
||||
)
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PATCH_QUEUE_LEN = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_queue_length"
|
||||
)
|
||||
|
||||
|
||||
def _create_processing_user_file(db_session: Session, user_id: object) -> UserFile:
|
||||
"""Insert a UserFile in PROCESSING status and return it."""
|
||||
uf = UserFile(
|
||||
id=uuid4(),
|
||||
user_id=user_id,
|
||||
file_id=f"test_file_{uuid4().hex[:8]}",
|
||||
name=f"test_{uuid4().hex[:8]}.txt",
|
||||
file_type="text/plain",
|
||||
status=UserFileStatus.PROCESSING,
|
||||
)
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
db_session.refresh(uf)
|
||||
return uf
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, None]:
|
||||
"""Patch the ``app`` property on *task*'s class so that ``self.app``
|
||||
inside the task function returns *mock_app*.
|
||||
|
||||
With ``bind=True``, ``task.run`` is a bound method whose ``__self__`` is
|
||||
the actual task instance. We patch ``app`` on that instance's class
|
||||
(a unique Celery-generated Task subclass) so the mock is scoped to this
|
||||
task only.
|
||||
"""
|
||||
task_instance = task.run.__self__
|
||||
with patch.object(
|
||||
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQueueDepthBackpressure:
|
||||
"""Protection 1: skip all enqueuing when the broker queue is too deep."""
|
||||
|
||||
def test_no_tasks_enqueued_when_queue_over_limit(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""When the queue depth exceeds the limit the beat cycle is skipped."""
|
||||
user = create_test_user(db_session, "bp_user")
|
||||
_create_processing_user_file(db_session, user.id)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
with (
|
||||
_patch_task_app(check_user_file_processing, mock_app),
|
||||
patch(
|
||||
_PATCH_QUEUE_LEN, return_value=USER_FILE_PROCESSING_MAX_QUEUE_DEPTH + 1
|
||||
),
|
||||
):
|
||||
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
mock_app.send_task.assert_not_called()
|
||||
|
||||
|
||||
class TestPerFileGuardKey:
|
||||
"""Protection 2: per-file Redis guard key prevents duplicate enqueue."""
|
||||
|
||||
def test_guarded_file_not_re_enqueued(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file whose guard key is already set in Redis is skipped."""
|
||||
user = create_test_user(db_session, "guard_user")
|
||||
uf = _create_processing_user_file(db_session, user.id)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_queued_key(uf.id)
|
||||
redis_client.setex(guard_key, CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, 1)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
try:
|
||||
with (
|
||||
_patch_task_app(check_user_file_processing, mock_app),
|
||||
patch(_PATCH_QUEUE_LEN, return_value=0),
|
||||
):
|
||||
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
# send_task must not have been called with this specific file's ID
|
||||
for call in mock_app.send_task.call_args_list:
|
||||
kwargs = call.kwargs.get("kwargs", {})
|
||||
assert kwargs.get("user_file_id") != str(
|
||||
uf.id
|
||||
), f"File {uf.id} should have been skipped because its guard key exists"
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
def test_guard_key_exists_in_redis_after_enqueue(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""After a file is enqueued its guard key is present in Redis with a TTL."""
|
||||
user = create_test_user(db_session, "guard_set_user")
|
||||
uf = _create_processing_user_file(db_session, user.id)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_queued_key(uf.id)
|
||||
redis_client.delete(guard_key) # clean slate
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
try:
|
||||
with (
|
||||
_patch_task_app(check_user_file_processing, mock_app),
|
||||
patch(_PATCH_QUEUE_LEN, return_value=0),
|
||||
):
|
||||
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
assert redis_client.exists(
|
||||
guard_key
|
||||
), "Guard key should be set in Redis after enqueue"
|
||||
ttl = int(redis_client.ttl(guard_key)) # type: ignore[arg-type]
|
||||
assert 0 < ttl <= CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, (
|
||||
f"Guard key TTL {ttl}s is outside the expected range "
|
||||
f"(0, {CELERY_USER_FILE_PROCESSING_TASK_EXPIRES}]"
|
||||
)
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
|
||||
class TestTaskExpiry:
|
||||
"""Protection 3: every send_task call includes an expires value."""
|
||||
|
||||
def test_send_task_called_with_expires(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""send_task is called with the correct queue, task name, and expires."""
|
||||
user = create_test_user(db_session, "expires_user")
|
||||
uf = _create_processing_user_file(db_session, user.id)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_queued_key(uf.id)
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
try:
|
||||
with (
|
||||
_patch_task_app(check_user_file_processing, mock_app),
|
||||
patch(_PATCH_QUEUE_LEN, return_value=0),
|
||||
):
|
||||
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
# At least one task should have been submitted (for our file)
|
||||
assert (
|
||||
mock_app.send_task.call_count >= 1
|
||||
), "Expected at least one task to be submitted"
|
||||
|
||||
# Every submitted task must carry expires
|
||||
for call in mock_app.send_task.call_args_list:
|
||||
assert call.args[0] == OnyxCeleryTask.PROCESS_SINGLE_USER_FILE
|
||||
assert call.kwargs.get("queue") == OnyxCeleryQueues.USER_FILE_PROCESSING
|
||||
assert (
|
||||
call.kwargs.get("expires")
|
||||
== CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
), (
|
||||
"Task must be submitted with the correct expires value to prevent "
|
||||
"stale task accumulation"
|
||||
)
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
|
||||
class TestWorkerClearsGuardKey:
|
||||
"""process_single_user_file removes the guard key when it picks up a task."""
|
||||
|
||||
def test_guard_key_deleted_on_pickup(
|
||||
self,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""The guard key is deleted before the worker does any real work.
|
||||
|
||||
We simulate an already-locked file so process_single_user_file returns
|
||||
early – but crucially, after the guard key deletion.
|
||||
"""
|
||||
user_file_id = str(uuid4())
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_queued_key(user_file_id)
|
||||
|
||||
# Simulate the guard key set when the beat enqueued the task
|
||||
redis_client.setex(guard_key, CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, 1)
|
||||
assert redis_client.exists(guard_key), "Guard key must exist before pickup"
|
||||
|
||||
# Hold the per-file processing lock so the worker exits early without
|
||||
# touching the database or file store.
|
||||
lock_key = _user_file_lock_key(user_file_id)
|
||||
processing_lock = redis_client.lock(lock_key, timeout=10)
|
||||
acquired = processing_lock.acquire(blocking=False)
|
||||
assert acquired, "Should be able to acquire the processing lock for this test"
|
||||
|
||||
try:
|
||||
process_single_user_file.run(
|
||||
user_file_id=user_file_id,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
)
|
||||
finally:
|
||||
if processing_lock.owned():
|
||||
processing_lock.release()
|
||||
|
||||
assert not redis_client.exists(
|
||||
guard_key
|
||||
), "Guard key should be deleted when the worker picks up the task"
|
||||
@@ -90,12 +90,12 @@ _EXPECTED_CONFLUENCE_GROUPS = [
|
||||
),
|
||||
ExternalUserGroupSet(
|
||||
id="bitbucket-users-onyxai",
|
||||
user_emails={"oauth@onyx.app"},
|
||||
user_emails={"founders@onyx.app", "oauth@onyx.app"},
|
||||
gives_anyone_access=False,
|
||||
),
|
||||
ExternalUserGroupSet(
|
||||
id="bitbucket-admins-onyxai",
|
||||
user_emails={"oauth@onyx.app"},
|
||||
user_emails={"founders@onyx.app", "oauth@onyx.app"},
|
||||
gives_anyone_access=False,
|
||||
),
|
||||
ExternalUserGroupSet(
|
||||
|
||||
@@ -77,6 +77,16 @@ _EXPECTED_JIRA_GROUPS = [
|
||||
},
|
||||
gives_anyone_access=False,
|
||||
),
|
||||
ExternalUserGroupSet(
|
||||
id="bitbucket-admins-onyxai",
|
||||
user_emails={"founders@onyx.app"}, # no Oauth, we skip "app" account in jira
|
||||
gives_anyone_access=False,
|
||||
),
|
||||
ExternalUserGroupSet(
|
||||
id="bitbucket-users-onyxai",
|
||||
user_emails={"founders@onyx.app"}, # no Oauth, we skip "app" account in jira
|
||||
gives_anyone_access=False,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
||||
162
backend/tests/external_dependency_unit/discord_bot/conftest.py
Normal file
162
backend/tests/external_dependency_unit/discord_bot/conftest.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""Fixtures for Discord bot external dependency tests."""
|
||||
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import discord
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
TEST_TENANT_ID: str = "public"
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def db_session() -> Generator[Session, None, None]:
|
||||
"""Create a database session for testing."""
|
||||
SqlEngine.init_engine(pool_size=10, max_overflow=5)
|
||||
with get_session_with_current_tenant() as session:
|
||||
yield session
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def tenant_context() -> Generator[None, None, None]:
|
||||
"""Set up tenant context for testing."""
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cache_manager() -> MagicMock:
|
||||
"""Mock DiscordCacheManager."""
|
||||
cache = MagicMock()
|
||||
cache.get_tenant.return_value = TEST_TENANT_ID
|
||||
cache.get_api_key.return_value = "test_api_key"
|
||||
cache.refresh_all = AsyncMock()
|
||||
cache.refresh_guild = AsyncMock()
|
||||
cache.is_initialized = True
|
||||
return cache
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_client() -> MagicMock:
|
||||
"""Mock OnyxAPIClient."""
|
||||
client = MagicMock()
|
||||
client.initialize = AsyncMock()
|
||||
client.close = AsyncMock()
|
||||
client.is_initialized = True
|
||||
|
||||
# Mock successful response
|
||||
mock_response = MagicMock()
|
||||
mock_response.answer = "Test response from bot"
|
||||
mock_response.citation_info = None
|
||||
mock_response.top_documents = None
|
||||
mock_response.error_msg = None
|
||||
|
||||
client.send_chat_message = AsyncMock(return_value=mock_response)
|
||||
client.health_check = AsyncMock(return_value=True)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord_guild() -> MagicMock:
|
||||
"""Mock Discord guild with channels."""
|
||||
guild = MagicMock(spec=discord.Guild)
|
||||
guild.id = 123456789
|
||||
guild.name = "Test Server"
|
||||
guild.default_role = MagicMock()
|
||||
|
||||
# Create some mock channels
|
||||
text_channel = MagicMock(spec=discord.TextChannel)
|
||||
text_channel.id = 111111111
|
||||
text_channel.name = "general"
|
||||
text_channel.type = discord.ChannelType.text
|
||||
perms = MagicMock()
|
||||
perms.view_channel = True
|
||||
text_channel.permissions_for.return_value = perms
|
||||
|
||||
forum_channel = MagicMock(spec=discord.ForumChannel)
|
||||
forum_channel.id = 222222222
|
||||
forum_channel.name = "forum"
|
||||
forum_channel.type = discord.ChannelType.forum
|
||||
forum_channel.permissions_for.return_value = perms
|
||||
|
||||
private_channel = MagicMock(spec=discord.TextChannel)
|
||||
private_channel.id = 333333333
|
||||
private_channel.name = "private"
|
||||
private_channel.type = discord.ChannelType.text
|
||||
private_perms = MagicMock()
|
||||
private_perms.view_channel = False
|
||||
private_channel.permissions_for.return_value = private_perms
|
||||
|
||||
guild.channels = [text_channel, forum_channel, private_channel]
|
||||
guild.text_channels = [text_channel, private_channel]
|
||||
guild.forum_channels = [forum_channel]
|
||||
|
||||
return guild
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord_message(mock_discord_guild: MagicMock) -> MagicMock:
|
||||
"""Mock Discord message for testing."""
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.id = 555555555
|
||||
msg.author = MagicMock(spec=discord.Member)
|
||||
msg.author.id = 444444444
|
||||
msg.author.bot = False
|
||||
msg.author.display_name = "TestUser"
|
||||
msg.author.guild_permissions = MagicMock()
|
||||
msg.author.guild_permissions.administrator = True
|
||||
msg.author.guild_permissions.manage_guild = True
|
||||
msg.content = "Hello bot"
|
||||
msg.guild = mock_discord_guild
|
||||
msg.channel = MagicMock()
|
||||
msg.channel.id = 111111111
|
||||
msg.channel.name = "general"
|
||||
msg.channel.send = AsyncMock()
|
||||
msg.type = discord.MessageType.default
|
||||
msg.mentions = []
|
||||
msg.role_mentions = []
|
||||
msg.channel_mentions = []
|
||||
msg.reference = None
|
||||
msg.add_reaction = AsyncMock()
|
||||
msg.remove_reaction = AsyncMock()
|
||||
msg.reply = AsyncMock()
|
||||
msg.create_thread = AsyncMock()
|
||||
return msg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bot_user() -> MagicMock:
|
||||
"""Mock Discord bot user."""
|
||||
user = MagicMock(spec=discord.ClientUser)
|
||||
user.id = 987654321
|
||||
user.display_name = "OnyxBot"
|
||||
user.bot = True
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord_bot(
|
||||
mock_cache_manager: MagicMock,
|
||||
mock_api_client: MagicMock,
|
||||
mock_bot_user: MagicMock,
|
||||
) -> MagicMock:
|
||||
"""Mock OnyxDiscordClient."""
|
||||
bot = MagicMock()
|
||||
bot.user = mock_bot_user
|
||||
bot.cache = mock_cache_manager
|
||||
bot.api_client = mock_api_client
|
||||
bot.ready = True
|
||||
bot.loop = MagicMock()
|
||||
bot.is_closed.return_value = False
|
||||
bot.guilds = []
|
||||
return bot
|
||||
@@ -0,0 +1,616 @@
|
||||
"""Tests for Discord bot event handling with mocked Discord API.
|
||||
|
||||
These tests mock the Discord API to test event handling logic.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import discord
|
||||
import pytest
|
||||
|
||||
from onyx.onyxbot.discord.handle_commands import get_text_channels
|
||||
from onyx.onyxbot.discord.handle_commands import handle_dm
|
||||
from onyx.onyxbot.discord.handle_commands import handle_registration_command
|
||||
from onyx.onyxbot.discord.handle_commands import handle_sync_channels_command
|
||||
from onyx.onyxbot.discord.handle_message import process_chat_message
|
||||
from onyx.onyxbot.discord.handle_message import send_error_response
|
||||
from onyx.onyxbot.discord.handle_message import send_response
|
||||
|
||||
|
||||
class TestGuildRegistrationCommand:
|
||||
"""Tests for !register command handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_guild_success(
|
||||
self,
|
||||
mock_discord_message: MagicMock,
|
||||
mock_cache_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Valid registration key with admin perms succeeds."""
|
||||
mock_discord_message.content = "!register discord_public.valid_token"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_commands.parse_discord_registration_key",
|
||||
return_value="public",
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_commands.get_session_with_tenant"
|
||||
) as mock_session,
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_commands.get_guild_config_by_registration_key"
|
||||
) as mock_get_config,
|
||||
patch("onyx.onyxbot.discord.handle_commands.bulk_create_channel_configs"),
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.id = 1
|
||||
mock_config.guild_id = None # Not yet registered
|
||||
mock_get_config.return_value = mock_config
|
||||
|
||||
mock_cache_manager.get_tenant.return_value = None # Not in cache yet
|
||||
|
||||
result = await handle_registration_command(
|
||||
mock_discord_message, mock_cache_manager
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_discord_message.reply.assert_called()
|
||||
# Check that success message was sent
|
||||
call_args = mock_discord_message.reply.call_args
|
||||
assert "Successfully registered" in str(call_args)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_invalid_key_format(
|
||||
self,
|
||||
mock_discord_message: MagicMock,
|
||||
mock_cache_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Malformed key DMs user and deletes message."""
|
||||
mock_discord_message.content = "!register abc" # Malformed
|
||||
|
||||
with patch(
|
||||
"onyx.onyxbot.discord.handle_commands.parse_discord_registration_key",
|
||||
return_value=None, # Invalid format
|
||||
):
|
||||
result = await handle_registration_command(
|
||||
mock_discord_message, mock_cache_manager
|
||||
)
|
||||
|
||||
assert result is True
|
||||
# On failure: DM the author and delete the message
|
||||
mock_discord_message.author.send.assert_called()
|
||||
call_args = mock_discord_message.author.send.call_args
|
||||
assert "Invalid" in str(call_args)
|
||||
mock_discord_message.delete.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_key_not_found(
|
||||
self,
|
||||
mock_discord_message: MagicMock,
|
||||
mock_cache_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Key not in database DMs user and deletes message."""
|
||||
mock_discord_message.content = "!register discord_public.notexist"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_commands.parse_discord_registration_key",
|
||||
return_value="public",
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_commands.get_session_with_tenant"
|
||||
) as mock_session,
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_commands.get_guild_config_by_registration_key",
|
||||
return_value=None, # Not found
|
||||
),
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
# Must return False so exceptions are not suppressed
|
||||
mock_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_cache_manager.get_tenant.return_value = None
|
||||
|
||||
result = await handle_registration_command(
|
||||
mock_discord_message, mock_cache_manager
|
||||
)
|
||||
|
||||
assert result is True
|
||||
# On failure: DM the author and delete the message
|
||||
mock_discord_message.author.send.assert_called()
|
||||
call_args = mock_discord_message.author.send.call_args
|
||||
assert "not found" in str(call_args).lower()
|
||||
mock_discord_message.delete.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_key_already_used(
|
||||
self,
|
||||
mock_discord_message: MagicMock,
|
||||
mock_cache_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Previously used key DMs user and deletes message."""
|
||||
mock_discord_message.content = "!register discord_public.used_key"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_commands.parse_discord_registration_key",
|
||||
return_value="public",
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_commands.get_session_with_tenant"
|
||||
) as mock_session,
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_commands.get_guild_config_by_registration_key"
|
||||
) as mock_get_config,
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
# Must return False so exceptions are not suppressed
|
||||
mock_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.guild_id = 999999 # Already registered!
|
||||
mock_get_config.return_value = mock_config
|
||||
|
||||
mock_cache_manager.get_tenant.return_value = None
|
||||
|
||||
result = await handle_registration_command(
|
||||
mock_discord_message, mock_cache_manager
|
||||
)
|
||||
|
||||
assert result is True
|
||||
# On failure: DM the author and delete the message
|
||||
mock_discord_message.author.send.assert_called()
|
||||
call_args = mock_discord_message.author.send.call_args
|
||||
assert "already" in str(call_args).lower()
|
||||
mock_discord_message.delete.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_guild_already_registered(
|
||||
self,
|
||||
mock_discord_message: MagicMock,
|
||||
mock_cache_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Guild already in cache DMs user and deletes message."""
|
||||
mock_discord_message.content = "!register discord_public.valid_token"
|
||||
|
||||
with patch(
|
||||
"onyx.onyxbot.discord.handle_commands.parse_discord_registration_key",
|
||||
return_value="public",
|
||||
):
|
||||
# Guild already in cache
|
||||
mock_cache_manager.get_tenant.return_value = "existing_tenant"
|
||||
|
||||
result = await handle_registration_command(
|
||||
mock_discord_message, mock_cache_manager
|
||||
)
|
||||
|
||||
assert result is True
|
||||
# On failure: DM the author and delete the message
|
||||
mock_discord_message.author.send.assert_called()
|
||||
call_args = mock_discord_message.author.send.call_args
|
||||
assert "already registered" in str(call_args).lower()
|
||||
mock_discord_message.delete.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_no_permission(
|
||||
self,
|
||||
mock_discord_message: MagicMock,
|
||||
mock_cache_manager: MagicMock,
|
||||
) -> None:
|
||||
"""User without admin perms gets DM and message deleted."""
|
||||
mock_discord_message.content = "!register discord_public.valid_token"
|
||||
mock_discord_message.author.guild_permissions.administrator = False
|
||||
mock_discord_message.author.guild_permissions.manage_guild = False
|
||||
|
||||
result = await handle_registration_command(
|
||||
mock_discord_message, mock_cache_manager
|
||||
)
|
||||
|
||||
assert result is True
|
||||
# On failure: DM the author and delete the message
|
||||
mock_discord_message.author.send.assert_called()
|
||||
call_args = mock_discord_message.author.send.call_args
|
||||
assert "permission" in str(call_args).lower()
|
||||
mock_discord_message.delete.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_in_dm(
|
||||
self,
|
||||
mock_cache_manager: MagicMock,
|
||||
) -> None:
|
||||
"""Registration in DM sends DM and returns True."""
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.guild = None # DM
|
||||
msg.content = "!register discord_public.token"
|
||||
msg.author = MagicMock()
|
||||
msg.author.send = AsyncMock()
|
||||
|
||||
result = await handle_registration_command(msg, mock_cache_manager)
|
||||
|
||||
assert result is True
|
||||
msg.author.send.assert_called()
|
||||
call_args = msg.author.send.call_args
|
||||
assert "server" in str(call_args).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_syncs_forum_channels(
|
||||
self,
|
||||
mock_discord_message: MagicMock,
|
||||
mock_discord_guild: MagicMock,
|
||||
) -> None:
|
||||
"""Forum channels are included in sync."""
|
||||
channels = get_text_channels(mock_discord_guild)
|
||||
|
||||
channel_types = [c.channel_type for c in channels]
|
||||
assert "forum" in channel_types
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_private_channel_detection(
|
||||
self,
|
||||
mock_discord_message: MagicMock,
|
||||
mock_discord_guild: MagicMock,
|
||||
) -> None:
|
||||
"""Private channels are marked correctly."""
|
||||
channels = get_text_channels(mock_discord_guild)
|
||||
|
||||
private_channels = [c for c in channels if c.is_private]
|
||||
assert len(private_channels) >= 1
|
||||
|
||||
|
||||
class TestSyncChannelsCommand:
|
||||
"""Tests for !sync-channels command handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_channels_adds_new(
|
||||
self,
|
||||
mock_discord_message: MagicMock,
|
||||
mock_discord_bot: MagicMock,
|
||||
) -> None:
|
||||
"""New channel in Discord creates channel config."""
|
||||
mock_discord_message.content = "!sync-channels"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_commands.get_session_with_tenant"
|
||||
) as mock_session,
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_commands.get_guild_config_by_discord_id"
|
||||
) as mock_get_guild,
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_commands.get_guild_config_by_internal_id"
|
||||
) as mock_get_guild_internal,
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_commands.sync_channel_configs"
|
||||
) as mock_sync,
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.id = 1
|
||||
mock_config.guild_id = 123456789
|
||||
mock_get_guild.return_value = mock_config
|
||||
mock_get_guild_internal.return_value = mock_config
|
||||
|
||||
mock_sync.return_value = (1, 0, 0) # 1 added, 0 removed, 0 updated
|
||||
|
||||
mock_discord_bot.get_guild.return_value = mock_discord_message.guild
|
||||
|
||||
result = await handle_sync_channels_command(
|
||||
mock_discord_message, "public", mock_discord_bot
|
||||
)
|
||||
|
||||
assert result is True
|
||||
mock_discord_message.reply.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_channels_no_permission(
|
||||
self,
|
||||
mock_discord_message: MagicMock,
|
||||
mock_discord_bot: MagicMock,
|
||||
) -> None:
|
||||
"""User without admin perms gets DM and reaction."""
|
||||
mock_discord_message.content = "!sync-channels"
|
||||
mock_discord_message.author.guild_permissions.administrator = False
|
||||
mock_discord_message.author.guild_permissions.manage_guild = False
|
||||
|
||||
result = await handle_sync_channels_command(
|
||||
mock_discord_message, "public", mock_discord_bot
|
||||
)
|
||||
|
||||
assert result is True
|
||||
# On failure: DM the author and react with ❌
|
||||
mock_discord_message.author.send.assert_called()
|
||||
call_args = mock_discord_message.author.send.call_args
|
||||
assert "permission" in str(call_args).lower()
|
||||
mock_discord_message.add_reaction.assert_called_with("❌")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_channels_unregistered_guild(
|
||||
self,
|
||||
mock_discord_message: MagicMock,
|
||||
mock_discord_bot: MagicMock,
|
||||
) -> None:
|
||||
"""Sync in unregistered guild gets DM and reaction."""
|
||||
mock_discord_message.content = "!sync-channels"
|
||||
|
||||
# tenant_id is None = not registered
|
||||
result = await handle_sync_channels_command(
|
||||
mock_discord_message, None, mock_discord_bot
|
||||
)
|
||||
|
||||
assert result is True
|
||||
# On failure: DM the author and react with ❌
|
||||
mock_discord_message.author.send.assert_called()
|
||||
call_args = mock_discord_message.author.send.call_args
|
||||
assert "not registered" in str(call_args).lower()
|
||||
mock_discord_message.add_reaction.assert_called_with("❌")
|
||||
|
||||
|
||||
class TestMessageHandling:
|
||||
"""Tests for message handling behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_adds_thinking_emoji(
|
||||
self,
|
||||
mock_discord_message: MagicMock,
|
||||
mock_api_client: MagicMock,
|
||||
mock_bot_user: MagicMock,
|
||||
) -> None:
|
||||
"""Thinking emoji is added during processing."""
|
||||
await process_chat_message(
|
||||
message=mock_discord_message,
|
||||
api_key="test_key",
|
||||
persona_id=None,
|
||||
thread_only_mode=False,
|
||||
api_client=mock_api_client,
|
||||
bot_user=mock_bot_user,
|
||||
)
|
||||
|
||||
mock_discord_message.add_reaction.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_removes_thinking_emoji(
|
||||
self,
|
||||
mock_discord_message: MagicMock,
|
||||
mock_api_client: MagicMock,
|
||||
mock_bot_user: MagicMock,
|
||||
) -> None:
|
||||
"""Thinking emoji is removed after response."""
|
||||
await process_chat_message(
|
||||
message=mock_discord_message,
|
||||
api_key="test_key",
|
||||
persona_id=None,
|
||||
thread_only_mode=False,
|
||||
api_client=mock_api_client,
|
||||
bot_user=mock_bot_user,
|
||||
)
|
||||
|
||||
mock_discord_message.remove_reaction.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_reaction_failure_non_blocking(
|
||||
self,
|
||||
mock_discord_message: MagicMock,
|
||||
mock_api_client: MagicMock,
|
||||
mock_bot_user: MagicMock,
|
||||
) -> None:
|
||||
"""add_reaction failure doesn't block processing."""
|
||||
mock_discord_message.add_reaction = AsyncMock(
|
||||
side_effect=discord.DiscordException("Cannot add reaction")
|
||||
)
|
||||
|
||||
# Should not raise - just log warning and continue
|
||||
await process_chat_message(
|
||||
message=mock_discord_message,
|
||||
api_key="test_key",
|
||||
persona_id=None,
|
||||
thread_only_mode=False,
|
||||
api_client=mock_api_client,
|
||||
bot_user=mock_bot_user,
|
||||
)
|
||||
|
||||
# Should still complete and send reply
|
||||
mock_discord_message.reply.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dm_response(self) -> None:
|
||||
"""DM to bot sends redirect message."""
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.channel = MagicMock(spec=discord.DMChannel)
|
||||
msg.channel.send = AsyncMock()
|
||||
|
||||
await handle_dm(msg)
|
||||
|
||||
msg.channel.send.assert_called_once()
|
||||
call_args = msg.channel.send.call_args
|
||||
assert "DM" in str(call_args) or "server" in str(call_args).lower()
|
||||
|
||||
|
||||
class TestThreadCreationAndResponseRouting:
|
||||
"""Tests for thread creation and response routing."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_in_existing_thread(
|
||||
self,
|
||||
mock_bot_user: MagicMock,
|
||||
) -> None:
|
||||
"""Message in thread - response appended to thread."""
|
||||
thread = MagicMock(spec=discord.Thread)
|
||||
thread.send = AsyncMock()
|
||||
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.channel = thread
|
||||
msg.reply = AsyncMock()
|
||||
msg.create_thread = AsyncMock()
|
||||
|
||||
await send_response(msg, "Test response", thread_only_mode=False)
|
||||
|
||||
# Should send to thread, not create new thread
|
||||
thread.send.assert_called()
|
||||
msg.create_thread.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_creates_thread_thread_only_mode(
|
||||
self,
|
||||
mock_discord_message: MagicMock,
|
||||
mock_bot_user: MagicMock,
|
||||
) -> None:
|
||||
"""thread_only_mode=true creates new thread for response."""
|
||||
mock_thread = MagicMock()
|
||||
mock_thread.send = AsyncMock()
|
||||
mock_discord_message.create_thread = AsyncMock(return_value=mock_thread)
|
||||
|
||||
# Make sure it's not a thread
|
||||
mock_discord_message.channel = MagicMock(spec=discord.TextChannel)
|
||||
|
||||
await send_response(
|
||||
mock_discord_message, "Test response", thread_only_mode=True
|
||||
)
|
||||
|
||||
mock_discord_message.create_thread.assert_called()
|
||||
mock_thread.send.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_replies_inline(
|
||||
self,
|
||||
mock_discord_message: MagicMock,
|
||||
mock_bot_user: MagicMock,
|
||||
) -> None:
|
||||
"""thread_only_mode=false uses message.reply()."""
|
||||
# Make sure it's not a thread
|
||||
mock_discord_message.channel = MagicMock(spec=discord.TextChannel)
|
||||
|
||||
await send_response(
|
||||
mock_discord_message, "Test response", thread_only_mode=False
|
||||
)
|
||||
|
||||
mock_discord_message.reply.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thread_name_truncation(
|
||||
self,
|
||||
mock_bot_user: MagicMock,
|
||||
) -> None:
|
||||
"""Thread name is truncated to 100 chars."""
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.channel = MagicMock(spec=discord.TextChannel)
|
||||
msg.author = MagicMock()
|
||||
msg.author.display_name = "A" * 200 # Very long name
|
||||
|
||||
mock_thread = MagicMock()
|
||||
mock_thread.send = AsyncMock()
|
||||
msg.create_thread = AsyncMock(return_value=mock_thread)
|
||||
|
||||
await send_response(msg, "Test", thread_only_mode=True)
|
||||
|
||||
call_args = msg.create_thread.call_args
|
||||
thread_name = call_args.kwargs.get("name") or call_args[1].get("name")
|
||||
assert len(thread_name) <= 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_response_creates_thread(
|
||||
self,
|
||||
mock_discord_message: MagicMock,
|
||||
mock_bot_user: MagicMock,
|
||||
) -> None:
|
||||
"""Error response in channel creates thread."""
|
||||
mock_discord_message.channel = MagicMock(spec=discord.TextChannel)
|
||||
mock_thread = MagicMock()
|
||||
mock_thread.send = AsyncMock()
|
||||
mock_discord_message.create_thread = AsyncMock(return_value=mock_thread)
|
||||
|
||||
await send_error_response(mock_discord_message, mock_bot_user)
|
||||
|
||||
mock_discord_message.create_thread.assert_called()
|
||||
|
||||
|
||||
class TestBotLifecycle:
|
||||
"""Tests for bot lifecycle management."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_hook_initializes_cache(
|
||||
self,
|
||||
mock_cache_manager: MagicMock,
|
||||
mock_api_client: MagicMock,
|
||||
) -> None:
|
||||
"""setup_hook calls cache.refresh_all()."""
|
||||
from onyx.onyxbot.discord.client import OnyxDiscordClient
|
||||
|
||||
with (
|
||||
patch.object(OnyxDiscordClient, "__init__", lambda self: None),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.client.DiscordCacheManager",
|
||||
return_value=mock_cache_manager,
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.client.OnyxAPIClient",
|
||||
return_value=mock_api_client,
|
||||
),
|
||||
):
|
||||
bot = OnyxDiscordClient()
|
||||
bot.cache = mock_cache_manager
|
||||
bot.api_client = mock_api_client
|
||||
bot.loop = MagicMock()
|
||||
bot.loop.create_task = MagicMock()
|
||||
|
||||
await bot.setup_hook()
|
||||
|
||||
mock_cache_manager.refresh_all.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_hook_initializes_api_client(
|
||||
self,
|
||||
mock_cache_manager: MagicMock,
|
||||
mock_api_client: MagicMock,
|
||||
) -> None:
|
||||
"""setup_hook calls api_client.initialize()."""
|
||||
from onyx.onyxbot.discord.client import OnyxDiscordClient
|
||||
|
||||
with (patch.object(OnyxDiscordClient, "__init__", lambda self: None),):
|
||||
bot = OnyxDiscordClient()
|
||||
bot.cache = mock_cache_manager
|
||||
bot.api_client = mock_api_client
|
||||
bot.loop = MagicMock()
|
||||
bot.loop.create_task = MagicMock()
|
||||
|
||||
await bot.setup_hook()
|
||||
|
||||
mock_api_client.initialize.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_closes_api_client(
|
||||
self,
|
||||
mock_cache_manager: MagicMock,
|
||||
mock_api_client: MagicMock,
|
||||
) -> None:
|
||||
"""close() calls api_client.close()."""
|
||||
from onyx.onyxbot.discord.client import OnyxDiscordClient
|
||||
|
||||
with (
|
||||
patch.object(OnyxDiscordClient, "__init__", lambda self: None),
|
||||
patch.object(OnyxDiscordClient, "is_closed", return_value=True),
|
||||
):
|
||||
bot = OnyxDiscordClient()
|
||||
bot.cache = mock_cache_manager
|
||||
bot.api_client = mock_api_client
|
||||
bot._cache_refresh_task = None
|
||||
bot.ready = True
|
||||
|
||||
# Mock parent close
|
||||
async def mock_super_close() -> None:
|
||||
pass
|
||||
|
||||
with patch("discord.ext.commands.Bot.close", mock_super_close):
|
||||
await bot.close()
|
||||
|
||||
mock_api_client.close.assert_called()
|
||||
mock_cache_manager.clear.assert_called()
|
||||
@@ -476,8 +476,8 @@ class ChatSessionManager:
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
)
|
||||
# Chat session should return 400 if it doesn't exist
|
||||
return response.status_code == 400
|
||||
# Chat session should return 404 if it doesn't exist or is deleted
|
||||
return response.status_code == 404
|
||||
|
||||
@staticmethod
|
||||
def verify_soft_deleted(
|
||||
|
||||
310
backend/tests/integration/common_utils/managers/discord_bot.py
Normal file
310
backend/tests/integration/common_utils/managers/discord_bot.py
Normal file
@@ -0,0 +1,310 @@
|
||||
"""Manager for Discord bot API integration tests."""
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.db.discord_bot import create_channel_config
|
||||
from onyx.db.discord_bot import create_guild_config
|
||||
from onyx.db.discord_bot import register_guild
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.utils import DiscordChannelView
|
||||
from onyx.server.manage.discord_bot.utils import generate_discord_registration_key
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.test_models import DATestDiscordChannelConfig
|
||||
from tests.integration.common_utils.test_models import DATestDiscordGuildConfig
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
DISCORD_BOT_API_URL = f"{API_SERVER_URL}/manage/admin/discord-bot"
|
||||
|
||||
|
||||
class DiscordBotManager:
|
||||
"""Manager for Discord bot API operations."""
|
||||
|
||||
# === Bot Config ===
|
||||
|
||||
@staticmethod
|
||||
def get_bot_config(
|
||||
user_performing_action: DATestUser,
|
||||
) -> dict:
|
||||
"""Get Discord bot config."""
|
||||
response = requests.get(
|
||||
url=f"{DISCORD_BOT_API_URL}/config",
|
||||
headers=user_performing_action.headers,
|
||||
cookies=user_performing_action.cookies,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def create_bot_config(
|
||||
bot_token: str,
|
||||
user_performing_action: DATestUser,
|
||||
) -> dict:
|
||||
"""Create Discord bot config."""
|
||||
response = requests.post(
|
||||
url=f"{DISCORD_BOT_API_URL}/config",
|
||||
headers=user_performing_action.headers,
|
||||
cookies=user_performing_action.cookies,
|
||||
json={"bot_token": bot_token},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def delete_bot_config(
|
||||
user_performing_action: DATestUser,
|
||||
) -> dict:
|
||||
"""Delete Discord bot config."""
|
||||
response = requests.delete(
|
||||
url=f"{DISCORD_BOT_API_URL}/config",
|
||||
headers=user_performing_action.headers,
|
||||
cookies=user_performing_action.cookies,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
# === Guild Config ===
|
||||
|
||||
@staticmethod
|
||||
def list_guilds(
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[dict]:
|
||||
"""List all guild configs."""
|
||||
response = requests.get(
|
||||
url=f"{DISCORD_BOT_API_URL}/guilds",
|
||||
headers=user_performing_action.headers,
|
||||
cookies=user_performing_action.cookies,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def create_guild(
|
||||
user_performing_action: DATestUser,
|
||||
) -> DATestDiscordGuildConfig:
|
||||
"""Create a new guild config with registration key."""
|
||||
response = requests.post(
|
||||
url=f"{DISCORD_BOT_API_URL}/guilds",
|
||||
headers=user_performing_action.headers,
|
||||
cookies=user_performing_action.cookies,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return DATestDiscordGuildConfig(
|
||||
id=data["id"],
|
||||
registration_key=data["registration_key"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_guild(
|
||||
config_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
) -> dict:
|
||||
"""Get a specific guild config."""
|
||||
response = requests.get(
|
||||
url=f"{DISCORD_BOT_API_URL}/guilds/{config_id}",
|
||||
headers=user_performing_action.headers,
|
||||
cookies=user_performing_action.cookies,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def update_guild(
|
||||
config_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
enabled: bool | None = None,
|
||||
default_persona_id: int | None = None,
|
||||
) -> dict:
|
||||
"""Update a guild config."""
|
||||
# Fetch current guild config to get existing values
|
||||
current_guild = DiscordBotManager.get_guild(config_id, user_performing_action)
|
||||
|
||||
# Build request body with required fields
|
||||
body: dict = {
|
||||
"enabled": enabled if enabled is not None else current_guild["enabled"],
|
||||
"default_persona_id": (
|
||||
default_persona_id
|
||||
if default_persona_id is not None
|
||||
else current_guild.get("default_persona_id")
|
||||
),
|
||||
}
|
||||
|
||||
response = requests.patch(
|
||||
url=f"{DISCORD_BOT_API_URL}/guilds/{config_id}",
|
||||
headers=user_performing_action.headers,
|
||||
cookies=user_performing_action.cookies,
|
||||
json=body,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def delete_guild(
|
||||
config_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
) -> dict:
|
||||
"""Delete a guild config."""
|
||||
response = requests.delete(
|
||||
url=f"{DISCORD_BOT_API_URL}/guilds/{config_id}",
|
||||
headers=user_performing_action.headers,
|
||||
cookies=user_performing_action.cookies,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
# === Channel Config ===
|
||||
|
||||
@staticmethod
|
||||
def list_channels(
|
||||
guild_config_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[DATestDiscordChannelConfig]:
|
||||
"""List all channel configs for a guild."""
|
||||
response = requests.get(
|
||||
url=f"{DISCORD_BOT_API_URL}/guilds/{guild_config_id}/channels",
|
||||
headers=user_performing_action.headers,
|
||||
cookies=user_performing_action.cookies,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [DATestDiscordChannelConfig(**c) for c in response.json()]
|
||||
|
||||
@staticmethod
|
||||
def update_channel(
|
||||
guild_config_id: int,
|
||||
channel_config_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
enabled: bool = False,
|
||||
thread_only_mode: bool = False,
|
||||
require_bot_invocation: bool = True,
|
||||
persona_override_id: int | None = None,
|
||||
) -> DATestDiscordChannelConfig:
|
||||
"""Update a channel config.
|
||||
|
||||
All fields are required by the API. Default values match the channel
|
||||
config defaults from create_channel_config.
|
||||
"""
|
||||
body: dict = {
|
||||
"enabled": enabled,
|
||||
"thread_only_mode": thread_only_mode,
|
||||
"require_bot_invocation": require_bot_invocation,
|
||||
"persona_override_id": persona_override_id,
|
||||
}
|
||||
|
||||
response = requests.patch(
|
||||
url=f"{DISCORD_BOT_API_URL}/guilds/{guild_config_id}/channels/{channel_config_id}",
|
||||
headers=user_performing_action.headers,
|
||||
cookies=user_performing_action.cookies,
|
||||
json=body,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return DATestDiscordChannelConfig(**response.json())
|
||||
|
||||
# === Utility methods for testing ===
|
||||
|
||||
@staticmethod
|
||||
def create_registered_guild_in_db(
|
||||
guild_id: int,
|
||||
guild_name: str,
|
||||
) -> DATestDiscordGuildConfig:
|
||||
"""Create a registered guild config directly in the database.
|
||||
|
||||
This creates a guild that has already completed registration,
|
||||
with guild_id and guild_name set. Use this for testing channel
|
||||
endpoints which require a registered guild.
|
||||
"""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
tenant_id = get_current_tenant_id()
|
||||
registration_key = generate_discord_registration_key(tenant_id)
|
||||
config = create_guild_config(db_session, registration_key)
|
||||
config = register_guild(db_session, config, guild_id, guild_name)
|
||||
db_session.commit()
|
||||
|
||||
return DATestDiscordGuildConfig(
|
||||
id=config.id,
|
||||
registration_key=registration_key,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_guild_or_none(
|
||||
config_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
) -> dict | None:
|
||||
"""Get a guild config, returning None if not found."""
|
||||
response = requests.get(
|
||||
url=f"{DISCORD_BOT_API_URL}/guilds/{config_id}",
|
||||
headers=user_performing_action.headers,
|
||||
cookies=user_performing_action.cookies,
|
||||
)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def delete_guild_if_exists(
|
||||
config_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
"""Delete a guild config if it exists. Returns True if deleted."""
|
||||
response = requests.delete(
|
||||
url=f"{DISCORD_BOT_API_URL}/guilds/{config_id}",
|
||||
headers=user_performing_action.headers,
|
||||
cookies=user_performing_action.cookies,
|
||||
)
|
||||
if response.status_code == 404:
|
||||
return False
|
||||
response.raise_for_status()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def delete_bot_config_if_exists(
|
||||
user_performing_action: DATestUser,
|
||||
) -> bool:
|
||||
"""Delete bot config if it exists. Returns True if deleted."""
|
||||
response = requests.delete(
|
||||
url=f"{DISCORD_BOT_API_URL}/config",
|
||||
headers=user_performing_action.headers,
|
||||
cookies=user_performing_action.cookies,
|
||||
)
|
||||
if response.status_code == 404:
|
||||
return False
|
||||
response.raise_for_status()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def create_test_channel_in_db(
|
||||
guild_config_id: int,
|
||||
channel_id: int,
|
||||
channel_name: str,
|
||||
channel_type: str = "text",
|
||||
is_private: bool = False,
|
||||
) -> DATestDiscordChannelConfig:
|
||||
"""Create a test channel config directly in the database.
|
||||
|
||||
This is needed because channels are normally synced from Discord,
|
||||
not created via API. For testing the channel API endpoints,
|
||||
we need to populate test data directly.
|
||||
"""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
channel_view = DiscordChannelView(
|
||||
channel_id=channel_id,
|
||||
channel_name=channel_name,
|
||||
channel_type=channel_type,
|
||||
is_private=is_private,
|
||||
)
|
||||
config = create_channel_config(db_session, guild_config_id, channel_view)
|
||||
db_session.commit()
|
||||
|
||||
return DATestDiscordChannelConfig(
|
||||
id=config.id,
|
||||
guild_config_id=config.guild_config_id,
|
||||
channel_id=config.channel_id,
|
||||
channel_name=config.channel_name,
|
||||
channel_type=config.channel_type,
|
||||
is_private=config.is_private,
|
||||
enabled=config.enabled,
|
||||
thread_only_mode=config.thread_only_mode,
|
||||
require_bot_invocation=config.require_bot_invocation,
|
||||
persona_override_id=config.persona_override_id,
|
||||
)
|
||||
@@ -273,3 +273,30 @@ class DATestTool(BaseModel):
|
||||
description: str
|
||||
display_name: str
|
||||
in_code_tool_id: str | None
|
||||
|
||||
|
||||
# Discord Bot Models
|
||||
class DATestDiscordGuildConfig(BaseModel):
|
||||
"""Discord guild config model for testing."""
|
||||
|
||||
id: int
|
||||
registration_key: str | None = None # Only present on creation
|
||||
guild_id: int | None = None
|
||||
guild_name: str | None = None
|
||||
enabled: bool = True
|
||||
default_persona_id: int | None = None
|
||||
|
||||
|
||||
class DATestDiscordChannelConfig(BaseModel):
|
||||
"""Discord channel config model for testing."""
|
||||
|
||||
id: int
|
||||
guild_config_id: int
|
||||
channel_id: int
|
||||
channel_name: str
|
||||
channel_type: str
|
||||
is_private: bool
|
||||
enabled: bool = False
|
||||
thread_only_mode: bool = False
|
||||
require_bot_invocation: bool = True
|
||||
persona_override_id: int | None = None
|
||||
|
||||
@@ -0,0 +1,456 @@
|
||||
"""Multi-tenant isolation tests for Discord bot.
|
||||
|
||||
These tests ensure tenant isolation and prevent data leakage between tenants.
|
||||
Tests follow the multi-tenant integration test pattern using API requests.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.db.discord_bot import get_guild_config_by_registration_key
|
||||
from onyx.db.discord_bot import register_guild
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.onyxbot.discord.cache import DiscordCacheManager
|
||||
from onyx.server.manage.discord_bot.utils import generate_discord_registration_key
|
||||
from onyx.server.manage.discord_bot.utils import parse_discord_registration_key
|
||||
from onyx.server.manage.discord_bot.utils import REGISTRATION_KEY_PREFIX
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
class TestBotConfigIsolationCloudMode:
|
||||
"""Tests for bot config isolation in cloud mode."""
|
||||
|
||||
def test_cannot_create_bot_config_in_cloud_mode(self) -> None:
|
||||
"""Bot config creation is blocked in cloud mode."""
|
||||
with patch("onyx.configs.app_configs.AUTH_TYPE", AuthType.CLOUD):
|
||||
from fastapi import HTTPException
|
||||
|
||||
from onyx.server.manage.discord_bot.api import _check_bot_config_api_access
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
_check_bot_config_api_access()
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "Cloud" in str(exc_info.value.detail)
|
||||
|
||||
def test_bot_token_from_env_only_in_cloud(self) -> None:
|
||||
"""Bot token comes from env var in cloud mode, ignores DB."""
|
||||
from onyx.onyxbot.discord.utils import get_bot_token
|
||||
|
||||
with (
|
||||
patch("onyx.onyxbot.discord.utils.DISCORD_BOT_TOKEN", "env_token"),
|
||||
patch("onyx.onyxbot.discord.utils.AUTH_TYPE", AuthType.CLOUD),
|
||||
):
|
||||
result = get_bot_token()
|
||||
|
||||
assert result == "env_token"
|
||||
|
||||
|
||||
class TestGuildRegistrationIsolation:
|
||||
"""Tests for guild registration isolation between tenants."""
|
||||
|
||||
def test_guild_can_only_register_to_one_tenant(self) -> None:
|
||||
"""Guild registered to tenant 1 cannot be registered to tenant 2."""
|
||||
cache = DiscordCacheManager()
|
||||
|
||||
# Register guild to tenant 1
|
||||
cache._guild_tenants[123456789] = "tenant1"
|
||||
|
||||
# Check if guild is already registered
|
||||
existing = cache.get_tenant(123456789)
|
||||
|
||||
assert existing is not None
|
||||
assert existing == "tenant1"
|
||||
|
||||
def test_registration_key_tenant_mismatch(self) -> None:
|
||||
"""Key created in tenant 1 cannot be used in tenant 2 context."""
|
||||
key = generate_discord_registration_key("tenant1")
|
||||
|
||||
# Parse the key to get tenant
|
||||
parsed_tenant = parse_discord_registration_key(key)
|
||||
|
||||
assert parsed_tenant == "tenant1"
|
||||
assert parsed_tenant != "tenant2"
|
||||
|
||||
def test_registration_key_encodes_correct_tenant(self) -> None:
|
||||
"""Key format discord_<tenant_id>.<token> encodes correct tenant."""
|
||||
tenant_id = "my_tenant_123"
|
||||
key = generate_discord_registration_key(tenant_id)
|
||||
|
||||
assert key.startswith(REGISTRATION_KEY_PREFIX)
|
||||
assert "my_tenant_123" in key or "my%5Ftenant%5F123" in key
|
||||
|
||||
parsed = parse_discord_registration_key(key)
|
||||
assert parsed == tenant_id
|
||||
|
||||
|
||||
class TestGuildDataIsolation:
|
||||
"""Tests for guild data isolation between tenants via API."""
|
||||
|
||||
def test_tenant_cannot_see_other_tenant_guilds(
|
||||
self, reset_multitenant: None
|
||||
) -> None:
|
||||
"""Guilds created in tenant 1 are not visible from tenant 2.
|
||||
|
||||
Creates guilds via API in tenant 1, then queries from tenant 2
|
||||
context to verify the guilds are not visible.
|
||||
"""
|
||||
unique = uuid4().hex
|
||||
|
||||
# Create admin user for tenant 1
|
||||
admin_user1: DATestUser = UserManager.create(
|
||||
email=f"discord_admin1+{unique}@example.com",
|
||||
)
|
||||
assert UserManager.is_role(admin_user1, UserRole.ADMIN)
|
||||
|
||||
# Create admin user for tenant 2
|
||||
admin_user2: DATestUser = UserManager.create(
|
||||
email=f"discord_admin2+{unique}@example.com",
|
||||
)
|
||||
assert UserManager.is_role(admin_user2, UserRole.ADMIN)
|
||||
|
||||
# Create a guild registration key in tenant 1
|
||||
response1 = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/discord-bot/guilds",
|
||||
headers=admin_user1.headers,
|
||||
)
|
||||
|
||||
# If Discord bot feature is not enabled, skip the test
|
||||
if response1.status_code == 404:
|
||||
pytest.skip("Discord bot feature not enabled")
|
||||
|
||||
assert response1.ok, f"Failed to create guild in tenant 1: {response1.text}"
|
||||
guild1_data = response1.json()
|
||||
guild1_id = guild1_data["id"]
|
||||
|
||||
try:
|
||||
# List guilds from tenant 1 - should see the guild
|
||||
list_response1 = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/discord-bot/guilds",
|
||||
headers=admin_user1.headers,
|
||||
)
|
||||
assert list_response1.ok
|
||||
tenant1_guilds = list_response1.json()
|
||||
tenant1_guild_ids = [g["id"] for g in tenant1_guilds]
|
||||
assert guild1_id in tenant1_guild_ids
|
||||
|
||||
# List guilds from tenant 2 - should NOT see tenant 1's guild
|
||||
list_response2 = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/discord-bot/guilds",
|
||||
headers=admin_user2.headers,
|
||||
)
|
||||
assert list_response2.ok
|
||||
tenant2_guilds = list_response2.json()
|
||||
tenant2_guild_ids = [g["id"] for g in tenant2_guilds]
|
||||
assert guild1_id not in tenant2_guild_ids
|
||||
|
||||
finally:
|
||||
# Cleanup - delete guild from tenant 1
|
||||
requests.delete(
|
||||
f"{API_SERVER_URL}/manage/admin/discord-bot/guilds/{guild1_id}",
|
||||
headers=admin_user1.headers,
|
||||
)
|
||||
|
||||
def test_guild_list_returns_only_own_tenant(self, reset_multitenant: None) -> None:
|
||||
"""List guilds returns exactly the guilds for that tenant.
|
||||
|
||||
Creates 1 guild in each tenant, registers them with different data,
|
||||
and verifies each tenant only sees their own guild.
|
||||
"""
|
||||
unique = uuid4().hex
|
||||
|
||||
# Create admin users for two tenants
|
||||
admin_user1: DATestUser = UserManager.create(
|
||||
email=f"discord_list1+{unique}@example.com",
|
||||
)
|
||||
admin_user2: DATestUser = UserManager.create(
|
||||
email=f"discord_list2+{unique}@example.com",
|
||||
)
|
||||
|
||||
# Create 1 guild in tenant 1
|
||||
response1 = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/discord-bot/guilds",
|
||||
headers=admin_user1.headers,
|
||||
)
|
||||
if response1.status_code == 404:
|
||||
pytest.skip("Discord bot feature not enabled")
|
||||
assert response1.ok, f"Failed to create guild in tenant 1: {response1.text}"
|
||||
guild1_data = response1.json()
|
||||
guild1_id = guild1_data["id"]
|
||||
registration_key1 = guild1_data["registration_key"]
|
||||
tenant1_id = parse_discord_registration_key(registration_key1)
|
||||
assert (
|
||||
tenant1_id is not None
|
||||
), "Failed to parse tenant ID from registration key 1"
|
||||
|
||||
# Create 1 guild in tenant 2
|
||||
response2 = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/discord-bot/guilds",
|
||||
headers=admin_user2.headers,
|
||||
)
|
||||
assert response2.ok, f"Failed to create guild in tenant 2: {response2.text}"
|
||||
guild2_data = response2.json()
|
||||
guild2_id = guild2_data["id"]
|
||||
registration_key2 = guild2_data["registration_key"]
|
||||
tenant2_id = parse_discord_registration_key(registration_key2)
|
||||
assert (
|
||||
tenant2_id is not None
|
||||
), "Failed to parse tenant ID from registration key 2"
|
||||
|
||||
# Verify tenant IDs are different
|
||||
assert (
|
||||
tenant1_id != tenant2_id
|
||||
), "Tenant 1 and tenant 2 should have different tenant IDs"
|
||||
|
||||
# Register guild 1 with tenant 1's context - populate with different data
|
||||
with get_session_with_tenant(tenant_id=tenant1_id) as db_session:
|
||||
config1 = get_guild_config_by_registration_key(
|
||||
db_session, registration_key1
|
||||
)
|
||||
assert config1 is not None, "Guild config 1 should exist"
|
||||
register_guild(
|
||||
db_session=db_session,
|
||||
config=config1,
|
||||
guild_id=111111111111111111, # Different Discord guild ID
|
||||
guild_name="Tenant 1 Server", # Different guild name
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
# Register guild 2 with tenant 2's context - populate with different data
|
||||
with get_session_with_tenant(tenant_id=tenant2_id) as db_session:
|
||||
config2 = get_guild_config_by_registration_key(
|
||||
db_session, registration_key2
|
||||
)
|
||||
assert config2 is not None, "Guild config 2 should exist"
|
||||
register_guild(
|
||||
db_session=db_session,
|
||||
config=config2,
|
||||
guild_id=222222222222222222, # Different Discord guild ID
|
||||
guild_name="Tenant 2 Server", # Different guild name
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
try:
|
||||
# Verify tenant 1 sees only their guild
|
||||
list_response1 = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/discord-bot/guilds",
|
||||
headers=admin_user1.headers,
|
||||
)
|
||||
assert list_response1.ok
|
||||
tenant1_guilds = list_response1.json()
|
||||
|
||||
# Tenant 1 should see exactly 1 guild
|
||||
assert (
|
||||
len(tenant1_guilds) == 1
|
||||
), f"Tenant 1 should see 1 guild, got {len(tenant1_guilds)}"
|
||||
|
||||
# Verify tenant 1's guild has the correct data
|
||||
tenant1_guild = tenant1_guilds[0]
|
||||
assert (
|
||||
tenant1_guild["id"] == guild1_id
|
||||
), "Tenant 1 should see their own guild"
|
||||
assert tenant1_guild["guild_id"] == 111111111111111111, (
|
||||
f"Tenant 1's guild should have guild_id 111111111111111111, "
|
||||
f"got {tenant1_guild['guild_id']}"
|
||||
)
|
||||
assert tenant1_guild["guild_name"] == "Tenant 1 Server", (
|
||||
f"Tenant 1's guild should have name 'Tenant 1 Server', "
|
||||
f"got {tenant1_guild['guild_name']}"
|
||||
)
|
||||
assert (
|
||||
tenant1_guild["registered_at"] is not None
|
||||
), "Tenant 1's guild should be registered"
|
||||
|
||||
# Tenant 1 should NOT see tenant 2's guild
|
||||
assert (
|
||||
tenant1_guild["guild_id"] != 222222222222222222
|
||||
), "Tenant 1 should not see tenant 2's guild_id"
|
||||
assert (
|
||||
tenant1_guild["guild_name"] != "Tenant 2 Server"
|
||||
), "Tenant 1 should not see tenant 2's guild_name"
|
||||
|
||||
# Verify tenant 2 sees only their guild
|
||||
list_response2 = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/discord-bot/guilds",
|
||||
headers=admin_user2.headers,
|
||||
)
|
||||
assert list_response2.ok
|
||||
tenant2_guilds = list_response2.json()
|
||||
|
||||
# Tenant 2 should see exactly 1 guild
|
||||
assert (
|
||||
len(tenant2_guilds) == 1
|
||||
), f"Tenant 2 should see 1 guild, got {len(tenant2_guilds)}"
|
||||
|
||||
# Verify tenant 2's guild has the correct data
|
||||
tenant2_guild = tenant2_guilds[0]
|
||||
assert (
|
||||
tenant2_guild["id"] == guild2_id
|
||||
), "Tenant 2 should see their own guild"
|
||||
assert tenant2_guild["guild_id"] == 222222222222222222, (
|
||||
f"Tenant 2's guild should have guild_id 222222222222222222, "
|
||||
f"got {tenant2_guild['guild_id']}"
|
||||
)
|
||||
assert tenant2_guild["guild_name"] == "Tenant 2 Server", (
|
||||
f"Tenant 2's guild should have name 'Tenant 2 Server', "
|
||||
f"got {tenant2_guild['guild_name']}"
|
||||
)
|
||||
assert (
|
||||
tenant2_guild["registered_at"] is not None
|
||||
), "Tenant 2's guild should be registered"
|
||||
|
||||
# Tenant 2 should NOT see tenant 1's guild
|
||||
assert (
|
||||
tenant2_guild["guild_id"] != 111111111111111111
|
||||
), "Tenant 2 should not see tenant 1's guild_id"
|
||||
assert (
|
||||
tenant2_guild["guild_name"] != "Tenant 1 Server"
|
||||
), "Tenant 2 should not see tenant 1's guild_name"
|
||||
|
||||
# Verify the guilds are different (different data)
|
||||
assert (
|
||||
tenant1_guild["guild_id"] != tenant2_guild["guild_id"]
|
||||
), "Guilds should have different Discord guild IDs"
|
||||
assert (
|
||||
tenant1_guild["guild_name"] != tenant2_guild["guild_name"]
|
||||
), "Guilds should have different names"
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
requests.delete(
|
||||
f"{API_SERVER_URL}/manage/admin/discord-bot/guilds/{guild1_id}",
|
||||
headers=admin_user1.headers,
|
||||
)
|
||||
requests.delete(
|
||||
f"{API_SERVER_URL}/manage/admin/discord-bot/guilds/{guild2_id}",
|
||||
headers=admin_user2.headers,
|
||||
)
|
||||
|
||||
|
||||
class TestGuildAccessIsolation:
|
||||
"""Tests for guild access isolation between tenants."""
|
||||
|
||||
def test_tenant_cannot_access_other_tenant_guild(
|
||||
self, reset_multitenant: None
|
||||
) -> None:
|
||||
"""Tenant 2 cannot access or modify tenant 1's guild by ID.
|
||||
|
||||
Creates a guild in tenant 1, then attempts to access it from tenant 2.
|
||||
"""
|
||||
unique = uuid4().hex
|
||||
|
||||
# Create admin users for two tenants
|
||||
admin_user1: DATestUser = UserManager.create(
|
||||
email=f"discord_access1+{unique}@example.com",
|
||||
)
|
||||
admin_user2: DATestUser = UserManager.create(
|
||||
email=f"discord_access2+{unique}@example.com",
|
||||
)
|
||||
|
||||
# Create a guild in tenant 1
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/discord-bot/guilds",
|
||||
headers=admin_user1.headers,
|
||||
)
|
||||
if response.status_code == 404:
|
||||
pytest.skip("Discord bot feature not enabled")
|
||||
assert response.ok
|
||||
guild1_id = response.json()["id"]
|
||||
|
||||
try:
|
||||
# Tenant 2 tries to get the guild - should fail (404 or 403)
|
||||
get_response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/discord-bot/guilds/{guild1_id}",
|
||||
headers=admin_user2.headers,
|
||||
)
|
||||
# Should either return 404 (not found) or 403 (forbidden)
|
||||
assert get_response.status_code in [
|
||||
403,
|
||||
404,
|
||||
], f"Expected 403 or 404, got {get_response.status_code}"
|
||||
|
||||
# Tenant 2 tries to delete the guild - should fail
|
||||
delete_response = requests.delete(
|
||||
f"{API_SERVER_URL}/manage/admin/discord-bot/guilds/{guild1_id}",
|
||||
headers=admin_user2.headers,
|
||||
)
|
||||
assert delete_response.status_code in [403, 404]
|
||||
|
||||
finally:
|
||||
# Cleanup - delete from tenant 1
|
||||
requests.delete(
|
||||
f"{API_SERVER_URL}/manage/admin/discord-bot/guilds/{guild1_id}",
|
||||
headers=admin_user1.headers,
|
||||
)
|
||||
|
||||
|
||||
class TestCacheManagerIsolation:
|
||||
"""Tests for cache manager tenant isolation."""
|
||||
|
||||
def test_cache_maps_guild_to_correct_tenant(self) -> None:
|
||||
"""Cache correctly maps guild_id to tenant_id."""
|
||||
cache = DiscordCacheManager()
|
||||
|
||||
# Set up mappings
|
||||
cache._guild_tenants[111] = "tenant1"
|
||||
cache._guild_tenants[222] = "tenant2"
|
||||
cache._guild_tenants[333] = "tenant1"
|
||||
|
||||
assert cache.get_tenant(111) == "tenant1"
|
||||
assert cache.get_tenant(222) == "tenant2"
|
||||
assert cache.get_tenant(333) == "tenant1"
|
||||
assert cache.get_tenant(444) is None
|
||||
|
||||
def test_api_key_per_tenant_isolation(self) -> None:
|
||||
"""Each tenant has unique API key."""
|
||||
cache = DiscordCacheManager()
|
||||
|
||||
cache._api_keys["tenant1"] = "key_for_tenant1"
|
||||
cache._api_keys["tenant2"] = "key_for_tenant2"
|
||||
|
||||
assert cache.get_api_key("tenant1") == "key_for_tenant1"
|
||||
assert cache.get_api_key("tenant2") == "key_for_tenant2"
|
||||
assert cache.get_api_key("tenant1") != cache.get_api_key("tenant2")
|
||||
|
||||
|
||||
class TestAPIRequestIsolation:
|
||||
"""Tests for API request isolation between tenants."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_bot_uses_tenant_specific_api_key(self) -> None:
|
||||
"""Message from guild in tenant 1 uses tenant 1's API key."""
|
||||
cache = DiscordCacheManager()
|
||||
cache._guild_tenants[123456] = "tenant1"
|
||||
cache._api_keys["tenant1"] = "tenant1_api_key"
|
||||
cache._api_keys["tenant2"] = "tenant2_api_key"
|
||||
|
||||
# When processing message from guild 123456
|
||||
tenant = cache.get_tenant(123456)
|
||||
assert tenant is not None
|
||||
api_key = cache.get_api_key(tenant)
|
||||
|
||||
assert tenant == "tenant1"
|
||||
assert api_key == "tenant1_api_key"
|
||||
assert api_key != "tenant2_api_key"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guild_message_routes_to_correct_tenant(self) -> None:
|
||||
"""Message from registered guild routes to correct tenant context."""
|
||||
cache = DiscordCacheManager()
|
||||
cache._guild_tenants[999] = "target_tenant"
|
||||
cache._api_keys["target_tenant"] = "target_key"
|
||||
|
||||
# Simulate message routing
|
||||
guild_id = 999
|
||||
tenant = cache.get_tenant(guild_id)
|
||||
api_key = cache.get_api_key(tenant) if tenant else None
|
||||
|
||||
assert tenant == "target_tenant"
|
||||
assert api_key == "target_key"
|
||||
185
backend/tests/integration/tests/chat/test_chat_session_access.py
Normal file
185
backend/tests/integration/tests/chat/test_chat_session_access.py
Normal file
@@ -0,0 +1,185 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from requests import HTTPError
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.reset import reset_all
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def reset_for_module() -> None:
|
||||
"""Reset all data once before running any tests in this module."""
|
||||
reset_all()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def second_user(admin_user: DATestUser) -> DATestUser:
|
||||
# Ensure admin exists so this new user is created with BASIC role.
|
||||
try:
|
||||
return UserManager.create(name="second_basic_user")
|
||||
except HTTPError as e:
|
||||
response = e.response
|
||||
if response is None:
|
||||
raise
|
||||
if response.status_code not in (400, 409):
|
||||
raise
|
||||
try:
|
||||
payload = response.json()
|
||||
except ValueError:
|
||||
raise
|
||||
detail = payload.get("detail")
|
||||
if not _is_user_already_exists_detail(detail):
|
||||
raise
|
||||
print("Second basic user already exists; logging in instead.")
|
||||
return UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email("second_basic_user"),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.BASIC,
|
||||
is_active=True,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _is_user_already_exists_detail(detail: object) -> bool:
|
||||
if isinstance(detail, str):
|
||||
normalized = detail.lower()
|
||||
return (
|
||||
"already exists" in normalized
|
||||
or "register_user_already_exists" in normalized
|
||||
)
|
||||
if isinstance(detail, dict):
|
||||
code = detail.get("code")
|
||||
if isinstance(code, str) and code.lower() == "register_user_already_exists":
|
||||
return True
|
||||
message = detail.get("message")
|
||||
if isinstance(message, str) and "already exists" in message.lower():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_chat_session(
|
||||
chat_session_id: str,
|
||||
user: DATestUser,
|
||||
is_shared: bool | None = None,
|
||||
include_deleted: bool | None = None,
|
||||
) -> requests.Response:
|
||||
params: dict[str, str] = {}
|
||||
if is_shared is not None:
|
||||
params["is_shared"] = str(is_shared).lower()
|
||||
if include_deleted is not None:
|
||||
params["include_deleted"] = str(include_deleted).lower()
|
||||
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session_id}",
|
||||
params=params,
|
||||
headers=user.headers,
|
||||
cookies=user.cookies,
|
||||
)
|
||||
|
||||
|
||||
def _set_sharing_status(
|
||||
chat_session_id: str, sharing_status: str, user: DATestUser
|
||||
) -> requests.Response:
|
||||
return requests.patch(
|
||||
f"{API_SERVER_URL}/chat/chat-session/{chat_session_id}",
|
||||
json={"sharing_status": sharing_status},
|
||||
headers=user.headers,
|
||||
cookies=user.cookies,
|
||||
)
|
||||
|
||||
|
||||
def test_private_chat_session_access(
|
||||
basic_user: DATestUser, second_user: DATestUser
|
||||
) -> None:
|
||||
"""Verify private sessions are only accessible by the owner and never via share link."""
|
||||
# Create a private chat session owned by basic_user.
|
||||
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
|
||||
|
||||
# Owner can access the private session normally.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Share link should be forbidden when the session is private.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user, is_shared=True)
|
||||
assert response.status_code == 403
|
||||
|
||||
# Other users cannot access private sessions directly.
|
||||
response = _get_chat_session(str(chat_session.id), second_user)
|
||||
assert response.status_code == 403
|
||||
|
||||
# Other users also cannot access private sessions via share link.
|
||||
response = _get_chat_session(str(chat_session.id), second_user, is_shared=True)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_public_shared_chat_session_access(
|
||||
basic_user: DATestUser, second_user: DATestUser
|
||||
) -> None:
|
||||
"""Verify shared sessions are accessible only via share link for non-owners."""
|
||||
# Create a private session, then mark it public.
|
||||
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
|
||||
|
||||
response = _set_sharing_status(str(chat_session.id), "public", basic_user)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Owner can access normally.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Owner can also access via share link.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user, is_shared=True)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Non-owner cannot access without share link.
|
||||
response = _get_chat_session(str(chat_session.id), second_user)
|
||||
assert response.status_code == 403
|
||||
|
||||
# Non-owner can access with share link for public sessions.
|
||||
response = _get_chat_session(str(chat_session.id), second_user, is_shared=True)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_deleted_chat_session_access(
|
||||
basic_user: DATestUser, second_user: DATestUser
|
||||
) -> None:
|
||||
"""Verify deleted sessions return 404, with include_deleted gated by access checks."""
|
||||
# Create and soft-delete a session.
|
||||
chat_session = ChatSessionManager.create(user_performing_action=basic_user)
|
||||
|
||||
deletion_success = ChatSessionManager.soft_delete(
|
||||
chat_session=chat_session, user_performing_action=basic_user
|
||||
)
|
||||
assert deletion_success is True
|
||||
|
||||
# Deleted sessions are not accessible normally.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user)
|
||||
assert response.status_code == 404
|
||||
|
||||
# Owner can fetch deleted session only with include_deleted.
|
||||
response = _get_chat_session(str(chat_session.id), basic_user, include_deleted=True)
|
||||
assert response.status_code == 200
|
||||
assert response.json().get("deleted") is True
|
||||
|
||||
# Non-owner should be blocked even with include_deleted.
|
||||
response = _get_chat_session(
|
||||
str(chat_session.id), second_user, include_deleted=True
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def test_chat_session_not_found_returns_404(basic_user: DATestUser) -> None:
|
||||
"""Verify unknown IDs return 404."""
|
||||
response = _get_chat_session(str(uuid4()), basic_user)
|
||||
assert response.status_code == 404
|
||||
@@ -0,0 +1,443 @@
|
||||
"""Integration tests for Discord bot API endpoints.
|
||||
|
||||
These tests hit actual API endpoints via HTTP requests.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from onyx.db.discord_bot import get_discord_service_api_key
|
||||
from onyx.db.discord_bot import get_or_create_discord_service_api_key
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from tests.integration.common_utils.managers.discord_bot import DiscordBotManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
class TestBotConfigEndpoints:
|
||||
"""Tests for /manage/admin/discord-bot/config endpoints."""
|
||||
|
||||
def test_get_bot_config_not_configured(self, reset: None) -> None:
|
||||
"""GET /config returns configured=False when no config exists."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# Ensure no config exists
|
||||
DiscordBotManager.delete_bot_config_if_exists(admin_user)
|
||||
|
||||
config = DiscordBotManager.get_bot_config(admin_user)
|
||||
|
||||
assert config["configured"] is False
|
||||
assert "created_at" not in config or config.get("created_at") is None
|
||||
|
||||
def test_create_bot_config(self, reset: None) -> None:
|
||||
"""POST /config creates a new bot config."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# Ensure no config exists
|
||||
DiscordBotManager.delete_bot_config_if_exists(admin_user)
|
||||
|
||||
config = DiscordBotManager.create_bot_config(
|
||||
bot_token="test_token_123",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert config["configured"] is True
|
||||
assert "created_at" in config
|
||||
|
||||
# Cleanup
|
||||
DiscordBotManager.delete_bot_config_if_exists(admin_user)
|
||||
|
||||
def test_create_bot_config_already_exists(self, reset: None) -> None:
|
||||
"""POST /config returns 409 if config already exists."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# Ensure no config exists, then create one
|
||||
DiscordBotManager.delete_bot_config_if_exists(admin_user)
|
||||
DiscordBotManager.create_bot_config(
|
||||
bot_token="token1",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Try to create another - should fail
|
||||
with pytest.raises(requests.HTTPError) as exc_info:
|
||||
DiscordBotManager.create_bot_config(
|
||||
bot_token="token2",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert exc_info.value.response.status_code == 409
|
||||
|
||||
# Cleanup
|
||||
DiscordBotManager.delete_bot_config_if_exists(admin_user)
|
||||
|
||||
def test_delete_bot_config(self, reset: None) -> None:
|
||||
"""DELETE /config removes the bot config."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# Ensure no config exists, then create one
|
||||
DiscordBotManager.delete_bot_config_if_exists(admin_user)
|
||||
DiscordBotManager.create_bot_config(
|
||||
bot_token="test_token",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Delete it
|
||||
result = DiscordBotManager.delete_bot_config(admin_user)
|
||||
assert result["deleted"] is True
|
||||
|
||||
# Verify it's gone
|
||||
config = DiscordBotManager.get_bot_config(admin_user)
|
||||
assert config["configured"] is False
|
||||
|
||||
def test_delete_bot_config_not_found(self, reset: None) -> None:
|
||||
"""DELETE /config returns 404 if no config exists."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# Ensure no config exists
|
||||
DiscordBotManager.delete_bot_config_if_exists(admin_user)
|
||||
|
||||
# Try to delete - should fail
|
||||
with pytest.raises(requests.HTTPError) as exc_info:
|
||||
DiscordBotManager.delete_bot_config(admin_user)
|
||||
|
||||
assert exc_info.value.response.status_code == 404
|
||||
|
||||
|
||||
class TestGuildConfigEndpoints:
|
||||
"""Tests for /manage/admin/discord-bot/guilds endpoints."""
|
||||
|
||||
def test_create_guild_config(self, reset: None) -> None:
|
||||
"""POST /guilds creates a new guild config with registration key."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
guild = DiscordBotManager.create_guild(admin_user)
|
||||
|
||||
assert guild.id is not None
|
||||
assert guild.registration_key is not None
|
||||
assert guild.registration_key.startswith("discord_")
|
||||
|
||||
# Cleanup
|
||||
DiscordBotManager.delete_guild_if_exists(guild.id, admin_user)
|
||||
|
||||
def test_list_guilds(self, reset: None) -> None:
|
||||
"""GET /guilds returns all guild configs."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# Create some guilds
|
||||
guild1 = DiscordBotManager.create_guild(admin_user)
|
||||
guild2 = DiscordBotManager.create_guild(admin_user)
|
||||
|
||||
guilds = DiscordBotManager.list_guilds(admin_user)
|
||||
|
||||
guild_ids = [g["id"] for g in guilds]
|
||||
assert guild1.id in guild_ids
|
||||
assert guild2.id in guild_ids
|
||||
|
||||
# Cleanup
|
||||
DiscordBotManager.delete_guild_if_exists(guild1.id, admin_user)
|
||||
DiscordBotManager.delete_guild_if_exists(guild2.id, admin_user)
|
||||
|
||||
def test_get_guild_config(self, reset: None) -> None:
|
||||
"""GET /guilds/{config_id} returns the specific guild config."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
guild = DiscordBotManager.create_guild(admin_user)
|
||||
|
||||
fetched = DiscordBotManager.get_guild(guild.id, admin_user)
|
||||
|
||||
assert fetched["id"] == guild.id
|
||||
assert fetched["enabled"] is True # Default
|
||||
assert fetched["guild_id"] is None # Not registered yet
|
||||
assert fetched["guild_name"] is None
|
||||
|
||||
# Cleanup
|
||||
DiscordBotManager.delete_guild_if_exists(guild.id, admin_user)
|
||||
|
||||
def test_get_guild_config_not_found(self, reset: None) -> None:
|
||||
"""GET /guilds/{config_id} returns 404 for non-existent guild."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
result = DiscordBotManager.get_guild_or_none(999999, admin_user)
|
||||
assert result is None
|
||||
|
||||
def test_update_guild_config(self, reset: None) -> None:
|
||||
"""PATCH /guilds/{config_id} updates the guild config."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
guild = DiscordBotManager.create_guild(admin_user)
|
||||
|
||||
# Update enabled status
|
||||
updated = DiscordBotManager.update_guild(
|
||||
guild.id,
|
||||
admin_user,
|
||||
enabled=False,
|
||||
)
|
||||
|
||||
assert updated["enabled"] is False
|
||||
|
||||
# Verify persistence
|
||||
fetched = DiscordBotManager.get_guild(guild.id, admin_user)
|
||||
assert fetched["enabled"] is False
|
||||
|
||||
# Cleanup
|
||||
DiscordBotManager.delete_guild_if_exists(guild.id, admin_user)
|
||||
|
||||
def test_delete_guild_config(self, reset: None) -> None:
|
||||
"""DELETE /guilds/{config_id} removes the guild config."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
guild = DiscordBotManager.create_guild(admin_user)
|
||||
|
||||
# Delete it
|
||||
result = DiscordBotManager.delete_guild(guild.id, admin_user)
|
||||
assert result["deleted"] is True
|
||||
|
||||
# Verify it's gone
|
||||
assert DiscordBotManager.get_guild_or_none(guild.id, admin_user) is None
|
||||
|
||||
def test_delete_guild_config_not_found(self, reset: None) -> None:
|
||||
"""DELETE /guilds/{config_id} returns 404 for non-existent guild."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
with pytest.raises(requests.HTTPError) as exc_info:
|
||||
DiscordBotManager.delete_guild(999999, admin_user)
|
||||
|
||||
assert exc_info.value.response.status_code == 404
|
||||
|
||||
def test_registration_key_format(self, reset: None) -> None:
|
||||
"""Registration key has proper format with tenant encoded."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
guild = DiscordBotManager.create_guild(admin_user)
|
||||
|
||||
# Key should be: discord_{encoded_tenant}.{random}
|
||||
key = guild.registration_key
|
||||
assert key is not None
|
||||
assert key.startswith("discord_")
|
||||
|
||||
# Should have two parts separated by dot
|
||||
key_body = key.removeprefix("discord_")
|
||||
parts = key_body.split(".", 1)
|
||||
assert len(parts) == 2
|
||||
|
||||
# Cleanup
|
||||
DiscordBotManager.delete_guild_if_exists(guild.id, admin_user)
|
||||
|
||||
def test_each_registration_key_is_unique(self, reset: None) -> None:
|
||||
"""Each created guild gets a unique registration key."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
guilds = [DiscordBotManager.create_guild(admin_user) for _ in range(5)]
|
||||
keys = [g.registration_key for g in guilds]
|
||||
|
||||
assert len(set(keys)) == 5 # All unique
|
||||
|
||||
# Cleanup
|
||||
for guild in guilds:
|
||||
DiscordBotManager.delete_guild_if_exists(guild.id, admin_user)
|
||||
|
||||
|
||||
class TestChannelConfigEndpoints:
|
||||
"""Tests for /manage/admin/discord-bot/guilds/{id}/channels endpoints."""
|
||||
|
||||
def test_list_channels_empty(self, reset: None) -> None:
|
||||
"""GET /guilds/{id}/channels returns empty list when no channels exist."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# Create a registered guild (has guild_id set)
|
||||
guild = DiscordBotManager.create_registered_guild_in_db(
|
||||
guild_id=111111111,
|
||||
guild_name="Test Guild",
|
||||
)
|
||||
|
||||
channels = DiscordBotManager.list_channels(guild.id, admin_user)
|
||||
|
||||
assert channels == []
|
||||
|
||||
# Cleanup
|
||||
DiscordBotManager.delete_guild_if_exists(guild.id, admin_user)
|
||||
|
||||
def test_list_channels_with_data(self, reset: None) -> None:
|
||||
"""GET /guilds/{id}/channels returns channel configs."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# Create a registered guild (has guild_id set)
|
||||
guild = DiscordBotManager.create_registered_guild_in_db(
|
||||
guild_id=222222222,
|
||||
guild_name="Test Guild",
|
||||
)
|
||||
|
||||
# Create test channels directly in DB
|
||||
channel1 = DiscordBotManager.create_test_channel_in_db(
|
||||
guild_config_id=guild.id,
|
||||
channel_id=123456789,
|
||||
channel_name="general",
|
||||
)
|
||||
channel2 = DiscordBotManager.create_test_channel_in_db(
|
||||
guild_config_id=guild.id,
|
||||
channel_id=987654321,
|
||||
channel_name="help",
|
||||
channel_type="forum",
|
||||
)
|
||||
|
||||
channels = DiscordBotManager.list_channels(guild.id, admin_user)
|
||||
|
||||
assert len(channels) == 2
|
||||
channel_ids = [c.id for c in channels]
|
||||
assert channel1.id in channel_ids
|
||||
assert channel2.id in channel_ids
|
||||
|
||||
# Cleanup
|
||||
DiscordBotManager.delete_guild_if_exists(guild.id, admin_user)
|
||||
|
||||
def test_update_channel_enabled(self, reset: None) -> None:
|
||||
"""PATCH /guilds/{id}/channels/{id} updates enabled status."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# Create a registered guild (has guild_id set)
|
||||
guild = DiscordBotManager.create_registered_guild_in_db(
|
||||
guild_id=333333333,
|
||||
guild_name="Test Guild",
|
||||
)
|
||||
channel = DiscordBotManager.create_test_channel_in_db(
|
||||
guild_config_id=guild.id,
|
||||
channel_id=123456789,
|
||||
channel_name="general",
|
||||
)
|
||||
|
||||
# Default is disabled
|
||||
assert channel.enabled is False
|
||||
|
||||
# Enable the channel
|
||||
updated = DiscordBotManager.update_channel(
|
||||
guild.id,
|
||||
channel.id,
|
||||
admin_user,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
assert updated.enabled is True
|
||||
|
||||
# Verify persistence
|
||||
channels = DiscordBotManager.list_channels(guild.id, admin_user)
|
||||
found = next(c for c in channels if c.id == channel.id)
|
||||
assert found.enabled is True
|
||||
|
||||
# Cleanup
|
||||
DiscordBotManager.delete_guild_if_exists(guild.id, admin_user)
|
||||
|
||||
def test_update_channel_thread_only_mode(self, reset: None) -> None:
|
||||
"""PATCH /guilds/{id}/channels/{id} updates thread_only_mode."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# Create a registered guild (has guild_id set)
|
||||
guild = DiscordBotManager.create_registered_guild_in_db(
|
||||
guild_id=444444444,
|
||||
guild_name="Test Guild",
|
||||
)
|
||||
channel = DiscordBotManager.create_test_channel_in_db(
|
||||
guild_config_id=guild.id,
|
||||
channel_id=123456789,
|
||||
channel_name="general",
|
||||
)
|
||||
|
||||
# Default is False
|
||||
assert channel.thread_only_mode is False
|
||||
|
||||
# Enable thread_only_mode
|
||||
updated = DiscordBotManager.update_channel(
|
||||
guild.id,
|
||||
channel.id,
|
||||
admin_user,
|
||||
thread_only_mode=True,
|
||||
)
|
||||
|
||||
assert updated.thread_only_mode is True
|
||||
|
||||
# Cleanup
|
||||
DiscordBotManager.delete_guild_if_exists(guild.id, admin_user)
|
||||
|
||||
def test_update_channel_require_bot_invocation(self, reset: None) -> None:
|
||||
"""PATCH /guilds/{id}/channels/{id} updates require_bot_invocation."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# Create a registered guild (has guild_id set)
|
||||
guild = DiscordBotManager.create_registered_guild_in_db(
|
||||
guild_id=555555555,
|
||||
guild_name="Test Guild",
|
||||
)
|
||||
channel = DiscordBotManager.create_test_channel_in_db(
|
||||
guild_config_id=guild.id,
|
||||
channel_id=123456789,
|
||||
channel_name="general",
|
||||
)
|
||||
|
||||
# Default is True
|
||||
assert channel.require_bot_invocation is True
|
||||
|
||||
# Disable require_bot_invocation
|
||||
updated = DiscordBotManager.update_channel(
|
||||
guild.id,
|
||||
channel.id,
|
||||
admin_user,
|
||||
require_bot_invocation=False,
|
||||
)
|
||||
|
||||
assert updated.require_bot_invocation is False
|
||||
|
||||
# Cleanup
|
||||
DiscordBotManager.delete_guild_if_exists(guild.id, admin_user)
|
||||
|
||||
def test_update_channel_not_found(self, reset: None) -> None:
|
||||
"""PATCH /guilds/{id}/channels/{id} returns 404 for non-existent channel."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# Create a registered guild (has guild_id set)
|
||||
guild = DiscordBotManager.create_registered_guild_in_db(
|
||||
guild_id=666666666,
|
||||
guild_name="Test Guild",
|
||||
)
|
||||
|
||||
with pytest.raises(requests.HTTPError) as exc_info:
|
||||
DiscordBotManager.update_channel(
|
||||
guild.id,
|
||||
999999,
|
||||
admin_user,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
assert exc_info.value.response.status_code == 404
|
||||
|
||||
# Cleanup
|
||||
DiscordBotManager.delete_guild_if_exists(guild.id, admin_user)
|
||||
|
||||
|
||||
class TestServiceApiKeyCleanup:
|
||||
"""Tests for service API key cleanup when bot/guild configs are deleted."""
|
||||
|
||||
def test_delete_bot_config_also_deletes_service_api_key(self, reset: None) -> None:
|
||||
"""DELETE /config also deletes the service API key (self-hosted flow)."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# Setup: create bot config via API
|
||||
DiscordBotManager.delete_bot_config_if_exists(admin_user)
|
||||
DiscordBotManager.create_bot_config(
|
||||
bot_token="test_token",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Create service API key directly in DB (simulating bot registration)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
get_or_create_discord_service_api_key(db_session, "public")
|
||||
db_session.commit()
|
||||
|
||||
# Verify it exists
|
||||
assert get_discord_service_api_key(db_session) is not None
|
||||
|
||||
# Delete bot config via API
|
||||
result = DiscordBotManager.delete_bot_config(admin_user)
|
||||
assert result["deleted"] is True
|
||||
|
||||
# Verify service API key was also deleted
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
assert get_discord_service_api_key(db_session) is None
|
||||
@@ -0,0 +1,673 @@
|
||||
"""Integration tests for Discord bot database operations.
|
||||
|
||||
These tests verify CRUD operations for Discord bot models.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.discord_bot import bulk_create_channel_configs
|
||||
from onyx.db.discord_bot import create_discord_bot_config
|
||||
from onyx.db.discord_bot import create_guild_config
|
||||
from onyx.db.discord_bot import delete_discord_bot_config
|
||||
from onyx.db.discord_bot import delete_discord_service_api_key
|
||||
from onyx.db.discord_bot import delete_guild_config
|
||||
from onyx.db.discord_bot import get_channel_configs
|
||||
from onyx.db.discord_bot import get_discord_bot_config
|
||||
from onyx.db.discord_bot import get_discord_service_api_key
|
||||
from onyx.db.discord_bot import get_guild_config_by_internal_id
|
||||
from onyx.db.discord_bot import get_guild_config_by_registration_key
|
||||
from onyx.db.discord_bot import get_guild_configs
|
||||
from onyx.db.discord_bot import get_or_create_discord_service_api_key
|
||||
from onyx.db.discord_bot import sync_channel_configs
|
||||
from onyx.db.discord_bot import update_discord_channel_config
|
||||
from onyx.db.discord_bot import update_guild_config
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.utils import DiscordChannelView
|
||||
from onyx.server.manage.discord_bot.utils import generate_discord_registration_key
|
||||
|
||||
|
||||
def _create_test_persona(db_session: Session, persona_id: int, name: str) -> Persona:
|
||||
"""Create a minimal test persona."""
|
||||
persona = Persona(
|
||||
id=persona_id,
|
||||
name=name,
|
||||
description="Test persona for Discord bot tests",
|
||||
num_chunks=5.0,
|
||||
chunks_above=1,
|
||||
chunks_below=1,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.FAVOR_RECENT,
|
||||
is_visible=True,
|
||||
is_default_persona=False,
|
||||
deleted=False,
|
||||
builtin_persona=False,
|
||||
)
|
||||
db_session.add(persona)
|
||||
db_session.flush()
|
||||
return persona
|
||||
|
||||
|
||||
def _delete_test_persona(db_session: Session, persona_id: int) -> None:
|
||||
"""Delete a test persona."""
|
||||
db_session.query(Persona).filter(Persona.id == persona_id).delete()
|
||||
db_session.flush()
|
||||
|
||||
|
||||
class TestBotConfigAPI:
|
||||
"""Tests for bot config API operations."""
|
||||
|
||||
def test_create_bot_config(self, db_session: Session) -> None:
|
||||
"""Create bot config succeeds with valid token."""
|
||||
# Clean up any existing config first
|
||||
delete_discord_bot_config(db_session)
|
||||
db_session.commit()
|
||||
|
||||
config = create_discord_bot_config(db_session, bot_token="test_token_123")
|
||||
db_session.commit()
|
||||
|
||||
assert config is not None
|
||||
assert config.bot_token == "test_token_123"
|
||||
|
||||
# Cleanup
|
||||
delete_discord_bot_config(db_session)
|
||||
db_session.commit()
|
||||
|
||||
def test_create_bot_config_already_exists(self, db_session: Session) -> None:
|
||||
"""Creating config twice raises ValueError."""
|
||||
# Clean up first
|
||||
delete_discord_bot_config(db_session)
|
||||
db_session.commit()
|
||||
|
||||
create_discord_bot_config(db_session, bot_token="token1")
|
||||
db_session.commit()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
create_discord_bot_config(db_session, bot_token="token2")
|
||||
|
||||
# Cleanup
|
||||
delete_discord_bot_config(db_session)
|
||||
db_session.commit()
|
||||
|
||||
def test_get_bot_config(self, db_session: Session) -> None:
|
||||
"""Get bot config returns config with masked token."""
|
||||
# Clean up first
|
||||
delete_discord_bot_config(db_session)
|
||||
db_session.commit()
|
||||
|
||||
create_discord_bot_config(db_session, bot_token="my_secret_token")
|
||||
db_session.commit()
|
||||
|
||||
config = get_discord_bot_config(db_session)
|
||||
|
||||
assert config is not None
|
||||
# Token should be stored (we don't mask in DB, only API response)
|
||||
assert config.bot_token is not None
|
||||
|
||||
# Cleanup
|
||||
delete_discord_bot_config(db_session)
|
||||
db_session.commit()
|
||||
|
||||
def test_delete_bot_config(self, db_session: Session) -> None:
|
||||
"""Delete bot config removes it from DB."""
|
||||
# Clean up first
|
||||
delete_discord_bot_config(db_session)
|
||||
db_session.commit()
|
||||
|
||||
create_discord_bot_config(db_session, bot_token="token")
|
||||
db_session.commit()
|
||||
|
||||
deleted = delete_discord_bot_config(db_session)
|
||||
db_session.commit()
|
||||
|
||||
assert deleted is True
|
||||
assert get_discord_bot_config(db_session) is None
|
||||
|
||||
def test_delete_bot_config_not_found(self, db_session: Session) -> None:
|
||||
"""Delete when no config exists returns False."""
|
||||
# Ensure no config exists
|
||||
delete_discord_bot_config(db_session)
|
||||
db_session.commit()
|
||||
|
||||
deleted = delete_discord_bot_config(db_session)
|
||||
assert deleted is False
|
||||
|
||||
|
||||
class TestRegistrationKeyAPI:
|
||||
"""Tests for registration key API operations."""
|
||||
|
||||
def test_create_registration_key(self, db_session: Session) -> None:
|
||||
"""Create registration key with proper format."""
|
||||
key = generate_discord_registration_key("test_tenant")
|
||||
|
||||
config = create_guild_config(db_session, registration_key=key)
|
||||
db_session.commit()
|
||||
|
||||
assert config is not None
|
||||
assert config.registration_key == key
|
||||
assert key.startswith("discord_")
|
||||
assert "test_tenant" in key or "test%5Ftenant" in key
|
||||
|
||||
# Cleanup
|
||||
delete_guild_config(db_session, config.id)
|
||||
db_session.commit()
|
||||
|
||||
def test_registration_key_is_unique(self, db_session: Session) -> None:
|
||||
"""Each generated key is unique."""
|
||||
keys = [generate_discord_registration_key("tenant") for _ in range(5)]
|
||||
assert len(set(keys)) == 5
|
||||
|
||||
def test_delete_registration_key(self, db_session: Session) -> None:
|
||||
"""Deleted key can no longer be used."""
|
||||
key = generate_discord_registration_key("tenant")
|
||||
config = create_guild_config(db_session, registration_key=key)
|
||||
db_session.commit()
|
||||
config_id = config.id
|
||||
|
||||
# Delete
|
||||
deleted = delete_guild_config(db_session, config_id)
|
||||
db_session.commit()
|
||||
|
||||
assert deleted is True
|
||||
|
||||
# Should not find it anymore
|
||||
found = get_guild_config_by_registration_key(db_session, key)
|
||||
assert found is None
|
||||
|
||||
|
||||
class TestGuildConfigAPI:
|
||||
"""Tests for guild config API operations."""
|
||||
|
||||
def test_list_guilds(self, db_session: Session) -> None:
|
||||
"""List guilds returns all guild configs."""
|
||||
# Create some guild configs
|
||||
key1 = generate_discord_registration_key("t1")
|
||||
key2 = generate_discord_registration_key("t2")
|
||||
|
||||
config1 = create_guild_config(db_session, registration_key=key1)
|
||||
config2 = create_guild_config(db_session, registration_key=key2)
|
||||
db_session.commit()
|
||||
|
||||
configs = get_guild_configs(db_session)
|
||||
|
||||
assert len(configs) >= 2
|
||||
|
||||
# Cleanup
|
||||
delete_guild_config(db_session, config1.id)
|
||||
delete_guild_config(db_session, config2.id)
|
||||
db_session.commit()
|
||||
|
||||
def test_get_guild_config(self, db_session: Session) -> None:
|
||||
"""Get specific guild config by ID."""
|
||||
key = generate_discord_registration_key("tenant")
|
||||
config = create_guild_config(db_session, registration_key=key)
|
||||
db_session.commit()
|
||||
|
||||
found = get_guild_config_by_internal_id(db_session, config.id)
|
||||
|
||||
assert found is not None
|
||||
assert found.id == config.id
|
||||
assert found.registration_key == key
|
||||
|
||||
# Cleanup
|
||||
delete_guild_config(db_session, config.id)
|
||||
db_session.commit()
|
||||
|
||||
def test_update_guild_enabled(self, db_session: Session) -> None:
|
||||
"""Update guild enabled status."""
|
||||
key = generate_discord_registration_key("tenant")
|
||||
config = create_guild_config(db_session, registration_key=key)
|
||||
db_session.commit()
|
||||
|
||||
# Initially enabled is True by default
|
||||
assert config.enabled is True
|
||||
|
||||
# Disable
|
||||
updated = update_guild_config(
|
||||
db_session, config, enabled=False, default_persona_id=None
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
assert updated.enabled is False
|
||||
|
||||
# Cleanup
|
||||
delete_guild_config(db_session, config.id)
|
||||
db_session.commit()
|
||||
|
||||
def test_update_guild_persona(self, db_session: Session) -> None:
|
||||
"""Update guild default persona."""
|
||||
# Create test persona first to satisfy foreign key constraint
|
||||
_create_test_persona(db_session, 5, "Test Persona 5")
|
||||
db_session.commit()
|
||||
|
||||
key = generate_discord_registration_key("tenant")
|
||||
config = create_guild_config(db_session, registration_key=key)
|
||||
db_session.commit()
|
||||
|
||||
# Set persona
|
||||
updated = update_guild_config(
|
||||
db_session, config, enabled=True, default_persona_id=5
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
assert updated.default_persona_id == 5
|
||||
|
||||
# Cleanup
|
||||
delete_guild_config(db_session, config.id)
|
||||
_delete_test_persona(db_session, 5)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
class TestChannelConfigAPI:
|
||||
"""Tests for channel config API operations."""
|
||||
|
||||
def test_list_channels_for_guild(self, db_session: Session) -> None:
|
||||
"""List channels returns all channel configs for guild."""
|
||||
key = generate_discord_registration_key("tenant")
|
||||
guild = create_guild_config(db_session, registration_key=key)
|
||||
db_session.commit()
|
||||
|
||||
# Create some channels
|
||||
channels = [
|
||||
DiscordChannelView(
|
||||
channel_id=111,
|
||||
channel_name="general",
|
||||
channel_type="text",
|
||||
is_private=False,
|
||||
),
|
||||
DiscordChannelView(
|
||||
channel_id=222,
|
||||
channel_name="help",
|
||||
channel_type="text",
|
||||
is_private=False,
|
||||
),
|
||||
]
|
||||
bulk_create_channel_configs(db_session, guild.id, channels)
|
||||
db_session.commit()
|
||||
|
||||
channel_configs = get_channel_configs(db_session, guild.id)
|
||||
|
||||
assert len(channel_configs) == 2
|
||||
|
||||
# Cleanup
|
||||
delete_guild_config(db_session, guild.id)
|
||||
db_session.commit()
|
||||
|
||||
def test_update_channel_enabled(self, db_session: Session) -> None:
|
||||
"""Update channel enabled status."""
|
||||
key = generate_discord_registration_key("tenant")
|
||||
guild = create_guild_config(db_session, registration_key=key)
|
||||
db_session.commit()
|
||||
|
||||
channels = [
|
||||
DiscordChannelView(
|
||||
channel_id=111,
|
||||
channel_name="general",
|
||||
channel_type="text",
|
||||
is_private=False,
|
||||
),
|
||||
]
|
||||
created = bulk_create_channel_configs(db_session, guild.id, channels)
|
||||
db_session.commit()
|
||||
|
||||
# Channels are disabled by default
|
||||
assert created[0].enabled is False
|
||||
|
||||
# Enable
|
||||
updated = update_discord_channel_config(
|
||||
db_session,
|
||||
created[0],
|
||||
channel_name="general",
|
||||
thread_only_mode=False,
|
||||
require_bot_invocation=True,
|
||||
enabled=True,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
assert updated.enabled is True
|
||||
|
||||
# Cleanup
|
||||
delete_guild_config(db_session, guild.id)
|
||||
db_session.commit()
|
||||
|
||||
def test_update_channel_thread_only_mode(self, db_session: Session) -> None:
|
||||
"""Update channel thread_only_mode setting."""
|
||||
key = generate_discord_registration_key("tenant")
|
||||
guild = create_guild_config(db_session, registration_key=key)
|
||||
db_session.commit()
|
||||
|
||||
channels = [
|
||||
DiscordChannelView(
|
||||
channel_id=111,
|
||||
channel_name="general",
|
||||
channel_type="text",
|
||||
is_private=False,
|
||||
),
|
||||
]
|
||||
created = bulk_create_channel_configs(db_session, guild.id, channels)
|
||||
db_session.commit()
|
||||
|
||||
# Update thread_only_mode
|
||||
updated = update_discord_channel_config(
|
||||
db_session,
|
||||
created[0],
|
||||
channel_name="general",
|
||||
thread_only_mode=True,
|
||||
require_bot_invocation=True,
|
||||
enabled=True,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
assert updated.thread_only_mode is True
|
||||
|
||||
# Cleanup
|
||||
delete_guild_config(db_session, guild.id)
|
||||
db_session.commit()
|
||||
|
||||
def test_sync_channels_adds_new(self, db_session: Session) -> None:
|
||||
"""Sync channels adds new channels."""
|
||||
key = generate_discord_registration_key("tenant")
|
||||
guild = create_guild_config(db_session, registration_key=key)
|
||||
db_session.commit()
|
||||
|
||||
# Initial channels
|
||||
initial = [
|
||||
DiscordChannelView(
|
||||
channel_id=111,
|
||||
channel_name="general",
|
||||
channel_type="text",
|
||||
is_private=False,
|
||||
),
|
||||
]
|
||||
bulk_create_channel_configs(db_session, guild.id, initial)
|
||||
db_session.commit()
|
||||
|
||||
# Sync with new channel
|
||||
current = [
|
||||
DiscordChannelView(
|
||||
channel_id=111,
|
||||
channel_name="general",
|
||||
channel_type="text",
|
||||
is_private=False,
|
||||
),
|
||||
DiscordChannelView(
|
||||
channel_id=222,
|
||||
channel_name="new-channel",
|
||||
channel_type="text",
|
||||
is_private=False,
|
||||
),
|
||||
]
|
||||
added, removed, updated = sync_channel_configs(db_session, guild.id, current)
|
||||
db_session.commit()
|
||||
|
||||
assert added == 1
|
||||
assert removed == 0
|
||||
|
||||
# Cleanup
|
||||
delete_guild_config(db_session, guild.id)
|
||||
db_session.commit()
|
||||
|
||||
def test_sync_channels_removes_deleted(self, db_session: Session) -> None:
|
||||
"""Sync channels removes deleted channels."""
|
||||
key = generate_discord_registration_key("tenant")
|
||||
guild = create_guild_config(db_session, registration_key=key)
|
||||
db_session.commit()
|
||||
|
||||
# Initial channels
|
||||
initial = [
|
||||
DiscordChannelView(
|
||||
channel_id=111,
|
||||
channel_name="general",
|
||||
channel_type="text",
|
||||
is_private=False,
|
||||
),
|
||||
DiscordChannelView(
|
||||
channel_id=222,
|
||||
channel_name="old-channel",
|
||||
channel_type="text",
|
||||
is_private=False,
|
||||
),
|
||||
]
|
||||
bulk_create_channel_configs(db_session, guild.id, initial)
|
||||
db_session.commit()
|
||||
|
||||
# Sync with one channel removed
|
||||
current = [
|
||||
DiscordChannelView(
|
||||
channel_id=111,
|
||||
channel_name="general",
|
||||
channel_type="text",
|
||||
is_private=False,
|
||||
),
|
||||
]
|
||||
added, removed, updated = sync_channel_configs(db_session, guild.id, current)
|
||||
db_session.commit()
|
||||
|
||||
assert added == 0
|
||||
assert removed == 1
|
||||
|
||||
# Cleanup
|
||||
delete_guild_config(db_session, guild.id)
|
||||
db_session.commit()
|
||||
|
||||
def test_sync_channels_updates_renamed(self, db_session: Session) -> None:
|
||||
"""Sync channels updates renamed channels."""
|
||||
key = generate_discord_registration_key("tenant")
|
||||
guild = create_guild_config(db_session, registration_key=key)
|
||||
db_session.commit()
|
||||
|
||||
# Initial channels
|
||||
initial = [
|
||||
DiscordChannelView(
|
||||
channel_id=111,
|
||||
channel_name="old-name",
|
||||
channel_type="text",
|
||||
is_private=False,
|
||||
),
|
||||
]
|
||||
bulk_create_channel_configs(db_session, guild.id, initial)
|
||||
db_session.commit()
|
||||
|
||||
# Sync with renamed channel
|
||||
current = [
|
||||
DiscordChannelView(
|
||||
channel_id=111,
|
||||
channel_name="new-name",
|
||||
channel_type="text",
|
||||
is_private=False,
|
||||
),
|
||||
]
|
||||
added, removed, updated = sync_channel_configs(db_session, guild.id, current)
|
||||
db_session.commit()
|
||||
|
||||
assert added == 0
|
||||
assert removed == 0
|
||||
assert updated == 1
|
||||
|
||||
# Verify name was updated
|
||||
configs = get_channel_configs(db_session, guild.id)
|
||||
assert configs[0].channel_name == "new-name"
|
||||
|
||||
# Cleanup
|
||||
delete_guild_config(db_session, guild.id)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
class TestPersonaConfigurationAPI:
|
||||
"""Tests for persona configuration in API."""
|
||||
|
||||
def test_guild_persona_used_in_api_call(self, db_session: Session) -> None:
|
||||
"""Guild default_persona_id is used when no channel override."""
|
||||
# Create test persona first
|
||||
_create_test_persona(db_session, 42, "Test Persona 42")
|
||||
db_session.commit()
|
||||
|
||||
key = generate_discord_registration_key("tenant")
|
||||
guild = create_guild_config(db_session, registration_key=key)
|
||||
update_guild_config(db_session, guild, enabled=True, default_persona_id=42)
|
||||
db_session.commit()
|
||||
|
||||
# Verify persona is set
|
||||
config = get_guild_config_by_internal_id(db_session, guild.id)
|
||||
assert config is not None
|
||||
assert config.default_persona_id == 42
|
||||
|
||||
# Cleanup
|
||||
delete_guild_config(db_session, guild.id)
|
||||
_delete_test_persona(db_session, 42)
|
||||
db_session.commit()
|
||||
|
||||
def test_channel_persona_override_in_api_call(self, db_session: Session) -> None:
|
||||
"""Channel persona_override_id takes precedence over guild default."""
|
||||
# Create test personas first
|
||||
_create_test_persona(db_session, 42, "Test Persona 42")
|
||||
_create_test_persona(db_session, 99, "Test Persona 99")
|
||||
db_session.commit()
|
||||
|
||||
key = generate_discord_registration_key("tenant")
|
||||
guild = create_guild_config(db_session, registration_key=key)
|
||||
update_guild_config(db_session, guild, enabled=True, default_persona_id=42)
|
||||
db_session.commit()
|
||||
|
||||
channels = [
|
||||
DiscordChannelView(
|
||||
channel_id=111,
|
||||
channel_name="general",
|
||||
channel_type="text",
|
||||
is_private=False,
|
||||
),
|
||||
]
|
||||
created = bulk_create_channel_configs(db_session, guild.id, channels)
|
||||
db_session.commit()
|
||||
|
||||
# Set channel persona override
|
||||
updated = update_discord_channel_config(
|
||||
db_session,
|
||||
created[0],
|
||||
channel_name="general",
|
||||
thread_only_mode=False,
|
||||
require_bot_invocation=True,
|
||||
enabled=True,
|
||||
persona_override_id=99, # Override!
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
assert updated.persona_override_id == 99
|
||||
|
||||
# Cleanup
|
||||
delete_guild_config(db_session, guild.id)
|
||||
_delete_test_persona(db_session, 42)
|
||||
_delete_test_persona(db_session, 99)
|
||||
db_session.commit()
|
||||
|
||||
def test_no_persona_uses_default(self, db_session: Session) -> None:
|
||||
"""Neither guild nor channel has persona - uses API default."""
|
||||
key = generate_discord_registration_key("tenant")
|
||||
guild = create_guild_config(db_session, registration_key=key)
|
||||
# No persona set
|
||||
db_session.commit()
|
||||
|
||||
config = get_guild_config_by_internal_id(db_session, guild.id)
|
||||
assert config is not None
|
||||
assert config.default_persona_id is None
|
||||
|
||||
# Cleanup
|
||||
delete_guild_config(db_session, guild.id)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
class TestServiceApiKeyAPI:
|
||||
"""Tests for Discord service API key operations."""
|
||||
|
||||
def test_create_service_api_key(self, db_session: Session) -> None:
|
||||
"""Create service API key returns valid key."""
|
||||
# Clean up any existing key first
|
||||
delete_discord_service_api_key(db_session)
|
||||
db_session.commit()
|
||||
|
||||
api_key = get_or_create_discord_service_api_key(db_session, "public")
|
||||
db_session.commit()
|
||||
|
||||
assert api_key is not None
|
||||
assert len(api_key) > 0
|
||||
|
||||
# Verify key was stored in database
|
||||
stored_key = get_discord_service_api_key(db_session)
|
||||
assert stored_key is not None
|
||||
|
||||
# Cleanup
|
||||
delete_discord_service_api_key(db_session)
|
||||
db_session.commit()
|
||||
|
||||
def test_get_or_create_returns_existing(self, db_session: Session) -> None:
|
||||
"""get_or_create_discord_service_api_key regenerates key if exists."""
|
||||
# Clean up any existing key first
|
||||
delete_discord_service_api_key(db_session)
|
||||
db_session.commit()
|
||||
|
||||
# Create first key
|
||||
key1 = get_or_create_discord_service_api_key(db_session, "public")
|
||||
db_session.commit()
|
||||
|
||||
# Call again - should regenerate (per implementation, it regenerates to update cache)
|
||||
key2 = get_or_create_discord_service_api_key(db_session, "public")
|
||||
db_session.commit()
|
||||
|
||||
# Keys should be different since it regenerates
|
||||
assert key1 != key2
|
||||
|
||||
# But there should still be only one key in the database
|
||||
stored_key = get_discord_service_api_key(db_session)
|
||||
assert stored_key is not None
|
||||
|
||||
# Cleanup
|
||||
delete_discord_service_api_key(db_session)
|
||||
db_session.commit()
|
||||
|
||||
def test_delete_service_api_key(self, db_session: Session) -> None:
|
||||
"""Delete service API key removes it from DB."""
|
||||
# Clean up any existing key first
|
||||
delete_discord_service_api_key(db_session)
|
||||
db_session.commit()
|
||||
|
||||
# Create a key
|
||||
get_or_create_discord_service_api_key(db_session, "public")
|
||||
db_session.commit()
|
||||
|
||||
# Delete it
|
||||
deleted = delete_discord_service_api_key(db_session)
|
||||
db_session.commit()
|
||||
|
||||
assert deleted is True
|
||||
assert get_discord_service_api_key(db_session) is None
|
||||
|
||||
def test_delete_service_api_key_not_found(self, db_session: Session) -> None:
|
||||
"""Delete when no key exists returns False."""
|
||||
# Ensure no key exists
|
||||
delete_discord_service_api_key(db_session)
|
||||
db_session.commit()
|
||||
|
||||
deleted = delete_discord_service_api_key(db_session)
|
||||
assert deleted is False
|
||||
|
||||
|
||||
# Pytest fixture for db_session
|
||||
@pytest.fixture
|
||||
def db_session() -> Generator[Session, None, None]:
|
||||
"""Create database session for tests."""
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
SqlEngine.init_engine(pool_size=10, max_overflow=5)
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set("public")
|
||||
try:
|
||||
with get_session_with_current_tenant() as session:
|
||||
yield session
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
@@ -309,6 +309,63 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
|
||||
assert fallback_llm.config.model_name == default_provider.default_model_name
|
||||
|
||||
|
||||
def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
users: tuple[DATestUser, DATestUser],
|
||||
) -> None:
|
||||
"""Test that the /llm/provider endpoint correctly excludes non-public providers
|
||||
with no group/persona restrictions.
|
||||
|
||||
This tests the fix for the bug where non-public providers with no restrictions
|
||||
were incorrectly shown to all users instead of being admin-only.
|
||||
"""
|
||||
admin_user, basic_user = users
|
||||
|
||||
# Create a public provider (should be visible to all)
|
||||
public_provider = LLMProviderManager.create(
|
||||
name="public-provider",
|
||||
is_public=True,
|
||||
set_as_default=True,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Create a non-public provider with no restrictions (should be admin-only)
|
||||
non_public_provider = LLMProviderManager.create(
|
||||
name="non-public-unrestricted",
|
||||
is_public=False,
|
||||
groups=[],
|
||||
personas=[],
|
||||
set_as_default=False,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Non-admin user calls the /llm/provider endpoint
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/llm/provider",
|
||||
headers=basic_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
providers = response.json()
|
||||
provider_names = [p["name"] for p in providers]
|
||||
|
||||
# Public provider should be visible
|
||||
assert public_provider.name in provider_names
|
||||
|
||||
# Non-public provider with no restrictions should NOT be visible to non-admin
|
||||
assert non_public_provider.name not in provider_names
|
||||
|
||||
# Admin user should see both providers
|
||||
admin_response = requests.get(
|
||||
f"{API_SERVER_URL}/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert admin_response.status_code == 200
|
||||
admin_providers = admin_response.json()
|
||||
admin_provider_names = [p["name"] for p in admin_providers]
|
||||
|
||||
assert public_provider.name in admin_provider_names
|
||||
assert non_public_provider.name in admin_provider_names
|
||||
|
||||
|
||||
def test_provider_delete_clears_persona_references(reset: None) -> None:
|
||||
"""Test that deleting a provider automatically clears persona references."""
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.persona import PersonaLabelManager
|
||||
from tests.integration.common_utils.managers.persona import PersonaManager
|
||||
from tests.integration.common_utils.test_models import DATestPersonaLabel
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_update_persona_with_null_label_ids_preserves_labels(
|
||||
reset: None, admin_user: DATestUser
|
||||
) -> None:
|
||||
persona_label = PersonaLabelManager.create(
|
||||
label=DATestPersonaLabel(name=f"Test label {uuid4()}"),
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert persona_label.id is not None
|
||||
persona = PersonaManager.create(
|
||||
label_ids=[persona_label.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
updated_description = f"{persona.description}-updated"
|
||||
update_request = PersonaUpsertRequest(
|
||||
name=persona.name,
|
||||
description=updated_description,
|
||||
system_prompt=persona.system_prompt or "",
|
||||
task_prompt=persona.task_prompt or "",
|
||||
datetime_aware=persona.datetime_aware,
|
||||
document_set_ids=persona.document_set_ids,
|
||||
num_chunks=persona.num_chunks,
|
||||
is_public=persona.is_public,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
tool_ids=persona.tool_ids,
|
||||
users=[],
|
||||
groups=[],
|
||||
label_ids=None,
|
||||
)
|
||||
|
||||
response = requests.patch(
|
||||
f"{API_SERVER_URL}/persona/{persona.id}",
|
||||
json=update_request.model_dump(mode="json", exclude_none=False),
|
||||
headers=admin_user.headers,
|
||||
cookies=admin_user.cookies,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
fetched = requests.get(
|
||||
f"{API_SERVER_URL}/persona/{persona.id}",
|
||||
headers=admin_user.headers,
|
||||
cookies=admin_user.cookies,
|
||||
)
|
||||
fetched.raise_for_status()
|
||||
fetched_persona = fetched.json()
|
||||
|
||||
assert fetched_persona["description"] == updated_description
|
||||
fetched_label_ids = {label["id"] for label in fetched_persona["labels"]}
|
||||
assert persona_label.id in fetched_label_ids
|
||||
@@ -270,7 +270,7 @@ def test_web_search_endpoints_with_exa(
|
||||
provider_id = _activate_exa_provider(admin_user)
|
||||
assert isinstance(provider_id, int)
|
||||
|
||||
search_request = {"queries": ["latest ai research news"], "max_results": 3}
|
||||
search_request = {"queries": ["wikipedia python programming"], "max_results": 3}
|
||||
|
||||
lite_response = requests.post(
|
||||
f"{API_SERVER_URL}/web-search/search-lite",
|
||||
|
||||
@@ -0,0 +1,159 @@
|
||||
"""Tests for license enforcement middleware."""
|
||||
|
||||
from collections.abc import Awaitable
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from ee.onyx.server.middleware.license_enforcement import _is_path_allowed
|
||||
|
||||
# Type alias for the middleware harness tuple
|
||||
MiddlewareHarness = tuple[
|
||||
Callable[[Request, Callable[[Request], Awaitable[Response]]], Awaitable[Response]],
|
||||
Callable[[Request], Awaitable[Response]],
|
||||
]
|
||||
|
||||
|
||||
class TestPathAllowlist:
|
||||
"""Tests for the path allowlist logic."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path,expected",
|
||||
[
|
||||
# Each allowlisted prefix (one example each)
|
||||
("/auth", True),
|
||||
("/license", True),
|
||||
("/health", True),
|
||||
("/me", True),
|
||||
("/settings", True),
|
||||
("/enterprise-settings", True),
|
||||
("/tenants/billing-information", True),
|
||||
("/tenants/create-customer-portal-session", True),
|
||||
# Verify prefix matching works (subpath of allowlisted)
|
||||
("/auth/callback/google", True),
|
||||
# Blocked paths (core functionality that requires license)
|
||||
("/chat", False),
|
||||
("/search", False),
|
||||
("/admin", False),
|
||||
("/connector", False),
|
||||
("/persona", False),
|
||||
],
|
||||
)
|
||||
def test_path_allowlist(self, path: str, expected: bool) -> None:
|
||||
"""Verify correct paths are allowed/blocked when license is gated."""
|
||||
assert _is_path_allowed(path) is expected
|
||||
|
||||
|
||||
class TestLicenseEnforcementMiddleware:
|
||||
"""Tests for middleware behavior under different conditions."""
|
||||
|
||||
@pytest.fixture
|
||||
def middleware_harness(self) -> MiddlewareHarness:
|
||||
"""Create a test harness for the middleware."""
|
||||
from ee.onyx.server.middleware.license_enforcement import (
|
||||
add_license_enforcement_middleware,
|
||||
)
|
||||
|
||||
app = MagicMock()
|
||||
logger = MagicMock()
|
||||
captured_middleware: Any = None
|
||||
|
||||
def capture_middleware(middleware_type: str) -> Callable[[Any], Any]:
|
||||
def decorator(func: Any) -> Any:
|
||||
nonlocal captured_middleware
|
||||
captured_middleware = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
app.middleware = capture_middleware
|
||||
add_license_enforcement_middleware(app, logger)
|
||||
|
||||
async def call_next(req: Request) -> Response:
|
||||
response = MagicMock()
|
||||
response.status_code = 200
|
||||
return response
|
||||
|
||||
return captured_middleware, call_next
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
"ee.onyx.server.middleware.license_enforcement.LICENSE_ENFORCEMENT_ENABLED",
|
||||
True,
|
||||
)
|
||||
@patch("ee.onyx.server.middleware.license_enforcement.MULTI_TENANT", True)
|
||||
@patch("ee.onyx.server.middleware.license_enforcement.get_current_tenant_id")
|
||||
@patch("ee.onyx.server.middleware.license_enforcement.is_tenant_gated")
|
||||
async def test_gated_tenant_gets_402(
|
||||
self,
|
||||
mock_is_gated: MagicMock,
|
||||
mock_get_tenant: MagicMock,
|
||||
middleware_harness: MiddlewareHarness,
|
||||
) -> None:
|
||||
"""Gated tenants receive 402 Payment Required on non-allowlisted paths."""
|
||||
mock_get_tenant.return_value = "gated_tenant"
|
||||
mock_is_gated.return_value = True
|
||||
|
||||
middleware, call_next = middleware_harness
|
||||
mock_request = MagicMock()
|
||||
mock_request.url.path = "/api/chat"
|
||||
|
||||
response = await middleware(mock_request, call_next)
|
||||
assert response.status_code == 402
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
"ee.onyx.server.middleware.license_enforcement.LICENSE_ENFORCEMENT_ENABLED",
|
||||
True,
|
||||
)
|
||||
@patch("ee.onyx.server.middleware.license_enforcement.MULTI_TENANT", False)
|
||||
@patch("ee.onyx.server.middleware.license_enforcement.get_current_tenant_id")
|
||||
@patch("ee.onyx.server.middleware.license_enforcement.get_cached_license_metadata")
|
||||
async def test_no_license_self_hosted_gets_402(
|
||||
self,
|
||||
mock_get_metadata: MagicMock,
|
||||
mock_get_tenant: MagicMock,
|
||||
middleware_harness: MiddlewareHarness,
|
||||
) -> None:
|
||||
"""Self-hosted with no license receives 402 on non-allowlisted paths."""
|
||||
mock_get_tenant.return_value = "default"
|
||||
mock_get_metadata.return_value = None
|
||||
|
||||
middleware, call_next = middleware_harness
|
||||
mock_request = MagicMock()
|
||||
mock_request.url.path = "/api/chat"
|
||||
|
||||
response = await middleware(mock_request, call_next)
|
||||
assert response.status_code == 402
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
"ee.onyx.server.middleware.license_enforcement.LICENSE_ENFORCEMENT_ENABLED",
|
||||
True,
|
||||
)
|
||||
@patch("ee.onyx.server.middleware.license_enforcement.MULTI_TENANT", True)
|
||||
@patch("ee.onyx.server.middleware.license_enforcement.get_current_tenant_id")
|
||||
@patch("ee.onyx.server.middleware.license_enforcement.is_tenant_gated")
|
||||
async def test_redis_error_fails_open(
|
||||
self,
|
||||
mock_is_gated: MagicMock,
|
||||
mock_get_tenant: MagicMock,
|
||||
middleware_harness: MiddlewareHarness,
|
||||
) -> None:
|
||||
"""Redis errors should not block users - fail open to allow access."""
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
mock_get_tenant.return_value = "test_tenant"
|
||||
mock_is_gated.side_effect = RedisError("Connection failed")
|
||||
|
||||
middleware, call_next = middleware_harness
|
||||
mock_request = MagicMock()
|
||||
mock_request.url.path = "/api/chat"
|
||||
|
||||
response = await middleware(mock_request, call_next)
|
||||
assert response.status_code == 200 # Fail open
|
||||
@@ -0,0 +1,93 @@
|
||||
"""Tests for license enforcement in settings API."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from onyx.server.settings.models import Settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_settings() -> Settings:
|
||||
"""Create base settings for testing."""
|
||||
return Settings(
|
||||
maximum_chat_retention_days=None,
|
||||
gpu_enabled=False,
|
||||
application_status=ApplicationStatus.ACTIVE,
|
||||
)
|
||||
|
||||
|
||||
class TestApplyLicenseStatusToSettings:
|
||||
"""Tests for apply_license_status_to_settings function."""
|
||||
|
||||
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", False)
|
||||
def test_enforcement_disabled_returns_unchanged(
|
||||
self, base_settings: Settings
|
||||
) -> None:
|
||||
"""Critical: When LICENSE_ENFORCEMENT_ENABLED=False, settings remain unchanged.
|
||||
|
||||
This is the key behavior that allows disabling enforcement for rollback.
|
||||
"""
|
||||
from ee.onyx.server.settings.api import apply_license_status_to_settings
|
||||
|
||||
result = apply_license_status_to_settings(base_settings)
|
||||
assert result.application_status == ApplicationStatus.ACTIVE
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"license_status,expected_status",
|
||||
[
|
||||
(None, ApplicationStatus.GATED_ACCESS), # No license = gated
|
||||
(
|
||||
ApplicationStatus.GATED_ACCESS,
|
||||
ApplicationStatus.GATED_ACCESS,
|
||||
), # Gated status propagated
|
||||
(ApplicationStatus.ACTIVE, ApplicationStatus.ACTIVE), # Active stays active
|
||||
],
|
||||
)
|
||||
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
|
||||
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
|
||||
@patch("ee.onyx.server.settings.api.get_current_tenant_id")
|
||||
@patch("ee.onyx.server.settings.api.get_cached_license_metadata")
|
||||
def test_self_hosted_license_status_propagation(
|
||||
self,
|
||||
mock_get_metadata: MagicMock,
|
||||
mock_get_tenant: MagicMock,
|
||||
license_status: ApplicationStatus | None,
|
||||
expected_status: ApplicationStatus,
|
||||
base_settings: Settings,
|
||||
) -> None:
|
||||
"""Self-hosted: license status is propagated to settings correctly."""
|
||||
from ee.onyx.server.settings.api import apply_license_status_to_settings
|
||||
|
||||
mock_get_tenant.return_value = "test_tenant"
|
||||
if license_status is None:
|
||||
mock_get_metadata.return_value = None
|
||||
else:
|
||||
mock_metadata = MagicMock()
|
||||
mock_metadata.status = license_status
|
||||
mock_get_metadata.return_value = mock_metadata
|
||||
|
||||
result = apply_license_status_to_settings(base_settings)
|
||||
assert result.application_status == expected_status
|
||||
|
||||
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
|
||||
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
|
||||
@patch("ee.onyx.server.settings.api.get_current_tenant_id")
|
||||
@patch("ee.onyx.server.settings.api.get_cached_license_metadata")
|
||||
def test_redis_error_fails_open(
|
||||
self,
|
||||
mock_get_metadata: MagicMock,
|
||||
mock_get_tenant: MagicMock,
|
||||
base_settings: Settings,
|
||||
) -> None:
|
||||
"""Redis errors should not block users - fail open."""
|
||||
from ee.onyx.server.settings.api import apply_license_status_to_settings
|
||||
|
||||
mock_get_tenant.return_value = "test_tenant"
|
||||
mock_get_metadata.side_effect = RedisError("Connection failed")
|
||||
|
||||
result = apply_license_status_to_settings(base_settings)
|
||||
assert result.application_status == ApplicationStatus.ACTIVE
|
||||
@@ -0,0 +1,69 @@
|
||||
"""Tests for product gating functions."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestIsTenantGated:
|
||||
"""Tests for is_tenant_gated - the O(1) Redis check used by middleware."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"redis_result,expected",
|
||||
[
|
||||
(True, True),
|
||||
(False, False),
|
||||
(1, True), # Redis sismember can return int
|
||||
(0, False),
|
||||
],
|
||||
)
|
||||
@patch("ee.onyx.server.tenants.product_gating.get_redis_replica_client")
|
||||
def test_tenant_gated_status(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
redis_result: bool | int,
|
||||
expected: bool,
|
||||
) -> None:
|
||||
"""is_tenant_gated correctly interprets Redis sismember result."""
|
||||
from ee.onyx.server.tenants.product_gating import is_tenant_gated
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_redis.sismember.return_value = redis_result
|
||||
mock_get_redis.return_value = mock_redis
|
||||
|
||||
assert is_tenant_gated("test_tenant") is expected
|
||||
|
||||
|
||||
class TestUpdateTenantGating:
|
||||
"""Tests for update_tenant_gating - modifies Redis gated set."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"status,should_add_to_set",
|
||||
[
|
||||
("gated_access", True), # Only GATED_ACCESS adds to set
|
||||
("active", False), # All other statuses remove from set
|
||||
],
|
||||
)
|
||||
@patch("ee.onyx.server.tenants.product_gating.get_redis_client")
|
||||
def test_gating_set_modification(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
status: str,
|
||||
should_add_to_set: bool,
|
||||
) -> None:
|
||||
"""update_tenant_gating adds tenant to set only for GATED_ACCESS status."""
|
||||
from ee.onyx.server.tenants.product_gating import update_tenant_gating
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
|
||||
mock_redis = MagicMock()
|
||||
mock_get_redis.return_value = mock_redis
|
||||
|
||||
update_tenant_gating("test_tenant", ApplicationStatus(status))
|
||||
|
||||
if should_add_to_set:
|
||||
mock_redis.sadd.assert_called_once()
|
||||
mock_redis.srem.assert_not_called()
|
||||
else:
|
||||
mock_redis.srem.assert_called_once()
|
||||
mock_redis.sadd.assert_not_called()
|
||||
@@ -0,0 +1,50 @@
|
||||
"""Tests for Asana connector configuration parsing."""
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.connectors.asana.connector import AsanaConnector
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"project_ids,expected",
|
||||
[
|
||||
(None, None),
|
||||
("", None),
|
||||
(" ", None),
|
||||
(" 123 ", ["123"]),
|
||||
(" 123 , , 456 , ", ["123", "456"]),
|
||||
],
|
||||
)
|
||||
def test_asana_connector_project_ids_normalization(
|
||||
project_ids: str | None, expected: list[str] | None
|
||||
) -> None:
|
||||
connector = AsanaConnector(
|
||||
asana_workspace_id=" 1153293530468850 ",
|
||||
asana_project_ids=project_ids,
|
||||
asana_team_id=" 1210918501948021 ",
|
||||
)
|
||||
|
||||
assert connector.workspace_id == "1153293530468850"
|
||||
assert connector.project_ids_to_index == expected
|
||||
assert connector.asana_team_id == "1210918501948021"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"team_id,expected",
|
||||
[
|
||||
(None, None),
|
||||
("", None),
|
||||
(" ", None),
|
||||
(" 1210918501948021 ", "1210918501948021"),
|
||||
],
|
||||
)
|
||||
def test_asana_connector_team_id_normalization(
|
||||
team_id: str | None, expected: str | None
|
||||
) -> None:
|
||||
connector = AsanaConnector(
|
||||
asana_workspace_id="1153293530468850",
|
||||
asana_project_ids=None,
|
||||
asana_team_id=team_id,
|
||||
)
|
||||
|
||||
assert connector.asana_team_id == expected
|
||||
@@ -409,6 +409,53 @@ def test_multiple_tool_calls_streaming(default_multi_llm: LitellmLLM) -> None:
|
||||
)
|
||||
|
||||
|
||||
def test_vertex_stream_omits_stream_options() -> None:
|
||||
llm = LitellmLLM(
|
||||
api_key="test_key",
|
||||
timeout=30,
|
||||
model_provider=LlmProviderNames.VERTEX_AI,
|
||||
model_name="claude-opus-4-5@20251101",
|
||||
max_input_tokens=get_max_input_tokens(
|
||||
model_provider=LlmProviderNames.VERTEX_AI,
|
||||
model_name="claude-opus-4-5@20251101",
|
||||
),
|
||||
)
|
||||
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
mock_completion.return_value = []
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
list(llm.stream(messages))
|
||||
|
||||
kwargs = mock_completion.call_args.kwargs
|
||||
assert "stream_options" not in kwargs
|
||||
|
||||
|
||||
def test_vertex_opus_4_5_omits_reasoning_effort() -> None:
|
||||
llm = LitellmLLM(
|
||||
api_key="test_key",
|
||||
timeout=30,
|
||||
model_provider=LlmProviderNames.VERTEX_AI,
|
||||
model_name="claude-opus-4-5@20251101",
|
||||
max_input_tokens=get_max_input_tokens(
|
||||
model_provider=LlmProviderNames.VERTEX_AI,
|
||||
model_name="claude-opus-4-5@20251101",
|
||||
),
|
||||
)
|
||||
|
||||
with (
|
||||
patch("litellm.completion") as mock_completion,
|
||||
patch("onyx.llm.multi_llm.model_is_reasoning_model", return_value=True),
|
||||
):
|
||||
mock_completion.return_value = []
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
list(llm.stream(messages))
|
||||
|
||||
kwargs = mock_completion.call_args.kwargs
|
||||
assert "reasoning_effort" not in kwargs
|
||||
|
||||
|
||||
def test_user_identity_metadata_enabled(default_multi_llm: LitellmLLM) -> None:
|
||||
with (
|
||||
patch("litellm.completion") as mock_completion,
|
||||
|
||||
281
backend/tests/unit/onyx/onyxbot/discord/conftest.py
Normal file
281
backend/tests/unit/onyx/onyxbot/discord/conftest.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""Fixtures for Discord bot unit tests."""
|
||||
|
||||
import random
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import discord
|
||||
import pytest
|
||||
|
||||
|
||||
class AsyncIteratorMock:
|
||||
"""Helper class to mock async iterators like channel.history()."""
|
||||
|
||||
def __init__(self, items: list[Any]) -> None:
|
||||
self.items = items
|
||||
self.index = 0
|
||||
|
||||
def __aiter__(self) -> "AsyncIteratorMock":
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> Any:
|
||||
if self.index >= len(self.items):
|
||||
raise StopAsyncIteration
|
||||
item = self.items[self.index]
|
||||
self.index += 1
|
||||
return item
|
||||
|
||||
|
||||
def mock_message(
|
||||
content: str = "Test message",
|
||||
author_bot: bool = False,
|
||||
message_type: discord.MessageType = discord.MessageType.default,
|
||||
reference: MagicMock | None = None,
|
||||
message_id: int | None = None,
|
||||
author_id: int | None = None,
|
||||
author_display_name: str | None = None,
|
||||
) -> MagicMock:
|
||||
"""Helper to create mock Discord messages."""
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.id = message_id or random.randint(100000, 999999)
|
||||
msg.content = content
|
||||
msg.author = MagicMock()
|
||||
msg.author.id = author_id or random.randint(100000, 999999)
|
||||
msg.author.bot = author_bot
|
||||
msg.author.display_name = author_display_name or ("Bot" if author_bot else "User")
|
||||
msg.type = message_type
|
||||
msg.reference = reference
|
||||
msg.mentions = []
|
||||
msg.role_mentions = []
|
||||
msg.channel_mentions = []
|
||||
return msg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_bot_user() -> MagicMock:
|
||||
"""Mock Discord bot user."""
|
||||
user = MagicMock(spec=discord.ClientUser)
|
||||
user.id = 123456789
|
||||
user.display_name = "OnyxBot"
|
||||
user.bot = True
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord_guild() -> MagicMock:
|
||||
"""Mock Discord guild with channels."""
|
||||
guild = MagicMock(spec=discord.Guild)
|
||||
guild.id = 987654321
|
||||
guild.name = "Test Server"
|
||||
guild.default_role = MagicMock()
|
||||
|
||||
# Create some mock channels
|
||||
text_channel = MagicMock(spec=discord.TextChannel)
|
||||
text_channel.id = 111111111
|
||||
text_channel.name = "general"
|
||||
text_channel.type = discord.ChannelType.text
|
||||
perms = MagicMock()
|
||||
perms.view_channel = True
|
||||
text_channel.permissions_for.return_value = perms
|
||||
|
||||
forum_channel = MagicMock(spec=discord.ForumChannel)
|
||||
forum_channel.id = 222222222
|
||||
forum_channel.name = "forum"
|
||||
forum_channel.type = discord.ChannelType.forum
|
||||
forum_channel.permissions_for.return_value = perms
|
||||
|
||||
guild.channels = [text_channel, forum_channel]
|
||||
guild.text_channels = [text_channel]
|
||||
guild.forum_channels = [forum_channel]
|
||||
|
||||
return guild
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_discord_message(mock_bot_user: MagicMock) -> MagicMock:
|
||||
"""Mock Discord message for testing."""
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.id = 555555555
|
||||
msg.author = MagicMock()
|
||||
msg.author.id = 444444444
|
||||
msg.author.bot = False
|
||||
msg.author.display_name = "TestUser"
|
||||
msg.content = "Hello bot"
|
||||
msg.guild = MagicMock()
|
||||
msg.guild.id = 987654321
|
||||
msg.guild.name = "Test Server"
|
||||
msg.channel = MagicMock()
|
||||
msg.channel.id = 111111111
|
||||
msg.channel.name = "general"
|
||||
msg.type = discord.MessageType.default
|
||||
msg.mentions = []
|
||||
msg.role_mentions = []
|
||||
msg.channel_mentions = []
|
||||
msg.reference = None
|
||||
return msg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_thread_with_messages(mock_bot_user: MagicMock) -> MagicMock:
|
||||
"""Mock Discord thread with message history."""
|
||||
thread = MagicMock(spec=discord.Thread)
|
||||
thread.id = 666666666
|
||||
thread.name = "Test Thread"
|
||||
thread.owner_id = mock_bot_user.id
|
||||
thread.parent = MagicMock(spec=discord.TextChannel)
|
||||
thread.parent.id = 111111111
|
||||
|
||||
# Mock starter message
|
||||
starter = mock_message(
|
||||
content="Thread starter message",
|
||||
author_bot=False,
|
||||
message_id=thread.id,
|
||||
)
|
||||
|
||||
messages = [
|
||||
mock_message(author_bot=False, content="User msg 1", message_id=100),
|
||||
mock_message(author_bot=True, content="Bot response", message_id=101),
|
||||
mock_message(author_bot=False, content="User msg 2", message_id=102),
|
||||
]
|
||||
|
||||
# Setup async iterator for history
|
||||
def history(**kwargs: Any) -> AsyncIteratorMock:
|
||||
return AsyncIteratorMock(messages)
|
||||
|
||||
thread.history = history
|
||||
|
||||
# Mock parent.fetch_message
|
||||
async def fetch_starter(msg_id: int) -> MagicMock:
|
||||
if msg_id == thread.id:
|
||||
return starter
|
||||
raise discord.NotFound(MagicMock(), "Not found")
|
||||
|
||||
thread.parent.fetch_message = AsyncMock(side_effect=fetch_starter)
|
||||
|
||||
return thread
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_thread_forum_parent() -> MagicMock:
|
||||
"""Mock thread with ForumChannel parent (special case)."""
|
||||
thread = MagicMock(spec=discord.Thread)
|
||||
thread.id = 777777777
|
||||
thread.name = "Forum Post"
|
||||
thread.parent = MagicMock(spec=discord.ForumChannel)
|
||||
thread.parent.id = 222222222
|
||||
return thread
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_reply_chain() -> MagicMock:
|
||||
"""Mock message with reply chain."""
|
||||
# Build chain backwards: msg3 -> msg2 -> msg1
|
||||
ref3 = MagicMock()
|
||||
ref3.message_id = 1003
|
||||
|
||||
ref2 = MagicMock()
|
||||
ref2.message_id = 1002
|
||||
|
||||
msg3 = mock_message(content="Third message", reference=None, message_id=1003)
|
||||
msg2 = mock_message(content="Second message", reference=ref3, message_id=1002)
|
||||
msg1 = mock_message(content="First message", reference=ref2, message_id=1001)
|
||||
|
||||
# Store messages for lookup
|
||||
msg1._chain = {1002: msg2, 1003: msg3}
|
||||
msg2._chain = {1003: msg3}
|
||||
|
||||
return msg1
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_guild_config_enabled() -> MagicMock:
|
||||
"""Guild config that is enabled."""
|
||||
config = MagicMock()
|
||||
config.id = 1
|
||||
config.guild_id = 987654321
|
||||
config.enabled = True
|
||||
config.default_persona_id = 1
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_guild_config_disabled() -> MagicMock:
|
||||
"""Guild config that is disabled."""
|
||||
config = MagicMock()
|
||||
config.id = 2
|
||||
config.guild_id = 987654321
|
||||
config.enabled = False
|
||||
config.default_persona_id = None
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_channel_config_factory() -> Callable[..., MagicMock]:
|
||||
"""Factory fixture for creating channel configs with various settings."""
|
||||
|
||||
def _make_config(
|
||||
enabled: bool = True,
|
||||
require_bot_invocation: bool = True,
|
||||
thread_only_mode: bool = False,
|
||||
persona_override_id: int | None = None,
|
||||
) -> MagicMock:
|
||||
config = MagicMock()
|
||||
config.id = random.randint(1, 1000)
|
||||
config.channel_id = 111111111
|
||||
config.enabled = enabled
|
||||
config.require_bot_invocation = require_bot_invocation
|
||||
config.thread_only_mode = thread_only_mode
|
||||
config.persona_override_id = persona_override_id
|
||||
return config
|
||||
|
||||
return _make_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message_with_bot_mention(mock_bot_user: MagicMock) -> MagicMock:
|
||||
"""Message that mentions the bot."""
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.id = 888888888
|
||||
msg.mentions = [mock_bot_user]
|
||||
msg.author = MagicMock()
|
||||
msg.author.id = 444444444
|
||||
msg.author.bot = False
|
||||
msg.author.display_name = "TestUser"
|
||||
msg.type = discord.MessageType.default
|
||||
msg.content = f"<@{mock_bot_user.id}> hello"
|
||||
msg.reference = None
|
||||
msg.guild = MagicMock()
|
||||
msg.guild.id = 987654321
|
||||
msg.channel = MagicMock()
|
||||
msg.channel.id = 111111111
|
||||
msg.role_mentions = []
|
||||
msg.channel_mentions = []
|
||||
return msg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_guild_with_members() -> MagicMock:
|
||||
"""Mock guild for mention resolution."""
|
||||
guild = MagicMock(spec=discord.Guild)
|
||||
|
||||
def get_member(member_id: int) -> MagicMock:
|
||||
member = MagicMock()
|
||||
member.display_name = f"User{member_id}"
|
||||
return member
|
||||
|
||||
def get_role(role_id: int) -> MagicMock:
|
||||
role = MagicMock()
|
||||
role.name = f"Role{role_id}"
|
||||
return role
|
||||
|
||||
def get_channel(channel_id: int) -> MagicMock:
|
||||
channel = MagicMock()
|
||||
channel.name = f"channel{channel_id}"
|
||||
return channel
|
||||
|
||||
guild.get_member = get_member
|
||||
guild.get_role = get_role
|
||||
guild.get_channel = get_channel
|
||||
return guild
|
||||
441
backend/tests/unit/onyx/onyxbot/discord/test_api_client.py
Normal file
441
backend/tests/unit/onyx/onyxbot/discord/test_api_client.py
Normal file
@@ -0,0 +1,441 @@
|
||||
"""Unit tests for Discord bot API client.
|
||||
|
||||
Tests for OnyxAPIClient class functionality.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.onyxbot.discord.api_client import OnyxAPIClient
|
||||
from onyx.onyxbot.discord.constants import API_REQUEST_TIMEOUT
|
||||
from onyx.onyxbot.discord.exceptions import APIConnectionError
|
||||
from onyx.onyxbot.discord.exceptions import APIResponseError
|
||||
from onyx.onyxbot.discord.exceptions import APITimeoutError
|
||||
|
||||
|
||||
class MockAsyncContextManager:
|
||||
"""Helper class to create proper async context managers for testing."""
|
||||
|
||||
def __init__(
|
||||
self, return_value: Any = None, enter_side_effect: Exception | None = None
|
||||
) -> None:
|
||||
self.return_value = return_value
|
||||
self.enter_side_effect = enter_side_effect
|
||||
|
||||
async def __aenter__(self) -> Any:
|
||||
if self.enter_side_effect:
|
||||
raise self.enter_side_effect
|
||||
return self.return_value
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class TestClientLifecycle:
|
||||
"""Tests for API client lifecycle management."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_creates_session(self) -> None:
|
||||
"""initialize() creates aiohttp session."""
|
||||
client = OnyxAPIClient()
|
||||
assert client._session is None
|
||||
|
||||
with patch("aiohttp.ClientSession") as mock_session_class:
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
await client.initialize()
|
||||
|
||||
assert client._session is not None
|
||||
mock_session_class.assert_called_once()
|
||||
|
||||
def test_is_initialized_before_init(self) -> None:
|
||||
"""is_initialized returns False before initialize()."""
|
||||
client = OnyxAPIClient()
|
||||
assert client.is_initialized is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_initialized_after_init(self) -> None:
|
||||
"""is_initialized returns True after initialize()."""
|
||||
client = OnyxAPIClient()
|
||||
|
||||
with patch("aiohttp.ClientSession"):
|
||||
await client.initialize()
|
||||
|
||||
assert client.is_initialized is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_closes_session(self) -> None:
|
||||
"""close() closes session and resets is_initialized."""
|
||||
client = OnyxAPIClient()
|
||||
|
||||
mock_session = AsyncMock()
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
await client.initialize()
|
||||
assert client.is_initialized is True
|
||||
|
||||
await client.close()
|
||||
|
||||
assert client.is_initialized is False
|
||||
mock_session.close.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_not_initialized(self) -> None:
|
||||
"""send_chat_message() before initialize() raises APIConnectionError."""
|
||||
client = OnyxAPIClient()
|
||||
|
||||
with pytest.raises(APIConnectionError) as exc_info:
|
||||
await client.send_chat_message("test", "api_key")
|
||||
|
||||
assert "not initialized" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestSendChatMessage:
|
||||
"""Tests for send_chat_message functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_success(self) -> None:
|
||||
"""Valid request returns ChatFullResponse."""
|
||||
client = OnyxAPIClient()
|
||||
|
||||
response_data = {
|
||||
"answer": "Test response",
|
||||
"citations": [],
|
||||
"error_msg": None,
|
||||
}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value=response_data)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.post = MagicMock(
|
||||
return_value=MockAsyncContextManager(return_value=mock_response)
|
||||
)
|
||||
|
||||
client._session = mock_session
|
||||
|
||||
with patch.object(
|
||||
ChatFullResponse,
|
||||
"model_validate",
|
||||
return_value=MagicMock(answer="Test response", error_msg=None),
|
||||
):
|
||||
result = await client.send_chat_message("Hello", "api_key_123")
|
||||
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_with_persona(self) -> None:
|
||||
"""persona_id is passed to API."""
|
||||
client = OnyxAPIClient()
|
||||
|
||||
response_data = {"answer": "Response", "citations": [], "error_msg": None}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value=response_data)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_post = MagicMock(
|
||||
return_value=MockAsyncContextManager(return_value=mock_response)
|
||||
)
|
||||
mock_session.post = mock_post
|
||||
|
||||
client._session = mock_session
|
||||
|
||||
with patch.object(
|
||||
ChatFullResponse,
|
||||
"model_validate",
|
||||
return_value=MagicMock(answer="Response", error_msg=None),
|
||||
):
|
||||
await client.send_chat_message("Hello", "api_key", persona_id=5)
|
||||
|
||||
# Verify persona was included in request
|
||||
call_args = mock_post.call_args
|
||||
json_data = call_args.kwargs.get("json") or call_args[1].get("json")
|
||||
assert json_data is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_401_error(self) -> None:
|
||||
"""Invalid API key returns APIResponseError with 401."""
|
||||
client = OnyxAPIClient()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 401
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.post = MagicMock(
|
||||
return_value=MockAsyncContextManager(return_value=mock_response)
|
||||
)
|
||||
|
||||
client._session = mock_session
|
||||
|
||||
with pytest.raises(APIResponseError) as exc_info:
|
||||
await client.send_chat_message("Hello", "bad_key")
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_403_error(self) -> None:
|
||||
"""Persona not accessible returns APIResponseError with 403."""
|
||||
client = OnyxAPIClient()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 403
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.post = MagicMock(
|
||||
return_value=MockAsyncContextManager(return_value=mock_response)
|
||||
)
|
||||
|
||||
client._session = mock_session
|
||||
|
||||
with pytest.raises(APIResponseError) as exc_info:
|
||||
await client.send_chat_message("Hello", "api_key", persona_id=999)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_timeout(self) -> None:
|
||||
"""Request timeout raises APITimeoutError."""
|
||||
client = OnyxAPIClient()
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.post = MagicMock(
|
||||
return_value=MockAsyncContextManager(
|
||||
enter_side_effect=TimeoutError("Timeout")
|
||||
)
|
||||
)
|
||||
|
||||
client._session = mock_session
|
||||
|
||||
with pytest.raises(APITimeoutError):
|
||||
await client.send_chat_message("Hello", "api_key")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_connection_error(self) -> None:
|
||||
"""Network failure raises APIConnectionError."""
|
||||
client = OnyxAPIClient()
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.post = MagicMock(
|
||||
return_value=MockAsyncContextManager(
|
||||
enter_side_effect=aiohttp.ClientConnectorError(
|
||||
MagicMock(), OSError("Connection refused")
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
client._session = mock_session
|
||||
|
||||
with pytest.raises(APIConnectionError):
|
||||
await client.send_chat_message("Hello", "api_key")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_server_error(self) -> None:
|
||||
"""500 response raises APIResponseError with 500."""
|
||||
client = OnyxAPIClient()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 500
|
||||
mock_response.text = AsyncMock(return_value="Internal Server Error")
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.post = MagicMock(
|
||||
return_value=MockAsyncContextManager(return_value=mock_response)
|
||||
)
|
||||
|
||||
client._session = mock_session
|
||||
|
||||
with pytest.raises(APIResponseError) as exc_info:
|
||||
await client.send_chat_message("Hello", "api_key")
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
|
||||
|
||||
class TestHealthCheck:
|
||||
"""Tests for health_check functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_success(self) -> None:
|
||||
"""Server healthy returns True."""
|
||||
client = OnyxAPIClient()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.get = MagicMock(
|
||||
return_value=MockAsyncContextManager(return_value=mock_response)
|
||||
)
|
||||
|
||||
client._session = mock_session
|
||||
|
||||
result = await client.health_check()
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_failure(self) -> None:
|
||||
"""Server unhealthy returns False."""
|
||||
client = OnyxAPIClient()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 503
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.get = MagicMock(
|
||||
return_value=MockAsyncContextManager(return_value=mock_response)
|
||||
)
|
||||
|
||||
client._session = mock_session
|
||||
|
||||
result = await client.health_check()
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_timeout(self) -> None:
|
||||
"""Request times out returns False."""
|
||||
client = OnyxAPIClient()
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.get = MagicMock(
|
||||
return_value=MockAsyncContextManager(
|
||||
enter_side_effect=TimeoutError("Timeout")
|
||||
)
|
||||
)
|
||||
|
||||
client._session = mock_session
|
||||
|
||||
result = await client.health_check()
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_not_initialized(self) -> None:
|
||||
"""Health check before initialize returns False."""
|
||||
client = OnyxAPIClient()
|
||||
|
||||
result = await client.health_check()
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestResponseParsing:
|
||||
"""Tests for API response parsing."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_malformed_json(self) -> None:
|
||||
"""API returns invalid JSON raises exception."""
|
||||
client = OnyxAPIClient()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(side_effect=ValueError("Invalid JSON"))
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.post = MagicMock(
|
||||
return_value=MockAsyncContextManager(return_value=mock_response)
|
||||
)
|
||||
|
||||
client._session = mock_session
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await client.send_chat_message("Hello", "api_key")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_with_error_msg(self) -> None:
|
||||
"""200 status but error_msg present - warning logged, response returned."""
|
||||
client = OnyxAPIClient()
|
||||
|
||||
response_data = {
|
||||
"answer": "Partial response",
|
||||
"citations": [],
|
||||
"error_msg": "Some warning",
|
||||
}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value=response_data)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.post = MagicMock(
|
||||
return_value=MockAsyncContextManager(return_value=mock_response)
|
||||
)
|
||||
|
||||
client._session = mock_session
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.answer = "Partial response"
|
||||
mock_result.error_msg = "Some warning"
|
||||
|
||||
with patch.object(ChatFullResponse, "model_validate", return_value=mock_result):
|
||||
result = await client.send_chat_message("Hello", "api_key")
|
||||
|
||||
# Should still return response
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_empty_answer(self) -> None:
|
||||
"""answer field is empty string - handled gracefully."""
|
||||
client = OnyxAPIClient()
|
||||
|
||||
response_data = {
|
||||
"answer": "",
|
||||
"citations": [],
|
||||
"error_msg": None,
|
||||
}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value=response_data)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.post = MagicMock(
|
||||
return_value=MockAsyncContextManager(return_value=mock_response)
|
||||
)
|
||||
|
||||
client._session = mock_session
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.answer = ""
|
||||
mock_result.error_msg = None
|
||||
|
||||
with patch.object(ChatFullResponse, "model_validate", return_value=mock_result):
|
||||
result = await client.send_chat_message("Hello", "api_key")
|
||||
|
||||
# Should return response even with empty answer
|
||||
assert result is not None
|
||||
|
||||
|
||||
class TestClientConfiguration:
|
||||
"""Tests for client configuration."""
|
||||
|
||||
def test_default_timeout(self) -> None:
|
||||
"""Client uses API_REQUEST_TIMEOUT by default."""
|
||||
client = OnyxAPIClient()
|
||||
assert client._timeout == API_REQUEST_TIMEOUT
|
||||
|
||||
def test_custom_timeout(self) -> None:
|
||||
"""Client accepts custom timeout."""
|
||||
client = OnyxAPIClient(timeout=60)
|
||||
assert client._timeout == 60
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_double_initialize_warning(self) -> None:
|
||||
"""Calling initialize() twice logs warning but doesn't error."""
|
||||
client = OnyxAPIClient()
|
||||
|
||||
with patch("aiohttp.ClientSession") as mock_session_class:
|
||||
mock_session = MagicMock()
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
await client.initialize()
|
||||
# Second call should be safe
|
||||
await client.initialize()
|
||||
|
||||
# Should only create one session
|
||||
assert mock_session_class.call_count == 1
|
||||
520
backend/tests/unit/onyx/onyxbot/discord/test_cache_manager.py
Normal file
520
backend/tests/unit/onyx/onyxbot/discord/test_cache_manager.py
Normal file
@@ -0,0 +1,520 @@
|
||||
"""Unit tests for Discord bot cache manager.
|
||||
|
||||
Tests for DiscordCacheManager class functionality.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.onyxbot.discord.cache import DiscordCacheManager
|
||||
|
||||
|
||||
class TestCacheInitialization:
|
||||
"""Tests for cache initialization."""
|
||||
|
||||
def test_cache_starts_empty(self) -> None:
|
||||
"""New cache manager has empty caches."""
|
||||
cache = DiscordCacheManager()
|
||||
assert cache._guild_tenants == {}
|
||||
assert cache._api_keys == {}
|
||||
assert cache.is_initialized is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_refresh_all_loads_guilds(self) -> None:
|
||||
"""refresh_all() loads all active guilds."""
|
||||
cache = DiscordCacheManager()
|
||||
|
||||
mock_config1 = MagicMock()
|
||||
mock_config1.guild_id = 111111
|
||||
mock_config1.enabled = True
|
||||
|
||||
mock_config2 = MagicMock()
|
||||
mock_config2.guild_id = 222222
|
||||
mock_config2.enabled = True
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
|
||||
return_value=["tenant1"],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda: set(),
|
||||
),
|
||||
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_guild_configs",
|
||||
return_value=[mock_config1, mock_config2],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_or_create_discord_service_api_key",
|
||||
return_value="test_api_key",
|
||||
),
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
await cache.refresh_all()
|
||||
|
||||
assert cache.is_initialized is True
|
||||
assert 111111 in cache._guild_tenants
|
||||
assert 222222 in cache._guild_tenants
|
||||
assert cache._guild_tenants[111111] == "tenant1"
|
||||
assert cache._guild_tenants[222222] == "tenant1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_refresh_provisions_api_key(self) -> None:
|
||||
"""Refresh for tenant without key creates API key."""
|
||||
cache = DiscordCacheManager()
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.guild_id = 111111
|
||||
mock_config.enabled = True
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
|
||||
return_value=["tenant1"],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda: set(),
|
||||
),
|
||||
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_guild_configs",
|
||||
return_value=[mock_config],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_or_create_discord_service_api_key",
|
||||
return_value="new_api_key",
|
||||
) as mock_provision,
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
await cache.refresh_all()
|
||||
|
||||
assert cache._api_keys.get("tenant1") == "new_api_key"
|
||||
mock_provision.assert_called()
|
||||
|
||||
|
||||
class TestCacheLookups:
|
||||
"""Tests for cache lookup operations."""
|
||||
|
||||
def test_get_tenant_returns_correct(self) -> None:
|
||||
"""Lookup registered guild returns correct tenant ID."""
|
||||
cache = DiscordCacheManager()
|
||||
cache._guild_tenants[123456] = "tenant1"
|
||||
|
||||
result = cache.get_tenant(123456)
|
||||
assert result == "tenant1"
|
||||
|
||||
def test_get_tenant_returns_none_unknown(self) -> None:
|
||||
"""Lookup unregistered guild returns None."""
|
||||
cache = DiscordCacheManager()
|
||||
|
||||
result = cache.get_tenant(999999)
|
||||
assert result is None
|
||||
|
||||
def test_get_api_key_returns_correct(self) -> None:
|
||||
"""Lookup tenant's API key returns valid key."""
|
||||
cache = DiscordCacheManager()
|
||||
cache._api_keys["tenant1"] = "api_key_123"
|
||||
|
||||
result = cache.get_api_key("tenant1")
|
||||
assert result == "api_key_123"
|
||||
|
||||
def test_get_api_key_returns_none_unknown(self) -> None:
|
||||
"""Lookup unknown tenant returns None."""
|
||||
cache = DiscordCacheManager()
|
||||
|
||||
result = cache.get_api_key("unknown_tenant")
|
||||
assert result is None
|
||||
|
||||
def test_get_all_guild_ids(self) -> None:
|
||||
"""After loading returns all cached guild IDs."""
|
||||
cache = DiscordCacheManager()
|
||||
cache._guild_tenants = {111: "t1", 222: "t2", 333: "t1"}
|
||||
|
||||
result = cache.get_all_guild_ids()
|
||||
assert set(result) == {111, 222, 333}
|
||||
|
||||
|
||||
class TestCacheUpdates:
|
||||
"""Tests for cache update operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_guild_adds_new(self) -> None:
|
||||
"""refresh_guild() for new guild adds it to cache."""
|
||||
cache = DiscordCacheManager()
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.guild_id = 111111
|
||||
mock_config.enabled = True
|
||||
|
||||
with (
|
||||
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_guild_configs",
|
||||
return_value=[mock_config],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_or_create_discord_service_api_key",
|
||||
return_value="api_key",
|
||||
),
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
await cache.refresh_guild(111111, "tenant1")
|
||||
|
||||
assert cache.get_tenant(111111) == "tenant1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_guild_verifies_active(self) -> None:
|
||||
"""refresh_guild() for disabled guild doesn't add it."""
|
||||
cache = DiscordCacheManager()
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.guild_id = 111111
|
||||
mock_config.enabled = False # Disabled!
|
||||
|
||||
with (
|
||||
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_guild_configs",
|
||||
return_value=[mock_config],
|
||||
),
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
await cache.refresh_guild(111111, "tenant1")
|
||||
|
||||
# Should not be added because it's disabled
|
||||
assert cache.get_tenant(111111) is None
|
||||
|
||||
def test_remove_guild(self) -> None:
|
||||
"""remove_guild() removes guild from cache."""
|
||||
cache = DiscordCacheManager()
|
||||
cache._guild_tenants[111111] = "tenant1"
|
||||
|
||||
cache.remove_guild(111111)
|
||||
|
||||
assert cache.get_tenant(111111) is None
|
||||
|
||||
def test_clear_removes_all(self) -> None:
|
||||
"""clear() empties all caches."""
|
||||
cache = DiscordCacheManager()
|
||||
cache._guild_tenants = {111: "t1", 222: "t2"}
|
||||
cache._api_keys = {"t1": "key1", "t2": "key2"}
|
||||
cache._initialized = True
|
||||
|
||||
cache.clear()
|
||||
|
||||
assert cache._guild_tenants == {}
|
||||
assert cache._api_keys == {}
|
||||
assert cache.is_initialized is False
|
||||
|
||||
|
||||
class TestThreadSafety:
|
||||
"""Tests for thread/async safety."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_refresh_no_race(self) -> None:
|
||||
"""Multiple concurrent refresh_all() calls don't corrupt data."""
|
||||
cache = DiscordCacheManager()
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.guild_id = 111111
|
||||
mock_config.enabled = True
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def slow_refresh() -> tuple[list[int], str]:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
# Simulate slow operation
|
||||
await asyncio.sleep(0.01)
|
||||
return ([111111], "api_key")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
|
||||
return_value=["tenant1"],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda: set(),
|
||||
),
|
||||
patch.object(cache, "_load_tenant_data", side_effect=slow_refresh),
|
||||
):
|
||||
# Run multiple concurrent refreshes
|
||||
await asyncio.gather(
|
||||
cache.refresh_all(),
|
||||
cache.refresh_all(),
|
||||
cache.refresh_all(),
|
||||
)
|
||||
|
||||
# Each refresh should complete without error
|
||||
assert cache.is_initialized is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_read_write(self) -> None:
|
||||
"""Read during refresh doesn't cause exceptions."""
|
||||
cache = DiscordCacheManager()
|
||||
cache._guild_tenants[111111] = "tenant1"
|
||||
|
||||
async def read_loop() -> None:
|
||||
for _ in range(10):
|
||||
cache.get_tenant(111111)
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
async def write_loop() -> None:
|
||||
for i in range(10):
|
||||
cache._guild_tenants[200000 + i] = f"tenant{i}"
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
# Should not raise any exceptions
|
||||
await asyncio.gather(read_loop(), write_loop())
|
||||
|
||||
|
||||
class TestAPIKeyProvisioning:
|
||||
"""Tests for API key provisioning via cache refresh."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_created_on_first_refresh(self) -> None:
|
||||
"""Cache refresh with no existing key creates new API key."""
|
||||
cache = DiscordCacheManager()
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.guild_id = 111111
|
||||
mock_config.enabled = True
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
|
||||
return_value=["tenant1"],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda: set(),
|
||||
),
|
||||
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_guild_configs",
|
||||
return_value=[mock_config],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_or_create_discord_service_api_key",
|
||||
return_value="new_api_key_123",
|
||||
) as mock_create,
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
await cache.refresh_all()
|
||||
|
||||
mock_create.assert_called_once()
|
||||
assert cache.get_api_key("tenant1") == "new_api_key_123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_cached_after_creation(self) -> None:
|
||||
"""Subsequent lookups after creation use cached key."""
|
||||
cache = DiscordCacheManager()
|
||||
cache._api_keys["tenant1"] = "cached_key"
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.guild_id = 111111
|
||||
mock_config.enabled = True
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
|
||||
return_value=["tenant1"],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda: set(),
|
||||
),
|
||||
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_guild_configs",
|
||||
return_value=[mock_config],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_or_create_discord_service_api_key",
|
||||
) as mock_create,
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
await cache.refresh_all()
|
||||
|
||||
# Should NOT call create because key is already cached
|
||||
mock_create.assert_not_called()
|
||||
# Cached key should be preserved after refresh
|
||||
assert cache.get_api_key("tenant1") == "cached_key"
|
||||
|
||||
|
||||
class TestGatedTenantHandling:
|
||||
"""Tests for gated tenant filtering."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_skips_gated_tenants(self) -> None:
|
||||
"""Gated tenant's guilds are not loaded."""
|
||||
cache = DiscordCacheManager()
|
||||
|
||||
# tenant2 is gated
|
||||
gated_tenants = {"tenant2"}
|
||||
|
||||
mock_config_t1 = MagicMock()
|
||||
mock_config_t1.guild_id = 111111
|
||||
mock_config_t1.enabled = True
|
||||
|
||||
mock_config_t2 = MagicMock()
|
||||
mock_config_t2.guild_id = 222222
|
||||
mock_config_t2.enabled = True
|
||||
|
||||
def mock_get_configs(db: MagicMock) -> list[MagicMock]:
|
||||
# Track which tenant this was called for
|
||||
return [mock_config_t1] # Always return same for simplicity
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
|
||||
return_value=["tenant1", "tenant2"],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda: gated_tenants,
|
||||
),
|
||||
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_guild_configs",
|
||||
side_effect=mock_get_configs,
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_or_create_discord_service_api_key",
|
||||
return_value="api_key",
|
||||
),
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
await cache.refresh_all()
|
||||
|
||||
# Only tenant1 should be loaded (tenant2 is gated)
|
||||
assert "tenant1" in cache._api_keys and 111111 in cache._guild_tenants
|
||||
# tenant2's guilds should NOT be in cache
|
||||
assert "tenant2" not in cache._api_keys and 222222 not in cache._guild_tenants
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gated_check_calls_ee_function(self) -> None:
|
||||
"""Refresh all tenants calls fetch_ee_implementation_or_noop."""
|
||||
cache = DiscordCacheManager()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
|
||||
return_value=["tenant1"],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda: set(),
|
||||
) as mock_ee,
|
||||
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_guild_configs",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
await cache.refresh_all()
|
||||
|
||||
mock_ee.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ungated_tenant_included(self) -> None:
|
||||
"""Regular (ungated) tenant has guilds loaded normally."""
|
||||
cache = DiscordCacheManager()
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.guild_id = 111111
|
||||
mock_config.enabled = True
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
|
||||
return_value=["tenant1"],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda: set(), # No gated tenants
|
||||
),
|
||||
patch("onyx.onyxbot.discord.cache.get_session_with_tenant") as mock_session,
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_guild_configs",
|
||||
return_value=[mock_config],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_or_create_discord_service_api_key",
|
||||
return_value="api_key",
|
||||
),
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
await cache.refresh_all()
|
||||
|
||||
assert cache.get_tenant(111111) == "tenant1"
|
||||
|
||||
|
||||
class TestCacheErrorHandling:
|
||||
"""Tests for error handling in cache operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_all_handles_tenant_error(self) -> None:
|
||||
"""Error loading one tenant doesn't stop others."""
|
||||
cache = DiscordCacheManager()
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_load(tenant_id: str) -> tuple[list[int], str]:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if tenant_id == "tenant1":
|
||||
raise Exception("Tenant 1 error")
|
||||
return ([222222], "api_key")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.get_all_tenant_ids",
|
||||
return_value=["tenant1", "tenant2"],
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.cache.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda: set(),
|
||||
),
|
||||
patch.object(cache, "_load_tenant_data", side_effect=mock_load),
|
||||
):
|
||||
await cache.refresh_all()
|
||||
|
||||
# Should still complete and load tenant2
|
||||
assert call_count == 2 # Both tenants attempted
|
||||
assert cache.get_tenant(222222) == "tenant2"
|
||||
645
backend/tests/unit/onyx/onyxbot/discord/test_context_builders.py
Normal file
645
backend/tests/unit/onyx/onyxbot/discord/test_context_builders.py
Normal file
@@ -0,0 +1,645 @@
|
||||
"""Unit tests for Discord bot context builders.
|
||||
|
||||
Tests the thread and reply context building logic with mocked Discord API.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import discord
|
||||
import pytest
|
||||
|
||||
from onyx.onyxbot.discord.constants import MAX_CONTEXT_MESSAGES
|
||||
from onyx.onyxbot.discord.handle_message import _build_conversation_context
|
||||
from onyx.onyxbot.discord.handle_message import _build_reply_chain_context
|
||||
from onyx.onyxbot.discord.handle_message import _build_thread_context
|
||||
from onyx.onyxbot.discord.handle_message import _format_messages_as_context
|
||||
from onyx.onyxbot.discord.handle_message import format_message_content
|
||||
from tests.unit.onyx.onyxbot.discord.conftest import AsyncIteratorMock
|
||||
from tests.unit.onyx.onyxbot.discord.conftest import mock_message
|
||||
|
||||
|
||||
class TestThreadContextBuilder:
|
||||
"""Tests for _build_thread_context function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_thread_context_basic(
|
||||
self, mock_thread_with_messages: MagicMock, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""Thread with messages returns context in order."""
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.id = 999 # Current message ID
|
||||
msg.channel = mock_thread_with_messages
|
||||
|
||||
result = await _build_thread_context(msg, mock_bot_user)
|
||||
|
||||
assert result is not None
|
||||
assert "Conversation history" in result
|
||||
# Should contain message content
|
||||
assert "User msg" in result or "Bot response" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_thread_context_max_limit(
|
||||
self, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""Thread with 20 messages returns only MAX_CONTEXT_MESSAGES."""
|
||||
# Create 20 messages
|
||||
messages = [
|
||||
mock_message(content=f"Message {i}", message_id=i) for i in range(20)
|
||||
]
|
||||
|
||||
thread = MagicMock(spec=discord.Thread)
|
||||
thread.id = 666666
|
||||
thread.parent = MagicMock(spec=discord.TextChannel)
|
||||
|
||||
def history(**kwargs: Any) -> AsyncIteratorMock:
|
||||
limit = kwargs.get("limit", MAX_CONTEXT_MESSAGES)
|
||||
return AsyncIteratorMock(messages[:limit])
|
||||
|
||||
thread.history = history
|
||||
thread.parent.fetch_message = AsyncMock(
|
||||
side_effect=discord.NotFound(MagicMock(), "")
|
||||
)
|
||||
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.id = 999
|
||||
msg.channel = thread
|
||||
|
||||
result = await _build_thread_context(msg, mock_bot_user)
|
||||
|
||||
assert result is not None
|
||||
# Should only have MAX_CONTEXT_MESSAGES worth of content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_thread_context_includes_starter(
|
||||
self, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""Thread with starter message includes it at beginning."""
|
||||
starter = mock_message(
|
||||
content="This is the thread starter",
|
||||
message_id=666666,
|
||||
)
|
||||
|
||||
thread = MagicMock(spec=discord.Thread)
|
||||
thread.id = 666666
|
||||
thread.parent = MagicMock(spec=discord.TextChannel)
|
||||
thread.parent.fetch_message = AsyncMock(return_value=starter)
|
||||
|
||||
messages = [
|
||||
mock_message(content="Reply 1", message_id=1),
|
||||
mock_message(content="Reply 2", message_id=2),
|
||||
]
|
||||
|
||||
def history(**kwargs: Any) -> AsyncIteratorMock:
|
||||
return AsyncIteratorMock(messages)
|
||||
|
||||
thread.history = history
|
||||
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.id = 999
|
||||
msg.channel = thread
|
||||
|
||||
result = await _build_thread_context(msg, mock_bot_user)
|
||||
|
||||
assert result is not None
|
||||
assert "thread starter" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_thread_context_filters_system_messages(
|
||||
self, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""Thread with system messages only includes content messages."""
|
||||
messages = [
|
||||
mock_message(
|
||||
content="Normal message", message_type=discord.MessageType.default
|
||||
),
|
||||
mock_message(
|
||||
content="", message_type=discord.MessageType.pins_add
|
||||
), # System
|
||||
mock_message(
|
||||
content="Another normal", message_type=discord.MessageType.reply
|
||||
),
|
||||
]
|
||||
|
||||
thread = MagicMock(spec=discord.Thread)
|
||||
thread.id = 666666
|
||||
thread.parent = MagicMock(spec=discord.TextChannel)
|
||||
thread.parent.fetch_message = AsyncMock(
|
||||
side_effect=discord.NotFound(MagicMock(), "")
|
||||
)
|
||||
|
||||
def history(**kwargs: Any) -> AsyncIteratorMock:
|
||||
return AsyncIteratorMock(messages)
|
||||
|
||||
thread.history = history
|
||||
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.id = 999
|
||||
msg.channel = thread
|
||||
|
||||
result = await _build_thread_context(msg, mock_bot_user)
|
||||
|
||||
# Should not include system message type
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_thread_context_includes_bot_messages(
|
||||
self, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""Bot messages in thread are included for context."""
|
||||
messages = [
|
||||
mock_message(content="User question", author_bot=False),
|
||||
mock_message(
|
||||
content="Bot response",
|
||||
author_bot=True,
|
||||
author_id=mock_bot_user.id,
|
||||
author_display_name="OnyxBot",
|
||||
),
|
||||
]
|
||||
|
||||
thread = MagicMock(spec=discord.Thread)
|
||||
thread.id = 666666
|
||||
thread.parent = MagicMock(spec=discord.TextChannel)
|
||||
thread.parent.fetch_message = AsyncMock(
|
||||
side_effect=discord.NotFound(MagicMock(), "")
|
||||
)
|
||||
|
||||
def history(**kwargs: Any) -> AsyncIteratorMock:
|
||||
return AsyncIteratorMock(messages)
|
||||
|
||||
thread.history = history
|
||||
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.id = 999
|
||||
msg.channel = thread
|
||||
|
||||
result = await _build_thread_context(msg, mock_bot_user)
|
||||
|
||||
assert result is not None
|
||||
assert "Bot response" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_thread_context_empty_thread(
|
||||
self, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""Thread with only system messages returns None."""
|
||||
messages = [
|
||||
mock_message(content="", message_type=discord.MessageType.pins_add),
|
||||
]
|
||||
|
||||
thread = MagicMock(spec=discord.Thread)
|
||||
thread.id = 666666
|
||||
thread.parent = MagicMock(spec=discord.TextChannel)
|
||||
thread.parent.fetch_message = AsyncMock(
|
||||
side_effect=discord.NotFound(MagicMock(), "")
|
||||
)
|
||||
|
||||
def history(**kwargs: Any) -> AsyncIteratorMock:
|
||||
return AsyncIteratorMock(messages)
|
||||
|
||||
thread.history = history
|
||||
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.id = 999
|
||||
msg.channel = thread
|
||||
|
||||
await _build_thread_context(msg, mock_bot_user)
|
||||
# Should return None for empty context
|
||||
# (depends on implementation - may return None or empty string)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_thread_context_forum_channel(
|
||||
self, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""Thread parent is ForumChannel - does NOT fetch starter message."""
|
||||
messages = [
|
||||
mock_message(content="Forum reply", message_id=1),
|
||||
]
|
||||
|
||||
thread = MagicMock(spec=discord.Thread)
|
||||
thread.id = 666666
|
||||
thread.parent = MagicMock(spec=discord.ForumChannel) # Forum!
|
||||
# Set up mock before calling function so we can verify it wasn't called
|
||||
thread.parent.fetch_message = AsyncMock()
|
||||
|
||||
def history(**kwargs: Any) -> AsyncIteratorMock:
|
||||
return AsyncIteratorMock(messages)
|
||||
|
||||
thread.history = history
|
||||
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.id = 999
|
||||
msg.channel = thread
|
||||
|
||||
await _build_thread_context(msg, mock_bot_user)
|
||||
|
||||
# Should not try to fetch starter message for forum channels
|
||||
thread.parent.fetch_message.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_thread_context_starter_fetch_fails(
|
||||
self, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""Starter message fetch raises NotFound - continues without starter."""
|
||||
messages = [
|
||||
mock_message(content="Reply message", message_id=1),
|
||||
]
|
||||
|
||||
thread = MagicMock(spec=discord.Thread)
|
||||
thread.id = 666666
|
||||
thread.parent = MagicMock(spec=discord.TextChannel)
|
||||
thread.parent.fetch_message = AsyncMock(
|
||||
side_effect=discord.NotFound(MagicMock(), "Not found")
|
||||
)
|
||||
|
||||
def history(**kwargs: Any) -> AsyncIteratorMock:
|
||||
return AsyncIteratorMock(messages)
|
||||
|
||||
thread.history = history
|
||||
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.id = 999
|
||||
msg.channel = thread
|
||||
|
||||
result = await _build_thread_context(msg, mock_bot_user)
|
||||
|
||||
# Should still return context without starter
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_thread_context_deduplicates_starter(
|
||||
self, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""Starter also in recent history is not duplicated."""
|
||||
starter = mock_message(content="Thread starter", message_id=666666)
|
||||
|
||||
messages = [
|
||||
starter, # Starter in history
|
||||
mock_message(content="Reply", message_id=1),
|
||||
]
|
||||
|
||||
thread = MagicMock(spec=discord.Thread)
|
||||
thread.id = 666666
|
||||
thread.parent = MagicMock(spec=discord.TextChannel)
|
||||
thread.parent.fetch_message = AsyncMock(return_value=starter)
|
||||
|
||||
def history(**kwargs: Any) -> AsyncIteratorMock:
|
||||
return AsyncIteratorMock(messages)
|
||||
|
||||
thread.history = history
|
||||
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.id = 999
|
||||
msg.channel = thread
|
||||
|
||||
result = await _build_thread_context(msg, mock_bot_user)
|
||||
|
||||
# Should only have starter once
|
||||
if result:
|
||||
assert (
|
||||
result.count("Thread starter") <= 2
|
||||
) # At most once in formatted output
|
||||
|
||||
|
||||
class TestReplyChainContextBuilder:
|
||||
"""Tests for _build_reply_chain_context function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_reply_chain_single_reply(
|
||||
self, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""Message replies to one message returns 1 message in chain."""
|
||||
parent = mock_message(content="Parent message", message_id=100)
|
||||
parent.reference = None
|
||||
|
||||
child = MagicMock(spec=discord.Message)
|
||||
child.id = 200
|
||||
child.reference = MagicMock()
|
||||
child.reference.message_id = 100
|
||||
child.channel = MagicMock()
|
||||
child.channel.fetch_message = AsyncMock(return_value=parent)
|
||||
child.channel.name = "general"
|
||||
|
||||
result = await _build_reply_chain_context(child, mock_bot_user)
|
||||
|
||||
assert result is not None
|
||||
assert "Parent message" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_reply_chain_deep_chain(self, mock_bot_user: MagicMock) -> None:
|
||||
"""A → B → C → D reply chain returns full chain in chronological order."""
|
||||
msg_d = mock_message(content="Message D", message_id=4)
|
||||
msg_d.reference = None
|
||||
|
||||
msg_c = mock_message(content="Message C", message_id=3)
|
||||
ref_c = MagicMock()
|
||||
ref_c.message_id = 4
|
||||
msg_c.reference = ref_c
|
||||
|
||||
msg_b = mock_message(content="Message B", message_id=2)
|
||||
ref_b = MagicMock()
|
||||
ref_b.message_id = 3
|
||||
msg_b.reference = ref_b
|
||||
|
||||
# Current message replying to B
|
||||
ref_a = MagicMock()
|
||||
ref_a.message_id = 2
|
||||
|
||||
msg_a = MagicMock(spec=discord.Message)
|
||||
msg_a.id = 1
|
||||
msg_a.reference = ref_a
|
||||
msg_a.channel = MagicMock()
|
||||
msg_a.channel.name = "general"
|
||||
|
||||
# Mock fetch to return the chain
|
||||
message_map = {2: msg_b, 3: msg_c, 4: msg_d}
|
||||
|
||||
async def fetch_message(msg_id: int) -> MagicMock:
|
||||
if msg_id in message_map:
|
||||
return message_map[msg_id]
|
||||
raise discord.NotFound(MagicMock(), "Not found")
|
||||
|
||||
msg_a.channel.fetch_message = AsyncMock(side_effect=fetch_message)
|
||||
|
||||
result = await _build_reply_chain_context(msg_a, mock_bot_user)
|
||||
|
||||
assert result is not None
|
||||
# Should have all messages from the chain
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_reply_chain_max_depth(self, mock_bot_user: MagicMock) -> None:
|
||||
"""Chain depth > MAX_CONTEXT_MESSAGES stops at limit."""
|
||||
# Create a chain longer than MAX_CONTEXT_MESSAGES
|
||||
messages = {}
|
||||
for i in range(MAX_CONTEXT_MESSAGES + 5, 0, -1):
|
||||
msg = mock_message(content=f"Message {i}", message_id=i)
|
||||
if i < MAX_CONTEXT_MESSAGES + 5:
|
||||
ref = MagicMock()
|
||||
ref.message_id = i + 1
|
||||
msg.reference = ref
|
||||
else:
|
||||
msg.reference = None
|
||||
messages[i] = msg
|
||||
|
||||
# Start from message 1
|
||||
start = MagicMock(spec=discord.Message)
|
||||
start.id = 0
|
||||
start.reference = MagicMock()
|
||||
start.reference.message_id = 1
|
||||
start.channel = MagicMock()
|
||||
start.channel.name = "general"
|
||||
|
||||
async def fetch_message(msg_id: int) -> MagicMock:
|
||||
if msg_id in messages:
|
||||
return messages[msg_id]
|
||||
raise discord.NotFound(MagicMock(), "Not found")
|
||||
|
||||
start.channel.fetch_message = AsyncMock(side_effect=fetch_message)
|
||||
|
||||
result = await _build_reply_chain_context(start, mock_bot_user)
|
||||
|
||||
# Should have at most MAX_CONTEXT_MESSAGES
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_reply_chain_no_reply(self, mock_bot_user: MagicMock) -> None:
|
||||
"""Message is not a reply returns None."""
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.reference = None
|
||||
|
||||
result = await _build_reply_chain_context(msg, mock_bot_user)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_reply_chain_deleted_message(
|
||||
self, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""Reply to deleted message handles gracefully with partial chain."""
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.id = 200
|
||||
msg.reference = MagicMock()
|
||||
msg.reference.message_id = 100
|
||||
msg.channel = MagicMock()
|
||||
msg.channel.fetch_message = AsyncMock(
|
||||
side_effect=discord.NotFound(MagicMock(), "Not found")
|
||||
)
|
||||
msg.channel.name = "general"
|
||||
|
||||
await _build_reply_chain_context(msg, mock_bot_user)
|
||||
# Should handle gracefully - may return None or partial context
|
||||
# Either is acceptable
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_reply_chain_missing_reference_data(
|
||||
self, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""message.reference.message_id is None returns None."""
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.reference = MagicMock()
|
||||
msg.reference.message_id = None
|
||||
|
||||
result = await _build_reply_chain_context(msg, mock_bot_user)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_reply_chain_http_exception(
|
||||
self, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""discord.HTTPException on fetch stops chain."""
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.id = 200
|
||||
msg.reference = MagicMock()
|
||||
msg.reference.message_id = 100
|
||||
msg.channel = MagicMock()
|
||||
msg.channel.fetch_message = AsyncMock(
|
||||
side_effect=discord.HTTPException(MagicMock(), "HTTP error")
|
||||
)
|
||||
msg.channel.name = "general"
|
||||
|
||||
await _build_reply_chain_context(msg, mock_bot_user)
|
||||
# Should handle gracefully
|
||||
|
||||
|
||||
class TestCombinedContext:
|
||||
"""Tests for combined thread + reply context."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_combined_context_thread_with_reply(
|
||||
self, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""Reply inside thread includes both contexts."""
|
||||
# Create a thread with messages
|
||||
thread = MagicMock(spec=discord.Thread)
|
||||
thread.id = 666666
|
||||
thread.parent = MagicMock(spec=discord.TextChannel)
|
||||
thread.parent.fetch_message = AsyncMock(
|
||||
side_effect=discord.NotFound(MagicMock(), "")
|
||||
)
|
||||
|
||||
# Thread history
|
||||
thread_messages = [
|
||||
mock_message(content="Thread msg 1", message_id=1),
|
||||
mock_message(content="Thread msg 2", message_id=2),
|
||||
]
|
||||
|
||||
def history(**kwargs: Any) -> AsyncIteratorMock:
|
||||
return AsyncIteratorMock(thread_messages)
|
||||
|
||||
thread.history = history
|
||||
|
||||
# Message is a reply to another message in the thread
|
||||
parent_msg = mock_message(content="Parent message", message_id=2)
|
||||
parent_msg.reference = None
|
||||
|
||||
ref = MagicMock()
|
||||
ref.message_id = 2
|
||||
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.id = 999
|
||||
msg.channel = thread
|
||||
msg.reference = ref
|
||||
msg.channel.fetch_message = AsyncMock(return_value=parent_msg)
|
||||
msg.channel.name = "test-thread"
|
||||
|
||||
result = await _build_conversation_context(msg, mock_bot_user)
|
||||
|
||||
# Should have context from the thread
|
||||
assert result is not None
|
||||
assert "Conversation history" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_conversation_context_routes_to_thread(
|
||||
self, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""Message in thread routes to _build_thread_context."""
|
||||
thread = MagicMock(spec=discord.Thread)
|
||||
thread.id = 666666
|
||||
thread.parent = MagicMock(spec=discord.TextChannel)
|
||||
thread.parent.fetch_message = AsyncMock(
|
||||
side_effect=discord.NotFound(MagicMock(), "")
|
||||
)
|
||||
|
||||
messages = [mock_message(content="Thread msg")]
|
||||
|
||||
def history(**kwargs: Any) -> AsyncIteratorMock:
|
||||
return AsyncIteratorMock(messages)
|
||||
|
||||
thread.history = history
|
||||
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.id = 999
|
||||
msg.channel = thread
|
||||
msg.reference = None
|
||||
|
||||
result = await _build_conversation_context(msg, mock_bot_user)
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_conversation_context_routes_to_reply(
|
||||
self, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""Message with reference routes to _build_reply_chain_context."""
|
||||
parent = mock_message(content="Parent", message_id=100)
|
||||
parent.reference = None
|
||||
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.id = 200
|
||||
msg.channel = MagicMock(spec=discord.TextChannel) # Not a thread
|
||||
msg.reference = MagicMock()
|
||||
msg.reference.message_id = 100
|
||||
msg.channel.fetch_message = AsyncMock(return_value=parent)
|
||||
msg.channel.name = "general"
|
||||
|
||||
result = await _build_conversation_context(msg, mock_bot_user)
|
||||
assert result is not None
|
||||
|
||||
|
||||
class TestContextFormatting:
|
||||
"""Tests for context formatting."""
|
||||
|
||||
def test_format_message_content_mentions(self) -> None:
|
||||
"""Messages with <@123> mentions are converted to @username."""
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.content = "Hello <@123456789> how are you?"
|
||||
|
||||
user = MagicMock()
|
||||
user.id = 123456789
|
||||
user.display_name = "TestUser"
|
||||
msg.mentions = [user]
|
||||
msg.role_mentions = []
|
||||
msg.channel_mentions = []
|
||||
|
||||
result = format_message_content(msg)
|
||||
assert "@TestUser" in result
|
||||
assert "<@123456789>" not in result
|
||||
|
||||
def test_format_message_content_roles(self) -> None:
|
||||
"""Messages with <@&456> roles are converted to @rolename."""
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.content = "Attention <@&456789> members"
|
||||
|
||||
role = MagicMock()
|
||||
role.id = 456789
|
||||
role.name = "Moderators"
|
||||
msg.mentions = []
|
||||
msg.role_mentions = [role]
|
||||
msg.channel_mentions = []
|
||||
|
||||
result = format_message_content(msg)
|
||||
assert "@Moderators" in result
|
||||
assert "<@&456789>" not in result
|
||||
|
||||
def test_format_message_content_channels(self) -> None:
|
||||
"""Messages with <#789> channels are converted to #channelname."""
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.content = "Check out <#789012>"
|
||||
|
||||
channel = MagicMock()
|
||||
channel.id = 789012
|
||||
channel.name = "announcements"
|
||||
msg.mentions = []
|
||||
msg.role_mentions = []
|
||||
msg.channel_mentions = [channel]
|
||||
|
||||
result = format_message_content(msg)
|
||||
assert "#announcements" in result
|
||||
assert "<#789012>" not in result
|
||||
|
||||
def test_context_format_output(self, mock_bot_user: MagicMock) -> None:
|
||||
"""Build full context has expected format."""
|
||||
messages: list[Any] = [
|
||||
mock_message(content="Hello bot", author_bot=False),
|
||||
]
|
||||
messages[0].type = discord.MessageType.default
|
||||
|
||||
result = _format_messages_as_context(messages, mock_bot_user)
|
||||
|
||||
assert result is not None
|
||||
assert "Conversation history" in result
|
||||
assert "---" in result
|
||||
|
||||
def test_context_format_with_username(self, mock_bot_user: MagicMock) -> None:
|
||||
"""Messages from users include @username: prefix."""
|
||||
msg = mock_message(content="User message", author_bot=False)
|
||||
msg.author.display_name = "TestUser"
|
||||
msg.type = discord.MessageType.default
|
||||
|
||||
result = _format_messages_as_context([msg], mock_bot_user)
|
||||
|
||||
assert result is not None
|
||||
assert "@TestUser:" in result
|
||||
|
||||
def test_context_format_bot_marker(self, mock_bot_user: MagicMock) -> None:
|
||||
"""Bot messages in context are marked as OnyxBot:."""
|
||||
msg = mock_message(
|
||||
content="Bot response",
|
||||
author_bot=True,
|
||||
author_id=mock_bot_user.id,
|
||||
)
|
||||
msg.type = discord.MessageType.default
|
||||
|
||||
result = _format_messages_as_context([msg], mock_bot_user)
|
||||
|
||||
assert result is not None
|
||||
assert "OnyxBot:" in result
|
||||
157
backend/tests/unit/onyx/onyxbot/discord/test_discord_utils.py
Normal file
157
backend/tests/unit/onyx/onyxbot/discord/test_discord_utils.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""Unit tests for Discord bot utilities.
|
||||
|
||||
Tests for:
|
||||
- Token management (get_bot_token)
|
||||
- Registration key parsing (parse_discord_registration_key, generate_discord_registration_key)
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.onyxbot.discord.utils import get_bot_token
|
||||
from onyx.server.manage.discord_bot.utils import generate_discord_registration_key
|
||||
from onyx.server.manage.discord_bot.utils import parse_discord_registration_key
|
||||
from onyx.server.manage.discord_bot.utils import REGISTRATION_KEY_PREFIX
|
||||
|
||||
|
||||
class TestGetBotToken:
|
||||
"""Tests for get_bot_token function."""
|
||||
|
||||
def test_get_token_from_env(self) -> None:
|
||||
"""When env var is set, returns env var."""
|
||||
with patch("onyx.onyxbot.discord.utils.DISCORD_BOT_TOKEN", "env_token_123"):
|
||||
result = get_bot_token()
|
||||
assert result == "env_token_123"
|
||||
|
||||
def test_get_token_from_db(self) -> None:
|
||||
"""When no env var and DB config exists, returns DB token."""
|
||||
mock_config = MagicMock()
|
||||
mock_config.bot_token = "db_token_456"
|
||||
|
||||
with (
|
||||
patch("onyx.onyxbot.discord.utils.DISCORD_BOT_TOKEN", None),
|
||||
patch("onyx.onyxbot.discord.utils.AUTH_TYPE", "basic"), # Not CLOUD
|
||||
patch("onyx.onyxbot.discord.utils.get_session_with_tenant") as mock_session,
|
||||
patch(
|
||||
"onyx.onyxbot.discord.utils.get_discord_bot_config",
|
||||
return_value=mock_config,
|
||||
),
|
||||
):
|
||||
mock_session.return_value.__enter__ = MagicMock()
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
result = get_bot_token()
|
||||
assert result == "db_token_456"
|
||||
|
||||
def test_get_token_none(self) -> None:
|
||||
"""When no env var and no DB config, returns None."""
|
||||
with (
|
||||
patch("onyx.onyxbot.discord.utils.DISCORD_BOT_TOKEN", None),
|
||||
patch("onyx.onyxbot.discord.utils.AUTH_TYPE", "basic"), # Not CLOUD
|
||||
patch("onyx.onyxbot.discord.utils.get_session_with_tenant") as mock_session,
|
||||
patch(
|
||||
"onyx.onyxbot.discord.utils.get_discord_bot_config",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
mock_session.return_value.__enter__ = MagicMock()
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
result = get_bot_token()
|
||||
assert result is None
|
||||
|
||||
def test_get_token_env_priority(self) -> None:
|
||||
"""When both env var and DB exist, env var takes priority."""
|
||||
mock_config = MagicMock()
|
||||
mock_config.bot_token = "db_token_456"
|
||||
|
||||
with (
|
||||
patch("onyx.onyxbot.discord.utils.DISCORD_BOT_TOKEN", "env_token_123"),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.utils.get_discord_bot_config",
|
||||
return_value=mock_config,
|
||||
),
|
||||
):
|
||||
result = get_bot_token()
|
||||
# Should return env var, not DB token
|
||||
assert result == "env_token_123"
|
||||
|
||||
|
||||
class TestParseRegistrationKey:
|
||||
"""Tests for parse_discord_registration_key function."""
|
||||
|
||||
def test_parse_registration_key_valid(self) -> None:
|
||||
"""Valid key format returns tenant_id."""
|
||||
key = "discord_tenant123.randomtoken"
|
||||
result = parse_discord_registration_key(key)
|
||||
assert result == "tenant123"
|
||||
|
||||
def test_parse_registration_key_invalid(self) -> None:
|
||||
"""Malformed key returns None."""
|
||||
result = parse_discord_registration_key("malformed_key")
|
||||
assert result is None
|
||||
|
||||
def test_parse_registration_key_missing_prefix(self) -> None:
|
||||
"""Key without 'discord_' prefix returns None."""
|
||||
key = "tenant123.randomtoken"
|
||||
result = parse_discord_registration_key(key)
|
||||
assert result is None
|
||||
|
||||
def test_parse_registration_key_missing_dot(self) -> None:
|
||||
"""Key without separator '.' returns None."""
|
||||
key = "discord_tenant123randomtoken"
|
||||
result = parse_discord_registration_key(key)
|
||||
assert result is None
|
||||
|
||||
def test_parse_registration_key_empty_token(self) -> None:
|
||||
"""Key with empty token part returns None."""
|
||||
# This test verifies behavior with empty token after dot
|
||||
key = "discord_tenant123."
|
||||
result = parse_discord_registration_key(key)
|
||||
# Current implementation allows empty token, but returns tenant
|
||||
# If this should be invalid, update the implementation
|
||||
assert result == "tenant123" or result is None
|
||||
|
||||
def test_parse_registration_key_url_encoded_tenant(self) -> None:
|
||||
"""Tenant ID with URL encoding is decoded correctly."""
|
||||
# URL encoded "my tenant" -> "my%20tenant"
|
||||
key = "discord_my%20tenant.randomtoken"
|
||||
result = parse_discord_registration_key(key)
|
||||
assert result == "my tenant"
|
||||
|
||||
def test_parse_registration_key_special_chars(self) -> None:
|
||||
"""Key with special characters in tenant ID."""
|
||||
# Tenant with slashes (URL encoded)
|
||||
key = "discord_tenant%2Fwith%2Fslashes.randomtoken"
|
||||
result = parse_discord_registration_key(key)
|
||||
assert result == "tenant/with/slashes"
|
||||
|
||||
|
||||
class TestGenerateRegistrationKey:
|
||||
"""Tests for generate_discord_registration_key function."""
|
||||
|
||||
def test_generate_registration_key(self) -> None:
|
||||
"""Generated key has correct format."""
|
||||
key = generate_discord_registration_key("tenant123")
|
||||
|
||||
assert key.startswith(REGISTRATION_KEY_PREFIX)
|
||||
assert "tenant123" in key
|
||||
assert "." in key
|
||||
|
||||
# Parse it back to verify round-trip
|
||||
parsed = parse_discord_registration_key(key)
|
||||
assert parsed == "tenant123"
|
||||
|
||||
def test_generate_registration_key_unique(self) -> None:
|
||||
"""Each generated key is unique."""
|
||||
keys = [generate_discord_registration_key("tenant123") for _ in range(10)]
|
||||
assert len(set(keys)) == 10 # All unique
|
||||
|
||||
def test_generate_registration_key_special_tenant(self) -> None:
|
||||
"""Key generation handles special characters in tenant ID."""
|
||||
key = generate_discord_registration_key("my tenant/id")
|
||||
|
||||
# Should be URL encoded
|
||||
assert "%20" in key or "%2F" in key
|
||||
|
||||
# Parse it back
|
||||
parsed = parse_discord_registration_key(key)
|
||||
assert parsed == "my tenant/id"
|
||||
316
backend/tests/unit/onyx/onyxbot/discord/test_message_utils.py
Normal file
316
backend/tests/unit/onyx/onyxbot/discord/test_message_utils.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""Unit tests for Discord bot message utilities.
|
||||
|
||||
Tests for:
|
||||
- Message splitting (_split_message)
|
||||
- Citation formatting (_append_citations)
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.onyxbot.discord.constants import MAX_MESSAGE_LENGTH
|
||||
from onyx.onyxbot.discord.handle_message import _append_citations
|
||||
from onyx.onyxbot.discord.handle_message import _split_message
|
||||
|
||||
|
||||
class TestSplitMessage:
|
||||
"""Tests for _split_message function."""
|
||||
|
||||
def test_split_message_under_limit(self) -> None:
|
||||
"""Message under 2000 chars returns single chunk."""
|
||||
content = "x" * 1999
|
||||
chunks = _split_message(content)
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0] == content
|
||||
|
||||
def test_split_message_at_limit(self) -> None:
|
||||
"""Message exactly at 2000 chars returns single chunk."""
|
||||
content = "x" * MAX_MESSAGE_LENGTH
|
||||
chunks = _split_message(content)
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0] == content
|
||||
|
||||
def test_split_message_over_limit(self) -> None:
|
||||
"""Message over 2000 chars splits into multiple chunks."""
|
||||
content = "x" * 2001
|
||||
chunks = _split_message(content)
|
||||
assert len(chunks) == 2
|
||||
# All chunks should be <= MAX_MESSAGE_LENGTH
|
||||
for chunk in chunks:
|
||||
assert len(chunk) <= MAX_MESSAGE_LENGTH
|
||||
|
||||
def test_split_at_double_newline(self) -> None:
|
||||
"""Prefers splitting at double newline."""
|
||||
# Create content with double newline near the end but before limit
|
||||
first_part = "x" * 1500
|
||||
second_part = "y" * 1000
|
||||
content = f"{first_part}\n\n{second_part}"
|
||||
|
||||
chunks = _split_message(content)
|
||||
assert len(chunks) == 2
|
||||
# First chunk should end with or right after the double newline
|
||||
assert chunks[0].endswith("\n\n") or first_part in chunks[0]
|
||||
|
||||
def test_split_at_single_newline(self) -> None:
|
||||
"""When no double newline, splits at single newline."""
|
||||
first_part = "x" * 1500
|
||||
second_part = "y" * 1000
|
||||
content = f"{first_part}\n{second_part}"
|
||||
|
||||
chunks = _split_message(content)
|
||||
assert len(chunks) == 2
|
||||
|
||||
def test_split_at_period_space(self) -> None:
|
||||
"""When no newlines, splits at '. ' (period + space)."""
|
||||
first_part = "x" * 1500
|
||||
second_part = "y" * 1000
|
||||
content = f"{first_part}. {second_part}"
|
||||
|
||||
chunks = _split_message(content)
|
||||
assert len(chunks) == 2
|
||||
# First chunk should include the period
|
||||
assert chunks[0].endswith(". ") or chunks[0].endswith(".")
|
||||
|
||||
def test_split_at_space(self) -> None:
|
||||
"""When no better breakpoints, splits at space."""
|
||||
first_part = "x" * 1500
|
||||
second_part = "y" * 1000
|
||||
content = f"{first_part} {second_part}"
|
||||
|
||||
chunks = _split_message(content)
|
||||
assert len(chunks) == 2
|
||||
|
||||
def test_split_no_breakpoint(self) -> None:
|
||||
"""Handles gracefully when no breakpoints available (hard split)."""
|
||||
# 2001 chars with no spaces or newlines
|
||||
content = "x" * 2001
|
||||
chunks = _split_message(content)
|
||||
assert len(chunks) == 2
|
||||
# Content should be preserved
|
||||
assert "".join(chunks) == content
|
||||
|
||||
def test_split_threshold_50_percent(self) -> None:
|
||||
"""Breakpoint at less than 50% of limit is skipped."""
|
||||
# Put a breakpoint early (at 40% = 800 chars)
|
||||
# and another late (at 80% = 1600 chars)
|
||||
early_part = "x" * 800
|
||||
middle_part = "m" * 800 # Total: 1600
|
||||
late_part = "y" * 600 # Total: 2200
|
||||
content = f"{early_part}\n\n{middle_part}\n\n{late_part}"
|
||||
|
||||
chunks = _split_message(content)
|
||||
# Should prefer the later breakpoint over the 40% one
|
||||
assert len(chunks) == 2
|
||||
# First chunk should be longer than 800 chars
|
||||
assert len(chunks[0]) > 800
|
||||
|
||||
def test_split_multiple_chunks(self) -> None:
|
||||
"""5000 char message splits into 3 chunks."""
|
||||
content = "x" * 5000
|
||||
chunks = _split_message(content)
|
||||
assert len(chunks) == 3
|
||||
# Each chunk should be <= MAX_MESSAGE_LENGTH
|
||||
for chunk in chunks:
|
||||
assert len(chunk) <= MAX_MESSAGE_LENGTH
|
||||
|
||||
def test_split_preserves_content(self) -> None:
|
||||
"""Concatenated chunks equal original content."""
|
||||
content = "Hello world! " * 200 # About 2600 chars
|
||||
chunks = _split_message(content)
|
||||
assert "".join(chunks) == content
|
||||
|
||||
def test_split_with_unicode(self) -> None:
|
||||
"""Handles unicode characters correctly."""
|
||||
# Mix of ASCII and unicode
|
||||
content = "Hello " + "🎉" * 500 + " World " + "x" * 1500
|
||||
chunks = _split_message(content)
|
||||
# Should not break in the middle of emoji
|
||||
assert "".join(chunks) == content
|
||||
|
||||
|
||||
class TestAppendCitations:
|
||||
"""Tests for _append_citations function."""
|
||||
|
||||
def _make_response(
|
||||
self,
|
||||
answer: str,
|
||||
citations: list[dict] | None = None,
|
||||
documents: list[dict] | None = None,
|
||||
) -> ChatFullResponse:
|
||||
"""Helper to create ChatFullResponse with citations."""
|
||||
response = MagicMock(spec=ChatFullResponse)
|
||||
response.answer = answer
|
||||
|
||||
if citations:
|
||||
citation_mocks = []
|
||||
for c in citations:
|
||||
cm = MagicMock()
|
||||
cm.citation_number = c.get("num", 1)
|
||||
cm.document_id = c.get("doc_id", "doc1")
|
||||
citation_mocks.append(cm)
|
||||
response.citation_info = citation_mocks
|
||||
else:
|
||||
response.citation_info = None
|
||||
|
||||
if documents:
|
||||
doc_mocks = []
|
||||
for d in documents:
|
||||
dm = MagicMock()
|
||||
dm.document_id = d.get("doc_id", "doc1")
|
||||
dm.semantic_identifier = d.get("name", "Source")
|
||||
dm.link = d.get("link")
|
||||
doc_mocks.append(dm)
|
||||
response.top_documents = doc_mocks
|
||||
else:
|
||||
response.top_documents = None
|
||||
|
||||
return response
|
||||
|
||||
def test_format_citations_empty_list(self) -> None:
|
||||
"""No citations returns answer unchanged."""
|
||||
response = self._make_response("Test answer")
|
||||
result = _append_citations("Test answer", response)
|
||||
assert result == "Test answer"
|
||||
assert "Sources:" not in result
|
||||
|
||||
def test_format_citations_single(self) -> None:
|
||||
"""Single citation is formatted correctly."""
|
||||
response = self._make_response(
|
||||
"Test answer",
|
||||
citations=[{"num": 1, "doc_id": "doc1"}],
|
||||
documents=[
|
||||
{
|
||||
"doc_id": "doc1",
|
||||
"name": "Document One",
|
||||
"link": "https://example.com",
|
||||
}
|
||||
],
|
||||
)
|
||||
result = _append_citations("Test answer", response)
|
||||
assert "**Sources:**" in result
|
||||
assert "[Document One](<https://example.com>)" in result
|
||||
|
||||
def test_format_citations_multiple(self) -> None:
|
||||
"""Multiple citations are all formatted and numbered."""
|
||||
response = self._make_response(
|
||||
"Test answer",
|
||||
citations=[
|
||||
{"num": 1, "doc_id": "doc1"},
|
||||
{"num": 2, "doc_id": "doc2"},
|
||||
{"num": 3, "doc_id": "doc3"},
|
||||
],
|
||||
documents=[
|
||||
{"doc_id": "doc1", "name": "Doc 1", "link": "https://example.com/1"},
|
||||
{"doc_id": "doc2", "name": "Doc 2", "link": "https://example.com/2"},
|
||||
{"doc_id": "doc3", "name": "Doc 3", "link": "https://example.com/3"},
|
||||
],
|
||||
)
|
||||
result = _append_citations("Test answer", response)
|
||||
assert "1. [Doc 1]" in result
|
||||
assert "2. [Doc 2]" in result
|
||||
assert "3. [Doc 3]" in result
|
||||
|
||||
def test_format_citations_max_five(self) -> None:
|
||||
"""Only first 5 citations are included."""
|
||||
citations = [{"num": i, "doc_id": f"doc{i}"} for i in range(1, 11)]
|
||||
documents = [
|
||||
{
|
||||
"doc_id": f"doc{i}",
|
||||
"name": f"Doc {i}",
|
||||
"link": f"https://example.com/{i}",
|
||||
}
|
||||
for i in range(1, 11)
|
||||
]
|
||||
response = self._make_response(
|
||||
"Test answer", citations=citations, documents=documents
|
||||
)
|
||||
result = _append_citations("Test answer", response)
|
||||
|
||||
# Should have 5 citations
|
||||
assert "1. [Doc 1]" in result
|
||||
assert "5. [Doc 5]" in result
|
||||
# Should NOT have 6th citation
|
||||
assert "6. [Doc 6]" not in result
|
||||
|
||||
def test_format_citation_no_link(self) -> None:
|
||||
"""Citation without link formats as plain text (no markdown)."""
|
||||
response = self._make_response(
|
||||
"Test answer",
|
||||
citations=[{"num": 1, "doc_id": "doc1"}],
|
||||
documents=[{"doc_id": "doc1", "name": "No Link Doc", "link": None}],
|
||||
)
|
||||
result = _append_citations("Test answer", response)
|
||||
assert "1. No Link Doc" in result
|
||||
# Should not have markdown link syntax
|
||||
assert "[No Link Doc](<" not in result
|
||||
|
||||
def test_format_citation_empty_name(self) -> None:
|
||||
"""Empty semantic_identifier defaults to 'Source'."""
|
||||
response = self._make_response(
|
||||
"Test answer",
|
||||
citations=[{"num": 1, "doc_id": "doc1"}],
|
||||
documents=[{"doc_id": "doc1", "name": "", "link": "https://example.com"}],
|
||||
)
|
||||
result = _append_citations("Test answer", response)
|
||||
# Should use fallback "Source" name
|
||||
assert "[Source]" in result or "Source" in result
|
||||
|
||||
def test_format_citation_link_with_brackets(self) -> None:
|
||||
"""Link with special characters is wrapped with angle brackets."""
|
||||
response = self._make_response(
|
||||
"Test answer",
|
||||
citations=[{"num": 1, "doc_id": "doc1"}],
|
||||
documents=[
|
||||
{
|
||||
"doc_id": "doc1",
|
||||
"name": "Special Doc",
|
||||
"link": "https://example.com/path?query=value&other=123",
|
||||
}
|
||||
],
|
||||
)
|
||||
result = _append_citations("Test answer", response)
|
||||
# Discord markdown uses <link> to prevent embed
|
||||
assert "(<https://example.com" in result
|
||||
|
||||
def test_format_citations_sorted_by_number(self) -> None:
|
||||
"""Citations are sorted by citation number."""
|
||||
# Add in reverse order
|
||||
response = self._make_response(
|
||||
"Test answer",
|
||||
citations=[
|
||||
{"num": 3, "doc_id": "doc3"},
|
||||
{"num": 1, "doc_id": "doc1"},
|
||||
{"num": 2, "doc_id": "doc2"},
|
||||
],
|
||||
documents=[
|
||||
{"doc_id": "doc1", "name": "Doc 1", "link": "https://example.com/1"},
|
||||
{"doc_id": "doc2", "name": "Doc 2", "link": "https://example.com/2"},
|
||||
{"doc_id": "doc3", "name": "Doc 3", "link": "https://example.com/3"},
|
||||
],
|
||||
)
|
||||
result = _append_citations("Test answer", response)
|
||||
|
||||
# Find positions
|
||||
pos1 = result.find("1. [Doc 1]")
|
||||
pos2 = result.find("2. [Doc 2]")
|
||||
pos3 = result.find("3. [Doc 3]")
|
||||
|
||||
# Should be in order
|
||||
assert pos1 < pos2 < pos3
|
||||
|
||||
def test_format_citations_with_missing_document(self) -> None:
|
||||
"""Citation referencing non-existent document is skipped."""
|
||||
response = self._make_response(
|
||||
"Test answer",
|
||||
citations=[
|
||||
{"num": 1, "doc_id": "doc1"},
|
||||
{"num": 2, "doc_id": "doc_missing"}, # No matching document
|
||||
],
|
||||
documents=[
|
||||
{"doc_id": "doc1", "name": "Doc 1", "link": "https://example.com/1"},
|
||||
],
|
||||
)
|
||||
result = _append_citations("Test answer", response)
|
||||
assert "Doc 1" in result
|
||||
# Missing doc should not appear
|
||||
assert "doc_missing" not in result.lower()
|
||||
645
backend/tests/unit/onyx/onyxbot/discord/test_should_respond.py
Normal file
645
backend/tests/unit/onyx/onyxbot/discord/test_should_respond.py
Normal file
@@ -0,0 +1,645 @@
|
||||
"""Unit tests for Discord bot should_respond logic.
|
||||
|
||||
Tests the decision tree for when the bot should respond to messages.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import discord
|
||||
import pytest
|
||||
|
||||
from onyx.onyxbot.discord.handle_message import check_implicit_invocation
|
||||
from onyx.onyxbot.discord.handle_message import should_respond
|
||||
|
||||
|
||||
class TestBasicShouldRespond:
|
||||
"""Tests for basic should_respond decision logic."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_respond_guild_disabled(
|
||||
self, mock_discord_message: MagicMock, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""Guild config enabled=false returns False."""
|
||||
mock_guild_config = MagicMock()
|
||||
mock_guild_config.enabled = False
|
||||
|
||||
with patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_session_with_tenant"
|
||||
) as mock_session:
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
with patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_guild_config_by_discord_id",
|
||||
return_value=mock_guild_config,
|
||||
):
|
||||
result = await should_respond(
|
||||
mock_discord_message, "tenant1", mock_bot_user
|
||||
)
|
||||
|
||||
assert result.should_respond is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_respond_guild_enabled(
|
||||
self, mock_discord_message: MagicMock, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""Guild config enabled=true proceeds to channel check."""
|
||||
mock_guild_config = MagicMock()
|
||||
mock_guild_config.enabled = True
|
||||
mock_guild_config.default_persona_id = 1
|
||||
|
||||
mock_channel_config = MagicMock()
|
||||
mock_channel_config.enabled = True
|
||||
mock_channel_config.require_bot_invocation = False
|
||||
mock_channel_config.thread_only_mode = False
|
||||
mock_channel_config.persona_override_id = None
|
||||
|
||||
with patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_session_with_tenant"
|
||||
) as mock_session:
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_guild_config_by_discord_id",
|
||||
return_value=mock_guild_config,
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_channel_config_by_discord_ids",
|
||||
return_value=mock_channel_config,
|
||||
),
|
||||
):
|
||||
result = await should_respond(
|
||||
mock_discord_message, "tenant1", mock_bot_user
|
||||
)
|
||||
|
||||
assert result.should_respond is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_respond_channel_disabled(
|
||||
self, mock_discord_message: MagicMock, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""Channel config enabled=false returns False."""
|
||||
mock_guild_config = MagicMock()
|
||||
mock_guild_config.enabled = True
|
||||
|
||||
mock_channel_config = MagicMock()
|
||||
mock_channel_config.enabled = False
|
||||
|
||||
with patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_session_with_tenant"
|
||||
) as mock_session:
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_guild_config_by_discord_id",
|
||||
return_value=mock_guild_config,
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_channel_config_by_discord_ids",
|
||||
return_value=mock_channel_config,
|
||||
),
|
||||
):
|
||||
result = await should_respond(
|
||||
mock_discord_message, "tenant1", mock_bot_user
|
||||
)
|
||||
|
||||
assert result.should_respond is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_respond_channel_enabled(
|
||||
self, mock_discord_message: MagicMock, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""Channel config enabled=true proceeds to mention check."""
|
||||
mock_guild_config = MagicMock()
|
||||
mock_guild_config.enabled = True
|
||||
mock_guild_config.default_persona_id = 2
|
||||
|
||||
mock_channel_config = MagicMock()
|
||||
mock_channel_config.enabled = True
|
||||
mock_channel_config.require_bot_invocation = False
|
||||
mock_channel_config.thread_only_mode = False
|
||||
mock_channel_config.persona_override_id = None
|
||||
|
||||
with patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_session_with_tenant"
|
||||
) as mock_session:
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_guild_config_by_discord_id",
|
||||
return_value=mock_guild_config,
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_channel_config_by_discord_ids",
|
||||
return_value=mock_channel_config,
|
||||
),
|
||||
):
|
||||
result = await should_respond(
|
||||
mock_discord_message, "tenant1", mock_bot_user
|
||||
)
|
||||
|
||||
assert result.should_respond is True
|
||||
assert result.persona_id == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_respond_channel_not_found(
|
||||
self, mock_discord_message: MagicMock, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""No channel config returns False (not whitelisted)."""
|
||||
mock_guild_config = MagicMock()
|
||||
mock_guild_config.enabled = True
|
||||
|
||||
with patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_session_with_tenant"
|
||||
) as mock_session:
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_guild_config_by_discord_id",
|
||||
return_value=mock_guild_config,
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_channel_config_by_discord_ids",
|
||||
return_value=None, # No config
|
||||
),
|
||||
):
|
||||
result = await should_respond(
|
||||
mock_discord_message, "tenant1", mock_bot_user
|
||||
)
|
||||
|
||||
assert result.should_respond is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_respond_require_mention_true_no_mention(
|
||||
self, mock_discord_message: MagicMock, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""require_bot_invocation=true with no @mention returns False."""
|
||||
mock_guild_config = MagicMock()
|
||||
mock_guild_config.enabled = True
|
||||
mock_guild_config.default_persona_id = 1
|
||||
|
||||
mock_channel_config = MagicMock()
|
||||
mock_channel_config.enabled = True
|
||||
mock_channel_config.require_bot_invocation = True
|
||||
mock_channel_config.thread_only_mode = False
|
||||
mock_channel_config.persona_override_id = None
|
||||
|
||||
# No bot mention
|
||||
mock_discord_message.mentions = []
|
||||
|
||||
with patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_session_with_tenant"
|
||||
) as mock_session:
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_guild_config_by_discord_id",
|
||||
return_value=mock_guild_config,
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_channel_config_by_discord_ids",
|
||||
return_value=mock_channel_config,
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_message.check_implicit_invocation",
|
||||
return_value=False,
|
||||
),
|
||||
):
|
||||
result = await should_respond(
|
||||
mock_discord_message, "tenant1", mock_bot_user
|
||||
)
|
||||
|
||||
assert result.should_respond is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_respond_require_mention_true_with_mention(
|
||||
self, mock_message_with_bot_mention: MagicMock, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""require_bot_invocation=true with @mention returns True."""
|
||||
mock_guild_config = MagicMock()
|
||||
mock_guild_config.enabled = True
|
||||
mock_guild_config.default_persona_id = 1
|
||||
|
||||
mock_channel_config = MagicMock()
|
||||
mock_channel_config.enabled = True
|
||||
mock_channel_config.require_bot_invocation = True
|
||||
mock_channel_config.thread_only_mode = False
|
||||
mock_channel_config.persona_override_id = None
|
||||
|
||||
with patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_session_with_tenant"
|
||||
) as mock_session:
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_guild_config_by_discord_id",
|
||||
return_value=mock_guild_config,
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_channel_config_by_discord_ids",
|
||||
return_value=mock_channel_config,
|
||||
),
|
||||
):
|
||||
result = await should_respond(
|
||||
mock_message_with_bot_mention, "tenant1", mock_bot_user
|
||||
)
|
||||
|
||||
assert result.should_respond is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_respond_require_mention_false_no_mention(
|
||||
self, mock_discord_message: MagicMock, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""require_bot_invocation=false with no @mention returns True."""
|
||||
mock_guild_config = MagicMock()
|
||||
mock_guild_config.enabled = True
|
||||
mock_guild_config.default_persona_id = 1
|
||||
|
||||
mock_channel_config = MagicMock()
|
||||
mock_channel_config.enabled = True
|
||||
mock_channel_config.require_bot_invocation = False
|
||||
mock_channel_config.thread_only_mode = False
|
||||
mock_channel_config.persona_override_id = None
|
||||
|
||||
with patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_session_with_tenant"
|
||||
) as mock_session:
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_guild_config_by_discord_id",
|
||||
return_value=mock_guild_config,
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_channel_config_by_discord_ids",
|
||||
return_value=mock_channel_config,
|
||||
),
|
||||
):
|
||||
result = await should_respond(
|
||||
mock_discord_message, "tenant1", mock_bot_user
|
||||
)
|
||||
|
||||
assert result.should_respond is True
|
||||
|
||||
|
||||
class TestImplicitShouldRespond:
|
||||
"""Tests for implicit invocation (no @mention required in certain contexts)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_implicit_respond_reply_to_bot_message(
|
||||
self, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""User replies to a bot message returns True."""
|
||||
# Create a message that replies to the bot
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.reference = MagicMock()
|
||||
msg.reference.message_id = 12345
|
||||
|
||||
# Mock the referenced message as a bot message
|
||||
referenced_msg = MagicMock()
|
||||
referenced_msg.author.id = mock_bot_user.id
|
||||
|
||||
msg.channel = MagicMock()
|
||||
msg.channel.fetch_message = AsyncMock(return_value=referenced_msg)
|
||||
|
||||
result = await check_implicit_invocation(msg, mock_bot_user)
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_implicit_respond_reply_to_user_message(
|
||||
self, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""User replies to another user's message returns False."""
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.reference = MagicMock()
|
||||
msg.reference.message_id = 12345
|
||||
|
||||
# Mock the referenced message as a user message
|
||||
referenced_msg = MagicMock()
|
||||
referenced_msg.author.id = 999999 # Different from bot
|
||||
|
||||
msg.channel = MagicMock()
|
||||
msg.channel.fetch_message = AsyncMock(return_value=referenced_msg)
|
||||
|
||||
result = await check_implicit_invocation(msg, mock_bot_user)
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_implicit_respond_in_bot_owned_thread(
|
||||
self, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""Message in thread owned by bot returns True."""
|
||||
thread = MagicMock(spec=discord.Thread)
|
||||
thread.owner_id = mock_bot_user.id # Bot owns the thread
|
||||
thread.parent = MagicMock(spec=discord.TextChannel)
|
||||
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.reference = None
|
||||
msg.channel = thread
|
||||
|
||||
result = await check_implicit_invocation(msg, mock_bot_user)
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_implicit_respond_in_user_owned_thread(
|
||||
self, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""Message in thread owned by user returns False."""
|
||||
thread = MagicMock(spec=discord.Thread)
|
||||
thread.owner_id = 999999 # User owns the thread
|
||||
thread.parent = MagicMock(spec=discord.TextChannel)
|
||||
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.reference = None
|
||||
msg.channel = thread
|
||||
|
||||
result = await check_implicit_invocation(msg, mock_bot_user)
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_implicit_respond_reply_in_bot_thread(
|
||||
self, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""Reply to user in bot-owned thread returns True (thread context)."""
|
||||
thread = MagicMock(spec=discord.Thread)
|
||||
thread.owner_id = mock_bot_user.id
|
||||
thread.parent = MagicMock(spec=discord.TextChannel)
|
||||
|
||||
# User replying to another user in bot's thread
|
||||
referenced_msg = MagicMock()
|
||||
referenced_msg.author.id = 888888 # Another user
|
||||
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.reference = MagicMock()
|
||||
msg.reference.message_id = 12345
|
||||
msg.channel = thread
|
||||
msg.channel.fetch_message = AsyncMock(return_value=referenced_msg)
|
||||
|
||||
result = await check_implicit_invocation(msg, mock_bot_user)
|
||||
# Should return True because it's in bot's thread
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_implicit_respond_thread_from_bot_message(
|
||||
self, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""Thread created from bot message (non-forum) returns True."""
|
||||
thread = MagicMock(spec=discord.Thread)
|
||||
thread.id = 777777
|
||||
thread.owner_id = 999999 # User owns thread but...
|
||||
thread.parent = MagicMock(spec=discord.TextChannel)
|
||||
|
||||
# The starter message is from the bot
|
||||
starter_msg = MagicMock()
|
||||
starter_msg.author.id = mock_bot_user.id
|
||||
thread.parent.fetch_message = AsyncMock(return_value=starter_msg)
|
||||
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.reference = None
|
||||
msg.channel = thread
|
||||
|
||||
result = await check_implicit_invocation(msg, mock_bot_user)
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_implicit_respond_forum_channel_excluded(
|
||||
self, mock_bot_user: MagicMock, mock_thread_forum_parent: MagicMock
|
||||
) -> None:
|
||||
"""Thread parent is ForumChannel - does NOT check starter message."""
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.reference = None
|
||||
msg.channel = mock_thread_forum_parent
|
||||
mock_thread_forum_parent.owner_id = 999999 # Not bot
|
||||
|
||||
result = await check_implicit_invocation(msg, mock_bot_user)
|
||||
# Should be False - forum threads don't use starter message check
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_implicit_respond_combined_with_mention(
|
||||
self, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""Has @mention AND is implicit - should return True (either works)."""
|
||||
thread = MagicMock(spec=discord.Thread)
|
||||
thread.owner_id = mock_bot_user.id
|
||||
thread.parent = MagicMock(spec=discord.TextChannel)
|
||||
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.reference = None
|
||||
msg.channel = thread
|
||||
msg.mentions = [mock_bot_user]
|
||||
|
||||
result = await check_implicit_invocation(msg, mock_bot_user)
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_implicit_respond_reference_fetch_fails(
|
||||
self, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""discord.NotFound when fetching reply reference returns False."""
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.reference = MagicMock()
|
||||
msg.reference.message_id = 12345
|
||||
msg.channel = MagicMock()
|
||||
msg.channel.fetch_message = AsyncMock(
|
||||
side_effect=discord.NotFound(MagicMock(), "Not found")
|
||||
)
|
||||
|
||||
result = await check_implicit_invocation(msg, mock_bot_user)
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_implicit_respond_http_exception(
|
||||
self, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""discord.HTTPException during check returns False."""
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.reference = MagicMock()
|
||||
msg.reference.message_id = 12345
|
||||
msg.channel = MagicMock()
|
||||
msg.channel.fetch_message = AsyncMock(
|
||||
side_effect=discord.HTTPException(MagicMock(), "HTTP error")
|
||||
)
|
||||
|
||||
result = await check_implicit_invocation(msg, mock_bot_user)
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestThreadOnlyMode:
|
||||
"""Tests for thread_only_mode behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thread_only_mode_message_in_thread(
|
||||
self, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""thread_only_mode=true, message in thread returns True."""
|
||||
mock_guild_config = MagicMock()
|
||||
mock_guild_config.enabled = True
|
||||
mock_guild_config.default_persona_id = 1
|
||||
|
||||
mock_channel_config = MagicMock()
|
||||
mock_channel_config.enabled = True
|
||||
mock_channel_config.require_bot_invocation = False
|
||||
mock_channel_config.thread_only_mode = True
|
||||
mock_channel_config.persona_override_id = None
|
||||
|
||||
# Create thread message
|
||||
thread = MagicMock(spec=discord.Thread)
|
||||
thread.parent = MagicMock(spec=discord.TextChannel)
|
||||
thread.parent.id = 111111111
|
||||
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.guild = MagicMock()
|
||||
msg.guild.id = 987654321
|
||||
msg.channel = thread
|
||||
msg.mentions = []
|
||||
msg.reference = None
|
||||
|
||||
with patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_session_with_tenant"
|
||||
) as mock_session:
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_guild_config_by_discord_id",
|
||||
return_value=mock_guild_config,
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_channel_config_by_discord_ids",
|
||||
return_value=mock_channel_config,
|
||||
),
|
||||
):
|
||||
result = await should_respond(msg, "tenant1", mock_bot_user)
|
||||
|
||||
assert result.should_respond is True
|
||||
assert result.thread_only_mode is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thread_only_mode_false_message_in_channel(
|
||||
self, mock_discord_message: MagicMock, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""thread_only_mode=false, message in channel returns True."""
|
||||
mock_guild_config = MagicMock()
|
||||
mock_guild_config.enabled = True
|
||||
mock_guild_config.default_persona_id = 1
|
||||
|
||||
mock_channel_config = MagicMock()
|
||||
mock_channel_config.enabled = True
|
||||
mock_channel_config.require_bot_invocation = False
|
||||
mock_channel_config.thread_only_mode = False
|
||||
mock_channel_config.persona_override_id = None
|
||||
|
||||
with patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_session_with_tenant"
|
||||
) as mock_session:
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_guild_config_by_discord_id",
|
||||
return_value=mock_guild_config,
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_channel_config_by_discord_ids",
|
||||
return_value=mock_channel_config,
|
||||
),
|
||||
):
|
||||
result = await should_respond(
|
||||
mock_discord_message, "tenant1", mock_bot_user
|
||||
)
|
||||
|
||||
assert result.should_respond is True
|
||||
assert result.thread_only_mode is False
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Edge case tests for should_respond."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_respond_no_guild(self, mock_bot_user: MagicMock) -> None:
|
||||
"""Message without guild (DM) returns False."""
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.guild = None
|
||||
|
||||
result = await should_respond(msg, "tenant1", mock_bot_user)
|
||||
assert result.should_respond is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_respond_thread_uses_parent_channel_config(
|
||||
self, mock_bot_user: MagicMock
|
||||
) -> None:
|
||||
"""Thread under channel uses parent channel's config."""
|
||||
mock_guild_config = MagicMock()
|
||||
mock_guild_config.enabled = True
|
||||
mock_guild_config.default_persona_id = 1
|
||||
|
||||
mock_channel_config = MagicMock()
|
||||
mock_channel_config.enabled = True
|
||||
mock_channel_config.require_bot_invocation = False
|
||||
mock_channel_config.thread_only_mode = False
|
||||
mock_channel_config.persona_override_id = 5 # Specific persona
|
||||
|
||||
# Create thread message
|
||||
thread = MagicMock(spec=discord.Thread)
|
||||
thread.id = 666666
|
||||
thread.parent = MagicMock(spec=discord.TextChannel)
|
||||
thread.parent.id = 111111111 # Parent channel ID
|
||||
|
||||
msg = MagicMock(spec=discord.Message)
|
||||
msg.guild = MagicMock()
|
||||
msg.guild.id = 987654321
|
||||
msg.channel = thread
|
||||
msg.mentions = []
|
||||
msg.reference = None
|
||||
|
||||
with patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_session_with_tenant"
|
||||
) as mock_session:
|
||||
mock_db = MagicMock()
|
||||
mock_session.return_value.__enter__ = MagicMock(return_value=mock_db)
|
||||
mock_session.return_value.__exit__ = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_guild_config_by_discord_id",
|
||||
return_value=mock_guild_config,
|
||||
),
|
||||
patch(
|
||||
"onyx.onyxbot.discord.handle_message.get_channel_config_by_discord_ids",
|
||||
return_value=mock_channel_config,
|
||||
),
|
||||
):
|
||||
result = await should_respond(msg, "tenant1", mock_bot_user)
|
||||
|
||||
assert result.should_respond is True
|
||||
# Should use parent's persona override
|
||||
assert result.persona_id == 5
|
||||
106
backend/tests/unit/onyx/onyxbot/test_slack_formatting.py
Normal file
106
backend/tests/unit/onyx/onyxbot/test_slack_formatting.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from onyx.onyxbot.slack.formatting import _convert_slack_links_to_markdown
|
||||
from onyx.onyxbot.slack.formatting import _normalize_link_destinations
|
||||
from onyx.onyxbot.slack.formatting import _sanitize_html
|
||||
from onyx.onyxbot.slack.formatting import _transform_outside_code_blocks
|
||||
from onyx.onyxbot.slack.formatting import format_slack_message
|
||||
from onyx.onyxbot.slack.utils import remove_slack_text_interactions
|
||||
from onyx.utils.text_processing import decode_escapes
|
||||
|
||||
|
||||
def test_normalize_citation_link_wraps_url_with_parentheses() -> None:
|
||||
message = (
|
||||
"See [[1]](https://example.com/Access%20ID%20Card(s)%20Guide.pdf) for details."
|
||||
)
|
||||
|
||||
normalized = _normalize_link_destinations(message)
|
||||
|
||||
assert (
|
||||
"See [[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>) for details."
|
||||
== normalized
|
||||
)
|
||||
|
||||
|
||||
def test_normalize_citation_link_keeps_existing_angle_brackets() -> None:
|
||||
message = "[[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>)"
|
||||
|
||||
normalized = _normalize_link_destinations(message)
|
||||
|
||||
assert message == normalized
|
||||
|
||||
|
||||
def test_normalize_citation_link_handles_multiple_links() -> None:
|
||||
message = (
|
||||
"[[1]](https://example.com/(USA)%20Guide.pdf) "
|
||||
"[[2]](https://example.com/Plan(s)%20Overview.pdf)"
|
||||
)
|
||||
|
||||
normalized = _normalize_link_destinations(message)
|
||||
|
||||
assert "[[1]](<https://example.com/(USA)%20Guide.pdf>)" in normalized
|
||||
assert "[[2]](<https://example.com/Plan(s)%20Overview.pdf>)" in normalized
|
||||
|
||||
|
||||
def test_format_slack_message_keeps_parenthesized_citation_links_intact() -> None:
|
||||
message = (
|
||||
"Download [[1]](https://example.com/(USA)%20Access%20ID%20Card(s)%20Guide.pdf)"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
rendered = decode_escapes(remove_slack_text_interactions(formatted))
|
||||
|
||||
assert (
|
||||
"<https://example.com/(USA)%20Access%20ID%20Card(s)%20Guide.pdf|[1]>"
|
||||
in rendered
|
||||
)
|
||||
assert "|[1]>%20Access%20ID%20Card" not in rendered
|
||||
|
||||
|
||||
def test_slack_style_links_converted_to_clickable_links() -> None:
|
||||
message = "Visit <https://example.com/page|Example Page> for details."
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "<https://example.com/page|Example Page>" in formatted
|
||||
assert "<" not in formatted
|
||||
|
||||
|
||||
def test_slack_style_links_preserved_inside_code_blocks() -> None:
|
||||
message = "```\n<https://example.com|click>\n```"
|
||||
|
||||
converted = _convert_slack_links_to_markdown(message)
|
||||
|
||||
assert "<https://example.com|click>" in converted
|
||||
|
||||
|
||||
def test_html_tags_stripped_outside_code_blocks() -> None:
|
||||
message = "Hello<br/>world ```<div>code</div>``` after"
|
||||
|
||||
sanitized = _transform_outside_code_blocks(message, _sanitize_html)
|
||||
|
||||
assert "<br" not in sanitized
|
||||
assert "<div>code</div>" in sanitized
|
||||
|
||||
|
||||
def test_format_slack_message_block_spacing() -> None:
|
||||
message = "Paragraph one.\n\nParagraph two."
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "Paragraph one.\n\nParagraph two." == formatted
|
||||
|
||||
|
||||
def test_format_slack_message_code_block_no_trailing_blank_line() -> None:
|
||||
message = "```python\nprint('hi')\n```"
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert formatted.endswith("print('hi')\n```")
|
||||
|
||||
|
||||
def test_format_slack_message_ampersand_not_double_escaped() -> None:
|
||||
message = 'She said "hello" & goodbye.'
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "&" in formatted
|
||||
assert """ not in formatted
|
||||
57
backend/tests/unit/onyx/utils/test_telemetry.py
Normal file
57
backend/tests/unit/onyx/utils/test_telemetry.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.utils import telemetry as telemetry_utils
|
||||
|
||||
|
||||
def test_mt_cloud_telemetry_noop_when_not_multi_tenant(monkeypatch: Any) -> None:
|
||||
fetch_impl = Mock()
|
||||
monkeypatch.setattr(
|
||||
telemetry_utils,
|
||||
"fetch_versioned_implementation_with_fallback",
|
||||
fetch_impl,
|
||||
)
|
||||
# mt_cloud_telemetry reads the module-local imported symbol, so patch this path.
|
||||
monkeypatch.setattr("onyx.utils.telemetry.MULTI_TENANT", False)
|
||||
|
||||
telemetry_utils.mt_cloud_telemetry(
|
||||
tenant_id="tenant-1",
|
||||
distinct_id="user@example.com",
|
||||
event=MilestoneRecordType.USER_MESSAGE_SENT,
|
||||
properties={"origin": "web"},
|
||||
)
|
||||
|
||||
fetch_impl.assert_not_called()
|
||||
|
||||
|
||||
def test_mt_cloud_telemetry_calls_event_telemetry_when_multi_tenant(
|
||||
monkeypatch: Any,
|
||||
) -> None:
|
||||
event_telemetry = Mock()
|
||||
fetch_impl = Mock(return_value=event_telemetry)
|
||||
monkeypatch.setattr(
|
||||
telemetry_utils,
|
||||
"fetch_versioned_implementation_with_fallback",
|
||||
fetch_impl,
|
||||
)
|
||||
# mt_cloud_telemetry reads the module-local imported symbol, so patch this path.
|
||||
monkeypatch.setattr("onyx.utils.telemetry.MULTI_TENANT", True)
|
||||
|
||||
telemetry_utils.mt_cloud_telemetry(
|
||||
tenant_id="tenant-1",
|
||||
distinct_id="user@example.com",
|
||||
event=MilestoneRecordType.USER_MESSAGE_SENT,
|
||||
properties={"origin": "web"},
|
||||
)
|
||||
|
||||
fetch_impl.assert_called_once_with(
|
||||
module="onyx.utils.telemetry",
|
||||
attribute="event_telemetry",
|
||||
fallback=telemetry_utils.noop_fallback,
|
||||
)
|
||||
event_telemetry.assert_called_once_with(
|
||||
"user@example.com",
|
||||
MilestoneRecordType.USER_MESSAGE_SENT,
|
||||
{"origin": "web", "tenant_id": "tenant-1"},
|
||||
)
|
||||
@@ -221,6 +221,13 @@ services:
|
||||
- NOTIFY_SLACKBOT_NO_ANSWER=${NOTIFY_SLACKBOT_NO_ANSWER:-}
|
||||
- ONYX_BOT_MAX_QPM=${ONYX_BOT_MAX_QPM:-}
|
||||
- ONYX_BOT_MAX_WAIT_TIME=${ONYX_BOT_MAX_WAIT_TIME:-}
|
||||
# Discord Bot Configuration (runs via supervisord, requires DISCORD_BOT_TOKEN to be set)
|
||||
# IMPORTANT: Only one Discord bot instance can run per token - do not scale background workers
|
||||
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
|
||||
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
|
||||
# API Server connection for Discord bot message processing
|
||||
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
|
||||
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
|
||||
# Logging
|
||||
# Leave this on pretty please? Nothing sensitive is collected!
|
||||
- DISABLE_TELEMETRY=${DISABLE_TELEMETRY:-}
|
||||
|
||||
@@ -63,6 +63,11 @@ services:
|
||||
- S3_ENDPOINT_URL=${S3_ENDPOINT_URL:-http://minio:9000}
|
||||
- S3_AWS_ACCESS_KEY_ID=${S3_AWS_ACCESS_KEY_ID:-minioadmin}
|
||||
- S3_AWS_SECRET_ACCESS_KEY=${S3_AWS_SECRET_ACCESS_KEY:-minioadmin}
|
||||
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
|
||||
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
|
||||
# API Server connection for Discord bot message processing
|
||||
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
|
||||
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
|
||||
env_file:
|
||||
- path: .env
|
||||
required: false
|
||||
|
||||
@@ -82,6 +82,11 @@ services:
|
||||
- S3_ENDPOINT_URL=${S3_ENDPOINT_URL:-http://minio:9000}
|
||||
- S3_AWS_ACCESS_KEY_ID=${S3_AWS_ACCESS_KEY_ID:-minioadmin}
|
||||
- S3_AWS_SECRET_ACCESS_KEY=${S3_AWS_SECRET_ACCESS_KEY:-minioadmin}
|
||||
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
|
||||
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
|
||||
# API Server connection for Discord bot message processing
|
||||
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
|
||||
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
|
||||
env_file:
|
||||
- path: .env
|
||||
required: false
|
||||
|
||||
@@ -129,6 +129,11 @@ services:
|
||||
- S3_ENDPOINT_URL=${S3_ENDPOINT_URL:-http://minio:9000}
|
||||
- S3_AWS_ACCESS_KEY_ID=${S3_AWS_ACCESS_KEY_ID:-minioadmin}
|
||||
- S3_AWS_SECRET_ACCESS_KEY=${S3_AWS_SECRET_ACCESS_KEY:-minioadmin}
|
||||
- DISCORD_BOT_TOKEN=${DISCORD_BOT_TOKEN:-}
|
||||
- DISCORD_BOT_INVOKE_CHAR=${DISCORD_BOT_INVOKE_CHAR:-!}
|
||||
# API Server connection for Discord bot message processing
|
||||
- API_SERVER_PROTOCOL=${API_SERVER_PROTOCOL:-http}
|
||||
- API_SERVER_HOST=${API_SERVER_HOST:-api_server}
|
||||
# PRODUCTION: Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||
# volumes:
|
||||
# - ./bundle.pem:/app/bundle.pem:ro
|
||||
|
||||
@@ -77,6 +77,13 @@ MINIO_ROOT_PASSWORD=minioadmin
|
||||
## CORS origins for MCP clients (comma-separated list)
|
||||
# MCP_SERVER_CORS_ORIGINS=
|
||||
|
||||
## Discord Bot Configuration
|
||||
## The Discord bot allows users to interact with Onyx from Discord servers
|
||||
## Bot token from Discord Developer Portal (required to enable the bot)
|
||||
# DISCORD_BOT_TOKEN=
|
||||
## Command prefix for bot commands (default: "!")
|
||||
# DISCORD_BOT_INVOKE_CHAR=!
|
||||
|
||||
## Celery Configuration
|
||||
# CELERY_BROKER_POOL_LIMIT=
|
||||
# CELERY_WORKER_DOCFETCHING_CONCURRENCY=
|
||||
|
||||
98
deployment/helm/charts/onyx/templates/discordbot.yaml
Normal file
98
deployment/helm/charts/onyx/templates/discordbot.yaml
Normal file
@@ -0,0 +1,98 @@
|
||||
{{- if .Values.discordbot.enabled }}
|
||||
# Discord bot MUST run as a single replica - Discord only allows one client connection per bot token.
|
||||
# Do NOT enable HPA or increase replicas. Message processing is offloaded to scalable API pods via HTTP.
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: {{ include "onyx.fullname" . }}-discordbot
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 4 }}
|
||||
{{- with .Values.discordbot.deploymentLabels }}
|
||||
{{- toYaml . | nindent 4 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
# CRITICAL: Discord bots cannot be horizontally scaled - only one WebSocket connection per token is allowed
|
||||
replicas: 1
|
||||
strategy:
|
||||
type: Recreate # Ensure old pod is terminated before new one starts to avoid duplicate connections
|
||||
selector:
|
||||
matchLabels:
|
||||
{{- include "onyx.selectorLabels" . | nindent 6 }}
|
||||
{{- if .Values.discordbot.deploymentLabels }}
|
||||
{{- toYaml .Values.discordbot.deploymentLabels | nindent 6 }}
|
||||
{{- end }}
|
||||
template:
|
||||
metadata:
|
||||
annotations:
|
||||
checksum/config: {{ include (print $.Template.BasePath "/configmap.yaml") . | sha256sum }}
|
||||
{{- with .Values.discordbot.podAnnotations }}
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
labels:
|
||||
{{- include "onyx.labels" . | nindent 8 }}
|
||||
{{- with .Values.discordbot.deploymentLabels }}
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.discordbot.podLabels }}
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
spec:
|
||||
{{- with .Values.imagePullSecrets }}
|
||||
imagePullSecrets:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
serviceAccountName: {{ include "onyx.serviceAccountName" . }}
|
||||
securityContext:
|
||||
{{- toYaml .Values.discordbot.podSecurityContext | nindent 8 }}
|
||||
{{- with .Values.discordbot.nodeSelector }}
|
||||
nodeSelector:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.discordbot.affinity }}
|
||||
affinity:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- with .Values.discordbot.tolerations }}
|
||||
tolerations:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
containers:
|
||||
- name: discordbot
|
||||
securityContext:
|
||||
{{- toYaml .Values.discordbot.securityContext | nindent 12 }}
|
||||
image: "{{ .Values.discordbot.image.repository }}:{{ .Values.discordbot.image.tag | default .Values.global.version }}"
|
||||
imagePullPolicy: {{ .Values.global.pullPolicy }}
|
||||
command: ["python", "onyx/onyxbot/discord/client.py"]
|
||||
resources:
|
||||
{{- toYaml .Values.discordbot.resources | nindent 12 }}
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: {{ .Values.config.envConfigMapName }}
|
||||
env:
|
||||
{{- include "onyx.envSecrets" . | nindent 12}}
|
||||
# Discord bot token - required for bot to connect
|
||||
{{- if .Values.discordbot.botToken }}
|
||||
- name: DISCORD_BOT_TOKEN
|
||||
value: {{ .Values.discordbot.botToken | quote }}
|
||||
{{- end }}
|
||||
{{- if .Values.discordbot.botTokenSecretName }}
|
||||
- name: DISCORD_BOT_TOKEN
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: {{ .Values.discordbot.botTokenSecretName }}
|
||||
key: {{ .Values.discordbot.botTokenSecretKey | default "token" }}
|
||||
{{- end }}
|
||||
# Command prefix for bot commands (default: "!")
|
||||
{{- if .Values.discordbot.invokeChar }}
|
||||
- name: DISCORD_BOT_INVOKE_CHAR
|
||||
value: {{ .Values.discordbot.invokeChar | quote }}
|
||||
{{- end }}
|
||||
{{- with .Values.discordbot.volumeMounts }}
|
||||
volumeMounts:
|
||||
{{- toYaml . | nindent 12 }}
|
||||
{{- end }}
|
||||
{{- with .Values.discordbot.volumes }}
|
||||
volumes:
|
||||
{{- toYaml . | nindent 8 }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
@@ -655,6 +655,44 @@ celery_worker_user_file_processing:
|
||||
tolerations: []
|
||||
affinity: {}
|
||||
|
||||
# Discord bot for Onyx
|
||||
# The bot offloads message processing to scalable API pods via HTTP requests.
|
||||
discordbot:
|
||||
enabled: false # Disabled by default - requires bot token configuration
|
||||
# Bot token can be provided directly or via a Kubernetes secret
|
||||
# Option 1: Direct token (not recommended for production)
|
||||
botToken: ""
|
||||
# Option 2: Reference a Kubernetes secret (recommended)
|
||||
botTokenSecretName: "" # Name of the secret containing the bot token
|
||||
botTokenSecretKey: "token" # Key within the secret (default: "token")
|
||||
# Command prefix for bot commands (default: "!")
|
||||
invokeChar: "!"
|
||||
image:
|
||||
repository: onyxdotapp/onyx-backend
|
||||
tag: "" # Overrides the image tag whose default is the chart appVersion.
|
||||
podAnnotations: {}
|
||||
podLabels:
|
||||
scope: onyx-backend
|
||||
app: discord-bot
|
||||
deploymentLabels:
|
||||
app: discord-bot
|
||||
podSecurityContext:
|
||||
{}
|
||||
securityContext:
|
||||
{}
|
||||
resources:
|
||||
requests:
|
||||
cpu: "500m"
|
||||
memory: "512Mi"
|
||||
limits:
|
||||
cpu: "1000m"
|
||||
memory: "2000Mi"
|
||||
volumes: []
|
||||
volumeMounts: []
|
||||
nodeSelector: {}
|
||||
tolerations: []
|
||||
affinity: {}
|
||||
|
||||
slackbot:
|
||||
enabled: true
|
||||
replicaCount: 1
|
||||
@@ -1090,6 +1128,8 @@ configMap:
|
||||
ONYX_BOT_DISPLAY_ERROR_MSGS: ""
|
||||
ONYX_BOT_RESPOND_EVERY_CHANNEL: ""
|
||||
NOTIFY_SLACKBOT_NO_ANSWER: ""
|
||||
DISCORD_BOT_TOKEN: ""
|
||||
DISCORD_BOT_INVOKE_CHAR: ""
|
||||
# Logging
|
||||
# Optional Telemetry, please keep it on (nothing sensitive is collected)? <3
|
||||
DISABLE_TELEMETRY: ""
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user