mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-06 16:15:46 +00:00
Compare commits
138 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3f8ef8b465 | ||
|
|
ed46504a1a | ||
|
|
7a24b34516 | ||
|
|
7a7ffa9051 | ||
|
|
3053ab518c | ||
|
|
be38d3500f | ||
|
|
753a3bc093 | ||
|
|
2ba8fafe78 | ||
|
|
b77b580ebd | ||
|
|
3eee98b932 | ||
|
|
a97eb02fef | ||
|
|
c5061495a2 | ||
|
|
c20b0789ae | ||
|
|
d99848717b | ||
|
|
aaca55c415 | ||
|
|
9d7ffd1e4a | ||
|
|
a249161827 | ||
|
|
e126346a91 | ||
|
|
a96682fa73 | ||
|
|
3920371d56 | ||
|
|
e5a257345c | ||
|
|
a49df511e2 | ||
|
|
d5d2a8a1a6 | ||
|
|
b2f46b264c | ||
|
|
c6ad363fbd | ||
|
|
e313119f9a | ||
|
|
3a2a542a03 | ||
|
|
413aeba4a1 | ||
|
|
46028aa2bb | ||
|
|
454943c4a6 | ||
|
|
87946266de | ||
|
|
144030c5ca | ||
|
|
a557d76041 | ||
|
|
605e808158 | ||
|
|
8fec88c90d | ||
|
|
e54969a693 | ||
|
|
1da2b2f28f | ||
|
|
eb7b91e08e | ||
|
|
3339000968 | ||
|
|
d9db849e94 | ||
|
|
046408359c | ||
|
|
4b8cca190f | ||
|
|
52a312a63b | ||
|
|
0594fd17de | ||
|
|
fded81dc28 | ||
|
|
31db112de9 | ||
|
|
a3e2da2c51 | ||
|
|
f4d33bcc0d | ||
|
|
464d957494 | ||
|
|
be12de9a44 | ||
|
|
3e4a1f8a09 | ||
|
|
af9b7826ab | ||
|
|
cb16eb13fc | ||
|
|
20a73bdd2e | ||
|
|
85cc2b99b7 | ||
|
|
1208a3ee2b | ||
|
|
900fcef9dd | ||
|
|
d4ed25753b | ||
|
|
0ee58333b4 | ||
|
|
11b7e0d571 | ||
|
|
a35831f328 | ||
|
|
048a6d5259 | ||
|
|
e4bdb15910 | ||
|
|
3517d59286 | ||
|
|
4bc08e5d88 | ||
|
|
4bd080cf62 | ||
|
|
b0a8625ffc | ||
|
|
f94baf6143 | ||
|
|
9e1867638a | ||
|
|
5b6d7c9f0d | ||
|
|
e5dcf31f10 | ||
|
|
8ca06ef3e7 | ||
|
|
6897dbd610 | ||
|
|
7f3cb77466 | ||
|
|
267042a5aa | ||
|
|
d02b3ae6ac | ||
|
|
683c3f7a7e | ||
|
|
008b4d2288 | ||
|
|
8be261405a | ||
|
|
61f2c48ebc | ||
|
|
dbde2e6d6d | ||
|
|
2860136214 | ||
|
|
49ec5994d3 | ||
|
|
8d5fb67f0f | ||
|
|
15d02f6e3c | ||
|
|
e58974c419 | ||
|
|
6b66c07952 | ||
|
|
cae058a3ac | ||
|
|
aa3b21a191 | ||
|
|
7a07a78696 | ||
|
|
a8db236e37 | ||
|
|
8a2e4ed36f | ||
|
|
216f2c95a7 | ||
|
|
67081efe08 | ||
|
|
9d40b8336f | ||
|
|
23f0033302 | ||
|
|
9011b76eb0 | ||
|
|
45e436bafc | ||
|
|
010bc36d61 | ||
|
|
468e488bdb | ||
|
|
9104c0ffce | ||
|
|
d36a6bd0b4 | ||
|
|
a3603c498c | ||
|
|
8f274e34c9 | ||
|
|
5c256760ff | ||
|
|
258e1372b3 | ||
|
|
83a543a265 | ||
|
|
f9719d199d | ||
|
|
1c7bb6e56a | ||
|
|
982ad7d329 | ||
|
|
f94292808b | ||
|
|
293553a2e2 | ||
|
|
ba906ae6fa | ||
|
|
c84c7a354e | ||
|
|
2187b0dd82 | ||
|
|
d88a417bf9 | ||
|
|
f2d32b0b3b | ||
|
|
f89432009f | ||
|
|
8ab2bab34e | ||
|
|
59e0d62512 | ||
|
|
a1471b16a5 | ||
|
|
9d3811cb58 | ||
|
|
3cd9505383 | ||
|
|
d11829b393 | ||
|
|
f6e068e914 | ||
|
|
0c84edd980 | ||
|
|
2b274a7683 | ||
|
|
ddd91f2d71 | ||
|
|
a7c7da0dfc | ||
|
|
b00a3e8b5d | ||
|
|
d77d1a48f1 | ||
|
|
7b4fc6729c | ||
|
|
1f113c86ef | ||
|
|
8e38ba3e21 | ||
|
|
bb9708a64f | ||
|
|
8cae97e145 | ||
|
|
7e4abca224 | ||
|
|
233a91ea65 |
411
.github/workflows/deployment.yml
vendored
411
.github/workflows/deployment.yml
vendored
@@ -8,7 +8,9 @@ on:
|
||||
|
||||
# Set restrictive default permissions for all jobs. Jobs that need more permissions
|
||||
# should explicitly declare them.
|
||||
permissions: {}
|
||||
permissions:
|
||||
# Required for OIDC authentication with AWS
|
||||
id-token: write # zizmor: ignore[excessive-permissions]
|
||||
|
||||
env:
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
|
||||
@@ -150,16 +152,30 @@ jobs:
|
||||
if: always() && needs.check-version-tag.result == 'failure' && github.event_name != 'workflow_dispatch'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
environment: release
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
MONITOR_DEPLOYMENTS_WEBHOOK, deploy/monitor-deployments-webhook
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
webhook-url: ${{ env.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
failed-jobs: "• check-version-tag"
|
||||
title: "🚨 Version Tag Check Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
@@ -168,6 +184,7 @@ jobs:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-desktop == 'true'
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: write
|
||||
actions: read
|
||||
strategy:
|
||||
@@ -185,12 +202,33 @@ jobs:
|
||||
|
||||
runs-on: ${{ matrix.platform }}
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6.0.1
|
||||
with:
|
||||
# NOTE: persist-credentials is needed for tauri-action to create GitHub releases.
|
||||
persist-credentials: true # zizmor: ignore[artipacked]
|
||||
|
||||
- name: Configure AWS credentials
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
APPLE_ID, deploy/apple-id
|
||||
APPLE_PASSWORD, deploy/apple-password
|
||||
APPLE_CERTIFICATE, deploy/apple-certificate
|
||||
APPLE_CERTIFICATE_PASSWORD, deploy/apple-certificate-password
|
||||
KEYCHAIN_PASSWORD, deploy/keychain-password
|
||||
APPLE_TEAM_ID, deploy/apple-team-id
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: install dependencies (ubuntu only)
|
||||
if: startsWith(matrix.platform, 'ubuntu-')
|
||||
run: |
|
||||
@@ -285,15 +323,40 @@ jobs:
|
||||
|
||||
Write-Host "Versions set to: $VERSION"
|
||||
|
||||
- name: Import Apple Developer Certificate
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
run: |
|
||||
echo $APPLE_CERTIFICATE | base64 --decode > certificate.p12
|
||||
security create-keychain -p "$KEYCHAIN_PASSWORD" build.keychain
|
||||
security default-keychain -s build.keychain
|
||||
security unlock-keychain -p "$KEYCHAIN_PASSWORD" build.keychain
|
||||
security set-keychain-settings -t 3600 -u build.keychain
|
||||
security import certificate.p12 -k build.keychain -P "$APPLE_CERTIFICATE_PASSWORD" -T /usr/bin/codesign
|
||||
security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k "$KEYCHAIN_PASSWORD" build.keychain
|
||||
security find-identity -v -p codesigning build.keychain
|
||||
|
||||
- name: Verify Certificate
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
run: |
|
||||
CERT_INFO=$(security find-identity -v -p codesigning build.keychain | grep -E "(Developer ID Application|Apple Distribution|Apple Development)" | head -n 1)
|
||||
CERT_ID=$(echo "$CERT_INFO" | awk -F'"' '{print $2}')
|
||||
echo "CERT_ID=$CERT_ID" >> $GITHUB_ENV
|
||||
echo "Certificate imported."
|
||||
|
||||
- uses: tauri-apps/tauri-action@73fb865345c54760d875b94642314f8c0c894afa # ratchet:tauri-apps/tauri-action@action-v0.6.1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
APPLE_ID: ${{ env.APPLE_ID }}
|
||||
APPLE_PASSWORD: ${{ env.APPLE_PASSWORD }}
|
||||
APPLE_SIGNING_IDENTITY: ${{ env.CERT_ID }}
|
||||
APPLE_TEAM_ID: ${{ env.APPLE_TEAM_ID }}
|
||||
with:
|
||||
tagName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseBody: "See the assets to download this version and install."
|
||||
releaseDraft: true
|
||||
prerelease: false
|
||||
assetNamePattern: "[name]_[arch][ext]"
|
||||
args: ${{ matrix.args }}
|
||||
|
||||
build-web-amd64:
|
||||
@@ -305,6 +368,7 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-web-amd64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -317,6 +381,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -326,13 +404,13 @@ jobs:
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
@@ -363,6 +441,7 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-web-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -375,6 +454,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -384,13 +477,13 @@ jobs:
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
@@ -423,19 +516,34 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-merge-web
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -471,6 +579,7 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-web-cloud-amd64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -483,6 +592,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -492,13 +615,13 @@ jobs:
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
@@ -537,6 +660,7 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-web-cloud-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -549,6 +673,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -558,13 +696,13 @@ jobs:
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
@@ -605,19 +743,34 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-merge-web-cloud
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -650,6 +803,7 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-backend-amd64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -662,6 +816,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -671,13 +839,13 @@ jobs:
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
@@ -707,6 +875,7 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-backend-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -719,6 +888,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -728,13 +911,13 @@ jobs:
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
@@ -766,19 +949,34 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-merge-backend
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -815,6 +1013,7 @@ jobs:
|
||||
- volume=40gb
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -827,6 +1026,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -836,15 +1049,15 @@ jobs:
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
with:
|
||||
buildkitd-flags: ${{ vars.DOCKER_DEBUG == 'true' && '--debug' || '' }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
@@ -879,6 +1092,7 @@ jobs:
|
||||
- volume=40gb
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -891,6 +1105,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -900,15 +1128,15 @@ jobs:
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
with:
|
||||
buildkitd-flags: ${{ vars.DOCKER_DEBUG == 'true' && '--debug' || '' }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
@@ -944,19 +1172,34 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-merge-model-server
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -994,11 +1237,26 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-trivy-scan-web
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
@@ -1014,8 +1272,8 @@ jobs:
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
@@ -1034,11 +1292,26 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-trivy-scan-web-cloud
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
@@ -1054,8 +1327,8 @@ jobs:
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
@@ -1074,6 +1347,7 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-trivy-scan-backend
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
steps:
|
||||
@@ -1084,6 +1358,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
@@ -1100,8 +1388,8 @@ jobs:
|
||||
-v ${{ github.workspace }}/backend/.trivyignore:/tmp/.trivyignore:ro \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
@@ -1121,11 +1409,26 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-trivy-scan-model-server
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
@@ -1141,8 +1444,8 @@ jobs:
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
@@ -1170,12 +1473,26 @@ jobs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
MONITOR_DEPLOYMENTS_WEBHOOK, deploy/monitor-deployments-webhook
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Determine failed jobs
|
||||
id: failed-jobs
|
||||
shell: bash
|
||||
@@ -1241,7 +1558,7 @@ jobs:
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
webhook-url: ${{ env.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
failed-jobs: ${{ steps.failed-jobs.outputs.jobs }}
|
||||
title: "🚨 Deployment Workflow Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
|
||||
2
.github/workflows/docker-tag-beta.yml
vendored
2
.github/workflows/docker-tag-beta.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
|
||||
2
.github/workflows/docker-tag-latest.yml
vendored
2
.github/workflows/docker-tag-latest.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
|
||||
1
.github/workflows/helm-chart-releases.yml
vendored
1
.github/workflows/helm-chart-releases.yml
vendored
@@ -29,6 +29,7 @@ jobs:
|
||||
run: |
|
||||
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
|
||||
helm repo add onyx-vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo add opensearch https://opensearch-project.github.io/helm-charts
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
|
||||
2
.github/workflows/nightly-scan-licenses.yml
vendored
2
.github/workflows/nightly-scan-licenses.yml
vendored
@@ -94,7 +94,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
|
||||
@@ -45,6 +45,9 @@ env:
|
||||
# TODO: debug why this is failing and enable
|
||||
CODE_INTERPRETER_BASE_URL: http://localhost:8000
|
||||
|
||||
# OpenSearch
|
||||
OPENSEARCH_ADMIN_PASSWORD: "StrongPassword123!"
|
||||
|
||||
jobs:
|
||||
discover-test-dirs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
@@ -125,11 +128,13 @@ jobs:
|
||||
docker compose \
|
||||
-f docker-compose.yml \
|
||||
-f docker-compose.dev.yml \
|
||||
-f docker-compose.opensearch.yml \
|
||||
up -d \
|
||||
minio \
|
||||
relational_db \
|
||||
cache \
|
||||
index \
|
||||
opensearch \
|
||||
code-interpreter
|
||||
|
||||
- name: Run migrations
|
||||
@@ -158,7 +163,7 @@ jobs:
|
||||
cd deployment/docker_compose
|
||||
|
||||
# Get list of running containers
|
||||
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml ps -q)
|
||||
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml -f docker-compose.opensearch.yml ps -q)
|
||||
|
||||
# Collect logs from each container
|
||||
for container in $containers; do
|
||||
|
||||
8
.github/workflows/pr-helm-chart-testing.yml
vendored
8
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -88,6 +88,7 @@ jobs:
|
||||
echo "=== Adding Helm repositories ==="
|
||||
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
|
||||
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo add opensearch https://opensearch-project.github.io/helm-charts
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
@@ -180,6 +181,11 @@ jobs:
|
||||
trap cleanup EXIT
|
||||
|
||||
# Run the actual installation with detailed logging
|
||||
# Note that opensearch.enabled is true whereas others in this install
|
||||
# are false. There is some work that needs to be done to get this
|
||||
# entire step working in CI, enabling opensearch here is a small step
|
||||
# in that direction. If this is causing issues, disabling it in this
|
||||
# step should be ok in the short term.
|
||||
echo "=== Starting ct install ==="
|
||||
set +e
|
||||
ct install --all \
|
||||
@@ -187,6 +193,8 @@ jobs:
|
||||
--set=nginx.enabled=false \
|
||||
--set=minio.enabled=false \
|
||||
--set=vespa.enabled=false \
|
||||
--set=opensearch.enabled=true \
|
||||
--set=auth.opensearch.enabled=true \
|
||||
--set=slackbot.enabled=false \
|
||||
--set=postgresql.enabled=true \
|
||||
--set=postgresql.nameOverride=cloudnative-pg \
|
||||
|
||||
6
.github/workflows/pr-integration-tests.yml
vendored
6
.github/workflows/pr-integration-tests.yml
vendored
@@ -103,7 +103,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
@@ -163,7 +163,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
@@ -208,7 +208,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling openapitools/openapi-generator-cli
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
|
||||
@@ -95,7 +95,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
@@ -155,7 +155,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
@@ -214,7 +214,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling openapitools/openapi-generator-cli
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
|
||||
6
.github/workflows/pr-playwright-tests.yml
vendored
6
.github/workflows/pr-playwright-tests.yml
vendored
@@ -85,7 +85,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
@@ -146,7 +146,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
@@ -207,7 +207,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
|
||||
3
.github/workflows/pr-python-checks.yml
vendored
3
.github/workflows/pr-python-checks.yml
vendored
@@ -50,8 +50,9 @@ jobs:
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: backend/.mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
restore-keys: |
|
||||
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
|
||||
mypy-${{ runner.os }}-
|
||||
|
||||
- name: Run MyPy
|
||||
|
||||
2
.github/workflows/pr-python-model-tests.yml
vendored
2
.github/workflows/pr-python-model-tests.yml
vendored
@@ -70,7 +70,7 @@ jobs:
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f
|
||||
|
||||
- name: Build and load
|
||||
uses: docker/bake-action@5be5f02ff8819ecd3092ea6b2e6261c31774f2b4 # ratchet:docker/bake-action@v6
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,5 +1,8 @@
|
||||
# editors
|
||||
.vscode
|
||||
!/.vscode/env_template.txt
|
||||
!/.vscode/launch.json
|
||||
!/.vscode/tasks.template.jsonc
|
||||
.zed
|
||||
.cursor
|
||||
|
||||
|
||||
@@ -74,6 +74,13 @@ repos:
|
||||
# pass_filenames: true
|
||||
# files: ^backend/.*\.py$
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: 3e8a8703264a2f4a69428a0aa4dcb512790b2c8c # frozen: v6.0.0
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
name: Check for added large files
|
||||
args: ["--maxkb=1500"]
|
||||
|
||||
- repo: https://github.com/rhysd/actionlint
|
||||
rev: a443f344ff32813837fa49f7aa6cbc478d770e62 # frozen: v1.7.9
|
||||
hooks:
|
||||
|
||||
136
.vscode/launch.template.jsonc → .vscode/launch.json
vendored
136
.vscode/launch.template.jsonc → .vscode/launch.json
vendored
@@ -1,5 +1,3 @@
|
||||
/* Copy this file into '.vscode/launch.json' or merge its contents into your existing configurations. */
|
||||
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
@@ -24,7 +22,7 @@
|
||||
"Slack Bot",
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery background",
|
||||
"Celery heavy",
|
||||
"Celery docfetching",
|
||||
"Celery docprocessing",
|
||||
"Celery beat"
|
||||
@@ -151,6 +149,24 @@
|
||||
},
|
||||
"consoleTitle": "Slack Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "Discord Bot",
|
||||
"consoleName": "Discord Bot",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "onyx/onyxbot/discord/client.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Discord Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "MCP Server",
|
||||
"consoleName": "MCP Server",
|
||||
@@ -579,6 +595,120 @@
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Database ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "4",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Restore seeded database dump",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"run",
|
||||
"--with",
|
||||
"onyx-devtools",
|
||||
"ods",
|
||||
"db",
|
||||
"restore",
|
||||
"--fetch-seeded",
|
||||
"--yes"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "4"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clean restore seeded database dump (destructive)",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"run",
|
||||
"--with",
|
||||
"onyx-devtools",
|
||||
"ods",
|
||||
"db",
|
||||
"restore",
|
||||
"--fetch-seeded",
|
||||
"--clean",
|
||||
"--yes"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "4"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Create database snapshot",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"run",
|
||||
"--with",
|
||||
"onyx-devtools",
|
||||
"ods",
|
||||
"db",
|
||||
"dump",
|
||||
"backup.dump"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "4"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clean restore database snapshot (destructive)",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"run",
|
||||
"--with",
|
||||
"onyx-devtools",
|
||||
"ods",
|
||||
"db",
|
||||
"restore",
|
||||
"--clean",
|
||||
"--yes",
|
||||
"backup.dump"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "4"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Upgrade database to head revision",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"run",
|
||||
"--with",
|
||||
"onyx-devtools",
|
||||
"ods",
|
||||
"db",
|
||||
"upgrade"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "4"
|
||||
}
|
||||
},
|
||||
{
|
||||
// script to generate the openapi schema
|
||||
"name": "Onyx OpenAPI Schema Generator",
|
||||
@@ -37,10 +37,6 @@ CVE-2023-50868
|
||||
CVE-2023-52425
|
||||
CVE-2024-28757
|
||||
|
||||
# sqlite, only used by NLTK library to grab word lemmatizer and stopwords
|
||||
# No impact in our settings
|
||||
CVE-2023-7104
|
||||
|
||||
# libharfbuzz0b, O(n^2) growth, worst case is denial of service
|
||||
# Accept the risk
|
||||
CVE-2023-25193
|
||||
|
||||
@@ -89,12 +89,6 @@ RUN uv pip install --system --no-cache-dir --upgrade \
|
||||
RUN python -c "from tokenizers import Tokenizer; \
|
||||
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
|
||||
# Pre-downloading NLTK for setups with limited egress
|
||||
RUN python -c "import nltk; \
|
||||
nltk.download('stopwords', quiet=True); \
|
||||
nltk.download('punkt_tab', quiet=True);"
|
||||
# nltk.download('wordnet', quiet=True); introduce this back if lemmatization is needed
|
||||
|
||||
# Pre-downloading tiktoken for setups with limited egress
|
||||
RUN python -c "import tiktoken; \
|
||||
tiktoken.get_encoding('cl100k_base')"
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
"""add_unique_constraint_to_inputprompt_prompt_user_id
|
||||
|
||||
Revision ID: 2c2430828bdf
|
||||
Revises: fb80bdd256de
|
||||
Create Date: 2026-01-20 16:01:54.314805
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2c2430828bdf"
|
||||
down_revision = "fb80bdd256de"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create unique constraint on (prompt, user_id) for user-owned prompts
|
||||
# This ensures each user can only have one shortcut with a given name
|
||||
op.create_unique_constraint(
|
||||
"uq_inputprompt_prompt_user_id",
|
||||
"inputprompt",
|
||||
["prompt", "user_id"],
|
||||
)
|
||||
|
||||
# Create partial unique index for public prompts (where user_id IS NULL)
|
||||
# PostgreSQL unique constraints don't enforce uniqueness for NULL values,
|
||||
# so we need a partial index to ensure public prompt names are also unique
|
||||
op.execute(
|
||||
"""
|
||||
CREATE UNIQUE INDEX uq_inputprompt_prompt_public
|
||||
ON inputprompt (prompt)
|
||||
WHERE user_id IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS uq_inputprompt_prompt_public")
|
||||
op.drop_constraint("uq_inputprompt_prompt_user_id", "inputprompt", type_="unique")
|
||||
@@ -0,0 +1,29 @@
|
||||
"""remove default prompt shortcuts
|
||||
|
||||
Revision ID: 41fa44bef321
|
||||
Revises: 2c2430828bdf
|
||||
Create Date: 2025-01-21
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "41fa44bef321"
|
||||
down_revision = "2c2430828bdf"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Delete any user associations for the default prompts first (foreign key constraint)
|
||||
op.execute(
|
||||
"DELETE FROM inputprompt__user WHERE input_prompt_id IN (SELECT id FROM inputprompt WHERE id < 0)"
|
||||
)
|
||||
# Delete the pre-seeded default prompt shortcuts (they have negative IDs)
|
||||
op.execute("DELETE FROM inputprompt WHERE id < 0")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# We don't restore the default prompts on downgrade
|
||||
pass
|
||||
116
backend/alembic/versions/8b5ce697290e_add_discord_bot_tables.py
Normal file
116
backend/alembic/versions/8b5ce697290e_add_discord_bot_tables.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Add Discord bot tables
|
||||
|
||||
Revision ID: 8b5ce697290e
|
||||
Revises: a1b2c3d4e5f7
|
||||
Create Date: 2025-01-14
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8b5ce697290e"
|
||||
down_revision = "a1b2c3d4e5f7"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# DiscordBotConfig (singleton table - one per tenant)
|
||||
op.create_table(
|
||||
"discord_bot_config",
|
||||
sa.Column(
|
||||
"id",
|
||||
sa.String(),
|
||||
primary_key=True,
|
||||
server_default=sa.text("'SINGLETON'"),
|
||||
),
|
||||
sa.Column("bot_token", sa.LargeBinary(), nullable=False), # EncryptedString
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.CheckConstraint("id = 'SINGLETON'", name="ck_discord_bot_config_singleton"),
|
||||
)
|
||||
|
||||
# DiscordGuildConfig
|
||||
op.create_table(
|
||||
"discord_guild_config",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("guild_id", sa.BigInteger(), nullable=True, unique=True),
|
||||
sa.Column("guild_name", sa.String(), nullable=True),
|
||||
sa.Column("registration_key", sa.String(), nullable=False, unique=True),
|
||||
sa.Column("registered_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"default_persona_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("persona.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False
|
||||
),
|
||||
)
|
||||
|
||||
# DiscordChannelConfig
|
||||
op.create_table(
|
||||
"discord_channel_config",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column(
|
||||
"guild_config_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("discord_guild_config.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("channel_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("channel_name", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"channel_type",
|
||||
sa.String(20),
|
||||
server_default=sa.text("'text'"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"is_private",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("false"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"thread_only_mode",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("false"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"require_bot_invocation",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("true"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"persona_override_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("persona.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"enabled", sa.Boolean(), server_default=sa.text("false"), nullable=False
|
||||
),
|
||||
)
|
||||
|
||||
# Unique constraint: one config per channel per guild
|
||||
op.create_unique_constraint(
|
||||
"uq_discord_channel_guild_channel",
|
||||
"discord_channel_config",
|
||||
["guild_config_id", "channel_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("discord_channel_config")
|
||||
op.drop_table("discord_guild_config")
|
||||
op.drop_table("discord_bot_config")
|
||||
@@ -0,0 +1,47 @@
|
||||
"""drop agent_search_metrics table
|
||||
|
||||
Revision ID: a1b2c3d4e5f7
|
||||
Revises: 73e9983e5091
|
||||
Create Date: 2026-01-17
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a1b2c3d4e5f7"
|
||||
down_revision = "73e9983e5091"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_table("agent__search_metrics")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.create_table(
|
||||
"agent__search_metrics",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("persona_id", sa.Integer(), nullable=True),
|
||||
sa.Column("agent_type", sa.String(), nullable=False),
|
||||
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("base_duration_s", sa.Float(), nullable=False),
|
||||
sa.Column("full_duration_s", sa.Float(), nullable=False),
|
||||
sa.Column("base_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("refined_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("all_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
@@ -1,42 +0,0 @@
|
||||
"""Add SET NULL cascade to chat_session.persona_id foreign key
|
||||
|
||||
Revision ID: ac9c7b76419b
|
||||
Revises: 73e9983e5091
|
||||
Create Date: 2026-01-17 18:10:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ac9c7b76419b"
|
||||
down_revision = "73e9983e5091"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop the existing foreign key constraint (no cascade behavior)
|
||||
op.drop_constraint("fk_chat_session_persona_id", "chat_session", type_="foreignkey")
|
||||
# Recreate with SET NULL on delete, so deleting a persona sets
|
||||
# chat_session.persona_id to NULL instead of blocking the delete
|
||||
op.create_foreign_key(
|
||||
"fk_chat_session_persona_id",
|
||||
"chat_session",
|
||||
"persona",
|
||||
["persona_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert to original constraint without cascade behavior
|
||||
op.drop_constraint("fk_chat_session_persona_id", "chat_session", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"fk_chat_session_persona_id",
|
||||
"chat_session",
|
||||
"persona",
|
||||
["persona_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -0,0 +1,31 @@
|
||||
"""add chat_background to user
|
||||
|
||||
Revision ID: fb80bdd256de
|
||||
Revises: 8b5ce697290e
|
||||
Create Date: 2026-01-16 16:15:59.222617
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "fb80bdd256de"
|
||||
down_revision = "8b5ce697290e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"chat_background",
|
||||
sa.String(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "chat_background")
|
||||
@@ -109,7 +109,6 @@ CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS = float(
|
||||
|
||||
|
||||
STRIPE_SECRET_KEY = os.environ.get("STRIPE_SECRET_KEY")
|
||||
STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")
|
||||
|
||||
# JWT Public Key URL
|
||||
JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
|
||||
@@ -129,3 +128,8 @@ MARKETING_POSTHOG_API_KEY = os.environ.get("MARKETING_POSTHOG_API_KEY")
|
||||
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
|
||||
|
||||
GATED_TENANTS_KEY = "gated_tenants"
|
||||
|
||||
# License enforcement - when True, blocks API access for gated/expired licenses
|
||||
LICENSE_ENFORCEMENT_ENABLED = (
|
||||
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
@@ -16,6 +16,9 @@ from ee.onyx.server.enterprise_settings.api import (
|
||||
from ee.onyx.server.evals.api import router as evals_router
|
||||
from ee.onyx.server.license.api import router as license_router
|
||||
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
|
||||
from ee.onyx.server.middleware.license_enforcement import (
|
||||
add_license_enforcement_middleware,
|
||||
)
|
||||
from ee.onyx.server.middleware.tenant_tracking import (
|
||||
add_api_server_tenant_id_middleware,
|
||||
)
|
||||
@@ -83,6 +86,10 @@ def get_application() -> FastAPI:
|
||||
if MULTI_TENANT:
|
||||
add_api_server_tenant_id_middleware(application, logger)
|
||||
|
||||
# Add license enforcement middleware (runs after tenant tracking)
|
||||
# This blocks access when license is expired/gated
|
||||
add_license_enforcement_middleware(application, logger)
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
# For Google OAuth, refresh tokens are requested by:
|
||||
# 1. Adding the right scopes
|
||||
|
||||
@@ -17,7 +17,8 @@ from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.pipeline import merge_individual_chunks
|
||||
from onyx.context.search.pipeline import search_pipeline
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.factory import get_current_primary_default_document_index
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.secondary_llm_flows.document_filter import select_sections_for_expansion
|
||||
@@ -42,11 +43,13 @@ def _run_single_search(
|
||||
document_index: DocumentIndex,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
num_hits: int | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Execute a single search query and return chunks."""
|
||||
chunk_search_request = ChunkSearchRequest(
|
||||
query=query,
|
||||
user_selected_filters=filters,
|
||||
limit=num_hits,
|
||||
)
|
||||
|
||||
return search_pipeline(
|
||||
@@ -72,7 +75,9 @@ def stream_search_query(
|
||||
Used by both streaming and non-streaming endpoints.
|
||||
"""
|
||||
# Get document index
|
||||
document_index = get_current_primary_default_document_index(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
# This flow is for search so we do not get all indices.
|
||||
document_index = get_default_document_index(search_settings, None)
|
||||
|
||||
# Determine queries to execute
|
||||
original_query = request.search_query
|
||||
@@ -114,6 +119,7 @@ def stream_search_query(
|
||||
document_index=document_index,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
num_hits=request.num_hits,
|
||||
)
|
||||
else:
|
||||
# Multiple queries - run in parallel and merge with RRF
|
||||
@@ -121,7 +127,14 @@ def stream_search_query(
|
||||
search_functions = [
|
||||
(
|
||||
_run_single_search,
|
||||
(query, request.filters, document_index, user, db_session),
|
||||
(
|
||||
query,
|
||||
request.filters,
|
||||
document_index,
|
||||
user,
|
||||
db_session,
|
||||
request.num_hits,
|
||||
),
|
||||
)
|
||||
for query in all_executed_queries
|
||||
]
|
||||
@@ -168,6 +181,9 @@ def stream_search_query(
|
||||
# Merge chunks into sections
|
||||
sections = merge_individual_chunks(chunks)
|
||||
|
||||
# Truncate to the requested number of hits
|
||||
sections = sections[: request.num_hits]
|
||||
|
||||
# Apply LLM document selection if requested
|
||||
# num_docs_fed_to_llm_selection specifies how many sections to feed to the LLM for selection
|
||||
# The LLM will always try to select TARGET_NUM_SECTIONS_FOR_LLM_SELECTION sections from those fed to it
|
||||
|
||||
@@ -10,6 +10,8 @@ EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [
|
||||
("/enterprise-settings/logo", {"GET"}),
|
||||
("/enterprise-settings/logotype", {"GET"}),
|
||||
("/enterprise-settings/custom-analytics-script", {"GET"}),
|
||||
# Stripe publishable key is safe to expose publicly
|
||||
("/tenants/stripe-publishable-key", {"GET"}),
|
||||
]
|
||||
|
||||
|
||||
|
||||
102
backend/ee/onyx/server/middleware/license_enforcement.py
Normal file
102
backend/ee/onyx/server/middleware/license_enforcement.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Middleware to enforce license status application-wide."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from ee.onyx.server.tenants.product_gating import is_tenant_gated
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
# Paths that are ALWAYS accessible, even when license is expired/gated.
|
||||
# These enable users to:
|
||||
# /auth - Log in/out (users can't fix billing if locked out of auth)
|
||||
# /license - Fetch, upload, or check license status
|
||||
# /health - Health checks for load balancers/orchestrators
|
||||
# /me - Basic user info needed for UI rendering
|
||||
# /settings, /enterprise-settings - View app status and branding
|
||||
# /tenants/billing-* - Manage subscription to resolve gating
|
||||
ALLOWED_PATH_PREFIXES = {
|
||||
"/auth",
|
||||
"/license",
|
||||
"/health",
|
||||
"/me",
|
||||
"/settings",
|
||||
"/enterprise-settings",
|
||||
"/tenants/billing-information",
|
||||
"/tenants/create-customer-portal-session",
|
||||
"/tenants/create-subscription-session",
|
||||
}
|
||||
|
||||
|
||||
def _is_path_allowed(path: str) -> bool:
|
||||
"""Check if path is in allowlist (prefix match)."""
|
||||
return any(path.startswith(prefix) for prefix in ALLOWED_PATH_PREFIXES)
|
||||
|
||||
|
||||
def add_license_enforcement_middleware(
|
||||
app: FastAPI, logger: logging.LoggerAdapter
|
||||
) -> None:
|
||||
logger.info("License enforcement middleware registered")
|
||||
|
||||
@app.middleware("http")
|
||||
async def enforce_license(
|
||||
request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||||
) -> Response:
|
||||
"""Block requests when license is expired/gated."""
|
||||
if not LICENSE_ENFORCEMENT_ENABLED:
|
||||
return await call_next(request)
|
||||
|
||||
path = request.url.path
|
||||
if path.startswith("/api"):
|
||||
path = path[4:]
|
||||
|
||||
if _is_path_allowed(path):
|
||||
return await call_next(request)
|
||||
|
||||
is_gated = False
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if MULTI_TENANT:
|
||||
try:
|
||||
is_gated = is_tenant_gated(tenant_id)
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check tenant gating status: {e}")
|
||||
# Fail open - don't block users due to Redis connectivity issues
|
||||
is_gated = False
|
||||
else:
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if metadata:
|
||||
if metadata.status == ApplicationStatus.GATED_ACCESS:
|
||||
is_gated = True
|
||||
else:
|
||||
# No license metadata = gated for self-hosted EE
|
||||
is_gated = True
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license metadata: {e}")
|
||||
# Fail open - don't block users due to Redis connectivity issues
|
||||
is_gated = False
|
||||
|
||||
if is_gated:
|
||||
logger.info(f"Blocking request for gated tenant: {tenant_id}, path={path}")
|
||||
return JSONResponse(
|
||||
status_code=402,
|
||||
content={
|
||||
"detail": {
|
||||
"error": "license_expired",
|
||||
"message": "Your subscription has expired. Please update your billing.",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
@@ -32,6 +32,7 @@ class SendSearchQueryRequest(BaseModel):
|
||||
filters: BaseFilters | None = None
|
||||
num_docs_fed_to_llm_selection: int | None = None
|
||||
run_query_expansion: bool = False
|
||||
num_hits: int = 50
|
||||
|
||||
include_content: bool = False
|
||||
stream: bool = False
|
||||
|
||||
54
backend/ee/onyx/server/settings/api.py
Normal file
54
backend/ee/onyx/server/settings/api.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""EE Settings API - provides license-aware settings override."""
|
||||
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Statuses that indicate a billing/license problem - propagate these to settings
|
||||
_GATED_STATUSES = frozenset(
|
||||
{
|
||||
ApplicationStatus.GATED_ACCESS,
|
||||
ApplicationStatus.GRACE_PERIOD,
|
||||
ApplicationStatus.PAYMENT_REMINDER,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
"""EE version: checks license status for self-hosted deployments.
|
||||
|
||||
For self-hosted, looks up license metadata and overrides application_status
|
||||
if the license is missing or indicates a problem (expired, grace period, etc.).
|
||||
|
||||
For multi-tenant (cloud), the settings already have the correct status
|
||||
from the control plane, so no override is needed.
|
||||
|
||||
If LICENSE_ENFORCEMENT_ENABLED is false, settings are returned unchanged,
|
||||
allowing the product to function normally without license checks.
|
||||
"""
|
||||
if not LICENSE_ENFORCEMENT_ENABLED:
|
||||
return settings
|
||||
|
||||
if MULTI_TENANT:
|
||||
return settings
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if metadata and metadata.status in _GATED_STATUSES:
|
||||
settings.application_status = metadata.status
|
||||
elif not metadata:
|
||||
# No license = gated access for self-hosted EE
|
||||
settings.application_status = ApplicationStatus.GATED_ACCESS
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license metadata for settings: {e}")
|
||||
|
||||
return settings
|
||||
@@ -1,9 +1,9 @@
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
import requests
|
||||
import stripe
|
||||
|
||||
from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
@@ -16,15 +16,21 @@ stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def fetch_stripe_checkout_session(tenant_id: str) -> str:
|
||||
def fetch_stripe_checkout_session(
|
||||
tenant_id: str,
|
||||
billing_period: Literal["monthly", "annual"] = "monthly",
|
||||
) -> str:
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/create-checkout-session"
|
||||
params = {"tenant_id": tenant_id}
|
||||
response = requests.post(url, headers=headers, params=params)
|
||||
payload = {
|
||||
"tenant_id": tenant_id,
|
||||
"billing_period": billing_period,
|
||||
}
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()["sessionId"]
|
||||
|
||||
@@ -70,24 +76,46 @@ def fetch_billing_information(
|
||||
return BillingInformation(**response_data)
|
||||
|
||||
|
||||
def fetch_customer_portal_session(tenant_id: str, return_url: str | None = None) -> str:
|
||||
"""
|
||||
Fetch a Stripe customer portal session URL from the control plane.
|
||||
NOTE: This is currently only used for multi-tenant (cloud) deployments.
|
||||
Self-hosted proxy endpoints will be added in a future phase.
|
||||
"""
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/create-customer-portal-session"
|
||||
payload = {"tenant_id": tenant_id}
|
||||
if return_url:
|
||||
payload["return_url"] = return_url
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()["url"]
|
||||
|
||||
|
||||
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:
|
||||
"""
|
||||
Send a request to the control service to register the number of users for a tenant.
|
||||
Update the number of seats for a tenant's subscription.
|
||||
Preserves the existing price (monthly, annual, or grandfathered).
|
||||
"""
|
||||
|
||||
if not STRIPE_PRICE_ID:
|
||||
raise Exception("STRIPE_PRICE_ID is not set")
|
||||
|
||||
response = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_subscription_id = cast(str, response.get("stripe_subscription_id"))
|
||||
|
||||
subscription = stripe.Subscription.retrieve(stripe_subscription_id)
|
||||
subscription_item = subscription["items"]["data"][0]
|
||||
|
||||
# Use existing price to preserve the customer's current plan
|
||||
current_price_id = subscription_item.price.id
|
||||
|
||||
updated_subscription = stripe.Subscription.modify(
|
||||
stripe_subscription_id,
|
||||
items=[
|
||||
{
|
||||
"id": subscription["items"]["data"][0].id,
|
||||
"price": STRIPE_PRICE_ID,
|
||||
"id": subscription_item.id,
|
||||
"price": current_price_id,
|
||||
"quantity": number_of_users,
|
||||
}
|
||||
],
|
||||
|
||||
@@ -1,33 +1,41 @@
|
||||
import stripe
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import control_plane_dep
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.billing import fetch_customer_portal_session
|
||||
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
|
||||
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import CreateSubscriptionSessionRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingFullSyncRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingResponse
|
||||
from ee.onyx.server.tenants.models import StripePublishableKeyResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.product_gating import overwrite_full_gated_set
|
||||
from ee.onyx.server.tenants.product_gating import store_product_gating
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
# Cache for Stripe publishable key to avoid hitting S3 on every request
|
||||
_stripe_publishable_key_cache: str | None = None
|
||||
_stripe_key_lock = asyncio.Lock()
|
||||
|
||||
|
||||
@router.post("/product-gating")
|
||||
def gate_product(
|
||||
@@ -82,21 +90,17 @@ async def billing_information(
|
||||
async def create_customer_portal_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> dict:
|
||||
"""
|
||||
Create a Stripe customer portal session via the control plane.
|
||||
NOTE: This is currently only used for multi-tenant (cloud) deployments.
|
||||
Self-hosted proxy endpoints will be added in a future phase.
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
return_url = f"{WEB_DOMAIN}/admin/billing"
|
||||
|
||||
try:
|
||||
stripe_info = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_customer_id = stripe_info.get("stripe_customer_id")
|
||||
if not stripe_customer_id:
|
||||
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
|
||||
logger.info(stripe_customer_id)
|
||||
|
||||
portal_session = stripe.billing_portal.Session.create(
|
||||
customer=stripe_customer_id,
|
||||
return_url=f"{WEB_DOMAIN}/admin/billing",
|
||||
)
|
||||
logger.info(portal_session)
|
||||
return {"url": portal_session.url}
|
||||
portal_url = fetch_customer_portal_session(tenant_id, return_url)
|
||||
return {"url": portal_url}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create customer portal session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -104,15 +108,82 @@ async def create_customer_portal_session(
|
||||
|
||||
@router.post("/create-subscription-session")
|
||||
async def create_subscription_session(
|
||||
request: CreateSubscriptionSessionRequest | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> SubscriptionSessionResponse:
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="Tenant ID not found")
|
||||
session_id = fetch_stripe_checkout_session(tenant_id)
|
||||
|
||||
billing_period = request.billing_period if request else "monthly"
|
||||
session_id = fetch_stripe_checkout_session(tenant_id, billing_period)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create resubscription session")
|
||||
logger.exception("Failed to create subscription session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/stripe-publishable-key")
|
||||
async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
"""
|
||||
Fetch the Stripe publishable key.
|
||||
Priority: env var override (for testing) > S3 bucket (production).
|
||||
This endpoint is public (no auth required) since publishable keys are safe to expose.
|
||||
The key is cached in memory to avoid hitting S3 on every request.
|
||||
"""
|
||||
global _stripe_publishable_key_cache
|
||||
|
||||
# Fast path: return cached value without lock
|
||||
if _stripe_publishable_key_cache:
|
||||
return StripePublishableKeyResponse(
|
||||
publishable_key=_stripe_publishable_key_cache
|
||||
)
|
||||
|
||||
# Use lock to prevent concurrent S3 requests
|
||||
async with _stripe_key_lock:
|
||||
# Double-check after acquiring lock (another request may have populated cache)
|
||||
if _stripe_publishable_key_cache:
|
||||
return StripePublishableKeyResponse(
|
||||
publishable_key=_stripe_publishable_key_cache
|
||||
)
|
||||
|
||||
# Check for env var override first (for local testing with pk_test_* keys)
|
||||
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
|
||||
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
)
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
|
||||
# Fall back to S3 bucket
|
||||
if not STRIPE_PUBLISHABLE_KEY_URL:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Stripe publishable key is not configured",
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(STRIPE_PUBLISHABLE_KEY_URL)
|
||||
response.raise_for_status()
|
||||
key = response.text.strip()
|
||||
|
||||
# Validate key format
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
)
|
||||
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
except httpx.HTTPError:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to fetch Stripe publishable key",
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -73,6 +74,12 @@ class SubscriptionSessionResponse(BaseModel):
|
||||
sessionId: str
|
||||
|
||||
|
||||
class CreateSubscriptionSessionRequest(BaseModel):
|
||||
"""Request to create a subscription checkout session."""
|
||||
|
||||
billing_period: Literal["monthly", "annual"] = "monthly"
|
||||
|
||||
|
||||
class TenantByDomainResponse(BaseModel):
|
||||
tenant_id: str
|
||||
number_of_users: int
|
||||
@@ -98,3 +105,7 @@ class PendingUserSnapshot(BaseModel):
|
||||
|
||||
class ApproveUserRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
class StripePublishableKeyResponse(BaseModel):
|
||||
publishable_key: str
|
||||
|
||||
@@ -65,3 +65,9 @@ def get_gated_tenants() -> set[str]:
|
||||
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
gated_tenants_bytes = cast(set[bytes], redis_client.smembers(GATED_TENANTS_KEY))
|
||||
return {tenant_id.decode("utf-8") for tenant_id in gated_tenants_bytes}
|
||||
|
||||
|
||||
def is_tenant_gated(tenant_id: str) -> bool:
|
||||
"""Fast O(1) check if tenant is in gated set (multi-tenant only)."""
|
||||
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
return bool(redis_client.sismember(GATED_TENANTS_KEY, tenant_id))
|
||||
|
||||
@@ -97,10 +97,14 @@ def get_access_for_documents(
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
"""Returns a list of ACL entries that the user has access to. This is meant to be
|
||||
used downstream to filter out documents that the user does not have access to. The
|
||||
user should have access to a document if at least one entry in the document's ACL
|
||||
matches one entry in the returned set.
|
||||
"""Returns a list of ACL entries that the user has access to.
|
||||
|
||||
This is meant to be used downstream to filter out documents that the user
|
||||
does not have access to. The user should have access to a document if at
|
||||
least one entry in the document's ACL matches one entry in the returned set.
|
||||
|
||||
NOTE: These strings must be formatted in the same way as the output of
|
||||
DocumentAccess::to_acl.
|
||||
"""
|
||||
if user:
|
||||
return {prefix_user_email(user.email), PUBLIC_DOC_PAT}
|
||||
|
||||
@@ -125,9 +125,11 @@ class DocumentAccess(ExternalAccess):
|
||||
)
|
||||
|
||||
def to_acl(self) -> set[str]:
|
||||
# the acl's emitted by this function are prefixed by type
|
||||
# to get the native objects, access the member variables directly
|
||||
"""Converts the access state to a set of formatted ACL strings.
|
||||
|
||||
NOTE: When querying for documents, the supplied ACL filter strings must
|
||||
be formatted in the same way as this function.
|
||||
"""
|
||||
acl_set: set[str] = set()
|
||||
for user_email in self.user_emails:
|
||||
if user_email:
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import Protocol
|
||||
from typing import Tuple
|
||||
@@ -1456,6 +1457,9 @@ def get_default_admin_user_emails_() -> list[str]:
|
||||
|
||||
|
||||
STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state"
|
||||
STATE_TOKEN_LIFETIME_SECONDS = 3600
|
||||
CSRF_TOKEN_KEY = "csrftoken"
|
||||
CSRF_TOKEN_COOKIE_NAME = "fastapiusersoauthcsrf"
|
||||
|
||||
|
||||
class OAuth2AuthorizeResponse(BaseModel):
|
||||
@@ -1463,13 +1467,19 @@ class OAuth2AuthorizeResponse(BaseModel):
|
||||
|
||||
|
||||
def generate_state_token(
|
||||
data: Dict[str, str], secret: SecretType, lifetime_seconds: int = 3600
|
||||
data: Dict[str, str],
|
||||
secret: SecretType,
|
||||
lifetime_seconds: int = STATE_TOKEN_LIFETIME_SECONDS,
|
||||
) -> str:
|
||||
data["aud"] = STATE_TOKEN_AUDIENCE
|
||||
|
||||
return generate_jwt(data, secret, lifetime_seconds)
|
||||
|
||||
|
||||
def generate_csrf_token() -> str:
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91
|
||||
def create_onyx_oauth_router(
|
||||
oauth_client: BaseOAuth2,
|
||||
@@ -1498,6 +1508,13 @@ def get_oauth_router(
|
||||
redirect_url: Optional[str] = None,
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
*,
|
||||
csrf_token_cookie_name: str = CSRF_TOKEN_COOKIE_NAME,
|
||||
csrf_token_cookie_path: str = "/",
|
||||
csrf_token_cookie_domain: Optional[str] = None,
|
||||
csrf_token_cookie_secure: Optional[bool] = None,
|
||||
csrf_token_cookie_httponly: bool = True,
|
||||
csrf_token_cookie_samesite: Optional[Literal["lax", "strict", "none"]] = "lax",
|
||||
) -> APIRouter:
|
||||
"""Generate a router with the OAuth routes."""
|
||||
router = APIRouter()
|
||||
@@ -1514,6 +1531,9 @@ def get_oauth_router(
|
||||
route_name=callback_route_name,
|
||||
)
|
||||
|
||||
if csrf_token_cookie_secure is None:
|
||||
csrf_token_cookie_secure = WEB_DOMAIN.startswith("https")
|
||||
|
||||
@router.get(
|
||||
"/authorize",
|
||||
name=f"oauth:{oauth_client.name}.{backend.name}.authorize",
|
||||
@@ -1521,8 +1541,10 @@ def get_oauth_router(
|
||||
)
|
||||
async def authorize(
|
||||
request: Request,
|
||||
response: Response,
|
||||
redirect: bool = Query(False),
|
||||
scopes: List[str] = Query(None),
|
||||
) -> OAuth2AuthorizeResponse:
|
||||
) -> Response | OAuth2AuthorizeResponse:
|
||||
referral_source = request.cookies.get("referral_source", None)
|
||||
|
||||
if redirect_url is not None:
|
||||
@@ -1532,9 +1554,11 @@ def get_oauth_router(
|
||||
|
||||
next_url = request.query_params.get("next", "/")
|
||||
|
||||
csrf_token = generate_csrf_token()
|
||||
state_data: Dict[str, str] = {
|
||||
"next_url": next_url,
|
||||
"referral_source": referral_source or "default_referral",
|
||||
CSRF_TOKEN_KEY: csrf_token,
|
||||
}
|
||||
state = generate_state_token(state_data, state_secret)
|
||||
|
||||
@@ -1551,6 +1575,31 @@ def get_oauth_router(
|
||||
authorization_url, {"access_type": "offline", "prompt": "consent"}
|
||||
)
|
||||
|
||||
if redirect:
|
||||
redirect_response = RedirectResponse(authorization_url, status_code=302)
|
||||
redirect_response.set_cookie(
|
||||
key=csrf_token_cookie_name,
|
||||
value=csrf_token,
|
||||
max_age=STATE_TOKEN_LIFETIME_SECONDS,
|
||||
path=csrf_token_cookie_path,
|
||||
domain=csrf_token_cookie_domain,
|
||||
secure=csrf_token_cookie_secure,
|
||||
httponly=csrf_token_cookie_httponly,
|
||||
samesite=csrf_token_cookie_samesite,
|
||||
)
|
||||
return redirect_response
|
||||
|
||||
response.set_cookie(
|
||||
key=csrf_token_cookie_name,
|
||||
value=csrf_token,
|
||||
max_age=STATE_TOKEN_LIFETIME_SECONDS,
|
||||
path=csrf_token_cookie_path,
|
||||
domain=csrf_token_cookie_domain,
|
||||
secure=csrf_token_cookie_secure,
|
||||
httponly=csrf_token_cookie_httponly,
|
||||
samesite=csrf_token_cookie_samesite,
|
||||
)
|
||||
|
||||
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
@@ -1600,7 +1649,33 @@ def get_oauth_router(
|
||||
try:
|
||||
state_data = decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE])
|
||||
except jwt.DecodeError:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=getattr(
|
||||
ErrorCode, "ACCESS_TOKEN_DECODE_ERROR", "ACCESS_TOKEN_DECODE_ERROR"
|
||||
),
|
||||
)
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=getattr(
|
||||
ErrorCode,
|
||||
"ACCESS_TOKEN_ALREADY_EXPIRED",
|
||||
"ACCESS_TOKEN_ALREADY_EXPIRED",
|
||||
),
|
||||
)
|
||||
|
||||
cookie_csrf_token = request.cookies.get(csrf_token_cookie_name)
|
||||
state_csrf_token = state_data.get(CSRF_TOKEN_KEY)
|
||||
if (
|
||||
not cookie_csrf_token
|
||||
or not state_csrf_token
|
||||
or not secrets.compare_digest(cookie_csrf_token, state_csrf_token)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=getattr(ErrorCode, "OAUTH_INVALID_STATE", "OAUTH_INVALID_STATE"),
|
||||
)
|
||||
|
||||
next_url = state_data.get("next_url", "/")
|
||||
referral_source = state_data.get("referral_source", None)
|
||||
|
||||
@@ -26,10 +26,13 @@ from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
from onyx.background.celery.celery_utils import make_probe_path
|
||||
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_PREFIX
|
||||
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_TASKSET_KEY
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_FOR_ONYX
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
|
||||
from onyx.document_index.opensearch.client import (
|
||||
wait_for_opensearch_with_timeout,
|
||||
)
|
||||
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
@@ -516,15 +519,17 @@ def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for Vespa to become ready subject to a timeout.
|
||||
Raises WorkerShutdown if the timeout is reached."""
|
||||
|
||||
if ENABLE_OPENSEARCH_FOR_ONYX:
|
||||
# TODO(andrei): Do some similar liveness checking for OpenSearch.
|
||||
return
|
||||
|
||||
if not wait_for_vespa_with_timeout():
|
||||
msg = "Vespa: Readiness probe did not succeed within the timeout. Exiting..."
|
||||
msg = "[Vespa] Readiness probe did not succeed within the timeout. Exiting..."
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
if not wait_for_opensearch_with_timeout():
|
||||
msg = "[OpenSearch] Readiness probe did not succeed within the timeout. Exiting..."
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
|
||||
# File for validating worker liveness
|
||||
class LivenessProbe(bootsteps.StartStopStep):
|
||||
|
||||
@@ -87,7 +87,7 @@ from onyx.db.models import SearchSettings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.db.swap_index import check_and_perform_index_swap
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.file_store.document_batch_storage import DocumentBatchStorage
|
||||
from onyx.file_store.document_batch_storage import get_document_batch_storage
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
@@ -1436,7 +1436,7 @@ def _docprocessing_task(
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
document_index = get_default_document_index(
|
||||
document_indices = get_all_document_indices(
|
||||
index_attempt.search_settings,
|
||||
None,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
@@ -1473,7 +1473,7 @@ def _docprocessing_task(
|
||||
# real work happens here!
|
||||
index_pipeline_result = run_indexing_pipeline(
|
||||
embedder=embedding_model,
|
||||
document_index=document_index,
|
||||
document_indices=document_indices,
|
||||
ignore_time_skip=True, # Documents are already filtered during extraction
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@@ -25,7 +25,7 @@ from onyx.db.document_set import fetch_document_sets_for_document
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.relationships import delete_document_references_from_kg
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -97,13 +97,17 @@ def document_by_cc_pair_cleanup_task(
|
||||
action = "skip"
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
# This flow is for updates and deletion so we get all indices.
|
||||
document_indices = get_all_document_indices(
|
||||
active_search_settings.primary,
|
||||
active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
|
||||
retry_index = RetryDocumentIndex(doc_index)
|
||||
retry_document_indices: list[RetryDocumentIndex] = [
|
||||
RetryDocumentIndex(document_index)
|
||||
for document_index in document_indices
|
||||
]
|
||||
|
||||
count = get_document_connector_count(db_session, document_id)
|
||||
if count == 1:
|
||||
@@ -113,11 +117,12 @@ def document_by_cc_pair_cleanup_task(
|
||||
|
||||
chunk_count = fetch_chunk_count_for_document(document_id, db_session)
|
||||
|
||||
_ = retry_index.delete_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
for retry_document_index in retry_document_indices:
|
||||
_ = retry_document_index.delete_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
|
||||
delete_document_references_from_kg(
|
||||
db_session=db_session,
|
||||
@@ -155,14 +160,18 @@ def document_by_cc_pair_cleanup_task(
|
||||
hidden=doc.hidden,
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
retry_index.update_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
fields=fields,
|
||||
user_fields=None,
|
||||
)
|
||||
for retry_document_index in retry_document_indices:
|
||||
# TODO(andrei): Previously there was a comment here saying
|
||||
# it was ok if a doc did not exist in the document index. I
|
||||
# don't agree with that claim, so keep an eye on this task
|
||||
# to see if this raises.
|
||||
retry_document_index.update_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
fields=fields,
|
||||
user_fields=None,
|
||||
)
|
||||
|
||||
# there are still other cc_pair references to the doc, so just resync to Vespa
|
||||
delete_document_by_connector_credential_pair__no_commit(
|
||||
|
||||
@@ -12,6 +12,7 @@ from retry import retry
|
||||
from sqlalchemy import select
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
@@ -19,12 +20,14 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
@@ -32,7 +35,7 @@ from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
@@ -53,6 +56,17 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_queued_key(user_file_id: str | UUID) -> str:
|
||||
"""Key that exists while a process_single_user_file task is sitting in the queue.
|
||||
|
||||
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
before enqueuing and the worker deletes it as its first action. This prevents
|
||||
the beat from adding duplicate tasks for files that already have a live task
|
||||
in flight.
|
||||
"""
|
||||
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
@@ -116,7 +130,24 @@ def _get_document_chunk_count(
|
||||
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
|
||||
|
||||
Uses direct Redis locks to avoid overlapping runs.
|
||||
Three mechanisms prevent queue runaway:
|
||||
|
||||
1. **Queue depth backpressure** – if the broker queue already has more than
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
|
||||
entirely. Workers are clearly behind; adding more tasks would only make
|
||||
the backlog worse.
|
||||
|
||||
2. **Per-file queued guard** – before enqueuing a task we set a short-lived
|
||||
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
|
||||
already exists the file already has a live task in the queue, so we skip
|
||||
it. The worker deletes the key the moment it picks up the task so the
|
||||
next beat cycle can re-enqueue if the file is still PROCESSING.
|
||||
|
||||
3. **Task expiry** – every enqueued task carries an `expires` value equal to
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
|
||||
the queue after that deadline, Celery discards it without touching the DB.
|
||||
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
|
||||
Redis restart), stale tasks evict themselves rather than piling up forever.
|
||||
"""
|
||||
task_logger.info("check_user_file_processing - Starting")
|
||||
|
||||
@@ -131,7 +162,21 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
enqueued = 0
|
||||
skipped_guard = 0
|
||||
try:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
queue_len = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
|
||||
)
|
||||
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
f"check_user_file_processing - Queue depth {queue_len} exceeds "
|
||||
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file_ids = (
|
||||
db_session.execute(
|
||||
@@ -144,12 +189,35 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
)
|
||||
|
||||
for user_file_id in user_file_ids:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
# --- Protection 2: per-file queued guard ---
|
||||
queued_key = _user_file_queued_key(user_file_id)
|
||||
guard_set = redis_client.set(
|
||||
queued_key,
|
||||
1,
|
||||
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
nx=True,
|
||||
)
|
||||
if not guard_set:
|
||||
skipped_guard += 1
|
||||
continue
|
||||
|
||||
# --- Protection 3: task expiry ---
|
||||
# If task submission fails, clear the guard immediately so the
|
||||
# next beat cycle can retry enqueuing this file.
|
||||
try:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={
|
||||
"user_file_id": str(user_file_id),
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
)
|
||||
except Exception:
|
||||
redis_client.delete(queued_key)
|
||||
raise
|
||||
enqueued += 1
|
||||
|
||||
finally:
|
||||
@@ -157,7 +225,8 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
|
||||
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
|
||||
f"tasks for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -172,6 +241,12 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
|
||||
start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Clear the "queued" guard set by the beat generator so that the next beat
|
||||
# cycle can re-enqueue this file if it is still in PROCESSING state after
|
||||
# this task completes or fails.
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
@@ -244,7 +319,8 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
|
||||
search_settings=current_search_settings,
|
||||
)
|
||||
|
||||
document_index = get_default_document_index(
|
||||
# This flow is for indexing so we get all indices.
|
||||
document_indices = get_all_document_indices(
|
||||
current_search_settings,
|
||||
None,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
@@ -258,7 +334,7 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
|
||||
# real work happens here!
|
||||
index_pipeline_result = run_indexing_pipeline(
|
||||
embedder=embedding_model,
|
||||
document_index=document_index,
|
||||
document_indices=document_indices,
|
||||
ignore_time_skip=True,
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
@@ -412,12 +488,16 @@ def process_single_user_file_delete(
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
document_index = get_default_document_index(
|
||||
# This flow is for deletion so we get all indices.
|
||||
document_indices = get_all_document_indices(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
retry_index = RetryDocumentIndex(document_index)
|
||||
retry_document_indices: list[RetryDocumentIndex] = [
|
||||
RetryDocumentIndex(document_index)
|
||||
for document_index in document_indices
|
||||
]
|
||||
index_name = active_search_settings.primary.index_name
|
||||
selection = f"{index_name}.document_id=='{user_file_id}'"
|
||||
|
||||
@@ -438,11 +518,12 @@ def process_single_user_file_delete(
|
||||
else:
|
||||
chunk_count = user_file.chunk_count
|
||||
|
||||
retry_index.delete_single(
|
||||
doc_id=user_file_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.delete_single(
|
||||
doc_id=user_file_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
|
||||
# 2) Delete the user-uploaded file content from filestore (blob + metadata)
|
||||
file_store = get_default_file_store()
|
||||
@@ -564,12 +645,16 @@ def process_single_user_file_project_sync(
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
# This flow is for updates so we get all indices.
|
||||
document_indices = get_all_document_indices(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
retry_index = RetryDocumentIndex(doc_index)
|
||||
retry_document_indices: list[RetryDocumentIndex] = [
|
||||
RetryDocumentIndex(document_index)
|
||||
for document_index in document_indices
|
||||
]
|
||||
|
||||
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if not user_file:
|
||||
@@ -579,13 +664,14 @@ def process_single_user_file_project_sync(
|
||||
return None
|
||||
|
||||
project_ids = [project.id for project in user_file.projects]
|
||||
retry_index.update_single(
|
||||
doc_id=str(user_file.id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=user_file.chunk_count,
|
||||
fields=None,
|
||||
user_fields=VespaDocumentUserFields(user_projects=project_ids),
|
||||
)
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.update_single(
|
||||
doc_id=str(user_file.id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=user_file.chunk_count,
|
||||
fields=None,
|
||||
user_fields=VespaDocumentUserFields(user_projects=project_ids),
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - User file id={user_file_id}"
|
||||
|
||||
@@ -21,6 +21,8 @@ from onyx.utils.logger import setup_logger
|
||||
DOCUMENT_SYNC_PREFIX = "documentsync"
|
||||
DOCUMENT_SYNC_FENCE_KEY = f"{DOCUMENT_SYNC_PREFIX}_fence"
|
||||
DOCUMENT_SYNC_TASKSET_KEY = f"{DOCUMENT_SYNC_PREFIX}_taskset"
|
||||
FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks
|
||||
TASKSET_TTL = FENCE_TTL
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -50,7 +52,7 @@ def set_document_sync_fence(r: Redis, payload: int | None) -> None:
|
||||
r.delete(DOCUMENT_SYNC_FENCE_KEY)
|
||||
return
|
||||
|
||||
r.set(DOCUMENT_SYNC_FENCE_KEY, payload)
|
||||
r.set(DOCUMENT_SYNC_FENCE_KEY, payload, ex=FENCE_TTL)
|
||||
r.sadd(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY)
|
||||
|
||||
|
||||
@@ -110,6 +112,7 @@ def generate_document_sync_tasks(
|
||||
|
||||
# Add to the tracking taskset in Redis BEFORE creating the celery task
|
||||
r.sadd(DOCUMENT_SYNC_TASKSET_KEY, custom_task_id)
|
||||
r.expire(DOCUMENT_SYNC_TASKSET_KEY, TASKSET_TTL)
|
||||
|
||||
# Create the Celery task
|
||||
celery_app.send_task(
|
||||
|
||||
@@ -49,7 +49,7 @@ from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.sync_record import cleanup_sync_records
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
@@ -70,6 +70,8 @@ logger = setup_logger()
|
||||
|
||||
# celery auto associates tasks created inside another task,
|
||||
# which bloats the result metadata considerably. trail=False prevents this.
|
||||
# TODO(andrei): Rename all these kinds of functions from *vespa* to a more
|
||||
# generic *document_index*.
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
ignore_result=True,
|
||||
@@ -465,13 +467,17 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
# This flow is for updates so we get all indices.
|
||||
document_indices = get_all_document_indices(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
|
||||
retry_index = RetryDocumentIndex(doc_index)
|
||||
retry_document_indices: list[RetryDocumentIndex] = [
|
||||
RetryDocumentIndex(document_index)
|
||||
for document_index in document_indices
|
||||
]
|
||||
|
||||
doc = get_document(document_id, db_session)
|
||||
if not doc:
|
||||
@@ -500,14 +506,18 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
|
||||
# aggregated_boost_factor=doc.aggregated_boost_factor,
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
retry_index.update_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
fields=fields,
|
||||
user_fields=None,
|
||||
)
|
||||
for retry_document_index in retry_document_indices:
|
||||
# TODO(andrei): Previously there was a comment here saying
|
||||
# it was ok if a doc did not exist in the document index. I
|
||||
# don't agree with that claim, so keep an eye on this task
|
||||
# to see if this raises.
|
||||
retry_document_index.update_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
fields=fields,
|
||||
user_fields=None,
|
||||
)
|
||||
|
||||
# update db last. Worst case = we crash right before this and
|
||||
# the sync might repeat again later
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Any
|
||||
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
@@ -15,6 +16,11 @@ from onyx.tools.models import ToolCallInfo
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
|
||||
# Type alias for search doc deduplication key
|
||||
# Simple key: just document_id (str)
|
||||
# Full key: (document_id, chunk_ind, match_highlights)
|
||||
SearchDocKey = str | tuple[str, int, tuple[str, ...]]
|
||||
|
||||
|
||||
class ChatStateContainer:
|
||||
"""Container for accumulating state during LLM loop execution.
|
||||
@@ -40,6 +46,10 @@ class ChatStateContainer:
|
||||
# True if this turn is a clarification question (deep research flow)
|
||||
self.is_clarification: bool = False
|
||||
# Note: LLM cost tracking is now handled in multi_llm.py
|
||||
# Search doc collection - maps dedup key to SearchDoc for all docs from tool calls
|
||||
self._all_search_docs: dict[SearchDocKey, SearchDoc] = {}
|
||||
# Track which citation numbers were actually emitted during streaming
|
||||
self._emitted_citations: set[int] = set()
|
||||
|
||||
def add_tool_call(self, tool_call: ToolCallInfo) -> None:
|
||||
"""Add a tool call to the accumulated state."""
|
||||
@@ -91,6 +101,54 @@ class ChatStateContainer:
|
||||
with self._lock:
|
||||
return self.is_clarification
|
||||
|
||||
@staticmethod
|
||||
def create_search_doc_key(
|
||||
search_doc: SearchDoc, use_simple_key: bool = True
|
||||
) -> SearchDocKey:
|
||||
"""Create a unique key for a SearchDoc for deduplication.
|
||||
|
||||
Args:
|
||||
search_doc: The SearchDoc to create a key for
|
||||
use_simple_key: If True (default), use only document_id for deduplication.
|
||||
If False, include chunk_ind and match_highlights so that the same
|
||||
document/chunk with different highlights are stored separately.
|
||||
"""
|
||||
if use_simple_key:
|
||||
return search_doc.document_id
|
||||
match_highlights_tuple = tuple(sorted(search_doc.match_highlights or []))
|
||||
return (search_doc.document_id, search_doc.chunk_ind, match_highlights_tuple)
|
||||
|
||||
def add_search_docs(
|
||||
self, search_docs: list[SearchDoc], use_simple_key: bool = True
|
||||
) -> None:
|
||||
"""Add search docs to the accumulated collection with deduplication.
|
||||
|
||||
Args:
|
||||
search_docs: List of SearchDoc objects to add
|
||||
use_simple_key: If True (default), deduplicate by document_id only.
|
||||
If False, deduplicate by document_id + chunk_ind + match_highlights.
|
||||
"""
|
||||
with self._lock:
|
||||
for doc in search_docs:
|
||||
key = self.create_search_doc_key(doc, use_simple_key)
|
||||
if key not in self._all_search_docs:
|
||||
self._all_search_docs[key] = doc
|
||||
|
||||
def get_all_search_docs(self) -> dict[SearchDocKey, SearchDoc]:
|
||||
"""Thread-safe getter for all accumulated search docs (returns a copy)."""
|
||||
with self._lock:
|
||||
return self._all_search_docs.copy()
|
||||
|
||||
def add_emitted_citation(self, citation_num: int) -> None:
|
||||
"""Add a citation number that was actually emitted during streaming."""
|
||||
with self._lock:
|
||||
self._emitted_citations.add(citation_num)
|
||||
|
||||
def get_emitted_citations(self) -> set[int]:
|
||||
"""Thread-safe getter for emitted citations (returns a copy)."""
|
||||
with self._lock:
|
||||
return self._emitted_citations.copy()
|
||||
|
||||
|
||||
def run_chat_loop_with_state_containers(
|
||||
func: Callable[..., None],
|
||||
|
||||
@@ -53,6 +53,50 @@ def update_citation_processor_from_tool_response(
|
||||
citation_processor.update_citation_mapping(citation_to_doc)
|
||||
|
||||
|
||||
def extract_citation_order_from_text(text: str) -> list[int]:
|
||||
"""Extract citation numbers from text in order of first appearance.
|
||||
|
||||
Parses citation patterns like [1], [1, 2], [[1]], 【1】 etc. and returns
|
||||
the citation numbers in the order they first appear in the text.
|
||||
|
||||
Args:
|
||||
text: The text containing citations
|
||||
|
||||
Returns:
|
||||
List of citation numbers in order of first appearance (no duplicates)
|
||||
"""
|
||||
# Same pattern used in collapse_citations and DynamicCitationProcessor
|
||||
# Group 2 captures the number in double bracket format: [[1]], 【【1】】
|
||||
# Group 4 captures the numbers in single bracket format: [1], [1, 2]
|
||||
citation_pattern = re.compile(
|
||||
r"([\[【[]{2}(\d+)[\]】]]{2})|([\[【[]([\d]+(?: *, *\d+)*)[\]】]])"
|
||||
)
|
||||
seen: set[int] = set()
|
||||
order: list[int] = []
|
||||
|
||||
for match in citation_pattern.finditer(text):
|
||||
# Group 2 is for double bracket single number, group 4 is for single bracket
|
||||
if match.group(2):
|
||||
nums_str = match.group(2)
|
||||
elif match.group(4):
|
||||
nums_str = match.group(4)
|
||||
else:
|
||||
continue
|
||||
|
||||
for num_str in nums_str.split(","):
|
||||
num_str = num_str.strip()
|
||||
if num_str:
|
||||
try:
|
||||
num = int(num_str)
|
||||
if num not in seen:
|
||||
seen.add(num)
|
||||
order.append(num)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return order
|
||||
|
||||
|
||||
def collapse_citations(
|
||||
answer_text: str,
|
||||
existing_citation_mapping: CitationMapping,
|
||||
|
||||
@@ -9,6 +9,7 @@ from onyx.chat.citation_processor import CitationMode
|
||||
from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
from onyx.chat.citation_utils import update_citation_processor_from_tool_response
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.llm_step import extract_tool_calls_from_response_text
|
||||
from onyx.chat.llm_step import run_llm_step
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
@@ -38,11 +39,13 @@ from onyx.tools.built_in_tools import CITEABLE_TOOLS_NAMES
|
||||
from onyx.tools.built_in_tools import STOPPING_TOOLS_NAMES
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool_implementations.images.models import (
|
||||
FinalImageGenerationResponse,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.utils import extract_url_snippet_map
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
|
||||
from onyx.tools.tool_runner import run_tool_calls
|
||||
from onyx.tracing.framework.create import trace
|
||||
@@ -51,6 +54,78 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _try_fallback_tool_extraction(
|
||||
llm_step_result: LlmStepResult,
|
||||
tool_choice: ToolChoiceOptions,
|
||||
fallback_extraction_attempted: bool,
|
||||
tool_defs: list[dict],
|
||||
turn_index: int,
|
||||
) -> tuple[LlmStepResult, bool]:
|
||||
"""Attempt to extract tool calls from response text as a fallback.
|
||||
|
||||
This is a last resort fallback for low quality LLMs or those that don't have
|
||||
tool calling from the serving layer. Also triggers if there's reasoning but
|
||||
no answer and no tool calls.
|
||||
|
||||
Args:
|
||||
llm_step_result: The result from the LLM step
|
||||
tool_choice: The tool choice option used for this step
|
||||
fallback_extraction_attempted: Whether fallback extraction was already attempted
|
||||
tool_defs: List of tool definitions
|
||||
turn_index: The current turn index for placement
|
||||
|
||||
Returns:
|
||||
Tuple of (possibly updated LlmStepResult, whether fallback was attempted this call)
|
||||
"""
|
||||
if fallback_extraction_attempted:
|
||||
return llm_step_result, False
|
||||
|
||||
no_tool_calls = (
|
||||
not llm_step_result.tool_calls or len(llm_step_result.tool_calls) == 0
|
||||
)
|
||||
reasoning_but_no_answer_or_tools = (
|
||||
llm_step_result.reasoning and not llm_step_result.answer and no_tool_calls
|
||||
)
|
||||
should_try_fallback = (
|
||||
tool_choice == ToolChoiceOptions.REQUIRED and no_tool_calls
|
||||
) or reasoning_but_no_answer_or_tools
|
||||
|
||||
if not should_try_fallback:
|
||||
return llm_step_result, False
|
||||
|
||||
# Try to extract from answer first, then fall back to reasoning
|
||||
extracted_tool_calls: list[ToolCallKickoff] = []
|
||||
if llm_step_result.answer:
|
||||
extracted_tool_calls = extract_tool_calls_from_response_text(
|
||||
response_text=llm_step_result.answer,
|
||||
tool_definitions=tool_defs,
|
||||
placement=Placement(turn_index=turn_index),
|
||||
)
|
||||
if not extracted_tool_calls and llm_step_result.reasoning:
|
||||
extracted_tool_calls = extract_tool_calls_from_response_text(
|
||||
response_text=llm_step_result.reasoning,
|
||||
tool_definitions=tool_defs,
|
||||
placement=Placement(turn_index=turn_index),
|
||||
)
|
||||
|
||||
if extracted_tool_calls:
|
||||
logger.info(
|
||||
f"Extracted {len(extracted_tool_calls)} tool call(s) from response text "
|
||||
f"as fallback (tool_choice was REQUIRED but no tool calls returned)"
|
||||
)
|
||||
return (
|
||||
LlmStepResult(
|
||||
reasoning=llm_step_result.reasoning,
|
||||
answer=llm_step_result.answer,
|
||||
tool_calls=extracted_tool_calls,
|
||||
),
|
||||
True,
|
||||
)
|
||||
|
||||
return llm_step_result, True
|
||||
|
||||
|
||||
# Hardcoded oppinionated value, might breaks down to something like:
|
||||
# Cycle 1: Calls web_search for something
|
||||
# Cycle 2: Calls open_url for some results
|
||||
@@ -352,6 +427,7 @@ def run_llm_loop(
|
||||
ran_image_gen: bool = False
|
||||
just_ran_web_search: bool = False
|
||||
has_called_search_tool: bool = False
|
||||
fallback_extraction_attempted: bool = False
|
||||
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
|
||||
|
||||
default_base_system_prompt: str = get_default_base_system_prompt(db_session)
|
||||
@@ -378,12 +454,16 @@ def run_llm_loop(
|
||||
|
||||
# The section below calculates the available tokens for history a bit more accurately
|
||||
# now that project files are loaded in.
|
||||
if persona and persona.replace_base_system_prompt and persona.system_prompt:
|
||||
if persona and persona.replace_base_system_prompt:
|
||||
# Handles the case where user has checked off the "Replace base system prompt" checkbox
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=persona.system_prompt,
|
||||
token_count=token_counter(persona.system_prompt),
|
||||
message_type=MessageType.SYSTEM,
|
||||
system_prompt = (
|
||||
ChatMessageSimple(
|
||||
message=persona.system_prompt,
|
||||
token_count=token_counter(persona.system_prompt),
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
if persona.system_prompt
|
||||
else None
|
||||
)
|
||||
custom_agent_prompt_msg = None
|
||||
else:
|
||||
@@ -470,10 +550,11 @@ def run_llm_loop(
|
||||
|
||||
# This calls the LLM, yields packets (reasoning, answers, etc.) and returns the result
|
||||
# It also pre-processes the tool calls in preparation for running them
|
||||
tool_defs = [tool.tool_definition() for tool in final_tools]
|
||||
llm_step_result, has_reasoned = run_llm_step(
|
||||
emitter=emitter,
|
||||
history=truncated_message_history,
|
||||
tool_definitions=[tool.tool_definition() for tool in final_tools],
|
||||
tool_definitions=tool_defs,
|
||||
tool_choice=tool_choice,
|
||||
llm=llm,
|
||||
placement=Placement(turn_index=llm_cycle_count + reasoning_cycles),
|
||||
@@ -488,6 +569,19 @@ def run_llm_loop(
|
||||
if has_reasoned:
|
||||
reasoning_cycles += 1
|
||||
|
||||
# Fallback extraction for LLMs that don't support tool calling natively or are lower quality
|
||||
# and might incorrectly output tool calls in other channels
|
||||
llm_step_result, attempted = _try_fallback_tool_extraction(
|
||||
llm_step_result=llm_step_result,
|
||||
tool_choice=tool_choice,
|
||||
fallback_extraction_attempted=fallback_extraction_attempted,
|
||||
tool_defs=tool_defs,
|
||||
turn_index=llm_cycle_count + reasoning_cycles,
|
||||
)
|
||||
if attempted:
|
||||
# To prevent the case of excessive looping with bad models, we only allow one fallback attempt
|
||||
fallback_extraction_attempted = True
|
||||
|
||||
# Save citation mapping after each LLM step for incremental state updates
|
||||
state_container.set_citation_mapping(citation_processor.citation_to_doc)
|
||||
|
||||
@@ -523,6 +617,7 @@ def run_llm_loop(
|
||||
next_citation_num=citation_processor.get_next_citation_number(),
|
||||
max_concurrent_tools=None,
|
||||
skip_search_query_expansion=has_called_search_tool,
|
||||
url_snippet_map=extract_url_snippet_map(gathered_documents or []),
|
||||
)
|
||||
tool_responses = parallel_tool_call_results.tool_responses
|
||||
citation_mapping = parallel_tool_call_results.updated_citation_mapping
|
||||
@@ -561,8 +656,15 @@ def run_llm_loop(
|
||||
|
||||
# Extract search_docs if this is a search tool response
|
||||
search_docs = None
|
||||
displayed_docs = None
|
||||
if isinstance(tool_response.rich_response, SearchDocsResponse):
|
||||
search_docs = tool_response.rich_response.search_docs
|
||||
displayed_docs = tool_response.rich_response.displayed_docs
|
||||
|
||||
# Add ALL search docs to state container for DB persistence
|
||||
if search_docs:
|
||||
state_container.add_search_docs(search_docs)
|
||||
|
||||
if gathered_documents:
|
||||
gathered_documents.extend(search_docs)
|
||||
else:
|
||||
@@ -580,6 +682,12 @@ def run_llm_loop(
|
||||
):
|
||||
generated_images = tool_response.rich_response.generated_images
|
||||
|
||||
saved_response = (
|
||||
tool_response.rich_response
|
||||
if isinstance(tool_response.rich_response, str)
|
||||
else tool_response.llm_facing_response
|
||||
)
|
||||
|
||||
tool_call_info = ToolCallInfo(
|
||||
parent_tool_call_id=None, # Top-level tool calls are attached to the chat message
|
||||
turn_index=llm_cycle_count + reasoning_cycles,
|
||||
@@ -589,8 +697,8 @@ def run_llm_loop(
|
||||
tool_id=tool.id,
|
||||
reasoning_tokens=llm_step_result.reasoning, # All tool calls from this loop share the same reasoning
|
||||
tool_call_arguments=tool_call.tool_args,
|
||||
tool_call_response=tool_response.llm_facing_response,
|
||||
search_docs=search_docs,
|
||||
tool_call_response=saved_response,
|
||||
search_docs=displayed_docs or search_docs,
|
||||
generated_images=generated_images,
|
||||
)
|
||||
# Add to state container for partial save support
|
||||
@@ -645,7 +753,12 @@ def run_llm_loop(
|
||||
should_cite_documents = True
|
||||
|
||||
if not llm_step_result or not llm_step_result.answer:
|
||||
raise RuntimeError("LLM did not return an answer.")
|
||||
raise RuntimeError(
|
||||
"The LLM did not return an answer. "
|
||||
"Typically this is an issue with LLMs that do not support tool calling natively, "
|
||||
"or the model serving API is not configured correctly. "
|
||||
"This may also happen with models that are lower quality outputting invalid tool calls."
|
||||
)
|
||||
|
||||
emitter.emit(
|
||||
Packet(
|
||||
|
||||
@@ -14,6 +14,7 @@ from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS
|
||||
from onyx.configs.app_configs import PROMPT_CACHE_CHAT_HISTORY
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.file_store.models import ChatFileType
|
||||
@@ -49,6 +50,7 @@ from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tracing.framework.create import generation_span
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.text_processing import find_all_json_objects
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -278,6 +280,144 @@ def _extract_tool_call_kickoffs(
|
||||
return tool_calls
|
||||
|
||||
|
||||
def extract_tool_calls_from_response_text(
|
||||
response_text: str | None,
|
||||
tool_definitions: list[dict],
|
||||
placement: Placement,
|
||||
) -> list[ToolCallKickoff]:
|
||||
"""Extract tool calls from LLM response text by matching JSON against tool definitions.
|
||||
|
||||
This is a fallback mechanism for when the LLM was expected to return tool calls
|
||||
but didn't use the proper tool call format. It searches for JSON objects in the
|
||||
response text that match the structure of available tools.
|
||||
|
||||
Args:
|
||||
response_text: The LLM's text response to search for tool calls
|
||||
tool_definitions: List of tool definitions to match against
|
||||
placement: Placement information for the tool calls
|
||||
|
||||
Returns:
|
||||
List of ToolCallKickoff objects for any matched tool calls
|
||||
"""
|
||||
if not response_text or not tool_definitions:
|
||||
return []
|
||||
|
||||
# Build a map of tool names to their definitions
|
||||
tool_name_to_def: dict[str, dict] = {}
|
||||
for tool_def in tool_definitions:
|
||||
if tool_def.get("type") == "function" and "function" in tool_def:
|
||||
func_def = tool_def["function"]
|
||||
tool_name = func_def.get("name")
|
||||
if tool_name:
|
||||
tool_name_to_def[tool_name] = func_def
|
||||
|
||||
if not tool_name_to_def:
|
||||
return []
|
||||
|
||||
# Find all JSON objects in the response text
|
||||
json_objects = find_all_json_objects(response_text)
|
||||
|
||||
tool_calls: list[ToolCallKickoff] = []
|
||||
tab_index = 0
|
||||
|
||||
for json_obj in json_objects:
|
||||
matched_tool_call = _try_match_json_to_tool(json_obj, tool_name_to_def)
|
||||
if matched_tool_call:
|
||||
tool_name, tool_args = matched_tool_call
|
||||
tool_calls.append(
|
||||
ToolCallKickoff(
|
||||
tool_call_id=f"extracted_{uuid.uuid4().hex[:8]}",
|
||||
tool_name=tool_name,
|
||||
tool_args=tool_args,
|
||||
placement=Placement(
|
||||
turn_index=placement.turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=placement.sub_turn_index,
|
||||
),
|
||||
)
|
||||
)
|
||||
tab_index += 1
|
||||
|
||||
logger.info(
|
||||
f"Extracted {len(tool_calls)} tool call(s) from response text as fallback"
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
|
||||
def _try_match_json_to_tool(
|
||||
json_obj: dict[str, Any],
|
||||
tool_name_to_def: dict[str, dict],
|
||||
) -> tuple[str, dict[str, Any]] | None:
|
||||
"""Try to match a JSON object to a tool definition.
|
||||
|
||||
Supports several formats:
|
||||
1. Direct tool call format: {"name": "tool_name", "arguments": {...}}
|
||||
2. Function call format: {"function": {"name": "tool_name", "arguments": {...}}}
|
||||
3. Tool name as key: {"tool_name": {...arguments...}}
|
||||
4. Arguments matching a tool's parameter schema
|
||||
|
||||
Args:
|
||||
json_obj: The JSON object to match
|
||||
tool_name_to_def: Map of tool names to their function definitions
|
||||
|
||||
Returns:
|
||||
Tuple of (tool_name, tool_args) if matched, None otherwise
|
||||
"""
|
||||
# Format 1: Direct tool call format {"name": "...", "arguments": {...}}
|
||||
if "name" in json_obj and json_obj["name"] in tool_name_to_def:
|
||||
tool_name = json_obj["name"]
|
||||
arguments = json_obj.get("arguments", json_obj.get("parameters", {}))
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
if isinstance(arguments, dict):
|
||||
return (tool_name, arguments)
|
||||
|
||||
# Format 2: Function call format {"function": {"name": "...", "arguments": {...}}}
|
||||
if "function" in json_obj and isinstance(json_obj["function"], dict):
|
||||
func_obj = json_obj["function"]
|
||||
if "name" in func_obj and func_obj["name"] in tool_name_to_def:
|
||||
tool_name = func_obj["name"]
|
||||
arguments = func_obj.get("arguments", func_obj.get("parameters", {}))
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
if isinstance(arguments, dict):
|
||||
return (tool_name, arguments)
|
||||
|
||||
# Format 3: Tool name as key {"tool_name": {...arguments...}}
|
||||
for tool_name in tool_name_to_def:
|
||||
if tool_name in json_obj:
|
||||
arguments = json_obj[tool_name]
|
||||
if isinstance(arguments, dict):
|
||||
return (tool_name, arguments)
|
||||
|
||||
# Format 4: Check if the JSON object matches a tool's parameter schema
|
||||
for tool_name, func_def in tool_name_to_def.items():
|
||||
params = func_def.get("parameters", {})
|
||||
properties = params.get("properties", {})
|
||||
required = params.get("required", [])
|
||||
|
||||
if not properties:
|
||||
continue
|
||||
|
||||
# Check if all required parameters are present (empty required = all optional)
|
||||
if all(req in json_obj for req in required):
|
||||
# Check if any of the tool's properties are in the JSON object
|
||||
matching_props = [prop for prop in properties if prop in json_obj]
|
||||
if matching_props:
|
||||
# Filter to only include known properties
|
||||
filtered_args = {k: v for k, v in json_obj.items() if k in properties}
|
||||
return (tool_name, filtered_args)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def translate_history_to_llm_format(
|
||||
history: list[ChatMessageSimple],
|
||||
llm_config: LLMConfig,
|
||||
@@ -293,7 +433,7 @@ def translate_history_to_llm_format(
|
||||
|
||||
for idx, msg in enumerate(history):
|
||||
# if the message is being added to the history
|
||||
if msg.message_type in [
|
||||
if PROMPT_CACHE_CHAT_HISTORY and msg.message_type in [
|
||||
MessageType.SYSTEM,
|
||||
MessageType.USER,
|
||||
MessageType.ASSISTANT,
|
||||
@@ -720,6 +860,11 @@ def run_llm_step_pkt_generator(
|
||||
),
|
||||
obj=result,
|
||||
)
|
||||
# Track emitted citation for saving
|
||||
if state_container:
|
||||
state_container.add_emitted_citation(
|
||||
result.citation_number
|
||||
)
|
||||
else:
|
||||
# When citation_processor is None, use delta.content directly without modification
|
||||
accumulated_answer += delta.content
|
||||
@@ -846,6 +991,9 @@ def run_llm_step_pkt_generator(
|
||||
),
|
||||
obj=result,
|
||||
)
|
||||
# Track emitted citation for saving
|
||||
if state_container:
|
||||
state_container.add_emitted_citation(result.citation_number)
|
||||
|
||||
# Note: Content (AgentResponseDelta) doesn't need an explicit end packet - OverallStop handles it
|
||||
# Tool calls are handled by tool execution code and emit their own packets (e.g., SectionEnd)
|
||||
|
||||
@@ -42,7 +42,6 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import CitationDocInfo
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
@@ -86,10 +85,6 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.timing import log_function_time
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from onyx.utils.variable_functionality import noop_fallback
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -362,21 +357,20 @@ def handle_stream_message_objects(
|
||||
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
|
||||
)
|
||||
|
||||
# Track user message in PostHog for analytics
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
module="onyx.utils.telemetry",
|
||||
attribute="event_telemetry",
|
||||
fallback=noop_fallback,
|
||||
)(
|
||||
distinct_id=user.email if user else tenant_id,
|
||||
event="user_message_sent",
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=(
|
||||
user.email
|
||||
if user and not getattr(user, "is_anonymous", False)
|
||||
else tenant_id
|
||||
),
|
||||
event=MilestoneRecordType.USER_MESSAGE_SENT,
|
||||
properties={
|
||||
"origin": new_msg_req.origin.value,
|
||||
"has_files": len(new_msg_req.file_descriptors) > 0,
|
||||
"has_project": chat_session.project_id is not None,
|
||||
"has_persona": persona is not None and persona.id != DEFAULT_PERSONA_ID,
|
||||
"deep_research": new_msg_req.deep_research,
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -744,27 +738,16 @@ def llm_loop_completion_handle(
|
||||
else:
|
||||
final_answer = "The generation was stopped by the user."
|
||||
|
||||
# Build citation_docs_info from accumulated citations in state container
|
||||
citation_docs_info: list[CitationDocInfo] = []
|
||||
seen_citation_nums: set[int] = set()
|
||||
for citation_num, search_doc in state_container.citation_to_doc.items():
|
||||
if citation_num not in seen_citation_nums:
|
||||
seen_citation_nums.add(citation_num)
|
||||
citation_docs_info.append(
|
||||
CitationDocInfo(
|
||||
search_doc=search_doc,
|
||||
citation_number=citation_num,
|
||||
)
|
||||
)
|
||||
|
||||
save_chat_turn(
|
||||
message_text=final_answer,
|
||||
reasoning_tokens=state_container.reasoning_tokens,
|
||||
citation_docs_info=citation_docs_info,
|
||||
citation_to_doc=state_container.citation_to_doc,
|
||||
tool_calls=state_container.tool_calls,
|
||||
all_search_docs=state_container.get_all_search_docs(),
|
||||
db_session=db_session,
|
||||
assistant_message=assistant_message,
|
||||
is_clarification=state_container.is_clarification,
|
||||
emitted_citations=state_container.get_emitted_citations(),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
from onyx.prompts.prompt_utils import replace_citation_guidance_tag
|
||||
from onyx.prompts.tool_prompts import GENERATE_IMAGE_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import INTERNAL_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import MEMORY_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import OPEN_URLS_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import PYTHON_TOOL_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
@@ -28,6 +29,7 @@ from onyx.tools.interface import Tool
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.memory.memory_tool import MemoryTool
|
||||
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
@@ -178,8 +180,9 @@ def build_system_prompt(
|
||||
site_colon_disabled=WEB_SEARCH_SITE_DISABLED_GUIDANCE
|
||||
)
|
||||
+ OPEN_URLS_GUIDANCE
|
||||
+ GENERATE_IMAGE_GUIDANCE
|
||||
+ PYTHON_TOOL_GUIDANCE
|
||||
+ GENERATE_IMAGE_GUIDANCE
|
||||
+ MEMORY_GUIDANCE
|
||||
)
|
||||
return system_prompt
|
||||
|
||||
@@ -193,6 +196,7 @@ def build_system_prompt(
|
||||
has_generate_image = any(
|
||||
isinstance(tool, ImageGenerationTool) for tool in tools
|
||||
)
|
||||
has_memory = any(isinstance(tool, MemoryTool) for tool in tools)
|
||||
|
||||
if has_web_search or has_internal_search or include_all_guidance:
|
||||
system_prompt += TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
@@ -222,4 +226,7 @@ def build_system_prompt(
|
||||
if has_generate_image or include_all_guidance:
|
||||
system_prompt += GENERATE_IMAGE_GUIDANCE
|
||||
|
||||
if has_memory or include_all_guidance:
|
||||
system_prompt += MEMORY_GUIDANCE
|
||||
|
||||
return system_prompt
|
||||
|
||||
@@ -2,8 +2,9 @@ import json
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_state import SearchDocKey
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import CitationDocInfo
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.db.chat import add_search_docs_to_chat_message
|
||||
from onyx.db.chat import add_search_docs_to_tool_call
|
||||
@@ -19,22 +20,6 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _create_search_doc_key(search_doc: SearchDoc) -> tuple[str, int, tuple[str, ...]]:
|
||||
"""
|
||||
Create a unique key for a SearchDoc that accounts for different versions of the same
|
||||
document/chunk with different match_highlights.
|
||||
|
||||
Args:
|
||||
search_doc: The SearchDoc pydantic model to create a key for
|
||||
|
||||
Returns:
|
||||
A tuple of (document_id, chunk_ind, sorted match_highlights) that uniquely identifies
|
||||
this specific version of the document
|
||||
"""
|
||||
match_highlights_tuple = tuple(sorted(search_doc.match_highlights or []))
|
||||
return (search_doc.document_id, search_doc.chunk_ind, match_highlights_tuple)
|
||||
|
||||
|
||||
def _create_and_link_tool_calls(
|
||||
tool_calls: list[ToolCallInfo],
|
||||
assistant_message: ChatMessage,
|
||||
@@ -154,38 +139,36 @@ def save_chat_turn(
|
||||
message_text: str,
|
||||
reasoning_tokens: str | None,
|
||||
tool_calls: list[ToolCallInfo],
|
||||
citation_docs_info: list[CitationDocInfo],
|
||||
citation_to_doc: dict[int, SearchDoc],
|
||||
all_search_docs: dict[SearchDocKey, SearchDoc],
|
||||
db_session: Session,
|
||||
assistant_message: ChatMessage,
|
||||
is_clarification: bool = False,
|
||||
emitted_citations: set[int] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Save a chat turn by populating the assistant_message and creating related entities.
|
||||
|
||||
This function:
|
||||
1. Updates the ChatMessage with text, reasoning tokens, and token count
|
||||
2. Creates SearchDoc entries from ToolCall search_docs (for tool calls that returned documents)
|
||||
3. Collects all unique SearchDocs from all tool calls and links them to ChatMessage
|
||||
4. Builds citation mapping from citation_docs_info
|
||||
5. Links all unique SearchDocs from tool calls to the ChatMessage
|
||||
2. Creates DB SearchDoc entries from pre-deduplicated all_search_docs
|
||||
3. Builds tool_call -> search_doc mapping for displayed docs
|
||||
4. Builds citation mapping from citation_to_doc
|
||||
5. Links all unique SearchDocs to the ChatMessage
|
||||
6. Creates ToolCall entries and links SearchDocs to them
|
||||
7. Builds the citations mapping for the ChatMessage
|
||||
|
||||
Deduplication Logic:
|
||||
- SearchDocs are deduplicated using (document_id, chunk_ind, match_highlights) as the key
|
||||
- This ensures that the same document/chunk with different match_highlights (from different
|
||||
queries) are stored as separate SearchDoc entries
|
||||
- Each ToolCall and ChatMessage will map to the correct version of the SearchDoc that
|
||||
matches its specific query highlights
|
||||
|
||||
Args:
|
||||
message_text: The message content to save
|
||||
reasoning_tokens: Optional reasoning tokens for the message
|
||||
tool_calls: List of tool call information to create ToolCall entries (may include search_docs)
|
||||
citation_docs_info: List of citation document information for building citations mapping
|
||||
citation_to_doc: Mapping from citation number to SearchDoc for building citations
|
||||
all_search_docs: Pre-deduplicated search docs from ChatStateContainer
|
||||
db_session: Database session for persistence
|
||||
assistant_message: The ChatMessage object to populate (should already exist in DB)
|
||||
is_clarification: Whether this assistant message is a clarification question (deep research flow)
|
||||
emitted_citations: Set of citation numbers that were actually emitted during streaming.
|
||||
If provided, only citations in this set will be saved; others are filtered out.
|
||||
"""
|
||||
# 1. Update ChatMessage with message content, reasoning tokens, and token count
|
||||
assistant_message.message = message_text
|
||||
@@ -200,53 +183,53 @@ def save_chat_turn(
|
||||
else:
|
||||
assistant_message.token_count = 0
|
||||
|
||||
# 2. Create SearchDoc entries from tool_calls
|
||||
# Build mapping from SearchDoc to DB SearchDoc ID
|
||||
# Use (document_id, chunk_ind, match_highlights) as key to avoid duplicates
|
||||
# while ensuring different versions with different highlights are stored separately
|
||||
search_doc_key_to_id: dict[tuple[str, int, tuple[str, ...]], int] = {}
|
||||
tool_call_to_search_doc_ids: dict[str, list[int]] = {}
|
||||
# 2. Create DB SearchDoc entries from pre-deduplicated all_search_docs
|
||||
search_doc_key_to_id: dict[SearchDocKey, int] = {}
|
||||
for key, search_doc_py in all_search_docs.items():
|
||||
db_search_doc = create_db_search_doc(
|
||||
server_search_doc=search_doc_py,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
search_doc_key_to_id[key] = db_search_doc.id
|
||||
|
||||
# Process tool calls and their search docs
|
||||
# 3. Build tool_call -> search_doc mapping (for displayed docs in each tool call)
|
||||
tool_call_to_search_doc_ids: dict[str, list[int]] = {}
|
||||
for tool_call_info in tool_calls:
|
||||
if tool_call_info.search_docs:
|
||||
search_doc_ids_for_tool: list[int] = []
|
||||
for search_doc_py in tool_call_info.search_docs:
|
||||
# Create a unique key for this SearchDoc version
|
||||
search_doc_key = _create_search_doc_key(search_doc_py)
|
||||
|
||||
# Check if we've already created this exact SearchDoc version
|
||||
if search_doc_key in search_doc_key_to_id:
|
||||
search_doc_ids_for_tool.append(search_doc_key_to_id[search_doc_key])
|
||||
key = ChatStateContainer.create_search_doc_key(search_doc_py)
|
||||
if key in search_doc_key_to_id:
|
||||
search_doc_ids_for_tool.append(search_doc_key_to_id[key])
|
||||
else:
|
||||
# Create new DB SearchDoc entry
|
||||
# Displayed doc not in all_search_docs - create it
|
||||
# This can happen if displayed_docs contains docs not in search_docs
|
||||
db_search_doc = create_db_search_doc(
|
||||
server_search_doc=search_doc_py,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
search_doc_key_to_id[search_doc_key] = db_search_doc.id
|
||||
search_doc_key_to_id[key] = db_search_doc.id
|
||||
search_doc_ids_for_tool.append(db_search_doc.id)
|
||||
|
||||
tool_call_to_search_doc_ids[tool_call_info.tool_call_id] = list(
|
||||
set(search_doc_ids_for_tool)
|
||||
)
|
||||
|
||||
# 3. Collect all unique SearchDoc IDs from all tool calls to link to ChatMessage
|
||||
# Use a set to deduplicate by ID (since we've already deduplicated by key above)
|
||||
all_search_doc_ids_set: set[int] = set()
|
||||
for search_doc_ids in tool_call_to_search_doc_ids.values():
|
||||
all_search_doc_ids_set.update(search_doc_ids)
|
||||
# Collect all search doc IDs for ChatMessage linking
|
||||
all_search_doc_ids_set: set[int] = set(search_doc_key_to_id.values())
|
||||
|
||||
# 4. Build citation mapping from citation_docs_info
|
||||
# 4. Build a citation mapping from the citation number to the saved DB SearchDoc ID
|
||||
# Only include citations that were actually emitted during streaming
|
||||
citation_number_to_search_doc_id: dict[int, int] = {}
|
||||
|
||||
for citation_doc_info in citation_docs_info:
|
||||
# Extract SearchDoc pydantic model
|
||||
search_doc_py = citation_doc_info.search_doc
|
||||
for citation_num, search_doc_py in citation_to_doc.items():
|
||||
# Skip citations that weren't actually emitted (if emitted_citations is provided)
|
||||
if emitted_citations is not None and citation_num not in emitted_citations:
|
||||
continue
|
||||
|
||||
# Create the unique key for this SearchDoc version
|
||||
search_doc_key = _create_search_doc_key(search_doc_py)
|
||||
search_doc_key = ChatStateContainer.create_search_doc_key(search_doc_py)
|
||||
|
||||
# Get the search doc ID (should already exist from processing tool_calls)
|
||||
if search_doc_key in search_doc_key_to_id:
|
||||
@@ -283,10 +266,7 @@ def save_chat_turn(
|
||||
all_search_doc_ids_set.add(db_search_doc_id)
|
||||
|
||||
# Build mapping from citation number to search doc ID
|
||||
if citation_doc_info.citation_number is not None:
|
||||
citation_number_to_search_doc_id[citation_doc_info.citation_number] = (
|
||||
db_search_doc_id
|
||||
)
|
||||
citation_number_to_search_doc_id[citation_num] = db_search_doc_id
|
||||
|
||||
# 5. Link all unique SearchDocs (from both tool calls and citations) to ChatMessage
|
||||
final_search_doc_ids: list[int] = list(all_search_doc_ids_set)
|
||||
@@ -306,23 +286,10 @@ def save_chat_turn(
|
||||
tool_call_to_search_doc_ids=tool_call_to_search_doc_ids,
|
||||
)
|
||||
|
||||
# 7. Build citations mapping from citation_docs_info
|
||||
# Any citation_doc_info with a citation_number appeared in the text and should be mapped
|
||||
citations: dict[int, int] = {}
|
||||
for citation_doc_info in citation_docs_info:
|
||||
if citation_doc_info.citation_number is not None:
|
||||
search_doc_id = citation_number_to_search_doc_id.get(
|
||||
citation_doc_info.citation_number
|
||||
)
|
||||
if search_doc_id is not None:
|
||||
citations[citation_doc_info.citation_number] = search_doc_id
|
||||
else:
|
||||
logger.warning(
|
||||
f"Citation number {citation_doc_info.citation_number} found in citation_docs_info "
|
||||
f"but no matching search doc ID in mapping"
|
||||
)
|
||||
|
||||
assistant_message.citations = citations if citations else None
|
||||
# 7. Build citations mapping - use the mapping we already built in step 4
|
||||
assistant_message.citations = (
|
||||
citation_number_to_search_doc_id if citation_number_to_search_doc_id else None
|
||||
)
|
||||
|
||||
# Finally save the messages, tool calls, and docs
|
||||
db_session.commit()
|
||||
|
||||
@@ -208,8 +208,19 @@ OPENSEARCH_REST_API_PORT = int(os.environ.get("OPENSEARCH_REST_API_PORT") or 920
|
||||
OPENSEARCH_ADMIN_USERNAME = os.environ.get("OPENSEARCH_ADMIN_USERNAME", "admin")
|
||||
OPENSEARCH_ADMIN_PASSWORD = os.environ.get("OPENSEARCH_ADMIN_PASSWORD", "")
|
||||
|
||||
ENABLE_OPENSEARCH_FOR_ONYX = (
|
||||
os.environ.get("ENABLE_OPENSEARCH_FOR_ONYX", "").lower() == "true"
|
||||
# This is the "base" config for now, the idea is that at least for our dev
|
||||
# environments we always want to be dual indexing into both OpenSearch and Vespa
|
||||
# to stress test the new codepaths. Only enable this if there is some instance
|
||||
# of OpenSearch running for the relevant Onyx instance.
|
||||
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX = (
|
||||
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "").lower() == "true"
|
||||
)
|
||||
# Given that the "base" config above is true, this enables whether we want to
|
||||
# retrieve from OpenSearch or Vespa. We want to be able to quickly toggle this
|
||||
# in the event we see issues with OpenSearch retrieval in our dev environments.
|
||||
ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX = (
|
||||
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
and os.environ.get("ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX", "").lower() == "true"
|
||||
)
|
||||
|
||||
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
|
||||
@@ -738,6 +749,10 @@ JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default
|
||||
LOG_ONYX_MODEL_INTERACTIONS = (
|
||||
os.environ.get("LOG_ONYX_MODEL_INTERACTIONS", "").lower() == "true"
|
||||
)
|
||||
|
||||
PROMPT_CACHE_CHAT_HISTORY = (
|
||||
os.environ.get("PROMPT_CACHE_CHAT_HISTORY", "").lower() == "true"
|
||||
)
|
||||
# If set to `true` will enable additional logs about Vespa query performance
|
||||
# (time spent on finding the right docs + time spent fetching summaries from disk)
|
||||
LOG_VESPA_TIMING_INFORMATION = (
|
||||
@@ -1011,3 +1026,19 @@ INSTANCE_TYPE = (
|
||||
if os.environ.get("IS_MANAGED_INSTANCE", "").lower() == "true"
|
||||
else "cloud" if AUTH_TYPE == AuthType.CLOUD else "self_hosted"
|
||||
)
|
||||
|
||||
|
||||
## Discord Bot Configuration
|
||||
DISCORD_BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
|
||||
DISCORD_BOT_INVOKE_CHAR = os.environ.get("DISCORD_BOT_INVOKE_CHAR", "!")
|
||||
|
||||
|
||||
## Stripe Configuration
|
||||
# URL to fetch the Stripe publishable key from a public S3 bucket.
|
||||
# Publishable keys are safe to expose publicly - they can only initialize
|
||||
# Stripe.js and tokenize payment info, not make charges or access data.
|
||||
STRIPE_PUBLISHABLE_KEY_URL = (
|
||||
"https://onyx-stripe-public.s3.amazonaws.com/publishable-key.txt"
|
||||
)
|
||||
# Override for local testing with Stripe test keys (pk_test_*)
|
||||
STRIPE_PUBLISHABLE_KEY_OVERRIDE = os.environ.get("STRIPE_PUBLISHABLE_KEY")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
|
||||
INPUT_PROMPT_YAML = "./onyx/seeding/input_prompts.yaml"
|
||||
PROMPTS_YAML = "./onyx/seeding/prompts.yaml"
|
||||
PERSONAS_YAML = "./onyx/seeding/personas.yaml"
|
||||
NUM_RETURNED_HITS = 50
|
||||
|
||||
@@ -93,6 +93,7 @@ SSL_CERT_FILE = "bundle.pem"
|
||||
DANSWER_API_KEY_PREFIX = "API_KEY__"
|
||||
DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN = "onyxapikey.ai"
|
||||
UNNAMED_KEY_PLACEHOLDER = "Unnamed"
|
||||
DISCORD_SERVICE_API_KEY_NAME = "discord-bot-service"
|
||||
|
||||
# Key-Value store keys
|
||||
KV_REINDEX_KEY = "needs_reindexing"
|
||||
@@ -152,6 +153,17 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
|
||||
|
||||
# How long a queued user-file task is valid before workers discard it.
|
||||
# Should be longer than the beat interval (20 s) but short enough to prevent
|
||||
# indefinite queue growth. Workers drop tasks older than this without touching
|
||||
# the DB, so a shorter value = faster drain of stale duplicates.
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
|
||||
|
||||
# Maximum number of tasks allowed in the user-file-processing queue before the
|
||||
# beat generator stops adding more. Prevents unbounded queue growth when workers
|
||||
# fall behind.
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
|
||||
|
||||
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
|
||||
@@ -340,6 +352,7 @@ class MilestoneRecordType(str, Enum):
|
||||
CREATED_CONNECTOR = "created_connector"
|
||||
CONNECTOR_SUCCEEDED = "connector_succeeded"
|
||||
RAN_QUERY = "ran_query"
|
||||
USER_MESSAGE_SENT = "user_message_sent"
|
||||
MULTIPLE_ASSISTANTS = "multiple_assistants"
|
||||
CREATED_ASSISTANT = "created_assistant"
|
||||
CREATED_ONYX_BOT = "created_onyx_bot"
|
||||
@@ -422,6 +435,9 @@ class OnyxRedisLocks:
|
||||
# User file processing
|
||||
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
|
||||
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
|
||||
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
|
||||
# Prevents the beat from re-enqueuing the same file while a task is already queued.
|
||||
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
|
||||
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
|
||||
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
|
||||
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
|
||||
|
||||
@@ -25,11 +25,17 @@ class AsanaConnector(LoadConnector, PollConnector):
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
) -> None:
|
||||
self.workspace_id = asana_workspace_id
|
||||
self.project_ids_to_index: list[str] | None = (
|
||||
asana_project_ids.split(",") if asana_project_ids is not None else None
|
||||
)
|
||||
self.asana_team_id = asana_team_id
|
||||
self.workspace_id = asana_workspace_id.strip()
|
||||
if asana_project_ids:
|
||||
project_ids = [
|
||||
project_id.strip()
|
||||
for project_id in asana_project_ids.split(",")
|
||||
if project_id.strip()
|
||||
]
|
||||
self.project_ids_to_index = project_ids or None
|
||||
else:
|
||||
self.project_ids_to_index = None
|
||||
self.asana_team_id = (asana_team_id.strip() or None) if asana_team_id else None
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
logger.info(
|
||||
|
||||
@@ -31,6 +31,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
BASE_URL = "https://api.gong.io"
|
||||
MAX_CALL_DETAILS_ATTEMPTS = 6
|
||||
CALL_DETAILS_DELAY = 30 # in seconds
|
||||
# Gong API limit is 3 calls/sec — stay safely under it
|
||||
MIN_REQUEST_INTERVAL = 0.5 # seconds between requests
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -44,9 +46,13 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
self.continue_on_fail = continue_on_fail
|
||||
self.auth_token_basic: str | None = None
|
||||
self.hide_user_info = hide_user_info
|
||||
self._last_request_time: float = 0.0
|
||||
|
||||
# urllib3 Retry already respects the Retry-After header by default
|
||||
# (respect_retry_after_header=True), so on 429 it will sleep for the
|
||||
# duration Gong specifies before retrying.
|
||||
retry_strategy = Retry(
|
||||
total=5,
|
||||
total=10,
|
||||
backoff_factor=2,
|
||||
status_forcelist=[429, 500, 502, 503, 504],
|
||||
)
|
||||
@@ -60,8 +66,24 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
url = f"{GongConnector.BASE_URL}{endpoint}"
|
||||
return url
|
||||
|
||||
def _throttled_request(
|
||||
self, method: str, url: str, **kwargs: Any
|
||||
) -> requests.Response:
|
||||
"""Rate-limited request wrapper. Enforces MIN_REQUEST_INTERVAL between
|
||||
calls to stay under Gong's 3 calls/sec limit and avoid triggering 429s."""
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_request_time
|
||||
if elapsed < self.MIN_REQUEST_INTERVAL:
|
||||
time.sleep(self.MIN_REQUEST_INTERVAL - elapsed)
|
||||
|
||||
response = self._session.request(method, url, **kwargs)
|
||||
self._last_request_time = time.monotonic()
|
||||
return response
|
||||
|
||||
def _get_workspace_id_map(self) -> dict[str, str]:
|
||||
response = self._session.get(GongConnector.make_url("/v2/workspaces"))
|
||||
response = self._throttled_request(
|
||||
"GET", GongConnector.make_url("/v2/workspaces")
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
workspaces_details = response.json().get("workspaces")
|
||||
@@ -105,8 +127,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
del body["filter"]["workspaceId"]
|
||||
|
||||
while True:
|
||||
response = self._session.post(
|
||||
GongConnector.make_url("/v2/calls/transcript"), json=body
|
||||
response = self._throttled_request(
|
||||
"POST", GongConnector.make_url("/v2/calls/transcript"), json=body
|
||||
)
|
||||
# If no calls in the range, just break out
|
||||
if response.status_code == 404:
|
||||
@@ -141,8 +163,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
"contentSelector": {"exposedFields": {"parties": True}},
|
||||
}
|
||||
|
||||
response = self._session.post(
|
||||
GongConnector.make_url("/v2/calls/extensive"), json=body
|
||||
response = self._throttled_request(
|
||||
"POST", GongConnector.make_url("/v2/calls/extensive"), json=body
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -193,7 +215,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
# There's a likely race condition in the API where a transcript will have a
|
||||
# call id but the call to v2/calls/extensive will not return all of the id's
|
||||
# retry with exponential backoff has been observed to mitigate this
|
||||
# in ~2 minutes
|
||||
# in ~2 minutes. After max attempts, proceed with whatever we have —
|
||||
# the per-call loop below will skip missing IDs gracefully.
|
||||
current_attempt = 0
|
||||
while True:
|
||||
current_attempt += 1
|
||||
@@ -212,11 +235,14 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
f"missing_call_ids={missing_call_ids}"
|
||||
)
|
||||
if current_attempt >= self.MAX_CALL_DETAILS_ATTEMPTS:
|
||||
raise RuntimeError(
|
||||
f"Attempt count exceeded for _get_call_details_by_ids: "
|
||||
f"missing_call_ids={missing_call_ids} "
|
||||
f"max_attempts={self.MAX_CALL_DETAILS_ATTEMPTS}"
|
||||
logger.error(
|
||||
f"Giving up on missing call id's after "
|
||||
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
|
||||
f"missing_call_ids={missing_call_ids} — "
|
||||
f"proceeding with {len(call_details_map)} of "
|
||||
f"{len(transcript_call_ids)} calls"
|
||||
)
|
||||
break
|
||||
|
||||
wait_seconds = self.CALL_DETAILS_DELAY * pow(2, current_attempt - 1)
|
||||
logger.warning(
|
||||
|
||||
@@ -244,6 +244,9 @@ def convert_metadata_dict_to_list_of_strings(
|
||||
Each string is a key-value pair separated by the INDEX_SEPARATOR. If a key
|
||||
points to a list of values, each value generates a unique pair.
|
||||
|
||||
NOTE: Whatever formatting strategy is used here to generate a key-value
|
||||
string must be replicated when constructing query filters.
|
||||
|
||||
Args:
|
||||
metadata: The metadata dict to convert where values can be either a
|
||||
string or a list of strings.
|
||||
|
||||
@@ -6,6 +6,7 @@ import sys
|
||||
import tempfile
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -30,20 +31,29 @@ from onyx.connectors.salesforce.onyx_salesforce import OnyxSalesforce
|
||||
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
|
||||
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
|
||||
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
|
||||
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
|
||||
from onyx.connectors.salesforce.utils import get_sqlite_db_path
|
||||
from onyx.connectors.salesforce.utils import ID_FIELD
|
||||
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
|
||||
from onyx.connectors.salesforce.utils import NAME_FIELD
|
||||
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _convert_to_metadata_value(value: Any) -> str | list[str]:
|
||||
"""Convert a Salesforce field value to a valid metadata value.
|
||||
|
||||
Document metadata expects str | list[str], but Salesforce returns
|
||||
various types (bool, float, int, etc.). This function ensures all
|
||||
values are properly converted to strings.
|
||||
"""
|
||||
if isinstance(value, list):
|
||||
return [str(item) for item in value]
|
||||
return str(value)
|
||||
|
||||
|
||||
_DEFAULT_PARENT_OBJECT_TYPES = [ACCOUNT_OBJECT_TYPE]
|
||||
|
||||
_DEFAULT_ATTRIBUTES_TO_KEEP: dict[str, dict[str, str]] = {
|
||||
@@ -433,6 +443,88 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
# # gc.collect()
|
||||
# return all_types
|
||||
|
||||
def _yield_doc_batches(
|
||||
self,
|
||||
sf_db: OnyxSalesforceSQLite,
|
||||
type_to_processed: dict[str, int],
|
||||
changed_ids_to_type: dict[str, str],
|
||||
parent_types: set[str],
|
||||
increment_parents_changed: Callable[[], None],
|
||||
) -> GenerateDocumentsOutput:
|
||||
""" """
|
||||
docs_to_yield: list[Document] = []
|
||||
docs_to_yield_bytes = 0
|
||||
|
||||
last_log_time = 0.0
|
||||
|
||||
for (
|
||||
parent_type,
|
||||
parent_id,
|
||||
examined_ids,
|
||||
) in sf_db.get_changed_parent_ids_by_type(
|
||||
changed_ids=list(changed_ids_to_type.keys()),
|
||||
parent_types=parent_types,
|
||||
):
|
||||
now = time.monotonic()
|
||||
|
||||
processed = examined_ids - 1
|
||||
if now - last_log_time > SalesforceConnector.LOG_INTERVAL:
|
||||
logger.info(
|
||||
f"Processing stats: {type_to_processed} "
|
||||
f"file_size={sf_db.file_size} "
|
||||
f"processed={processed} "
|
||||
f"remaining={len(changed_ids_to_type) - processed}"
|
||||
)
|
||||
last_log_time = now
|
||||
|
||||
type_to_processed[parent_type] = type_to_processed.get(parent_type, 0) + 1
|
||||
|
||||
parent_object = sf_db.get_record(parent_id, parent_type)
|
||||
if not parent_object:
|
||||
logger.warning(
|
||||
f"Failed to get parent object {parent_id} for {parent_type}"
|
||||
)
|
||||
continue
|
||||
|
||||
# use the db to create a document we can yield
|
||||
doc = convert_sf_object_to_doc(
|
||||
sf_db,
|
||||
sf_object=parent_object,
|
||||
sf_instance=self.sf_client.sf_instance,
|
||||
)
|
||||
|
||||
doc.metadata["object_type"] = parent_type
|
||||
|
||||
# Add default attributes to the metadata
|
||||
for (
|
||||
sf_attribute,
|
||||
canonical_attribute,
|
||||
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(parent_type, {}).items():
|
||||
if sf_attribute in parent_object.data:
|
||||
doc.metadata[canonical_attribute] = _convert_to_metadata_value(
|
||||
parent_object.data[sf_attribute]
|
||||
)
|
||||
|
||||
doc_sizeof = sys.getsizeof(doc)
|
||||
docs_to_yield_bytes += doc_sizeof
|
||||
docs_to_yield.append(doc)
|
||||
increment_parents_changed()
|
||||
|
||||
# memory usage is sensitive to the input length, so we're yielding immediately
|
||||
# if the batch exceeds a certain byte length
|
||||
if (
|
||||
len(docs_to_yield) >= self.batch_size
|
||||
or docs_to_yield_bytes > SalesforceConnector.MAX_BATCH_BYTES
|
||||
):
|
||||
yield docs_to_yield
|
||||
docs_to_yield = []
|
||||
docs_to_yield_bytes = 0
|
||||
|
||||
# observed a memory leak / size issue with the account table if we don't gc.collect here.
|
||||
gc.collect()
|
||||
|
||||
yield docs_to_yield
|
||||
|
||||
def _full_sync(
|
||||
self,
|
||||
temp_dir: str,
|
||||
@@ -443,8 +535,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
if not self._sf_client:
|
||||
raise RuntimeError("self._sf_client is None!")
|
||||
|
||||
docs_to_yield: list[Document] = []
|
||||
|
||||
changed_ids_to_type: dict[str, str] = {}
|
||||
parents_changed = 0
|
||||
examined_ids = 0
|
||||
@@ -492,9 +582,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
f"records={num_records}"
|
||||
)
|
||||
|
||||
# yield an empty list to keep the connector alive
|
||||
yield docs_to_yield
|
||||
|
||||
new_ids = sf_db.update_from_csv(
|
||||
object_type=object_type,
|
||||
csv_download_path=csv_path,
|
||||
@@ -527,79 +614,17 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
)
|
||||
|
||||
# Step 3 - extract and index docs
|
||||
docs_to_yield_bytes = 0
|
||||
|
||||
last_log_time = 0.0
|
||||
|
||||
for (
|
||||
parent_type,
|
||||
parent_id,
|
||||
examined_ids,
|
||||
) in sf_db.get_changed_parent_ids_by_type(
|
||||
changed_ids=list(changed_ids_to_type.keys()),
|
||||
parent_types=ctx.parent_types,
|
||||
):
|
||||
now = time.monotonic()
|
||||
|
||||
processed = examined_ids - 1
|
||||
if now - last_log_time > SalesforceConnector.LOG_INTERVAL:
|
||||
logger.info(
|
||||
f"Processing stats: {type_to_processed} "
|
||||
f"file_size={sf_db.file_size} "
|
||||
f"processed={processed} "
|
||||
f"remaining={len(changed_ids_to_type) - processed}"
|
||||
)
|
||||
last_log_time = now
|
||||
|
||||
type_to_processed[parent_type] = (
|
||||
type_to_processed.get(parent_type, 0) + 1
|
||||
)
|
||||
|
||||
parent_object = sf_db.get_record(parent_id, parent_type)
|
||||
if not parent_object:
|
||||
logger.warning(
|
||||
f"Failed to get parent object {parent_id} for {parent_type}"
|
||||
)
|
||||
continue
|
||||
|
||||
# use the db to create a document we can yield
|
||||
doc = convert_sf_object_to_doc(
|
||||
sf_db,
|
||||
sf_object=parent_object,
|
||||
sf_instance=self.sf_client.sf_instance,
|
||||
)
|
||||
|
||||
doc.metadata["object_type"] = parent_type
|
||||
|
||||
# Add default attributes to the metadata
|
||||
for (
|
||||
sf_attribute,
|
||||
canonical_attribute,
|
||||
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(parent_type, {}).items():
|
||||
if sf_attribute in parent_object.data:
|
||||
doc.metadata[canonical_attribute] = parent_object.data[
|
||||
sf_attribute
|
||||
]
|
||||
|
||||
doc_sizeof = sys.getsizeof(doc)
|
||||
docs_to_yield_bytes += doc_sizeof
|
||||
docs_to_yield.append(doc)
|
||||
def increment_parents_changed() -> None:
|
||||
nonlocal parents_changed
|
||||
parents_changed += 1
|
||||
|
||||
# memory usage is sensitive to the input length, so we're yielding immediately
|
||||
# if the batch exceeds a certain byte length
|
||||
if (
|
||||
len(docs_to_yield) >= self.batch_size
|
||||
or docs_to_yield_bytes > SalesforceConnector.MAX_BATCH_BYTES
|
||||
):
|
||||
yield docs_to_yield
|
||||
docs_to_yield = []
|
||||
docs_to_yield_bytes = 0
|
||||
|
||||
# observed a memory leak / size issue with the account table if we don't gc.collect here.
|
||||
gc.collect()
|
||||
|
||||
yield docs_to_yield
|
||||
yield from self._yield_doc_batches(
|
||||
sf_db,
|
||||
type_to_processed,
|
||||
changed_ids_to_type,
|
||||
ctx.parent_types,
|
||||
increment_parents_changed,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Unexpected exception")
|
||||
raise
|
||||
@@ -801,7 +826,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
canonical_attribute,
|
||||
) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(actual_parent_type, {}).items():
|
||||
if sf_attribute in record:
|
||||
doc.metadata[canonical_attribute] = record[sf_attribute]
|
||||
doc.metadata[canonical_attribute] = _convert_to_metadata_value(
|
||||
record[sf_attribute]
|
||||
)
|
||||
|
||||
doc_sizeof = sys.getsizeof(doc)
|
||||
docs_to_yield_bytes += doc_sizeof
|
||||
@@ -1088,36 +1115,21 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSyn
|
||||
return return_context
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
if MULTI_TENANT:
|
||||
# if multi tenant, we cannot expect the sqlite db to be cached/present
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
return self._full_sync(temp_dir)
|
||||
|
||||
# nuke the db since we're starting from scratch
|
||||
sqlite_db_path = get_sqlite_db_path(BASE_DATA_PATH)
|
||||
if os.path.exists(sqlite_db_path):
|
||||
logger.info(f"load_from_state: Removing db at {sqlite_db_path}.")
|
||||
os.remove(sqlite_db_path)
|
||||
return self._full_sync(BASE_DATA_PATH)
|
||||
# Always use a temp directory for SQLite - the database is rebuilt
|
||||
# from scratch each time via CSV downloads, so there's no caching benefit
|
||||
# from persisting it. Using temp dirs also avoids collisions between
|
||||
# multiple CC pairs and eliminates stale WAL/SHM file issues.
|
||||
# TODO(evan): make this thing checkpointed and persist/load db from filestore
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
yield from self._full_sync(temp_dir)
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
"""Poll source will synchronize updated parent objects one by one."""
|
||||
|
||||
if start == 0:
|
||||
# nuke the db if we're starting from scratch
|
||||
sqlite_db_path = get_sqlite_db_path(BASE_DATA_PATH)
|
||||
if os.path.exists(sqlite_db_path):
|
||||
logger.info(
|
||||
f"poll_source: Starting at time 0, removing db at {sqlite_db_path}."
|
||||
)
|
||||
os.remove(sqlite_db_path)
|
||||
|
||||
return self._delta_sync(BASE_DATA_PATH, start, end)
|
||||
|
||||
# Always use a temp directory - see comment in load_from_state()
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
return self._delta_sync(temp_dir, start, end)
|
||||
yield from self._delta_sync(temp_dir, start, end)
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
|
||||
from onyx.connectors.salesforce.utils import ID_FIELD
|
||||
from onyx.connectors.salesforce.utils import NAME_FIELD
|
||||
from onyx.connectors.salesforce.utils import remove_sqlite_db_files
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
|
||||
from onyx.connectors.salesforce.utils import validate_salesforce_id
|
||||
@@ -22,6 +23,9 @@ from shared_configs.utils import batch_list
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
SQLITE_DISK_IO_ERROR = "disk I/O error"
|
||||
|
||||
|
||||
class OnyxSalesforceSQLite:
|
||||
"""Notes on context management using 'with self.conn':
|
||||
|
||||
@@ -99,8 +103,37 @@ class OnyxSalesforceSQLite:
|
||||
def apply_schema(self) -> None:
|
||||
"""Initialize the SQLite database with required tables if they don't exist.
|
||||
|
||||
Non-destructive operation.
|
||||
Non-destructive operation. If a disk I/O error is encountered (often due
|
||||
to stale WAL/SHM files from a previous crash), this method will attempt
|
||||
to recover by removing the corrupted files and recreating the database.
|
||||
"""
|
||||
try:
|
||||
self._apply_schema_impl()
|
||||
except sqlite3.OperationalError as e:
|
||||
if SQLITE_DISK_IO_ERROR not in str(e):
|
||||
raise
|
||||
|
||||
logger.warning(f"SQLite disk I/O error detected, attempting recovery: {e}")
|
||||
self._recover_from_corruption()
|
||||
self._apply_schema_impl()
|
||||
|
||||
def _recover_from_corruption(self) -> None:
|
||||
"""Recover from SQLite corruption by removing all database files and reconnecting."""
|
||||
logger.info(f"Removing corrupted SQLite files: {self.filename}")
|
||||
|
||||
# Close existing connection
|
||||
self.close()
|
||||
|
||||
# Remove all SQLite files (main db, WAL, SHM)
|
||||
remove_sqlite_db_files(self.filename)
|
||||
|
||||
# Reconnect - this will create a fresh database
|
||||
self.connect()
|
||||
|
||||
logger.info("SQLite recovery complete, fresh database created")
|
||||
|
||||
def _apply_schema_impl(self) -> None:
|
||||
"""Internal implementation of apply_schema."""
|
||||
if self._conn is None:
|
||||
raise RuntimeError("Database connection is closed")
|
||||
|
||||
|
||||
@@ -41,6 +41,28 @@ def get_sqlite_db_path(directory: str) -> str:
|
||||
return os.path.join(directory, "salesforce_db.sqlite")
|
||||
|
||||
|
||||
def remove_sqlite_db_files(db_path: str) -> None:
|
||||
"""Remove SQLite database and all associated files (WAL, SHM).
|
||||
|
||||
SQLite in WAL mode creates additional files:
|
||||
- .sqlite-wal: Write-ahead log
|
||||
- .sqlite-shm: Shared memory file
|
||||
|
||||
If these files become stale (e.g., after a crash), they can cause
|
||||
'disk I/O error' when trying to open the database. This function
|
||||
ensures all related files are removed.
|
||||
"""
|
||||
files_to_remove = [
|
||||
db_path,
|
||||
f"{db_path}-wal",
|
||||
f"{db_path}-shm",
|
||||
]
|
||||
for file_path in files_to_remove:
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
|
||||
|
||||
# NOTE: only used with shelves, deprecated at this point
|
||||
def get_object_type_path(object_type: str) -> str:
|
||||
"""Get the directory path for a specific object type."""
|
||||
type_dir = os.path.join(BASE_DATA_PATH, object_type)
|
||||
|
||||
@@ -15,6 +15,7 @@ from onyx.federated_connectors.slack.models import SlackEntities
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.llm.utils import llm_response_to_string
|
||||
from onyx.natural_language_processing.english_stopwords import ENGLISH_STOPWORDS_SET
|
||||
from onyx.onyxbot.slack.models import ChannelType
|
||||
from onyx.prompts.federated_search import SLACK_DATE_EXTRACTION_PROMPT
|
||||
from onyx.prompts.federated_search import SLACK_QUERY_EXPANSION_PROMPT
|
||||
@@ -113,7 +114,7 @@ def is_recency_query(query: str) -> bool:
|
||||
if not has_recency_keyword:
|
||||
return False
|
||||
|
||||
# Get combined stop words (NLTK + Slack-specific)
|
||||
# Get combined stop words (English + Slack-specific)
|
||||
all_stop_words = _get_combined_stop_words()
|
||||
|
||||
# Extract content words (excluding stop words)
|
||||
@@ -488,7 +489,7 @@ def build_channel_override_query(channel_references: set[str], time_filter: str)
|
||||
return f"__CHANNEL_OVERRIDE__ {channel_filter}{time_filter}"
|
||||
|
||||
|
||||
# Slack-specific stop words (in addition to standard NLTK stop words)
|
||||
# Slack-specific stop words (in addition to standard English stop words)
|
||||
# These include Slack-specific terms and temporal/recency keywords
|
||||
SLACK_SPECIFIC_STOP_WORDS = frozenset(
|
||||
RECENCY_KEYWORDS
|
||||
@@ -508,27 +509,16 @@ SLACK_SPECIFIC_STOP_WORDS = frozenset(
|
||||
)
|
||||
|
||||
|
||||
def _get_combined_stop_words() -> set[str]:
|
||||
"""Get combined NLTK + Slack-specific stop words.
|
||||
def _get_combined_stop_words() -> frozenset[str]:
|
||||
"""Get combined English + Slack-specific stop words.
|
||||
|
||||
Returns a set of stop words for filtering content words.
|
||||
Falls back to just Slack-specific stop words if NLTK is unavailable.
|
||||
Returns a frozenset of stop words for filtering content words.
|
||||
|
||||
Note: Currently only supports English stop words. Non-English queries
|
||||
may have suboptimal content word extraction. Future enhancement could
|
||||
detect query language and load appropriate stop words.
|
||||
"""
|
||||
try:
|
||||
from nltk.corpus import stopwords # type: ignore
|
||||
|
||||
# TODO: Support multiple languages - currently hardcoded to English
|
||||
# Could detect language or allow configuration
|
||||
nltk_stop_words = set(stopwords.words("english"))
|
||||
except Exception:
|
||||
# Fallback if NLTK not available
|
||||
nltk_stop_words = set()
|
||||
|
||||
return nltk_stop_words | SLACK_SPECIFIC_STOP_WORDS
|
||||
return ENGLISH_STOPWORDS_SET | SLACK_SPECIFIC_STOP_WORDS
|
||||
|
||||
|
||||
def extract_content_words_from_recency_query(
|
||||
@@ -536,7 +526,7 @@ def extract_content_words_from_recency_query(
|
||||
) -> list[str]:
|
||||
"""Extract meaningful content words from a recency query.
|
||||
|
||||
Filters out NLTK stop words, Slack-specific terms, channel references, and proper nouns.
|
||||
Filters out English stop words, Slack-specific terms, channel references, and proper nouns.
|
||||
|
||||
Args:
|
||||
query_text: The user's query text
|
||||
@@ -545,7 +535,7 @@ def extract_content_words_from_recency_query(
|
||||
Returns:
|
||||
List of content words (up to MAX_CONTENT_WORDS)
|
||||
"""
|
||||
# Get combined stop words (NLTK + Slack-specific)
|
||||
# Get combined stop words (English + Slack-specific)
|
||||
all_stop_words = _get_combined_stop_words()
|
||||
|
||||
words = query_text.split()
|
||||
@@ -567,6 +557,23 @@ def extract_content_words_from_recency_query(
|
||||
return content_words_filtered[:MAX_CONTENT_WORDS]
|
||||
|
||||
|
||||
def _is_valid_keyword_query(line: str) -> bool:
|
||||
"""Check if a line looks like a valid keyword query vs explanatory text.
|
||||
|
||||
Returns False for lines that appear to be LLM explanations rather than keywords.
|
||||
"""
|
||||
# Reject lines that start with parentheses (explanatory notes)
|
||||
if line.startswith("("):
|
||||
return False
|
||||
|
||||
# Reject lines that are too long (likely sentences, not keywords)
|
||||
# Keywords should be short - reject if > 50 chars or > 6 words
|
||||
if len(line) > 50 or len(line.split()) > 6:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
"""Use LLM to expand query into multiple search variations.
|
||||
|
||||
@@ -589,10 +596,18 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
response_clean = _parse_llm_code_block_response(response)
|
||||
|
||||
# Split into lines and filter out empty lines
|
||||
rephrased_queries = [
|
||||
raw_queries = [
|
||||
line.strip() for line in response_clean.split("\n") if line.strip()
|
||||
]
|
||||
|
||||
# Filter out lines that look like explanatory text rather than keywords
|
||||
rephrased_queries = [q for q in raw_queries if _is_valid_keyword_query(q)]
|
||||
|
||||
# Log if we filtered out garbage
|
||||
if len(raw_queries) != len(rephrased_queries):
|
||||
filtered_out = set(raw_queries) - set(rephrased_queries)
|
||||
logger.warning(f"Filtered out non-keyword LLM responses: {filtered_out}")
|
||||
|
||||
# If no queries generated, use empty query
|
||||
if not rephrased_queries:
|
||||
logger.debug("No content keywords extracted from query expansion")
|
||||
|
||||
@@ -116,6 +116,8 @@ class UserFileFilters(BaseModel):
|
||||
|
||||
|
||||
class IndexFilters(BaseFilters, UserFileFilters):
|
||||
# NOTE: These strings must be formatted in the same way as the output of
|
||||
# DocumentAccess::to_acl.
|
||||
access_control_list: list[str] | None
|
||||
tenant_id: str | None = None
|
||||
|
||||
@@ -144,10 +146,6 @@ class BasicChunkRequest(BaseModel):
|
||||
# In case some queries favor recency more than other queries.
|
||||
recency_bias_multiplier: float = 1.0
|
||||
|
||||
# Sometimes we may want to extract specific keywords from a more semantic query for
|
||||
# a better keyword search.
|
||||
query_keywords: list[str] | None = None # Not used currently
|
||||
|
||||
limit: int | None = None
|
||||
offset: int | None = None # This one is not set currently
|
||||
|
||||
@@ -166,6 +164,8 @@ class ChunkIndexRequest(BasicChunkRequest):
|
||||
# Calculated final filters
|
||||
filters: IndexFilters
|
||||
|
||||
query_keywords: list[str] | None = None
|
||||
|
||||
|
||||
class ContextExpansionType(str, Enum):
|
||||
NOT_RELEVANT = "not_relevant"
|
||||
@@ -372,6 +372,10 @@ class SearchDocsResponse(BaseModel):
|
||||
# document id is the most staightforward way.
|
||||
citation_mapping: dict[int, str]
|
||||
|
||||
# For cases where the frontend only needs to display a subset of the search docs
|
||||
# The whole list is typically still needed for later steps but this set should be saved separately
|
||||
displayed_docs: list[SearchDoc] | None = None
|
||||
|
||||
|
||||
class SavedSearchDoc(SearchDoc):
|
||||
db_doc_id: int
|
||||
@@ -430,11 +434,6 @@ class SavedSearchDoc(SearchDoc):
|
||||
return self_score < other_score
|
||||
|
||||
|
||||
class CitationDocInfo(BaseModel):
|
||||
search_doc: SearchDoc
|
||||
citation_number: int | None
|
||||
|
||||
|
||||
class SavedSearchDocWithContent(SavedSearchDoc):
|
||||
"""Used for endpoints that need to return the actual contents of the retrieved
|
||||
section in addition to the match_highlights."""
|
||||
|
||||
@@ -19,6 +19,7 @@ from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.natural_language_processing.english_stopwords import strip_stopwords
|
||||
from onyx.secondary_llm_flows.source_filter import extract_source_filter
|
||||
from onyx.secondary_llm_flows.time_filter import extract_time_filter
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -278,12 +279,16 @@ def search_pipeline(
|
||||
bypass_acl=chunk_search_request.bypass_acl,
|
||||
)
|
||||
|
||||
query_keywords = strip_stopwords(chunk_search_request.query)
|
||||
|
||||
query_request = ChunkIndexRequest(
|
||||
query=chunk_search_request.query,
|
||||
hybrid_alpha=chunk_search_request.hybrid_alpha,
|
||||
recency_bias_multiplier=chunk_search_request.recency_bias_multiplier,
|
||||
query_keywords=chunk_search_request.query_keywords,
|
||||
query_keywords=query_keywords,
|
||||
filters=filters,
|
||||
limit=chunk_search_request.limit,
|
||||
offset=chunk_search_request.offset,
|
||||
)
|
||||
|
||||
retrieved_chunks = search_chunks(
|
||||
|
||||
@@ -23,45 +23,6 @@ from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _dedupe_chunks(
|
||||
chunks: list[InferenceChunk],
|
||||
) -> list[InferenceChunk]:
|
||||
used_chunks: dict[tuple[str, int], InferenceChunk] = {}
|
||||
for chunk in chunks:
|
||||
key = (chunk.document_id, chunk.chunk_id)
|
||||
if key not in used_chunks:
|
||||
used_chunks[key] = chunk
|
||||
else:
|
||||
stored_chunk_score = used_chunks[key].score or 0
|
||||
this_chunk_score = chunk.score or 0
|
||||
if stored_chunk_score < this_chunk_score:
|
||||
used_chunks[key] = chunk
|
||||
|
||||
return list(used_chunks.values())
|
||||
|
||||
|
||||
def download_nltk_data() -> None:
|
||||
import nltk # type: ignore[import-untyped]
|
||||
|
||||
resources = {
|
||||
"stopwords": "corpora/stopwords",
|
||||
# "wordnet": "corpora/wordnet", # Not in use
|
||||
"punkt_tab": "tokenizers/punkt_tab",
|
||||
}
|
||||
|
||||
for resource_name, resource_path in resources.items():
|
||||
try:
|
||||
nltk.data.find(resource_path)
|
||||
logger.info(f"{resource_name} is already downloaded.")
|
||||
except LookupError:
|
||||
try:
|
||||
logger.info(f"Downloading {resource_name}...")
|
||||
nltk.download(resource_name, quiet=True)
|
||||
logger.info(f"{resource_name} downloaded successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download {resource_name}. Error: {e}")
|
||||
|
||||
|
||||
def combine_retrieval_results(
|
||||
chunk_sets: list[list[InferenceChunk]],
|
||||
) -> list[InferenceChunk]:
|
||||
|
||||
451
backend/onyx/db/discord_bot.py
Normal file
451
backend/onyx/db/discord_bot.py
Normal file
@@ -0,0 +1,451 @@
|
||||
"""CRUD operations for Discord bot models."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.api_key import build_displayable_api_key
|
||||
from onyx.auth.api_key import generate_api_key
|
||||
from onyx.auth.api_key import hash_api_key
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import DISCORD_SERVICE_API_KEY_NAME
|
||||
from onyx.db.api_key import insert_api_key
|
||||
from onyx.db.models import ApiKey
|
||||
from onyx.db.models import DiscordBotConfig
|
||||
from onyx.db.models import DiscordChannelConfig
|
||||
from onyx.db.models import DiscordGuildConfig
|
||||
from onyx.db.models import User
|
||||
from onyx.db.utils import DiscordChannelView
|
||||
from onyx.server.api_key.models import APIKeyArgs
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# === DiscordBotConfig ===
|
||||
|
||||
|
||||
def get_discord_bot_config(db_session: Session) -> DiscordBotConfig | None:
|
||||
"""Get the Discord bot config for this tenant (at most one)."""
|
||||
return db_session.scalar(select(DiscordBotConfig).limit(1))
|
||||
|
||||
|
||||
def create_discord_bot_config(
|
||||
db_session: Session,
|
||||
bot_token: str,
|
||||
) -> DiscordBotConfig:
|
||||
"""Create the Discord bot config. Raises ValueError if already exists.
|
||||
|
||||
The check constraint on id='SINGLETON' ensures only one config per tenant.
|
||||
"""
|
||||
existing = get_discord_bot_config(db_session)
|
||||
if existing:
|
||||
raise ValueError("Discord bot config already exists")
|
||||
|
||||
config = DiscordBotConfig(bot_token=bot_token)
|
||||
db_session.add(config)
|
||||
try:
|
||||
db_session.flush()
|
||||
except IntegrityError:
|
||||
# Race condition: another request created the config concurrently
|
||||
db_session.rollback()
|
||||
raise ValueError("Discord bot config already exists")
|
||||
return config
|
||||
|
||||
|
||||
def delete_discord_bot_config(db_session: Session) -> bool:
|
||||
"""Delete the Discord bot config. Returns True if deleted."""
|
||||
result = db_session.execute(delete(DiscordBotConfig))
|
||||
db_session.flush()
|
||||
return result.rowcount > 0 # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# === Discord Service API Key ===
|
||||
|
||||
|
||||
def get_discord_service_api_key(db_session: Session) -> ApiKey | None:
|
||||
"""Get the Discord service API key if it exists."""
|
||||
return db_session.scalar(
|
||||
select(ApiKey).where(ApiKey.name == DISCORD_SERVICE_API_KEY_NAME)
|
||||
)
|
||||
|
||||
|
||||
def get_or_create_discord_service_api_key(
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
) -> str:
|
||||
"""Get existing Discord service API key or create one.
|
||||
|
||||
The API key is used by the Discord bot to authenticate with the
|
||||
Onyx API pods when sending chat requests.
|
||||
|
||||
Args:
|
||||
db_session: Database session for the tenant.
|
||||
tenant_id: The tenant ID (used for logging/context).
|
||||
|
||||
Returns:
|
||||
The raw API key string (not hashed).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If API key creation fails.
|
||||
"""
|
||||
# Check for existing key
|
||||
existing = get_discord_service_api_key(db_session)
|
||||
if existing:
|
||||
# Database only stores the hash, so we must regenerate to get the raw key.
|
||||
# This is safe since the Discord bot is the only consumer of this key.
|
||||
logger.debug(
|
||||
f"Found existing Discord service API key for tenant {tenant_id} that isn't in cache, "
|
||||
"regenerating to update cache"
|
||||
)
|
||||
new_api_key = generate_api_key(tenant_id)
|
||||
existing.hashed_api_key = hash_api_key(new_api_key)
|
||||
existing.api_key_display = build_displayable_api_key(new_api_key)
|
||||
db_session.flush()
|
||||
return new_api_key
|
||||
|
||||
# Create new API key
|
||||
logger.info(f"Creating Discord service API key for tenant {tenant_id}")
|
||||
api_key_args = APIKeyArgs(
|
||||
name=DISCORD_SERVICE_API_KEY_NAME,
|
||||
role=UserRole.LIMITED, # Limited role is sufficient for chat requests
|
||||
)
|
||||
api_key_descriptor = insert_api_key(
|
||||
db_session=db_session,
|
||||
api_key_args=api_key_args,
|
||||
user_id=None, # Service account, no owner
|
||||
)
|
||||
|
||||
if not api_key_descriptor.api_key:
|
||||
raise RuntimeError(
|
||||
f"Failed to create Discord service API key for tenant {tenant_id}"
|
||||
)
|
||||
|
||||
return api_key_descriptor.api_key
|
||||
|
||||
|
||||
def delete_discord_service_api_key(db_session: Session) -> bool:
|
||||
"""Delete the Discord service API key for a tenant.
|
||||
|
||||
Called when:
|
||||
- Bot config is deleted (self-hosted)
|
||||
- All guild configs are deleted (Cloud)
|
||||
|
||||
Args:
|
||||
db_session: Database session for the tenant.
|
||||
|
||||
Returns:
|
||||
True if the key was deleted, False if it didn't exist.
|
||||
"""
|
||||
existing_key = get_discord_service_api_key(db_session)
|
||||
if not existing_key:
|
||||
return False
|
||||
|
||||
# Also delete the associated user
|
||||
api_key_user = db_session.scalar(
|
||||
select(User).where(User.id == existing_key.user_id) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
db_session.delete(existing_key)
|
||||
if api_key_user:
|
||||
db_session.delete(api_key_user)
|
||||
|
||||
db_session.flush()
|
||||
logger.info("Deleted Discord service API key")
|
||||
return True
|
||||
|
||||
|
||||
# === DiscordGuildConfig ===
|
||||
|
||||
|
||||
def get_guild_configs(
|
||||
db_session: Session,
|
||||
include_channels: bool = False,
|
||||
) -> list[DiscordGuildConfig]:
|
||||
"""Get all guild configs for this tenant."""
|
||||
stmt = select(DiscordGuildConfig)
|
||||
if include_channels:
|
||||
stmt = stmt.options(joinedload(DiscordGuildConfig.channels))
|
||||
return list(db_session.scalars(stmt).unique().all())
|
||||
|
||||
|
||||
def get_guild_config_by_internal_id(
|
||||
db_session: Session,
|
||||
internal_id: int,
|
||||
) -> DiscordGuildConfig | None:
|
||||
"""Get a specific guild config by its ID."""
|
||||
return db_session.scalar(
|
||||
select(DiscordGuildConfig).where(DiscordGuildConfig.id == internal_id)
|
||||
)
|
||||
|
||||
|
||||
def get_guild_config_by_discord_id(
|
||||
db_session: Session,
|
||||
guild_id: int,
|
||||
) -> DiscordGuildConfig | None:
|
||||
"""Get a guild config by Discord guild ID."""
|
||||
return db_session.scalar(
|
||||
select(DiscordGuildConfig).where(DiscordGuildConfig.guild_id == guild_id)
|
||||
)
|
||||
|
||||
|
||||
def get_guild_config_by_registration_key(
|
||||
db_session: Session,
|
||||
registration_key: str,
|
||||
) -> DiscordGuildConfig | None:
|
||||
"""Get a guild config by its registration key."""
|
||||
return db_session.scalar(
|
||||
select(DiscordGuildConfig).where(
|
||||
DiscordGuildConfig.registration_key == registration_key
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def create_guild_config(
|
||||
db_session: Session,
|
||||
registration_key: str,
|
||||
) -> DiscordGuildConfig:
|
||||
"""Create a new guild config with a registration key (guild_id=NULL)."""
|
||||
config = DiscordGuildConfig(registration_key=registration_key)
|
||||
db_session.add(config)
|
||||
db_session.flush()
|
||||
return config
|
||||
|
||||
|
||||
def register_guild(
|
||||
db_session: Session,
|
||||
config: DiscordGuildConfig,
|
||||
guild_id: int,
|
||||
guild_name: str,
|
||||
) -> DiscordGuildConfig:
|
||||
"""Complete registration by setting guild_id and guild_name."""
|
||||
config.guild_id = guild_id
|
||||
config.guild_name = guild_name
|
||||
config.registered_at = datetime.now(timezone.utc)
|
||||
db_session.flush()
|
||||
return config
|
||||
|
||||
|
||||
def update_guild_config(
|
||||
db_session: Session,
|
||||
config: DiscordGuildConfig,
|
||||
enabled: bool,
|
||||
default_persona_id: int | None = None,
|
||||
) -> DiscordGuildConfig:
|
||||
"""Update guild config fields."""
|
||||
config.enabled = enabled
|
||||
config.default_persona_id = default_persona_id
|
||||
db_session.flush()
|
||||
return config
|
||||
|
||||
|
||||
def delete_guild_config(
|
||||
db_session: Session,
|
||||
internal_id: int,
|
||||
) -> bool:
|
||||
"""Delete guild config (cascades to channel configs). Returns True if deleted."""
|
||||
result = db_session.execute(
|
||||
delete(DiscordGuildConfig).where(DiscordGuildConfig.id == internal_id)
|
||||
)
|
||||
db_session.flush()
|
||||
return result.rowcount > 0 # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# === DiscordChannelConfig ===
|
||||
|
||||
|
||||
def get_channel_configs(
|
||||
db_session: Session,
|
||||
guild_config_id: int,
|
||||
) -> list[DiscordChannelConfig]:
|
||||
"""Get all channel configs for a guild."""
|
||||
return list(
|
||||
db_session.scalars(
|
||||
select(DiscordChannelConfig).where(
|
||||
DiscordChannelConfig.guild_config_id == guild_config_id
|
||||
)
|
||||
).all()
|
||||
)
|
||||
|
||||
|
||||
def get_channel_config_by_discord_ids(
|
||||
db_session: Session,
|
||||
guild_id: int,
|
||||
channel_id: int,
|
||||
) -> DiscordChannelConfig | None:
|
||||
"""Get a specific channel config by guild_id and channel_id."""
|
||||
return db_session.scalar(
|
||||
select(DiscordChannelConfig)
|
||||
.join(DiscordGuildConfig)
|
||||
.where(
|
||||
DiscordGuildConfig.guild_id == guild_id,
|
||||
DiscordChannelConfig.channel_id == channel_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_channel_config_by_internal_ids(
|
||||
db_session: Session,
|
||||
guild_config_id: int,
|
||||
channel_config_id: int,
|
||||
) -> DiscordChannelConfig | None:
|
||||
"""Get a specific channel config by guild_config_id and channel_config_id"""
|
||||
return db_session.scalar(
|
||||
select(DiscordChannelConfig).where(
|
||||
DiscordChannelConfig.guild_config_id == guild_config_id,
|
||||
DiscordChannelConfig.id == channel_config_id,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def update_discord_channel_config(
|
||||
db_session: Session,
|
||||
config: DiscordChannelConfig,
|
||||
channel_name: str,
|
||||
thread_only_mode: bool,
|
||||
require_bot_invocation: bool,
|
||||
enabled: bool,
|
||||
persona_override_id: int | None = None,
|
||||
) -> DiscordChannelConfig:
|
||||
"""Update channel config fields."""
|
||||
config.channel_name = channel_name
|
||||
config.require_bot_invocation = require_bot_invocation
|
||||
config.persona_override_id = persona_override_id
|
||||
config.enabled = enabled
|
||||
config.thread_only_mode = thread_only_mode
|
||||
db_session.flush()
|
||||
return config
|
||||
|
||||
|
||||
def delete_discord_channel_config(
|
||||
db_session: Session,
|
||||
guild_config_id: int,
|
||||
channel_config_id: int,
|
||||
) -> bool:
|
||||
"""Delete a channel config. Returns True if deleted."""
|
||||
result = db_session.execute(
|
||||
delete(DiscordChannelConfig).where(
|
||||
DiscordChannelConfig.guild_config_id == guild_config_id,
|
||||
DiscordChannelConfig.id == channel_config_id,
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
return result.rowcount > 0 # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def create_channel_config(
|
||||
db_session: Session,
|
||||
guild_config_id: int,
|
||||
channel_view: DiscordChannelView,
|
||||
) -> DiscordChannelConfig:
|
||||
"""Create a new channel config with default settings (disabled by default, admin enables via UI)."""
|
||||
config = DiscordChannelConfig(
|
||||
guild_config_id=guild_config_id,
|
||||
channel_id=channel_view.channel_id,
|
||||
channel_name=channel_view.channel_name,
|
||||
channel_type=channel_view.channel_type,
|
||||
is_private=channel_view.is_private,
|
||||
)
|
||||
db_session.add(config)
|
||||
db_session.flush()
|
||||
return config
|
||||
|
||||
|
||||
def bulk_create_channel_configs(
|
||||
db_session: Session,
|
||||
guild_config_id: int,
|
||||
channels: list[DiscordChannelView],
|
||||
) -> list[DiscordChannelConfig]:
|
||||
"""Create multiple channel configs at once. Skips existing channels."""
|
||||
# Get existing channel IDs for this guild
|
||||
existing_channel_ids = set(
|
||||
db_session.scalars(
|
||||
select(DiscordChannelConfig.channel_id).where(
|
||||
DiscordChannelConfig.guild_config_id == guild_config_id
|
||||
)
|
||||
).all()
|
||||
)
|
||||
|
||||
# Create configs for new channels only
|
||||
new_configs = []
|
||||
for channel_view in channels:
|
||||
if channel_view.channel_id not in existing_channel_ids:
|
||||
config = DiscordChannelConfig(
|
||||
guild_config_id=guild_config_id,
|
||||
channel_id=channel_view.channel_id,
|
||||
channel_name=channel_view.channel_name,
|
||||
channel_type=channel_view.channel_type,
|
||||
is_private=channel_view.is_private,
|
||||
)
|
||||
db_session.add(config)
|
||||
new_configs.append(config)
|
||||
|
||||
db_session.flush()
|
||||
return new_configs
|
||||
|
||||
|
||||
def sync_channel_configs(
|
||||
db_session: Session,
|
||||
guild_config_id: int,
|
||||
current_channels: list[DiscordChannelView],
|
||||
) -> tuple[int, int, int]:
|
||||
"""Sync channel configs with current Discord channels.
|
||||
|
||||
- Creates configs for new channels (disabled by default)
|
||||
- Removes configs for deleted channels
|
||||
- Updates names and types for existing channels if changed
|
||||
|
||||
Returns: (added_count, removed_count, updated_count)
|
||||
"""
|
||||
current_channel_map = {
|
||||
channel_view.channel_id: channel_view for channel_view in current_channels
|
||||
}
|
||||
current_channel_ids = set(current_channel_map.keys())
|
||||
|
||||
# Get existing configs
|
||||
existing_configs = get_channel_configs(db_session, guild_config_id)
|
||||
existing_channel_ids = {c.channel_id for c in existing_configs}
|
||||
|
||||
# Find channels to add, remove, and potentially update
|
||||
to_add = current_channel_ids - existing_channel_ids
|
||||
to_remove = existing_channel_ids - current_channel_ids
|
||||
|
||||
# Add new channels
|
||||
added_count = 0
|
||||
for channel_id in to_add:
|
||||
channel_view = current_channel_map[channel_id]
|
||||
create_channel_config(db_session, guild_config_id, channel_view)
|
||||
added_count += 1
|
||||
|
||||
# Remove deleted channels
|
||||
removed_count = 0
|
||||
for config in existing_configs:
|
||||
if config.channel_id in to_remove:
|
||||
db_session.delete(config)
|
||||
removed_count += 1
|
||||
|
||||
# Update names, types, and privacy for existing channels if changed
|
||||
updated_count = 0
|
||||
for config in existing_configs:
|
||||
if config.channel_id in current_channel_ids:
|
||||
channel_view = current_channel_map[config.channel_id]
|
||||
changed = False
|
||||
if config.channel_name != channel_view.channel_name:
|
||||
config.channel_name = channel_view.channel_name
|
||||
changed = True
|
||||
if config.channel_type != channel_view.channel_type:
|
||||
config.channel_type = channel_view.channel_type
|
||||
changed = True
|
||||
if config.is_private != channel_view.is_private:
|
||||
config.is_private = channel_view.is_private
|
||||
changed = True
|
||||
if changed:
|
||||
updated_count += 1
|
||||
|
||||
db_session.flush()
|
||||
return added_count, removed_count, updated_count
|
||||
@@ -3,6 +3,8 @@ from uuid import UUID
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -18,45 +20,6 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def insert_input_prompt_if_not_exists(
|
||||
user: User | None,
|
||||
input_prompt_id: int | None,
|
||||
prompt: str,
|
||||
content: str,
|
||||
active: bool,
|
||||
is_public: bool,
|
||||
db_session: Session,
|
||||
commit: bool = True,
|
||||
) -> InputPrompt:
|
||||
if input_prompt_id is not None:
|
||||
input_prompt = (
|
||||
db_session.query(InputPrompt).filter_by(id=input_prompt_id).first()
|
||||
)
|
||||
else:
|
||||
query = db_session.query(InputPrompt).filter(InputPrompt.prompt == prompt)
|
||||
if user:
|
||||
query = query.filter(InputPrompt.user_id == user.id)
|
||||
else:
|
||||
query = query.filter(InputPrompt.user_id.is_(None))
|
||||
input_prompt = query.first()
|
||||
|
||||
if input_prompt is None:
|
||||
input_prompt = InputPrompt(
|
||||
id=input_prompt_id,
|
||||
prompt=prompt,
|
||||
content=content,
|
||||
active=active,
|
||||
is_public=is_public or user is None,
|
||||
user_id=user.id if user else None,
|
||||
)
|
||||
db_session.add(input_prompt)
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
|
||||
return input_prompt
|
||||
|
||||
|
||||
def insert_input_prompt(
|
||||
prompt: str,
|
||||
content: str,
|
||||
@@ -64,16 +27,41 @@ def insert_input_prompt(
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> InputPrompt:
|
||||
input_prompt = InputPrompt(
|
||||
user_id = user.id if user else None
|
||||
|
||||
# Use atomic INSERT ... ON CONFLICT DO NOTHING with RETURNING
|
||||
# to avoid race conditions with the uniqueness check
|
||||
stmt = pg_insert(InputPrompt).values(
|
||||
prompt=prompt,
|
||||
content=content,
|
||||
active=True,
|
||||
is_public=is_public,
|
||||
user_id=user.id if user is not None else None,
|
||||
user_id=user_id,
|
||||
)
|
||||
db_session.add(input_prompt)
|
||||
db_session.commit()
|
||||
|
||||
# Use the appropriate constraint based on whether this is a user-owned or public prompt
|
||||
if user_id is not None:
|
||||
stmt = stmt.on_conflict_do_nothing(constraint="uq_inputprompt_prompt_user_id")
|
||||
else:
|
||||
# Partial unique indexes cannot be targeted by constraint name;
|
||||
# must use index_elements + index_where
|
||||
stmt = stmt.on_conflict_do_nothing(
|
||||
index_elements=[InputPrompt.prompt],
|
||||
index_where=InputPrompt.user_id.is_(None),
|
||||
)
|
||||
|
||||
stmt = stmt.returning(InputPrompt)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
input_prompt = result.scalar_one_or_none()
|
||||
|
||||
if input_prompt is None:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"A prompt shortcut with the name '{prompt}' already exists",
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
return input_prompt
|
||||
|
||||
|
||||
@@ -98,23 +86,40 @@ def update_input_prompt(
|
||||
input_prompt.content = content
|
||||
input_prompt.active = active
|
||||
|
||||
db_session.commit()
|
||||
try:
|
||||
db_session.commit()
|
||||
except IntegrityError:
|
||||
db_session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"A prompt shortcut with the name '{prompt}' already exists",
|
||||
)
|
||||
|
||||
return input_prompt
|
||||
|
||||
|
||||
def validate_user_prompt_authorization(
|
||||
user: User | None, input_prompt: InputPrompt
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the user is authorized to modify the given input prompt.
|
||||
Returns True only if the user owns the prompt.
|
||||
Returns False for public prompts (only admins can modify those),
|
||||
unless auth is disabled (then anyone can manage public prompts).
|
||||
"""
|
||||
prompt = InputPromptSnapshot.from_model(input_prompt=input_prompt)
|
||||
|
||||
if prompt.user_id is not None:
|
||||
if user is None:
|
||||
return False
|
||||
# Public prompts cannot be modified via the user API (unless auth is disabled)
|
||||
if prompt.is_public or prompt.user_id is None:
|
||||
return AUTH_TYPE == AuthType.DISABLED
|
||||
|
||||
user_details = UserInfo.from_model(user)
|
||||
if str(user_details.id) != str(prompt.user_id):
|
||||
return False
|
||||
return True
|
||||
# User must be logged in
|
||||
if user is None:
|
||||
return False
|
||||
|
||||
# User must own the prompt
|
||||
user_details = UserInfo.from_model(user)
|
||||
return str(user_details.id) == str(prompt.user_id)
|
||||
|
||||
|
||||
def remove_public_input_prompt(input_prompt_id: int, db_session: Session) -> None:
|
||||
|
||||
@@ -9,6 +9,9 @@ def get_memories(user: User | None, db_session: Session) -> list[str]:
|
||||
if user is None:
|
||||
return []
|
||||
|
||||
if not user.use_memories:
|
||||
return []
|
||||
|
||||
user_info = [
|
||||
f"User's name: {user.personal_name}" if user.personal_name else "",
|
||||
f"User's role: {user.personal_role}" if user.personal_role else "",
|
||||
|
||||
@@ -26,6 +26,7 @@ from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Index
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import BigInteger
|
||||
|
||||
from sqlalchemy import Sequence
|
||||
from sqlalchemy import String
|
||||
@@ -187,6 +188,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
nullable=True,
|
||||
default=None,
|
||||
)
|
||||
chat_background: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
# personalization fields are exposed via the chat user settings "Personalization" tab
|
||||
personal_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
personal_role: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
@@ -2045,7 +2047,7 @@ class ChatSession(Base):
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
persona_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("persona.id", ondelete="SET NULL"), nullable=True
|
||||
ForeignKey("persona.id"), nullable=True
|
||||
)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
# This chat created by OnyxBot
|
||||
@@ -2931,8 +2933,6 @@ class PersonaLabel(Base):
|
||||
"Persona",
|
||||
secondary=Persona__PersonaLabel.__table__,
|
||||
back_populates="labels",
|
||||
cascade="all, delete-orphan",
|
||||
single_parent=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -3038,6 +3038,124 @@ class SlackBot(Base):
|
||||
)
|
||||
|
||||
|
||||
class DiscordBotConfig(Base):
|
||||
"""Global Discord bot configuration (one per tenant).
|
||||
|
||||
Stores the bot token when not provided via DISCORD_BOT_TOKEN env var.
|
||||
Uses a fixed ID with check constraint to enforce only one row per tenant.
|
||||
"""
|
||||
|
||||
__tablename__ = "discord_bot_config"
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
String, primary_key=True, server_default=text("'SINGLETON'")
|
||||
)
|
||||
bot_token: Mapped[str] = mapped_column(EncryptedString(), nullable=False)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
|
||||
|
||||
class DiscordGuildConfig(Base):
|
||||
"""Configuration for a Discord guild (server) connected to this tenant.
|
||||
|
||||
registration_key is a one-time key used to link a Discord server to this tenant.
|
||||
Format: discord_<tenant_id>.<random_token>
|
||||
guild_id is NULL until the Discord admin runs !register with the key.
|
||||
"""
|
||||
|
||||
__tablename__ = "discord_guild_config"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
|
||||
# Discord snowflake - NULL until registered via command in Discord
|
||||
guild_id: Mapped[int | None] = mapped_column(BigInteger, nullable=True, unique=True)
|
||||
guild_name: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
||||
|
||||
# One-time registration key: discord_<tenant_id>.<random_token>
|
||||
registration_key: Mapped[str] = mapped_column(String, unique=True, nullable=False)
|
||||
|
||||
registered_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
# Configuration
|
||||
default_persona_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("persona.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
enabled: Mapped[bool] = mapped_column(
|
||||
Boolean, server_default=text("true"), nullable=False
|
||||
)
|
||||
|
||||
# Relationships
|
||||
default_persona: Mapped["Persona | None"] = relationship(
|
||||
"Persona", foreign_keys=[default_persona_id]
|
||||
)
|
||||
channels: Mapped[list["DiscordChannelConfig"]] = relationship(
|
||||
back_populates="guild_config", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class DiscordChannelConfig(Base):
|
||||
"""Per-channel configuration for Discord bot behavior.
|
||||
|
||||
Used to whitelist specific channels and configure per-channel behavior.
|
||||
"""
|
||||
|
||||
__tablename__ = "discord_channel_config"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
guild_config_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("discord_guild_config.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
# Discord snowflake
|
||||
channel_id: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
channel_name: Mapped[str] = mapped_column(String(), nullable=False)
|
||||
|
||||
# Channel type from Discord (text, forum)
|
||||
channel_type: Mapped[str] = mapped_column(
|
||||
String(20), server_default=text("'text'"), nullable=False
|
||||
)
|
||||
|
||||
# True if @everyone cannot view the channel
|
||||
is_private: Mapped[bool] = mapped_column(
|
||||
Boolean, server_default=text("false"), nullable=False
|
||||
)
|
||||
|
||||
# If true, bot only responds to messages in threads
|
||||
# Otherwise, will reply in channel
|
||||
thread_only_mode: Mapped[bool] = mapped_column(
|
||||
Boolean, server_default=text("false"), nullable=False
|
||||
)
|
||||
|
||||
# If true (default), bot only responds when @mentioned
|
||||
# If false, bot responds to ALL messages in this channel
|
||||
require_bot_invocation: Mapped[bool] = mapped_column(
|
||||
Boolean, server_default=text("true"), nullable=False
|
||||
)
|
||||
|
||||
# Override the guild's default persona for this channel
|
||||
persona_override_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("persona.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
|
||||
enabled: Mapped[bool] = mapped_column(
|
||||
Boolean, server_default=text("false"), nullable=False
|
||||
)
|
||||
|
||||
# Relationships
|
||||
guild_config: Mapped["DiscordGuildConfig"] = relationship(back_populates="channels")
|
||||
persona_override: Mapped["Persona | None"] = relationship()
|
||||
|
||||
# Constraints
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"guild_config_id", "channel_id", name="uq_discord_channel_guild_channel"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Milestone(Base):
|
||||
# This table is used to track significant events for a deployment towards finding value
|
||||
# The table is currently not used for features but it may be used in the future to inform
|
||||
@@ -3115,25 +3233,6 @@ class FileRecord(Base):
|
||||
)
|
||||
|
||||
|
||||
class AgentSearchMetrics(Base):
|
||||
__tablename__ = "agent__search_metrics"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
persona_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("persona.id"), nullable=True
|
||||
)
|
||||
agent_type: Mapped[str] = mapped_column(String)
|
||||
start_time: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
|
||||
base_duration_s: Mapped[float] = mapped_column(Float)
|
||||
full_duration_s: Mapped[float] = mapped_column(Float)
|
||||
base_metrics: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
|
||||
refined_metrics: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
|
||||
all_metrics: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
|
||||
|
||||
|
||||
"""
|
||||
************************************************************************
|
||||
Enterprise Edition Models
|
||||
@@ -3526,6 +3625,18 @@ class InputPrompt(Base):
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
# Unique constraint on (prompt, user_id) for user-owned prompts
|
||||
UniqueConstraint("prompt", "user_id", name="uq_inputprompt_prompt_user_id"),
|
||||
# Partial unique index for public prompts (user_id IS NULL)
|
||||
Index(
|
||||
"uq_inputprompt_prompt_public",
|
||||
"prompt",
|
||||
unique=True,
|
||||
postgresql_where=text("user_id IS NULL"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class InputPrompt__User(Base):
|
||||
__tablename__ = "inputprompt__user"
|
||||
@@ -3534,7 +3645,7 @@ class InputPrompt__User(Base):
|
||||
ForeignKey("inputprompt.id"), primary_key=True
|
||||
)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("inputprompt.id"), primary_key=True
|
||||
ForeignKey("user.id"), primary_key=True
|
||||
)
|
||||
disabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
|
||||
@@ -917,7 +917,9 @@ def upsert_persona(
|
||||
existing_persona.icon_name = icon_name
|
||||
existing_persona.is_visible = is_visible
|
||||
existing_persona.search_start_date = search_start_date
|
||||
existing_persona.labels = labels or []
|
||||
if label_ids is not None:
|
||||
existing_persona.labels.clear()
|
||||
existing_persona.labels = labels or []
|
||||
existing_persona.is_default_persona = (
|
||||
is_default_persona
|
||||
if is_default_persona is not None
|
||||
|
||||
@@ -20,7 +20,7 @@ from onyx.db.models import SearchSettings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.db.search_settings import update_search_settings_status
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -80,39 +80,43 @@ def _perform_index_swap(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# remove the old index from the vector db
|
||||
document_index = get_default_document_index(new_search_settings, None)
|
||||
# This flow is for checking and possibly creating an index so we get all
|
||||
# indices.
|
||||
document_indices = get_all_document_indices(new_search_settings, None, None)
|
||||
|
||||
WAIT_SECONDS = 5
|
||||
|
||||
success = False
|
||||
for x in range(VESPA_NUM_ATTEMPTS_ON_STARTUP):
|
||||
try:
|
||||
logger.notice(
|
||||
f"Vespa index swap (attempt {x+1}/{VESPA_NUM_ATTEMPTS_ON_STARTUP})..."
|
||||
)
|
||||
document_index.ensure_indices_exist(
|
||||
primary_embedding_dim=new_search_settings.final_embedding_dim,
|
||||
primary_embedding_precision=new_search_settings.embedding_precision,
|
||||
# just finished swap, no more secondary index
|
||||
secondary_index_embedding_dim=None,
|
||||
secondary_index_embedding_precision=None,
|
||||
)
|
||||
for document_index in document_indices:
|
||||
success = False
|
||||
for x in range(VESPA_NUM_ATTEMPTS_ON_STARTUP):
|
||||
try:
|
||||
logger.notice(
|
||||
f"Document index {document_index.__class__.__name__} swap (attempt {x+1}/{VESPA_NUM_ATTEMPTS_ON_STARTUP})..."
|
||||
)
|
||||
document_index.ensure_indices_exist(
|
||||
primary_embedding_dim=new_search_settings.final_embedding_dim,
|
||||
primary_embedding_precision=new_search_settings.embedding_precision,
|
||||
# just finished swap, no more secondary index
|
||||
secondary_index_embedding_dim=None,
|
||||
secondary_index_embedding_precision=None,
|
||||
)
|
||||
|
||||
logger.notice("Vespa index swap complete.")
|
||||
success = True
|
||||
break
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Vespa index swap did not succeed. The Vespa service may not be ready yet. Retrying in {WAIT_SECONDS} seconds."
|
||||
)
|
||||
time.sleep(WAIT_SECONDS)
|
||||
logger.notice("Document index swap complete.")
|
||||
success = True
|
||||
break
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Document index swap for {document_index.__class__.__name__} did not succeed. "
|
||||
f"The document index services may not be ready yet. Retrying in {WAIT_SECONDS} seconds."
|
||||
)
|
||||
time.sleep(WAIT_SECONDS)
|
||||
|
||||
if not success:
|
||||
logger.error(
|
||||
f"Vespa index swap did not succeed. Attempt limit reached. ({VESPA_NUM_ATTEMPTS_ON_STARTUP})"
|
||||
)
|
||||
return None
|
||||
if not success:
|
||||
logger.error(
|
||||
f"Document index swap for {document_index.__class__.__name__} did not succeed. "
|
||||
f"Attempt limit reached. ({VESPA_NUM_ATTEMPTS_ON_STARTUP})"
|
||||
)
|
||||
return None
|
||||
|
||||
return current_search_settings
|
||||
|
||||
|
||||
@@ -139,6 +139,20 @@ def update_user_theme_preference(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_user_chat_background(
|
||||
user_id: UUID,
|
||||
chat_background: str | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Update user's chat background setting."""
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == user_id) # type: ignore
|
||||
.values(chat_background=chat_background)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_user_personalization(
|
||||
user_id: UUID,
|
||||
*,
|
||||
|
||||
@@ -15,7 +15,9 @@ from sqlalchemy.sql.elements import KeyedColumnElement
|
||||
from onyx.auth.invited_users import remove_user_from_invited_users
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__User
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User
|
||||
@@ -327,6 +329,15 @@ def delete_user_from_db(
|
||||
db_session.query(SamlAccount).filter(
|
||||
SamlAccount.user_id == user_to_delete.id
|
||||
).delete()
|
||||
# Null out ownership on document sets and personas so they're
|
||||
# preserved for other users instead of being cascade-deleted
|
||||
db_session.query(DocumentSet).filter(
|
||||
DocumentSet.user_id == user_to_delete.id
|
||||
).update({DocumentSet.user_id: None})
|
||||
db_session.query(Persona).filter(Persona.user_id == user_to_delete.id).update(
|
||||
{Persona.user_id: None}
|
||||
)
|
||||
|
||||
db_session.query(DocumentSet__User).filter(
|
||||
DocumentSet__User.user_id == user_to_delete.id
|
||||
).delete()
|
||||
|
||||
@@ -40,3 +40,10 @@ class DocumentRow(BaseModel):
|
||||
class SortOrder(str, Enum):
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
||||
|
||||
class DiscordChannelView(BaseModel):
|
||||
channel_id: int
|
||||
channel_name: str
|
||||
channel_type: str = "text" # text, forum
|
||||
is_private: bool = False # True if @everyone cannot view the channel
|
||||
|
||||
@@ -2,13 +2,18 @@ from onyx.configs.app_configs import BLURB_SIZE
|
||||
from onyx.configs.constants import RETURN_SEPARATOR
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceChunkUncleaned
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
|
||||
|
||||
def generate_enriched_content_for_chunk(chunk: DocMetadataAwareIndexChunk) -> str:
|
||||
def generate_enriched_content_for_chunk_text(chunk: DocMetadataAwareIndexChunk) -> str:
|
||||
return f"{chunk.title_prefix}{chunk.doc_summary}{chunk.content}{chunk.chunk_context}{chunk.metadata_suffix_keyword}"
|
||||
|
||||
|
||||
def generate_enriched_content_for_chunk_embedding(chunk: DocAwareChunk) -> str:
|
||||
return f"{chunk.title_prefix}{chunk.doc_summary}{chunk.content}{chunk.chunk_context}{chunk.metadata_suffix_semantic}"
|
||||
|
||||
|
||||
def cleanup_content_for_chunks(
|
||||
chunks: list[InferenceChunkUncleaned],
|
||||
) -> list[InferenceChunk]:
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_FOR_ONYX
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchOldDocumentIndex,
|
||||
@@ -17,17 +16,24 @@ def get_default_document_index(
|
||||
secondary_search_settings: SearchSettings | None,
|
||||
httpx_client: httpx.Client | None = None,
|
||||
) -> DocumentIndex:
|
||||
"""Primary index is the index that is used for querying/updating etc.
|
||||
Secondary index is for when both the currently used index and the upcoming
|
||||
index both need to be updated, updates are applied to both indices"""
|
||||
"""Gets the default document index from env vars.
|
||||
|
||||
To be used for retrieval only. Indexing should be done through both indices
|
||||
until Vespa is deprecated.
|
||||
|
||||
Pre-existing docstring for this function, although secondary indices are not
|
||||
currently supported:
|
||||
Primary index is the index that is used for querying/updating etc. Secondary
|
||||
index is for when both the currently used index and the upcoming index both
|
||||
need to be updated, updates are applied to both indices.
|
||||
"""
|
||||
secondary_index_name: str | None = None
|
||||
secondary_large_chunks_enabled: bool | None = None
|
||||
if secondary_search_settings:
|
||||
secondary_index_name = secondary_search_settings.index_name
|
||||
secondary_large_chunks_enabled = secondary_search_settings.large_chunks_enabled
|
||||
|
||||
if ENABLE_OPENSEARCH_FOR_ONYX:
|
||||
if ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX:
|
||||
return OpenSearchOldDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
secondary_index_name=secondary_index_name,
|
||||
@@ -47,12 +53,48 @@ def get_default_document_index(
|
||||
)
|
||||
|
||||
|
||||
def get_current_primary_default_document_index(db_session: Session) -> DocumentIndex:
|
||||
def get_all_document_indices(
|
||||
search_settings: SearchSettings,
|
||||
secondary_search_settings: SearchSettings | None,
|
||||
httpx_client: httpx.Client | None = None,
|
||||
) -> list[DocumentIndex]:
|
||||
"""Gets all document indices.
|
||||
|
||||
NOTE: Will only return an OpenSearch index interface if
|
||||
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX is True. This is so we don't break flows
|
||||
where we know it won't be enabled.
|
||||
|
||||
Used for indexing only. Until Vespa is deprecated we will index into both
|
||||
document indices. Retrieval is done through only one index however.
|
||||
|
||||
Large chunks and secondary indices are not currently supported so we
|
||||
hardcode appropriate values.
|
||||
"""
|
||||
TODO: Use redis to cache this or something
|
||||
"""
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
return get_default_document_index(
|
||||
search_settings,
|
||||
None,
|
||||
vespa_document_index = VespaIndex(
|
||||
index_name=search_settings.index_name,
|
||||
secondary_index_name=(
|
||||
secondary_search_settings.index_name if secondary_search_settings else None
|
||||
),
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=(
|
||||
secondary_search_settings.large_chunks_enabled
|
||||
if secondary_search_settings
|
||||
else None
|
||||
),
|
||||
multitenant=MULTI_TENANT,
|
||||
httpx_client=httpx_client,
|
||||
)
|
||||
opensearch_document_index: OpenSearchOldDocumentIndex | None = None
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
opensearch_document_index = OpenSearchOldDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
secondary_index_name=None,
|
||||
large_chunks_enabled=False,
|
||||
secondary_large_chunks_enabled=None,
|
||||
multitenant=MULTI_TENANT,
|
||||
httpx_client=httpx_client,
|
||||
)
|
||||
result: list[DocumentIndex] = [vespa_document_index]
|
||||
if opensearch_document_index:
|
||||
result.append(opensearch_document_index)
|
||||
return result
|
||||
|
||||
@@ -28,8 +28,8 @@ of "minimum value clipping".
|
||||
## On time decay and boosting
|
||||
Embedding models do not have a uniform distribution from 0 to 1. The values typically cluster strongly around 0.6 to 0.8 but also
|
||||
varies between models and even the query. It is not a safe assumption to pre-normalize the scores so we also cannot apply any
|
||||
additive or multiplicative boost to it. Ie. if results of a doc cluster around 0.6 to 0.8 and I give a 50% penalty to the score,
|
||||
it doesn't bring a result from the top of the range to 50 percentile, it brings its under the 0.6 and is now the worst match.
|
||||
additive or multiplicative boost to it. i.e. if results of a doc cluster around 0.6 to 0.8 and I give a 50% penalty to the score,
|
||||
it doesn't bring a result from the top of the range to 50th percentile, it brings it under the 0.6 and is now the worst match.
|
||||
Same logic applies to additive boosting.
|
||||
|
||||
So these boosts can only be applied after normalization. Unfortunately with Opensearch, the normalization processor runs last
|
||||
@@ -40,7 +40,7 @@ and vector would make the docs which only came because of time filter very low s
|
||||
scored documents from the union of all the `Search` phase documents to show up higher and potentially not get dropped before
|
||||
being fetched and returned to the user. But there are other issues of including these:
|
||||
- There is no way to sort by this field, only a filter, so there's no way to guarantee the best docs even irrespective of the
|
||||
contents. If there are lots of updates, this may miss
|
||||
contents. If there are lots of updates, this may miss.
|
||||
- There is not a good way to normalize this field, the best is to clip it on the bottom.
|
||||
- This would require using min-max norm but z-score norm is better for the other functions due to things like it being less
|
||||
sensitive to outliers, better handles distribution drifts (min-max assumes stable meaningful ranges), better for comparing
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TypeVar
|
||||
@@ -569,6 +570,9 @@ class OpenSearchClient:
|
||||
def close(self) -> None:
|
||||
"""Closes the client.
|
||||
|
||||
TODO(andrei): Can we have some way to auto close when the client no
|
||||
longer has any references?
|
||||
|
||||
Raises:
|
||||
Exception: There was an error closing the client.
|
||||
"""
|
||||
@@ -596,3 +600,55 @@ class OpenSearchClient:
|
||||
)
|
||||
hits_second_layer: list[Any] = hits_first_layer.get("hits", [])
|
||||
return hits_second_layer
|
||||
|
||||
|
||||
def wait_for_opensearch_with_timeout(
|
||||
wait_interval_s: int = 5,
|
||||
wait_limit_s: int = 60,
|
||||
client: OpenSearchClient | None = None,
|
||||
) -> bool:
|
||||
"""Waits for OpenSearch to become ready subject to a timeout.
|
||||
|
||||
Will create a new dummy client if no client is provided. Will close this
|
||||
client at the end of the function. Will not close the client if it was
|
||||
supplied.
|
||||
|
||||
Args:
|
||||
wait_interval_s: The interval in seconds to wait between checks.
|
||||
Defaults to 5.
|
||||
wait_limit_s: The total timeout in seconds to wait for OpenSearch to
|
||||
become ready. Defaults to 60.
|
||||
client: The OpenSearch client to use for pinging. If None, a new dummy
|
||||
client will be created. Defaults to None.
|
||||
|
||||
Returns:
|
||||
True if OpenSearch is ready, False otherwise.
|
||||
"""
|
||||
made_client = False
|
||||
try:
|
||||
if client is None:
|
||||
# NOTE: index_name does not matter because we are only using this object
|
||||
# to ping.
|
||||
# TODO(andrei): Make this better.
|
||||
client = OpenSearchClient(index_name="")
|
||||
made_client = True
|
||||
time_start = time.monotonic()
|
||||
while True:
|
||||
if client.ping():
|
||||
logger.info("[OpenSearch] Readiness probe succeeded. Continuing...")
|
||||
return True
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
if time_elapsed > wait_limit_s:
|
||||
logger.info(
|
||||
f"[OpenSearch] Readiness probe did not succeed within the timeout "
|
||||
f"({wait_limit_s} seconds)."
|
||||
)
|
||||
return False
|
||||
logger.info(
|
||||
f"[OpenSearch] Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={wait_limit_s:.1f}"
|
||||
)
|
||||
time.sleep(wait_interval_s)
|
||||
finally:
|
||||
if made_client:
|
||||
assert client is not None
|
||||
client.close()
|
||||
|
||||
@@ -3,7 +3,9 @@ from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
from onyx.configs.constants import PUBLIC_DOC_PAT
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_experts_stores_representations,
|
||||
)
|
||||
@@ -17,7 +19,7 @@ from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.db.models import DocumentSource
|
||||
from onyx.document_index.chunk_content_enrichment import cleanup_content_for_chunks
|
||||
from onyx.document_index.chunk_content_enrichment import (
|
||||
generate_enriched_content_for_chunk,
|
||||
generate_enriched_content_for_chunk_text,
|
||||
)
|
||||
from onyx.document_index.interfaces import DocumentIndex as OldDocumentIndex
|
||||
from onyx.document_index.interfaces import (
|
||||
@@ -68,6 +70,18 @@ from shared_configs.model_server_models import Embedding
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
def generate_opensearch_filtered_access_control_list(
|
||||
access: DocumentAccess,
|
||||
) -> list[str]:
|
||||
"""Generates an access control list with PUBLIC_DOC_PAT removed.
|
||||
|
||||
In the OpenSearch schema this is represented by PUBLIC_FIELD_NAME.
|
||||
"""
|
||||
access_control_list = access.to_acl()
|
||||
access_control_list.discard(PUBLIC_DOC_PAT)
|
||||
return list(access_control_list)
|
||||
|
||||
|
||||
def _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
chunk: DocumentChunk,
|
||||
score: float | None,
|
||||
@@ -140,19 +154,21 @@ def _convert_onyx_chunk_to_opensearch_document(
|
||||
return DocumentChunk(
|
||||
document_id=chunk.source_document.id,
|
||||
chunk_index=chunk.chunk_id,
|
||||
title=chunk.source_document.title,
|
||||
# Use get_title_for_document_index to match the logic used when creating
|
||||
# the title_embedding in the embedder. This method falls back to
|
||||
# semantic_identifier when title is None (but not empty string).
|
||||
title=chunk.source_document.get_title_for_document_index(),
|
||||
title_vector=chunk.title_embedding,
|
||||
content=generate_enriched_content_for_chunk(chunk),
|
||||
content=generate_enriched_content_for_chunk_text(chunk),
|
||||
content_vector=chunk.embeddings.full_embedding,
|
||||
source_type=chunk.source_document.source.value,
|
||||
metadata_list=chunk.source_document.get_metadata_str_attributes(),
|
||||
metadata_suffix=chunk.metadata_suffix_keyword,
|
||||
last_updated=chunk.source_document.doc_updated_at,
|
||||
public=chunk.access.is_public,
|
||||
# TODO(andrei): When going over ACL look very carefully at
|
||||
# access_control_list. Notice DocumentAccess::to_acl prepends every
|
||||
# string with a type.
|
||||
access_control_list=list(chunk.access.to_acl()),
|
||||
access_control_list=generate_opensearch_filtered_access_control_list(
|
||||
chunk.access
|
||||
),
|
||||
global_boost=chunk.boost,
|
||||
semantic_identifier=chunk.source_document.semantic_identifier,
|
||||
image_file_id=chunk.image_file_id,
|
||||
@@ -421,6 +437,24 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
def verify_and_create_index_if_necessary(
|
||||
self, embedding_dim: int, embedding_precision: EmbeddingPrecision
|
||||
) -> None:
|
||||
"""Verifies and creates the index if necessary.
|
||||
|
||||
Also puts the desired search pipeline state, creating the pipelines if
|
||||
they do not exist and updating them otherwise.
|
||||
|
||||
Args:
|
||||
embedding_dim: Vector dimensionality for the vector similarity part
|
||||
of the search.
|
||||
embedding_precision: Precision of the values of the vectors for the
|
||||
similarity part of the search.
|
||||
|
||||
Raises:
|
||||
RuntimeError: There was an error verifying or creating the index or
|
||||
search pipelines.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Verifying and creating index {self._index_name} if necessary."
|
||||
)
|
||||
expected_mappings = DocumentSchema.get_document_schema(
|
||||
embedding_dim, self._tenant_state.multitenant
|
||||
)
|
||||
@@ -450,6 +484,9 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata,
|
||||
) -> list[DocumentInsertionRecord]:
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} chunks for index {self._index_name}."
|
||||
)
|
||||
# Set of doc IDs.
|
||||
unique_docs_to_be_indexed: set[str] = set()
|
||||
document_indexing_results: list[DocumentInsertionRecord] = []
|
||||
@@ -494,6 +531,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
def delete(self, document_id: str, chunk_count: int | None = None) -> int:
|
||||
"""Deletes all chunks for a given document.
|
||||
|
||||
Does nothing if the specified document ID does not exist.
|
||||
|
||||
TODO(andrei): Make this method require supplying source type.
|
||||
TODO(andrei): Consider implementing this method to delete on document
|
||||
chunk IDs vs querying for matching document chunks.
|
||||
@@ -510,6 +549,9 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
Returns:
|
||||
The number of chunks successfully deleted.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Deleting document {document_id} from index {self._index_name}."
|
||||
)
|
||||
query_body = DocumentQuery.delete_from_document_id_query(
|
||||
document_id=document_id,
|
||||
tenant_state=self._tenant_state,
|
||||
@@ -523,6 +565,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
) -> None:
|
||||
"""Updates some set of chunks.
|
||||
|
||||
NOTE: Will raise if the specified document chunks do not exist.
|
||||
NOTE: Requires document chunk count be known; will raise if it is not.
|
||||
NOTE: Each update request must have some field to update; if not it is
|
||||
assumed there is a bug in the caller and this will raise.
|
||||
@@ -539,14 +582,19 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
RuntimeError: Failed to update some or all of the chunks for the
|
||||
specified documents.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Updating {len(update_requests)} chunks for index {self._index_name}."
|
||||
)
|
||||
for update_request in update_requests:
|
||||
properties_to_update: dict[str, Any] = dict()
|
||||
# TODO(andrei): Nit but consider if we can use DocumentChunk
|
||||
# here so we don't have to think about passing in the
|
||||
# appropriate types into this dict.
|
||||
if update_request.access is not None:
|
||||
properties_to_update[ACCESS_CONTROL_LIST_FIELD_NAME] = list(
|
||||
update_request.access.to_acl()
|
||||
properties_to_update[ACCESS_CONTROL_LIST_FIELD_NAME] = (
|
||||
generate_opensearch_filtered_access_control_list(
|
||||
update_request.access
|
||||
)
|
||||
)
|
||||
if update_request.document_sets is not None:
|
||||
properties_to_update[DOCUMENT_SETS_FIELD_NAME] = list(
|
||||
@@ -592,24 +640,27 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
def id_based_retrieval(
|
||||
self,
|
||||
chunk_requests: list[DocumentSectionRequest],
|
||||
# TODO(andrei): When going over ACL look very carefully at
|
||||
# access_control_list. Notice DocumentAccess::to_acl prepends every
|
||||
# string with a type.
|
||||
filters: IndexFilters,
|
||||
# TODO(andrei): Remove this from the new interface at some point; we
|
||||
# should not be exposing this.
|
||||
batch_retrieval: bool = False,
|
||||
# TODO(andrei): Add a param for whether to retrieve hidden docs.
|
||||
) -> list[InferenceChunk]:
|
||||
"""
|
||||
TODO(andrei): Consider implementing this method to retrieve on document
|
||||
chunk IDs vs querying for matching document chunks.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Retrieving {len(chunk_requests)} chunks for index {self._index_name}."
|
||||
)
|
||||
results: list[InferenceChunk] = []
|
||||
for chunk_request in chunk_requests:
|
||||
search_hits: list[SearchHit[DocumentChunk]] = []
|
||||
query_body = DocumentQuery.get_from_document_id_query(
|
||||
document_id=chunk_request.document_id,
|
||||
tenant_state=self._tenant_state,
|
||||
index_filters=filters,
|
||||
include_hidden=False,
|
||||
max_chunk_size=chunk_request.max_chunk_size,
|
||||
min_chunk_index=chunk_request.min_chunk_ind,
|
||||
max_chunk_index=chunk_request.max_chunk_ind,
|
||||
@@ -636,19 +687,21 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
query_embedding: Embedding,
|
||||
final_keywords: list[str] | None,
|
||||
query_type: QueryType,
|
||||
# TODO(andrei): When going over ACL look very carefully at
|
||||
# access_control_list. Notice DocumentAccess::to_acl prepends every
|
||||
# string with a type.
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
offset: int = 0,
|
||||
) -> list[InferenceChunk]:
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Hybrid retrieving {num_to_retrieve} chunks for index {self._index_name}."
|
||||
)
|
||||
query_body = DocumentQuery.get_hybrid_search_query(
|
||||
query_text=query,
|
||||
query_vector=query_embedding,
|
||||
num_candidates=1000, # TODO(andrei): Magic number.
|
||||
num_hits=num_to_retrieve,
|
||||
tenant_state=self._tenant_state,
|
||||
index_filters=filters,
|
||||
include_hidden=False,
|
||||
)
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._os_client.search(
|
||||
body=query_body,
|
||||
|
||||
@@ -172,24 +172,23 @@ class DocumentChunk(BaseModel):
|
||||
return serialized_exclude_none
|
||||
|
||||
@field_serializer("last_updated", mode="wrap")
|
||||
def serialize_datetime_fields_to_epoch_millis(
|
||||
def serialize_datetime_fields_to_epoch_seconds(
|
||||
self, value: datetime | None, handler: SerializerFunctionWrapHandler
|
||||
) -> int | None:
|
||||
"""
|
||||
Serializes datetime fields to milliseconds since the Unix epoch.
|
||||
Serializes datetime fields to seconds since the Unix epoch.
|
||||
|
||||
If there is no datetime, returns None.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
value = set_or_convert_timezone_to_utc(value)
|
||||
# timestamp returns a float in seconds so convert to millis.
|
||||
return int(value.timestamp() * 1000)
|
||||
return int(value.timestamp())
|
||||
|
||||
@field_validator("last_updated", mode="before")
|
||||
@classmethod
|
||||
def parse_epoch_millis_to_datetime(cls, value: Any) -> datetime | None:
|
||||
"""Parses milliseconds since the Unix epoch to a datetime object.
|
||||
def parse_epoch_seconds_to_datetime(cls, value: Any) -> datetime | None:
|
||||
"""Parses seconds since the Unix epoch to a datetime object.
|
||||
|
||||
If the input is None, returns None.
|
||||
|
||||
@@ -204,7 +203,7 @@ class DocumentChunk(BaseModel):
|
||||
raise ValueError(
|
||||
f"Bug: Expected an int for the last_updated property from OpenSearch, got {type(value)} instead."
|
||||
)
|
||||
return datetime.fromtimestamp(value / 1000, tz=timezone.utc)
|
||||
return datetime.fromtimestamp(value, tz=timezone.utc)
|
||||
|
||||
@field_serializer("tenant_id", mode="wrap")
|
||||
def serialize_tenant_state(
|
||||
@@ -354,11 +353,9 @@ class DocumentSchema:
|
||||
},
|
||||
SOURCE_TYPE_FIELD_NAME: {"type": "keyword"},
|
||||
METADATA_LIST_FIELD_NAME: {"type": "keyword"},
|
||||
# TODO(andrei): Check if Vespa stores seconds, we may wanna do
|
||||
# seconds here not millis.
|
||||
LAST_UPDATED_FIELD_NAME: {
|
||||
"type": "date",
|
||||
"format": "epoch_millis",
|
||||
"format": "epoch_second",
|
||||
# For some reason date defaults to False, even though it
|
||||
# would make sense to sort by date.
|
||||
"doc_values": True,
|
||||
@@ -366,14 +363,21 @@ class DocumentSchema:
|
||||
# Access control fields.
|
||||
# Whether the doc is public. Could have fallen under access
|
||||
# control list but is such a broad and critical filter that it
|
||||
# is its own field.
|
||||
# is its own field. If true, ACCESS_CONTROL_LIST_FIELD_NAME
|
||||
# should have no effect on queries.
|
||||
PUBLIC_FIELD_NAME: {"type": "boolean"},
|
||||
# Access control list for the doc, excluding public access,
|
||||
# which is covered above.
|
||||
# If a user's access set contains at least one entry from this
|
||||
# set, the user should be able to retrieve this document. This
|
||||
# only applies if public is set to false; public non-hidden
|
||||
# documents are always visible to anyone in a given tenancy
|
||||
# regardless of this field.
|
||||
ACCESS_CONTROL_LIST_FIELD_NAME: {"type": "keyword"},
|
||||
# Whether the doc is hidden from search results. Should clobber
|
||||
# all other search filters; up to search implementations to
|
||||
# guarantee this.
|
||||
# Whether the doc is hidden from search results.
|
||||
# Should clobber all other access search filters, namely
|
||||
# PUBLIC_FIELD_NAME and ACCESS_CONTROL_LIST_FIELD_NAME; up to
|
||||
# search implementations to guarantee this.
|
||||
HIDDEN_FIELD_NAME: {"type": "boolean"},
|
||||
GLOBAL_BOOST_FIELD_NAME: {"type": "integer"},
|
||||
# This field is only used for displaying a useful name for the
|
||||
@@ -447,7 +451,6 @@ class DocumentSchema:
|
||||
DOCUMENT_ID_FIELD_NAME: {"type": "keyword"},
|
||||
CHUNK_INDEX_FIELD_NAME: {"type": "integer"},
|
||||
# The maximum number of tokens this chunk's content can hold.
|
||||
# TODO(andrei): Can we generalize this to embedding type?
|
||||
MAX_CHUNK_SIZE_FIELD_NAME: {"type": "integer"},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,21 +1,36 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import INDEX_SEPARATOR
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import Tag
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.constants import SEARCH_CONTENT_KEYWORD_WEIGHT
|
||||
from onyx.document_index.opensearch.constants import SEARCH_CONTENT_PHRASE_WEIGHT
|
||||
from onyx.document_index.opensearch.constants import SEARCH_CONTENT_VECTOR_WEIGHT
|
||||
from onyx.document_index.opensearch.constants import SEARCH_TITLE_KEYWORD_WEIGHT
|
||||
from onyx.document_index.opensearch.constants import SEARCH_TITLE_VECTOR_WEIGHT
|
||||
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import CHUNK_INDEX_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import CONTENT_VECTOR_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import DOCUMENT_ID_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import DOCUMENT_SETS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import HIDDEN_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import LAST_UPDATED_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import MAX_CHUNK_SIZE_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import METADATA_LIST_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import PUBLIC_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import set_or_convert_timezone_to_utc
|
||||
from onyx.document_index.opensearch.schema import SOURCE_TYPE_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import TENANT_ID_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import TITLE_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import TITLE_VECTOR_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME
|
||||
|
||||
# Normalization pipelines combine document scores from multiple query clauses.
|
||||
# The number and ordering of weights should match the query clauses. The values
|
||||
@@ -91,6 +106,11 @@ assert (
|
||||
# given search. This value is configurable in the index settings.
|
||||
DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW = 10_000
|
||||
|
||||
# For documents which do not have a value for LAST_UPDATED_FIELD_NAME, we assume
|
||||
# that the document was last updated this many days ago for the purpose of time
|
||||
# cutoff filtering during retrieval.
|
||||
ASSUMED_DOCUMENT_AGE_DAYS = 90
|
||||
|
||||
|
||||
class DocumentQuery:
|
||||
"""
|
||||
@@ -103,6 +123,8 @@ class DocumentQuery:
|
||||
def get_from_document_id_query(
|
||||
document_id: str,
|
||||
tenant_state: TenantState,
|
||||
index_filters: IndexFilters,
|
||||
include_hidden: bool,
|
||||
max_chunk_size: int,
|
||||
min_chunk_index: int | None,
|
||||
max_chunk_index: int | None,
|
||||
@@ -120,6 +142,8 @@ class DocumentQuery:
|
||||
document_id: Onyx document ID. Notably not an OpenSearch document
|
||||
ID, which points to what Onyx would refer to as a chunk.
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
index_filters: Filters for the document retrieval query.
|
||||
include_hidden: Whether to include hidden documents.
|
||||
max_chunk_size: Document chunks are categorized by the maximum
|
||||
number of tokens they can hold. This parameter specifies the
|
||||
maximum size category of document chunks to retrieve.
|
||||
@@ -136,28 +160,21 @@ class DocumentQuery:
|
||||
Returns:
|
||||
A dictionary representing the final ID search query.
|
||||
"""
|
||||
filter_clauses: list[dict[str, Any]] = [
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
|
||||
]
|
||||
|
||||
if tenant_state.multitenant:
|
||||
# TODO(andrei): Fix tenant stuff.
|
||||
filter_clauses.append(
|
||||
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
|
||||
)
|
||||
|
||||
if min_chunk_index is not None or max_chunk_index is not None:
|
||||
range_clause: dict[str, Any] = {"range": {CHUNK_INDEX_FIELD_NAME: {}}}
|
||||
if min_chunk_index is not None:
|
||||
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["gte"] = min_chunk_index
|
||||
if max_chunk_index is not None:
|
||||
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["lte"] = max_chunk_index
|
||||
filter_clauses.append(range_clause)
|
||||
|
||||
filter_clauses.append(
|
||||
{"term": {MAX_CHUNK_SIZE_FIELD_NAME: {"value": max_chunk_size}}}
|
||||
filter_clauses = DocumentQuery._get_search_filters(
|
||||
tenant_state=tenant_state,
|
||||
include_hidden=include_hidden,
|
||||
access_control_list=index_filters.access_control_list,
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=min_chunk_index,
|
||||
max_chunk_index=max_chunk_index,
|
||||
max_chunk_size=max_chunk_size,
|
||||
document_id=document_id,
|
||||
)
|
||||
|
||||
final_get_ids_query: dict[str, Any] = {
|
||||
"query": {"bool": {"filter": filter_clauses}},
|
||||
# We include this to make sure OpenSearch does not revert to
|
||||
@@ -195,15 +212,22 @@ class DocumentQuery:
|
||||
Returns:
|
||||
A dictionary representing the final delete query.
|
||||
"""
|
||||
filter_clauses: list[dict[str, Any]] = [
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
|
||||
]
|
||||
|
||||
if tenant_state.multitenant:
|
||||
filter_clauses.append(
|
||||
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
|
||||
)
|
||||
|
||||
filter_clauses = DocumentQuery._get_search_filters(
|
||||
tenant_state=tenant_state,
|
||||
# Delete hidden docs too.
|
||||
include_hidden=True,
|
||||
access_control_list=None,
|
||||
source_types=[],
|
||||
tags=[],
|
||||
document_sets=[],
|
||||
user_file_ids=[],
|
||||
project_id=None,
|
||||
time_cutoff=None,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
max_chunk_size=None,
|
||||
document_id=document_id,
|
||||
)
|
||||
final_delete_query: dict[str, Any] = {
|
||||
"query": {"bool": {"filter": filter_clauses}},
|
||||
}
|
||||
@@ -217,19 +241,25 @@ class DocumentQuery:
|
||||
num_candidates: int,
|
||||
num_hits: int,
|
||||
tenant_state: TenantState,
|
||||
index_filters: IndexFilters,
|
||||
include_hidden: bool,
|
||||
) -> dict[str, Any]:
|
||||
"""Returns a final hybrid search query.
|
||||
|
||||
This query can be directly supplied to the OpenSearch client.
|
||||
NOTE: This query can be directly supplied to the OpenSearch client, but
|
||||
it MUST be supplied in addition to a search pipeline. The results from
|
||||
hybrid search are not meaningful without that step.
|
||||
|
||||
Args:
|
||||
query_text: The text to query for.
|
||||
query_vector: The vector embedding of the text to query for.
|
||||
num_candidates: The number of candidates to consider for vector
|
||||
num_candidates: The number of neighbors to consider for vector
|
||||
similarity search. Generally more candidates improves search
|
||||
quality at the cost of performance.
|
||||
num_hits: The final number of hits to return.
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
index_filters: Filters for the hybrid search query.
|
||||
include_hidden: Whether to include hidden documents.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the final hybrid search query.
|
||||
@@ -243,31 +273,47 @@ class DocumentQuery:
|
||||
hybrid_search_subqueries = DocumentQuery._get_hybrid_search_subqueries(
|
||||
query_text, query_vector, num_candidates
|
||||
)
|
||||
hybrid_search_filters = DocumentQuery._get_hybrid_search_filters(tenant_state)
|
||||
hybrid_search_filters = DocumentQuery._get_search_filters(
|
||||
tenant_state=tenant_state,
|
||||
include_hidden=include_hidden,
|
||||
# TODO(andrei): We've done no filtering for PUBLIC_DOC_PAT up to
|
||||
# now. This should not cause any issues but it can introduce
|
||||
# redundant filters in queries that may affect performance.
|
||||
access_control_list=index_filters.access_control_list,
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
)
|
||||
match_highlights_configuration = (
|
||||
DocumentQuery._get_match_highlights_configuration()
|
||||
)
|
||||
|
||||
# See https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
|
||||
hybrid_search_query: dict[str, Any] = {
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"hybrid": {
|
||||
"queries": hybrid_search_subqueries,
|
||||
}
|
||||
}
|
||||
],
|
||||
# TODO(andrei): When revisiting our hybrid query logic see if
|
||||
# this needs to be nested one level down.
|
||||
"filter": hybrid_search_filters,
|
||||
"hybrid": {
|
||||
"queries": hybrid_search_subqueries,
|
||||
# Applied to all the sub-queries. Source:
|
||||
# https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
|
||||
# Does AND for each filter in the list.
|
||||
"filter": {"bool": {"filter": hybrid_search_filters}},
|
||||
}
|
||||
}
|
||||
|
||||
# NOTE: By default, hybrid search retrieves "size"-many results from
|
||||
# each OpenSearch shard before aggregation. Source:
|
||||
# https://docs.opensearch.org/latest/vector-search/ai-search/hybrid-search/pagination/
|
||||
|
||||
final_hybrid_search_body: dict[str, Any] = {
|
||||
"query": hybrid_search_query,
|
||||
"size": num_hits,
|
||||
"highlight": match_highlights_configuration,
|
||||
}
|
||||
|
||||
return final_hybrid_search_body
|
||||
|
||||
@staticmethod
|
||||
@@ -294,7 +340,8 @@ class DocumentQuery:
|
||||
pipeline.
|
||||
|
||||
NOTE: For OpenSearch, 5 is the maximum number of query clauses allowed
|
||||
in a single hybrid query.
|
||||
in a single hybrid query. Source:
|
||||
https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
|
||||
|
||||
Args:
|
||||
query_text: The text of the query to search for.
|
||||
@@ -305,6 +352,7 @@ class DocumentQuery:
|
||||
hybrid_search_queries: list[dict[str, Any]] = [
|
||||
{
|
||||
"knn": {
|
||||
# Match on semantic similarity of the title.
|
||||
TITLE_VECTOR_FIELD_NAME: {
|
||||
"vector": query_vector,
|
||||
"k": num_candidates,
|
||||
@@ -313,6 +361,7 @@ class DocumentQuery:
|
||||
},
|
||||
{
|
||||
"knn": {
|
||||
# Match on semantic similarity of the content.
|
||||
CONTENT_VECTOR_FIELD_NAME: {
|
||||
"vector": query_vector,
|
||||
"k": num_candidates,
|
||||
@@ -322,36 +371,273 @@ class DocumentQuery:
|
||||
{
|
||||
"multi_match": {
|
||||
"query": query_text,
|
||||
# TODO(andrei): Ask Yuhong do we want this?
|
||||
# Either fuzzy match on the analyzed title (boosted 2x), or
|
||||
# exact match on exact title keywords (no OpenSearch
|
||||
# analysis done on the title). See
|
||||
# https://docs.opensearch.org/latest/mappings/supported-field-types/keyword/
|
||||
"fields": [f"{TITLE_FIELD_NAME}^2", f"{TITLE_FIELD_NAME}.keyword"],
|
||||
# Returns the score of the best match of the fields above.
|
||||
# See
|
||||
# https://docs.opensearch.org/latest/query-dsl/full-text/multi-match/
|
||||
"type": "best_fields",
|
||||
}
|
||||
},
|
||||
# Fuzzy match on the OpenSearch-analyzed content. See
|
||||
# https://docs.opensearch.org/latest/query-dsl/full-text/match/
|
||||
{"match": {CONTENT_FIELD_NAME: {"query": query_text}}},
|
||||
# Exact match on the OpenSearch-analyzed content. See
|
||||
# https://docs.opensearch.org/latest/query-dsl/full-text/match-phrase/
|
||||
{"match_phrase": {CONTENT_FIELD_NAME: {"query": query_text, "boost": 1.5}}},
|
||||
]
|
||||
|
||||
return hybrid_search_queries
|
||||
|
||||
@staticmethod
|
||||
def _get_hybrid_search_filters(tenant_state: TenantState) -> list[dict[str, Any]]:
|
||||
"""Returns filters for hybrid search.
|
||||
def _get_search_filters(
|
||||
tenant_state: TenantState,
|
||||
include_hidden: bool,
|
||||
access_control_list: list[str] | None,
|
||||
source_types: list[DocumentSource],
|
||||
tags: list[Tag],
|
||||
document_sets: list[str],
|
||||
user_file_ids: list[UUID],
|
||||
project_id: int | None,
|
||||
time_cutoff: datetime | None,
|
||||
min_chunk_index: int | None,
|
||||
max_chunk_index: int | None,
|
||||
max_chunk_size: int | None = None,
|
||||
document_id: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Returns filters to be passed into the "filter" key of a search query.
|
||||
|
||||
For now only fetches public and not hidden documents.
|
||||
The "filter" key applies a logical AND operator to its elements, so
|
||||
every subfilter must evaluate to true in order for the document to be
|
||||
retrieved. This function returns a list of such subfilters.
|
||||
See https://docs.opensearch.org/latest/query-dsl/compound/bool/
|
||||
|
||||
The return of this function is not sufficient to be directly supplied to
|
||||
the OpenSearch client. See get_hybrid_search_query.
|
||||
Args:
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
include_hidden: Whether to include hidden documents.
|
||||
access_control_list: Access control list for the documents to
|
||||
retrieve. If None, there is no restriction on the documents that
|
||||
can be retrieved. If not None, only public documents can be
|
||||
retrieved, or non-public documents where at least one acl
|
||||
provided here is present in the document's acl list.
|
||||
source_types: If supplied, only documents of one of these source
|
||||
types will be retrieved.
|
||||
tags: If supplied, only documents with an entry in their metadata
|
||||
list corresponding to a tag will be retrieved.
|
||||
document_sets: If supplied, only documents with at least one
|
||||
document set ID from this list will be retrieved.
|
||||
user_file_ids: If supplied, only document IDs in this list will be
|
||||
retrieved.
|
||||
project_id: If not None, only documents with this project ID in user
|
||||
projects will be retrieved.
|
||||
time_cutoff: Time cutoff for the documents to retrieve. If not None,
|
||||
Documents which were last updated before this date will not be
|
||||
returned. For documents which do not have a value for their last
|
||||
updated time, we assume some default age of
|
||||
ASSUMED_DOCUMENT_AGE_DAYS for when the document was last
|
||||
updated.
|
||||
min_chunk_index: The minimum chunk index to retrieve, inclusive. If
|
||||
None, no minimum chunk index will be applied.
|
||||
max_chunk_index: The maximum chunk index to retrieve, inclusive. If
|
||||
None, no maximum chunk index will be applied.
|
||||
max_chunk_size: The type of chunk to retrieve, specified by the
|
||||
maximum number of tokens it can hold. If None, no filter will be
|
||||
applied for this. Defaults to None.
|
||||
NOTE: See DocumentChunk.max_chunk_size.
|
||||
document_id: The document ID to retrieve. If None, no filter will be
|
||||
applied for this. Defaults to None.
|
||||
WARNING: This filters on the same property as user_file_ids.
|
||||
Although it would never make sense to supply both, note that if
|
||||
user_file_ids is supplied and does not contain document_id, no
|
||||
matches will be retrieved.
|
||||
|
||||
TODO(andrei): Add ACL filters and stuff.
|
||||
Returns:
|
||||
A list of filters to be passed into the "filter" key of a search
|
||||
query.
|
||||
"""
|
||||
hybrid_search_filters: list[dict[str, Any]] = [
|
||||
{"term": {PUBLIC_FIELD_NAME: {"value": True}}},
|
||||
{"term": {HIDDEN_FIELD_NAME: {"value": False}}},
|
||||
]
|
||||
|
||||
def _get_acl_visibility_filter(
|
||||
access_control_list: list[str],
|
||||
) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
acl_visibility_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
acl_visibility_filter["bool"]["should"].append(
|
||||
{"term": {PUBLIC_FIELD_NAME: {"value": True}}}
|
||||
)
|
||||
for acl in access_control_list:
|
||||
acl_subclause: dict[str, Any] = {
|
||||
"term": {ACCESS_CONTROL_LIST_FIELD_NAME: {"value": acl}}
|
||||
}
|
||||
acl_visibility_filter["bool"]["should"].append(acl_subclause)
|
||||
return acl_visibility_filter
|
||||
|
||||
def _get_source_type_filter(
|
||||
source_types: list[DocumentSource],
|
||||
) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
source_type_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for source_type in source_types:
|
||||
source_type_filter["bool"]["should"].append(
|
||||
{"term": {SOURCE_TYPE_FIELD_NAME: {"value": source_type.value}}}
|
||||
)
|
||||
return source_type_filter
|
||||
|
||||
def _get_tag_filter(tags: list[Tag]) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
tag_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for tag in tags:
|
||||
# Kind of an abstraction leak, see
|
||||
# convert_metadata_dict_to_list_of_strings for why metadata list
|
||||
# entries are expected to look this way.
|
||||
tag_str = f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}"
|
||||
tag_filter["bool"]["should"].append(
|
||||
{"term": {METADATA_LIST_FIELD_NAME: {"value": tag_str}}}
|
||||
)
|
||||
return tag_filter
|
||||
|
||||
def _get_document_set_filter(document_sets: list[str]) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
document_set_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for document_set in document_sets:
|
||||
document_set_filter["bool"]["should"].append(
|
||||
{"term": {DOCUMENT_SETS_FIELD_NAME: {"value": document_set}}}
|
||||
)
|
||||
return document_set_filter
|
||||
|
||||
def _get_user_file_id_filter(user_file_ids: list[UUID]) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
user_file_id_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for user_file_id in user_file_ids:
|
||||
user_file_id_filter["bool"]["should"].append(
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": str(user_file_id)}}}
|
||||
)
|
||||
return user_file_id_filter
|
||||
|
||||
def _get_user_project_filter(project_id: int) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
user_project_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
user_project_filter["bool"]["should"].append(
|
||||
{"term": {USER_PROJECTS_FIELD_NAME: {"value": project_id}}}
|
||||
)
|
||||
return user_project_filter
|
||||
|
||||
def _get_time_cutoff_filter(time_cutoff: datetime) -> dict[str, Any]:
|
||||
# Convert to UTC if not already so the cutoff is comparable to the
|
||||
# document data.
|
||||
time_cutoff = set_or_convert_timezone_to_utc(time_cutoff)
|
||||
# Logical OR operator on its elements.
|
||||
time_cutoff_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
time_cutoff_filter["bool"]["should"].append(
|
||||
{
|
||||
"range": {
|
||||
LAST_UPDATED_FIELD_NAME: {"gte": int(time_cutoff.timestamp())}
|
||||
}
|
||||
}
|
||||
)
|
||||
if time_cutoff < datetime.now(timezone.utc) - timedelta(
|
||||
days=ASSUMED_DOCUMENT_AGE_DAYS
|
||||
):
|
||||
# Since the time cutoff is older than ASSUMED_DOCUMENT_AGE_DAYS
|
||||
# ago, we include documents which have no
|
||||
# LAST_UPDATED_FIELD_NAME value.
|
||||
time_cutoff_filter["bool"]["should"].append(
|
||||
{
|
||||
"bool": {
|
||||
"must_not": {"exists": {"field": LAST_UPDATED_FIELD_NAME}}
|
||||
}
|
||||
}
|
||||
)
|
||||
return time_cutoff_filter
|
||||
|
||||
def _get_chunk_index_filter(
|
||||
min_chunk_index: int | None, max_chunk_index: int | None
|
||||
) -> dict[str, Any]:
|
||||
range_clause: dict[str, Any] = {"range": {CHUNK_INDEX_FIELD_NAME: {}}}
|
||||
if min_chunk_index is not None:
|
||||
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["gte"] = min_chunk_index
|
||||
if max_chunk_index is not None:
|
||||
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["lte"] = max_chunk_index
|
||||
return range_clause
|
||||
|
||||
filter_clauses: list[dict[str, Any]] = []
|
||||
|
||||
if not include_hidden:
|
||||
filter_clauses.append({"term": {HIDDEN_FIELD_NAME: {"value": False}}})
|
||||
|
||||
if access_control_list is not None:
|
||||
# If an access control list is provided, the caller can only
|
||||
# retrieve public documents, and non-public documents where at least
|
||||
# one acl provided here is present in the document's acl list. If
|
||||
# there is explicitly no list provided, we make no restrictions on
|
||||
# the documents that can be retrieved.
|
||||
filter_clauses.append(_get_acl_visibility_filter(access_control_list))
|
||||
|
||||
if source_types:
|
||||
# If at least one source type is provided, the caller will only
|
||||
# retrieve documents whose source type is present in this input
|
||||
# list.
|
||||
filter_clauses.append(_get_source_type_filter(source_types))
|
||||
|
||||
if tags:
|
||||
# If at least one tag is provided, the caller will only retrieve
|
||||
# documents where at least one tag provided here is present in the
|
||||
# document's metadata list.
|
||||
filter_clauses.append(_get_tag_filter(tags))
|
||||
|
||||
if document_sets:
|
||||
# If at least one document set is provided, the caller will only
|
||||
# retrieve documents where at least one document set provided here
|
||||
# is present in the document's document sets list.
|
||||
filter_clauses.append(_get_document_set_filter(document_sets))
|
||||
|
||||
if user_file_ids:
|
||||
# If at least one user file ID is provided, the caller will only
|
||||
# retrieve documents where the document ID is in this input list of
|
||||
# file IDs. Note that these IDs correspond to Onyx documents whereas
|
||||
# the entries retrieved from the document index correspond to Onyx
|
||||
# document chunks.
|
||||
filter_clauses.append(_get_user_file_id_filter(user_file_ids))
|
||||
|
||||
if project_id is not None:
|
||||
# If a project ID is provided, the caller will only retrieve
|
||||
# documents where the project ID provided here is present in the
|
||||
# document's user projects list.
|
||||
filter_clauses.append(_get_user_project_filter(project_id))
|
||||
|
||||
if time_cutoff is not None:
|
||||
# If a time cutoff is provided, the caller will only retrieve
|
||||
# documents where the document was last updated at or after the time
|
||||
# cutoff. For documents which do not have a value for
|
||||
# LAST_UPDATED_FIELD_NAME, we assume some default age for the
|
||||
# purposes of time cutoff.
|
||||
filter_clauses.append(_get_time_cutoff_filter(time_cutoff))
|
||||
|
||||
if min_chunk_index is not None or max_chunk_index is not None:
|
||||
filter_clauses.append(
|
||||
_get_chunk_index_filter(min_chunk_index, max_chunk_index)
|
||||
)
|
||||
|
||||
if document_id is not None:
|
||||
# WARNING: If user_file_ids has elements and if none of them are
|
||||
# document_id, no matches will be retrieved.
|
||||
filter_clauses.append(
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
|
||||
)
|
||||
|
||||
if max_chunk_size is not None:
|
||||
filter_clauses.append(
|
||||
{"term": {MAX_CHUNK_SIZE_FIELD_NAME: {"value": max_chunk_size}}}
|
||||
)
|
||||
|
||||
if tenant_state.multitenant:
|
||||
hybrid_search_filters.append(
|
||||
filter_clauses.append(
|
||||
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
|
||||
)
|
||||
return hybrid_search_filters
|
||||
|
||||
return filter_clauses
|
||||
|
||||
@staticmethod
|
||||
def _get_match_highlights_configuration() -> dict[str, Any]:
|
||||
@@ -378,4 +664,5 @@ class DocumentQuery:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return match_highlights_configuration
|
||||
|
||||
@@ -17,7 +17,7 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_experts_stores_representations,
|
||||
)
|
||||
from onyx.document_index.chunk_content_enrichment import (
|
||||
generate_enriched_content_for_chunk,
|
||||
generate_enriched_content_for_chunk_text,
|
||||
)
|
||||
from onyx.document_index.document_index_utils import get_uuid_from_chunk
|
||||
from onyx.document_index.document_index_utils import get_uuid_from_chunk_info_old
|
||||
@@ -186,7 +186,7 @@ def _index_vespa_chunk(
|
||||
# For the BM25 index, the keyword suffix is used, the vector is already generated with the more
|
||||
# natural language representation of the metadata section
|
||||
CONTENT: remove_invalid_unicode_chars(
|
||||
generate_enriched_content_for_chunk(chunk)
|
||||
generate_enriched_content_for_chunk_text(chunk)
|
||||
),
|
||||
# This duplication of `content` is needed for keyword highlighting
|
||||
# Note that it's not exactly the same as the actual content
|
||||
|
||||
@@ -7,6 +7,9 @@ from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorStopSignal
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.document_index.chunk_content_enrichment import (
|
||||
generate_enriched_content_for_chunk_embedding,
|
||||
)
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
@@ -126,7 +129,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
if chunk.large_chunk_reference_ids:
|
||||
large_chunks_present = True
|
||||
chunk_text = (
|
||||
f"{chunk.title_prefix}{chunk.doc_summary}{chunk.content}{chunk.chunk_context}{chunk.metadata_suffix_semantic}"
|
||||
generate_enriched_content_for_chunk_embedding(chunk)
|
||||
) or chunk.source_document.get_title_for_document_index()
|
||||
|
||||
if not chunk_text:
|
||||
|
||||
@@ -37,6 +37,7 @@ from onyx.document_index.document_index_utils import (
|
||||
get_multipass_config,
|
||||
)
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import DocumentInsertionRecord
|
||||
from onyx.document_index.interfaces import DocumentMetadata
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
|
||||
@@ -163,7 +164,7 @@ def index_doc_batch_with_handler(
|
||||
*,
|
||||
chunker: Chunker,
|
||||
embedder: IndexingEmbedder,
|
||||
document_index: DocumentIndex,
|
||||
document_indices: list[DocumentIndex],
|
||||
document_batch: list[Document],
|
||||
request_id: str | None,
|
||||
tenant_id: str,
|
||||
@@ -176,7 +177,7 @@ def index_doc_batch_with_handler(
|
||||
index_pipeline_result = index_doc_batch(
|
||||
chunker=chunker,
|
||||
embedder=embedder,
|
||||
document_index=document_index,
|
||||
document_indices=document_indices,
|
||||
document_batch=document_batch,
|
||||
request_id=request_id,
|
||||
tenant_id=tenant_id,
|
||||
@@ -627,7 +628,7 @@ def index_doc_batch(
|
||||
document_batch: list[Document],
|
||||
chunker: Chunker,
|
||||
embedder: IndexingEmbedder,
|
||||
document_index: DocumentIndex,
|
||||
document_indices: list[DocumentIndex],
|
||||
request_id: str | None,
|
||||
tenant_id: str,
|
||||
adapter: IndexingBatchAdapter,
|
||||
@@ -743,47 +744,57 @@ def index_doc_batch(
|
||||
short_descriptor_log = str(short_descriptor_list)[:1024]
|
||||
logger.debug(f"Indexing the following chunks: {short_descriptor_log}")
|
||||
|
||||
# A document will not be spread across different batches, so all the
|
||||
# documents with chunks in this set, are fully represented by the chunks
|
||||
# in this set
|
||||
(
|
||||
insertion_records,
|
||||
vector_db_write_failures,
|
||||
) = write_chunks_to_vector_db_with_backoff(
|
||||
document_index=document_index,
|
||||
chunks=result.chunks,
|
||||
index_batch_params=IndexBatchParams(
|
||||
doc_id_to_previous_chunk_cnt=result.doc_id_to_previous_chunk_cnt,
|
||||
doc_id_to_new_chunk_cnt=result.doc_id_to_new_chunk_cnt,
|
||||
tenant_id=tenant_id,
|
||||
large_chunks_enabled=chunker.enable_large_chunks,
|
||||
),
|
||||
)
|
||||
primary_doc_idx_insertion_records: list[DocumentInsertionRecord] | None = None
|
||||
primary_doc_idx_vector_db_write_failures: list[ConnectorFailure] | None = None
|
||||
for document_index in document_indices:
|
||||
# A document will not be spread across different batches, so all the
|
||||
# documents with chunks in this set, are fully represented by the chunks
|
||||
# in this set
|
||||
(
|
||||
insertion_records,
|
||||
vector_db_write_failures,
|
||||
) = write_chunks_to_vector_db_with_backoff(
|
||||
document_index=document_index,
|
||||
chunks=result.chunks,
|
||||
index_batch_params=IndexBatchParams(
|
||||
doc_id_to_previous_chunk_cnt=result.doc_id_to_previous_chunk_cnt,
|
||||
doc_id_to_new_chunk_cnt=result.doc_id_to_new_chunk_cnt,
|
||||
tenant_id=tenant_id,
|
||||
large_chunks_enabled=chunker.enable_large_chunks,
|
||||
),
|
||||
)
|
||||
|
||||
all_returned_doc_ids = (
|
||||
{record.document_id for record in insertion_records}
|
||||
.union(
|
||||
{
|
||||
record.failed_document.document_id
|
||||
for record in vector_db_write_failures
|
||||
if record.failed_document
|
||||
}
|
||||
)
|
||||
.union(
|
||||
{
|
||||
record.failed_document.document_id
|
||||
for record in embedding_failures
|
||||
if record.failed_document
|
||||
}
|
||||
)
|
||||
)
|
||||
if all_returned_doc_ids != set(updatable_ids):
|
||||
raise RuntimeError(
|
||||
f"Some documents were not successfully indexed. "
|
||||
f"Updatable IDs: {updatable_ids}, "
|
||||
f"Returned IDs: {all_returned_doc_ids}. "
|
||||
"This should never happen."
|
||||
all_returned_doc_ids: set[str] = (
|
||||
{record.document_id for record in insertion_records}
|
||||
.union(
|
||||
{
|
||||
record.failed_document.document_id
|
||||
for record in vector_db_write_failures
|
||||
if record.failed_document
|
||||
}
|
||||
)
|
||||
.union(
|
||||
{
|
||||
record.failed_document.document_id
|
||||
for record in embedding_failures
|
||||
if record.failed_document
|
||||
}
|
||||
)
|
||||
)
|
||||
if all_returned_doc_ids != set(updatable_ids):
|
||||
raise RuntimeError(
|
||||
f"Some documents were not successfully indexed. "
|
||||
f"Updatable IDs: {updatable_ids}, "
|
||||
f"Returned IDs: {all_returned_doc_ids}. "
|
||||
"This should never happen."
|
||||
f"This occured for document index {document_index.__class__.__name__}"
|
||||
)
|
||||
# We treat the first document index we got as the primary one used
|
||||
# for reporting the state of indexing.
|
||||
if primary_doc_idx_insertion_records is None:
|
||||
primary_doc_idx_insertion_records = insertion_records
|
||||
if primary_doc_idx_vector_db_write_failures is None:
|
||||
primary_doc_idx_vector_db_write_failures = vector_db_write_failures
|
||||
|
||||
adapter.post_index(
|
||||
context=context,
|
||||
@@ -792,11 +803,15 @@ def index_doc_batch(
|
||||
result=result,
|
||||
)
|
||||
|
||||
assert primary_doc_idx_insertion_records is not None
|
||||
assert primary_doc_idx_vector_db_write_failures is not None
|
||||
return IndexingPipelineResult(
|
||||
new_docs=len([r for r in insertion_records if not r.already_existed]),
|
||||
new_docs=len(
|
||||
[r for r in primary_doc_idx_insertion_records if not r.already_existed]
|
||||
),
|
||||
total_docs=len(filtered_documents),
|
||||
total_chunks=len(chunks_with_embeddings),
|
||||
failures=vector_db_write_failures + embedding_failures,
|
||||
failures=primary_doc_idx_vector_db_write_failures + embedding_failures,
|
||||
)
|
||||
|
||||
|
||||
@@ -805,7 +820,7 @@ def run_indexing_pipeline(
|
||||
document_batch: list[Document],
|
||||
request_id: str | None,
|
||||
embedder: IndexingEmbedder,
|
||||
document_index: DocumentIndex,
|
||||
document_indices: list[DocumentIndex],
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
adapter: IndexingBatchAdapter,
|
||||
@@ -846,7 +861,7 @@ def run_indexing_pipeline(
|
||||
return index_doc_batch_with_handler(
|
||||
chunker=chunker,
|
||||
embedder=embedder,
|
||||
document_index=document_index,
|
||||
document_indices=document_indices,
|
||||
document_batch=document_batch,
|
||||
request_id=request_id,
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@@ -41,6 +41,11 @@ alphanum_regex = re.compile(r"[^a-z0-9]+")
|
||||
rem_email_regex = re.compile(r"(?<=\S)@([a-z0-9-]+)\.([a-z]{2,6})$")
|
||||
|
||||
|
||||
def _ngrams(sequence: str, n: int) -> list[tuple[str, ...]]:
|
||||
"""Generate n-grams from a sequence."""
|
||||
return [tuple(sequence[i : i + n]) for i in range(len(sequence) - n + 1)]
|
||||
|
||||
|
||||
def _clean_name(entity_name: str) -> str:
|
||||
"""
|
||||
Clean an entity string by removing non-alphanumeric characters and email addresses.
|
||||
@@ -58,8 +63,6 @@ def _normalize_one_entity(
|
||||
attributes: dict[str, str],
|
||||
allowed_docs_temp_view_name: str | None = None,
|
||||
) -> str | None:
|
||||
from nltk import ngrams # type: ignore
|
||||
|
||||
"""
|
||||
Matches a single entity to the best matching entity of the same type.
|
||||
"""
|
||||
@@ -150,16 +153,16 @@ def _normalize_one_entity(
|
||||
|
||||
# step 2: do a weighted ngram analysis and damerau levenshtein distance to rerank
|
||||
n1, n2, n3 = (
|
||||
set(ngrams(cleaned_entity, 1)),
|
||||
set(ngrams(cleaned_entity, 2)),
|
||||
set(ngrams(cleaned_entity, 3)),
|
||||
set(_ngrams(cleaned_entity, 1)),
|
||||
set(_ngrams(cleaned_entity, 2)),
|
||||
set(_ngrams(cleaned_entity, 3)),
|
||||
)
|
||||
for i, (candidate_id_name, candidate_name, _) in enumerate(candidates):
|
||||
cleaned_candidate = _clean_name(candidate_name)
|
||||
h_n1, h_n2, h_n3 = (
|
||||
set(ngrams(cleaned_candidate, 1)),
|
||||
set(ngrams(cleaned_candidate, 2)),
|
||||
set(ngrams(cleaned_candidate, 3)),
|
||||
set(_ngrams(cleaned_candidate, 1)),
|
||||
set(_ngrams(cleaned_candidate, 2)),
|
||||
set(_ngrams(cleaned_candidate, 3)),
|
||||
)
|
||||
|
||||
# compute ngram overlap, renormalize scores if the names are too short for larger ngrams
|
||||
|
||||
@@ -369,6 +369,8 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
# New output item added
|
||||
output_item = parsed_chunk.get("item", {})
|
||||
if output_item.get("type") == "function_call":
|
||||
# Track that we've received tool calls via streaming
|
||||
self._has_streamed_tool_calls = True
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=ChatCompletionToolCallChunk(
|
||||
@@ -394,6 +396,8 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
content_part: Optional[str] = parsed_chunk.get("delta", None)
|
||||
if content_part:
|
||||
# Track that we've received tool calls via streaming
|
||||
self._has_streamed_tool_calls = True
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=ChatCompletionToolCallChunk(
|
||||
@@ -491,22 +495,72 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
|
||||
elif event_type == "response.completed":
|
||||
# Final event signaling all output items (including parallel tool calls) are done
|
||||
# Check if we already received tool calls via streaming events
|
||||
# There is an issue where OpenAI (not via Azure) will give back the tool calls streamed out as tokens
|
||||
# But on Azure, it's only given out all at once. OpenAI also happens to give back the tool calls in the
|
||||
# response.completed event so we need to throw it out here or there are duplicate tool calls.
|
||||
has_streamed_tool_calls = getattr(self, "_has_streamed_tool_calls", False)
|
||||
|
||||
response_data = parsed_chunk.get("response", {})
|
||||
# Determine finish reason based on response content
|
||||
finish_reason = "stop"
|
||||
if response_data.get("output"):
|
||||
for item in response_data["output"]:
|
||||
if isinstance(item, dict) and item.get("type") == "function_call":
|
||||
finish_reason = "tool_calls"
|
||||
break
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason=finish_reason,
|
||||
usage=None,
|
||||
output_items = response_data.get("output", [])
|
||||
|
||||
# Check if there are function_call items in the output
|
||||
has_function_calls = any(
|
||||
isinstance(item, dict) and item.get("type") == "function_call"
|
||||
for item in output_items
|
||||
)
|
||||
|
||||
if has_function_calls and not has_streamed_tool_calls:
|
||||
# Azure's Responses API returns all tool calls in response.completed
|
||||
# without streaming them incrementally. Extract them here.
|
||||
from litellm.types.utils import (
|
||||
Delta,
|
||||
ModelResponseStream,
|
||||
StreamingChoices,
|
||||
)
|
||||
|
||||
tool_calls = []
|
||||
for idx, item in enumerate(output_items):
|
||||
if isinstance(item, dict) and item.get("type") == "function_call":
|
||||
tool_calls.append(
|
||||
ChatCompletionToolCallChunk(
|
||||
id=item.get("call_id"),
|
||||
index=idx,
|
||||
type="function",
|
||||
function=ChatCompletionToolCallFunctionChunk(
|
||||
name=item.get("name"),
|
||||
arguments=item.get("arguments", ""),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return ModelResponseStream(
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
index=0,
|
||||
delta=Delta(tool_calls=tool_calls),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
]
|
||||
)
|
||||
elif has_function_calls:
|
||||
# Tool calls were already streamed, just signal completion
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason="tool_calls",
|
||||
usage=None,
|
||||
)
|
||||
else:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason="stop",
|
||||
usage=None,
|
||||
)
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
@@ -631,6 +685,40 @@ def _patch_openai_responses_transform_response() -> None:
|
||||
LiteLLMResponsesTransformationHandler.transform_response = _patched_transform_response # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _patch_azure_responses_should_fake_stream() -> None:
|
||||
"""
|
||||
Patches AzureOpenAIResponsesAPIConfig.should_fake_stream to always return False.
|
||||
|
||||
By default, LiteLLM uses "fake streaming" (MockResponsesAPIStreamingIterator) for models
|
||||
not in its database. This causes Azure custom model deployments to buffer the entire
|
||||
response before yielding, resulting in poor time-to-first-token.
|
||||
|
||||
Azure's Responses API supports native streaming, so we override this to always use
|
||||
real streaming (SyncResponsesAPIStreamingIterator).
|
||||
"""
|
||||
from litellm.llms.azure.responses.transformation import (
|
||||
AzureOpenAIResponsesAPIConfig,
|
||||
)
|
||||
|
||||
if (
|
||||
getattr(AzureOpenAIResponsesAPIConfig.should_fake_stream, "__name__", "")
|
||||
== "_patched_should_fake_stream"
|
||||
):
|
||||
return
|
||||
|
||||
def _patched_should_fake_stream(
|
||||
self: Any,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
# Azure Responses API supports native streaming - never fake it
|
||||
return False
|
||||
|
||||
_patched_should_fake_stream.__name__ = "_patched_should_fake_stream"
|
||||
AzureOpenAIResponsesAPIConfig.should_fake_stream = _patched_should_fake_stream # type: ignore[method-assign]
|
||||
|
||||
|
||||
def apply_monkey_patches() -> None:
|
||||
"""
|
||||
Apply all necessary monkey patches to LiteLLM for compatibility.
|
||||
@@ -640,12 +728,13 @@ def apply_monkey_patches() -> None:
|
||||
- Patching OllamaChatCompletionResponseIterator.chunk_parser for streaming content
|
||||
- Patching OpenAiResponsesToChatCompletionStreamIterator.chunk_parser for OpenAI Responses API
|
||||
- Patching LiteLLMResponsesTransformationHandler.transform_response for non-streaming responses
|
||||
- Patching LiteLLMResponsesTransformationHandler._convert_content_str_to_input_text for tool content types
|
||||
- Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming
|
||||
"""
|
||||
_patch_ollama_transform_request()
|
||||
_patch_ollama_chunk_parser()
|
||||
_patch_openai_responses_chunk_parser()
|
||||
_patch_openai_responses_transform_response()
|
||||
_patch_azure_responses_should_fake_stream()
|
||||
|
||||
|
||||
def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[str]]:
|
||||
|
||||
@@ -54,11 +54,6 @@
|
||||
"model_vendor": "amazon",
|
||||
"model_version": "v1:0"
|
||||
},
|
||||
"anthropic.claude-3-5-haiku-20241022-v1:0": {
|
||||
"display_name": "Claude Haiku 3.5",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20241022-v1:0"
|
||||
},
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0": {
|
||||
"display_name": "Claude Sonnet 3.5",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -1465,11 +1460,6 @@
|
||||
"model_vendor": "mistral",
|
||||
"model_version": "v0:1"
|
||||
},
|
||||
"bedrock/us.anthropic.claude-3-5-haiku-20241022-v1:0": {
|
||||
"display_name": "Claude Haiku 3.5",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20241022-v1:0"
|
||||
},
|
||||
"chat-bison": {
|
||||
"display_name": "Chat Bison",
|
||||
"model_vendor": "google",
|
||||
@@ -1500,16 +1490,6 @@
|
||||
"model_vendor": "openai",
|
||||
"model_version": "latest"
|
||||
},
|
||||
"claude-3-5-haiku-20241022": {
|
||||
"display_name": "Claude Haiku 3.5",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20241022"
|
||||
},
|
||||
"claude-3-5-haiku-latest": {
|
||||
"display_name": "Claude Haiku 3.5",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "latest"
|
||||
},
|
||||
"claude-3-5-sonnet-20240620": {
|
||||
"display_name": "Claude Sonnet 3.5",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -1715,11 +1695,6 @@
|
||||
"model_vendor": "amazon",
|
||||
"model_version": "v1:0"
|
||||
},
|
||||
"eu.anthropic.claude-3-5-haiku-20241022-v1:0": {
|
||||
"display_name": "Claude Haiku 3.5",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20241022-v1:0"
|
||||
},
|
||||
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0": {
|
||||
"display_name": "Claude Sonnet 3.5",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -3251,15 +3226,6 @@
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "latest"
|
||||
},
|
||||
"openrouter/anthropic/claude-3-5-haiku": {
|
||||
"display_name": "Claude Haiku 3.5",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"openrouter/anthropic/claude-3-5-haiku-20241022": {
|
||||
"display_name": "Claude Haiku 3.5",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20241022"
|
||||
},
|
||||
"openrouter/anthropic/claude-3-haiku": {
|
||||
"display_name": "Claude Haiku 3",
|
||||
"model_vendor": "anthropic"
|
||||
@@ -3774,11 +3740,6 @@
|
||||
"model_vendor": "amazon",
|
||||
"model_version": "1:0"
|
||||
},
|
||||
"us.anthropic.claude-3-5-haiku-20241022-v1:0": {
|
||||
"display_name": "Claude Haiku 3.5",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20241022"
|
||||
},
|
||||
"us.anthropic.claude-3-5-sonnet-20240620-v1:0": {
|
||||
"display_name": "Claude Sonnet 3.5",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -3899,15 +3860,6 @@
|
||||
"model_vendor": "twelvelabs",
|
||||
"model_version": "v1:0"
|
||||
},
|
||||
"vertex_ai/claude-3-5-haiku": {
|
||||
"display_name": "Claude Haiku 3.5",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"vertex_ai/claude-3-5-haiku@20241022": {
|
||||
"display_name": "Claude Haiku 3.5",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20241022"
|
||||
},
|
||||
"vertex_ai/claude-3-5-sonnet": {
|
||||
"display_name": "Claude Sonnet 3.5",
|
||||
"model_vendor": "anthropic"
|
||||
|
||||
@@ -301,6 +301,12 @@ class LitellmLLM(LLM):
|
||||
)
|
||||
is_ollama = self._model_provider == LlmProviderNames.OLLAMA_CHAT
|
||||
is_mistral = self._model_provider == LlmProviderNames.MISTRAL
|
||||
is_vertex_ai = self._model_provider == LlmProviderNames.VERTEX_AI
|
||||
# Vertex Anthropic Opus 4.5 rejects output_config (LiteLLM maps reasoning_effort).
|
||||
# Keep this guard until LiteLLM/Vertex accept the field for this model.
|
||||
is_vertex_opus_4_5 = (
|
||||
is_vertex_ai and "claude-opus-4-5" in self.config.model_name.lower()
|
||||
)
|
||||
|
||||
#########################
|
||||
# Build arguments
|
||||
@@ -331,12 +337,16 @@ class LitellmLLM(LLM):
|
||||
# Temperature
|
||||
temperature = 1 if is_reasoning else self._temperature
|
||||
|
||||
if stream:
|
||||
if stream and not is_vertex_opus_4_5:
|
||||
optional_kwargs["stream_options"] = {"include_usage": True}
|
||||
|
||||
# Use configured default if not provided (if not set in env, low)
|
||||
reasoning_effort = reasoning_effort or ReasoningEffort(DEFAULT_REASONING_EFFORT)
|
||||
if is_reasoning and reasoning_effort != ReasoningEffort.OFF:
|
||||
if (
|
||||
is_reasoning
|
||||
and reasoning_effort != ReasoningEffort.OFF
|
||||
and not is_vertex_opus_4_5
|
||||
):
|
||||
if is_openai_model:
|
||||
# OpenAI API does not accept reasoning params for GPT 5 chat models
|
||||
# (neither reasoning nor reasoning_effort are accepted)
|
||||
|
||||
@@ -96,6 +96,7 @@ from onyx.server.long_term_logs.long_term_logs_api import (
|
||||
router as long_term_logs_router,
|
||||
)
|
||||
from onyx.server.manage.administrative import router as admin_router
|
||||
from onyx.server.manage.discord_bot.api import router as discord_bot_router
|
||||
from onyx.server.manage.embedding.api import admin_router as embedding_admin_router
|
||||
from onyx.server.manage.embedding.api import basic_router as embedding_router
|
||||
from onyx.server.manage.get_state import router as state_router
|
||||
@@ -380,6 +381,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, slack_bot_management_router
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, discord_bot_router)
|
||||
include_router_with_global_prefix_prepended(application, persona_router)
|
||||
include_router_with_global_prefix_prepended(application, admin_persona_router)
|
||||
include_router_with_global_prefix_prepended(application, agents_router)
|
||||
|
||||
225
backend/onyx/natural_language_processing/english_stopwords.py
Normal file
225
backend/onyx/natural_language_processing/english_stopwords.py
Normal file
@@ -0,0 +1,225 @@
|
||||
import re
|
||||
|
||||
ENGLISH_STOPWORDS = [
|
||||
"a",
|
||||
"about",
|
||||
"above",
|
||||
"after",
|
||||
"again",
|
||||
"against",
|
||||
"ain",
|
||||
"all",
|
||||
"am",
|
||||
"an",
|
||||
"and",
|
||||
"any",
|
||||
"are",
|
||||
"aren",
|
||||
"aren't",
|
||||
"as",
|
||||
"at",
|
||||
"be",
|
||||
"because",
|
||||
"been",
|
||||
"before",
|
||||
"being",
|
||||
"below",
|
||||
"between",
|
||||
"both",
|
||||
"but",
|
||||
"by",
|
||||
"can",
|
||||
"couldn",
|
||||
"couldn't",
|
||||
"d",
|
||||
"did",
|
||||
"didn",
|
||||
"didn't",
|
||||
"do",
|
||||
"does",
|
||||
"doesn",
|
||||
"doesn't",
|
||||
"doing",
|
||||
"don",
|
||||
"don't",
|
||||
"down",
|
||||
"during",
|
||||
"each",
|
||||
"few",
|
||||
"for",
|
||||
"from",
|
||||
"further",
|
||||
"had",
|
||||
"hadn",
|
||||
"hadn't",
|
||||
"has",
|
||||
"hasn",
|
||||
"hasn't",
|
||||
"have",
|
||||
"haven",
|
||||
"haven't",
|
||||
"having",
|
||||
"he",
|
||||
"he'd",
|
||||
"he'll",
|
||||
"he's",
|
||||
"her",
|
||||
"here",
|
||||
"hers",
|
||||
"herself",
|
||||
"him",
|
||||
"himself",
|
||||
"his",
|
||||
"how",
|
||||
"i",
|
||||
"i'd",
|
||||
"i'll",
|
||||
"i'm",
|
||||
"i've",
|
||||
"if",
|
||||
"in",
|
||||
"into",
|
||||
"is",
|
||||
"isn",
|
||||
"isn't",
|
||||
"it",
|
||||
"it'd",
|
||||
"it'll",
|
||||
"it's",
|
||||
"its",
|
||||
"itself",
|
||||
"just",
|
||||
"ll",
|
||||
"m",
|
||||
"ma",
|
||||
"me",
|
||||
"mightn",
|
||||
"mightn't",
|
||||
"more",
|
||||
"most",
|
||||
"mustn",
|
||||
"mustn't",
|
||||
"my",
|
||||
"myself",
|
||||
"needn",
|
||||
"needn't",
|
||||
"no",
|
||||
"nor",
|
||||
"not",
|
||||
"now",
|
||||
"o",
|
||||
"of",
|
||||
"off",
|
||||
"on",
|
||||
"once",
|
||||
"only",
|
||||
"or",
|
||||
"other",
|
||||
"our",
|
||||
"ours",
|
||||
"ourselves",
|
||||
"out",
|
||||
"over",
|
||||
"own",
|
||||
"re",
|
||||
"s",
|
||||
"same",
|
||||
"shan",
|
||||
"shan't",
|
||||
"she",
|
||||
"she'd",
|
||||
"she'll",
|
||||
"she's",
|
||||
"should",
|
||||
"should've",
|
||||
"shouldn",
|
||||
"shouldn't",
|
||||
"so",
|
||||
"some",
|
||||
"such",
|
||||
"t",
|
||||
"than",
|
||||
"that",
|
||||
"that'll",
|
||||
"the",
|
||||
"their",
|
||||
"theirs",
|
||||
"them",
|
||||
"themselves",
|
||||
"then",
|
||||
"there",
|
||||
"these",
|
||||
"they",
|
||||
"they'd",
|
||||
"they'll",
|
||||
"they're",
|
||||
"they've",
|
||||
"this",
|
||||
"those",
|
||||
"through",
|
||||
"to",
|
||||
"too",
|
||||
"under",
|
||||
"until",
|
||||
"up",
|
||||
"ve",
|
||||
"very",
|
||||
"was",
|
||||
"wasn",
|
||||
"wasn't",
|
||||
"we",
|
||||
"we'd",
|
||||
"we'll",
|
||||
"we're",
|
||||
"we've",
|
||||
"were",
|
||||
"weren",
|
||||
"weren't",
|
||||
"what",
|
||||
"when",
|
||||
"where",
|
||||
"which",
|
||||
"while",
|
||||
"who",
|
||||
"whom",
|
||||
"why",
|
||||
"will",
|
||||
"with",
|
||||
"won",
|
||||
"won't",
|
||||
"wouldn",
|
||||
"wouldn't",
|
||||
"y",
|
||||
"you",
|
||||
"you'd",
|
||||
"you'll",
|
||||
"you're",
|
||||
"you've",
|
||||
"your",
|
||||
"yours",
|
||||
"yourself",
|
||||
"yourselves",
|
||||
]
|
||||
|
||||
ENGLISH_STOPWORDS_SET = frozenset(ENGLISH_STOPWORDS)
|
||||
|
||||
|
||||
def strip_stopwords(text: str) -> list[str]:
|
||||
"""Remove English stopwords from text.
|
||||
|
||||
Matching is case-insensitive and ignores leading/trailing punctuation
|
||||
on each word. Internal punctuation (like apostrophes in contractions)
|
||||
is preserved for matching, so "you're" matches the stopword "you're"
|
||||
but "youre" would not.
|
||||
"""
|
||||
words = text.split()
|
||||
result = []
|
||||
|
||||
for word in words:
|
||||
# Strip leading/trailing punctuation to get the core word for comparison
|
||||
# This preserves internal punctuation like apostrophes
|
||||
core = re.sub(r"^[^\w']+|[^\w']+$", "", word)
|
||||
if core.lower() not in ENGLISH_STOPWORDS_SET:
|
||||
result.append(word)
|
||||
|
||||
return result
|
||||
287
backend/onyx/onyxbot/discord/DISCORD_MULTITENANT_README.md
Normal file
287
backend/onyx/onyxbot/discord/DISCORD_MULTITENANT_README.md
Normal file
@@ -0,0 +1,287 @@
|
||||
# Discord Bot Multitenant Architecture
|
||||
|
||||
This document analyzes how the Discord cache manager and API client coordinate to handle multitenant API keys from a single Discord client.
|
||||
|
||||
## Overview
|
||||
|
||||
The Discord bot uses a **single-client, multi-tenant** architecture where one `OnyxDiscordClient` instance serves multiple tenants (organizations) simultaneously. Tenant isolation is achieved through:
|
||||
|
||||
- **Cache Manager**: Maps Discord guilds to tenants and stores per-tenant API keys
|
||||
- **API Client**: Stateless HTTP client that accepts dynamic API keys per request
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────┐
|
||||
│ OnyxDiscordClient │
|
||||
│ │
|
||||
│ ┌─────────────────────────┐ ┌─────────────────────────────┐ │
|
||||
│ │ DiscordCacheManager │ │ OnyxAPIClient │ │
|
||||
│ │ │ │ │ │
|
||||
│ │ guild_id → tenant_id │───▶│ send_chat_message( │ │
|
||||
│ │ tenant_id → api_key │ │ message, │ │
|
||||
│ │ │ │ api_key=<per-tenant>, │ │
|
||||
│ └─────────────────────────┘ │ persona_id=... │ │
|
||||
│ │ ) │ │
|
||||
│ └─────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Component Details
|
||||
|
||||
### 1. Cache Manager (`backend/onyx/onyxbot/discord/cache.py`)
|
||||
|
||||
The `DiscordCacheManager` maintains two critical in-memory mappings:
|
||||
|
||||
```python
|
||||
class DiscordCacheManager:
|
||||
_guild_tenants: dict[int, str] # guild_id → tenant_id
|
||||
_api_keys: dict[str, str] # tenant_id → api_key
|
||||
_lock: asyncio.Lock # Concurrency control
|
||||
```
|
||||
|
||||
#### Key Responsibilities
|
||||
|
||||
| Function | Purpose |
|
||||
|----------|---------|
|
||||
| `get_tenant(guild_id)` | O(1) lookup: guild → tenant |
|
||||
| `get_api_key(tenant_id)` | O(1) lookup: tenant → API key |
|
||||
| `refresh_all()` | Full cache rebuild from database |
|
||||
| `refresh_guild()` | Incremental update for single guild |
|
||||
|
||||
#### API Key Provisioning Strategy
|
||||
|
||||
API keys are **lazily provisioned** - only created when first needed:
|
||||
|
||||
```python
|
||||
async def _load_tenant_data(self, tenant_id: str) -> tuple[list[int], str | None]:
|
||||
needs_key = tenant_id not in self._api_keys
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db:
|
||||
# Load guild configs
|
||||
configs = get_discord_bot_configs(db)
|
||||
guild_ids = [c.guild_id for c in configs if c.enabled]
|
||||
|
||||
# Only provision API key if not already cached
|
||||
api_key = None
|
||||
if needs_key:
|
||||
api_key = get_or_create_discord_service_api_key(db, tenant_id)
|
||||
|
||||
return guild_ids, api_key
|
||||
```
|
||||
|
||||
This optimization avoids repeated database calls for API key generation.
|
||||
|
||||
#### Concurrency Control
|
||||
|
||||
All write operations acquire an async lock to prevent race conditions:
|
||||
|
||||
```python
|
||||
async def refresh_all(self) -> None:
|
||||
async with self._lock:
|
||||
# Safe to modify _guild_tenants and _api_keys
|
||||
for tenant_id in get_all_tenant_ids():
|
||||
guild_ids, api_key = await self._load_tenant_data(tenant_id)
|
||||
# Update mappings...
|
||||
```
|
||||
|
||||
Read operations (`get_tenant`, `get_api_key`) are lock-free since Python dict lookups are atomic.
|
||||
|
||||
---
|
||||
|
||||
### 2. API Client (`backend/onyx/onyxbot/discord/api_client.py`)
|
||||
|
||||
The `OnyxAPIClient` is a **stateless async HTTP client** that communicates with Onyx API pods.
|
||||
|
||||
#### Key Design: Per-Request API Key Injection
|
||||
|
||||
```python
|
||||
class OnyxAPIClient:
|
||||
async def send_chat_message(
|
||||
self,
|
||||
message: str,
|
||||
api_key: str, # Injected per-request
|
||||
persona_id: int | None,
|
||||
...
|
||||
) -> ChatFullResponse:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}", # Tenant-specific auth
|
||||
}
|
||||
# Make request...
|
||||
```
|
||||
|
||||
The client accepts `api_key` as a parameter to each method, enabling **dynamic tenant selection at request time**. This design allows a single client instance to serve multiple tenants:
|
||||
|
||||
```python
|
||||
# Same client, different tenants
|
||||
await api_client.send_chat_message(msg, api_key=key_for_tenant_1, ...)
|
||||
await api_client.send_chat_message(msg, api_key=key_for_tenant_2, ...)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Coordination Flow
|
||||
|
||||
### Message Processing Pipeline
|
||||
|
||||
When a Discord message arrives, the client coordinates cache and API client:
|
||||
|
||||
```python
|
||||
async def on_message(self, message: Message) -> None:
|
||||
guild_id = message.guild.id
|
||||
|
||||
# Step 1: Cache lookup - guild → tenant
|
||||
tenant_id = self.cache.get_tenant(guild_id)
|
||||
if not tenant_id:
|
||||
return # Guild not registered
|
||||
|
||||
# Step 2: Cache lookup - tenant → API key
|
||||
api_key = self.cache.get_api_key(tenant_id)
|
||||
if not api_key:
|
||||
logger.warning(f"No API key for tenant {tenant_id}")
|
||||
return
|
||||
|
||||
# Step 3: API call with tenant-specific credentials
|
||||
await process_chat_message(
|
||||
message=message,
|
||||
api_key=api_key, # Tenant-specific
|
||||
persona_id=persona_id, # Tenant-specific
|
||||
api_client=self.api_client,
|
||||
)
|
||||
```
|
||||
|
||||
### Startup Sequence
|
||||
|
||||
```python
|
||||
async def setup_hook(self) -> None:
|
||||
# 1. Initialize API client (create aiohttp session)
|
||||
await self.api_client.initialize()
|
||||
|
||||
# 2. Populate cache with all tenants
|
||||
await self.cache.refresh_all()
|
||||
|
||||
# 3. Start background refresh task
|
||||
self._cache_refresh_task = self.loop.create_task(
|
||||
self._periodic_cache_refresh() # Every 60 seconds
|
||||
)
|
||||
```
|
||||
|
||||
### Shutdown Sequence
|
||||
|
||||
```python
|
||||
async def close(self) -> None:
|
||||
# 1. Cancel background refresh
|
||||
if self._cache_refresh_task:
|
||||
self._cache_refresh_task.cancel()
|
||||
|
||||
# 2. Close Discord connection
|
||||
await super().close()
|
||||
|
||||
# 3. Close API client session
|
||||
await self.api_client.close()
|
||||
|
||||
# 4. Clear cache
|
||||
self.cache.clear()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tenant Isolation Mechanisms
|
||||
|
||||
### 1. Per-Tenant API Keys
|
||||
|
||||
Each tenant has a dedicated service API key:
|
||||
|
||||
```python
|
||||
# backend/onyx/db/discord_bot.py
|
||||
def get_or_create_discord_service_api_key(db_session: Session, tenant_id: str) -> str:
|
||||
existing = get_discord_service_api_key(db_session)
|
||||
if existing:
|
||||
return regenerate_key(existing)
|
||||
|
||||
# Create LIMITED role key (chat-only permissions)
|
||||
return insert_api_key(
|
||||
db_session=db_session,
|
||||
api_key_args=APIKeyArgs(
|
||||
name=DISCORD_SERVICE_API_KEY_NAME,
|
||||
role=UserRole.LIMITED, # Minimal permissions
|
||||
),
|
||||
user_id=None, # Service account (system-owned)
|
||||
).api_key
|
||||
```
|
||||
|
||||
### 2. Database Context Variables
|
||||
|
||||
The cache uses context variables for proper tenant-scoped DB sessions:
|
||||
|
||||
```python
|
||||
context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db:
|
||||
# All DB operations scoped to this tenant
|
||||
...
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token)
|
||||
```
|
||||
|
||||
### 3. Enterprise Gating Support
|
||||
|
||||
Gated tenants are filtered during cache refresh:
|
||||
|
||||
```python
|
||||
gated_tenants = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.product_gating",
|
||||
"get_gated_tenants",
|
||||
set(),
|
||||
)()
|
||||
|
||||
for tenant_id in get_all_tenant_ids():
|
||||
if tenant_id in gated_tenants:
|
||||
continue # Skip gated tenants
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Cache Refresh Strategy
|
||||
|
||||
| Trigger | Method | Scope |
|
||||
|---------|--------|-------|
|
||||
| Startup | `refresh_all()` | All tenants |
|
||||
| Periodic (60s) | `refresh_all()` | All tenants |
|
||||
| Guild registration | `refresh_guild()` | Single tenant |
|
||||
|
||||
### Error Handling
|
||||
|
||||
- **Tenant-level errors**: Logged and skipped (doesn't stop other tenants)
|
||||
- **Missing API key**: Bot silently ignores messages from that guild
|
||||
- **Network errors**: Logged, cache continues with stale data until next refresh
|
||||
|
||||
---
|
||||
|
||||
## Key Design Insights
|
||||
|
||||
1. **Single Client, Multiple Tenants**: One `OnyxAPIClient` and one `DiscordCacheManager` instance serves all tenants via dynamic API key injection.
|
||||
|
||||
2. **Cache-First Architecture**: Guild lookups are O(1) in-memory; API keys are cached after first provisioning to avoid repeated DB calls.
|
||||
|
||||
3. **Graceful Degradation**: If an API key is missing or stale, the bot simply doesn't respond (no crash or error propagation).
|
||||
|
||||
4. **Thread Safety Without Blocking**: `asyncio.Lock` prevents race conditions while maintaining async concurrency for reads.
|
||||
|
||||
5. **Lazy Provisioning**: API keys are only created when first needed, then cached for performance.
|
||||
|
||||
6. **Stateless API Client**: The HTTP client holds no tenant state - all tenant context is injected per-request via the `api_key` parameter.
|
||||
|
||||
---
|
||||
|
||||
## File References
|
||||
|
||||
| Component | Path |
|
||||
|-----------|------|
|
||||
| Cache Manager | `backend/onyx/onyxbot/discord/cache.py` |
|
||||
| API Client | `backend/onyx/onyxbot/discord/api_client.py` |
|
||||
| Discord Client | `backend/onyx/onyxbot/discord/client.py` |
|
||||
| API Key DB Operations | `backend/onyx/db/discord_bot.py` |
|
||||
| Cache Manager Tests | `backend/tests/unit/onyx/onyxbot/discord/test_cache_manager.py` |
|
||||
| API Client Tests | `backend/tests/unit/onyx/onyxbot/discord/test_api_client.py` |
|
||||
215
backend/onyx/onyxbot/discord/api_client.py
Normal file
215
backend/onyx/onyxbot/discord/api_client.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""Async HTTP client for communicating with Onyx API pods."""
|
||||
|
||||
import aiohttp
|
||||
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.onyxbot.discord.constants import API_REQUEST_TIMEOUT
|
||||
from onyx.onyxbot.discord.exceptions import APIConnectionError
|
||||
from onyx.onyxbot.discord.exceptions import APIResponseError
|
||||
from onyx.onyxbot.discord.exceptions import APITimeoutError
|
||||
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import build_api_server_url_for_http_requests
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class OnyxAPIClient:
|
||||
"""Async HTTP client for sending chat requests to Onyx API pods.
|
||||
|
||||
This client manages an aiohttp session for making non-blocking HTTP
|
||||
requests to the Onyx API server. It handles authentication with per-tenant
|
||||
API keys and multi-tenant routing.
|
||||
|
||||
Usage:
|
||||
client = OnyxAPIClient()
|
||||
await client.initialize()
|
||||
try:
|
||||
response = await client.send_chat_message(
|
||||
message="What is our deployment process?",
|
||||
tenant_id="tenant_123",
|
||||
api_key="dn_xxx...",
|
||||
persona_id=1,
|
||||
)
|
||||
print(response.answer)
|
||||
finally:
|
||||
await client.close()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout: int = API_REQUEST_TIMEOUT,
|
||||
) -> None:
|
||||
"""Initialize the API client.
|
||||
|
||||
Args:
|
||||
timeout: Request timeout in seconds.
|
||||
"""
|
||||
# Helm chart uses API_SERVER_URL_OVERRIDE_FOR_HTTP_REQUESTS to set the base URL
|
||||
# TODO: Ideally, this override is only used when someone is launching an Onyx service independently
|
||||
self._base_url = build_api_server_url_for_http_requests(
|
||||
respect_env_override_if_set=True
|
||||
).rstrip("/")
|
||||
self._timeout = timeout
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Create the aiohttp session.
|
||||
|
||||
Must be called before making any requests. The session is created
|
||||
with a total timeout and connection timeout.
|
||||
"""
|
||||
if self._session is not None:
|
||||
logger.warning("API client session already initialized")
|
||||
return
|
||||
|
||||
timeout = aiohttp.ClientTimeout(
|
||||
total=self._timeout,
|
||||
connect=30, # 30 seconds to establish connection
|
||||
)
|
||||
self._session = aiohttp.ClientSession(timeout=timeout)
|
||||
logger.info(f"API client initialized with base URL: {self._base_url}")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the aiohttp session.
|
||||
|
||||
Should be called when shutting down the bot to properly release
|
||||
resources.
|
||||
"""
|
||||
if self._session is not None:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
logger.info("API client session closed")
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if the session is initialized."""
|
||||
return self._session is not None
|
||||
|
||||
async def send_chat_message(
|
||||
self,
|
||||
message: str,
|
||||
api_key: str,
|
||||
persona_id: int | None = None,
|
||||
) -> ChatFullResponse:
|
||||
"""Send a chat message to the Onyx API server and get a response.
|
||||
|
||||
This method sends a non-streaming chat request to the API server. The response
|
||||
contains the complete answer with any citations and metadata.
|
||||
|
||||
Args:
|
||||
message: The user's message to process.
|
||||
api_key: The API key for authentication.
|
||||
persona_id: Optional persona ID to use for the response.
|
||||
|
||||
Returns:
|
||||
ChatFullResponse containing the answer, citations, and metadata.
|
||||
|
||||
Raises:
|
||||
APIConnectionError: If unable to connect to the API.
|
||||
APITimeoutError: If the request times out.
|
||||
APIResponseError: If the API returns an error response.
|
||||
"""
|
||||
if self._session is None:
|
||||
raise APIConnectionError(
|
||||
"API client not initialized. Call initialize() first."
|
||||
)
|
||||
|
||||
url = f"{self._base_url}/chat/send-chat-message"
|
||||
|
||||
# Build request payload
|
||||
request = SendMessageRequest(
|
||||
message=message,
|
||||
stream=False,
|
||||
origin=MessageOrigin.DISCORDBOT,
|
||||
chat_session_info=ChatSessionCreationRequest(
|
||||
persona_id=persona_id if persona_id is not None else 0,
|
||||
),
|
||||
)
|
||||
|
||||
# Build headers
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
|
||||
try:
|
||||
async with self._session.post(
|
||||
url,
|
||||
json=request.model_dump(mode="json"),
|
||||
headers=headers,
|
||||
) as response:
|
||||
if response.status == 401:
|
||||
raise APIResponseError(
|
||||
"Authentication failed - invalid API key",
|
||||
status_code=401,
|
||||
)
|
||||
elif response.status == 403:
|
||||
raise APIResponseError(
|
||||
"Access denied - insufficient permissions",
|
||||
status_code=403,
|
||||
)
|
||||
elif response.status == 404:
|
||||
raise APIResponseError(
|
||||
"API endpoint not found",
|
||||
status_code=404,
|
||||
)
|
||||
elif response.status >= 500:
|
||||
error_text = await response.text()
|
||||
raise APIResponseError(
|
||||
f"Server error: {error_text}",
|
||||
status_code=response.status,
|
||||
)
|
||||
elif response.status >= 400:
|
||||
error_text = await response.text()
|
||||
raise APIResponseError(
|
||||
f"Request error: {error_text}",
|
||||
status_code=response.status,
|
||||
)
|
||||
|
||||
# Parse successful response
|
||||
data = await response.json()
|
||||
response_obj = ChatFullResponse.model_validate(data)
|
||||
|
||||
if response_obj.error_msg:
|
||||
logger.warning(f"Chat API returned error: {response_obj.error_msg}")
|
||||
|
||||
return response_obj
|
||||
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error(f"Failed to connect to API: {e}")
|
||||
raise APIConnectionError(
|
||||
f"Failed to connect to API at {self._base_url}: {e}"
|
||||
) from e
|
||||
|
||||
except TimeoutError as e:
|
||||
logger.error(f"API request timed out after {self._timeout}s")
|
||||
raise APITimeoutError(
|
||||
f"Request timed out after {self._timeout} seconds"
|
||||
) from e
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"HTTP client error: {e}")
|
||||
raise APIConnectionError(f"HTTP client error: {e}") from e
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
"""Check if the API server is healthy.
|
||||
|
||||
Returns:
|
||||
True if the API server is reachable and healthy, False otherwise.
|
||||
"""
|
||||
if self._session is None:
|
||||
logger.warning("API client not initialized. Call initialize() first.")
|
||||
return False
|
||||
|
||||
try:
|
||||
url = f"{self._base_url}/health"
|
||||
async with self._session.get(
|
||||
url, timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as response:
|
||||
return response.status == 200
|
||||
except Exception as e:
|
||||
logger.warning(f"API server health check failed: {e}")
|
||||
return False
|
||||
154
backend/onyx/onyxbot/discord/cache.py
Normal file
154
backend/onyx/onyxbot/discord/cache.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""Multi-tenant cache for Discord bot guild-tenant mappings and API keys."""
|
||||
|
||||
import asyncio
|
||||
|
||||
from onyx.db.discord_bot import get_guild_configs
|
||||
from onyx.db.discord_bot import get_or_create_discord_service_api_key
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.onyxbot.discord.exceptions import CacheError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class DiscordCacheManager:
|
||||
"""Caches guild->tenant mappings and tenant->API key mappings.
|
||||
|
||||
Refreshed on startup, periodically (every 60s), and when guilds register.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._guild_tenants: dict[int, str] = {} # guild_id -> tenant_id
|
||||
self._api_keys: dict[str, str] = {} # tenant_id -> api_key
|
||||
self._lock = asyncio.Lock()
|
||||
self._initialized = False
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
return self._initialized
|
||||
|
||||
async def refresh_all(self) -> None:
|
||||
"""Full cache refresh from all tenants."""
|
||||
async with self._lock:
|
||||
logger.info("Starting Discord cache refresh")
|
||||
|
||||
new_guild_tenants: dict[int, str] = {}
|
||||
new_api_keys: dict[str, str] = {}
|
||||
|
||||
try:
|
||||
gated = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.product_gating",
|
||||
"get_gated_tenants",
|
||||
set(),
|
||||
)()
|
||||
|
||||
tenant_ids = await asyncio.to_thread(get_all_tenant_ids)
|
||||
for tenant_id in tenant_ids:
|
||||
if tenant_id in gated:
|
||||
continue
|
||||
|
||||
context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
guild_ids, api_key = await self._load_tenant_data(tenant_id)
|
||||
if not guild_ids:
|
||||
logger.debug(f"No guilds found for tenant {tenant_id}")
|
||||
continue
|
||||
|
||||
if not api_key:
|
||||
logger.warning(
|
||||
"Discord service API key missing for tenant that has registered guilds. "
|
||||
f"{tenant_id} will not be handled in this refresh cycle."
|
||||
)
|
||||
continue
|
||||
|
||||
for guild_id in guild_ids:
|
||||
new_guild_tenants[guild_id] = tenant_id
|
||||
|
||||
new_api_keys[tenant_id] = api_key
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to refresh tenant {tenant_id}: {e}")
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token)
|
||||
|
||||
self._guild_tenants = new_guild_tenants
|
||||
self._api_keys = new_api_keys
|
||||
self._initialized = True
|
||||
|
||||
logger.info(
|
||||
f"Cache refresh complete: {len(new_guild_tenants)} guilds, "
|
||||
f"{len(new_api_keys)} tenants"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cache refresh failed: {e}")
|
||||
raise CacheError(f"Failed to refresh cache: {e}") from e
|
||||
|
||||
async def refresh_guild(self, guild_id: int, tenant_id: str) -> None:
|
||||
"""Add a single guild to cache after registration."""
|
||||
async with self._lock:
|
||||
logger.info(f"Refreshing cache for guild {guild_id} (tenant: {tenant_id})")
|
||||
|
||||
guild_ids, api_key = await self._load_tenant_data(tenant_id)
|
||||
|
||||
if guild_id in guild_ids:
|
||||
self._guild_tenants[guild_id] = tenant_id
|
||||
if api_key:
|
||||
self._api_keys[tenant_id] = api_key
|
||||
logger.info(f"Cache updated for guild {guild_id}")
|
||||
else:
|
||||
logger.warning(f"Guild {guild_id} not found or disabled")
|
||||
|
||||
async def _load_tenant_data(self, tenant_id: str) -> tuple[list[int], str | None]:
|
||||
"""Load guild IDs and provision API key if needed.
|
||||
|
||||
Returns:
|
||||
(active_guild_ids, api_key) - api_key is the cached key if available,
|
||||
otherwise a newly created key. Returns None if no guilds found.
|
||||
"""
|
||||
cached_key = self._api_keys.get(tenant_id)
|
||||
|
||||
def _sync() -> tuple[list[int], str | None]:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db:
|
||||
configs = get_guild_configs(db)
|
||||
guild_ids = [
|
||||
config.guild_id
|
||||
for config in configs
|
||||
if config.enabled and config.guild_id is not None
|
||||
]
|
||||
|
||||
if not guild_ids:
|
||||
return [], None
|
||||
|
||||
if not cached_key:
|
||||
new_key = get_or_create_discord_service_api_key(db, tenant_id)
|
||||
db.commit()
|
||||
return guild_ids, new_key
|
||||
|
||||
return guild_ids, cached_key
|
||||
|
||||
return await asyncio.to_thread(_sync)
|
||||
|
||||
def get_tenant(self, guild_id: int) -> str | None:
|
||||
"""Get tenant ID for a guild."""
|
||||
return self._guild_tenants.get(guild_id)
|
||||
|
||||
def get_api_key(self, tenant_id: str) -> str | None:
|
||||
"""Get API key for a tenant."""
|
||||
return self._api_keys.get(tenant_id)
|
||||
|
||||
def remove_guild(self, guild_id: int) -> None:
|
||||
"""Remove a guild from cache."""
|
||||
self._guild_tenants.pop(guild_id, None)
|
||||
|
||||
def get_all_guild_ids(self) -> list[int]:
|
||||
"""Get all cached guild IDs."""
|
||||
return list(self._guild_tenants.keys())
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all caches."""
|
||||
self._guild_tenants.clear()
|
||||
self._api_keys.clear()
|
||||
self._initialized = False
|
||||
232
backend/onyx/onyxbot/discord/client.py
Normal file
232
backend/onyx/onyxbot/discord/client.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""Discord bot client with integrated message handling."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
||||
from onyx.configs.app_configs import DISCORD_BOT_INVOKE_CHAR
|
||||
from onyx.onyxbot.discord.api_client import OnyxAPIClient
|
||||
from onyx.onyxbot.discord.cache import DiscordCacheManager
|
||||
from onyx.onyxbot.discord.constants import CACHE_REFRESH_INTERVAL
|
||||
from onyx.onyxbot.discord.handle_commands import handle_dm
|
||||
from onyx.onyxbot.discord.handle_commands import handle_registration_command
|
||||
from onyx.onyxbot.discord.handle_commands import handle_sync_channels_command
|
||||
from onyx.onyxbot.discord.handle_message import process_chat_message
|
||||
from onyx.onyxbot.discord.handle_message import should_respond
|
||||
from onyx.onyxbot.discord.utils import get_bot_token
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class OnyxDiscordClient(commands.Bot):
|
||||
"""Discord bot client with integrated cache, API client, and message handling.
|
||||
|
||||
This client handles:
|
||||
- Guild registration via !register command
|
||||
- Message processing with persona-based responses
|
||||
- Thread context for conversation continuity
|
||||
- Multi-tenant support via cached API keys
|
||||
"""
|
||||
|
||||
def __init__(self, command_prefix: str = DISCORD_BOT_INVOKE_CHAR) -> None:
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
intents.members = True
|
||||
|
||||
super().__init__(command_prefix=command_prefix, intents=intents)
|
||||
|
||||
self.ready = False
|
||||
self.cache = DiscordCacheManager()
|
||||
self.api_client = OnyxAPIClient()
|
||||
self._cache_refresh_task: asyncio.Task | None = None
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Lifecycle Methods
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def setup_hook(self) -> None:
|
||||
"""Called before on_ready. Initialize components."""
|
||||
logger.info("Initializing Discord bot components...")
|
||||
|
||||
# Initialize API client
|
||||
await self.api_client.initialize()
|
||||
|
||||
# Initial cache load
|
||||
await self.cache.refresh_all()
|
||||
|
||||
# Start periodic cache refresh
|
||||
self._cache_refresh_task = self.loop.create_task(self._periodic_cache_refresh())
|
||||
|
||||
logger.info("Discord bot components initialized")
|
||||
|
||||
async def _periodic_cache_refresh(self) -> None:
|
||||
"""Background task to refresh cache periodically."""
|
||||
while not self.is_closed():
|
||||
await asyncio.sleep(CACHE_REFRESH_INTERVAL)
|
||||
try:
|
||||
await self.cache.refresh_all()
|
||||
except Exception as e:
|
||||
logger.error(f"Cache refresh failed: {e}")
|
||||
|
||||
async def on_ready(self) -> None:
|
||||
"""Bot connected and ready."""
|
||||
if self.ready:
|
||||
return
|
||||
|
||||
if not self.user:
|
||||
raise RuntimeError("Critical error: Discord Bot user not found")
|
||||
|
||||
logger.info(f"Discord Bot connected as {self.user} (ID: {self.user.id})")
|
||||
logger.info(f"Connected to {len(self.guilds)} guild(s)")
|
||||
logger.info(f"Cached {len(self.cache.get_all_guild_ids())} registered guild(s)")
|
||||
|
||||
self.ready = True
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Graceful shutdown."""
|
||||
logger.info("Shutting down Discord bot...")
|
||||
|
||||
# Cancel cache refresh task
|
||||
if self._cache_refresh_task:
|
||||
self._cache_refresh_task.cancel()
|
||||
try:
|
||||
await self._cache_refresh_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Close Discord connection first - stops new commands from triggering cache ops
|
||||
if not self.is_closed():
|
||||
await super().close()
|
||||
|
||||
# Close API client
|
||||
await self.api_client.close()
|
||||
|
||||
# Clear cache (safe now - no concurrent operations possible)
|
||||
self.cache.clear()
|
||||
|
||||
self.ready = False
|
||||
logger.info("Discord bot shutdown complete")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Message Handling
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def on_message(self, message: discord.Message) -> None:
|
||||
"""Main message handler."""
|
||||
# mypy
|
||||
if not self.user:
|
||||
raise RuntimeError("Critical error: Discord Bot user not found")
|
||||
|
||||
try:
|
||||
# Ignore bot messages
|
||||
if message.author.bot:
|
||||
return
|
||||
|
||||
# Ignore thread starter messages (empty reference nodes that don't contain content)
|
||||
if message.type == discord.MessageType.thread_starter_message:
|
||||
return
|
||||
|
||||
# Handle DMs
|
||||
if isinstance(message.channel, discord.DMChannel):
|
||||
await handle_dm(message)
|
||||
return
|
||||
|
||||
# Must have a guild
|
||||
if not message.guild or not message.guild.id:
|
||||
return
|
||||
|
||||
guild_id = message.guild.id
|
||||
|
||||
# Check for registration command first
|
||||
if await handle_registration_command(message, self.cache):
|
||||
return
|
||||
|
||||
# Look up guild in cache
|
||||
tenant_id = self.cache.get_tenant(guild_id)
|
||||
|
||||
# Check for sync-channels command (requires registered guild)
|
||||
if await handle_sync_channels_command(message, tenant_id, self):
|
||||
return
|
||||
|
||||
if not tenant_id:
|
||||
# Guild not registered, ignore
|
||||
return
|
||||
|
||||
# Get API key
|
||||
api_key = self.cache.get_api_key(tenant_id)
|
||||
if not api_key:
|
||||
logger.warning(f"No API key cached for tenant {tenant_id}")
|
||||
return
|
||||
|
||||
# Check if bot should respond
|
||||
should_respond_context = await should_respond(message, tenant_id, self.user)
|
||||
|
||||
if not should_respond_context.should_respond:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Processing message: '{message.content[:50]}' in "
|
||||
f"#{getattr(message.channel, 'name', 'unknown')} ({message.guild.name}), "
|
||||
f"persona_id={should_respond_context.persona_id}"
|
||||
)
|
||||
|
||||
# Process the message
|
||||
await process_chat_message(
|
||||
message=message,
|
||||
api_key=api_key,
|
||||
persona_id=should_respond_context.persona_id,
|
||||
thread_only_mode=should_respond_context.thread_only_mode,
|
||||
api_client=self.api_client,
|
||||
bot_user=self.user,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing message: {e}")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Entry Point
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main entry point for Discord bot."""
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
logger.info("Starting Onyx Discord Bot...")
|
||||
|
||||
# Initialize the database engine (required before any DB operations)
|
||||
SqlEngine.init_engine(pool_size=20, max_overflow=5)
|
||||
|
||||
# Initialize EE features based on environment
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
counter = 0
|
||||
while True:
|
||||
token = get_bot_token()
|
||||
if not token:
|
||||
if counter % 180 == 0:
|
||||
logger.info(
|
||||
"Discord bot is dormant. Waiting for token configuration..."
|
||||
)
|
||||
counter += 1
|
||||
time.sleep(5)
|
||||
continue
|
||||
counter = 0
|
||||
bot = OnyxDiscordClient()
|
||||
|
||||
try:
|
||||
# bot.run() handles SIGINT/SIGTERM and calls close() automatically
|
||||
bot.run(token)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Fatal error in Discord bot")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
19
backend/onyx/onyxbot/discord/constants.py
Normal file
19
backend/onyx/onyxbot/discord/constants.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Discord bot constants."""
|
||||
|
||||
# API settings
|
||||
API_REQUEST_TIMEOUT: int = 3 * 60 # 3 minutes
|
||||
|
||||
# Cache settings
|
||||
CACHE_REFRESH_INTERVAL: int = 60 # 1 minute
|
||||
|
||||
# Message settings
|
||||
MAX_MESSAGE_LENGTH: int = 2000 # Discord's character limit
|
||||
MAX_CONTEXT_MESSAGES: int = 10 # Max messages to include in conversation context
|
||||
# Note: Discord.py's add_reaction() requires unicode emoji, not :name: format
|
||||
THINKING_EMOJI: str = "🤔" # U+1F914 - Thinking Face
|
||||
SUCCESS_EMOJI: str = "✅" # U+2705 - White Heavy Check Mark
|
||||
ERROR_EMOJI: str = "❌" # U+274C - Cross Mark
|
||||
|
||||
# Command prefix
|
||||
REGISTER_COMMAND: str = "register"
|
||||
SYNC_CHANNELS_COMMAND: str = "sync-channels"
|
||||
37
backend/onyx/onyxbot/discord/exceptions.py
Normal file
37
backend/onyx/onyxbot/discord/exceptions.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Custom exception classes for Discord bot."""
|
||||
|
||||
|
||||
class DiscordBotError(Exception):
|
||||
"""Base exception for Discord bot errors."""
|
||||
|
||||
|
||||
class RegistrationError(DiscordBotError):
|
||||
"""Error during guild registration."""
|
||||
|
||||
|
||||
class SyncChannelsError(DiscordBotError):
|
||||
"""Error during channel sync."""
|
||||
|
||||
|
||||
class APIError(DiscordBotError):
|
||||
"""Base API error."""
|
||||
|
||||
|
||||
class CacheError(DiscordBotError):
|
||||
"""Error during cache operations."""
|
||||
|
||||
|
||||
class APIConnectionError(APIError):
|
||||
"""Failed to connect to API."""
|
||||
|
||||
|
||||
class APITimeoutError(APIError):
|
||||
"""Request timed out."""
|
||||
|
||||
|
||||
class APIResponseError(APIError):
|
||||
"""API returned an error response."""
|
||||
|
||||
def __init__(self, message: str, status_code: int | None = None):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
437
backend/onyx/onyxbot/discord/handle_commands.py
Normal file
437
backend/onyx/onyxbot/discord/handle_commands.py
Normal file
@@ -0,0 +1,437 @@
|
||||
"""Discord bot command handlers for registration and channel sync."""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import discord
|
||||
|
||||
from onyx.configs.app_configs import DISCORD_BOT_INVOKE_CHAR
|
||||
from onyx.configs.constants import ONYX_DISCORD_URL
|
||||
from onyx.db.discord_bot import bulk_create_channel_configs
|
||||
from onyx.db.discord_bot import get_guild_config_by_discord_id
|
||||
from onyx.db.discord_bot import get_guild_config_by_internal_id
|
||||
from onyx.db.discord_bot import get_guild_config_by_registration_key
|
||||
from onyx.db.discord_bot import sync_channel_configs
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.utils import DiscordChannelView
|
||||
from onyx.onyxbot.discord.cache import DiscordCacheManager
|
||||
from onyx.onyxbot.discord.constants import REGISTER_COMMAND
|
||||
from onyx.onyxbot.discord.constants import SYNC_CHANNELS_COMMAND
|
||||
from onyx.onyxbot.discord.exceptions import RegistrationError
|
||||
from onyx.onyxbot.discord.exceptions import SyncChannelsError
|
||||
from onyx.server.manage.discord_bot.utils import parse_discord_registration_key
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
async def handle_dm(message: discord.Message) -> None:
|
||||
"""Handle direct messages."""
|
||||
dm_response = (
|
||||
"**I can't respond to DMs** :sweat:\n\n"
|
||||
f"Please chat with me in a server channel, or join the official "
|
||||
f"[Onyx Discord]({ONYX_DISCORD_URL}) for help!"
|
||||
)
|
||||
await message.channel.send(dm_response)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Helper functions for error handling
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _try_dm_author(message: discord.Message, content: str) -> bool:
|
||||
"""Attempt to DM the message author. Returns True if successful."""
|
||||
logger.debug(f"Responding in Discord DM with {content}")
|
||||
try:
|
||||
await message.author.send(content)
|
||||
return True
|
||||
except (discord.Forbidden, discord.HTTPException) as e:
|
||||
# User has DMs disabled or other error
|
||||
logger.warning(f"Failed to DM author {message.author.id}: {e}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error DMing author {message.author.id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def _try_delete_message(message: discord.Message) -> bool:
|
||||
"""Attempt to delete a message. Returns True if successful."""
|
||||
logger.debug(f"Deleting potentially sensitive message {message.id}")
|
||||
try:
|
||||
await message.delete()
|
||||
return True
|
||||
except (discord.Forbidden, discord.HTTPException) as e:
|
||||
# Bot lacks permission or other error
|
||||
logger.warning(f"Failed to delete message {message.id}: {e}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error deleting message {message.id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def _try_react_x(message: discord.Message) -> bool:
|
||||
"""Attempt to react to a message with ❌. Returns True if successful."""
|
||||
try:
|
||||
await message.add_reaction("❌")
|
||||
return True
|
||||
except (discord.Forbidden, discord.HTTPException) as e:
|
||||
# Bot lacks permission or other error
|
||||
logger.warning(f"Failed to react to message {message.id}: {e}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error reacting to message {message.id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Registration
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def handle_registration_command(
|
||||
message: discord.Message,
|
||||
cache: DiscordCacheManager,
|
||||
) -> bool:
|
||||
"""Handle !register command. Returns True if command was handled."""
|
||||
content = message.content.strip()
|
||||
|
||||
# Check for !register command
|
||||
if not content.startswith(f"{DISCORD_BOT_INVOKE_CHAR}{REGISTER_COMMAND}"):
|
||||
return False
|
||||
|
||||
# Must be in a server
|
||||
if not message.guild:
|
||||
await _try_dm_author(
|
||||
message, "This command can only be used in a server channel."
|
||||
)
|
||||
return True
|
||||
|
||||
guild_name = message.guild.name
|
||||
logger.info(f"Registration command received: {guild_name}")
|
||||
|
||||
try:
|
||||
# Parse the registration key
|
||||
parts = content.split(maxsplit=1)
|
||||
if len(parts) < 2:
|
||||
raise RegistrationError(
|
||||
"Invalid registration key format. Please check the key and try again."
|
||||
)
|
||||
|
||||
registration_key = parts[1].strip()
|
||||
|
||||
if not message.author or not isinstance(message.author, discord.Member):
|
||||
raise RegistrationError(
|
||||
"You need to be a server administrator to register the bot."
|
||||
)
|
||||
|
||||
# Check permissions - require admin or manage_guild
|
||||
if not message.author.guild_permissions.administrator:
|
||||
if not message.author.guild_permissions.manage_guild:
|
||||
raise RegistrationError(
|
||||
"You need **Administrator** or **Manage Server** permissions "
|
||||
"to register this bot."
|
||||
)
|
||||
|
||||
await _register_guild(message, registration_key, cache)
|
||||
logger.info(f"Registration successful: {guild_name}")
|
||||
await message.reply(
|
||||
":white_check_mark: **Successfully registered!**\n\n"
|
||||
"This server is now connected to Onyx. "
|
||||
"I'll respond to messages based on your server and channel settings set in Onyx."
|
||||
)
|
||||
except RegistrationError as e:
|
||||
logger.debug(f"Registration failed: {guild_name}, error={e}")
|
||||
await _try_dm_author(message, f":x: **Registration failed.**\n\n{e}")
|
||||
await _try_delete_message(message)
|
||||
except Exception:
|
||||
logger.exception(f"Registration failed unexpectedly: {guild_name}")
|
||||
await _try_dm_author(
|
||||
message,
|
||||
":x: **Registration failed.**\n\n"
|
||||
"An unexpected error occurred. Please try again later.",
|
||||
)
|
||||
await _try_delete_message(message)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def _register_guild(
|
||||
message: discord.Message,
|
||||
registration_key: str,
|
||||
cache: DiscordCacheManager,
|
||||
) -> None:
|
||||
"""Register a guild with a registration key."""
|
||||
if not message.guild:
|
||||
# mypy, even though we already know that message.guild is not None
|
||||
raise RegistrationError("This command can only be used in a server.")
|
||||
|
||||
logger.info(f"Guild '{message.guild.name}' attempting to register Discord bot")
|
||||
registration_key = registration_key.strip()
|
||||
|
||||
# Parse tenant_id from registration key
|
||||
parsed = parse_discord_registration_key(registration_key)
|
||||
if parsed is None:
|
||||
raise RegistrationError(
|
||||
"Invalid registration key format. Please check the key and try again."
|
||||
)
|
||||
|
||||
tenant_id = parsed
|
||||
|
||||
logger.info(f"Parsed tenant_id {tenant_id} from registration key")
|
||||
|
||||
# Check if this guild is already registered to any tenant
|
||||
guild_id = message.guild.id
|
||||
existing_tenant = cache.get_tenant(guild_id)
|
||||
if existing_tenant is not None:
|
||||
logger.warning(
|
||||
f"Guild {guild_id} is already registered to tenant {existing_tenant}"
|
||||
)
|
||||
raise RegistrationError(
|
||||
"This server is already registered.\n\n"
|
||||
"OnyxBot can only connect one Discord server to one Onyx workspace."
|
||||
)
|
||||
|
||||
context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
guild = message.guild
|
||||
guild_name = guild.name
|
||||
|
||||
# Collect all text channels from the guild
|
||||
channels = get_text_channels(guild)
|
||||
logger.info(f"Found {len(channels)} text channels in guild '{guild_name}'")
|
||||
|
||||
# Validate and update in database
|
||||
def _sync_register() -> int:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db:
|
||||
# Find the guild config by registration key
|
||||
config = get_guild_config_by_registration_key(db, registration_key)
|
||||
if not config:
|
||||
raise RegistrationError(
|
||||
"Registration key not found.\n\n"
|
||||
"The key may have expired or been deleted. "
|
||||
"Please generate a new one from the Onyx admin panel."
|
||||
)
|
||||
|
||||
# Check if already used
|
||||
if config.guild_id is not None:
|
||||
raise RegistrationError(
|
||||
"This registration key has already been used.\n\n"
|
||||
"Each key can only be used once. "
|
||||
"Please generate a new key from the Onyx admin panel."
|
||||
)
|
||||
|
||||
# Update the guild config
|
||||
config.guild_id = guild_id
|
||||
config.guild_name = guild_name
|
||||
config.registered_at = datetime.now(timezone.utc)
|
||||
|
||||
# Create channel configs for all text channels
|
||||
bulk_create_channel_configs(db, config.id, channels)
|
||||
|
||||
db.commit()
|
||||
return config.id
|
||||
|
||||
await asyncio.to_thread(_sync_register)
|
||||
|
||||
# Refresh cache for this guild
|
||||
await cache.refresh_guild(guild_id, tenant_id)
|
||||
|
||||
logger.info(
|
||||
f"Guild '{guild_name}' registered with {len(channels)} channel configs"
|
||||
)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token)
|
||||
|
||||
|
||||
def get_text_channels(guild: discord.Guild) -> list[DiscordChannelView]:
|
||||
"""Get all text channels from a guild as DiscordChannelView objects."""
|
||||
channels: list[DiscordChannelView] = []
|
||||
for channel in guild.channels:
|
||||
# Include text channels and forum channels (where threads can be created)
|
||||
if isinstance(channel, (discord.TextChannel, discord.ForumChannel)):
|
||||
# Check if channel is private (not visible to @everyone)
|
||||
everyone_perms = channel.permissions_for(guild.default_role)
|
||||
is_private = not everyone_perms.view_channel
|
||||
|
||||
logger.debug(
|
||||
f"Found channel: #{channel.name}, "
|
||||
f"type={channel.type.name}, is_private={is_private}"
|
||||
)
|
||||
|
||||
channels.append(
|
||||
DiscordChannelView(
|
||||
channel_id=channel.id,
|
||||
channel_name=channel.name,
|
||||
channel_type=channel.type.name, # "text" or "forum"
|
||||
is_private=is_private,
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"Retrieved {len(channels)} channels from guild '{guild.name}'")
|
||||
return channels
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Sync Channels
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def handle_sync_channels_command(
|
||||
message: discord.Message,
|
||||
tenant_id: str | None,
|
||||
bot: discord.Client,
|
||||
) -> bool:
|
||||
"""Handle !sync-channels command. Returns True if command was handled."""
|
||||
content = message.content.strip()
|
||||
|
||||
# Check for !sync-channels command
|
||||
if not content.startswith(f"{DISCORD_BOT_INVOKE_CHAR}{SYNC_CHANNELS_COMMAND}"):
|
||||
return False
|
||||
|
||||
# Must be in a server
|
||||
if not message.guild:
|
||||
await _try_dm_author(
|
||||
message, "This command can only be used in a server channel."
|
||||
)
|
||||
return True
|
||||
|
||||
guild_name = message.guild.name
|
||||
logger.info(f"Sync-channels command received: {guild_name}")
|
||||
|
||||
try:
|
||||
# Must be registered
|
||||
if not tenant_id:
|
||||
raise SyncChannelsError(
|
||||
"This server is not registered. Please register it first."
|
||||
)
|
||||
|
||||
# Check permissions - require admin or manage_guild
|
||||
if not message.author or not isinstance(message.author, discord.Member):
|
||||
raise SyncChannelsError(
|
||||
"You need to be a server administrator to sync channels."
|
||||
)
|
||||
|
||||
if not message.author.guild_permissions.administrator:
|
||||
if not message.author.guild_permissions.manage_guild:
|
||||
raise SyncChannelsError(
|
||||
"You need **Administrator** or **Manage Server** permissions "
|
||||
"to sync channels."
|
||||
)
|
||||
|
||||
# Get guild config ID
|
||||
def _get_guild_config_id() -> int | None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db:
|
||||
if not message.guild:
|
||||
raise SyncChannelsError(
|
||||
"Server not found. This shouldn't happen. Please contact Onyx support."
|
||||
)
|
||||
config = get_guild_config_by_discord_id(db, message.guild.id)
|
||||
return config.id if config else None
|
||||
|
||||
guild_config_id = await asyncio.to_thread(_get_guild_config_id)
|
||||
|
||||
if not guild_config_id:
|
||||
raise SyncChannelsError(
|
||||
"Server config not found. This shouldn't happen. Please contact Onyx support."
|
||||
)
|
||||
|
||||
# Perform the sync
|
||||
added, removed, updated = await sync_guild_channels(
|
||||
guild_config_id, tenant_id, bot
|
||||
)
|
||||
logger.info(
|
||||
f"Sync-channels successful: {guild_name}, "
|
||||
f"added={added}, removed={removed}, updated={updated}"
|
||||
)
|
||||
await message.reply(
|
||||
f":white_check_mark: **Channel sync complete!**\n\n"
|
||||
f"* **{added}** new channel(s) added\n"
|
||||
f"* **{removed}** deleted channel(s) removed\n"
|
||||
f"* **{updated}** channel name(s) updated\n\n"
|
||||
"New channels are disabled by default. Enable them in the Onyx admin panel."
|
||||
)
|
||||
except SyncChannelsError as e:
|
||||
logger.debug(f"Sync-channels failed: {guild_name}, error={e}")
|
||||
await _try_dm_author(message, f":x: **Channel sync failed.**\n\n{e}")
|
||||
await _try_react_x(message)
|
||||
except Exception:
|
||||
logger.exception(f"Sync-channels failed unexpectedly: {guild_name}")
|
||||
await _try_dm_author(
|
||||
message,
|
||||
":x: **Channel sync failed.**\n\n"
|
||||
"An unexpected error occurred. Please try again later.",
|
||||
)
|
||||
await _try_react_x(message)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def sync_guild_channels(
|
||||
guild_config_id: int,
|
||||
tenant_id: str,
|
||||
bot: discord.Client,
|
||||
) -> tuple[int, int, int]:
|
||||
"""Sync channel configs with current Discord channels for a guild.
|
||||
|
||||
Fetches current channels from Discord and syncs with database:
|
||||
- Creates configs for new channels (disabled by default)
|
||||
- Removes configs for deleted channels
|
||||
- Updates names for existing channels if changed
|
||||
|
||||
Args:
|
||||
guild_config_id: Internal ID of the guild config
|
||||
tenant_id: Tenant ID for database access
|
||||
bot: Discord bot client
|
||||
|
||||
Returns:
|
||||
(added_count, removed_count, updated_count)
|
||||
|
||||
Raises:
|
||||
ValueError: If guild config not found or guild not registered
|
||||
"""
|
||||
context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
# Get guild_id from config
|
||||
def _get_guild_id() -> int | None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db:
|
||||
config = get_guild_config_by_internal_id(db, guild_config_id)
|
||||
if not config:
|
||||
return None
|
||||
return config.guild_id
|
||||
|
||||
guild_id = await asyncio.to_thread(_get_guild_id)
|
||||
|
||||
if guild_id is None:
|
||||
raise ValueError(
|
||||
f"Guild config {guild_config_id} not found or not registered"
|
||||
)
|
||||
|
||||
# Get the guild from Discord
|
||||
guild = bot.get_guild(guild_id)
|
||||
if not guild:
|
||||
raise ValueError(f"Guild {guild_id} not found in Discord cache")
|
||||
|
||||
# Get current channels from Discord
|
||||
channels = get_text_channels(guild)
|
||||
logger.info(f"Syncing {len(channels)} channels for guild '{guild.name}'")
|
||||
|
||||
# Sync with database
|
||||
def _sync() -> tuple[int, int, int]:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db:
|
||||
added, removed, updated = sync_channel_configs(
|
||||
db, guild_config_id, channels
|
||||
)
|
||||
db.commit()
|
||||
return added, removed, updated
|
||||
|
||||
added, removed, updated = await asyncio.to_thread(_sync)
|
||||
|
||||
logger.info(
|
||||
f"Channel sync complete for guild '{guild.name}': "
|
||||
f"added={added}, removed={removed}, updated={updated}"
|
||||
)
|
||||
|
||||
return added, removed, updated
|
||||
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token)
|
||||
493
backend/onyx/onyxbot/discord/handle_message.py
Normal file
493
backend/onyx/onyxbot/discord/handle_message.py
Normal file
@@ -0,0 +1,493 @@
|
||||
"""Discord bot message handling and response logic."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import discord
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.db.discord_bot import get_channel_config_by_discord_ids
|
||||
from onyx.db.discord_bot import get_guild_config_by_discord_id
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.models import DiscordChannelConfig
|
||||
from onyx.db.models import DiscordGuildConfig
|
||||
from onyx.onyxbot.discord.api_client import OnyxAPIClient
|
||||
from onyx.onyxbot.discord.constants import MAX_CONTEXT_MESSAGES
|
||||
from onyx.onyxbot.discord.constants import MAX_MESSAGE_LENGTH
|
||||
from onyx.onyxbot.discord.constants import THINKING_EMOJI
|
||||
from onyx.onyxbot.discord.exceptions import APIError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Message types with actual content (excludes system notifications like "user joined")
|
||||
CONTENT_MESSAGE_TYPES = (
|
||||
discord.MessageType.default,
|
||||
discord.MessageType.reply,
|
||||
discord.MessageType.thread_starter_message,
|
||||
)
|
||||
|
||||
|
||||
class ShouldRespondContext(BaseModel):
|
||||
"""Context for whether the bot should respond to a message."""
|
||||
|
||||
should_respond: bool
|
||||
persona_id: int | None
|
||||
thread_only_mode: bool
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Response Logic
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def should_respond(
|
||||
message: discord.Message,
|
||||
tenant_id: str,
|
||||
bot_user: discord.ClientUser,
|
||||
) -> ShouldRespondContext:
|
||||
"""Determine if bot should respond and which persona to use."""
|
||||
if not message.guild:
|
||||
logger.warning("Received a message that isn't in a server.")
|
||||
return ShouldRespondContext(
|
||||
should_respond=False, persona_id=None, thread_only_mode=False
|
||||
)
|
||||
|
||||
guild_id = message.guild.id
|
||||
channel_id = message.channel.id
|
||||
bot_mentioned = bot_user in message.mentions
|
||||
|
||||
def _get_configs() -> tuple[DiscordGuildConfig | None, DiscordChannelConfig | None]:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db:
|
||||
guild_config = get_guild_config_by_discord_id(db, guild_id)
|
||||
if not guild_config or not guild_config.enabled:
|
||||
return None, None
|
||||
|
||||
# For threads, use parent channel ID
|
||||
actual_channel_id = channel_id
|
||||
if isinstance(message.channel, discord.Thread) and message.channel.parent:
|
||||
actual_channel_id = message.channel.parent.id
|
||||
|
||||
channel_config = get_channel_config_by_discord_ids(
|
||||
db, guild_id, actual_channel_id
|
||||
)
|
||||
return guild_config, channel_config
|
||||
|
||||
guild_config, channel_config = await asyncio.to_thread(_get_configs)
|
||||
|
||||
if not guild_config or not channel_config or not channel_config.enabled:
|
||||
return ShouldRespondContext(
|
||||
should_respond=False, persona_id=None, thread_only_mode=False
|
||||
)
|
||||
|
||||
# Determine persona (channel override or guild default)
|
||||
persona_id = channel_config.persona_override_id or guild_config.default_persona_id
|
||||
|
||||
# Check mention requirement (with exceptions for implicit invocation)
|
||||
if channel_config.require_bot_invocation and not bot_mentioned:
|
||||
if not await check_implicit_invocation(message, bot_user):
|
||||
return ShouldRespondContext(
|
||||
should_respond=False, persona_id=None, thread_only_mode=False
|
||||
)
|
||||
|
||||
return ShouldRespondContext(
|
||||
should_respond=True,
|
||||
persona_id=persona_id,
|
||||
thread_only_mode=channel_config.thread_only_mode,
|
||||
)
|
||||
|
||||
|
||||
async def check_implicit_invocation(
|
||||
message: discord.Message,
|
||||
bot_user: discord.ClientUser,
|
||||
) -> bool:
|
||||
"""Check if the bot should respond without explicit mention.
|
||||
|
||||
Returns True if:
|
||||
1. User is replying to a bot message
|
||||
2. User is in a thread owned by the bot
|
||||
3. User is in a thread created from a bot message
|
||||
"""
|
||||
# Check if replying to a bot message
|
||||
if message.reference and message.reference.message_id:
|
||||
try:
|
||||
referenced_msg = await message.channel.fetch_message(
|
||||
message.reference.message_id
|
||||
)
|
||||
if referenced_msg.author.id == bot_user.id:
|
||||
logger.debug(
|
||||
f"Implicit invocation via reply: '{message.content[:50]}...'"
|
||||
)
|
||||
return True
|
||||
except (discord.NotFound, discord.HTTPException):
|
||||
pass
|
||||
|
||||
# Check thread-related conditions
|
||||
if isinstance(message.channel, discord.Thread):
|
||||
thread = message.channel
|
||||
|
||||
# Bot owns the thread
|
||||
if thread.owner_id == bot_user.id:
|
||||
logger.debug(
|
||||
f"Implicit invocation via bot-owned thread: "
|
||||
f"'{message.content[:50]}...' in #{thread.name}"
|
||||
)
|
||||
return True
|
||||
|
||||
# Thread was created from a bot message
|
||||
if thread.parent and not isinstance(thread.parent, discord.ForumChannel):
|
||||
try:
|
||||
starter = await thread.parent.fetch_message(thread.id)
|
||||
if starter.author.id == bot_user.id:
|
||||
logger.debug(
|
||||
f"Implicit invocation via bot-started thread: "
|
||||
f"'{message.content[:50]}...' in #{thread.name}"
|
||||
)
|
||||
return True
|
||||
except (discord.NotFound, discord.HTTPException):
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Message Processing
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def process_chat_message(
|
||||
message: discord.Message,
|
||||
api_key: str,
|
||||
persona_id: int | None,
|
||||
thread_only_mode: bool,
|
||||
api_client: OnyxAPIClient,
|
||||
bot_user: discord.ClientUser,
|
||||
) -> None:
|
||||
"""Process a message and send response."""
|
||||
try:
|
||||
await message.add_reaction(THINKING_EMOJI)
|
||||
except discord.DiscordException:
|
||||
logger.warning(
|
||||
f"Failed to add thinking reaction to message: '{message.content[:50]}...'"
|
||||
)
|
||||
|
||||
try:
|
||||
# Build conversation context
|
||||
context = await _build_conversation_context(message, bot_user)
|
||||
|
||||
# Prepare full message content
|
||||
parts = []
|
||||
if context:
|
||||
parts.append(context)
|
||||
if isinstance(message.channel, discord.Thread):
|
||||
if isinstance(message.channel.parent, discord.ForumChannel):
|
||||
parts.append(f"Forum post title: {message.channel.name}")
|
||||
parts.append(
|
||||
f"Current message from @{message.author.display_name}: {format_message_content(message)}"
|
||||
)
|
||||
|
||||
# Send to API
|
||||
response = await api_client.send_chat_message(
|
||||
message="\n\n".join(parts),
|
||||
api_key=api_key,
|
||||
persona_id=persona_id,
|
||||
)
|
||||
|
||||
# Format response with citations
|
||||
answer = response.answer or "I couldn't generate a response."
|
||||
answer = _append_citations(answer, response)
|
||||
|
||||
await send_response(message, answer, thread_only_mode)
|
||||
|
||||
try:
|
||||
await message.remove_reaction(THINKING_EMOJI, bot_user)
|
||||
except discord.DiscordException:
|
||||
pass
|
||||
|
||||
except APIError as e:
|
||||
logger.error(f"API error processing message: {e}")
|
||||
await send_error_response(message, bot_user)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing chat message: {e}")
|
||||
await send_error_response(message, bot_user)
|
||||
|
||||
|
||||
async def _build_conversation_context(
|
||||
message: discord.Message,
|
||||
bot_user: discord.ClientUser,
|
||||
) -> str | None:
|
||||
"""Build conversation context from thread history or reply chain."""
|
||||
if isinstance(message.channel, discord.Thread):
|
||||
return await _build_thread_context(message, bot_user)
|
||||
elif message.reference:
|
||||
return await _build_reply_chain_context(message, bot_user)
|
||||
return None
|
||||
|
||||
|
||||
def _append_citations(answer: str, response: ChatFullResponse) -> str:
|
||||
"""Append citation sources to the answer if present."""
|
||||
if not response.citation_info or not response.top_documents:
|
||||
return answer
|
||||
|
||||
cited_docs: list[tuple[int, str, str | None]] = []
|
||||
for citation in response.citation_info:
|
||||
doc = next(
|
||||
(
|
||||
d
|
||||
for d in response.top_documents
|
||||
if d.document_id == citation.document_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
if doc:
|
||||
cited_docs.append(
|
||||
(
|
||||
citation.citation_number,
|
||||
doc.semantic_identifier or "Source",
|
||||
doc.link,
|
||||
)
|
||||
)
|
||||
|
||||
if not cited_docs:
|
||||
return answer
|
||||
|
||||
cited_docs.sort(key=lambda x: x[0])
|
||||
citations = "\n\n**Sources:**\n"
|
||||
for num, name, link in cited_docs[:5]:
|
||||
if link:
|
||||
citations += f"{num}. [{name}](<{link}>)\n"
|
||||
else:
|
||||
citations += f"{num}. {name}\n"
|
||||
|
||||
return answer + citations
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Context Building
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _build_reply_chain_context(
|
||||
message: discord.Message,
|
||||
bot_user: discord.ClientUser,
|
||||
) -> str | None:
|
||||
"""Build context by following the reply chain backwards."""
|
||||
if not message.reference or not message.reference.message_id:
|
||||
return None
|
||||
|
||||
try:
|
||||
messages: list[discord.Message] = []
|
||||
current = message
|
||||
|
||||
# Follow reply chain backwards up to MAX_CONTEXT_MESSAGES
|
||||
while (
|
||||
current.reference
|
||||
and current.reference.message_id
|
||||
and len(messages) < MAX_CONTEXT_MESSAGES
|
||||
):
|
||||
try:
|
||||
parent = await message.channel.fetch_message(
|
||||
current.reference.message_id
|
||||
)
|
||||
messages.append(parent)
|
||||
current = parent
|
||||
except (discord.NotFound, discord.HTTPException):
|
||||
break
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
messages.reverse() # Chronological order
|
||||
|
||||
logger.debug(
|
||||
f"Built reply chain context: {len(messages)} messages in #{getattr(message.channel, 'name', 'unknown')}"
|
||||
)
|
||||
|
||||
return _format_messages_as_context(messages, bot_user)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to build reply chain context: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def _build_thread_context(
|
||||
message: discord.Message,
|
||||
bot_user: discord.ClientUser,
|
||||
) -> str | None:
|
||||
"""Build context from thread message history."""
|
||||
if not isinstance(message.channel, discord.Thread):
|
||||
return None
|
||||
|
||||
try:
|
||||
thread = message.channel
|
||||
messages: list[discord.Message] = []
|
||||
|
||||
# Fetch recent messages (excluding current)
|
||||
async for msg in thread.history(limit=MAX_CONTEXT_MESSAGES, oldest_first=False):
|
||||
if msg.id != message.id:
|
||||
messages.append(msg)
|
||||
|
||||
# Include thread starter message and its reply chain if not already present
|
||||
if thread.parent and not isinstance(thread.parent, discord.ForumChannel):
|
||||
try:
|
||||
starter = await thread.parent.fetch_message(thread.id)
|
||||
if starter.id != message.id and not any(
|
||||
m.id == starter.id for m in messages
|
||||
):
|
||||
messages.append(starter)
|
||||
|
||||
# Trace back through the starter's reply chain for more context
|
||||
current = starter
|
||||
while (
|
||||
current.reference
|
||||
and current.reference.message_id
|
||||
and len(messages) < MAX_CONTEXT_MESSAGES
|
||||
):
|
||||
try:
|
||||
parent = await thread.parent.fetch_message(
|
||||
current.reference.message_id
|
||||
)
|
||||
if not any(m.id == parent.id for m in messages):
|
||||
messages.append(parent)
|
||||
current = parent
|
||||
except (discord.NotFound, discord.HTTPException):
|
||||
break
|
||||
except (discord.NotFound, discord.HTTPException):
|
||||
pass
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
messages.sort(key=lambda m: m.id) # Chronological order
|
||||
logger.debug(
|
||||
f"Built thread context: {len(messages)} messages in #{thread.name}"
|
||||
)
|
||||
|
||||
return _format_messages_as_context(messages, bot_user)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to build thread context: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _format_messages_as_context(
|
||||
messages: list[discord.Message],
|
||||
bot_user: discord.ClientUser,
|
||||
) -> str | None:
|
||||
"""Format a list of messages into a conversation context string."""
|
||||
formatted = []
|
||||
for msg in messages:
|
||||
if msg.type not in CONTENT_MESSAGE_TYPES:
|
||||
continue
|
||||
|
||||
sender = (
|
||||
"OnyxBot" if msg.author.id == bot_user.id else f"@{msg.author.display_name}"
|
||||
)
|
||||
formatted.append(f"{sender}: {format_message_content(msg)}")
|
||||
|
||||
if not formatted:
|
||||
return None
|
||||
|
||||
return (
|
||||
"You are a Discord bot named OnyxBot.\n"
|
||||
'Always assume that [user] is the same as the "Current message" author.'
|
||||
"Conversation history:\n"
|
||||
"---\n" + "\n".join(formatted) + "\n---"
|
||||
)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Message Formatting
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
def format_message_content(message: discord.Message) -> str:
|
||||
"""Format message content with readable mentions."""
|
||||
content = message.content
|
||||
|
||||
for user in message.mentions:
|
||||
content = content.replace(f"<@{user.id}>", f"@{user.display_name}")
|
||||
content = content.replace(f"<@!{user.id}>", f"@{user.display_name}")
|
||||
|
||||
for role in message.role_mentions:
|
||||
content = content.replace(f"<@&{role.id}>", f"@{role.name}")
|
||||
|
||||
for channel in message.channel_mentions:
|
||||
content = content.replace(f"<#{channel.id}>", f"#{channel.name}")
|
||||
|
||||
return content
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Response Sending
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def send_response(
|
||||
message: discord.Message,
|
||||
content: str,
|
||||
thread_only_mode: bool,
|
||||
) -> None:
|
||||
"""Send response based on thread_only_mode setting."""
|
||||
chunks = _split_message(content)
|
||||
|
||||
if isinstance(message.channel, discord.Thread):
|
||||
for chunk in chunks:
|
||||
await message.channel.send(chunk)
|
||||
elif thread_only_mode:
|
||||
thread_name = f"OnyxBot <> {message.author.display_name}"[:100]
|
||||
thread = await message.create_thread(name=thread_name)
|
||||
for chunk in chunks:
|
||||
await thread.send(chunk)
|
||||
else:
|
||||
for i, chunk in enumerate(chunks):
|
||||
if i == 0:
|
||||
await message.reply(chunk)
|
||||
else:
|
||||
await message.channel.send(chunk)
|
||||
|
||||
|
||||
def _split_message(content: str) -> list[str]:
|
||||
"""Split content into chunks that fit Discord's message limit."""
|
||||
chunks = []
|
||||
while content:
|
||||
if len(content) <= MAX_MESSAGE_LENGTH:
|
||||
chunks.append(content)
|
||||
break
|
||||
|
||||
# Find a good split point
|
||||
split_at = MAX_MESSAGE_LENGTH
|
||||
for sep in ["\n\n", "\n", ". ", " "]:
|
||||
idx = content.rfind(sep, 0, MAX_MESSAGE_LENGTH)
|
||||
if idx > MAX_MESSAGE_LENGTH // 2:
|
||||
split_at = idx + len(sep)
|
||||
break
|
||||
|
||||
chunks.append(content[:split_at])
|
||||
content = content[split_at:]
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
async def send_error_response(
|
||||
message: discord.Message,
|
||||
bot_user: discord.ClientUser,
|
||||
) -> None:
|
||||
"""Send error response and clean up reaction."""
|
||||
try:
|
||||
await message.remove_reaction(THINKING_EMOJI, bot_user)
|
||||
except discord.DiscordException:
|
||||
pass
|
||||
|
||||
error_msg = "Sorry, I encountered an error processing your message. You may want to contact Onyx for support :sweat_smile:"
|
||||
|
||||
try:
|
||||
if isinstance(message.channel, discord.Thread):
|
||||
await message.channel.send(error_msg)
|
||||
else:
|
||||
thread = await message.create_thread(
|
||||
name=f"Response to {message.author.display_name}"[:100]
|
||||
)
|
||||
await thread.send(error_msg)
|
||||
except discord.DiscordException:
|
||||
pass
|
||||
39
backend/onyx/onyxbot/discord/utils.py
Normal file
39
backend/onyx/onyxbot/discord/utils.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import DISCORD_BOT_TOKEN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.db.discord_bot import get_discord_bot_config
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_bot_token() -> str | None:
|
||||
"""Get Discord bot token from env var or database.
|
||||
|
||||
Priority:
|
||||
1. DISCORD_BOT_TOKEN env var (always takes precedence)
|
||||
2. For self-hosted: DiscordBotConfig in database (default tenant)
|
||||
3. For Cloud: should always have env var set
|
||||
|
||||
Returns:
|
||||
Bot token string, or None if not configured.
|
||||
"""
|
||||
# Environment variable takes precedence
|
||||
if DISCORD_BOT_TOKEN:
|
||||
return DISCORD_BOT_TOKEN
|
||||
|
||||
# Cloud should always have env var; if not, return None
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
logger.warning("Cloud deployment missing DISCORD_BOT_TOKEN env var")
|
||||
return None
|
||||
|
||||
# Self-hosted: check database for bot config
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db:
|
||||
config = get_discord_bot_config(db)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get bot token from database: {e}")
|
||||
return None
|
||||
return config.bot_token if config else None
|
||||
@@ -592,11 +592,8 @@ def build_slack_response_blocks(
|
||||
)
|
||||
|
||||
citations_blocks = []
|
||||
document_blocks = []
|
||||
if answer.citation_info:
|
||||
citations_blocks = _build_citations_blocks(answer)
|
||||
else:
|
||||
document_blocks = _priority_ordered_documents_blocks(answer)
|
||||
|
||||
citations_divider = [DividerBlock()] if citations_blocks else []
|
||||
buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else []
|
||||
@@ -608,7 +605,6 @@ def build_slack_response_blocks(
|
||||
+ ai_feedback_block
|
||||
+ citations_divider
|
||||
+ citations_blocks
|
||||
+ document_blocks
|
||||
+ buttons_divider
|
||||
+ web_follow_up_block
|
||||
+ follow_up_block
|
||||
|
||||
@@ -1,12 +1,149 @@
|
||||
from mistune import Markdown # type: ignore[import-untyped]
|
||||
from mistune import Renderer
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from mistune import create_markdown
|
||||
from mistune import HTMLRenderer
|
||||
|
||||
# Tags that should be replaced with a newline (line-break and block-level elements)
|
||||
_HTML_NEWLINE_TAG_PATTERN = re.compile(
|
||||
r"<br\s*/?>|</(?:p|div|li|h[1-6]|tr|blockquote|section|article)>",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Strips HTML tags but excludes autolinks like <https://...> and <mailto:...>
|
||||
_HTML_TAG_PATTERN = re.compile(
|
||||
r"<(?!https?://|mailto:)/?[a-zA-Z][^>]*>",
|
||||
)
|
||||
|
||||
# Matches fenced code blocks (``` ... ```) so we can skip sanitization inside them
|
||||
_FENCED_CODE_BLOCK_PATTERN = re.compile(r"```[\s\S]*?```")
|
||||
|
||||
# Matches the start of any markdown link: [text]( or [[n]](
|
||||
# The inner group handles nested brackets for citation links like [[1]](.
|
||||
_MARKDOWN_LINK_PATTERN = re.compile(r"\[(?:[^\[\]]|\[[^\]]*\])*\]\(")
|
||||
|
||||
# Matches Slack-style links <url|text> that LLMs sometimes output directly.
|
||||
# Mistune doesn't recognise this syntax, so text() would escape the angle
|
||||
# brackets and Slack would render them as literal text instead of links.
|
||||
_SLACK_LINK_PATTERN = re.compile(r"<(https?://[^|>]+)\|([^>]+)>")
|
||||
|
||||
|
||||
def _sanitize_html(text: str) -> str:
|
||||
"""Strip HTML tags from a text fragment.
|
||||
|
||||
Block-level closing tags and <br> are converted to newlines.
|
||||
All other HTML tags are removed. Autolinks (<https://...>) are preserved.
|
||||
"""
|
||||
text = _HTML_NEWLINE_TAG_PATTERN.sub("\n", text)
|
||||
text = _HTML_TAG_PATTERN.sub("", text)
|
||||
return text
|
||||
|
||||
|
||||
def _transform_outside_code_blocks(
|
||||
message: str, transform: Callable[[str], str]
|
||||
) -> str:
|
||||
"""Apply *transform* only to text outside fenced code blocks."""
|
||||
parts = _FENCED_CODE_BLOCK_PATTERN.split(message)
|
||||
code_blocks = _FENCED_CODE_BLOCK_PATTERN.findall(message)
|
||||
|
||||
result: list[str] = []
|
||||
for i, part in enumerate(parts):
|
||||
result.append(transform(part))
|
||||
if i < len(code_blocks):
|
||||
result.append(code_blocks[i])
|
||||
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int | None]:
|
||||
"""Extract markdown link destination, allowing nested parentheses in the URL."""
|
||||
depth = 0
|
||||
i = start_idx
|
||||
|
||||
while i < len(message):
|
||||
curr = message[i]
|
||||
if curr == "\\":
|
||||
i += 2
|
||||
continue
|
||||
|
||||
if curr == "(":
|
||||
depth += 1
|
||||
elif curr == ")":
|
||||
if depth == 0:
|
||||
return message[start_idx:i], i
|
||||
depth -= 1
|
||||
i += 1
|
||||
|
||||
return message[start_idx:], None
|
||||
|
||||
|
||||
def _normalize_link_destinations(message: str) -> str:
|
||||
"""Wrap markdown link URLs in angle brackets so the parser handles special chars safely.
|
||||
|
||||
Markdown link syntax [text](url) breaks when the URL contains unescaped
|
||||
parentheses, spaces, or other special characters. Wrapping the URL in angle
|
||||
brackets — [text](<url>) — tells the parser to treat everything inside as
|
||||
a literal URL. This applies to all links, not just citations.
|
||||
"""
|
||||
if "](" not in message:
|
||||
return message
|
||||
|
||||
normalized_parts: list[str] = []
|
||||
cursor = 0
|
||||
|
||||
while match := _MARKDOWN_LINK_PATTERN.search(message, cursor):
|
||||
normalized_parts.append(message[cursor : match.end()])
|
||||
destination_start = match.end()
|
||||
destination, end_idx = _extract_link_destination(message, destination_start)
|
||||
if end_idx is None:
|
||||
normalized_parts.append(message[destination_start:])
|
||||
return "".join(normalized_parts)
|
||||
|
||||
already_wrapped = destination.startswith("<") and destination.endswith(">")
|
||||
if destination and not already_wrapped:
|
||||
destination = f"<{destination}>"
|
||||
|
||||
normalized_parts.append(destination)
|
||||
normalized_parts.append(")")
|
||||
cursor = end_idx + 1
|
||||
|
||||
normalized_parts.append(message[cursor:])
|
||||
return "".join(normalized_parts)
|
||||
|
||||
|
||||
def _convert_slack_links_to_markdown(message: str) -> str:
|
||||
"""Convert Slack-style <url|text> links to standard markdown [text](url).
|
||||
|
||||
LLMs sometimes emit Slack mrkdwn link syntax directly. Mistune doesn't
|
||||
recognise it, so the angle brackets would be escaped by text() and Slack
|
||||
would render the link as literal text instead of a clickable link.
|
||||
"""
|
||||
return _transform_outside_code_blocks(
|
||||
message, lambda text: _SLACK_LINK_PATTERN.sub(r"[\2](\1)", text)
|
||||
)
|
||||
|
||||
|
||||
def format_slack_message(message: str | None) -> str:
|
||||
return Markdown(renderer=SlackRenderer()).render(message)
|
||||
if message is None:
|
||||
return ""
|
||||
message = _transform_outside_code_blocks(message, _sanitize_html)
|
||||
message = _convert_slack_links_to_markdown(message)
|
||||
normalized_message = _normalize_link_destinations(message)
|
||||
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
|
||||
result = md(normalized_message)
|
||||
# With HTMLRenderer, result is always str (not AST list)
|
||||
assert isinstance(result, str)
|
||||
return result.rstrip("\n")
|
||||
|
||||
|
||||
class SlackRenderer(Renderer):
|
||||
class SlackRenderer(HTMLRenderer):
|
||||
"""Renders markdown as Slack mrkdwn format instead of HTML.
|
||||
|
||||
Overrides all HTMLRenderer methods that produce HTML tags to ensure
|
||||
no raw HTML ever appears in Slack messages.
|
||||
"""
|
||||
|
||||
SPECIALS: dict[str, str] = {"&": "&", "<": "<", ">": ">"}
|
||||
|
||||
def escape_special(self, text: str) -> str:
|
||||
@@ -14,52 +151,72 @@ class SlackRenderer(Renderer):
|
||||
text = text.replace(special, replacement)
|
||||
return text
|
||||
|
||||
def header(self, text: str, level: int, raw: str | None = None) -> str:
|
||||
return f"*{text}*\n"
|
||||
def heading(self, text: str, level: int, **attrs: Any) -> str: # noqa: ARG002
|
||||
return f"*{text}*\n\n"
|
||||
|
||||
def emphasis(self, text: str) -> str:
|
||||
return f"_{text}_"
|
||||
|
||||
def double_emphasis(self, text: str) -> str:
|
||||
def strong(self, text: str) -> str:
|
||||
return f"*{text}*"
|
||||
|
||||
def strikethrough(self, text: str) -> str:
|
||||
return f"~{text}~"
|
||||
|
||||
def list(self, body: str, ordered: bool = True) -> str:
|
||||
lines = body.split("\n")
|
||||
def list(self, text: str, ordered: bool, **attrs: Any) -> str:
|
||||
lines = text.split("\n")
|
||||
count = 0
|
||||
for i, line in enumerate(lines):
|
||||
if line.startswith("li: "):
|
||||
count += 1
|
||||
prefix = f"{count}. " if ordered else "• "
|
||||
lines[i] = f"{prefix}{line[4:]}"
|
||||
return "\n".join(lines)
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
def list_item(self, text: str) -> str:
|
||||
return f"li: {text}\n"
|
||||
|
||||
def link(self, link: str, title: str | None, content: str | None) -> str:
|
||||
escaped_link = self.escape_special(link)
|
||||
if content:
|
||||
return f"<{escaped_link}|{content}>"
|
||||
def link(self, text: str, url: str, title: str | None = None) -> str:
|
||||
escaped_url = self.escape_special(url)
|
||||
if text:
|
||||
return f"<{escaped_url}|{text}>"
|
||||
if title:
|
||||
return f"<{escaped_link}|{title}>"
|
||||
return f"<{escaped_link}>"
|
||||
return f"<{escaped_url}|{title}>"
|
||||
return f"<{escaped_url}>"
|
||||
|
||||
def image(self, src: str, title: str | None, text: str | None) -> str:
|
||||
escaped_src = self.escape_special(src)
|
||||
def image(self, text: str, url: str, title: str | None = None) -> str:
|
||||
escaped_url = self.escape_special(url)
|
||||
display_text = title or text
|
||||
return f"<{escaped_src}|{display_text}>" if display_text else f"<{escaped_src}>"
|
||||
return f"<{escaped_url}|{display_text}>" if display_text else f"<{escaped_url}>"
|
||||
|
||||
def codespan(self, text: str) -> str:
|
||||
return f"`{text}`"
|
||||
|
||||
def block_code(self, text: str, lang: str | None) -> str:
|
||||
return f"```\n{text}\n```\n"
|
||||
def block_code(self, code: str, info: str | None = None) -> str: # noqa: ARG002
|
||||
return f"```\n{code.rstrip(chr(10))}\n```\n\n"
|
||||
|
||||
def linebreak(self) -> str:
|
||||
return "\n"
|
||||
|
||||
def thematic_break(self) -> str:
|
||||
return "---\n\n"
|
||||
|
||||
def block_quote(self, text: str) -> str:
|
||||
lines = text.strip().split("\n")
|
||||
quoted = "\n".join(f">{line}" for line in lines)
|
||||
return quoted + "\n\n"
|
||||
|
||||
def block_html(self, html: str) -> str:
|
||||
return _sanitize_html(html) + "\n\n"
|
||||
|
||||
def block_error(self, text: str) -> str:
|
||||
return f"```\n{text}\n```\n\n"
|
||||
|
||||
def text(self, text: str) -> str:
|
||||
# Only escape the three entities Slack recognizes: & < >
|
||||
# HTMLRenderer.text() also escapes " to " which Slack renders
|
||||
# as literal " text since Slack doesn't recognize that entity.
|
||||
return self.escape_special(text)
|
||||
|
||||
def paragraph(self, text: str) -> str:
|
||||
return f"{text}\n"
|
||||
|
||||
def autolink(self, link: str, is_email: bool) -> str:
|
||||
return link if is_email else self.link(link, None, None)
|
||||
return f"{text}\n\n"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user