mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-20 09:15:47 +00:00
Compare commits
247 Commits
thread_sen
...
test-tests
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b1e92d8e8f | ||
|
|
0594fd17de | ||
|
|
fded81dc28 | ||
|
|
31db112de9 | ||
|
|
a3e2da2c51 | ||
|
|
f4d33bcc0d | ||
|
|
464d957494 | ||
|
|
be12de9a44 | ||
|
|
3e4a1f8a09 | ||
|
|
af9b7826ab | ||
|
|
cb16eb13fc | ||
|
|
20a73bdd2e | ||
|
|
85cc2b99b7 | ||
|
|
1208a3ee2b | ||
|
|
900fcef9dd | ||
|
|
d4ed25753b | ||
|
|
0ee58333b4 | ||
|
|
11b7e0d571 | ||
|
|
a35831f328 | ||
|
|
048a6d5259 | ||
|
|
e4bdb15910 | ||
|
|
3517d59286 | ||
|
|
4bc08e5d88 | ||
|
|
4bd080cf62 | ||
|
|
b0a8625ffc | ||
|
|
f94baf6143 | ||
|
|
9e1867638a | ||
|
|
5b6d7c9f0d | ||
|
|
e5dcf31f10 | ||
|
|
8ca06ef3e7 | ||
|
|
6897dbd610 | ||
|
|
7f3cb77466 | ||
|
|
267042a5aa | ||
|
|
d02b3ae6ac | ||
|
|
683c3f7a7e | ||
|
|
008b4d2288 | ||
|
|
8be261405a | ||
|
|
61f2c48ebc | ||
|
|
dbde2e6d6d | ||
|
|
2860136214 | ||
|
|
49ec5994d3 | ||
|
|
8d5fb67f0f | ||
|
|
15d02f6e3c | ||
|
|
e58974c419 | ||
|
|
6b66c07952 | ||
|
|
cae058a3ac | ||
|
|
aa3b21a191 | ||
|
|
7a07a78696 | ||
|
|
a8db236e37 | ||
|
|
8a2e4ed36f | ||
|
|
216f2c95a7 | ||
|
|
67081efe08 | ||
|
|
9d40b8336f | ||
|
|
23f0033302 | ||
|
|
9011b76eb0 | ||
|
|
45e436bafc | ||
|
|
010bc36d61 | ||
|
|
468e488bdb | ||
|
|
9104c0ffce | ||
|
|
d36a6bd0b4 | ||
|
|
a3603c498c | ||
|
|
8f274e34c9 | ||
|
|
5c256760ff | ||
|
|
258e1372b3 | ||
|
|
83a543a265 | ||
|
|
f9719d199d | ||
|
|
1c7bb6e56a | ||
|
|
982ad7d329 | ||
|
|
f94292808b | ||
|
|
293553a2e2 | ||
|
|
ba906ae6fa | ||
|
|
c84c7a354e | ||
|
|
2187b0dd82 | ||
|
|
d88a417bf9 | ||
|
|
f2d32b0b3b | ||
|
|
f89432009f | ||
|
|
8ab2bab34e | ||
|
|
59e0d62512 | ||
|
|
a1471b16a5 | ||
|
|
9d3811cb58 | ||
|
|
3cd9505383 | ||
|
|
d11829b393 | ||
|
|
f6e068e914 | ||
|
|
0c84edd980 | ||
|
|
2b274a7683 | ||
|
|
ddd91f2d71 | ||
|
|
a7c7da0dfc | ||
|
|
b00a3e8b5d | ||
|
|
d77d1a48f1 | ||
|
|
7b4fc6729c | ||
|
|
1f113c86ef | ||
|
|
8e38ba3e21 | ||
|
|
bb9708a64f | ||
|
|
8cae97e145 | ||
|
|
7e4abca224 | ||
|
|
233a91ea65 | ||
|
|
b30737b6b2 | ||
|
|
caf8b85ec2 | ||
|
|
1d13580b63 | ||
|
|
00390c53e0 | ||
|
|
66656df9e6 | ||
|
|
51d26d7e4c | ||
|
|
198ac8ccbc | ||
|
|
ee6d33f484 | ||
|
|
7bcb72d055 | ||
|
|
876894e097 | ||
|
|
7215f56b25 | ||
|
|
0fd1c34014 | ||
|
|
9e24b41b7b | ||
|
|
ab3853578b | ||
|
|
7db969d36a | ||
|
|
6cdeb71656 | ||
|
|
2c4b2c68b4 | ||
|
|
5301ee7cef | ||
|
|
f8e6716875 | ||
|
|
755c65fd8a | ||
|
|
90cf5f49e3 | ||
|
|
d4068c2b07 | ||
|
|
fd6fa43fe1 | ||
|
|
8d5013bf01 | ||
|
|
dabd7c6263 | ||
|
|
c8c0389675 | ||
|
|
9cfcfb12e1 | ||
|
|
786a0c2bd0 | ||
|
|
0cd8d3402b | ||
|
|
3fa397b24d | ||
|
|
e0a97230b8 | ||
|
|
7f1272117a | ||
|
|
79302f19be | ||
|
|
4a91e644d4 | ||
|
|
ca0318f16e | ||
|
|
be8e0b3a98 | ||
|
|
49c4814c70 | ||
|
|
2f945613a2 | ||
|
|
e9242ca3a8 | ||
|
|
a150de761a | ||
|
|
0e792ca6c9 | ||
|
|
6be467a4ac | ||
|
|
dd91bfcfe6 | ||
|
|
8a72291781 | ||
|
|
b2d71da4eb | ||
|
|
6e2f851c62 | ||
|
|
be078edcb4 | ||
|
|
194c54aca3 | ||
|
|
9fa7221e24 | ||
|
|
3a5c7ef8ee | ||
|
|
84458aa0bf | ||
|
|
de57bfa35f | ||
|
|
386f8f31ed | ||
|
|
376f04caea | ||
|
|
4b0a3c2b04 | ||
|
|
1bd9f9d9a6 | ||
|
|
4ac10abaea | ||
|
|
a66a283af4 | ||
|
|
bf5da04166 | ||
|
|
693487f855 | ||
|
|
d02a76d7d1 | ||
|
|
28e05c6e90 | ||
|
|
a18f546921 | ||
|
|
e98dea149e | ||
|
|
027c165794 | ||
|
|
14ebe912c8 | ||
|
|
a63b906789 | ||
|
|
92a68a3c22 | ||
|
|
95db4ed9c7 | ||
|
|
5134d60d48 | ||
|
|
651a54470d | ||
|
|
269d243b67 | ||
|
|
0286dd7da9 | ||
|
|
f3a0710d69 | ||
|
|
037c2aee3a | ||
|
|
9b2f3d234d | ||
|
|
7646399cd4 | ||
|
|
d913b93d10 | ||
|
|
8a0ce4c294 | ||
|
|
862c140763 | ||
|
|
47487f1940 | ||
|
|
e3471df940 | ||
|
|
fb33c815b3 | ||
|
|
5c6594be73 | ||
|
|
8d30a03d7f | ||
|
|
277428f579 | ||
|
|
9f8c0d4237 | ||
|
|
9ccbb6a04b | ||
|
|
58a943f782 | ||
|
|
9021c607f2 | ||
|
|
c03b0d80fd | ||
|
|
fcf0b316a4 | ||
|
|
157f672b4b | ||
|
|
51b9484b96 | ||
|
|
0c8f55c049 | ||
|
|
c7be2571d1 | ||
|
|
4948b6cca9 | ||
|
|
638ea5f316 | ||
|
|
6e3268ca75 | ||
|
|
d8921df60c | ||
|
|
693d9f5f69 | ||
|
|
02e17871cc | ||
|
|
209cfd00b0 | ||
|
|
cd36baa484 | ||
|
|
c78fe275af | ||
|
|
c935c4808f | ||
|
|
4ebcfef541 | ||
|
|
e320ef9d9c | ||
|
|
9e02438af5 | ||
|
|
177e097ddb | ||
|
|
9ecd47ec31 | ||
|
|
83f3d29b10 | ||
|
|
12e668cc0f | ||
|
|
afe8376d5e | ||
|
|
082ef3e096 | ||
|
|
cb2951a1c0 | ||
|
|
eda5598af5 | ||
|
|
0bbb4b6988 | ||
|
|
4768aadb20 | ||
|
|
e05e85e782 | ||
|
|
6408f61307 | ||
|
|
5a5cd51e4f | ||
|
|
7c047c47a0 | ||
|
|
22138bbb33 | ||
|
|
7cff1064a8 | ||
|
|
deeb6fdcd2 | ||
|
|
3e7f4e0aa5 | ||
|
|
ac73671e35 | ||
|
|
3c20d132e0 | ||
|
|
0e3e7eb4a2 | ||
|
|
c85aebe8ab | ||
|
|
a47e6a3146 | ||
|
|
1e61737e03 | ||
|
|
c7fc1cd5ae | ||
|
|
e2b60bf67c | ||
|
|
f4d4d14286 | ||
|
|
1c24bc6ea2 | ||
|
|
cacbd18dcd | ||
|
|
8527b83b15 | ||
|
|
33e37a1846 | ||
|
|
d454d8a878 | ||
|
|
00ad65a6a8 | ||
|
|
dac60d403c | ||
|
|
6256b2854d | ||
|
|
8acb8e191d | ||
|
|
8c4cbddc43 | ||
|
|
f6cd006bd6 | ||
|
|
0033934319 | ||
|
|
ff87b79d14 | ||
|
|
ebf18af7c9 | ||
|
|
cf67ae962c |
413
.github/workflows/deployment.yml
vendored
413
.github/workflows/deployment.yml
vendored
@@ -8,7 +8,9 @@ on:
|
||||
|
||||
# Set restrictive default permissions for all jobs. Jobs that need more permissions
|
||||
# should explicitly declare them.
|
||||
permissions: {}
|
||||
permissions:
|
||||
# Required for OIDC authentication with AWS
|
||||
id-token: write # zizmor: ignore[excessive-permissions]
|
||||
|
||||
env:
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
|
||||
@@ -150,16 +152,30 @@ jobs:
|
||||
if: always() && needs.check-version-tag.result == 'failure' && github.event_name != 'workflow_dispatch'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
environment: release
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
MONITOR_DEPLOYMENTS_WEBHOOK, deploy/monitor-deployments-webhook
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
webhook-url: ${{ env.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
failed-jobs: "• check-version-tag"
|
||||
title: "🚨 Version Tag Check Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
@@ -168,6 +184,7 @@ jobs:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-desktop == 'true'
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: write
|
||||
actions: read
|
||||
strategy:
|
||||
@@ -185,12 +202,33 @@ jobs:
|
||||
|
||||
runs-on: ${{ matrix.platform }}
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6.0.1
|
||||
with:
|
||||
# NOTE: persist-credentials is needed for tauri-action to create GitHub releases.
|
||||
persist-credentials: true # zizmor: ignore[artipacked]
|
||||
|
||||
- name: Configure AWS credentials
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
APPLE_ID, deploy/apple-id
|
||||
APPLE_PASSWORD, deploy/apple-password
|
||||
APPLE_CERTIFICATE, deploy/apple-certificate
|
||||
APPLE_CERTIFICATE_PASSWORD, deploy/apple-certificate-password
|
||||
KEYCHAIN_PASSWORD, deploy/keychain-password
|
||||
APPLE_TEAM_ID, deploy/apple-team-id
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: install dependencies (ubuntu only)
|
||||
if: startsWith(matrix.platform, 'ubuntu-')
|
||||
run: |
|
||||
@@ -285,15 +323,40 @@ jobs:
|
||||
|
||||
Write-Host "Versions set to: $VERSION"
|
||||
|
||||
- uses: tauri-apps/tauri-action@19b93bb55601e3e373a93cfb6eb4242e45f5af20 # ratchet:tauri-apps/tauri-action@action-v0.6.0
|
||||
- name: Import Apple Developer Certificate
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
run: |
|
||||
echo $APPLE_CERTIFICATE | base64 --decode > certificate.p12
|
||||
security create-keychain -p "$KEYCHAIN_PASSWORD" build.keychain
|
||||
security default-keychain -s build.keychain
|
||||
security unlock-keychain -p "$KEYCHAIN_PASSWORD" build.keychain
|
||||
security set-keychain-settings -t 3600 -u build.keychain
|
||||
security import certificate.p12 -k build.keychain -P "$APPLE_CERTIFICATE_PASSWORD" -T /usr/bin/codesign
|
||||
security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k "$KEYCHAIN_PASSWORD" build.keychain
|
||||
security find-identity -v -p codesigning build.keychain
|
||||
|
||||
- name: Verify Certificate
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
run: |
|
||||
CERT_INFO=$(security find-identity -v -p codesigning build.keychain | grep -E "(Developer ID Application|Apple Distribution|Apple Development)" | head -n 1)
|
||||
CERT_ID=$(echo "$CERT_INFO" | awk -F'"' '{print $2}')
|
||||
echo "CERT_ID=$CERT_ID" >> $GITHUB_ENV
|
||||
echo "Certificate imported."
|
||||
|
||||
- uses: tauri-apps/tauri-action@73fb865345c54760d875b94642314f8c0c894afa # ratchet:tauri-apps/tauri-action@action-v0.6.1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
APPLE_ID: ${{ env.APPLE_ID }}
|
||||
APPLE_PASSWORD: ${{ env.APPLE_PASSWORD }}
|
||||
APPLE_SIGNING_IDENTITY: ${{ env.CERT_ID }}
|
||||
APPLE_TEAM_ID: ${{ env.APPLE_TEAM_ID }}
|
||||
with:
|
||||
tagName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseBody: "See the assets to download this version and install."
|
||||
releaseDraft: true
|
||||
prerelease: false
|
||||
assetNamePattern: "[name]_[arch][ext]"
|
||||
args: ${{ matrix.args }}
|
||||
|
||||
build-web-amd64:
|
||||
@@ -305,6 +368,7 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-web-amd64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -317,6 +381,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -326,13 +404,13 @@ jobs:
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
@@ -363,6 +441,7 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-web-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -375,6 +454,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -384,13 +477,13 @@ jobs:
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
@@ -423,19 +516,34 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-merge-web
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -471,6 +579,7 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-web-cloud-amd64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -483,6 +592,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -492,13 +615,13 @@ jobs:
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
@@ -537,6 +660,7 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-web-cloud-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -549,6 +673,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -558,13 +696,13 @@ jobs:
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
@@ -605,19 +743,34 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-merge-web-cloud
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -650,6 +803,7 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-backend-amd64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -662,6 +816,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -671,13 +839,13 @@ jobs:
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
@@ -707,6 +875,7 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-backend-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -719,6 +888,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -728,13 +911,13 @@ jobs:
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
@@ -766,19 +949,34 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-merge-backend
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -815,6 +1013,7 @@ jobs:
|
||||
- volume=40gb
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -827,6 +1026,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -836,15 +1049,15 @@ jobs:
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
with:
|
||||
buildkitd-flags: ${{ vars.DOCKER_DEBUG == 'true' && '--debug' || '' }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
@@ -879,6 +1092,7 @@ jobs:
|
||||
- volume=40gb
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -891,6 +1105,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -900,15 +1128,15 @@ jobs:
|
||||
latest=false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
with:
|
||||
buildkitd-flags: ${{ vars.DOCKER_DEBUG == 'true' && '--debug' || '' }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
@@ -944,19 +1172,34 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-merge-model-server
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -994,11 +1237,26 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-trivy-scan-web
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
@@ -1014,8 +1272,8 @@ jobs:
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
@@ -1034,11 +1292,26 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-trivy-scan-web-cloud
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
@@ -1054,8 +1327,8 @@ jobs:
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
@@ -1074,6 +1347,7 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-trivy-scan-backend
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
steps:
|
||||
@@ -1084,6 +1358,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
@@ -1100,8 +1388,8 @@ jobs:
|
||||
-v ${{ github.workspace }}/backend/.trivyignore:/tmp/.trivyignore:ro \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
@@ -1121,11 +1409,26 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-trivy-scan-model-server
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
@@ -1141,8 +1444,8 @@ jobs:
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
@@ -1170,12 +1473,26 @@ jobs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
MONITOR_DEPLOYMENTS_WEBHOOK, deploy/monitor-deployments-webhook
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Determine failed jobs
|
||||
id: failed-jobs
|
||||
shell: bash
|
||||
@@ -1241,7 +1558,7 @@ jobs:
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
webhook-url: ${{ env.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
failed-jobs: ${{ steps.failed-jobs.outputs.jobs }}
|
||||
title: "🚨 Deployment Workflow Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
|
||||
2
.github/workflows/docker-tag-beta.yml
vendored
2
.github/workflows/docker-tag-beta.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
|
||||
2
.github/workflows/docker-tag-latest.yml
vendored
2
.github/workflows/docker-tag-latest.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
|
||||
1
.github/workflows/helm-chart-releases.yml
vendored
1
.github/workflows/helm-chart-releases.yml
vendored
@@ -29,6 +29,7 @@ jobs:
|
||||
run: |
|
||||
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
|
||||
helm repo add onyx-vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo add opensearch https://opensearch-project.github.io/helm-charts
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
|
||||
@@ -13,7 +13,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # ratchet:actions/stale@v10
|
||||
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # ratchet:actions/stale@v10
|
||||
with:
|
||||
stale-issue-message: 'This issue is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
|
||||
stale-pr-message: 'This PR is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
|
||||
|
||||
2
.github/workflows/nightly-scan-licenses.yml
vendored
2
.github/workflows/nightly-scan-licenses.yml
vendored
@@ -94,7 +94,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
|
||||
@@ -45,6 +45,9 @@ env:
|
||||
# TODO: debug why this is failing and enable
|
||||
CODE_INTERPRETER_BASE_URL: http://localhost:8000
|
||||
|
||||
# OpenSearch
|
||||
OPENSEARCH_ADMIN_PASSWORD: "StrongPassword123!"
|
||||
|
||||
jobs:
|
||||
discover-test-dirs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
@@ -125,11 +128,13 @@ jobs:
|
||||
docker compose \
|
||||
-f docker-compose.yml \
|
||||
-f docker-compose.dev.yml \
|
||||
-f docker-compose.opensearch.yml \
|
||||
up -d \
|
||||
minio \
|
||||
relational_db \
|
||||
cache \
|
||||
index \
|
||||
opensearch \
|
||||
code-interpreter
|
||||
|
||||
- name: Run migrations
|
||||
@@ -158,7 +163,7 @@ jobs:
|
||||
cd deployment/docker_compose
|
||||
|
||||
# Get list of running containers
|
||||
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml ps -q)
|
||||
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml -f docker-compose.opensearch.yml ps -q)
|
||||
|
||||
# Collect logs from each container
|
||||
for container in $containers; do
|
||||
@@ -172,7 +177,7 @@ jobs:
|
||||
|
||||
- name: Upload Docker logs
|
||||
if: failure()
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v5
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-logs-${{ matrix.test-dir }}
|
||||
path: docker-logs/
|
||||
|
||||
8
.github/workflows/pr-helm-chart-testing.yml
vendored
8
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -88,6 +88,7 @@ jobs:
|
||||
echo "=== Adding Helm repositories ==="
|
||||
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
|
||||
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo add opensearch https://opensearch-project.github.io/helm-charts
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
@@ -180,6 +181,11 @@ jobs:
|
||||
trap cleanup EXIT
|
||||
|
||||
# Run the actual installation with detailed logging
|
||||
# Note that opensearch.enabled is true whereas others in this install
|
||||
# are false. There is some work that needs to be done to get this
|
||||
# entire step working in CI, enabling opensearch here is a small step
|
||||
# in that direction. If this is causing issues, disabling it in this
|
||||
# step should be ok in the short term.
|
||||
echo "=== Starting ct install ==="
|
||||
set +e
|
||||
ct install --all \
|
||||
@@ -187,6 +193,8 @@ jobs:
|
||||
--set=nginx.enabled=false \
|
||||
--set=minio.enabled=false \
|
||||
--set=vespa.enabled=false \
|
||||
--set=opensearch.enabled=true \
|
||||
--set=auth.opensearch.enabled=true \
|
||||
--set=slackbot.enabled=false \
|
||||
--set=postgresql.enabled=true \
|
||||
--set=postgresql.nameOverride=cloudnative-pg \
|
||||
|
||||
13
.github/workflows/pr-integration-tests.yml
vendored
13
.github/workflows/pr-integration-tests.yml
vendored
@@ -103,7 +103,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
@@ -163,7 +163,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
@@ -208,7 +208,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling openapitools/openapi-generator-cli
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
@@ -310,8 +310,9 @@ jobs:
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=1
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
MCP_SERVER_ENABLED=true
|
||||
USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
|
||||
EOF
|
||||
|
||||
- name: Start Docker containers
|
||||
@@ -438,7 +439,7 @@ jobs:
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-${{ matrix.test-dir.name }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
@@ -567,7 +568,7 @@ jobs:
|
||||
|
||||
- name: Upload logs (multi-tenant)
|
||||
if: always()
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-multitenant
|
||||
path: ${{ github.workspace }}/docker-compose-multitenant.log
|
||||
|
||||
2
.github/workflows/pr-jest-tests.yml
vendored
2
.github/workflows/pr-jest-tests.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports
|
||||
if: always()
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: jest-coverage-${{ github.run_id }}
|
||||
path: ./web/coverage
|
||||
|
||||
10
.github/workflows/pr-mit-integration-tests.yml
vendored
10
.github/workflows/pr-mit-integration-tests.yml
vendored
@@ -95,7 +95,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
@@ -155,7 +155,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling Vespa, Redis, Postgres, and Minio images
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
@@ -214,7 +214,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling openapitools/openapi-generator-cli
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
@@ -301,7 +301,7 @@ jobs:
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
MCP_SERVER_ENABLED=true
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=1
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
EOF
|
||||
|
||||
- name: Start Docker containers
|
||||
@@ -424,7 +424,7 @@ jobs:
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-${{ matrix.test-dir.name }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
10
.github/workflows/pr-playwright-tests.yml
vendored
10
.github/workflows/pr-playwright-tests.yml
vendored
@@ -85,7 +85,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
@@ -146,7 +146,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
@@ -207,7 +207,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
@@ -435,7 +435,7 @@ jobs:
|
||||
fi
|
||||
npx playwright test --project ${PROJECT}
|
||||
|
||||
- uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
if: always()
|
||||
with:
|
||||
# Includes test results and trace.zip files
|
||||
@@ -455,7 +455,7 @@ jobs:
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
140
.github/workflows/pr-python-model-tests.yml
vendored
140
.github/workflows/pr-python-model-tests.yml
vendored
@@ -5,11 +5,6 @@ on:
|
||||
# This cron expression runs the job daily at 16:00 UTC (9am PT)
|
||||
- cron: "0 16 * * *"
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
branch:
|
||||
description: 'Branch to run the workflow on'
|
||||
required: false
|
||||
default: 'main'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -31,7 +26,11 @@ env:
|
||||
jobs:
|
||||
model-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}-model-check"]
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- "run-id=${{ github.run_id }}-model-check"
|
||||
- "extras=ecr-cache"
|
||||
timeout-minutes: 45
|
||||
|
||||
env:
|
||||
@@ -43,108 +42,87 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Python and Install Dependencies
|
||||
uses: ./.github/actions/setup-python-and-install-dependencies
|
||||
with:
|
||||
requirements: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f
|
||||
|
||||
# We don't need to build the Web Docker image since it's not yet used
|
||||
# in the integration tests. We have a separate action to verify that it builds
|
||||
# successfully.
|
||||
- name: Pull Model Server Docker image
|
||||
run: |
|
||||
docker pull onyxdotapp/onyx-model-server:latest
|
||||
docker tag onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:test
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # ratchet:actions/setup-python@v6
|
||||
- name: Build and load
|
||||
uses: docker/bake-action@5be5f02ff8819ecd3092ea6b2e6261c31774f2b4 # ratchet:docker/bake-action@v6
|
||||
env:
|
||||
TAG: model-server-${{ github.run_id }}
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
|
||||
- name: Install Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
load: true
|
||||
targets: model-server
|
||||
set: |
|
||||
model-server.cache-from=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }}
|
||||
model-server.cache-from=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
model-server.cache-from=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
|
||||
model-server.cache-from=type=registry,ref=onyxdotapp/onyx-model-server:latest
|
||||
model-server.cache-to=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max
|
||||
model-server.cache-to=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
model-server.cache-to=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
|
||||
|
||||
- name: Start Docker containers
|
||||
id: start_docker
|
||||
env:
|
||||
IMAGE_TAG: model-server-${{ github.run_id }}
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
AUTH_TYPE=basic \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.model-server-test.yml up -d indexing_model_server
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Service did not become ready in 5 minutes."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Use curl with error handling to ignore specific exit code 56
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:9000/api/health || echo "curl_error")
|
||||
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "Service is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
|
||||
else
|
||||
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
|
||||
sleep 5
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
docker compose \
|
||||
-f docker-compose.yml \
|
||||
-f docker-compose.dev.yml \
|
||||
up -d --wait \
|
||||
inference_model_server
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/llm
|
||||
py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/embedding
|
||||
|
||||
- name: Alert on Failure
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
env:
|
||||
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
|
||||
REPO: ${{ github.repository }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
curl -X POST \
|
||||
-H 'Content-type: application/json' \
|
||||
--data "{\"text\":\"Scheduled Model Tests failed! Check the run at: https://github.com/${REPO}/actions/runs/${RUN_ID}\"}" \
|
||||
$SLACK_WEBHOOK
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
|
||||
failed-jobs: model-check
|
||||
title: "🚨 Scheduled Model Tests failed!"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
|
||||
- name: Dump all-container logs (optional)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.model-server-test.yml logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -1,5 +1,8 @@
|
||||
# editors
|
||||
.vscode
|
||||
!/.vscode/env_template.txt
|
||||
!/.vscode/launch.json
|
||||
!/.vscode/tasks.template.jsonc
|
||||
.zed
|
||||
.cursor
|
||||
|
||||
@@ -21,6 +24,7 @@ backend/tests/regression/search_quality/*.json
|
||||
backend/onyx/evals/data/
|
||||
backend/onyx/evals/one_off/*.json
|
||||
*.log
|
||||
*.csv
|
||||
|
||||
# secret files
|
||||
.env
|
||||
|
||||
@@ -11,7 +11,6 @@ repos:
|
||||
- id: uv-sync
|
||||
args: ["--locked", "--all-extras"]
|
||||
- id: uv-lock
|
||||
files: ^pyproject\.toml$
|
||||
- id: uv-export
|
||||
name: uv-export default.txt
|
||||
args:
|
||||
@@ -75,6 +74,13 @@ repos:
|
||||
# pass_filenames: true
|
||||
# files: ^backend/.*\.py$
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: 3e8a8703264a2f4a69428a0aa4dcb512790b2c8c # frozen: v6.0.0
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
name: Check for added large files
|
||||
args: ["--maxkb=1500"]
|
||||
|
||||
- repo: https://github.com/rhysd/actionlint
|
||||
rev: a443f344ff32813837fa49f7aa6cbc478d770e62 # frozen: v1.7.9
|
||||
hooks:
|
||||
@@ -147,6 +153,22 @@ repos:
|
||||
pass_filenames: false
|
||||
files: \.tf$
|
||||
|
||||
- id: npm-install
|
||||
name: npm install
|
||||
description: "Automatically run 'npm install' after a checkout, pull or rebase"
|
||||
language: system
|
||||
entry: bash -c 'cd web && npm install --no-save'
|
||||
pass_filenames: false
|
||||
files: ^web/package(-lock)?\.json$
|
||||
stages: [post-checkout, post-merge, post-rewrite]
|
||||
- id: npm-install-check
|
||||
name: npm install --package-lock-only
|
||||
description: "Check the 'web/package-lock.json' is updated"
|
||||
language: system
|
||||
entry: bash -c 'cd web && npm install --package-lock-only'
|
||||
pass_filenames: false
|
||||
files: ^web/package(-lock)?\.json$
|
||||
|
||||
# Uses tsgo (TypeScript's native Go compiler) for ~10x faster type checking.
|
||||
# This is a preview package - if it breaks:
|
||||
# 1. Try updating: cd web && npm update @typescript/native-preview
|
||||
|
||||
6
.vscode/env_template.txt
vendored
6
.vscode/env_template.txt
vendored
@@ -17,12 +17,6 @@ LOG_ONYX_MODEL_INTERACTIONS=True
|
||||
LOG_LEVEL=debug
|
||||
|
||||
|
||||
# This passes top N results to LLM an additional time for reranking prior to
|
||||
# answer generation.
|
||||
# This step is quite heavy on token usage so we disable it for dev generally.
|
||||
DISABLE_LLM_DOC_RELEVANCE=False
|
||||
|
||||
|
||||
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically).
|
||||
OAUTH_CLIENT_ID=<REPLACE THIS>
|
||||
OAUTH_CLIENT_SECRET=<REPLACE THIS>
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
/* Copy this file into '.vscode/launch.json' or merge its contents into your existing configurations. */
|
||||
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
@@ -24,7 +22,7 @@
|
||||
"Slack Bot",
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
"Celery background",
|
||||
"Celery heavy",
|
||||
"Celery docfetching",
|
||||
"Celery docprocessing",
|
||||
"Celery beat"
|
||||
@@ -579,6 +577,99 @@
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Database ---",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"presentation": {
|
||||
"group": "4",
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clean restore seeded database dump (destructive)",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"run",
|
||||
"--with",
|
||||
"onyx-devtools",
|
||||
"ods",
|
||||
"db",
|
||||
"restore",
|
||||
"--fetch-seeded",
|
||||
"--clean",
|
||||
"--yes"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "4"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Create database snapshot",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"run",
|
||||
"--with",
|
||||
"onyx-devtools",
|
||||
"ods",
|
||||
"db",
|
||||
"dump",
|
||||
"backup.dump"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "4"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clean restore database snapshot (destructive)",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"run",
|
||||
"--with",
|
||||
"onyx-devtools",
|
||||
"ods",
|
||||
"db",
|
||||
"restore",
|
||||
"--clean",
|
||||
"--yes",
|
||||
"backup.dump"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "4"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Upgrade database to head revision",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"run",
|
||||
"--with",
|
||||
"onyx-devtools",
|
||||
"ods",
|
||||
"db",
|
||||
"upgrade"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "4"
|
||||
}
|
||||
},
|
||||
{
|
||||
// script to generate the openapi schema
|
||||
"name": "Onyx OpenAPI Schema Generator",
|
||||
259
CONTRIBUTING.md
259
CONTRIBUTING.md
@@ -1,262 +1,31 @@
|
||||
<!-- ONYX_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/CONTRIBUTING.md"} -->
|
||||
|
||||
# Contributing to Onyx
|
||||
|
||||
Hey there! We are so excited that you're interested in Onyx.
|
||||
|
||||
As an open source project in a rapidly changing space, we welcome all contributions.
|
||||
|
||||
## 💃 Guidelines
|
||||
## Contribution Opportunities
|
||||
The [GitHub Issues](https://github.com/onyx-dot-app/onyx/issues) page is a great place to look for and share contribution ideas.
|
||||
|
||||
### Contribution Opportunities
|
||||
If you have your own feature that you would like to build please create an issue and community members can provide feedback and
|
||||
thumb it up if they feel a common need.
|
||||
|
||||
The [GitHub Issues](https://github.com/onyx-dot-app/onyx/issues) page is a great place to start for contribution ideas.
|
||||
|
||||
To ensure that your contribution is aligned with the project's direction, please reach out to any maintainer on the Onyx team
|
||||
via [Discord](https://discord.gg/4NA5SbzrWb) or [email](mailto:hello@onyx.app).
|
||||
## Contributing Code
|
||||
Please reference the documents in contributing_guides folder to ensure that the code base is kept to a high standard.
|
||||
1. dev_setup.md (start here): gives you a guide to setting up a local development environment.
|
||||
2. contribution_process.md: how to ensure you are building valuable features that will get reviewed and merged.
|
||||
3. best_practices.md: before asking for reviews, ensure your changes meet the repo code quality standards.
|
||||
|
||||
Issues that have been explicitly approved by the maintainers (aligned with the direction of the project)
|
||||
will be marked with the `approved by maintainers` label.
|
||||
Issues marked `good first issue` are an especially great place to start.
|
||||
|
||||
**Connectors** to other tools are another great place to contribute. For details on how, refer to this
|
||||
[README.md](https://github.com/onyx-dot-app/onyx/blob/main/backend/onyx/connectors/README.md).
|
||||
|
||||
If you have a new/different contribution in mind, we'd love to hear about it!
|
||||
Your input is vital to making sure that Onyx moves in the right direction.
|
||||
Before starting on implementation, please raise a GitHub issue.
|
||||
|
||||
Also, always feel free to message the founders (Chris Weaver / Yuhong Sun) on
|
||||
[Discord](https://discord.gg/4NA5SbzrWb) directly about anything at all.
|
||||
|
||||
### Contributing Code
|
||||
|
||||
To contribute to this project, please follow the
|
||||
To contribute, please follow the
|
||||
["fork and pull request"](https://docs.github.com/en/get-started/quickstart/contributing-to-projects) workflow.
|
||||
When opening a pull request, mention related issues and feel free to tag relevant maintainers.
|
||||
|
||||
Before creating a pull request please make sure that the new changes conform to the formatting and linting requirements.
|
||||
See the [Formatting and Linting](#formatting-and-linting) section for how to run these checks locally.
|
||||
|
||||
### Getting Help 🙋
|
||||
## Getting Help 🙋
|
||||
We have support channels and generally interesting discussions on our [Discord](https://discord.gg/4NA5SbzrWb).
|
||||
|
||||
Our goal is to make contributing as easy as possible. If you run into any issues please don't hesitate to reach out.
|
||||
That way we can help future contributors and users can avoid the same issue.
|
||||
See you there!
|
||||
|
||||
We also have support channels and generally interesting discussions on our
|
||||
[Discord](https://discord.gg/4NA5SbzrWb).
|
||||
|
||||
We would love to see you there!
|
||||
|
||||
## Get Started 🚀
|
||||
|
||||
Onyx being a fully functional app, relies on some external software, specifically:
|
||||
|
||||
- [Postgres](https://www.postgresql.org/) (Relational DB)
|
||||
- [Vespa](https://vespa.ai/) (Vector DB/Search Engine)
|
||||
- [Redis](https://redis.io/) (Cache)
|
||||
- [MinIO](https://min.io/) (File Store)
|
||||
- [Nginx](https://nginx.org/) (Not needed for development flows generally)
|
||||
|
||||
> **Note:**
|
||||
> This guide provides instructions to build and run Onyx locally from source with Docker containers providing the above external software. We believe this combination is easier for
|
||||
> development purposes. If you prefer to use pre-built container images, we provide instructions on running the full Onyx stack within Docker below.
|
||||
|
||||
### Local Set Up
|
||||
|
||||
Be sure to use Python version 3.11. For instructions on installing Python 3.11 on macOS, refer to the [CONTRIBUTING_MACOS.md](./CONTRIBUTING_MACOS.md) readme.
|
||||
|
||||
If using a lower version, modifications will have to be made to the code.
|
||||
If using a higher version, sometimes some libraries will not be available (i.e. we had problems with Tensorflow in the past with higher versions of python).
|
||||
|
||||
#### Backend: Python requirements
|
||||
|
||||
Currently, we use [uv](https://docs.astral.sh/uv/) and recommend creating a [virtual environment](https://docs.astral.sh/uv/pip/environments/#using-a-virtual-environment).
|
||||
|
||||
For convenience here's a command for it:
|
||||
|
||||
```bash
|
||||
uv venv .venv --python 3.11
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
_For Windows, activate the virtual environment using Command Prompt:_
|
||||
|
||||
```bash
|
||||
.venv\Scripts\activate
|
||||
```
|
||||
|
||||
If using PowerShell, the command slightly differs:
|
||||
|
||||
```powershell
|
||||
.venv\Scripts\Activate.ps1
|
||||
```
|
||||
|
||||
Install the required python dependencies:
|
||||
|
||||
```bash
|
||||
uv sync --all-extras
|
||||
```
|
||||
|
||||
Install Playwright for Python (headless browser required by the Web Connector):
|
||||
|
||||
```bash
|
||||
uv run playwright install
|
||||
```
|
||||
|
||||
#### Frontend: Node dependencies
|
||||
|
||||
Onyx uses Node v22.20.0. We highly recommend you use [Node Version Manager (nvm)](https://github.com/nvm-sh/nvm)
|
||||
to manage your Node installations. Once installed, you can run
|
||||
|
||||
```bash
|
||||
nvm install 22 && nvm use 22
|
||||
node -v # verify your active version
|
||||
```
|
||||
|
||||
Navigate to `onyx/web` and run:
|
||||
|
||||
```bash
|
||||
npm i
|
||||
```
|
||||
|
||||
## Formatting and Linting
|
||||
|
||||
### Backend
|
||||
|
||||
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
|
||||
|
||||
Then run:
|
||||
|
||||
```bash
|
||||
uv run pre-commit install
|
||||
```
|
||||
|
||||
Additionally, we use `mypy` for static type checking.
|
||||
Onyx is fully type-annotated, and we want to keep it that way!
|
||||
To run the mypy checks manually, run `uv run mypy .` from the `onyx/backend` directory.
|
||||
|
||||
### Web
|
||||
|
||||
We use `prettier` for formatting. The desired version will be installed via a `npm i` from the `onyx/web` directory.
|
||||
To run the formatter, use `npx prettier --write .` from the `onyx/web` directory.
|
||||
|
||||
Pre-commit will also run prettier automatically on files you've recently touched. If re-formatted, your commit will fail.
|
||||
Re-stage your changes and commit again.
|
||||
|
||||
# Running the application for development
|
||||
|
||||
## Developing using VSCode Debugger (recommended)
|
||||
|
||||
**We highly recommend using VSCode debugger for development.**
|
||||
See [CONTRIBUTING_VSCODE.md](./CONTRIBUTING_VSCODE.md) for more details.
|
||||
|
||||
Otherwise, you can follow the instructions below to run the application for development.
|
||||
|
||||
## Manually running the application for development
|
||||
### Docker containers for external software
|
||||
|
||||
You will need Docker installed to run these containers.
|
||||
|
||||
First navigate to `onyx/deployment/docker_compose`, then start up Postgres/Vespa/Redis/MinIO with:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d index relational_db cache minio
|
||||
```
|
||||
|
||||
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)
|
||||
|
||||
### Running Onyx locally
|
||||
|
||||
To start the frontend, navigate to `onyx/web` and run:
|
||||
|
||||
```bash
|
||||
npm run dev
|
||||
```
|
||||
|
||||
Next, start the model server which runs the local NLP models.
|
||||
Navigate to `onyx/backend` and run:
|
||||
|
||||
```bash
|
||||
uvicorn model_server.main:app --reload --port 9000
|
||||
```
|
||||
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
|
||||
```bash
|
||||
powershell -Command "uvicorn model_server.main:app --reload --port 9000"
|
||||
```
|
||||
|
||||
The first time running Onyx, you will need to run the DB migrations for Postgres.
|
||||
After the first time, this is no longer required unless the DB models change.
|
||||
|
||||
Navigate to `onyx/backend` and with the venv active, run:
|
||||
|
||||
```bash
|
||||
alembic upgrade head
|
||||
```
|
||||
|
||||
Next, start the task queue which orchestrates the background jobs.
|
||||
Jobs that take more time are run async from the API server.
|
||||
|
||||
Still in `onyx/backend`, run:
|
||||
|
||||
```bash
|
||||
python ./scripts/dev_run_background_jobs.py
|
||||
```
|
||||
|
||||
To run the backend API server, navigate back to `onyx/backend` and run:
|
||||
|
||||
```bash
|
||||
AUTH_TYPE=disabled uvicorn onyx.main:app --reload --port 8080
|
||||
```
|
||||
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
|
||||
```bash
|
||||
powershell -Command "
|
||||
$env:AUTH_TYPE='disabled'
|
||||
uvicorn onyx.main:app --reload --port 8080
|
||||
"
|
||||
```
|
||||
|
||||
> **Note:**
|
||||
> If you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services.
|
||||
|
||||
#### Wrapping up
|
||||
|
||||
You should now have 4 servers running:
|
||||
|
||||
- Web server
|
||||
- Backend API
|
||||
- Model server
|
||||
- Background jobs
|
||||
|
||||
Now, visit `http://localhost:3000` in your browser. You should see the Onyx onboarding wizard where you can connect your external LLM provider to Onyx.
|
||||
|
||||
You've successfully set up a local Onyx instance! 🏁
|
||||
|
||||
#### Running the Onyx application in a container
|
||||
|
||||
You can run the full Onyx application stack from pre-built images including all external software dependencies.
|
||||
|
||||
Navigate to `onyx/deployment/docker_compose` and run:
|
||||
|
||||
```bash
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
After Docker pulls and starts these containers, navigate to `http://localhost:3000` to use Onyx.
|
||||
|
||||
If you want to make changes to Onyx and run those changes in Docker, you can also build a local version of the Onyx container images that incorporates your changes like so:
|
||||
|
||||
```bash
|
||||
docker compose up -d --build
|
||||
```
|
||||
|
||||
|
||||
### Release Process
|
||||
|
||||
## Release Process
|
||||
Onyx loosely follows the SemVer versioning standard.
|
||||
Major changes are released with a "minor" version bump. Currently we use patch release versions to indicate small feature changes.
|
||||
A set of Docker containers will be pushed automatically to DockerHub with every tag.
|
||||
|
||||
@@ -37,10 +37,6 @@ CVE-2023-50868
|
||||
CVE-2023-52425
|
||||
CVE-2024-28757
|
||||
|
||||
# sqlite, only used by NLTK library to grab word lemmatizer and stopwords
|
||||
# No impact in our settings
|
||||
CVE-2023-7104
|
||||
|
||||
# libharfbuzz0b, O(n^2) growth, worst case is denial of service
|
||||
# Accept the risk
|
||||
CVE-2023-25193
|
||||
|
||||
@@ -89,12 +89,6 @@ RUN uv pip install --system --no-cache-dir --upgrade \
|
||||
RUN python -c "from tokenizers import Tokenizer; \
|
||||
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
|
||||
# Pre-downloading NLTK for setups with limited egress
|
||||
RUN python -c "import nltk; \
|
||||
nltk.download('stopwords', quiet=True); \
|
||||
nltk.download('punkt_tab', quiet=True);"
|
||||
# nltk.download('wordnet', quiet=True); introduce this back if lemmatization is needed
|
||||
|
||||
# Pre-downloading tiktoken for setups with limited egress
|
||||
RUN python -c "import tiktoken; \
|
||||
tiktoken.get_encoding('cl100k_base')"
|
||||
|
||||
@@ -225,7 +225,6 @@ def do_run_migrations(
|
||||
) -> None:
|
||||
if create_schema:
|
||||
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
|
||||
connection.execute(text("COMMIT"))
|
||||
|
||||
connection.execute(text(f'SET search_path TO "{schema_name}"'))
|
||||
|
||||
@@ -309,6 +308,7 @@ async def run_async_migrations() -> None:
|
||||
schema_name=schema,
|
||||
create_schema=create_schema,
|
||||
)
|
||||
await connection.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema}: {e}")
|
||||
if not continue_on_error:
|
||||
@@ -346,6 +346,7 @@ async def run_async_migrations() -> None:
|
||||
schema_name=schema,
|
||||
create_schema=create_schema,
|
||||
)
|
||||
await connection.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema}: {e}")
|
||||
if not continue_on_error:
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
"""add_unique_constraint_to_inputprompt_prompt_user_id
|
||||
|
||||
Revision ID: 2c2430828bdf
|
||||
Revises: fb80bdd256de
|
||||
Create Date: 2026-01-20 16:01:54.314805
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2c2430828bdf"
|
||||
down_revision = "fb80bdd256de"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create unique constraint on (prompt, user_id) for user-owned prompts
|
||||
# This ensures each user can only have one shortcut with a given name
|
||||
op.create_unique_constraint(
|
||||
"uq_inputprompt_prompt_user_id",
|
||||
"inputprompt",
|
||||
["prompt", "user_id"],
|
||||
)
|
||||
|
||||
# Create partial unique index for public prompts (where user_id IS NULL)
|
||||
# PostgreSQL unique constraints don't enforce uniqueness for NULL values,
|
||||
# so we need a partial index to ensure public prompt names are also unique
|
||||
op.execute(
|
||||
"""
|
||||
CREATE UNIQUE INDEX uq_inputprompt_prompt_public
|
||||
ON inputprompt (prompt)
|
||||
WHERE user_id IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS uq_inputprompt_prompt_public")
|
||||
op.drop_constraint("uq_inputprompt_prompt_user_id", "inputprompt", type_="unique")
|
||||
@@ -0,0 +1,29 @@
|
||||
"""remove default prompt shortcuts
|
||||
|
||||
Revision ID: 41fa44bef321
|
||||
Revises: 2c2430828bdf
|
||||
Create Date: 2025-01-21
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "41fa44bef321"
|
||||
down_revision = "2c2430828bdf"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Delete any user associations for the default prompts first (foreign key constraint)
|
||||
op.execute(
|
||||
"DELETE FROM inputprompt__user WHERE input_prompt_id IN (SELECT id FROM inputprompt WHERE id < 0)"
|
||||
)
|
||||
# Delete the pre-seeded default prompt shortcuts (they have negative IDs)
|
||||
op.execute("DELETE FROM inputprompt WHERE id < 0")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# We don't restore the default prompts on downgrade
|
||||
pass
|
||||
@@ -85,103 +85,122 @@ class UserRow(NamedTuple):
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Start transaction
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
# Step 1: Create or update the unified assistant (ID 0)
|
||||
search_assistant = conn.execute(
|
||||
sa.text("SELECT * FROM persona WHERE id = 0")
|
||||
).fetchone()
|
||||
|
||||
try:
|
||||
# Step 1: Create or update the unified assistant (ID 0)
|
||||
search_assistant = conn.execute(
|
||||
sa.text("SELECT * FROM persona WHERE id = 0")
|
||||
).fetchone()
|
||||
|
||||
if search_assistant:
|
||||
# Update existing Search assistant to be the unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET name = :name,
|
||||
description = :description,
|
||||
system_prompt = :system_prompt,
|
||||
num_chunks = :num_chunks,
|
||||
is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false,
|
||||
display_priority = :display_priority,
|
||||
llm_filter_extraction = :llm_filter_extraction,
|
||||
llm_relevance_filter = :llm_relevance_filter,
|
||||
recency_bias = :recency_bias,
|
||||
chunks_above = :chunks_above,
|
||||
chunks_below = :chunks_below,
|
||||
datetime_aware = :datetime_aware,
|
||||
starter_messages = null
|
||||
WHERE id = 0
|
||||
"""
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
else:
|
||||
# Create new unified assistant with ID 0
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona (
|
||||
id, name, description, system_prompt, num_chunks,
|
||||
is_default_persona, is_visible, deleted, display_priority,
|
||||
llm_filter_extraction, llm_relevance_filter, recency_bias,
|
||||
chunks_above, chunks_below, datetime_aware, starter_messages,
|
||||
builtin_persona
|
||||
) VALUES (
|
||||
0, :name, :description, :system_prompt, :num_chunks,
|
||||
true, true, false, :display_priority, :llm_filter_extraction,
|
||||
:llm_relevance_filter, :recency_bias, :chunks_above, :chunks_below,
|
||||
:datetime_aware, null, true
|
||||
)
|
||||
"""
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
|
||||
# Step 2: Mark ALL builtin assistants as deleted (except the unified assistant ID 0)
|
||||
if search_assistant:
|
||||
# Update existing Search assistant to be the unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = true, is_visible = false, is_default_persona = false
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
SET name = :name,
|
||||
description = :description,
|
||||
system_prompt = :system_prompt,
|
||||
num_chunks = :num_chunks,
|
||||
is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false,
|
||||
display_priority = :display_priority,
|
||||
llm_filter_extraction = :llm_filter_extraction,
|
||||
llm_relevance_filter = :llm_relevance_filter,
|
||||
recency_bias = :recency_bias,
|
||||
chunks_above = :chunks_above,
|
||||
chunks_below = :chunks_below,
|
||||
datetime_aware = :datetime_aware,
|
||||
starter_messages = null
|
||||
WHERE id = 0
|
||||
"""
|
||||
)
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
else:
|
||||
# Create new unified assistant with ID 0
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona (
|
||||
id, name, description, system_prompt, num_chunks,
|
||||
is_default_persona, is_visible, deleted, display_priority,
|
||||
llm_filter_extraction, llm_relevance_filter, recency_bias,
|
||||
chunks_above, chunks_below, datetime_aware, starter_messages,
|
||||
builtin_persona
|
||||
) VALUES (
|
||||
0, :name, :description, :system_prompt, :num_chunks,
|
||||
true, true, false, :display_priority, :llm_filter_extraction,
|
||||
:llm_relevance_filter, :recency_bias, :chunks_above, :chunks_below,
|
||||
:datetime_aware, null, true
|
||||
)
|
||||
"""
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
|
||||
# Step 3: Add all built-in tools to the unified assistant
|
||||
# First, get the tool IDs for SearchTool, ImageGenerationTool, and WebSearchTool
|
||||
search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'SearchTool'")
|
||||
).fetchone()
|
||||
# Step 2: Mark ALL builtin assistants as deleted (except the unified assistant ID 0)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = true, is_visible = false, is_default_persona = false
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
if not search_tool:
|
||||
raise ValueError(
|
||||
"SearchTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
# Step 3: Add all built-in tools to the unified assistant
|
||||
# First, get the tool IDs for SearchTool, ImageGenerationTool, and WebSearchTool
|
||||
search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'SearchTool'")
|
||||
).fetchone()
|
||||
|
||||
image_gen_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'ImageGenerationTool'")
|
||||
).fetchone()
|
||||
if not search_tool:
|
||||
raise ValueError(
|
||||
"SearchTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
|
||||
if not image_gen_tool:
|
||||
raise ValueError(
|
||||
"ImageGenerationTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
image_gen_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'ImageGenerationTool'")
|
||||
).fetchone()
|
||||
|
||||
# WebSearchTool is optional - may not be configured
|
||||
web_search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'WebSearchTool'")
|
||||
).fetchone()
|
||||
if not image_gen_tool:
|
||||
raise ValueError(
|
||||
"ImageGenerationTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
|
||||
# Clear existing tool associations for persona 0
|
||||
conn.execute(sa.text("DELETE FROM persona__tool WHERE persona_id = 0"))
|
||||
# WebSearchTool is optional - may not be configured
|
||||
web_search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'WebSearchTool'")
|
||||
).fetchone()
|
||||
|
||||
# Add tools to the unified assistant
|
||||
# Clear existing tool associations for persona 0
|
||||
conn.execute(sa.text("DELETE FROM persona__tool WHERE persona_id = 0"))
|
||||
|
||||
# Add tools to the unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": search_tool[0]},
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": image_gen_tool[0]},
|
||||
)
|
||||
|
||||
if web_search_tool:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
@@ -190,191 +209,148 @@ def upgrade() -> None:
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": search_tool[0]},
|
||||
{"tool_id": web_search_tool[0]},
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
# Step 4: Migrate existing chat sessions from all builtin assistants to unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
),
|
||||
{"tool_id": image_gen_tool[0]},
|
||||
UPDATE chat_session
|
||||
SET persona_id = 0
|
||||
WHERE persona_id IN (
|
||||
SELECT id FROM persona WHERE builtin_persona = true AND id != 0
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
if web_search_tool:
|
||||
# Step 5: Migrate user preferences - remove references to all builtin assistants
|
||||
# First, get all builtin assistant IDs (except 0)
|
||||
builtin_assistants_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id FROM persona
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
builtin_assistant_ids = [row[0] for row in builtin_assistants_result]
|
||||
|
||||
# Get all users with preferences
|
||||
users_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, chosen_assistants, visible_assistants,
|
||||
hidden_assistants, pinned_assistants
|
||||
FROM "user"
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for user_row in users_result:
|
||||
user = UserRow(*user_row)
|
||||
user_id: UUID = user.id
|
||||
updates: dict[str, Any] = {}
|
||||
|
||||
# Remove all builtin assistants from chosen_assistants
|
||||
if user.chosen_assistants:
|
||||
new_chosen: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.chosen_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_chosen != user.chosen_assistants:
|
||||
updates["chosen_assistants"] = json.dumps(new_chosen)
|
||||
|
||||
# Remove all builtin assistants from visible_assistants
|
||||
if user.visible_assistants:
|
||||
new_visible: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.visible_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_visible != user.visible_assistants:
|
||||
updates["visible_assistants"] = json.dumps(new_visible)
|
||||
|
||||
# Add all builtin assistants to hidden_assistants
|
||||
if user.hidden_assistants:
|
||||
new_hidden: list[int] = list(user.hidden_assistants)
|
||||
for old_id in builtin_assistant_ids:
|
||||
if old_id not in new_hidden:
|
||||
new_hidden.append(old_id)
|
||||
if new_hidden != user.hidden_assistants:
|
||||
updates["hidden_assistants"] = json.dumps(new_hidden)
|
||||
else:
|
||||
updates["hidden_assistants"] = json.dumps(builtin_assistant_ids)
|
||||
|
||||
# Remove all builtin assistants from pinned_assistants
|
||||
if user.pinned_assistants:
|
||||
new_pinned: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.pinned_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_pinned != user.pinned_assistants:
|
||||
updates["pinned_assistants"] = json.dumps(new_pinned)
|
||||
|
||||
# Apply updates if any
|
||||
if updates:
|
||||
set_clause = ", ".join([f"{k} = :{k}" for k in updates.keys()])
|
||||
updates["user_id"] = str(user_id) # Convert UUID to string for SQL
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": web_search_tool[0]},
|
||||
sa.text(f'UPDATE "user" SET {set_clause} WHERE id = :user_id'),
|
||||
updates,
|
||||
)
|
||||
|
||||
# Step 4: Migrate existing chat sessions from all builtin assistants to unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE chat_session
|
||||
SET persona_id = 0
|
||||
WHERE persona_id IN (
|
||||
SELECT id FROM persona WHERE builtin_persona = true AND id != 0
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Step 5: Migrate user preferences - remove references to all builtin assistants
|
||||
# First, get all builtin assistant IDs (except 0)
|
||||
builtin_assistants_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id FROM persona
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
builtin_assistant_ids = [row[0] for row in builtin_assistants_result]
|
||||
|
||||
# Get all users with preferences
|
||||
users_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, chosen_assistants, visible_assistants,
|
||||
hidden_assistants, pinned_assistants
|
||||
FROM "user"
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for user_row in users_result:
|
||||
user = UserRow(*user_row)
|
||||
user_id: UUID = user.id
|
||||
updates: dict[str, Any] = {}
|
||||
|
||||
# Remove all builtin assistants from chosen_assistants
|
||||
if user.chosen_assistants:
|
||||
new_chosen: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.chosen_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_chosen != user.chosen_assistants:
|
||||
updates["chosen_assistants"] = json.dumps(new_chosen)
|
||||
|
||||
# Remove all builtin assistants from visible_assistants
|
||||
if user.visible_assistants:
|
||||
new_visible: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.visible_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_visible != user.visible_assistants:
|
||||
updates["visible_assistants"] = json.dumps(new_visible)
|
||||
|
||||
# Add all builtin assistants to hidden_assistants
|
||||
if user.hidden_assistants:
|
||||
new_hidden: list[int] = list(user.hidden_assistants)
|
||||
for old_id in builtin_assistant_ids:
|
||||
if old_id not in new_hidden:
|
||||
new_hidden.append(old_id)
|
||||
if new_hidden != user.hidden_assistants:
|
||||
updates["hidden_assistants"] = json.dumps(new_hidden)
|
||||
else:
|
||||
updates["hidden_assistants"] = json.dumps(builtin_assistant_ids)
|
||||
|
||||
# Remove all builtin assistants from pinned_assistants
|
||||
if user.pinned_assistants:
|
||||
new_pinned: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.pinned_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_pinned != user.pinned_assistants:
|
||||
updates["pinned_assistants"] = json.dumps(new_pinned)
|
||||
|
||||
# Apply updates if any
|
||||
if updates:
|
||||
set_clause = ", ".join([f"{k} = :{k}" for k in updates.keys()])
|
||||
updates["user_id"] = str(user_id) # Convert UUID to string for SQL
|
||||
conn.execute(
|
||||
sa.text(f'UPDATE "user" SET {set_clause} WHERE id = :user_id'),
|
||||
updates,
|
||||
)
|
||||
|
||||
# Commit transaction
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on error
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Start transaction
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
try:
|
||||
# Only restore General (ID -1) and Art (ID -3) assistants
|
||||
# Step 1: Keep Search assistant (ID 0) as default but restore original state
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false
|
||||
WHERE id = 0
|
||||
# Only restore General (ID -1) and Art (ID -3) assistants
|
||||
# Step 1: Keep Search assistant (ID 0) as default but restore original state
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
)
|
||||
UPDATE persona
|
||||
SET is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false
|
||||
WHERE id = 0
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Step 2: Restore General assistant (ID -1)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :general_assistant_id
|
||||
# Step 2: Restore General assistant (ID -1)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
),
|
||||
{"general_assistant_id": GENERAL_ASSISTANT_ID},
|
||||
)
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :general_assistant_id
|
||||
"""
|
||||
),
|
||||
{"general_assistant_id": GENERAL_ASSISTANT_ID},
|
||||
)
|
||||
|
||||
# Step 3: Restore Art assistant (ID -3)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :art_assistant_id
|
||||
# Step 3: Restore Art assistant (ID -3)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
),
|
||||
{"art_assistant_id": ART_ASSISTANT_ID},
|
||||
)
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :art_assistant_id
|
||||
"""
|
||||
),
|
||||
{"art_assistant_id": ART_ASSISTANT_ID},
|
||||
)
|
||||
|
||||
# Note: We don't restore the original tool associations, names, or descriptions
|
||||
# as those would require more complex logic to determine original state.
|
||||
# We also cannot restore original chat session persona_ids as we don't
|
||||
# have the original mappings.
|
||||
# Other builtin assistants remain deleted as per the requirement.
|
||||
|
||||
# Commit transaction
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on error
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
# Note: We don't restore the original tool associations, names, or descriptions
|
||||
# as those would require more complex logic to determine original state.
|
||||
# We also cannot restore original chat session persona_ids as we don't
|
||||
# have the original mappings.
|
||||
# Other builtin assistants remain deleted as per the requirement.
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
"""add_search_query_table
|
||||
|
||||
Revision ID: 73e9983e5091
|
||||
Revises: d1b637d7050a
|
||||
Create Date: 2026-01-14 14:16:52.837489
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "73e9983e5091"
|
||||
down_revision = "d1b637d7050a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"search_query",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("user.id"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("query", sa.String(), nullable=False),
|
||||
sa.Column("query_expansions", postgresql.ARRAY(sa.String()), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
)
|
||||
|
||||
op.create_index("ix_search_query_user_id", "search_query", ["user_id"])
|
||||
op.create_index("ix_search_query_created_at", "search_query", ["created_at"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_search_query_created_at", table_name="search_query")
|
||||
op.drop_index("ix_search_query_user_id", table_name="search_query")
|
||||
op.drop_table("search_query")
|
||||
@@ -10,8 +10,7 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.enums import RecencyBiasSetting, SearchType
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "776b3bbe9092"
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
"""notifications constraint, sort index, and cleanup old notifications
|
||||
|
||||
Revision ID: 8405ca81cc83
|
||||
Revises: a3c1a7904cd0
|
||||
Create Date: 2026-01-07 16:43:44.855156
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8405ca81cc83"
|
||||
down_revision = "a3c1a7904cd0"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create unique index for notification deduplication.
|
||||
# This enables atomic ON CONFLICT DO NOTHING inserts in batch_create_notifications.
|
||||
#
|
||||
# Uses COALESCE to handle NULL additional_data (NULLs are normally distinct
|
||||
# in unique constraints, but we want NULL == NULL for deduplication).
|
||||
# The '{}' represents an empty JSONB object as the NULL replacement.
|
||||
|
||||
# Clean up legacy notifications first
|
||||
op.execute("DELETE FROM notification WHERE title = 'New Notification'")
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS ix_notification_user_type_data
|
||||
ON notification (user_id, notif_type, COALESCE(additional_data, '{}'::jsonb))
|
||||
"""
|
||||
)
|
||||
|
||||
# Create index for efficient notification sorting by user
|
||||
# Covers: WHERE user_id = ? ORDER BY dismissed, first_shown DESC
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS ix_notification_user_sort
|
||||
ON notification (user_id, dismissed, first_shown DESC)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS ix_notification_user_type_data")
|
||||
op.execute("DROP INDEX IF EXISTS ix_notification_user_sort")
|
||||
116
backend/alembic/versions/8b5ce697290e_add_discord_bot_tables.py
Normal file
116
backend/alembic/versions/8b5ce697290e_add_discord_bot_tables.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Add Discord bot tables
|
||||
|
||||
Revision ID: 8b5ce697290e
|
||||
Revises: a1b2c3d4e5f7
|
||||
Create Date: 2025-01-14
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8b5ce697290e"
|
||||
down_revision = "a1b2c3d4e5f7"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# DiscordBotConfig (singleton table - one per tenant)
|
||||
op.create_table(
|
||||
"discord_bot_config",
|
||||
sa.Column(
|
||||
"id",
|
||||
sa.String(),
|
||||
primary_key=True,
|
||||
server_default=sa.text("'SINGLETON'"),
|
||||
),
|
||||
sa.Column("bot_token", sa.LargeBinary(), nullable=False), # EncryptedString
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.CheckConstraint("id = 'SINGLETON'", name="ck_discord_bot_config_singleton"),
|
||||
)
|
||||
|
||||
# DiscordGuildConfig
|
||||
op.create_table(
|
||||
"discord_guild_config",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("guild_id", sa.BigInteger(), nullable=True, unique=True),
|
||||
sa.Column("guild_name", sa.String(), nullable=True),
|
||||
sa.Column("registration_key", sa.String(), nullable=False, unique=True),
|
||||
sa.Column("registered_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"default_persona_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("persona.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False
|
||||
),
|
||||
)
|
||||
|
||||
# DiscordChannelConfig
|
||||
op.create_table(
|
||||
"discord_channel_config",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column(
|
||||
"guild_config_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("discord_guild_config.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("channel_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("channel_name", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"channel_type",
|
||||
sa.String(20),
|
||||
server_default=sa.text("'text'"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"is_private",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("false"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"thread_only_mode",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("false"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"require_bot_invocation",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("true"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"persona_override_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("persona.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"enabled", sa.Boolean(), server_default=sa.text("false"), nullable=False
|
||||
),
|
||||
)
|
||||
|
||||
# Unique constraint: one config per channel per guild
|
||||
op.create_unique_constraint(
|
||||
"uq_discord_channel_guild_channel",
|
||||
"discord_channel_config",
|
||||
["guild_config_id", "channel_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("discord_channel_config")
|
||||
op.drop_table("discord_guild_config")
|
||||
op.drop_table("discord_bot_config")
|
||||
@@ -42,20 +42,13 @@ TOOL_DESCRIPTIONS = {
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
try:
|
||||
for tool_id, description in TOOL_DESCRIPTIONS.items():
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"UPDATE tool SET description = :description WHERE in_code_tool_id = :tool_id"
|
||||
),
|
||||
{"description": description, "tool_id": tool_id},
|
||||
)
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
except Exception as e:
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
for tool_id, description in TOOL_DESCRIPTIONS.items():
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"UPDATE tool SET description = :description WHERE in_code_tool_id = :tool_id"
|
||||
),
|
||||
{"description": description, "tool_id": tool_id},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
"""drop agent_search_metrics table
|
||||
|
||||
Revision ID: a1b2c3d4e5f7
|
||||
Revises: 73e9983e5091
|
||||
Create Date: 2026-01-17
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a1b2c3d4e5f7"
|
||||
down_revision = "73e9983e5091"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_table("agent__search_metrics")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.create_table(
|
||||
"agent__search_metrics",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("persona_id", sa.Integer(), nullable=True),
|
||||
sa.Column("agent_type", sa.String(), nullable=False),
|
||||
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("base_duration_s", sa.Float(), nullable=False),
|
||||
sa.Column("full_duration_s", sa.Float(), nullable=False),
|
||||
sa.Column("base_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("refined_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("all_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
@@ -7,7 +7,6 @@ Create Date: 2025-12-18 16:00:00.000000
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_DB_NAME
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
@@ -19,7 +18,7 @@ depends_on = None
|
||||
|
||||
|
||||
DEEP_RESEARCH_TOOL = {
|
||||
"name": RESEARCH_AGENT_DB_NAME,
|
||||
"name": "ResearchAgent",
|
||||
"display_name": "Research Agent",
|
||||
"description": "The Research Agent is a sub-agent that conducts research on a specific topic.",
|
||||
"in_code_tool_id": "ResearchAgent",
|
||||
|
||||
@@ -70,80 +70,66 @@ BUILT_IN_TOOLS = [
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Start transaction
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
# Get existing tools to check what already exists
|
||||
existing_tools = conn.execute(
|
||||
sa.text("SELECT in_code_tool_id FROM tool WHERE in_code_tool_id IS NOT NULL")
|
||||
).fetchall()
|
||||
existing_tool_ids = {row[0] for row in existing_tools}
|
||||
|
||||
try:
|
||||
# Get existing tools to check what already exists
|
||||
existing_tools = conn.execute(
|
||||
sa.text(
|
||||
"SELECT in_code_tool_id FROM tool WHERE in_code_tool_id IS NOT NULL"
|
||||
# Insert or update built-in tools
|
||||
for tool in BUILT_IN_TOOLS:
|
||||
in_code_id = tool["in_code_tool_id"]
|
||||
|
||||
# Handle historical rename: InternetSearchTool -> WebSearchTool
|
||||
if (
|
||||
in_code_id == "WebSearchTool"
|
||||
and "WebSearchTool" not in existing_tool_ids
|
||||
and "InternetSearchTool" in existing_tool_ids
|
||||
):
|
||||
# Rename the existing InternetSearchTool row in place and update fields
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description,
|
||||
in_code_tool_id = :in_code_tool_id
|
||||
WHERE in_code_tool_id = 'InternetSearchTool'
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
).fetchall()
|
||||
existing_tool_ids = {row[0] for row in existing_tools}
|
||||
# Keep the local view of existing ids in sync to avoid duplicate insert
|
||||
existing_tool_ids.discard("InternetSearchTool")
|
||||
existing_tool_ids.add("WebSearchTool")
|
||||
continue
|
||||
|
||||
# Insert or update built-in tools
|
||||
for tool in BUILT_IN_TOOLS:
|
||||
in_code_id = tool["in_code_tool_id"]
|
||||
|
||||
# Handle historical rename: InternetSearchTool -> WebSearchTool
|
||||
if (
|
||||
in_code_id == "WebSearchTool"
|
||||
and "WebSearchTool" not in existing_tool_ids
|
||||
and "InternetSearchTool" in existing_tool_ids
|
||||
):
|
||||
# Rename the existing InternetSearchTool row in place and update fields
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description,
|
||||
in_code_tool_id = :in_code_tool_id
|
||||
WHERE in_code_tool_id = 'InternetSearchTool'
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
# Keep the local view of existing ids in sync to avoid duplicate insert
|
||||
existing_tool_ids.discard("InternetSearchTool")
|
||||
existing_tool_ids.add("WebSearchTool")
|
||||
continue
|
||||
|
||||
if in_code_id in existing_tool_ids:
|
||||
# Update existing tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
else:
|
||||
# Insert new tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id)
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
|
||||
# Commit transaction
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on error
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
if in_code_id in existing_tool_ids:
|
||||
# Update existing tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
else:
|
||||
# Insert new tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id)
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
"""sync_exa_api_key_to_content_provider
|
||||
|
||||
Revision ID: d1b637d7050a
|
||||
Revises: d25168c2beee
|
||||
Create Date: 2026-01-09 15:54:15.646249
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
from sqlalchemy import text
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d1b637d7050a"
|
||||
down_revision = "d25168c2beee"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Exa uses a shared API key between search and content providers.
|
||||
# For existing Exa search providers with API keys, create the corresponding
|
||||
# content provider if it doesn't exist yet.
|
||||
connection = op.get_bind()
|
||||
|
||||
# Check if Exa search provider exists with an API key
|
||||
result = connection.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT api_key FROM internet_search_provider
|
||||
WHERE provider_type = 'exa' AND api_key IS NOT NULL
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
)
|
||||
row = result.fetchone()
|
||||
|
||||
if row:
|
||||
api_key = row[0]
|
||||
# Create Exa content provider with the shared key
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO internet_content_provider
|
||||
(name, provider_type, api_key, is_active)
|
||||
VALUES ('Exa', 'exa', :api_key, false)
|
||||
ON CONFLICT (name) DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"api_key": api_key},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the Exa content provider that was created by this migration
|
||||
connection = op.get_bind()
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
DELETE FROM internet_content_provider
|
||||
WHERE provider_type = 'exa'
|
||||
"""
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,86 @@
|
||||
"""tool_name_consistency
|
||||
|
||||
Revision ID: d25168c2beee
|
||||
Revises: 8405ca81cc83
|
||||
Create Date: 2026-01-11 17:54:40.135777
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d25168c2beee"
|
||||
down_revision = "8405ca81cc83"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
# Currently the seeded tools have the in_code_tool_id == name
|
||||
CURRENT_TOOL_NAME_MAPPING = [
|
||||
"SearchTool",
|
||||
"WebSearchTool",
|
||||
"ImageGenerationTool",
|
||||
"PythonTool",
|
||||
"OpenURLTool",
|
||||
"KnowledgeGraphTool",
|
||||
"ResearchAgent",
|
||||
]
|
||||
|
||||
# Mapping of in_code_tool_id -> name
|
||||
# These are the expected names that we want in the database
|
||||
EXPECTED_TOOL_NAME_MAPPING = {
|
||||
"SearchTool": "internal_search",
|
||||
"WebSearchTool": "web_search",
|
||||
"ImageGenerationTool": "generate_image",
|
||||
"PythonTool": "python",
|
||||
"OpenURLTool": "open_url",
|
||||
"KnowledgeGraphTool": "run_kg_search",
|
||||
"ResearchAgent": "research_agent",
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Mapping of in_code_tool_id to the NAME constant from each tool class
|
||||
# These match the .name property of each tool implementation
|
||||
tool_name_mapping = EXPECTED_TOOL_NAME_MAPPING
|
||||
|
||||
# Update the name column for each tool based on its in_code_tool_id
|
||||
for in_code_tool_id, expected_name in tool_name_mapping.items():
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :expected_name
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
{
|
||||
"expected_name": expected_name,
|
||||
"in_code_tool_id": in_code_tool_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Reverse the migration by setting name back to in_code_tool_id
|
||||
# This matches the original pattern where name was the class name
|
||||
for in_code_tool_id in CURRENT_TOOL_NAME_MAPPING:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :current_name
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
{
|
||||
"current_name": in_code_tool_id,
|
||||
"in_code_tool_id": in_code_tool_id,
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,31 @@
|
||||
"""add chat_background to user
|
||||
|
||||
Revision ID: fb80bdd256de
|
||||
Revises: 8b5ce697290e
|
||||
Create Date: 2026-01-16 16:15:59.222617
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "fb80bdd256de"
|
||||
down_revision = "8b5ce697290e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"chat_background",
|
||||
sa.String(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "chat_background")
|
||||
@@ -109,7 +109,6 @@ CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS = float(
|
||||
|
||||
|
||||
STRIPE_SECRET_KEY = os.environ.get("STRIPE_SECRET_KEY")
|
||||
STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")
|
||||
|
||||
# JWT Public Key URL
|
||||
JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
|
||||
@@ -129,3 +128,8 @@ MARKETING_POSTHOG_API_KEY = os.environ.get("MARKETING_POSTHOG_API_KEY")
|
||||
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
|
||||
|
||||
GATED_TENANTS_KEY = "gated_tenants"
|
||||
|
||||
# License enforcement - when True, blocks API access for gated/expired licenses
|
||||
LICENSE_ENFORCEMENT_ENABLED = (
|
||||
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
@@ -3,30 +3,42 @@ from uuid import UUID
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.server.features.persona.models import PersonaSharedNotificationData
|
||||
|
||||
|
||||
def make_persona_private(
|
||||
def update_persona_access(
|
||||
persona_id: int,
|
||||
creator_user_id: UUID | None,
|
||||
user_ids: list[UUID] | None,
|
||||
group_ids: list[int] | None,
|
||||
db_session: Session,
|
||||
is_public: bool | None = None,
|
||||
user_ids: list[UUID] | None = None,
|
||||
group_ids: list[int] | None = None,
|
||||
) -> None:
|
||||
"""NOTE(rkuo): This function batches all updates into a single commit. If we don't
|
||||
dedupe the inputs, the commit will exception."""
|
||||
"""Updates the access settings for a persona including public status, user shares,
|
||||
and group shares.
|
||||
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
db_session.query(Persona__UserGroup).filter(
|
||||
Persona__UserGroup.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
NOTE: This function batches all updates. If we don't dedupe the inputs,
|
||||
the commit will exception.
|
||||
|
||||
NOTE: Callers are responsible for committing."""
|
||||
|
||||
if is_public is not None:
|
||||
persona = db_session.query(Persona).filter(Persona.id == persona_id).first()
|
||||
if persona:
|
||||
persona.is_public = is_public
|
||||
|
||||
# NOTE: For user-ids and group-ids, `None` means "leave unchanged", `[]` means "clear all shares",
|
||||
# and a non-empty list means "replace with these shares".
|
||||
|
||||
if user_ids is not None:
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
if user_ids:
|
||||
user_ids_set = set(user_ids)
|
||||
for user_id in user_ids_set:
|
||||
db_session.add(Persona__User(persona_id=persona_id, user_id=user_id))
|
||||
@@ -41,11 +53,13 @@ def make_persona_private(
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
if group_ids:
|
||||
if group_ids is not None:
|
||||
db_session.query(Persona__UserGroup).filter(
|
||||
Persona__UserGroup.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
group_ids_set = set(group_ids)
|
||||
for group_id in group_ids_set:
|
||||
db_session.add(
|
||||
Persona__UserGroup(persona_id=persona_id, user_group_id=group_id)
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
64
backend/ee/onyx/db/search.py
Normal file
64
backend/ee/onyx/db/search.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.time_utils import get_db_current_time
|
||||
from onyx.db.models import SearchQuery
|
||||
|
||||
|
||||
def create_search_query(
|
||||
db_session: Session,
|
||||
user_id: UUID,
|
||||
query: str,
|
||||
query_expansions: list[str] | None = None,
|
||||
) -> SearchQuery:
|
||||
"""Create and persist a `SearchQuery` row.
|
||||
|
||||
Notes:
|
||||
- `SearchQuery.id` is a UUID PK without a server-side default, so we generate it.
|
||||
- `created_at` is filled by the DB (server_default=now()).
|
||||
"""
|
||||
search_query = SearchQuery(
|
||||
id=uuid.uuid4(),
|
||||
user_id=user_id,
|
||||
query=query,
|
||||
query_expansions=query_expansions,
|
||||
)
|
||||
db_session.add(search_query)
|
||||
db_session.commit()
|
||||
db_session.refresh(search_query)
|
||||
return search_query
|
||||
|
||||
|
||||
def fetch_search_queries_for_user(
|
||||
db_session: Session,
|
||||
user_id: UUID,
|
||||
filter_days: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[SearchQuery]:
|
||||
"""Fetch `SearchQuery` rows for a user.
|
||||
|
||||
Args:
|
||||
user_id: User UUID.
|
||||
filter_days: Optional time filter. If provided, only rows created within
|
||||
the last `filter_days` days are returned.
|
||||
limit: Optional max number of rows to return.
|
||||
"""
|
||||
if filter_days is not None and filter_days <= 0:
|
||||
raise ValueError("filter_days must be > 0")
|
||||
|
||||
stmt = select(SearchQuery).where(SearchQuery.user_id == user_id)
|
||||
|
||||
if filter_days is not None and filter_days > 0:
|
||||
cutoff = get_db_current_time(db_session) - timedelta(days=filter_days)
|
||||
stmt = stmt.where(SearchQuery.created_at >= cutoff)
|
||||
|
||||
stmt = stmt.order_by(SearchQuery.created_at.desc())
|
||||
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
@@ -16,16 +16,17 @@ from ee.onyx.server.enterprise_settings.api import (
|
||||
from ee.onyx.server.evals.api import router as evals_router
|
||||
from ee.onyx.server.license.api import router as license_router
|
||||
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
|
||||
from ee.onyx.server.middleware.license_enforcement import (
|
||||
add_license_enforcement_middleware,
|
||||
)
|
||||
from ee.onyx.server.middleware.tenant_tracking import (
|
||||
add_api_server_tenant_id_middleware,
|
||||
)
|
||||
from ee.onyx.server.oauth.api import router as ee_oauth_router
|
||||
from ee.onyx.server.query_and_chat.chat_backend import (
|
||||
router as chat_router,
|
||||
)
|
||||
from ee.onyx.server.query_and_chat.query_backend import (
|
||||
basic_router as ee_query_router,
|
||||
)
|
||||
from ee.onyx.server.query_and_chat.search_backend import router as search_router
|
||||
from ee.onyx.server.query_history.api import router as query_history_router
|
||||
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
|
||||
from ee.onyx.server.seeding import seed_db
|
||||
@@ -85,6 +86,10 @@ def get_application() -> FastAPI:
|
||||
if MULTI_TENANT:
|
||||
add_api_server_tenant_id_middleware(application, logger)
|
||||
|
||||
# Add license enforcement middleware (runs after tenant tracking)
|
||||
# This blocks access when license is expired/gated
|
||||
add_license_enforcement_middleware(application, logger)
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
# For Google OAuth, refresh tokens are requested by:
|
||||
# 1. Adding the right scopes
|
||||
@@ -124,7 +129,7 @@ def get_application() -> FastAPI:
|
||||
# EE only backend APIs
|
||||
include_router_with_global_prefix_prepended(application, query_router)
|
||||
include_router_with_global_prefix_prepended(application, ee_query_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
include_router_with_global_prefix_prepended(application, search_router)
|
||||
include_router_with_global_prefix_prepended(application, standard_answer_router)
|
||||
include_router_with_global_prefix_prepended(application, ee_oauth_router)
|
||||
include_router_with_global_prefix_prepended(application, ee_document_cc_pair_router)
|
||||
|
||||
27
backend/ee/onyx/prompts/query_expansion.py
Normal file
27
backend/ee/onyx/prompts/query_expansion.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# Single message is likely most reliable and generally better for this task
|
||||
# No final reminders at the end since the user query is expected to be short
|
||||
# If it is not short, it should go into the chat flow so we do not need to account for this.
|
||||
KEYWORD_EXPANSION_PROMPT = """
|
||||
Generate a set of keyword-only queries to help find relevant documents for the provided query. \
|
||||
These queries will be passed to a bm25-based keyword search engine. \
|
||||
Provide a single query per line (where each query consists of one or more keywords). \
|
||||
The queries must be purely keywords and not contain any filler natural language. \
|
||||
The each query should have as few keywords as necessary to represent the user's search intent. \
|
||||
If there are no useful expansions, simply return the original query with no additional keyword queries. \
|
||||
CRITICAL: Do not include any additional formatting, comments, or anything aside from the keyword queries.
|
||||
|
||||
The user query is:
|
||||
{user_query}
|
||||
""".strip()
|
||||
|
||||
|
||||
QUERY_TYPE_PROMPT = """
|
||||
Determine if the provided query is better suited for a keyword search or a semantic search.
|
||||
Respond with "keyword" or "semantic" literally and nothing else.
|
||||
Do not provide any additional text or reasoning to your response.
|
||||
|
||||
CRITICAL: It must only be 1 single word - EITHER "keyword" or "semantic".
|
||||
|
||||
The user query is:
|
||||
{user_query}
|
||||
""".strip()
|
||||
42
backend/ee/onyx/prompts/search_flow_classification.py
Normal file
42
backend/ee/onyx/prompts/search_flow_classification.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# ruff: noqa: E501, W605 start
|
||||
SEARCH_CLASS = "search"
|
||||
CHAT_CLASS = "chat"
|
||||
|
||||
# Will note that with many larger LLMs the latency on running this prompt via third party APIs is as high as 2 seconds which is too slow for many
|
||||
# use cases.
|
||||
SEARCH_CHAT_PROMPT = f"""
|
||||
Determine if the following query is better suited for a search UI or a chat UI. Respond with "{SEARCH_CLASS}" or "{CHAT_CLASS}" literally and nothing else. \
|
||||
Do not provide any additional text or reasoning to your response. CRITICAL, IT MUST ONLY BE 1 SINGLE WORD - EITHER "{SEARCH_CLASS}" or "{CHAT_CLASS}".
|
||||
|
||||
# Classification Guidelines:
|
||||
## {SEARCH_CLASS}
|
||||
- If the query consists entirely of keywords or query doesn't require any answer from the AI
|
||||
- If the query is a short statement that seems like a search query rather than a question
|
||||
- If the query feels nonsensical or is a short phrase that possibly describes a document or information that could be found in a internal document
|
||||
|
||||
### Examples of {SEARCH_CLASS} queries:
|
||||
- Find me the document that goes over the onboarding process for a new hire
|
||||
- Pull requests since last week
|
||||
- Sales Runbook AMEA Region
|
||||
- Procurement process
|
||||
- Retrieve the PRD for project X
|
||||
|
||||
## {CHAT_CLASS}
|
||||
- If the query is asking a question that requires an answer rather than a document
|
||||
- If the query is asking for a solution, suggestion, or general help
|
||||
- If the query is seeking information that is on the web and likely not in a company internal document
|
||||
- If the query should be answered without any context from additional documents or searches
|
||||
|
||||
### Examples of {CHAT_CLASS} queries:
|
||||
- What led us to win the deal with company X? (seeking answer)
|
||||
- Google Drive not sync-ing files to my computer (seeking solution)
|
||||
- Review my email: <whatever the email is> (general help)
|
||||
- Write me a script to... (general help)
|
||||
- Cheap flights Europe to Tokyo (information likely found on the web, not internal)
|
||||
|
||||
# User Query:
|
||||
{{user_query}}
|
||||
|
||||
REMEMBER TO ONLY RESPOND WITH "{SEARCH_CLASS}" OR "{CHAT_CLASS}" AND NOTHING ELSE.
|
||||
""".strip()
|
||||
# ruff: noqa: E501, W605 end
|
||||
286
backend/ee/onyx/search/process_search_query.py
Normal file
286
backend/ee/onyx/search/process_search_query.py
Normal file
@@ -0,0 +1,286 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.search import create_search_query
|
||||
from ee.onyx.secondary_llm_flows.query_expansion import expand_keywords
|
||||
from ee.onyx.server.query_and_chat.models import SearchDocWithContent
|
||||
from ee.onyx.server.query_and_chat.models import SearchFullResponse
|
||||
from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest
|
||||
from ee.onyx.server.query_and_chat.streaming_models import LLMSelectedDocsPacket
|
||||
from ee.onyx.server.query_and_chat.streaming_models import SearchDocsPacket
|
||||
from ee.onyx.server.query_and_chat.streaming_models import SearchErrorPacket
|
||||
from ee.onyx.server.query_and_chat.streaming_models import SearchQueriesPacket
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import ChunkSearchRequest
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.pipeline import merge_individual_chunks
|
||||
from onyx.context.search.pipeline import search_pipeline
|
||||
from onyx.db.models import User
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.secondary_llm_flows.document_filter import select_sections_for_expansion
|
||||
from onyx.tools.tool_implementations.search.search_utils import (
|
||||
weighted_reciprocal_rank_fusion,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# This is just a heuristic that also happens to work well for the UI/UX
|
||||
# Users would not find it useful to see a huge list of suggested docs
|
||||
# but more than 1 is also likely good as many questions may target more than 1 doc.
|
||||
TARGET_NUM_SECTIONS_FOR_LLM_SELECTION = 3
|
||||
|
||||
|
||||
def _run_single_search(
|
||||
query: str,
|
||||
filters: BaseFilters | None,
|
||||
document_index: DocumentIndex,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
num_hits: int | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Execute a single search query and return chunks."""
|
||||
chunk_search_request = ChunkSearchRequest(
|
||||
query=query,
|
||||
user_selected_filters=filters,
|
||||
limit=num_hits,
|
||||
)
|
||||
|
||||
return search_pipeline(
|
||||
chunk_search_request=chunk_search_request,
|
||||
document_index=document_index,
|
||||
user=user,
|
||||
persona=None, # No persona for direct search
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def stream_search_query(
|
||||
request: SendSearchQueryRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> Generator[
|
||||
SearchQueriesPacket | SearchDocsPacket | LLMSelectedDocsPacket | SearchErrorPacket,
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
Core search function that yields streaming packets.
|
||||
Used by both streaming and non-streaming endpoints.
|
||||
"""
|
||||
# Get document index
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
# This flow is for search so we do not get all indices.
|
||||
document_index = get_default_document_index(search_settings, None)
|
||||
|
||||
# Determine queries to execute
|
||||
original_query = request.search_query
|
||||
keyword_expansions: list[str] = []
|
||||
|
||||
if request.run_query_expansion:
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
keyword_expansions = expand_keywords(
|
||||
user_query=original_query,
|
||||
llm=llm,
|
||||
)
|
||||
if keyword_expansions:
|
||||
logger.debug(
|
||||
f"Query expansion generated {len(keyword_expansions)} keyword queries"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Query expansion failed: {e}; using original query only.")
|
||||
keyword_expansions = []
|
||||
|
||||
# Build list of all executed queries for tracking
|
||||
all_executed_queries = [original_query] + keyword_expansions
|
||||
|
||||
# TODO remove this check, user should not be None
|
||||
if user is not None:
|
||||
create_search_query(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
query=request.search_query,
|
||||
query_expansions=keyword_expansions if keyword_expansions else None,
|
||||
)
|
||||
|
||||
# Execute search(es)
|
||||
if not keyword_expansions:
|
||||
# Single query (original only) - no threading needed
|
||||
chunks = _run_single_search(
|
||||
query=original_query,
|
||||
filters=request.filters,
|
||||
document_index=document_index,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
num_hits=request.num_hits,
|
||||
)
|
||||
else:
|
||||
# Multiple queries - run in parallel and merge with RRF
|
||||
# First query is the original (semantic), rest are keyword expansions
|
||||
search_functions = [
|
||||
(
|
||||
_run_single_search,
|
||||
(
|
||||
query,
|
||||
request.filters,
|
||||
document_index,
|
||||
user,
|
||||
db_session,
|
||||
request.num_hits,
|
||||
),
|
||||
)
|
||||
for query in all_executed_queries
|
||||
]
|
||||
|
||||
# Run all searches in parallel
|
||||
all_search_results: list[list[InferenceChunk]] = (
|
||||
run_functions_tuples_in_parallel(
|
||||
search_functions,
|
||||
allow_failures=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Separate original query results from keyword expansion results
|
||||
# Note that in rare cases, the original query may have failed and so we may be
|
||||
# just overweighting one set of keyword results, should be not a big deal though.
|
||||
original_result = all_search_results[0] if all_search_results else []
|
||||
keyword_results = all_search_results[1:] if len(all_search_results) > 1 else []
|
||||
|
||||
# Build valid results and weights
|
||||
# Original query (semantic): weight 2.0
|
||||
# Keyword expansions: weight 1.0 each
|
||||
valid_results: list[list[InferenceChunk]] = []
|
||||
weights: list[float] = []
|
||||
|
||||
if original_result:
|
||||
valid_results.append(original_result)
|
||||
weights.append(2.0)
|
||||
|
||||
for keyword_result in keyword_results:
|
||||
if keyword_result:
|
||||
valid_results.append(keyword_result)
|
||||
weights.append(1.0)
|
||||
|
||||
if not valid_results:
|
||||
logger.warning("All parallel searches returned empty results")
|
||||
chunks = []
|
||||
else:
|
||||
chunks = weighted_reciprocal_rank_fusion(
|
||||
ranked_results=valid_results,
|
||||
weights=weights,
|
||||
id_extractor=lambda chunk: f"{chunk.document_id}_{chunk.chunk_id}",
|
||||
)
|
||||
|
||||
# Merge chunks into sections
|
||||
sections = merge_individual_chunks(chunks)
|
||||
|
||||
# Truncate to the requested number of hits
|
||||
sections = sections[: request.num_hits]
|
||||
|
||||
# Apply LLM document selection if requested
|
||||
# num_docs_fed_to_llm_selection specifies how many sections to feed to the LLM for selection
|
||||
# The LLM will always try to select TARGET_NUM_SECTIONS_FOR_LLM_SELECTION sections from those fed to it
|
||||
# llm_selected_doc_ids will be:
|
||||
# - None if LLM selection was not requested or failed
|
||||
# - Empty list if LLM selection ran but selected nothing
|
||||
# - List of doc IDs if LLM selection succeeded
|
||||
run_llm_selection = (
|
||||
request.num_docs_fed_to_llm_selection is not None
|
||||
and request.num_docs_fed_to_llm_selection >= 1
|
||||
)
|
||||
llm_selected_doc_ids: list[str] | None = None
|
||||
llm_selection_failed = False
|
||||
if run_llm_selection and sections:
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
sections_to_evaluate = sections[: request.num_docs_fed_to_llm_selection]
|
||||
selected_sections, _ = select_sections_for_expansion(
|
||||
sections=sections_to_evaluate,
|
||||
user_query=original_query,
|
||||
llm=llm,
|
||||
max_sections=TARGET_NUM_SECTIONS_FOR_LLM_SELECTION,
|
||||
try_to_fill_to_max=True,
|
||||
)
|
||||
# Extract unique document IDs from selected sections (may be empty)
|
||||
llm_selected_doc_ids = list(
|
||||
dict.fromkeys(
|
||||
section.center_chunk.document_id for section in selected_sections
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
f"LLM document selection evaluated {len(sections_to_evaluate)} sections, "
|
||||
f"selected {len(selected_sections)} sections with doc IDs: {llm_selected_doc_ids}"
|
||||
)
|
||||
except Exception as e:
|
||||
# Allowing a blanket exception here as this step is not critical and the rest of the results are still valid
|
||||
logger.warning(f"LLM document selection failed: {e}")
|
||||
llm_selection_failed = True
|
||||
elif run_llm_selection and not sections:
|
||||
# LLM selection requested but no sections to evaluate
|
||||
llm_selected_doc_ids = []
|
||||
|
||||
# Convert to SearchDocWithContent list, optionally including content
|
||||
search_docs = SearchDocWithContent.from_inference_sections(
|
||||
sections,
|
||||
include_content=request.include_content,
|
||||
is_internet=False,
|
||||
)
|
||||
|
||||
# Yield queries packet
|
||||
yield SearchQueriesPacket(all_executed_queries=all_executed_queries)
|
||||
|
||||
# Yield docs packet
|
||||
yield SearchDocsPacket(search_docs=search_docs)
|
||||
|
||||
# Yield LLM selected docs packet if LLM selection was requested
|
||||
# - llm_selected_doc_ids is None if selection failed
|
||||
# - llm_selected_doc_ids is empty list if no docs were selected
|
||||
# - llm_selected_doc_ids is list of IDs if docs were selected
|
||||
if run_llm_selection:
|
||||
yield LLMSelectedDocsPacket(
|
||||
llm_selected_doc_ids=None if llm_selection_failed else llm_selected_doc_ids
|
||||
)
|
||||
|
||||
|
||||
def gather_search_stream(
|
||||
packets: Generator[
|
||||
SearchQueriesPacket
|
||||
| SearchDocsPacket
|
||||
| LLMSelectedDocsPacket
|
||||
| SearchErrorPacket,
|
||||
None,
|
||||
None,
|
||||
],
|
||||
) -> SearchFullResponse:
|
||||
"""
|
||||
Aggregate all streaming packets into SearchFullResponse.
|
||||
"""
|
||||
all_executed_queries: list[str] = []
|
||||
search_docs: list[SearchDocWithContent] = []
|
||||
llm_selected_doc_ids: list[str] | None = None
|
||||
error: str | None = None
|
||||
|
||||
for packet in packets:
|
||||
if isinstance(packet, SearchQueriesPacket):
|
||||
all_executed_queries = packet.all_executed_queries
|
||||
elif isinstance(packet, SearchDocsPacket):
|
||||
search_docs = packet.search_docs
|
||||
elif isinstance(packet, LLMSelectedDocsPacket):
|
||||
llm_selected_doc_ids = packet.llm_selected_doc_ids
|
||||
elif isinstance(packet, SearchErrorPacket):
|
||||
error = packet.error
|
||||
|
||||
return SearchFullResponse(
|
||||
all_executed_queries=all_executed_queries,
|
||||
search_docs=search_docs,
|
||||
doc_selection_reasoning=None,
|
||||
llm_selected_doc_ids=llm_selected_doc_ids,
|
||||
error=error,
|
||||
)
|
||||
0
backend/ee/onyx/secondary_llm_flows/__init__.py
Normal file
0
backend/ee/onyx/secondary_llm_flows/__init__.py
Normal file
92
backend/ee/onyx/secondary_llm_flows/query_expansion.py
Normal file
92
backend/ee/onyx/secondary_llm_flows/query_expansion.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import re
|
||||
|
||||
from ee.onyx.prompts.query_expansion import KEYWORD_EXPANSION_PROMPT
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import LanguageModelInput
|
||||
from onyx.llm.models import ReasoningEffort
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.llm.utils import llm_response_to_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Pattern to remove common LLM artifacts: brackets, quotes, list markers, etc.
|
||||
CLEANUP_PATTERN = re.compile(r'[\[\]"\'`]')
|
||||
|
||||
|
||||
def _clean_keyword_line(line: str) -> str:
|
||||
"""Clean a keyword line by removing common LLM artifacts.
|
||||
|
||||
Removes brackets, quotes, and other characters that LLMs may accidentally
|
||||
include in their output.
|
||||
"""
|
||||
# Remove common artifacts
|
||||
cleaned = CLEANUP_PATTERN.sub("", line)
|
||||
# Remove leading list markers like "1.", "2.", "-", "*"
|
||||
cleaned = re.sub(r"^\s*(?:\d+[\.\)]\s*|[-*]\s*)", "", cleaned)
|
||||
return cleaned.strip()
|
||||
|
||||
|
||||
def expand_keywords(
|
||||
user_query: str,
|
||||
llm: LLM,
|
||||
) -> list[str]:
|
||||
"""Expand a user query into multiple keyword-only queries for BM25 search.
|
||||
|
||||
Uses an LLM to generate keyword-based search queries that capture different
|
||||
aspects of the user's search intent. Returns only the expanded queries,
|
||||
not the original query.
|
||||
|
||||
Args:
|
||||
user_query: The original search query from the user
|
||||
llm: Language model to use for keyword expansion
|
||||
|
||||
Returns:
|
||||
List of expanded keyword queries (excluding the original query).
|
||||
Returns empty list if expansion fails or produces no useful expansions.
|
||||
"""
|
||||
messages: LanguageModelInput = [
|
||||
UserMessage(content=KEYWORD_EXPANSION_PROMPT.format(user_query=user_query))
|
||||
]
|
||||
|
||||
try:
|
||||
response = llm.invoke(
|
||||
prompt=messages,
|
||||
reasoning_effort=ReasoningEffort.OFF,
|
||||
# Limit output - we only expect a few short keyword queries
|
||||
max_tokens=150,
|
||||
)
|
||||
|
||||
content = llm_response_to_string(response).strip()
|
||||
|
||||
if not content:
|
||||
logger.warning("Keyword expansion returned empty response.")
|
||||
return []
|
||||
|
||||
# Parse response - each line is a separate keyword query
|
||||
# Clean each line to remove LLM artifacts and drop empty lines
|
||||
parsed_queries = []
|
||||
for line in content.strip().split("\n"):
|
||||
cleaned = _clean_keyword_line(line)
|
||||
if cleaned:
|
||||
parsed_queries.append(cleaned)
|
||||
|
||||
if not parsed_queries:
|
||||
logger.warning("Keyword expansion parsing returned no queries.")
|
||||
return []
|
||||
|
||||
# Filter out duplicates and queries that match the original
|
||||
expanded_queries: list[str] = []
|
||||
seen_lower: set[str] = {user_query.lower()}
|
||||
for query in parsed_queries:
|
||||
query_lower = query.lower()
|
||||
if query_lower not in seen_lower:
|
||||
seen_lower.add(query_lower)
|
||||
expanded_queries.append(query)
|
||||
|
||||
logger.debug(f"Keyword expansion generated {len(expanded_queries)} queries")
|
||||
return expanded_queries
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Keyword expansion failed: {e}")
|
||||
return []
|
||||
@@ -0,0 +1,50 @@
|
||||
from ee.onyx.prompts.search_flow_classification import CHAT_CLASS
|
||||
from ee.onyx.prompts.search_flow_classification import SEARCH_CHAT_PROMPT
|
||||
from ee.onyx.prompts.search_flow_classification import SEARCH_CLASS
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import LanguageModelInput
|
||||
from onyx.llm.models import ReasoningEffort
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.llm.utils import llm_response_to_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def classify_is_search_flow(
|
||||
query: str,
|
||||
llm: LLM,
|
||||
) -> bool:
|
||||
messages: LanguageModelInput = [
|
||||
UserMessage(content=SEARCH_CHAT_PROMPT.format(user_query=query))
|
||||
]
|
||||
response = llm.invoke(
|
||||
prompt=messages,
|
||||
reasoning_effort=ReasoningEffort.OFF,
|
||||
# Nothing can happen in the UI until this call finishes so we need to be aggressive with the timeout
|
||||
timeout_override=2,
|
||||
# Well more than necessary but just to ensure completion and in case it succeeds with classifying but
|
||||
# ends up rambling
|
||||
max_tokens=20,
|
||||
)
|
||||
|
||||
content = llm_response_to_string(response).strip().lower()
|
||||
if not content:
|
||||
logger.warning(
|
||||
"Search flow classification returned empty response; defaulting to chat flow."
|
||||
)
|
||||
return False
|
||||
|
||||
# Prefer chat if both appear.
|
||||
if CHAT_CLASS in content:
|
||||
return False
|
||||
if SEARCH_CLASS in content:
|
||||
return True
|
||||
|
||||
logger.warning(
|
||||
"Search flow classification returned unexpected response; defaulting to chat flow. Response=%r",
|
||||
content,
|
||||
)
|
||||
return False
|
||||
@@ -19,9 +19,9 @@ from ee.onyx.db.analytics import fetch_query_analytics
|
||||
from ee.onyx.db.analytics import user_can_view_assistant_stats
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.server.utils import PUBLIC_API_TAGS
|
||||
|
||||
router = APIRouter(prefix="/analytics", tags=PUBLIC_API_TAGS)
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@ EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [
|
||||
("/enterprise-settings/logo", {"GET"}),
|
||||
("/enterprise-settings/logotype", {"GET"}),
|
||||
("/enterprise-settings/custom-analytics-script", {"GET"}),
|
||||
# Stripe publishable key is safe to expose publicly
|
||||
("/tenants/stripe-publishable-key", {"GET"}),
|
||||
]
|
||||
|
||||
|
||||
|
||||
102
backend/ee/onyx/server/middleware/license_enforcement.py
Normal file
102
backend/ee/onyx/server/middleware/license_enforcement.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Middleware to enforce license status application-wide."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from ee.onyx.server.tenants.product_gating import is_tenant_gated
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
# Paths that are ALWAYS accessible, even when license is expired/gated.
|
||||
# These enable users to:
|
||||
# /auth - Log in/out (users can't fix billing if locked out of auth)
|
||||
# /license - Fetch, upload, or check license status
|
||||
# /health - Health checks for load balancers/orchestrators
|
||||
# /me - Basic user info needed for UI rendering
|
||||
# /settings, /enterprise-settings - View app status and branding
|
||||
# /tenants/billing-* - Manage subscription to resolve gating
|
||||
ALLOWED_PATH_PREFIXES = {
|
||||
"/auth",
|
||||
"/license",
|
||||
"/health",
|
||||
"/me",
|
||||
"/settings",
|
||||
"/enterprise-settings",
|
||||
"/tenants/billing-information",
|
||||
"/tenants/create-customer-portal-session",
|
||||
"/tenants/create-subscription-session",
|
||||
}
|
||||
|
||||
|
||||
def _is_path_allowed(path: str) -> bool:
|
||||
"""Check if path is in allowlist (prefix match)."""
|
||||
return any(path.startswith(prefix) for prefix in ALLOWED_PATH_PREFIXES)
|
||||
|
||||
|
||||
def add_license_enforcement_middleware(
|
||||
app: FastAPI, logger: logging.LoggerAdapter
|
||||
) -> None:
|
||||
logger.info("License enforcement middleware registered")
|
||||
|
||||
@app.middleware("http")
|
||||
async def enforce_license(
|
||||
request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||||
) -> Response:
|
||||
"""Block requests when license is expired/gated."""
|
||||
if not LICENSE_ENFORCEMENT_ENABLED:
|
||||
return await call_next(request)
|
||||
|
||||
path = request.url.path
|
||||
if path.startswith("/api"):
|
||||
path = path[4:]
|
||||
|
||||
if _is_path_allowed(path):
|
||||
return await call_next(request)
|
||||
|
||||
is_gated = False
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if MULTI_TENANT:
|
||||
try:
|
||||
is_gated = is_tenant_gated(tenant_id)
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check tenant gating status: {e}")
|
||||
# Fail open - don't block users due to Redis connectivity issues
|
||||
is_gated = False
|
||||
else:
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if metadata:
|
||||
if metadata.status == ApplicationStatus.GATED_ACCESS:
|
||||
is_gated = True
|
||||
else:
|
||||
# No license metadata = gated for self-hosted EE
|
||||
is_gated = True
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license metadata: {e}")
|
||||
# Fail open - don't block users due to Redis connectivity issues
|
||||
is_gated = False
|
||||
|
||||
if is_gated:
|
||||
logger.info(f"Blocking request for gated tenant: {tenant_id}, path={path}")
|
||||
return JSONResponse(
|
||||
status_code=402,
|
||||
content={
|
||||
"detail": {
|
||||
"error": "license_expired",
|
||||
"message": "Your subscription has expired. Please update your billing.",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
@@ -1,214 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.query_and_chat.models import BasicCreateChatMessageRequest
|
||||
from ee.onyx.server.query_and_chat.models import (
|
||||
BasicCreateChatMessageWithHistoryRequest,
|
||||
)
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.chat_utils import create_chat_history_chain
|
||||
from onyx.chat.models import ChatBasicResponse
|
||||
from onyx.chat.process_message import gather_stream
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import OptionalSearchSetting
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/chat")
|
||||
|
||||
|
||||
@router.post("/send-message-simple-api")
|
||||
def handle_simplified_chat_message(
|
||||
chat_message_req: BasicCreateChatMessageRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatBasicResponse:
|
||||
"""This is a Non-Streaming version that only gives back a minimal set of information"""
|
||||
logger.notice(f"Received new simple api chat message: {chat_message_req.message}")
|
||||
|
||||
if not chat_message_req.message:
|
||||
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
|
||||
|
||||
# Handle chat session creation if chat_session_id is not provided
|
||||
if chat_message_req.chat_session_id is None:
|
||||
if chat_message_req.persona_id is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Either chat_session_id or persona_id must be provided",
|
||||
)
|
||||
|
||||
# Create a new chat session with the provided persona_id
|
||||
try:
|
||||
new_chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="", # Leave empty for simple API
|
||||
user_id=user.id if user else None,
|
||||
persona_id=chat_message_req.persona_id,
|
||||
)
|
||||
chat_session_id = new_chat_session.id
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise HTTPException(status_code=400, detail="Invalid Persona provided.")
|
||||
else:
|
||||
chat_session_id = chat_message_req.chat_session_id
|
||||
|
||||
try:
|
||||
parent_message = create_chat_history_chain(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)[-1]
|
||||
except Exception:
|
||||
parent_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
|
||||
if (
|
||||
chat_message_req.retrieval_options is None
|
||||
and chat_message_req.search_doc_ids is None
|
||||
):
|
||||
retrieval_options: RetrievalDetails | None = RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=False,
|
||||
)
|
||||
else:
|
||||
retrieval_options = chat_message_req.retrieval_options
|
||||
|
||||
full_chat_msg_info = CreateChatMessageRequest(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=parent_message.id,
|
||||
message=chat_message_req.message,
|
||||
file_descriptors=[],
|
||||
search_doc_ids=chat_message_req.search_doc_ids,
|
||||
retrieval_options=retrieval_options,
|
||||
# Simple API does not support reranking, hide complexity from user
|
||||
rerank_settings=None,
|
||||
query_override=chat_message_req.query_override,
|
||||
# Currently only applies to search flow not chat
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
full_doc=chat_message_req.full_doc,
|
||||
structured_response_format=chat_message_req.structured_response_format,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=full_chat_msg_info,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return gather_stream(packets)
|
||||
|
||||
|
||||
@router.post("/send-message-simple-with-history")
|
||||
def handle_send_message_simple_with_history(
|
||||
req: BasicCreateChatMessageWithHistoryRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatBasicResponse:
|
||||
"""This is a Non-Streaming version that only gives back a minimal set of information.
|
||||
takes in chat history maintained by the caller
|
||||
and does query rephrasing similar to answer-with-quote"""
|
||||
|
||||
if len(req.messages) == 0:
|
||||
raise HTTPException(status_code=400, detail="Messages cannot be zero length")
|
||||
|
||||
# This is a sanity check to make sure the chat history is valid
|
||||
# It must start with a user message and alternate beteen user and assistant
|
||||
expected_role = MessageType.USER
|
||||
for msg in req.messages:
|
||||
if not msg.message:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="One or more chat messages were empty"
|
||||
)
|
||||
|
||||
if msg.role != expected_role:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Message roles must start and end with MessageType.USER and alternate in-between.",
|
||||
)
|
||||
if expected_role == MessageType.USER:
|
||||
expected_role = MessageType.ASSISTANT
|
||||
else:
|
||||
expected_role = MessageType.USER
|
||||
|
||||
query = req.messages[-1].message
|
||||
msg_history = req.messages[:-1]
|
||||
|
||||
logger.notice(f"Received new simple with history chat message: {query}")
|
||||
|
||||
user_id = user.id if user is not None else None
|
||||
chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="handle_send_message_simple_with_history",
|
||||
user_id=user_id,
|
||||
persona_id=req.persona_id,
|
||||
)
|
||||
|
||||
llm = get_llm_for_persona(persona=chat_session.persona, user=user)
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
provider_type=llm.config.model_provider,
|
||||
)
|
||||
|
||||
# Every chat Session begins with an empty root message
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
)
|
||||
|
||||
chat_message = root_message
|
||||
for msg in msg_history:
|
||||
chat_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=chat_message,
|
||||
message=msg.message,
|
||||
token_count=len(llm_tokenizer.encode(msg.message)),
|
||||
message_type=msg.role,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
if req.retrieval_options is None and req.search_doc_ids is None:
|
||||
retrieval_options: RetrievalDetails | None = RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=False,
|
||||
)
|
||||
else:
|
||||
retrieval_options = req.retrieval_options
|
||||
|
||||
full_chat_msg_info = CreateChatMessageRequest(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message_id=chat_message.id,
|
||||
message=query,
|
||||
file_descriptors=[],
|
||||
search_doc_ids=req.search_doc_ids,
|
||||
retrieval_options=retrieval_options,
|
||||
# Simple API does not support reranking, hide complexity from user
|
||||
rerank_settings=None,
|
||||
query_override=None,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
full_doc=req.full_doc,
|
||||
structured_response_format=req.structured_response_format,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=full_chat_msg_info,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return gather_stream(packets)
|
||||
@@ -1,18 +1,12 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import BasicChunkRequest
|
||||
from onyx.context.search.models import ChunkContext
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.server.manage.models import StandardAnswer
|
||||
|
||||
|
||||
@@ -25,119 +19,89 @@ class StandardAnswerResponse(BaseModel):
|
||||
standard_answers: list[StandardAnswer] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DocumentSearchRequest(BasicChunkRequest):
|
||||
user_selected_filters: BaseFilters | None = None
|
||||
class SearchFlowClassificationRequest(BaseModel):
|
||||
user_query: str
|
||||
|
||||
|
||||
class DocumentSearchResponse(BaseModel):
|
||||
top_documents: list[InferenceChunk]
|
||||
class SearchFlowClassificationResponse(BaseModel):
|
||||
is_search_flow: bool
|
||||
|
||||
|
||||
class BasicCreateChatMessageRequest(ChunkContext):
|
||||
"""If a chat_session_id is not provided, a persona_id must be provided to automatically create a new chat session
|
||||
Note, for simplicity this option only allows for a single linear chain of messages
|
||||
"""
|
||||
class SendSearchQueryRequest(BaseModel):
|
||||
search_query: str
|
||||
filters: BaseFilters | None = None
|
||||
num_docs_fed_to_llm_selection: int | None = None
|
||||
run_query_expansion: bool = False
|
||||
num_hits: int = 50
|
||||
|
||||
chat_session_id: UUID | None = None
|
||||
# Optional persona_id to create a new chat session if chat_session_id is not provided
|
||||
persona_id: int | None = None
|
||||
# New message contents
|
||||
message: str
|
||||
# Defaults to using retrieval with no additional filters
|
||||
retrieval_options: RetrievalDetails | None = None
|
||||
# Allows the caller to specify the exact search query they want to use
|
||||
# will disable Query Rewording if specified
|
||||
query_override: str | None = None
|
||||
# If search_doc_ids provided, then retrieval options are unused
|
||||
search_doc_ids: list[int] | None = None
|
||||
# only works if using an OpenAI model. See the following for more details:
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_chat_session_or_persona(self) -> "BasicCreateChatMessageRequest":
|
||||
if self.chat_session_id is None and self.persona_id is None:
|
||||
raise ValueError("Either chat_session_id or persona_id must be provided")
|
||||
return self
|
||||
include_content: bool = False
|
||||
stream: bool = False
|
||||
|
||||
|
||||
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
# Last element is the new query. All previous elements are historical context
|
||||
messages: list[ThreadMessage]
|
||||
persona_id: int
|
||||
retrieval_options: RetrievalDetails | None = None
|
||||
query_override: str | None = None
|
||||
skip_rerank: bool | None = None
|
||||
# If search_doc_ids provided, then retrieval options are unused
|
||||
search_doc_ids: list[int] | None = None
|
||||
# only works if using an OpenAI model. See the following for more details:
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
class SearchDocWithContent(SearchDoc):
|
||||
# Allows None because this is determined by a flag but the object used in code
|
||||
# of the search path uses this type
|
||||
content: str | None
|
||||
|
||||
@classmethod
|
||||
def from_inference_sections(
|
||||
cls,
|
||||
sections: Sequence[InferenceSection],
|
||||
include_content: bool = False,
|
||||
is_internet: bool = False,
|
||||
) -> list["SearchDocWithContent"]:
|
||||
"""Convert InferenceSections to SearchDocWithContent objects.
|
||||
|
||||
class SimpleDoc(BaseModel):
|
||||
id: str
|
||||
semantic_identifier: str
|
||||
link: str | None
|
||||
blurb: str
|
||||
match_highlights: list[str]
|
||||
source_type: DocumentSource
|
||||
metadata: dict | None
|
||||
Args:
|
||||
sections: Sequence of InferenceSection objects
|
||||
include_content: If True, populate content field with combined_content
|
||||
is_internet: Whether these are internet search results
|
||||
|
||||
|
||||
class AgentSubQuestion(BaseModel):
|
||||
sub_question: str
|
||||
document_ids: list[str]
|
||||
|
||||
|
||||
class AgentAnswer(BaseModel):
|
||||
answer: str
|
||||
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
|
||||
|
||||
|
||||
class AgentSubQuery(BaseModel):
|
||||
sub_query: str
|
||||
query_id: int
|
||||
|
||||
@staticmethod
|
||||
def make_dict_by_level_and_question_index(
|
||||
original_dict: dict[tuple[int, int, int], "AgentSubQuery"],
|
||||
) -> dict[int, dict[int, list["AgentSubQuery"]]]:
|
||||
"""Takes a dict of tuple(level, question num, query_id) to sub queries.
|
||||
|
||||
returns a dict of level to dict[question num to list of query_id's]
|
||||
Ordering is asc for readability.
|
||||
Returns:
|
||||
List of SearchDocWithContent with optional content
|
||||
"""
|
||||
# In this function, when we sort int | None, we deliberately push None to the end
|
||||
if not sections:
|
||||
return []
|
||||
|
||||
# map entries to the level_question_dict
|
||||
level_question_dict: dict[int, dict[int, list["AgentSubQuery"]]] = {}
|
||||
for k1, obj in original_dict.items():
|
||||
level = k1[0]
|
||||
question = k1[1]
|
||||
|
||||
if level not in level_question_dict:
|
||||
level_question_dict[level] = {}
|
||||
|
||||
if question not in level_question_dict[level]:
|
||||
level_question_dict[level][question] = []
|
||||
|
||||
level_question_dict[level][question].append(obj)
|
||||
|
||||
# sort each query_id list and question_index
|
||||
for key1, obj1 in level_question_dict.items():
|
||||
for key2, value2 in obj1.items():
|
||||
# sort the query_id list of each question_index
|
||||
level_question_dict[key1][key2] = sorted(
|
||||
value2, key=lambda o: o.query_id
|
||||
)
|
||||
# sort the question_index dict of level
|
||||
level_question_dict[key1] = OrderedDict(
|
||||
sorted(level_question_dict[key1].items(), key=lambda x: (x is None, x))
|
||||
return [
|
||||
cls(
|
||||
document_id=(chunk := section.center_chunk).document_id,
|
||||
chunk_ind=chunk.chunk_id,
|
||||
semantic_identifier=chunk.semantic_identifier or "Unknown",
|
||||
link=chunk.source_links[0] if chunk.source_links else None,
|
||||
blurb=chunk.blurb,
|
||||
source_type=chunk.source_type,
|
||||
boost=chunk.boost,
|
||||
hidden=chunk.hidden,
|
||||
metadata=chunk.metadata,
|
||||
score=chunk.score,
|
||||
match_highlights=chunk.match_highlights,
|
||||
updated_at=chunk.updated_at,
|
||||
primary_owners=chunk.primary_owners,
|
||||
secondary_owners=chunk.secondary_owners,
|
||||
is_internet=is_internet,
|
||||
content=section.combined_content if include_content else None,
|
||||
)
|
||||
for section in sections
|
||||
]
|
||||
|
||||
# sort the top dict of levels
|
||||
sorted_dict = OrderedDict(
|
||||
sorted(level_question_dict.items(), key=lambda x: (x is None, x))
|
||||
)
|
||||
return sorted_dict
|
||||
|
||||
class SearchFullResponse(BaseModel):
|
||||
all_executed_queries: list[str]
|
||||
search_docs: list[SearchDocWithContent]
|
||||
# Reasoning tokens output by the LLM for the document selection
|
||||
doc_selection_reasoning: str | None = None
|
||||
# This a list of document ids that are in the search_docs list
|
||||
llm_selected_doc_ids: list[str] | None = None
|
||||
# Error message if the search failed partway through
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class SearchQueryResponse(BaseModel):
|
||||
query: str
|
||||
query_expansions: list[str] | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class SearchHistoryResponse(BaseModel):
|
||||
search_queries: list[SearchQueryResponse]
|
||||
|
||||
170
backend/ee/onyx/server/query_and_chat/search_backend.py
Normal file
170
backend/ee/onyx/server/query_and_chat/search_backend.py
Normal file
@@ -0,0 +1,170 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.search import fetch_search_queries_for_user
|
||||
from ee.onyx.search.process_search_query import gather_search_stream
|
||||
from ee.onyx.search.process_search_query import stream_search_query
|
||||
from ee.onyx.secondary_llm_flows.search_flow_classification import (
|
||||
classify_is_search_flow,
|
||||
)
|
||||
from ee.onyx.server.query_and_chat.models import SearchFlowClassificationRequest
|
||||
from ee.onyx.server.query_and_chat.models import SearchFlowClassificationResponse
|
||||
from ee.onyx.server.query_and_chat.models import SearchFullResponse
|
||||
from ee.onyx.server.query_and_chat.models import SearchHistoryResponse
|
||||
from ee.onyx.server.query_and_chat.models import SearchQueryResponse
|
||||
from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest
|
||||
from ee.onyx.server.query_and_chat.streaming_models import SearchErrorPacket
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.server.usage_limits import check_llm_cost_limit_for_provider
|
||||
from onyx.server.utils import get_json_line
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/search")
|
||||
|
||||
|
||||
@router.post("/search-flow-classification")
|
||||
def search_flow_classification(
|
||||
request: SearchFlowClassificationRequest,
|
||||
# This is added just to ensure this endpoint isn't spammed by non-authorized users since there's an LLM call underneath it
|
||||
_: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SearchFlowClassificationResponse:
|
||||
query = request.user_query
|
||||
# This is a heuristic that if the user is typing a lot of text, it's unlikely they're looking for some specific document
|
||||
# Most likely something needs to be done with the text included so we'll just classify it as a chat flow
|
||||
if len(query) > 200:
|
||||
return SearchFlowClassificationResponse(is_search_flow=False)
|
||||
|
||||
llm = get_default_llm()
|
||||
|
||||
check_llm_cost_limit_for_provider(
|
||||
db_session=db_session,
|
||||
tenant_id=get_current_tenant_id(),
|
||||
llm_provider_api_key=llm.config.api_key,
|
||||
)
|
||||
|
||||
try:
|
||||
is_search_flow = classify_is_search_flow(query=query, llm=llm)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Search flow classification failed; defaulting to chat flow",
|
||||
exc_info=e,
|
||||
)
|
||||
is_search_flow = False
|
||||
|
||||
return SearchFlowClassificationResponse(is_search_flow=is_search_flow)
|
||||
|
||||
|
||||
@router.post("/send-search-message", response_model=None)
|
||||
def handle_send_search_message(
|
||||
request: SendSearchQueryRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse | SearchFullResponse:
|
||||
"""
|
||||
Execute a search query with optional streaming.
|
||||
|
||||
When stream=True: Returns StreamingResponse with SSE
|
||||
When stream=False: Returns SearchFullResponse
|
||||
"""
|
||||
logger.debug(f"Received search query: {request.search_query}")
|
||||
|
||||
# Non-streaming path
|
||||
if not request.stream:
|
||||
try:
|
||||
packets = stream_search_query(request, user, db_session)
|
||||
return gather_search_stream(packets)
|
||||
except NotImplementedError as e:
|
||||
return SearchFullResponse(
|
||||
all_executed_queries=[],
|
||||
search_docs=[],
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
# Streaming path
|
||||
def stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
with get_session_with_current_tenant() as streaming_db_session:
|
||||
for packet in stream_search_query(request, user, streaming_db_session):
|
||||
yield get_json_line(packet.model_dump())
|
||||
except NotImplementedError as e:
|
||||
yield get_json_line(SearchErrorPacket(error=str(e)).model_dump())
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error in search streaming")
|
||||
yield get_json_line(SearchErrorPacket(error=str(e)).model_dump())
|
||||
|
||||
return StreamingResponse(stream_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.get("/search-history")
|
||||
def get_search_history(
|
||||
limit: int = 100,
|
||||
filter_days: int | None = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SearchHistoryResponse:
|
||||
"""
|
||||
Fetch past search queries for the authenticated user.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of queries to return (default 100)
|
||||
filter_days: Only return queries from the last N days (optional)
|
||||
|
||||
Returns:
|
||||
SearchHistoryResponse with list of search queries, ordered by most recent first.
|
||||
"""
|
||||
# Validate limit
|
||||
if limit <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="limit must be greater than 0",
|
||||
)
|
||||
if limit > 1000:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="limit must be at most 1000",
|
||||
)
|
||||
|
||||
# Validate filter_days
|
||||
if filter_days is not None and filter_days <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="filter_days must be greater than 0",
|
||||
)
|
||||
|
||||
# TODO(yuhong) remove this
|
||||
if user is None:
|
||||
# Return empty list for unauthenticated users
|
||||
return SearchHistoryResponse(search_queries=[])
|
||||
|
||||
search_queries = fetch_search_queries_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
filter_days=filter_days,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return SearchHistoryResponse(
|
||||
search_queries=[
|
||||
SearchQueryResponse(
|
||||
query=sq.query,
|
||||
query_expansions=sq.query_expansions,
|
||||
created_at=sq.created_at,
|
||||
)
|
||||
for sq in search_queries
|
||||
]
|
||||
)
|
||||
35
backend/ee/onyx/server/query_and_chat/streaming_models.py
Normal file
35
backend/ee/onyx/server/query_and_chat/streaming_models.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from ee.onyx.server.query_and_chat.models import SearchDocWithContent
|
||||
|
||||
|
||||
class SearchQueriesPacket(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
type: Literal["search_queries"] = "search_queries"
|
||||
all_executed_queries: list[str]
|
||||
|
||||
|
||||
class SearchDocsPacket(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
type: Literal["search_docs"] = "search_docs"
|
||||
search_docs: list[SearchDocWithContent]
|
||||
|
||||
|
||||
class SearchErrorPacket(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
type: Literal["search_error"] = "search_error"
|
||||
error: str
|
||||
|
||||
|
||||
class LLMSelectedDocsPacket(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
type: Literal["llm_selected_docs"] = "llm_selected_docs"
|
||||
# None if LLM selection failed, empty list if no docs selected, list of IDs otherwise
|
||||
llm_selected_doc_ids: list[str] | None
|
||||
@@ -32,6 +32,7 @@ from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.configs.constants import SessionType
|
||||
@@ -48,7 +49,6 @@ from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from onyx.server.query_and_chat.models import ChatSessionDetails
|
||||
from onyx.server.query_and_chat.models import ChatSessionsResponse
|
||||
from onyx.server.utils import PUBLIC_API_TAGS
|
||||
from onyx.utils.threadpool_concurrency import parallel_yield
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
0
backend/ee/onyx/server/settings/__init__.py
Normal file
0
backend/ee/onyx/server/settings/__init__.py
Normal file
54
backend/ee/onyx/server/settings/api.py
Normal file
54
backend/ee/onyx/server/settings/api.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""EE Settings API - provides license-aware settings override."""
|
||||
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Statuses that indicate a billing/license problem - propagate these to settings
|
||||
_GATED_STATUSES = frozenset(
|
||||
{
|
||||
ApplicationStatus.GATED_ACCESS,
|
||||
ApplicationStatus.GRACE_PERIOD,
|
||||
ApplicationStatus.PAYMENT_REMINDER,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
"""EE version: checks license status for self-hosted deployments.
|
||||
|
||||
For self-hosted, looks up license metadata and overrides application_status
|
||||
if the license is missing or indicates a problem (expired, grace period, etc.).
|
||||
|
||||
For multi-tenant (cloud), the settings already have the correct status
|
||||
from the control plane, so no override is needed.
|
||||
|
||||
If LICENSE_ENFORCEMENT_ENABLED is false, settings are returned unchanged,
|
||||
allowing the product to function normally without license checks.
|
||||
"""
|
||||
if not LICENSE_ENFORCEMENT_ENABLED:
|
||||
return settings
|
||||
|
||||
if MULTI_TENANT:
|
||||
return settings
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if metadata and metadata.status in _GATED_STATUSES:
|
||||
settings.application_status = metadata.status
|
||||
elif not metadata:
|
||||
# No license = gated access for self-hosted EE
|
||||
settings.application_status = ApplicationStatus.GATED_ACCESS
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license metadata for settings: {e}")
|
||||
|
||||
return settings
|
||||
@@ -1,10 +1,14 @@
|
||||
"""Tenant-specific usage limit overrides from the control plane (EE version)."""
|
||||
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.server.tenant_usage_limits import TenantUsageLimitOverrides
|
||||
from onyx.server.usage_limits import NO_LIMIT
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -12,9 +16,12 @@ logger = setup_logger()
|
||||
|
||||
# In-memory storage for tenant overrides (populated at startup)
|
||||
_tenant_usage_limit_overrides: dict[str, TenantUsageLimitOverrides] | None = None
|
||||
_last_fetch_time: float = 0.0
|
||||
_FETCH_INTERVAL = 60 * 60 * 24 # 24 hours
|
||||
_ERROR_FETCH_INTERVAL = 30 * 60 # 30 minutes (if the last fetch failed)
|
||||
|
||||
|
||||
def fetch_usage_limit_overrides() -> dict[str, TenantUsageLimitOverrides]:
|
||||
def fetch_usage_limit_overrides() -> dict[str, TenantUsageLimitOverrides] | None:
|
||||
"""
|
||||
Fetch tenant-specific usage limit overrides from the control plane.
|
||||
|
||||
@@ -45,33 +52,52 @@ def fetch_usage_limit_overrides() -> dict[str, TenantUsageLimitOverrides]:
|
||||
f"Failed to parse usage limit overrides for tenant {tenant_id}: {e}"
|
||||
)
|
||||
|
||||
return result
|
||||
return (
|
||||
result or None
|
||||
) # if empty dictionary, something went wrong and we shouldn't enforce limits
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.warning(f"Failed to fetch usage limit overrides from control plane: {e}")
|
||||
return {}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing usage limit overrides: {e}")
|
||||
return {}
|
||||
return None
|
||||
|
||||
|
||||
def load_usage_limit_overrides() -> dict[str, TenantUsageLimitOverrides]:
|
||||
def load_usage_limit_overrides() -> None:
|
||||
"""
|
||||
Load tenant usage limit overrides from the control plane.
|
||||
|
||||
Called at server startup to populate the in-memory cache.
|
||||
"""
|
||||
global _tenant_usage_limit_overrides
|
||||
global _last_fetch_time
|
||||
|
||||
logger.info("Loading tenant usage limit overrides from control plane...")
|
||||
overrides = fetch_usage_limit_overrides()
|
||||
_tenant_usage_limit_overrides = overrides
|
||||
|
||||
_last_fetch_time = time.time()
|
||||
|
||||
# use the new result if it exists, otherwise use the old result
|
||||
# (prevents us from updating to a failed fetch result)
|
||||
_tenant_usage_limit_overrides = overrides or _tenant_usage_limit_overrides
|
||||
|
||||
if overrides:
|
||||
logger.info(f"Loaded usage limit overrides for {len(overrides)} tenants")
|
||||
else:
|
||||
logger.info("No tenant-specific usage limit overrides found")
|
||||
return overrides
|
||||
|
||||
|
||||
def unlimited(tenant_id: str) -> TenantUsageLimitOverrides:
|
||||
return TenantUsageLimitOverrides(
|
||||
tenant_id=tenant_id,
|
||||
llm_cost_cents_trial=NO_LIMIT,
|
||||
llm_cost_cents_paid=NO_LIMIT,
|
||||
chunks_indexed_trial=NO_LIMIT,
|
||||
chunks_indexed_paid=NO_LIMIT,
|
||||
api_calls_trial=NO_LIMIT,
|
||||
api_calls_paid=NO_LIMIT,
|
||||
non_streaming_calls_trial=NO_LIMIT,
|
||||
non_streaming_calls_paid=NO_LIMIT,
|
||||
)
|
||||
|
||||
|
||||
def get_tenant_usage_limit_overrides(
|
||||
@@ -86,7 +112,22 @@ def get_tenant_usage_limit_overrides(
|
||||
Returns:
|
||||
TenantUsageLimitOverrides if the tenant has overrides, None otherwise.
|
||||
"""
|
||||
|
||||
if DEV_MODE: # in dev mode, we return unlimited limits for all tenants
|
||||
return unlimited(tenant_id)
|
||||
|
||||
global _tenant_usage_limit_overrides
|
||||
if _tenant_usage_limit_overrides is None:
|
||||
_tenant_usage_limit_overrides = load_usage_limit_overrides()
|
||||
time_since = time.time() - _last_fetch_time
|
||||
if (
|
||||
_tenant_usage_limit_overrides is None and time_since > _ERROR_FETCH_INTERVAL
|
||||
) or (time_since > _FETCH_INTERVAL):
|
||||
logger.debug(
|
||||
f"Last fetch time: {_last_fetch_time}, time since last fetch: {time_since}"
|
||||
)
|
||||
|
||||
load_usage_limit_overrides()
|
||||
|
||||
# If we have failed to fetch from the control plane or we're in dev mode, don't usage limit anyone.
|
||||
if _tenant_usage_limit_overrides is None or DEV_MODE:
|
||||
return unlimited(tenant_id)
|
||||
return _tenant_usage_limit_overrides.get(tenant_id)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
import requests
|
||||
import stripe
|
||||
|
||||
from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
@@ -16,15 +16,21 @@ stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def fetch_stripe_checkout_session(tenant_id: str) -> str:
|
||||
def fetch_stripe_checkout_session(
|
||||
tenant_id: str,
|
||||
billing_period: Literal["monthly", "annual"] = "monthly",
|
||||
) -> str:
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/create-checkout-session"
|
||||
params = {"tenant_id": tenant_id}
|
||||
response = requests.post(url, headers=headers, params=params)
|
||||
payload = {
|
||||
"tenant_id": tenant_id,
|
||||
"billing_period": billing_period,
|
||||
}
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()["sessionId"]
|
||||
|
||||
@@ -70,24 +76,46 @@ def fetch_billing_information(
|
||||
return BillingInformation(**response_data)
|
||||
|
||||
|
||||
def fetch_customer_portal_session(tenant_id: str, return_url: str | None = None) -> str:
|
||||
"""
|
||||
Fetch a Stripe customer portal session URL from the control plane.
|
||||
NOTE: This is currently only used for multi-tenant (cloud) deployments.
|
||||
Self-hosted proxy endpoints will be added in a future phase.
|
||||
"""
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/create-customer-portal-session"
|
||||
payload = {"tenant_id": tenant_id}
|
||||
if return_url:
|
||||
payload["return_url"] = return_url
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()["url"]
|
||||
|
||||
|
||||
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:
|
||||
"""
|
||||
Send a request to the control service to register the number of users for a tenant.
|
||||
Update the number of seats for a tenant's subscription.
|
||||
Preserves the existing price (monthly, annual, or grandfathered).
|
||||
"""
|
||||
|
||||
if not STRIPE_PRICE_ID:
|
||||
raise Exception("STRIPE_PRICE_ID is not set")
|
||||
|
||||
response = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_subscription_id = cast(str, response.get("stripe_subscription_id"))
|
||||
|
||||
subscription = stripe.Subscription.retrieve(stripe_subscription_id)
|
||||
subscription_item = subscription["items"]["data"][0]
|
||||
|
||||
# Use existing price to preserve the customer's current plan
|
||||
current_price_id = subscription_item.price.id
|
||||
|
||||
updated_subscription = stripe.Subscription.modify(
|
||||
stripe_subscription_id,
|
||||
items=[
|
||||
{
|
||||
"id": subscription["items"]["data"][0].id,
|
||||
"price": STRIPE_PRICE_ID,
|
||||
"id": subscription_item.id,
|
||||
"price": current_price_id,
|
||||
"quantity": number_of_users,
|
||||
}
|
||||
],
|
||||
|
||||
@@ -1,33 +1,41 @@
|
||||
import stripe
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import control_plane_dep
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.billing import fetch_customer_portal_session
|
||||
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
|
||||
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import CreateSubscriptionSessionRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingFullSyncRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingResponse
|
||||
from ee.onyx.server.tenants.models import StripePublishableKeyResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.product_gating import overwrite_full_gated_set
|
||||
from ee.onyx.server.tenants.product_gating import store_product_gating
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
# Cache for Stripe publishable key to avoid hitting S3 on every request
|
||||
_stripe_publishable_key_cache: str | None = None
|
||||
_stripe_key_lock = asyncio.Lock()
|
||||
|
||||
|
||||
@router.post("/product-gating")
|
||||
def gate_product(
|
||||
@@ -82,21 +90,17 @@ async def billing_information(
|
||||
async def create_customer_portal_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> dict:
|
||||
"""
|
||||
Create a Stripe customer portal session via the control plane.
|
||||
NOTE: This is currently only used for multi-tenant (cloud) deployments.
|
||||
Self-hosted proxy endpoints will be added in a future phase.
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
return_url = f"{WEB_DOMAIN}/admin/billing"
|
||||
|
||||
try:
|
||||
stripe_info = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_customer_id = stripe_info.get("stripe_customer_id")
|
||||
if not stripe_customer_id:
|
||||
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
|
||||
logger.info(stripe_customer_id)
|
||||
|
||||
portal_session = stripe.billing_portal.Session.create(
|
||||
customer=stripe_customer_id,
|
||||
return_url=f"{WEB_DOMAIN}/admin/billing",
|
||||
)
|
||||
logger.info(portal_session)
|
||||
return {"url": portal_session.url}
|
||||
portal_url = fetch_customer_portal_session(tenant_id, return_url)
|
||||
return {"url": portal_url}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create customer portal session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -104,15 +108,82 @@ async def create_customer_portal_session(
|
||||
|
||||
@router.post("/create-subscription-session")
|
||||
async def create_subscription_session(
|
||||
request: CreateSubscriptionSessionRequest | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> SubscriptionSessionResponse:
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="Tenant ID not found")
|
||||
session_id = fetch_stripe_checkout_session(tenant_id)
|
||||
|
||||
billing_period = request.billing_period if request else "monthly"
|
||||
session_id = fetch_stripe_checkout_session(tenant_id, billing_period)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create resubscription session")
|
||||
logger.exception("Failed to create subscription session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/stripe-publishable-key")
|
||||
async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
"""
|
||||
Fetch the Stripe publishable key.
|
||||
Priority: env var override (for testing) > S3 bucket (production).
|
||||
This endpoint is public (no auth required) since publishable keys are safe to expose.
|
||||
The key is cached in memory to avoid hitting S3 on every request.
|
||||
"""
|
||||
global _stripe_publishable_key_cache
|
||||
|
||||
# Fast path: return cached value without lock
|
||||
if _stripe_publishable_key_cache:
|
||||
return StripePublishableKeyResponse(
|
||||
publishable_key=_stripe_publishable_key_cache
|
||||
)
|
||||
|
||||
# Use lock to prevent concurrent S3 requests
|
||||
async with _stripe_key_lock:
|
||||
# Double-check after acquiring lock (another request may have populated cache)
|
||||
if _stripe_publishable_key_cache:
|
||||
return StripePublishableKeyResponse(
|
||||
publishable_key=_stripe_publishable_key_cache
|
||||
)
|
||||
|
||||
# Check for env var override first (for local testing with pk_test_* keys)
|
||||
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
|
||||
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
)
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
|
||||
# Fall back to S3 bucket
|
||||
if not STRIPE_PUBLISHABLE_KEY_URL:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Stripe publishable key is not configured",
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(STRIPE_PUBLISHABLE_KEY_URL)
|
||||
response.raise_for_status()
|
||||
key = response.text.strip()
|
||||
|
||||
# Validate key format
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
)
|
||||
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
except httpx.HTTPError:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to fetch Stripe publishable key",
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -73,6 +74,12 @@ class SubscriptionSessionResponse(BaseModel):
|
||||
sessionId: str
|
||||
|
||||
|
||||
class CreateSubscriptionSessionRequest(BaseModel):
|
||||
"""Request to create a subscription checkout session."""
|
||||
|
||||
billing_period: Literal["monthly", "annual"] = "monthly"
|
||||
|
||||
|
||||
class TenantByDomainResponse(BaseModel):
|
||||
tenant_id: str
|
||||
number_of_users: int
|
||||
@@ -98,3 +105,7 @@ class PendingUserSnapshot(BaseModel):
|
||||
|
||||
class ApproveUserRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
class StripePublishableKeyResponse(BaseModel):
|
||||
publishable_key: str
|
||||
|
||||
@@ -65,3 +65,9 @@ def get_gated_tenants() -> set[str]:
|
||||
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
gated_tenants_bytes = cast(set[bytes], redis_client.smembers(GATED_TENANTS_KEY))
|
||||
return {tenant_id.decode("utf-8") for tenant_id in gated_tenants_bytes}
|
||||
|
||||
|
||||
def is_tenant_gated(tenant_id: str) -> bool:
|
||||
"""Fast O(1) check if tenant is in gated set (multi-tenant only)."""
|
||||
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
return bool(redis_client.sismember(GATED_TENANTS_KEY, tenant_id))
|
||||
|
||||
@@ -9,6 +9,7 @@ from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits_for_user
|
||||
from ee.onyx.db.token_limit import insert_user_group_token_rate_limit
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.token_limit import fetch_all_user_token_rate_limits
|
||||
@@ -16,7 +17,6 @@ from onyx.db.token_limit import insert_user_token_rate_limit
|
||||
from onyx.server.query_and_chat.token_limit import any_rate_limit_exists
|
||||
from onyx.server.token_rate_limits.models import TokenRateLimitArgs
|
||||
from onyx.server.token_rate_limits.models import TokenRateLimitDisplay
|
||||
from onyx.server.utils import PUBLIC_API_TAGS
|
||||
|
||||
router = APIRouter(prefix="/admin/token-rate-limits", tags=PUBLIC_API_TAGS)
|
||||
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
"""EE Usage limits - trial detection via billing information."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
@@ -31,13 +28,7 @@ def is_tenant_on_trial(tenant_id: str) -> bool:
|
||||
return True
|
||||
|
||||
if isinstance(billing_info, BillingInformation):
|
||||
# Check if trial is active
|
||||
if billing_info.trial_end is not None:
|
||||
now = datetime.now(timezone.utc)
|
||||
# Trial active if trial_end is in the future
|
||||
# and subscription status indicates trialing
|
||||
if billing_info.trial_end > now and billing_info.status == "trialing":
|
||||
return True
|
||||
return billing_info.status == "trialing"
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@@ -18,10 +18,10 @@ from ee.onyx.server.user_group.models import UserGroupCreate
|
||||
from ee.onyx.server.user_group.models import UserGroupUpdate
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.server.utils import PUBLIC_API_TAGS
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -105,6 +105,8 @@ class DocExternalAccess:
|
||||
)
|
||||
|
||||
|
||||
# TODO(andrei): First refactor this into a pydantic model, then get rid of
|
||||
# duplicate fields.
|
||||
@dataclass(frozen=True, init=False)
|
||||
class DocumentAccess(ExternalAccess):
|
||||
# User emails for Onyx users, None indicates admin
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import Protocol
|
||||
from typing import Tuple
|
||||
@@ -1456,6 +1457,9 @@ def get_default_admin_user_emails_() -> list[str]:
|
||||
|
||||
|
||||
STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state"
|
||||
STATE_TOKEN_LIFETIME_SECONDS = 3600
|
||||
CSRF_TOKEN_KEY = "csrftoken"
|
||||
CSRF_TOKEN_COOKIE_NAME = "fastapiusersoauthcsrf"
|
||||
|
||||
|
||||
class OAuth2AuthorizeResponse(BaseModel):
|
||||
@@ -1463,13 +1467,19 @@ class OAuth2AuthorizeResponse(BaseModel):
|
||||
|
||||
|
||||
def generate_state_token(
|
||||
data: Dict[str, str], secret: SecretType, lifetime_seconds: int = 3600
|
||||
data: Dict[str, str],
|
||||
secret: SecretType,
|
||||
lifetime_seconds: int = STATE_TOKEN_LIFETIME_SECONDS,
|
||||
) -> str:
|
||||
data["aud"] = STATE_TOKEN_AUDIENCE
|
||||
|
||||
return generate_jwt(data, secret, lifetime_seconds)
|
||||
|
||||
|
||||
def generate_csrf_token() -> str:
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91
|
||||
def create_onyx_oauth_router(
|
||||
oauth_client: BaseOAuth2,
|
||||
@@ -1498,6 +1508,13 @@ def get_oauth_router(
|
||||
redirect_url: Optional[str] = None,
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
*,
|
||||
csrf_token_cookie_name: str = CSRF_TOKEN_COOKIE_NAME,
|
||||
csrf_token_cookie_path: str = "/",
|
||||
csrf_token_cookie_domain: Optional[str] = None,
|
||||
csrf_token_cookie_secure: Optional[bool] = None,
|
||||
csrf_token_cookie_httponly: bool = True,
|
||||
csrf_token_cookie_samesite: Optional[Literal["lax", "strict", "none"]] = "lax",
|
||||
) -> APIRouter:
|
||||
"""Generate a router with the OAuth routes."""
|
||||
router = APIRouter()
|
||||
@@ -1514,6 +1531,9 @@ def get_oauth_router(
|
||||
route_name=callback_route_name,
|
||||
)
|
||||
|
||||
if csrf_token_cookie_secure is None:
|
||||
csrf_token_cookie_secure = WEB_DOMAIN.startswith("https")
|
||||
|
||||
@router.get(
|
||||
"/authorize",
|
||||
name=f"oauth:{oauth_client.name}.{backend.name}.authorize",
|
||||
@@ -1521,8 +1541,10 @@ def get_oauth_router(
|
||||
)
|
||||
async def authorize(
|
||||
request: Request,
|
||||
response: Response,
|
||||
redirect: bool = Query(False),
|
||||
scopes: List[str] = Query(None),
|
||||
) -> OAuth2AuthorizeResponse:
|
||||
) -> Response | OAuth2AuthorizeResponse:
|
||||
referral_source = request.cookies.get("referral_source", None)
|
||||
|
||||
if redirect_url is not None:
|
||||
@@ -1532,9 +1554,11 @@ def get_oauth_router(
|
||||
|
||||
next_url = request.query_params.get("next", "/")
|
||||
|
||||
csrf_token = generate_csrf_token()
|
||||
state_data: Dict[str, str] = {
|
||||
"next_url": next_url,
|
||||
"referral_source": referral_source or "default_referral",
|
||||
CSRF_TOKEN_KEY: csrf_token,
|
||||
}
|
||||
state = generate_state_token(state_data, state_secret)
|
||||
|
||||
@@ -1551,6 +1575,31 @@ def get_oauth_router(
|
||||
authorization_url, {"access_type": "offline", "prompt": "consent"}
|
||||
)
|
||||
|
||||
if redirect:
|
||||
redirect_response = RedirectResponse(authorization_url, status_code=302)
|
||||
redirect_response.set_cookie(
|
||||
key=csrf_token_cookie_name,
|
||||
value=csrf_token,
|
||||
max_age=STATE_TOKEN_LIFETIME_SECONDS,
|
||||
path=csrf_token_cookie_path,
|
||||
domain=csrf_token_cookie_domain,
|
||||
secure=csrf_token_cookie_secure,
|
||||
httponly=csrf_token_cookie_httponly,
|
||||
samesite=csrf_token_cookie_samesite,
|
||||
)
|
||||
return redirect_response
|
||||
|
||||
response.set_cookie(
|
||||
key=csrf_token_cookie_name,
|
||||
value=csrf_token,
|
||||
max_age=STATE_TOKEN_LIFETIME_SECONDS,
|
||||
path=csrf_token_cookie_path,
|
||||
domain=csrf_token_cookie_domain,
|
||||
secure=csrf_token_cookie_secure,
|
||||
httponly=csrf_token_cookie_httponly,
|
||||
samesite=csrf_token_cookie_samesite,
|
||||
)
|
||||
|
||||
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
@@ -1600,7 +1649,33 @@ def get_oauth_router(
|
||||
try:
|
||||
state_data = decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE])
|
||||
except jwt.DecodeError:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=getattr(
|
||||
ErrorCode, "ACCESS_TOKEN_DECODE_ERROR", "ACCESS_TOKEN_DECODE_ERROR"
|
||||
),
|
||||
)
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=getattr(
|
||||
ErrorCode,
|
||||
"ACCESS_TOKEN_ALREADY_EXPIRED",
|
||||
"ACCESS_TOKEN_ALREADY_EXPIRED",
|
||||
),
|
||||
)
|
||||
|
||||
cookie_csrf_token = request.cookies.get(csrf_token_cookie_name)
|
||||
state_csrf_token = state_data.get(CSRF_TOKEN_KEY)
|
||||
if (
|
||||
not cookie_csrf_token
|
||||
or not state_csrf_token
|
||||
or not secrets.compare_digest(cookie_csrf_token, state_csrf_token)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=getattr(ErrorCode, "OAUTH_INVALID_STATE", "OAUTH_INVALID_STATE"),
|
||||
)
|
||||
|
||||
next_url = state_data.get("next_url", "/")
|
||||
referral_source = state_data.get("referral_source", None)
|
||||
|
||||
@@ -26,10 +26,13 @@ from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
from onyx.background.celery.celery_utils import make_probe_path
|
||||
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_PREFIX
|
||||
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_TASKSET_KEY
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_FOR_ONYX
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
|
||||
from onyx.document_index.opensearch.client import (
|
||||
wait_for_opensearch_with_timeout,
|
||||
)
|
||||
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
@@ -516,14 +519,17 @@ def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for Vespa to become ready subject to a timeout.
|
||||
Raises WorkerShutdown if the timeout is reached."""
|
||||
|
||||
if ENABLE_OPENSEARCH_FOR_ONYX:
|
||||
return
|
||||
|
||||
if not wait_for_vespa_with_timeout():
|
||||
msg = "Vespa: Readiness probe did not succeed within the timeout. Exiting..."
|
||||
msg = "[Vespa] Readiness probe did not succeed within the timeout. Exiting..."
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
if not wait_for_opensearch_with_timeout():
|
||||
msg = "[OpenSearch] Readiness probe did not succeed within the timeout. Exiting..."
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
|
||||
# File for validating worker liveness
|
||||
class LivenessProbe(bootsteps.StartStopStep):
|
||||
|
||||
@@ -124,6 +124,7 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.kg_processing",
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
# Light worker tasks
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
|
||||
@@ -174,7 +174,7 @@ if AUTO_LLM_CONFIG_URL:
|
||||
"schedule": timedelta(seconds=AUTO_LLM_UPDATE_INTERVAL_SECONDS),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": AUTO_LLM_UPDATE_INTERVAL_SECONDS,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -87,7 +87,7 @@ from onyx.db.models import SearchSettings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.db.swap_index import check_and_perform_index_swap
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.file_store.document_batch_storage import DocumentBatchStorage
|
||||
from onyx.file_store.document_batch_storage import get_document_batch_storage
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
@@ -1436,7 +1436,7 @@ def _docprocessing_task(
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
document_index = get_default_document_index(
|
||||
document_indices = get_all_document_indices(
|
||||
index_attempt.search_settings,
|
||||
None,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
@@ -1473,7 +1473,7 @@ def _docprocessing_task(
|
||||
# real work happens here!
|
||||
index_pipeline_result = run_indexing_pipeline(
|
||||
embedder=embedding_model,
|
||||
document_index=document_index,
|
||||
document_indices=document_indices,
|
||||
ignore_time_skip=True, # Documents are already filtered during extraction
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@@ -5,6 +5,9 @@ from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
sync_llm_models_from_github,
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -26,24 +29,9 @@ def check_for_auto_llm_updates(self: Task, *, tenant_id: str) -> bool | None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
fetch_llm_recommendations_from_github,
|
||||
)
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
sync_llm_models_from_github,
|
||||
)
|
||||
|
||||
# Fetch config from GitHub
|
||||
config = fetch_llm_recommendations_from_github()
|
||||
|
||||
if not config:
|
||||
task_logger.warning("Failed to fetch GitHub config")
|
||||
return None
|
||||
|
||||
# Sync to database
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
results = sync_llm_models_from_github(db_session, config)
|
||||
results = sync_llm_models_from_github(db_session)
|
||||
|
||||
if results:
|
||||
task_logger.info(f"Auto mode sync results: {results}")
|
||||
|
||||
@@ -25,7 +25,7 @@ from onyx.db.document_set import fetch_document_sets_for_document
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.relationships import delete_document_references_from_kg
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -97,13 +97,17 @@ def document_by_cc_pair_cleanup_task(
|
||||
action = "skip"
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
# This flow is for updates and deletion so we get all indices.
|
||||
document_indices = get_all_document_indices(
|
||||
active_search_settings.primary,
|
||||
active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
|
||||
retry_index = RetryDocumentIndex(doc_index)
|
||||
retry_document_indices: list[RetryDocumentIndex] = [
|
||||
RetryDocumentIndex(document_index)
|
||||
for document_index in document_indices
|
||||
]
|
||||
|
||||
count = get_document_connector_count(db_session, document_id)
|
||||
if count == 1:
|
||||
@@ -113,11 +117,12 @@ def document_by_cc_pair_cleanup_task(
|
||||
|
||||
chunk_count = fetch_chunk_count_for_document(document_id, db_session)
|
||||
|
||||
_ = retry_index.delete_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
for retry_document_index in retry_document_indices:
|
||||
_ = retry_document_index.delete_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
|
||||
delete_document_references_from_kg(
|
||||
db_session=db_session,
|
||||
@@ -155,14 +160,18 @@ def document_by_cc_pair_cleanup_task(
|
||||
hidden=doc.hidden,
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
retry_index.update_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
fields=fields,
|
||||
user_fields=None,
|
||||
)
|
||||
for retry_document_index in retry_document_indices:
|
||||
# TODO(andrei): Previously there was a comment here saying
|
||||
# it was ok if a doc did not exist in the document index. I
|
||||
# don't agree with that claim, so keep an eye on this task
|
||||
# to see if this raises.
|
||||
retry_document_index.update_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
fields=fields,
|
||||
user_fields=None,
|
||||
)
|
||||
|
||||
# there are still other cc_pair references to the doc, so just resync to Vespa
|
||||
delete_document_by_connector_credential_pair__no_commit(
|
||||
|
||||
@@ -32,7 +32,7 @@ from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
@@ -244,7 +244,8 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
|
||||
search_settings=current_search_settings,
|
||||
)
|
||||
|
||||
document_index = get_default_document_index(
|
||||
# This flow is for indexing so we get all indices.
|
||||
document_indices = get_all_document_indices(
|
||||
current_search_settings,
|
||||
None,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
@@ -258,7 +259,7 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
|
||||
# real work happens here!
|
||||
index_pipeline_result = run_indexing_pipeline(
|
||||
embedder=embedding_model,
|
||||
document_index=document_index,
|
||||
document_indices=document_indices,
|
||||
ignore_time_skip=True,
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
@@ -412,12 +413,16 @@ def process_single_user_file_delete(
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
document_index = get_default_document_index(
|
||||
# This flow is for deletion so we get all indices.
|
||||
document_indices = get_all_document_indices(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
retry_index = RetryDocumentIndex(document_index)
|
||||
retry_document_indices: list[RetryDocumentIndex] = [
|
||||
RetryDocumentIndex(document_index)
|
||||
for document_index in document_indices
|
||||
]
|
||||
index_name = active_search_settings.primary.index_name
|
||||
selection = f"{index_name}.document_id=='{user_file_id}'"
|
||||
|
||||
@@ -438,11 +443,12 @@ def process_single_user_file_delete(
|
||||
else:
|
||||
chunk_count = user_file.chunk_count
|
||||
|
||||
retry_index.delete_single(
|
||||
doc_id=user_file_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.delete_single(
|
||||
doc_id=user_file_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
|
||||
# 2) Delete the user-uploaded file content from filestore (blob + metadata)
|
||||
file_store = get_default_file_store()
|
||||
@@ -564,12 +570,16 @@ def process_single_user_file_project_sync(
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
# This flow is for updates so we get all indices.
|
||||
document_indices = get_all_document_indices(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
retry_index = RetryDocumentIndex(doc_index)
|
||||
retry_document_indices: list[RetryDocumentIndex] = [
|
||||
RetryDocumentIndex(document_index)
|
||||
for document_index in document_indices
|
||||
]
|
||||
|
||||
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if not user_file:
|
||||
@@ -579,13 +589,14 @@ def process_single_user_file_project_sync(
|
||||
return None
|
||||
|
||||
project_ids = [project.id for project in user_file.projects]
|
||||
retry_index.update_single(
|
||||
doc_id=str(user_file.id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=user_file.chunk_count,
|
||||
fields=None,
|
||||
user_fields=VespaDocumentUserFields(user_projects=project_ids),
|
||||
)
|
||||
for retry_document_index in retry_document_indices:
|
||||
retry_document_index.update_single(
|
||||
doc_id=str(user_file.id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=user_file.chunk_count,
|
||||
fields=None,
|
||||
user_fields=VespaDocumentUserFields(user_projects=project_ids),
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - User file id={user_file_id}"
|
||||
|
||||
@@ -49,7 +49,7 @@ from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.sync_record import cleanup_sync_records
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.factory import get_all_document_indices
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
@@ -70,6 +70,8 @@ logger = setup_logger()
|
||||
|
||||
# celery auto associates tasks created inside another task,
|
||||
# which bloats the result metadata considerably. trail=False prevents this.
|
||||
# TODO(andrei): Rename all these kinds of functions from *vespa* to a more
|
||||
# generic *document_index*.
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
ignore_result=True,
|
||||
@@ -465,13 +467,17 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
# This flow is for updates so we get all indices.
|
||||
document_indices = get_all_document_indices(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
|
||||
retry_index = RetryDocumentIndex(doc_index)
|
||||
retry_document_indices: list[RetryDocumentIndex] = [
|
||||
RetryDocumentIndex(document_index)
|
||||
for document_index in document_indices
|
||||
]
|
||||
|
||||
doc = get_document(document_id, db_session)
|
||||
if not doc:
|
||||
@@ -500,14 +506,18 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
|
||||
# aggregated_boost_factor=doc.aggregated_boost_factor,
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
retry_index.update_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
fields=fields,
|
||||
user_fields=None,
|
||||
)
|
||||
for retry_document_index in retry_document_indices:
|
||||
# TODO(andrei): Previously there was a comment here saying
|
||||
# it was ok if a doc did not exist in the document index. I
|
||||
# don't agree with that claim, so keep an eye on this task
|
||||
# to see if this raises.
|
||||
retry_document_index.update_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
fields=fields,
|
||||
user_fields=None,
|
||||
)
|
||||
|
||||
# update db last. Worst case = we crash right before this and
|
||||
# the sync might repeat again later
|
||||
|
||||
57
backend/onyx/chat/chat_processing_checker.py
Normal file
57
backend/onyx/chat/chat_processing_checker.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from uuid import UUID
|
||||
|
||||
from redis.client import Redis
|
||||
|
||||
# Redis key prefixes for chat message processing
|
||||
PREFIX = "chatprocessing"
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
FENCE_TTL = 30 * 60 # 30 minutes
|
||||
|
||||
|
||||
def _get_fence_key(chat_session_id: UUID) -> str:
|
||||
"""
|
||||
Generate the Redis key for a chat session processing a message.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
|
||||
Returns:
|
||||
The fence key string (tenant_id is automatically added by the Redis client)
|
||||
"""
|
||||
return f"{FENCE_PREFIX}_{chat_session_id}"
|
||||
|
||||
|
||||
def set_processing_status(
|
||||
chat_session_id: UUID, redis_client: Redis, value: bool
|
||||
) -> None:
|
||||
"""
|
||||
Set or clear the fence for a chat session processing a message.
|
||||
|
||||
If the key exists, we are processing a message. If the key does not exist, we are not processing a message.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
redis_client: The Redis client to use
|
||||
value: True to set the fence, False to clear it
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
|
||||
if value:
|
||||
redis_client.set(fence_key, 0, ex=FENCE_TTL)
|
||||
else:
|
||||
redis_client.delete(fence_key)
|
||||
|
||||
|
||||
def is_chat_session_processing(chat_session_id: UUID, redis_client: Redis) -> bool:
|
||||
"""
|
||||
Check if the chat session is processing a message.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
redis_client: The Redis client to use
|
||||
|
||||
Returns:
|
||||
True if the chat session is processing a message, False otherwise
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
return bool(redis_client.exists(fence_key))
|
||||
@@ -7,6 +7,7 @@ from typing import Any
|
||||
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
@@ -15,6 +16,11 @@ from onyx.tools.models import ToolCallInfo
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
|
||||
# Type alias for search doc deduplication key
|
||||
# Simple key: just document_id (str)
|
||||
# Full key: (document_id, chunk_ind, match_highlights)
|
||||
SearchDocKey = str | tuple[str, int, tuple[str, ...]]
|
||||
|
||||
|
||||
class ChatStateContainer:
|
||||
"""Container for accumulating state during LLM loop execution.
|
||||
@@ -40,6 +46,10 @@ class ChatStateContainer:
|
||||
# True if this turn is a clarification question (deep research flow)
|
||||
self.is_clarification: bool = False
|
||||
# Note: LLM cost tracking is now handled in multi_llm.py
|
||||
# Search doc collection - maps dedup key to SearchDoc for all docs from tool calls
|
||||
self._all_search_docs: dict[SearchDocKey, SearchDoc] = {}
|
||||
# Track which citation numbers were actually emitted during streaming
|
||||
self._emitted_citations: set[int] = set()
|
||||
|
||||
def add_tool_call(self, tool_call: ToolCallInfo) -> None:
|
||||
"""Add a tool call to the accumulated state."""
|
||||
@@ -91,9 +101,58 @@ class ChatStateContainer:
|
||||
with self._lock:
|
||||
return self.is_clarification
|
||||
|
||||
@staticmethod
|
||||
def create_search_doc_key(
|
||||
search_doc: SearchDoc, use_simple_key: bool = True
|
||||
) -> SearchDocKey:
|
||||
"""Create a unique key for a SearchDoc for deduplication.
|
||||
|
||||
Args:
|
||||
search_doc: The SearchDoc to create a key for
|
||||
use_simple_key: If True (default), use only document_id for deduplication.
|
||||
If False, include chunk_ind and match_highlights so that the same
|
||||
document/chunk with different highlights are stored separately.
|
||||
"""
|
||||
if use_simple_key:
|
||||
return search_doc.document_id
|
||||
match_highlights_tuple = tuple(sorted(search_doc.match_highlights or []))
|
||||
return (search_doc.document_id, search_doc.chunk_ind, match_highlights_tuple)
|
||||
|
||||
def add_search_docs(
|
||||
self, search_docs: list[SearchDoc], use_simple_key: bool = True
|
||||
) -> None:
|
||||
"""Add search docs to the accumulated collection with deduplication.
|
||||
|
||||
Args:
|
||||
search_docs: List of SearchDoc objects to add
|
||||
use_simple_key: If True (default), deduplicate by document_id only.
|
||||
If False, deduplicate by document_id + chunk_ind + match_highlights.
|
||||
"""
|
||||
with self._lock:
|
||||
for doc in search_docs:
|
||||
key = self.create_search_doc_key(doc, use_simple_key)
|
||||
if key not in self._all_search_docs:
|
||||
self._all_search_docs[key] = doc
|
||||
|
||||
def get_all_search_docs(self) -> dict[SearchDocKey, SearchDoc]:
|
||||
"""Thread-safe getter for all accumulated search docs (returns a copy)."""
|
||||
with self._lock:
|
||||
return self._all_search_docs.copy()
|
||||
|
||||
def add_emitted_citation(self, citation_num: int) -> None:
|
||||
"""Add a citation number that was actually emitted during streaming."""
|
||||
with self._lock:
|
||||
self._emitted_citations.add(citation_num)
|
||||
|
||||
def get_emitted_citations(self) -> set[int]:
|
||||
"""Thread-safe getter for emitted citations (returns a copy)."""
|
||||
with self._lock:
|
||||
return self._emitted_citations.copy()
|
||||
|
||||
|
||||
def run_chat_loop_with_state_containers(
|
||||
func: Callable[..., None],
|
||||
completion_callback: Callable[[ChatStateContainer], None],
|
||||
is_connected: Callable[[], bool],
|
||||
emitter: Emitter,
|
||||
state_container: ChatStateContainer,
|
||||
@@ -196,3 +255,12 @@ def run_chat_loop_with_state_containers(
|
||||
# Skip waiting if user disconnected to exit quickly.
|
||||
if is_connected():
|
||||
wait_on_background(thread)
|
||||
try:
|
||||
completion_callback(state_container)
|
||||
except Exception as e:
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=PacketException(type="error", exception=e),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -18,12 +18,10 @@ from onyx.background.celery.tasks.kg_processing.kg_indexing import (
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import get_chat_messages_by_session
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
@@ -48,13 +46,10 @@ from onyx.kg.models import KGException
|
||||
from onyx.kg.setup.kg_default_entity_definitions import (
|
||||
populate_missing_default_entity_types__commit,
|
||||
)
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.prompts.chat_prompts import ADDITIONAL_CONTEXT_PROMPT
|
||||
from onyx.prompts.chat_prompts import TOOL_CALL_RESPONSE_CROSS_MESSAGE
|
||||
from onyx.prompts.tool_prompts import TOOL_CALL_FAILURE_PROMPT
|
||||
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import (
|
||||
@@ -103,89 +98,6 @@ def create_chat_session_from_request(
|
||||
)
|
||||
|
||||
|
||||
def prepare_chat_message_request(
|
||||
message_text: str,
|
||||
user: User | None,
|
||||
persona_id: int | None,
|
||||
# Does the question need to have a persona override
|
||||
persona_override_config: PersonaOverrideConfig | None,
|
||||
message_ts_to_respond_to: str | None,
|
||||
retrieval_details: RetrievalDetails | None,
|
||||
rerank_settings: RerankingDetails | None,
|
||||
db_session: Session,
|
||||
skip_gen_ai_answer_generation: bool = False,
|
||||
llm_override: LLMOverride | None = None,
|
||||
allowed_tool_ids: list[int] | None = None,
|
||||
forced_tool_ids: list[int] | None = None,
|
||||
) -> CreateChatMessageRequest:
|
||||
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
|
||||
new_chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description=None,
|
||||
user_id=user.id if user else None,
|
||||
# If using an override, this id will be ignored later on
|
||||
persona_id=persona_id or DEFAULT_PERSONA_ID,
|
||||
onyxbot_flow=True,
|
||||
slack_thread_id=message_ts_to_respond_to,
|
||||
)
|
||||
|
||||
return CreateChatMessageRequest(
|
||||
chat_session_id=new_chat_session.id,
|
||||
parent_message_id=None, # It's a standalone chat session each time
|
||||
message=message_text,
|
||||
file_descriptors=[], # Currently SlackBot/answer api do not support files in the context
|
||||
# Can always override the persona for the single query, if it's a normal persona
|
||||
# then it will be treated the same
|
||||
persona_override_config=persona_override_config,
|
||||
search_doc_ids=None,
|
||||
retrieval_options=retrieval_details,
|
||||
rerank_settings=rerank_settings,
|
||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||
llm_override=llm_override,
|
||||
allowed_tool_ids=allowed_tool_ids,
|
||||
forced_tool_ids=forced_tool_ids,
|
||||
)
|
||||
|
||||
|
||||
def combine_message_thread(
|
||||
messages: list[ThreadMessage],
|
||||
max_tokens: int | None,
|
||||
llm_tokenizer: BaseTokenizer,
|
||||
) -> str:
|
||||
"""Used to create a single combined message context from threads"""
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
message_strs: list[str] = []
|
||||
total_token_count = 0
|
||||
|
||||
for message in reversed(messages):
|
||||
if message.role == MessageType.USER:
|
||||
role_str = message.role.value.upper()
|
||||
if message.sender:
|
||||
role_str += " " + message.sender
|
||||
else:
|
||||
# Since other messages might have the user identifying information
|
||||
# better to use Unknown for symmetry
|
||||
role_str += " Unknown"
|
||||
else:
|
||||
role_str = message.role.value.upper()
|
||||
|
||||
msg_str = f"{role_str}:\n{message.message}"
|
||||
message_token_count = len(llm_tokenizer.encode(msg_str))
|
||||
|
||||
if (
|
||||
max_tokens is not None
|
||||
and total_token_count + message_token_count > max_tokens
|
||||
):
|
||||
break
|
||||
|
||||
message_strs.insert(0, msg_str)
|
||||
total_token_count += message_token_count
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
|
||||
def create_chat_history_chain(
|
||||
chat_session_id: UUID,
|
||||
db_session: Session,
|
||||
@@ -247,31 +159,6 @@ def create_chat_history_chain(
|
||||
return mainline_messages
|
||||
|
||||
|
||||
def combine_message_chain(
|
||||
messages: list[ChatMessage],
|
||||
token_limit: int,
|
||||
msg_limit: int | None = None,
|
||||
) -> str:
|
||||
"""Used for secondary LLM flows that require the chat history,"""
|
||||
message_strs: list[str] = []
|
||||
total_token_count = 0
|
||||
|
||||
if msg_limit is not None:
|
||||
messages = messages[-msg_limit:]
|
||||
|
||||
for message in cast(list[ChatMessage], reversed(messages)):
|
||||
message_token_count = message.token_count
|
||||
|
||||
if total_token_count + message_token_count > token_limit:
|
||||
break
|
||||
|
||||
role = message.message_type.value.upper()
|
||||
message_strs.insert(0, f"{role}:\n{message.message}")
|
||||
total_token_count += message_token_count
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
|
||||
def reorganize_citations(
|
||||
answer: str, citations: list[CitationInfo]
|
||||
) -> tuple[str, list[CitationInfo]]:
|
||||
@@ -412,7 +299,7 @@ def create_temporary_persona(
|
||||
num_chunks=persona_config.num_chunks,
|
||||
llm_relevance_filter=persona_config.llm_relevance_filter,
|
||||
llm_filter_extraction=persona_config.llm_filter_extraction,
|
||||
recency_bias=persona_config.recency_bias,
|
||||
recency_bias=RecencyBiasSetting.BASE_DECAY,
|
||||
llm_model_provider_override=persona_config.llm_model_provider_override,
|
||||
llm_model_version_override=persona_config.llm_model_version_override,
|
||||
)
|
||||
@@ -582,6 +469,71 @@ def load_all_chat_files(
|
||||
return files
|
||||
|
||||
|
||||
def convert_chat_history_basic(
|
||||
chat_history: list[ChatMessage],
|
||||
token_counter: Callable[[str], int],
|
||||
max_individual_message_tokens: int | None = None,
|
||||
max_total_tokens: int | None = None,
|
||||
) -> list[ChatMessageSimple]:
|
||||
"""Convert ChatMessage history to ChatMessageSimple format with no tool calls or files included.
|
||||
|
||||
Args:
|
||||
chat_history: List of ChatMessage objects to convert
|
||||
token_counter: Function to count tokens in a message string
|
||||
max_individual_message_tokens: If set, messages exceeding this number of tokens are dropped.
|
||||
If None, no messages are dropped based on individual token count.
|
||||
max_total_tokens: If set, maximum number of tokens allowed for the entire history.
|
||||
If None, the history is not trimmed based on total token count.
|
||||
|
||||
Returns:
|
||||
List of ChatMessageSimple objects
|
||||
"""
|
||||
# Defensive: treat a non-positive total budget as "no history".
|
||||
if max_total_tokens is not None and max_total_tokens <= 0:
|
||||
return []
|
||||
|
||||
# Convert only the core USER/ASSISTANT messages; omit files and tool calls.
|
||||
converted: list[ChatMessageSimple] = []
|
||||
for chat_message in chat_history:
|
||||
if chat_message.message_type not in (MessageType.USER, MessageType.ASSISTANT):
|
||||
continue
|
||||
|
||||
message = chat_message.message or ""
|
||||
token_count = getattr(chat_message, "token_count", None)
|
||||
if token_count is None:
|
||||
token_count = token_counter(message)
|
||||
|
||||
# Drop any single message that would dominate the context window.
|
||||
if (
|
||||
max_individual_message_tokens is not None
|
||||
and token_count > max_individual_message_tokens
|
||||
):
|
||||
continue
|
||||
|
||||
converted.append(
|
||||
ChatMessageSimple(
|
||||
message=message,
|
||||
token_count=token_count,
|
||||
message_type=chat_message.message_type,
|
||||
image_files=None,
|
||||
)
|
||||
)
|
||||
|
||||
if max_total_tokens is None:
|
||||
return converted
|
||||
|
||||
# Enforce a max total budget by keeping a contiguous suffix of the conversation.
|
||||
trimmed_reversed: list[ChatMessageSimple] = []
|
||||
total_tokens = 0
|
||||
for msg in reversed(converted):
|
||||
if total_tokens + msg.token_count > max_total_tokens:
|
||||
break
|
||||
trimmed_reversed.append(msg)
|
||||
total_tokens += msg.token_count
|
||||
|
||||
return list(reversed(trimmed_reversed))
|
||||
|
||||
|
||||
def convert_chat_history(
|
||||
chat_history: list[ChatMessage],
|
||||
files: list[ChatLoadedFile],
|
||||
|
||||
@@ -4,14 +4,15 @@ Dynamic Citation Processor for LLM Responses
|
||||
This module provides a citation processor that can:
|
||||
- Accept citation number to SearchDoc mappings dynamically
|
||||
- Process token streams from LLMs to extract citations
|
||||
- Optionally replace citation markers with formatted markdown links
|
||||
- Emit CitationInfo objects for detected citations (when replacing)
|
||||
- Track all seen citations regardless of replacement mode
|
||||
- Handle citations in three modes: REMOVE, KEEP_MARKERS, or HYPERLINK
|
||||
- Emit CitationInfo objects for detected citations (in HYPERLINK mode)
|
||||
- Track all seen citations regardless of mode
|
||||
- Maintain a list of cited documents in order of first citation
|
||||
"""
|
||||
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from typing import TypeAlias
|
||||
|
||||
from onyx.configs.chat_configs import STOP_STREAM_PAT
|
||||
@@ -23,6 +24,29 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class CitationMode(Enum):
|
||||
"""Defines how citations should be handled in the output.
|
||||
|
||||
REMOVE: Citations are completely removed from output text.
|
||||
No CitationInfo objects are emitted.
|
||||
Use case: When you need to remove citations from the output if they are not shared with the user
|
||||
(e.g. in discord bot, public slack bot).
|
||||
|
||||
KEEP_MARKERS: Original citation markers like [1], [2] are preserved unchanged.
|
||||
No CitationInfo objects are emitted.
|
||||
Use case: When you need to track citations in research agent and later process
|
||||
them with collapse_citations() to renumber.
|
||||
|
||||
HYPERLINK: Citations are replaced with markdown links like [[1]](url).
|
||||
CitationInfo objects are emitted for UI tracking.
|
||||
Use case: Final reports shown to users with clickable links.
|
||||
"""
|
||||
|
||||
REMOVE = "remove"
|
||||
KEEP_MARKERS = "keep_markers"
|
||||
HYPERLINK = "hyperlink"
|
||||
|
||||
|
||||
CitationMapping: TypeAlias = dict[int, SearchDoc]
|
||||
|
||||
|
||||
@@ -48,29 +72,37 @@ class DynamicCitationProcessor:
|
||||
|
||||
This processor is designed for multi-turn conversations where the citation
|
||||
number to document mapping is provided externally. It processes streaming
|
||||
tokens from an LLM, detects citations (e.g., [1], [2,3], [[4]]), and based
|
||||
on the `replace_citation_tokens` setting:
|
||||
tokens from an LLM, detects citations (e.g., [1], [2,3], [[4]]), and handles
|
||||
them according to the configured CitationMode:
|
||||
|
||||
When replace_citation_tokens=True (default):
|
||||
CitationMode.HYPERLINK (default):
|
||||
1. Replaces citation markers with formatted markdown links (e.g., [[1]](url))
|
||||
2. Emits CitationInfo objects for tracking
|
||||
3. Maintains the order in which documents were first cited
|
||||
Use case: Final reports shown to users with clickable links.
|
||||
|
||||
When replace_citation_tokens=False:
|
||||
1. Preserves original citation markers in the output text
|
||||
CitationMode.KEEP_MARKERS:
|
||||
1. Preserves original citation markers like [1], [2] unchanged
|
||||
2. Does NOT emit CitationInfo objects
|
||||
3. Still tracks all seen citations via get_seen_citations()
|
||||
Use case: When citations need later processing (e.g., renumbering).
|
||||
|
||||
CitationMode.REMOVE:
|
||||
1. Removes citation markers entirely from the output text
|
||||
2. Does NOT emit CitationInfo objects
|
||||
3. Still tracks all seen citations via get_seen_citations()
|
||||
Use case: Research agent intermediate reports.
|
||||
|
||||
Features:
|
||||
- Accepts citation number → SearchDoc mapping via update_citation_mapping()
|
||||
- Configurable citation replacement behavior at initialization
|
||||
- Always tracks seen citations regardless of replacement mode
|
||||
- Configurable citation mode at initialization
|
||||
- Always tracks seen citations regardless of mode
|
||||
- Holds back tokens that might be partial citations
|
||||
- Maintains list of cited SearchDocs in order of first citation
|
||||
- Handles unicode bracket variants (【】, [])
|
||||
- Skips citation processing inside code blocks
|
||||
|
||||
Example (with citation replacement - default):
|
||||
Example (HYPERLINK mode - default):
|
||||
processor = DynamicCitationProcessor()
|
||||
|
||||
# Set up citation mapping
|
||||
@@ -87,8 +119,8 @@ class DynamicCitationProcessor:
|
||||
# Get cited documents at the end
|
||||
cited_docs = processor.get_cited_documents()
|
||||
|
||||
Example (without citation replacement):
|
||||
processor = DynamicCitationProcessor(replace_citation_tokens=False)
|
||||
Example (KEEP_MARKERS mode):
|
||||
processor = DynamicCitationProcessor(citation_mode=CitationMode.KEEP_MARKERS)
|
||||
processor.update_citation_mapping({1: search_doc1, 2: search_doc2})
|
||||
|
||||
# Process tokens from LLM
|
||||
@@ -99,26 +131,42 @@ class DynamicCitationProcessor:
|
||||
|
||||
# Get all seen citations after processing
|
||||
seen_citations = processor.get_seen_citations() # {1: search_doc1, ...}
|
||||
|
||||
Example (REMOVE mode):
|
||||
processor = DynamicCitationProcessor(citation_mode=CitationMode.REMOVE)
|
||||
processor.update_citation_mapping({1: search_doc1, 2: search_doc2})
|
||||
|
||||
# Process tokens - citations are removed but tracked
|
||||
for token in llm_stream:
|
||||
for result in processor.process_token(token):
|
||||
print(result) # Text without any citation markers
|
||||
|
||||
# Citations are still tracked
|
||||
seen_citations = processor.get_seen_citations()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
replace_citation_tokens: bool = True,
|
||||
citation_mode: CitationMode = CitationMode.HYPERLINK,
|
||||
stop_stream: str | None = STOP_STREAM_PAT,
|
||||
):
|
||||
"""
|
||||
Initialize the citation processor.
|
||||
|
||||
Args:
|
||||
replace_citation_tokens: If True (default), citations like [1] are replaced
|
||||
with formatted markdown links like [[1]](url) and CitationInfo objects
|
||||
are emitted. If False, original citation text is preserved in output
|
||||
and no CitationInfo objects are emitted. Regardless of this setting,
|
||||
all seen citations are tracked and available via get_seen_citations().
|
||||
citation_mode: How to handle citations in the output. One of:
|
||||
- CitationMode.HYPERLINK (default): Replace [1] with [[1]](url)
|
||||
and emit CitationInfo objects.
|
||||
- CitationMode.KEEP_MARKERS: Keep original [1] markers unchanged,
|
||||
no CitationInfo objects emitted.
|
||||
- CitationMode.REMOVE: Remove citations entirely from output,
|
||||
no CitationInfo objects emitted.
|
||||
All modes track seen citations via get_seen_citations().
|
||||
stop_stream: Optional stop token pattern to halt processing early.
|
||||
When this pattern is detected in the token stream, processing stops.
|
||||
Defaults to STOP_STREAM_PAT from chat configs.
|
||||
"""
|
||||
|
||||
# Citation mapping from citation number to SearchDoc
|
||||
self.citation_to_doc: CitationMapping = {}
|
||||
self.seen_citations: CitationMapping = {} # citation num -> SearchDoc
|
||||
@@ -128,7 +176,7 @@ class DynamicCitationProcessor:
|
||||
self.curr_segment = "" # tokens held for citation processing
|
||||
self.hold = "" # tokens held for stop token processing
|
||||
self.stop_stream = stop_stream
|
||||
self.replace_citation_tokens = replace_citation_tokens
|
||||
self.citation_mode = citation_mode
|
||||
|
||||
# Citation tracking
|
||||
self.cited_documents_in_order: list[SearchDoc] = (
|
||||
@@ -199,19 +247,21 @@ class DynamicCitationProcessor:
|
||||
5. Handles stop tokens
|
||||
6. Always tracks seen citations in self.seen_citations
|
||||
|
||||
Behavior depends on the `replace_citation_tokens` setting from __init__:
|
||||
- If True: Citations are replaced with [[n]](url) format and CitationInfo
|
||||
Behavior depends on the `citation_mode` setting from __init__:
|
||||
- HYPERLINK: Citations are replaced with [[n]](url) format and CitationInfo
|
||||
objects are yielded before each formatted citation
|
||||
- If False: Original citation text (e.g., [1]) is preserved in output
|
||||
and no CitationInfo objects are yielded
|
||||
- KEEP_MARKERS: Original citation markers like [1] are preserved unchanged,
|
||||
no CitationInfo objects are yielded
|
||||
- REMOVE: Citations are removed entirely from output,
|
||||
no CitationInfo objects are yielded
|
||||
|
||||
Args:
|
||||
token: The next token from the LLM stream, or None to signal end of stream.
|
||||
Pass None to flush any remaining buffered text at end of stream.
|
||||
|
||||
Yields:
|
||||
str: Text chunks to display. Citation format depends on replace_citation_tokens.
|
||||
CitationInfo: Citation metadata (only when replace_citation_tokens=True)
|
||||
str: Text chunks to display. Citation format depends on citation_mode.
|
||||
CitationInfo: Citation metadata (only when citation_mode=HYPERLINK)
|
||||
"""
|
||||
# None -> end of stream, flush remaining segment
|
||||
if token is None:
|
||||
@@ -299,17 +349,17 @@ class DynamicCitationProcessor:
|
||||
if self.non_citation_count > 5:
|
||||
self.recent_cited_documents.clear()
|
||||
|
||||
# Yield text before citation FIRST (preserve order)
|
||||
if intermatch_str:
|
||||
yield intermatch_str
|
||||
|
||||
# Process the citation (returns formatted citation text and CitationInfo objects)
|
||||
# Always tracks seen citations regardless of strip_citations flag
|
||||
# Always tracks seen citations regardless of citation_mode
|
||||
citation_text, citation_info_list = self._process_citation(
|
||||
match, has_leading_space, self.replace_citation_tokens
|
||||
match, has_leading_space
|
||||
)
|
||||
|
||||
if self.replace_citation_tokens:
|
||||
if self.citation_mode == CitationMode.HYPERLINK:
|
||||
# HYPERLINK mode: Replace citations with markdown links [[n]](url)
|
||||
# Yield text before citation FIRST (preserve order)
|
||||
if intermatch_str:
|
||||
yield intermatch_str
|
||||
# Yield CitationInfo objects BEFORE the citation text
|
||||
# This allows the frontend to receive citation metadata before the token
|
||||
# that contains [[n]](link), enabling immediate rendering
|
||||
@@ -318,10 +368,34 @@ class DynamicCitationProcessor:
|
||||
# Then yield the formatted citation text
|
||||
if citation_text:
|
||||
yield citation_text
|
||||
else:
|
||||
# When not stripping, yield the original citation text unchanged
|
||||
|
||||
elif self.citation_mode == CitationMode.KEEP_MARKERS:
|
||||
# KEEP_MARKERS mode: Preserve original citation markers unchanged
|
||||
# Yield text before citation
|
||||
if intermatch_str:
|
||||
yield intermatch_str
|
||||
# Yield the original citation marker as-is
|
||||
yield match.group()
|
||||
|
||||
else: # CitationMode.REMOVE
|
||||
# REMOVE mode: Remove citations entirely from output
|
||||
# This strips citation markers like [1], [2], 【1】 from the output text
|
||||
# When removing citations, we need to handle spacing to avoid issues like:
|
||||
# - "text [1] more" -> "text more" (double space)
|
||||
# - "text [1]." -> "text ." (space before punctuation)
|
||||
if intermatch_str:
|
||||
remaining_text = self.curr_segment[match_span[1] :]
|
||||
# Strip trailing space from intermatch if:
|
||||
# 1. Remaining text starts with space (avoids double space)
|
||||
# 2. Remaining text starts with punctuation (avoids space before punctuation)
|
||||
if intermatch_str[-1].isspace() and remaining_text:
|
||||
first_char = remaining_text[0]
|
||||
# Check if next char is space or common punctuation
|
||||
if first_char.isspace() or first_char in ".,;:!?)]}":
|
||||
intermatch_str = intermatch_str.rstrip()
|
||||
if intermatch_str:
|
||||
yield intermatch_str
|
||||
|
||||
self.non_citation_count = 0
|
||||
|
||||
# Leftover text could be part of next citation
|
||||
@@ -338,7 +412,7 @@ class DynamicCitationProcessor:
|
||||
yield result
|
||||
|
||||
def _process_citation(
|
||||
self, match: re.Match, has_leading_space: bool, replace_tokens: bool = True
|
||||
self, match: re.Match, has_leading_space: bool
|
||||
) -> tuple[str, list[CitationInfo]]:
|
||||
"""
|
||||
Process a single citation match and return formatted citation text and citation info objects.
|
||||
@@ -349,31 +423,28 @@ class DynamicCitationProcessor:
|
||||
This method always:
|
||||
1. Extracts citation numbers from the match
|
||||
2. Looks up the corresponding SearchDoc from the mapping
|
||||
3. Tracks seen citations in self.seen_citations (regardless of replace_tokens)
|
||||
3. Tracks seen citations in self.seen_citations (regardless of citation_mode)
|
||||
|
||||
When replace_tokens=True (controlled by self.replace_citation_tokens):
|
||||
When citation_mode is HYPERLINK:
|
||||
4. Creates formatted citation text as [[n]](url)
|
||||
5. Creates CitationInfo objects for new citations
|
||||
6. Handles deduplication of recently cited documents
|
||||
|
||||
When replace_tokens=False:
|
||||
4. Returns empty string and empty list (caller yields original match text)
|
||||
When citation_mode is REMOVE or KEEP_MARKERS:
|
||||
4. Returns empty string and empty list (caller handles output based on mode)
|
||||
|
||||
Args:
|
||||
match: Regex match object containing the citation pattern
|
||||
has_leading_space: Whether the text immediately before this citation
|
||||
ends with whitespace. Used to determine if a leading space should
|
||||
be added to the formatted output.
|
||||
replace_tokens: If True, return formatted text and CitationInfo objects.
|
||||
If False, only track seen citations and return empty results.
|
||||
This is passed from self.replace_citation_tokens by the caller.
|
||||
|
||||
Returns:
|
||||
Tuple of (formatted_citation_text, citation_info_list):
|
||||
- formatted_citation_text: Markdown-formatted citation text like
|
||||
"[[1]](https://example.com)" or empty string if replace_tokens=False
|
||||
"[[1]](https://example.com)" or empty string if not in HYPERLINK mode
|
||||
- citation_info_list: List of CitationInfo objects for newly cited
|
||||
documents, or empty list if replace_tokens=False
|
||||
documents, or empty list if not in HYPERLINK mode
|
||||
"""
|
||||
citation_str: str = match.group() # e.g., '[1]', '[1, 2, 3]', '[[1]]', '【1】'
|
||||
formatted = (
|
||||
@@ -411,11 +482,11 @@ class DynamicCitationProcessor:
|
||||
doc_id = search_doc.document_id
|
||||
link = search_doc.link or ""
|
||||
|
||||
# Always track seen citations regardless of replace_tokens setting
|
||||
# Always track seen citations regardless of citation_mode setting
|
||||
self.seen_citations[num] = search_doc
|
||||
|
||||
# When not replacing citation tokens, skip the rest of the processing
|
||||
if not replace_tokens:
|
||||
# Only generate formatted citations and CitationInfo in HYPERLINK mode
|
||||
if self.citation_mode != CitationMode.HYPERLINK:
|
||||
continue
|
||||
|
||||
# Format the citation text as [[n]](link)
|
||||
@@ -450,14 +521,14 @@ class DynamicCitationProcessor:
|
||||
"""
|
||||
Get the list of cited SearchDoc objects in the order they were first cited.
|
||||
|
||||
Note: This list is only populated when `replace_citation_tokens=True`.
|
||||
When `replace_citation_tokens=False`, this will return an empty list.
|
||||
Note: This list is only populated when `citation_mode=HYPERLINK`.
|
||||
When using REMOVE or KEEP_MARKERS mode, this will return an empty list.
|
||||
Use get_seen_citations() instead if you need to track citations without
|
||||
replacing them.
|
||||
emitting CitationInfo objects.
|
||||
|
||||
Returns:
|
||||
List of SearchDoc objects in the order they were first cited.
|
||||
Empty list if replace_citation_tokens=False.
|
||||
Empty list if citation_mode is not HYPERLINK.
|
||||
"""
|
||||
return self.cited_documents_in_order
|
||||
|
||||
@@ -465,14 +536,14 @@ class DynamicCitationProcessor:
|
||||
"""
|
||||
Get the list of cited document IDs in the order they were first cited.
|
||||
|
||||
Note: This list is only populated when `replace_citation_tokens=True`.
|
||||
When `replace_citation_tokens=False`, this will return an empty list.
|
||||
Note: This list is only populated when `citation_mode=HYPERLINK`.
|
||||
When using REMOVE or KEEP_MARKERS mode, this will return an empty list.
|
||||
Use get_seen_citations() instead if you need to track citations without
|
||||
replacing them.
|
||||
emitting CitationInfo objects.
|
||||
|
||||
Returns:
|
||||
List of document IDs (strings) in the order they were first cited.
|
||||
Empty list if replace_citation_tokens=False.
|
||||
Empty list if citation_mode is not HYPERLINK.
|
||||
"""
|
||||
return [doc.document_id for doc in self.cited_documents_in_order]
|
||||
|
||||
@@ -481,12 +552,12 @@ class DynamicCitationProcessor:
|
||||
Get all seen citations as a mapping from citation number to SearchDoc.
|
||||
|
||||
This returns all citations that have been encountered during processing,
|
||||
regardless of the `replace_citation_tokens` setting. Citations are tracked
|
||||
regardless of the `citation_mode` setting. Citations are tracked
|
||||
whenever they are parsed, making this useful for cases where you need to
|
||||
know which citations appeared in the text without replacing them.
|
||||
know which citations appeared in the text without emitting CitationInfo objects.
|
||||
|
||||
This is particularly useful when `replace_citation_tokens=False`, as
|
||||
get_cited_documents() will be empty in that case, but get_seen_citations()
|
||||
This is particularly useful when using REMOVE or KEEP_MARKERS mode, as
|
||||
get_cited_documents() will be empty in those cases, but get_seen_citations()
|
||||
will still contain all the citations that were found.
|
||||
|
||||
Returns:
|
||||
@@ -501,13 +572,13 @@ class DynamicCitationProcessor:
|
||||
"""
|
||||
Get the number of unique documents that have been cited.
|
||||
|
||||
Note: This count is only updated when `replace_citation_tokens=True`.
|
||||
When `replace_citation_tokens=False`, this will always return 0.
|
||||
Note: This count is only updated when `citation_mode=HYPERLINK`.
|
||||
When using REMOVE or KEEP_MARKERS mode, this will always return 0.
|
||||
Use len(get_seen_citations()) instead if you need to count citations
|
||||
without replacing them.
|
||||
without emitting CitationInfo objects.
|
||||
|
||||
Returns:
|
||||
Number of unique documents cited. 0 if replace_citation_tokens=False.
|
||||
Number of unique documents cited. 0 if citation_mode is not HYPERLINK.
|
||||
"""
|
||||
return len(self.cited_document_ids)
|
||||
|
||||
@@ -519,9 +590,9 @@ class DynamicCitationProcessor:
|
||||
CitationInfo objects for the same document when it's cited multiple times
|
||||
in close succession. This method clears that tracker.
|
||||
|
||||
This is primarily useful when `replace_citation_tokens=True` to allow
|
||||
This is primarily useful when `citation_mode=HYPERLINK` to allow
|
||||
previously cited documents to emit CitationInfo objects again. Has no
|
||||
effect when `replace_citation_tokens=False`.
|
||||
effect when using REMOVE or KEEP_MARKERS mode.
|
||||
|
||||
The recent citation tracker is also automatically cleared when more than
|
||||
5 non-citation characters are processed between citations.
|
||||
|
||||
@@ -53,6 +53,50 @@ def update_citation_processor_from_tool_response(
|
||||
citation_processor.update_citation_mapping(citation_to_doc)
|
||||
|
||||
|
||||
def extract_citation_order_from_text(text: str) -> list[int]:
|
||||
"""Extract citation numbers from text in order of first appearance.
|
||||
|
||||
Parses citation patterns like [1], [1, 2], [[1]], 【1】 etc. and returns
|
||||
the citation numbers in the order they first appear in the text.
|
||||
|
||||
Args:
|
||||
text: The text containing citations
|
||||
|
||||
Returns:
|
||||
List of citation numbers in order of first appearance (no duplicates)
|
||||
"""
|
||||
# Same pattern used in collapse_citations and DynamicCitationProcessor
|
||||
# Group 2 captures the number in double bracket format: [[1]], 【【1】】
|
||||
# Group 4 captures the numbers in single bracket format: [1], [1, 2]
|
||||
citation_pattern = re.compile(
|
||||
r"([\[【[]{2}(\d+)[\]】]]{2})|([\[【[]([\d]+(?: *, *\d+)*)[\]】]])"
|
||||
)
|
||||
seen: set[int] = set()
|
||||
order: list[int] = []
|
||||
|
||||
for match in citation_pattern.finditer(text):
|
||||
# Group 2 is for double bracket single number, group 4 is for single bracket
|
||||
if match.group(2):
|
||||
nums_str = match.group(2)
|
||||
elif match.group(4):
|
||||
nums_str = match.group(4)
|
||||
else:
|
||||
continue
|
||||
|
||||
for num_str in nums_str.split(","):
|
||||
num_str = num_str.strip()
|
||||
if num_str:
|
||||
try:
|
||||
num = int(num_str)
|
||||
if num not in seen:
|
||||
seen.add(num)
|
||||
order.append(num)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return order
|
||||
|
||||
|
||||
def collapse_citations(
|
||||
answer_text: str,
|
||||
existing_citation_mapping: CitationMapping,
|
||||
|
||||
@@ -5,9 +5,11 @@ from sqlalchemy.orm import Session
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_utils import create_tool_call_failure_messages
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.citation_processor import CitationMode
|
||||
from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
from onyx.chat.citation_utils import update_citation_processor_from_tool_response
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.llm_step import extract_tool_calls_from_response_text
|
||||
from onyx.chat.llm_step import run_llm_step
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
@@ -37,11 +39,13 @@ from onyx.tools.built_in_tools import CITEABLE_TOOLS_NAMES
|
||||
from onyx.tools.built_in_tools import STOPPING_TOOLS_NAMES
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool_implementations.images.models import (
|
||||
FinalImageGenerationResponse,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.utils import extract_url_snippet_map
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
|
||||
from onyx.tools.tool_runner import run_tool_calls
|
||||
from onyx.tracing.framework.create import trace
|
||||
@@ -50,6 +54,78 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _try_fallback_tool_extraction(
|
||||
llm_step_result: LlmStepResult,
|
||||
tool_choice: ToolChoiceOptions,
|
||||
fallback_extraction_attempted: bool,
|
||||
tool_defs: list[dict],
|
||||
turn_index: int,
|
||||
) -> tuple[LlmStepResult, bool]:
|
||||
"""Attempt to extract tool calls from response text as a fallback.
|
||||
|
||||
This is a last resort fallback for low quality LLMs or those that don't have
|
||||
tool calling from the serving layer. Also triggers if there's reasoning but
|
||||
no answer and no tool calls.
|
||||
|
||||
Args:
|
||||
llm_step_result: The result from the LLM step
|
||||
tool_choice: The tool choice option used for this step
|
||||
fallback_extraction_attempted: Whether fallback extraction was already attempted
|
||||
tool_defs: List of tool definitions
|
||||
turn_index: The current turn index for placement
|
||||
|
||||
Returns:
|
||||
Tuple of (possibly updated LlmStepResult, whether fallback was attempted this call)
|
||||
"""
|
||||
if fallback_extraction_attempted:
|
||||
return llm_step_result, False
|
||||
|
||||
no_tool_calls = (
|
||||
not llm_step_result.tool_calls or len(llm_step_result.tool_calls) == 0
|
||||
)
|
||||
reasoning_but_no_answer_or_tools = (
|
||||
llm_step_result.reasoning and not llm_step_result.answer and no_tool_calls
|
||||
)
|
||||
should_try_fallback = (
|
||||
tool_choice == ToolChoiceOptions.REQUIRED and no_tool_calls
|
||||
) or reasoning_but_no_answer_or_tools
|
||||
|
||||
if not should_try_fallback:
|
||||
return llm_step_result, False
|
||||
|
||||
# Try to extract from answer first, then fall back to reasoning
|
||||
extracted_tool_calls: list[ToolCallKickoff] = []
|
||||
if llm_step_result.answer:
|
||||
extracted_tool_calls = extract_tool_calls_from_response_text(
|
||||
response_text=llm_step_result.answer,
|
||||
tool_definitions=tool_defs,
|
||||
placement=Placement(turn_index=turn_index),
|
||||
)
|
||||
if not extracted_tool_calls and llm_step_result.reasoning:
|
||||
extracted_tool_calls = extract_tool_calls_from_response_text(
|
||||
response_text=llm_step_result.reasoning,
|
||||
tool_definitions=tool_defs,
|
||||
placement=Placement(turn_index=turn_index),
|
||||
)
|
||||
|
||||
if extracted_tool_calls:
|
||||
logger.info(
|
||||
f"Extracted {len(extracted_tool_calls)} tool call(s) from response text "
|
||||
f"as fallback (tool_choice was REQUIRED but no tool calls returned)"
|
||||
)
|
||||
return (
|
||||
LlmStepResult(
|
||||
reasoning=llm_step_result.reasoning,
|
||||
answer=llm_step_result.answer,
|
||||
tool_calls=extracted_tool_calls,
|
||||
),
|
||||
True,
|
||||
)
|
||||
|
||||
return llm_step_result, True
|
||||
|
||||
|
||||
# Hardcoded oppinionated value, might breaks down to something like:
|
||||
# Cycle 1: Calls web_search for something
|
||||
# Cycle 2: Calls open_url for some results
|
||||
@@ -297,6 +373,7 @@ def run_llm_loop(
|
||||
forced_tool_id: int | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
chat_session_id: str | None = None,
|
||||
include_citations: bool = True,
|
||||
) -> None:
|
||||
with trace(
|
||||
"run_llm_loop",
|
||||
@@ -314,7 +391,13 @@ def run_llm_loop(
|
||||
initialize_litellm()
|
||||
|
||||
# Initialize citation processor for handling citations dynamically
|
||||
citation_processor = DynamicCitationProcessor()
|
||||
# When include_citations is True, use HYPERLINK mode to format citations as [[1]](url)
|
||||
# When include_citations is False, use REMOVE mode to strip citations from output
|
||||
citation_processor = DynamicCitationProcessor(
|
||||
citation_mode=(
|
||||
CitationMode.HYPERLINK if include_citations else CitationMode.REMOVE
|
||||
)
|
||||
)
|
||||
|
||||
# Add project file citation mappings if project files are present
|
||||
project_citation_mapping: CitationMapping = {}
|
||||
@@ -344,6 +427,7 @@ def run_llm_loop(
|
||||
ran_image_gen: bool = False
|
||||
just_ran_web_search: bool = False
|
||||
has_called_search_tool: bool = False
|
||||
fallback_extraction_attempted: bool = False
|
||||
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
|
||||
|
||||
default_base_system_prompt: str = get_default_base_system_prompt(db_session)
|
||||
@@ -370,12 +454,16 @@ def run_llm_loop(
|
||||
|
||||
# The section below calculates the available tokens for history a bit more accurately
|
||||
# now that project files are loaded in.
|
||||
if persona and persona.replace_base_system_prompt and persona.system_prompt:
|
||||
if persona and persona.replace_base_system_prompt:
|
||||
# Handles the case where user has checked off the "Replace base system prompt" checkbox
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=persona.system_prompt,
|
||||
token_count=token_counter(persona.system_prompt),
|
||||
message_type=MessageType.SYSTEM,
|
||||
system_prompt = (
|
||||
ChatMessageSimple(
|
||||
message=persona.system_prompt,
|
||||
token_count=token_counter(persona.system_prompt),
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
if persona.system_prompt
|
||||
else None
|
||||
)
|
||||
custom_agent_prompt_msg = None
|
||||
else:
|
||||
@@ -462,10 +550,11 @@ def run_llm_loop(
|
||||
|
||||
# This calls the LLM, yields packets (reasoning, answers, etc.) and returns the result
|
||||
# It also pre-processes the tool calls in preparation for running them
|
||||
tool_defs = [tool.tool_definition() for tool in final_tools]
|
||||
llm_step_result, has_reasoned = run_llm_step(
|
||||
emitter=emitter,
|
||||
history=truncated_message_history,
|
||||
tool_definitions=[tool.tool_definition() for tool in final_tools],
|
||||
tool_definitions=tool_defs,
|
||||
tool_choice=tool_choice,
|
||||
llm=llm,
|
||||
placement=Placement(turn_index=llm_cycle_count + reasoning_cycles),
|
||||
@@ -480,6 +569,19 @@ def run_llm_loop(
|
||||
if has_reasoned:
|
||||
reasoning_cycles += 1
|
||||
|
||||
# Fallback extraction for LLMs that don't support tool calling natively or are lower quality
|
||||
# and might incorrectly output tool calls in other channels
|
||||
llm_step_result, attempted = _try_fallback_tool_extraction(
|
||||
llm_step_result=llm_step_result,
|
||||
tool_choice=tool_choice,
|
||||
fallback_extraction_attempted=fallback_extraction_attempted,
|
||||
tool_defs=tool_defs,
|
||||
turn_index=llm_cycle_count + reasoning_cycles,
|
||||
)
|
||||
if attempted:
|
||||
# To prevent the case of excessive looping with bad models, we only allow one fallback attempt
|
||||
fallback_extraction_attempted = True
|
||||
|
||||
# Save citation mapping after each LLM step for incremental state updates
|
||||
state_container.set_citation_mapping(citation_processor.citation_to_doc)
|
||||
|
||||
@@ -505,7 +607,7 @@ def run_llm_loop(
|
||||
# in-flight citations
|
||||
# It can be cleaned up but not super trivial or worthwhile right now
|
||||
just_ran_web_search = False
|
||||
tool_responses, citation_mapping = run_tool_calls(
|
||||
parallel_tool_call_results = run_tool_calls(
|
||||
tool_calls=tool_calls,
|
||||
tools=final_tools,
|
||||
message_history=truncated_message_history,
|
||||
@@ -515,7 +617,10 @@ def run_llm_loop(
|
||||
next_citation_num=citation_processor.get_next_citation_number(),
|
||||
max_concurrent_tools=None,
|
||||
skip_search_query_expansion=has_called_search_tool,
|
||||
url_snippet_map=extract_url_snippet_map(gathered_documents or []),
|
||||
)
|
||||
tool_responses = parallel_tool_call_results.tool_responses
|
||||
citation_mapping = parallel_tool_call_results.updated_citation_mapping
|
||||
|
||||
# Failure case, give something reasonable to the LLM to try again
|
||||
if tool_calls and not tool_responses:
|
||||
@@ -551,8 +656,15 @@ def run_llm_loop(
|
||||
|
||||
# Extract search_docs if this is a search tool response
|
||||
search_docs = None
|
||||
displayed_docs = None
|
||||
if isinstance(tool_response.rich_response, SearchDocsResponse):
|
||||
search_docs = tool_response.rich_response.search_docs
|
||||
displayed_docs = tool_response.rich_response.displayed_docs
|
||||
|
||||
# Add ALL search docs to state container for DB persistence
|
||||
if search_docs:
|
||||
state_container.add_search_docs(search_docs)
|
||||
|
||||
if gathered_documents:
|
||||
gathered_documents.extend(search_docs)
|
||||
else:
|
||||
@@ -570,6 +682,12 @@ def run_llm_loop(
|
||||
):
|
||||
generated_images = tool_response.rich_response.generated_images
|
||||
|
||||
saved_response = (
|
||||
tool_response.rich_response
|
||||
if isinstance(tool_response.rich_response, str)
|
||||
else tool_response.llm_facing_response
|
||||
)
|
||||
|
||||
tool_call_info = ToolCallInfo(
|
||||
parent_tool_call_id=None, # Top-level tool calls are attached to the chat message
|
||||
turn_index=llm_cycle_count + reasoning_cycles,
|
||||
@@ -579,8 +697,8 @@ def run_llm_loop(
|
||||
tool_id=tool.id,
|
||||
reasoning_tokens=llm_step_result.reasoning, # All tool calls from this loop share the same reasoning
|
||||
tool_call_arguments=tool_call.tool_args,
|
||||
tool_call_response=tool_response.llm_facing_response,
|
||||
search_docs=search_docs,
|
||||
tool_call_response=saved_response,
|
||||
search_docs=displayed_docs or search_docs,
|
||||
generated_images=generated_images,
|
||||
)
|
||||
# Add to state container for partial save support
|
||||
@@ -635,7 +753,12 @@ def run_llm_loop(
|
||||
should_cite_documents = True
|
||||
|
||||
if not llm_step_result or not llm_step_result.answer:
|
||||
raise RuntimeError("LLM did not return an answer.")
|
||||
raise RuntimeError(
|
||||
"The LLM did not return an answer. "
|
||||
"Typically this is an issue with LLMs that do not support tool calling natively, "
|
||||
"or the model serving API is not configured correctly. "
|
||||
"This may also happen with models that are lower quality outputting invalid tool calls."
|
||||
)
|
||||
|
||||
emitter.emit(
|
||||
Packet(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Mapping
|
||||
@@ -13,6 +14,7 @@ from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS
|
||||
from onyx.configs.app_configs import PROMPT_CACHE_CHAT_HISTORY
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.file_store.models import ChatFileType
|
||||
@@ -48,6 +50,7 @@ from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tracing.framework.create import generation_span
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.text_processing import find_all_json_objects
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -136,12 +139,11 @@ def _format_message_history_for_logging(
|
||||
|
||||
separator = "================================================"
|
||||
|
||||
# Handle string input
|
||||
if isinstance(message_history, str):
|
||||
formatted_lines.append("Message [string]:")
|
||||
formatted_lines.append(separator)
|
||||
formatted_lines.append(f"{message_history}")
|
||||
return "\n".join(formatted_lines)
|
||||
# Handle single ChatCompletionMessage - wrap in list for uniform processing
|
||||
if isinstance(
|
||||
message_history, (SystemMessage, UserMessage, AssistantMessage, ToolMessage)
|
||||
):
|
||||
message_history = [message_history]
|
||||
|
||||
# Handle sequence of messages
|
||||
for i, msg in enumerate(message_history):
|
||||
@@ -211,7 +213,8 @@ def _update_tool_call_with_delta(
|
||||
|
||||
if index not in tool_calls_in_progress:
|
||||
tool_calls_in_progress[index] = {
|
||||
"id": None,
|
||||
# Fallback ID in case the provider never sends one via deltas.
|
||||
"id": f"fallback_{uuid.uuid4().hex}",
|
||||
"name": None,
|
||||
"arguments": "",
|
||||
}
|
||||
@@ -277,6 +280,144 @@ def _extract_tool_call_kickoffs(
|
||||
return tool_calls
|
||||
|
||||
|
||||
def extract_tool_calls_from_response_text(
|
||||
response_text: str | None,
|
||||
tool_definitions: list[dict],
|
||||
placement: Placement,
|
||||
) -> list[ToolCallKickoff]:
|
||||
"""Extract tool calls from LLM response text by matching JSON against tool definitions.
|
||||
|
||||
This is a fallback mechanism for when the LLM was expected to return tool calls
|
||||
but didn't use the proper tool call format. It searches for JSON objects in the
|
||||
response text that match the structure of available tools.
|
||||
|
||||
Args:
|
||||
response_text: The LLM's text response to search for tool calls
|
||||
tool_definitions: List of tool definitions to match against
|
||||
placement: Placement information for the tool calls
|
||||
|
||||
Returns:
|
||||
List of ToolCallKickoff objects for any matched tool calls
|
||||
"""
|
||||
if not response_text or not tool_definitions:
|
||||
return []
|
||||
|
||||
# Build a map of tool names to their definitions
|
||||
tool_name_to_def: dict[str, dict] = {}
|
||||
for tool_def in tool_definitions:
|
||||
if tool_def.get("type") == "function" and "function" in tool_def:
|
||||
func_def = tool_def["function"]
|
||||
tool_name = func_def.get("name")
|
||||
if tool_name:
|
||||
tool_name_to_def[tool_name] = func_def
|
||||
|
||||
if not tool_name_to_def:
|
||||
return []
|
||||
|
||||
# Find all JSON objects in the response text
|
||||
json_objects = find_all_json_objects(response_text)
|
||||
|
||||
tool_calls: list[ToolCallKickoff] = []
|
||||
tab_index = 0
|
||||
|
||||
for json_obj in json_objects:
|
||||
matched_tool_call = _try_match_json_to_tool(json_obj, tool_name_to_def)
|
||||
if matched_tool_call:
|
||||
tool_name, tool_args = matched_tool_call
|
||||
tool_calls.append(
|
||||
ToolCallKickoff(
|
||||
tool_call_id=f"extracted_{uuid.uuid4().hex[:8]}",
|
||||
tool_name=tool_name,
|
||||
tool_args=tool_args,
|
||||
placement=Placement(
|
||||
turn_index=placement.turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=placement.sub_turn_index,
|
||||
),
|
||||
)
|
||||
)
|
||||
tab_index += 1
|
||||
|
||||
logger.info(
|
||||
f"Extracted {len(tool_calls)} tool call(s) from response text as fallback"
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
|
||||
def _try_match_json_to_tool(
|
||||
json_obj: dict[str, Any],
|
||||
tool_name_to_def: dict[str, dict],
|
||||
) -> tuple[str, dict[str, Any]] | None:
|
||||
"""Try to match a JSON object to a tool definition.
|
||||
|
||||
Supports several formats:
|
||||
1. Direct tool call format: {"name": "tool_name", "arguments": {...}}
|
||||
2. Function call format: {"function": {"name": "tool_name", "arguments": {...}}}
|
||||
3. Tool name as key: {"tool_name": {...arguments...}}
|
||||
4. Arguments matching a tool's parameter schema
|
||||
|
||||
Args:
|
||||
json_obj: The JSON object to match
|
||||
tool_name_to_def: Map of tool names to their function definitions
|
||||
|
||||
Returns:
|
||||
Tuple of (tool_name, tool_args) if matched, None otherwise
|
||||
"""
|
||||
# Format 1: Direct tool call format {"name": "...", "arguments": {...}}
|
||||
if "name" in json_obj and json_obj["name"] in tool_name_to_def:
|
||||
tool_name = json_obj["name"]
|
||||
arguments = json_obj.get("arguments", json_obj.get("parameters", {}))
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
if isinstance(arguments, dict):
|
||||
return (tool_name, arguments)
|
||||
|
||||
# Format 2: Function call format {"function": {"name": "...", "arguments": {...}}}
|
||||
if "function" in json_obj and isinstance(json_obj["function"], dict):
|
||||
func_obj = json_obj["function"]
|
||||
if "name" in func_obj and func_obj["name"] in tool_name_to_def:
|
||||
tool_name = func_obj["name"]
|
||||
arguments = func_obj.get("arguments", func_obj.get("parameters", {}))
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
if isinstance(arguments, dict):
|
||||
return (tool_name, arguments)
|
||||
|
||||
# Format 3: Tool name as key {"tool_name": {...arguments...}}
|
||||
for tool_name in tool_name_to_def:
|
||||
if tool_name in json_obj:
|
||||
arguments = json_obj[tool_name]
|
||||
if isinstance(arguments, dict):
|
||||
return (tool_name, arguments)
|
||||
|
||||
# Format 4: Check if the JSON object matches a tool's parameter schema
|
||||
for tool_name, func_def in tool_name_to_def.items():
|
||||
params = func_def.get("parameters", {})
|
||||
properties = params.get("properties", {})
|
||||
required = params.get("required", [])
|
||||
|
||||
if not properties:
|
||||
continue
|
||||
|
||||
# Check if all required parameters are present (empty required = all optional)
|
||||
if all(req in json_obj for req in required):
|
||||
# Check if any of the tool's properties are in the JSON object
|
||||
matching_props = [prop for prop in properties if prop in json_obj]
|
||||
if matching_props:
|
||||
# Filter to only include known properties
|
||||
filtered_args = {k: v for k, v in json_obj.items() if k in properties}
|
||||
return (tool_name, filtered_args)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def translate_history_to_llm_format(
|
||||
history: list[ChatMessageSimple],
|
||||
llm_config: LLMConfig,
|
||||
@@ -292,7 +433,7 @@ def translate_history_to_llm_format(
|
||||
|
||||
for idx, msg in enumerate(history):
|
||||
# if the message is being added to the history
|
||||
if msg.message_type in [
|
||||
if PROMPT_CACHE_CHAT_HISTORY and msg.message_type in [
|
||||
MessageType.SYSTEM,
|
||||
MessageType.USER,
|
||||
MessageType.ASSISTANT,
|
||||
@@ -581,6 +722,18 @@ def run_llm_step_pkt_generator(
|
||||
}
|
||||
# Note: LLM cost tracking is now handled in multi_llm.py
|
||||
delta = packet.choice.delta
|
||||
|
||||
# Weird behavior from some model providers, just log and ignore for now
|
||||
if (
|
||||
delta.content is None
|
||||
and delta.reasoning_content is None
|
||||
and delta.tool_calls is None
|
||||
):
|
||||
logger.warning(
|
||||
f"LLM packet is empty (no contents, reasoning or tool calls). Skipping: {packet}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not first_action_recorded and _delta_has_action(delta):
|
||||
span_generation.span_data.time_to_first_action_seconds = (
|
||||
time.monotonic() - stream_start_time
|
||||
@@ -707,6 +860,11 @@ def run_llm_step_pkt_generator(
|
||||
),
|
||||
obj=result,
|
||||
)
|
||||
# Track emitted citation for saving
|
||||
if state_container:
|
||||
state_container.add_emitted_citation(
|
||||
result.citation_number
|
||||
)
|
||||
else:
|
||||
# When citation_processor is None, use delta.content directly without modification
|
||||
accumulated_answer += delta.content
|
||||
@@ -833,6 +991,9 @@ def run_llm_step_pkt_generator(
|
||||
),
|
||||
obj=result,
|
||||
)
|
||||
# Track emitted citation for saving
|
||||
if state_container:
|
||||
state_container.add_emitted_citation(result.citation_number)
|
||||
|
||||
# Note: Content (AgentResponseDelta) doesn't need an explicit end packet - OverallStop handles it
|
||||
# Tool calls are handled by tool execution code and emit their own packets (e.g., SectionEnd)
|
||||
@@ -840,14 +1001,14 @@ def run_llm_step_pkt_generator(
|
||||
logger.debug(f"Accumulated reasoning: {accumulated_reasoning}")
|
||||
logger.debug(f"Accumulated answer: {accumulated_answer}")
|
||||
|
||||
if tool_calls:
|
||||
tool_calls_str = "\n".join(
|
||||
f" - {tc.tool_name}: {json.dumps(tc.tool_args, indent=4)}"
|
||||
for tc in tool_calls
|
||||
)
|
||||
logger.debug(f"Tool calls:\n{tool_calls_str}")
|
||||
else:
|
||||
logger.debug("Tool calls: []")
|
||||
if tool_calls:
|
||||
tool_calls_str = "\n".join(
|
||||
f" - {tc.tool_name}: {json.dumps(tc.tool_args, indent=4)}"
|
||||
for tc in tool_calls
|
||||
)
|
||||
logger.debug(f"Tool calls:\n{tool_calls_str}")
|
||||
else:
|
||||
logger.debug("Tool calls: []")
|
||||
|
||||
return (
|
||||
LlmStepResult(
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
@@ -8,10 +7,7 @@ from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.enums import QueryFlow
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
@@ -24,25 +20,6 @@ from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.tool_implementations.custom.base_tool_types import ToolResultType
|
||||
|
||||
|
||||
# First chunk of info for streaming QA
|
||||
class QADocsResponse(BaseModel):
|
||||
top_documents: list[SearchDoc]
|
||||
rephrased_query: str | None = None
|
||||
predicted_flow: QueryFlow | None
|
||||
predicted_search: SearchType | None
|
||||
applied_source_filters: list[DocumentSource] | None
|
||||
applied_time_cutoff: datetime | None
|
||||
recency_bias_multiplier: float
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
initial_dict["applied_time_cutoff"] = (
|
||||
self.applied_time_cutoff.isoformat() if self.applied_time_cutoff else None
|
||||
)
|
||||
|
||||
return initial_dict
|
||||
|
||||
|
||||
class StreamStopReason(Enum):
|
||||
CONTEXT_LENGTH = "context_length"
|
||||
CANCELLED = "cancelled"
|
||||
@@ -70,22 +47,11 @@ class UserKnowledgeFilePacket(BaseModel):
|
||||
user_files: list[FileDescriptor]
|
||||
|
||||
|
||||
class LLMRelevanceFilterResponse(BaseModel):
|
||||
llm_selected_doc_indices: list[int]
|
||||
|
||||
|
||||
class RelevanceAnalysis(BaseModel):
|
||||
relevant: bool
|
||||
content: str | None = None
|
||||
|
||||
|
||||
class SectionRelevancePiece(RelevanceAnalysis):
|
||||
"""LLM analysis mapped to an Inference Section"""
|
||||
|
||||
document_id: str
|
||||
chunk_id: int # ID of the center chunk for a given inference section
|
||||
|
||||
|
||||
class DocumentRelevance(BaseModel):
|
||||
"""Contains all relevance information for a given search"""
|
||||
|
||||
@@ -116,12 +82,6 @@ class OnyxAnswer(BaseModel):
|
||||
answer: str | None
|
||||
|
||||
|
||||
class ThreadMessage(BaseModel):
|
||||
message: str
|
||||
sender: str | None = None
|
||||
role: MessageType = MessageType.USER
|
||||
|
||||
|
||||
class FileChatDisplay(BaseModel):
|
||||
file_ids: list[str]
|
||||
|
||||
@@ -158,7 +118,6 @@ class PersonaOverrideConfig(BaseModel):
|
||||
num_chunks: float | None = None
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
|
||||
|
||||
@@ -5,10 +5,13 @@ An overview can be found in the README.md file in this directory.
|
||||
|
||||
import re
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from uuid import UUID
|
||||
|
||||
from redis.client import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_processing_checker import set_processing_status
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_state import run_chat_loop_with_state_containers
|
||||
from onyx.chat.chat_utils import convert_chat_history
|
||||
@@ -35,16 +38,19 @@ from onyx.chat.save_chat import save_chat_turn
|
||||
from onyx.chat.stop_signal_checker import is_connected as check_stop_signal
|
||||
from onyx.chat.stop_signal_checker import reset_cancel_status
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.models import CitationDocInfo
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.memory import get_memories
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.projects import get_project_token_count
|
||||
from onyx.db.projects import get_user_files_from_project
|
||||
@@ -62,6 +68,7 @@ from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.models import OptionalSearchSetting
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
@@ -78,18 +85,30 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.timing import log_function_time
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from onyx.utils.variable_functionality import noop_fallback
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
ERROR_TYPE_CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class ToolCallException(Exception):
|
||||
"""Exception raised for errors during tool calls."""
|
||||
def _should_enable_slack_search(
|
||||
persona: Persona,
|
||||
filters: BaseFilters | None,
|
||||
) -> bool:
|
||||
"""Determine if Slack search should be enabled.
|
||||
|
||||
def __init__(self, message: str, tool_name: str | None = None):
|
||||
super().__init__(message)
|
||||
self.tool_name = tool_name
|
||||
Returns True if:
|
||||
- Source type filter exists and includes Slack, OR
|
||||
- Default persona with no source type filter
|
||||
"""
|
||||
source_types = filters.source_type if filters else None
|
||||
return (source_types is not None and DocumentSource.SLACK in source_types) or (
|
||||
persona.id == DEFAULT_PERSONA_ID and source_types is None
|
||||
)
|
||||
|
||||
|
||||
def _extract_project_file_texts_and_images(
|
||||
@@ -280,6 +299,7 @@ def handle_stream_message_objects(
|
||||
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
mcp_headers: dict[str, str] | None = None,
|
||||
bypass_acl: bool = False,
|
||||
# Additional context that should be included in the chat history, for example:
|
||||
# Slack threads where the conversation cannot be represented by a chain of User/Assistant
|
||||
@@ -294,6 +314,8 @@ def handle_stream_message_objects(
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
llm: LLM | None = None
|
||||
chat_session: ChatSession | None = None
|
||||
redis_client: Redis | None = None
|
||||
|
||||
user_id = user.id if user is not None else None
|
||||
llm_user_identifier = (
|
||||
@@ -339,6 +361,24 @@ def handle_stream_message_objects(
|
||||
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
|
||||
)
|
||||
|
||||
# Track user message in PostHog for analytics
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
module="onyx.utils.telemetry",
|
||||
attribute="event_telemetry",
|
||||
fallback=noop_fallback,
|
||||
)(
|
||||
distinct_id=user.email if user else tenant_id,
|
||||
event="user_message_sent",
|
||||
properties={
|
||||
"origin": new_msg_req.origin.value,
|
||||
"has_files": len(new_msg_req.file_descriptors) > 0,
|
||||
"has_project": chat_session.project_id is not None,
|
||||
"has_persona": persona is not None and persona.id != DEFAULT_PERSONA_ID,
|
||||
"deep_research": new_msg_req.deep_research,
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
)
|
||||
|
||||
llm = get_llm_for_persona(
|
||||
persona=persona,
|
||||
user=user,
|
||||
@@ -380,7 +420,10 @@ def handle_stream_message_objects(
|
||||
if new_msg_req.parent_message_id == AUTO_PLACE_AFTER_LATEST_MESSAGE:
|
||||
# Auto-place after the latest message in the chain
|
||||
parent_message = chat_history[-1] if chat_history else root_message
|
||||
elif new_msg_req.parent_message_id is None:
|
||||
elif (
|
||||
new_msg_req.parent_message_id is None
|
||||
or new_msg_req.parent_message_id == root_message.id
|
||||
):
|
||||
# None = regeneration from root
|
||||
parent_message = root_message
|
||||
# Truncate history since we're starting from root
|
||||
@@ -480,11 +523,15 @@ def handle_stream_message_objects(
|
||||
),
|
||||
bypass_acl=bypass_acl,
|
||||
slack_context=slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
persona, new_msg_req.internal_search_filters
|
||||
),
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=chat_session.id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
additional_headers=custom_tool_additional_headers,
|
||||
mcp_headers=mcp_headers,
|
||||
),
|
||||
allowed_tool_ids=new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=project_search_config.search_usage,
|
||||
@@ -536,10 +583,27 @@ def handle_stream_message_objects(
|
||||
def check_is_connected() -> bool:
|
||||
return check_stop_signal(chat_session.id, redis_client)
|
||||
|
||||
set_processing_status(
|
||||
chat_session_id=chat_session.id,
|
||||
redis_client=redis_client,
|
||||
value=True,
|
||||
)
|
||||
|
||||
# Use external state container if provided, otherwise create internal one
|
||||
# External container allows non-streaming callers to access accumulated state
|
||||
state_container = external_state_container or ChatStateContainer()
|
||||
|
||||
def llm_loop_completion_callback(
|
||||
state_container: ChatStateContainer,
|
||||
) -> None:
|
||||
llm_loop_completion_handle(
|
||||
state_container=state_container,
|
||||
db_session=db_session,
|
||||
chat_session_id=str(chat_session.id),
|
||||
is_connected=check_is_connected,
|
||||
assistant_message=assistant_response,
|
||||
)
|
||||
|
||||
# Run the LLM loop with explicit wrapper for stop signal handling
|
||||
# The wrapper runs run_llm_loop in a background thread and polls every 300ms
|
||||
# for stop signals. run_llm_loop itself doesn't know about stopping.
|
||||
@@ -555,6 +619,7 @@ def handle_stream_message_objects(
|
||||
|
||||
yield from run_chat_loop_with_state_containers(
|
||||
run_deep_research_llm_loop,
|
||||
llm_loop_completion_callback,
|
||||
is_connected=check_is_connected,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
@@ -571,6 +636,7 @@ def handle_stream_message_objects(
|
||||
else:
|
||||
yield from run_chat_loop_with_state_containers(
|
||||
run_llm_loop,
|
||||
llm_loop_completion_callback,
|
||||
is_connected=check_is_connected, # Not passed through to run_llm_loop
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
@@ -586,53 +652,9 @@ def handle_stream_message_objects(
|
||||
forced_tool_id=forced_tool_id,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
include_citations=new_msg_req.include_citations,
|
||||
)
|
||||
|
||||
# Determine if stopped by user
|
||||
completed_normally = check_is_connected()
|
||||
if not completed_normally:
|
||||
logger.debug(f"Chat session {chat_session.id} stopped by user")
|
||||
|
||||
# Build final answer based on completion status
|
||||
if completed_normally:
|
||||
if state_container.answer_tokens is None:
|
||||
raise RuntimeError(
|
||||
"LLM run completed normally but did not return an answer."
|
||||
)
|
||||
final_answer = state_container.answer_tokens
|
||||
else:
|
||||
# Stopped by user - append stop message
|
||||
if state_container.answer_tokens:
|
||||
final_answer = (
|
||||
state_container.answer_tokens
|
||||
+ " ... The generation was stopped by the user here."
|
||||
)
|
||||
else:
|
||||
final_answer = "The generation was stopped by the user."
|
||||
|
||||
# Build citation_docs_info from accumulated citations in state container
|
||||
citation_docs_info: list[CitationDocInfo] = []
|
||||
seen_citation_nums: set[int] = set()
|
||||
for citation_num, search_doc in state_container.citation_to_doc.items():
|
||||
if citation_num not in seen_citation_nums:
|
||||
seen_citation_nums.add(citation_num)
|
||||
citation_docs_info.append(
|
||||
CitationDocInfo(
|
||||
search_doc=search_doc,
|
||||
citation_number=citation_num,
|
||||
)
|
||||
)
|
||||
|
||||
save_chat_turn(
|
||||
message_text=final_answer,
|
||||
reasoning_tokens=state_container.reasoning_tokens,
|
||||
citation_docs_info=citation_docs_info,
|
||||
tool_calls=state_container.tool_calls,
|
||||
db_session=db_session,
|
||||
assistant_message=assistant_response,
|
||||
is_clarification=state_container.is_clarification,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to process chat message.")
|
||||
|
||||
@@ -650,15 +672,7 @@ def handle_stream_message_objects(
|
||||
error_msg = str(e)
|
||||
stack_trace = traceback.format_exc()
|
||||
|
||||
if isinstance(e, ToolCallException):
|
||||
yield StreamingError(
|
||||
error=error_msg,
|
||||
stack_trace=stack_trace,
|
||||
error_code="TOOL_CALL_FAILED",
|
||||
is_retryable=True,
|
||||
details={"tool_name": e.tool_name} if e.tool_name else None,
|
||||
)
|
||||
elif llm:
|
||||
if llm:
|
||||
client_error_msg, error_code, is_retryable = litellm_exception_to_error_msg(
|
||||
e, llm
|
||||
)
|
||||
@@ -690,7 +704,56 @@ def handle_stream_message_objects(
|
||||
)
|
||||
|
||||
db_session.rollback()
|
||||
return
|
||||
finally:
|
||||
try:
|
||||
if redis_client is not None and chat_session is not None:
|
||||
set_processing_status(
|
||||
chat_session_id=chat_session.id,
|
||||
redis_client=redis_client,
|
||||
value=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error in setting processing status")
|
||||
|
||||
|
||||
def llm_loop_completion_handle(
|
||||
state_container: ChatStateContainer,
|
||||
is_connected: Callable[[], bool],
|
||||
db_session: Session,
|
||||
chat_session_id: str,
|
||||
assistant_message: ChatMessage,
|
||||
) -> None:
|
||||
# Determine if stopped by user
|
||||
completed_normally = is_connected()
|
||||
# Build final answer based on completion status
|
||||
if completed_normally:
|
||||
if state_container.answer_tokens is None:
|
||||
raise RuntimeError(
|
||||
"LLM run completed normally but did not return an answer."
|
||||
)
|
||||
final_answer = state_container.answer_tokens
|
||||
else:
|
||||
# Stopped by user - append stop message
|
||||
logger.debug(f"Chat session {chat_session_id} stopped by user")
|
||||
if state_container.answer_tokens:
|
||||
final_answer = (
|
||||
state_container.answer_tokens
|
||||
+ " ... \n\nGeneration was stopped by the user."
|
||||
)
|
||||
else:
|
||||
final_answer = "The generation was stopped by the user."
|
||||
|
||||
save_chat_turn(
|
||||
message_text=final_answer,
|
||||
reasoning_tokens=state_container.reasoning_tokens,
|
||||
citation_to_doc=state_container.citation_to_doc,
|
||||
tool_calls=state_container.tool_calls,
|
||||
all_search_docs=state_container.get_all_search_docs(),
|
||||
db_session=db_session,
|
||||
assistant_message=assistant_message,
|
||||
is_clarification=state_container.is_clarification,
|
||||
emitted_citations=state_container.get_emitted_citations(),
|
||||
)
|
||||
|
||||
|
||||
def stream_chat_message_objects(
|
||||
@@ -739,6 +802,8 @@ def stream_chat_message_objects(
|
||||
deep_research=new_msg_req.deep_research,
|
||||
parent_message_id=new_msg_req.parent_message_id,
|
||||
chat_session_id=new_msg_req.chat_session_id,
|
||||
origin=new_msg_req.origin,
|
||||
include_citations=new_msg_req.include_citations,
|
||||
)
|
||||
return handle_stream_message_objects(
|
||||
new_msg_req=translated_new_msg_req,
|
||||
|
||||
@@ -18,6 +18,7 @@ from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
from onyx.prompts.prompt_utils import replace_citation_guidance_tag
|
||||
from onyx.prompts.tool_prompts import GENERATE_IMAGE_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import INTERNAL_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import MEMORY_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import OPEN_URLS_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import PYTHON_TOOL_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
@@ -28,6 +29,7 @@ from onyx.tools.interface import Tool
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.memory.memory_tool import MemoryTool
|
||||
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
@@ -178,8 +180,9 @@ def build_system_prompt(
|
||||
site_colon_disabled=WEB_SEARCH_SITE_DISABLED_GUIDANCE
|
||||
)
|
||||
+ OPEN_URLS_GUIDANCE
|
||||
+ GENERATE_IMAGE_GUIDANCE
|
||||
+ PYTHON_TOOL_GUIDANCE
|
||||
+ GENERATE_IMAGE_GUIDANCE
|
||||
+ MEMORY_GUIDANCE
|
||||
)
|
||||
return system_prompt
|
||||
|
||||
@@ -193,6 +196,7 @@ def build_system_prompt(
|
||||
has_generate_image = any(
|
||||
isinstance(tool, ImageGenerationTool) for tool in tools
|
||||
)
|
||||
has_memory = any(isinstance(tool, MemoryTool) for tool in tools)
|
||||
|
||||
if has_web_search or has_internal_search or include_all_guidance:
|
||||
system_prompt += TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
@@ -222,4 +226,7 @@ def build_system_prompt(
|
||||
if has_generate_image or include_all_guidance:
|
||||
system_prompt += GENERATE_IMAGE_GUIDANCE
|
||||
|
||||
if has_memory or include_all_guidance:
|
||||
system_prompt += MEMORY_GUIDANCE
|
||||
|
||||
return system_prompt
|
||||
|
||||
@@ -2,8 +2,9 @@ import json
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_state import SearchDocKey
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import CitationDocInfo
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.db.chat import add_search_docs_to_chat_message
|
||||
from onyx.db.chat import add_search_docs_to_tool_call
|
||||
@@ -19,22 +20,6 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _create_search_doc_key(search_doc: SearchDoc) -> tuple[str, int, tuple[str, ...]]:
|
||||
"""
|
||||
Create a unique key for a SearchDoc that accounts for different versions of the same
|
||||
document/chunk with different match_highlights.
|
||||
|
||||
Args:
|
||||
search_doc: The SearchDoc pydantic model to create a key for
|
||||
|
||||
Returns:
|
||||
A tuple of (document_id, chunk_ind, sorted match_highlights) that uniquely identifies
|
||||
this specific version of the document
|
||||
"""
|
||||
match_highlights_tuple = tuple(sorted(search_doc.match_highlights or []))
|
||||
return (search_doc.document_id, search_doc.chunk_ind, match_highlights_tuple)
|
||||
|
||||
|
||||
def _create_and_link_tool_calls(
|
||||
tool_calls: list[ToolCallInfo],
|
||||
assistant_message: ChatMessage,
|
||||
@@ -154,38 +139,36 @@ def save_chat_turn(
|
||||
message_text: str,
|
||||
reasoning_tokens: str | None,
|
||||
tool_calls: list[ToolCallInfo],
|
||||
citation_docs_info: list[CitationDocInfo],
|
||||
citation_to_doc: dict[int, SearchDoc],
|
||||
all_search_docs: dict[SearchDocKey, SearchDoc],
|
||||
db_session: Session,
|
||||
assistant_message: ChatMessage,
|
||||
is_clarification: bool = False,
|
||||
emitted_citations: set[int] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Save a chat turn by populating the assistant_message and creating related entities.
|
||||
|
||||
This function:
|
||||
1. Updates the ChatMessage with text, reasoning tokens, and token count
|
||||
2. Creates SearchDoc entries from ToolCall search_docs (for tool calls that returned documents)
|
||||
3. Collects all unique SearchDocs from all tool calls and links them to ChatMessage
|
||||
4. Builds citation mapping from citation_docs_info
|
||||
5. Links all unique SearchDocs from tool calls to the ChatMessage
|
||||
2. Creates DB SearchDoc entries from pre-deduplicated all_search_docs
|
||||
3. Builds tool_call -> search_doc mapping for displayed docs
|
||||
4. Builds citation mapping from citation_to_doc
|
||||
5. Links all unique SearchDocs to the ChatMessage
|
||||
6. Creates ToolCall entries and links SearchDocs to them
|
||||
7. Builds the citations mapping for the ChatMessage
|
||||
|
||||
Deduplication Logic:
|
||||
- SearchDocs are deduplicated using (document_id, chunk_ind, match_highlights) as the key
|
||||
- This ensures that the same document/chunk with different match_highlights (from different
|
||||
queries) are stored as separate SearchDoc entries
|
||||
- Each ToolCall and ChatMessage will map to the correct version of the SearchDoc that
|
||||
matches its specific query highlights
|
||||
|
||||
Args:
|
||||
message_text: The message content to save
|
||||
reasoning_tokens: Optional reasoning tokens for the message
|
||||
tool_calls: List of tool call information to create ToolCall entries (may include search_docs)
|
||||
citation_docs_info: List of citation document information for building citations mapping
|
||||
citation_to_doc: Mapping from citation number to SearchDoc for building citations
|
||||
all_search_docs: Pre-deduplicated search docs from ChatStateContainer
|
||||
db_session: Database session for persistence
|
||||
assistant_message: The ChatMessage object to populate (should already exist in DB)
|
||||
is_clarification: Whether this assistant message is a clarification question (deep research flow)
|
||||
emitted_citations: Set of citation numbers that were actually emitted during streaming.
|
||||
If provided, only citations in this set will be saved; others are filtered out.
|
||||
"""
|
||||
# 1. Update ChatMessage with message content, reasoning tokens, and token count
|
||||
assistant_message.message = message_text
|
||||
@@ -200,53 +183,53 @@ def save_chat_turn(
|
||||
else:
|
||||
assistant_message.token_count = 0
|
||||
|
||||
# 2. Create SearchDoc entries from tool_calls
|
||||
# Build mapping from SearchDoc to DB SearchDoc ID
|
||||
# Use (document_id, chunk_ind, match_highlights) as key to avoid duplicates
|
||||
# while ensuring different versions with different highlights are stored separately
|
||||
search_doc_key_to_id: dict[tuple[str, int, tuple[str, ...]], int] = {}
|
||||
tool_call_to_search_doc_ids: dict[str, list[int]] = {}
|
||||
# 2. Create DB SearchDoc entries from pre-deduplicated all_search_docs
|
||||
search_doc_key_to_id: dict[SearchDocKey, int] = {}
|
||||
for key, search_doc_py in all_search_docs.items():
|
||||
db_search_doc = create_db_search_doc(
|
||||
server_search_doc=search_doc_py,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
search_doc_key_to_id[key] = db_search_doc.id
|
||||
|
||||
# Process tool calls and their search docs
|
||||
# 3. Build tool_call -> search_doc mapping (for displayed docs in each tool call)
|
||||
tool_call_to_search_doc_ids: dict[str, list[int]] = {}
|
||||
for tool_call_info in tool_calls:
|
||||
if tool_call_info.search_docs:
|
||||
search_doc_ids_for_tool: list[int] = []
|
||||
for search_doc_py in tool_call_info.search_docs:
|
||||
# Create a unique key for this SearchDoc version
|
||||
search_doc_key = _create_search_doc_key(search_doc_py)
|
||||
|
||||
# Check if we've already created this exact SearchDoc version
|
||||
if search_doc_key in search_doc_key_to_id:
|
||||
search_doc_ids_for_tool.append(search_doc_key_to_id[search_doc_key])
|
||||
key = ChatStateContainer.create_search_doc_key(search_doc_py)
|
||||
if key in search_doc_key_to_id:
|
||||
search_doc_ids_for_tool.append(search_doc_key_to_id[key])
|
||||
else:
|
||||
# Create new DB SearchDoc entry
|
||||
# Displayed doc not in all_search_docs - create it
|
||||
# This can happen if displayed_docs contains docs not in search_docs
|
||||
db_search_doc = create_db_search_doc(
|
||||
server_search_doc=search_doc_py,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
search_doc_key_to_id[search_doc_key] = db_search_doc.id
|
||||
search_doc_key_to_id[key] = db_search_doc.id
|
||||
search_doc_ids_for_tool.append(db_search_doc.id)
|
||||
|
||||
tool_call_to_search_doc_ids[tool_call_info.tool_call_id] = list(
|
||||
set(search_doc_ids_for_tool)
|
||||
)
|
||||
|
||||
# 3. Collect all unique SearchDoc IDs from all tool calls to link to ChatMessage
|
||||
# Use a set to deduplicate by ID (since we've already deduplicated by key above)
|
||||
all_search_doc_ids_set: set[int] = set()
|
||||
for search_doc_ids in tool_call_to_search_doc_ids.values():
|
||||
all_search_doc_ids_set.update(search_doc_ids)
|
||||
# Collect all search doc IDs for ChatMessage linking
|
||||
all_search_doc_ids_set: set[int] = set(search_doc_key_to_id.values())
|
||||
|
||||
# 4. Build citation mapping from citation_docs_info
|
||||
# 4. Build a citation mapping from the citation number to the saved DB SearchDoc ID
|
||||
# Only include citations that were actually emitted during streaming
|
||||
citation_number_to_search_doc_id: dict[int, int] = {}
|
||||
|
||||
for citation_doc_info in citation_docs_info:
|
||||
# Extract SearchDoc pydantic model
|
||||
search_doc_py = citation_doc_info.search_doc
|
||||
for citation_num, search_doc_py in citation_to_doc.items():
|
||||
# Skip citations that weren't actually emitted (if emitted_citations is provided)
|
||||
if emitted_citations is not None and citation_num not in emitted_citations:
|
||||
continue
|
||||
|
||||
# Create the unique key for this SearchDoc version
|
||||
search_doc_key = _create_search_doc_key(search_doc_py)
|
||||
search_doc_key = ChatStateContainer.create_search_doc_key(search_doc_py)
|
||||
|
||||
# Get the search doc ID (should already exist from processing tool_calls)
|
||||
if search_doc_key in search_doc_key_to_id:
|
||||
@@ -283,10 +266,7 @@ def save_chat_turn(
|
||||
all_search_doc_ids_set.add(db_search_doc_id)
|
||||
|
||||
# Build mapping from citation number to search doc ID
|
||||
if citation_doc_info.citation_number is not None:
|
||||
citation_number_to_search_doc_id[citation_doc_info.citation_number] = (
|
||||
db_search_doc_id
|
||||
)
|
||||
citation_number_to_search_doc_id[citation_num] = db_search_doc_id
|
||||
|
||||
# 5. Link all unique SearchDocs (from both tool calls and citations) to ChatMessage
|
||||
final_search_doc_ids: list[int] = list(all_search_doc_ids_set)
|
||||
@@ -306,23 +286,10 @@ def save_chat_turn(
|
||||
tool_call_to_search_doc_ids=tool_call_to_search_doc_ids,
|
||||
)
|
||||
|
||||
# 7. Build citations mapping from citation_docs_info
|
||||
# Any citation_doc_info with a citation_number appeared in the text and should be mapped
|
||||
citations: dict[int, int] = {}
|
||||
for citation_doc_info in citation_docs_info:
|
||||
if citation_doc_info.citation_number is not None:
|
||||
search_doc_id = citation_number_to_search_doc_id.get(
|
||||
citation_doc_info.citation_number
|
||||
)
|
||||
if search_doc_id is not None:
|
||||
citations[citation_doc_info.citation_number] = search_doc_id
|
||||
else:
|
||||
logger.warning(
|
||||
f"Citation number {citation_doc_info.citation_number} found in citation_docs_info "
|
||||
f"but no matching search doc ID in mapping"
|
||||
)
|
||||
|
||||
assistant_message.citations = citations if citations else None
|
||||
# 7. Build citations mapping - use the mapping we already built in step 4
|
||||
assistant_message.citations = (
|
||||
citation_number_to_search_doc_id if citation_number_to_search_doc_id else None
|
||||
)
|
||||
|
||||
# Finally save the messages, tool calls, and docs
|
||||
db_session.commit()
|
||||
|
||||
@@ -22,6 +22,14 @@ APP_PORT = 8080
|
||||
# prefix from requests directed towards the API server. In these cases, set this to `/api`
|
||||
APP_API_PREFIX = os.environ.get("API_PREFIX", "")
|
||||
|
||||
# Certain services need to make HTTP requests to the API server, such as the MCP server and Discord bot
|
||||
API_SERVER_PROTOCOL = os.environ.get("API_SERVER_PROTOCOL", "http")
|
||||
API_SERVER_HOST = os.environ.get("API_SERVER_HOST", "127.0.0.1")
|
||||
# This override allows self-hosting the MCP server with Onyx Cloud backend.
|
||||
API_SERVER_URL_OVERRIDE_FOR_HTTP_REQUESTS = os.environ.get(
|
||||
"API_SERVER_URL_OVERRIDE_FOR_HTTP_REQUESTS"
|
||||
)
|
||||
|
||||
# Whether to send user metadata (user_id/email and session_id) to the LLM provider.
|
||||
# Disabled by default.
|
||||
SEND_USER_METADATA_TO_LLM_PROVIDER = (
|
||||
@@ -200,8 +208,19 @@ OPENSEARCH_REST_API_PORT = int(os.environ.get("OPENSEARCH_REST_API_PORT") or 920
|
||||
OPENSEARCH_ADMIN_USERNAME = os.environ.get("OPENSEARCH_ADMIN_USERNAME", "admin")
|
||||
OPENSEARCH_ADMIN_PASSWORD = os.environ.get("OPENSEARCH_ADMIN_PASSWORD", "")
|
||||
|
||||
ENABLE_OPENSEARCH_FOR_ONYX = (
|
||||
os.environ.get("ENABLE_OPENSEARCH_FOR_ONYX", "").lower() == "true"
|
||||
# This is the "base" config for now, the idea is that at least for our dev
|
||||
# environments we always want to be dual indexing into both OpenSearch and Vespa
|
||||
# to stress test the new codepaths. Only enable this if there is some instance
|
||||
# of OpenSearch running for the relevant Onyx instance.
|
||||
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX = (
|
||||
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "").lower() == "true"
|
||||
)
|
||||
# Given that the "base" config above is true, this enables whether we want to
|
||||
# retrieve from OpenSearch or Vespa. We want to be able to quickly toggle this
|
||||
# in the event we see issues with OpenSearch retrieval in our dev environments.
|
||||
ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX = (
|
||||
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
and os.environ.get("ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX", "").lower() == "true"
|
||||
)
|
||||
|
||||
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
|
||||
@@ -568,6 +587,7 @@ JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
JIRA_CONNECTOR_MAX_TICKET_SIZE = int(
|
||||
os.environ.get("JIRA_CONNECTOR_MAX_TICKET_SIZE", 100 * 1024)
|
||||
)
|
||||
JIRA_SLIM_PAGE_SIZE = int(os.environ.get("JIRA_SLIM_PAGE_SIZE", 500))
|
||||
|
||||
GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
|
||||
|
||||
@@ -729,6 +749,10 @@ JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default
|
||||
LOG_ONYX_MODEL_INTERACTIONS = (
|
||||
os.environ.get("LOG_ONYX_MODEL_INTERACTIONS", "").lower() == "true"
|
||||
)
|
||||
|
||||
PROMPT_CACHE_CHAT_HISTORY = (
|
||||
os.environ.get("PROMPT_CACHE_CHAT_HISTORY", "").lower() == "true"
|
||||
)
|
||||
# If set to `true` will enable additional logs about Vespa query performance
|
||||
# (time spent on finding the right docs + time spent fetching summaries from disk)
|
||||
LOG_VESPA_TIMING_INFORMATION = (
|
||||
@@ -849,6 +873,7 @@ AZURE_IMAGE_DEPLOYMENT_NAME = os.environ.get(
|
||||
|
||||
# configurable image model
|
||||
IMAGE_MODEL_NAME = os.environ.get("IMAGE_MODEL_NAME", "gpt-image-1")
|
||||
IMAGE_MODEL_PROVIDER = os.environ.get("IMAGE_MODEL_PROVIDER", "openai")
|
||||
|
||||
# Use managed Vespa (Vespa Cloud). If set, must also set VESPA_CLOUD_URL, VESPA_CLOUD_CERT_PATH and VESPA_CLOUD_KEY_PATH
|
||||
MANAGED_VESPA = os.environ.get("MANAGED_VESPA", "").lower() == "true"
|
||||
@@ -995,3 +1020,25 @@ COHERE_DEFAULT_API_KEY = os.environ.get("COHERE_DEFAULT_API_KEY")
|
||||
VERTEXAI_DEFAULT_CREDENTIALS = os.environ.get("VERTEXAI_DEFAULT_CREDENTIALS")
|
||||
VERTEXAI_DEFAULT_LOCATION = os.environ.get("VERTEXAI_DEFAULT_LOCATION", "global")
|
||||
OPENROUTER_DEFAULT_API_KEY = os.environ.get("OPENROUTER_DEFAULT_API_KEY")
|
||||
|
||||
INSTANCE_TYPE = (
|
||||
"managed"
|
||||
if os.environ.get("IS_MANAGED_INSTANCE", "").lower() == "true"
|
||||
else "cloud" if AUTH_TYPE == AuthType.CLOUD else "self_hosted"
|
||||
)
|
||||
|
||||
|
||||
## Discord Bot Configuration
|
||||
DISCORD_BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
|
||||
DISCORD_BOT_INVOKE_CHAR = os.environ.get("DISCORD_BOT_INVOKE_CHAR", "!")
|
||||
|
||||
|
||||
## Stripe Configuration
|
||||
# URL to fetch the Stripe publishable key from a public S3 bucket.
|
||||
# Publishable keys are safe to expose publicly - they can only initialize
|
||||
# Stripe.js and tokenize payment info, not make charges or access data.
|
||||
STRIPE_PUBLISHABLE_KEY_URL = (
|
||||
"https://onyx-stripe-public.s3.amazonaws.com/publishable-key.txt"
|
||||
)
|
||||
# Override for local testing with Stripe test keys (pk_test_*)
|
||||
STRIPE_PUBLISHABLE_KEY_OVERRIDE = os.environ.get("STRIPE_PUBLISHABLE_KEY")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
|
||||
INPUT_PROMPT_YAML = "./onyx/seeding/input_prompts.yaml"
|
||||
PROMPTS_YAML = "./onyx/seeding/prompts.yaml"
|
||||
PERSONAS_YAML = "./onyx/seeding/personas.yaml"
|
||||
NUM_RETURNED_HITS = 50
|
||||
@@ -12,9 +11,6 @@ NUM_POSTPROCESSED_RESULTS = 20
|
||||
# May be less depending on model
|
||||
MAX_CHUNKS_FED_TO_CHAT = int(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 25)
|
||||
|
||||
# Maximum percentage of the context window to fill with selected sections
|
||||
SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE = 0.8
|
||||
|
||||
# 1 / (1 + DOC_TIME_DECAY * doc-age-in-years), set to 0 to have no decay
|
||||
# Capped in Vespa at 0.5
|
||||
DOC_TIME_DECAY = float(
|
||||
@@ -27,11 +23,6 @@ FAVOR_RECENT_DECAY_MULTIPLIER = 2.0
|
||||
# Currently only applies to search flow not chat
|
||||
CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 1)
|
||||
CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 1)
|
||||
DISABLE_LLM_QUERY_REPHRASE = (
|
||||
os.environ.get("DISABLE_LLM_QUERY_REPHRASE", "").lower() == "true"
|
||||
)
|
||||
# 1 edit per 20 characters, currently unused due to fuzzy match being too slow
|
||||
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
|
||||
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
|
||||
# Weighting factor between Vector and Keyword Search, 1 for completely vector search
|
||||
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.5)))
|
||||
@@ -46,34 +37,6 @@ TITLE_CONTENT_RATIO = max(
|
||||
0, min(1, float(os.environ.get("TITLE_CONTENT_RATIO") or 0.10))
|
||||
)
|
||||
|
||||
# A list of languages passed to the LLM to rephase the query
|
||||
# For example "English,French,Spanish", be sure to use the "," separator
|
||||
# TODO these are not used, should probably reintroduce these
|
||||
MULTILINGUAL_QUERY_EXPANSION = os.environ.get("MULTILINGUAL_QUERY_EXPANSION") or None
|
||||
LANGUAGE_HINT = "\n" + (
|
||||
os.environ.get("LANGUAGE_HINT")
|
||||
or "IMPORTANT: Respond in the same language as my query!"
|
||||
)
|
||||
LANGUAGE_CHAT_NAMING_HINT = (
|
||||
os.environ.get("LANGUAGE_CHAT_NAMING_HINT")
|
||||
or "The name of the conversation must be in the same language as the user query."
|
||||
)
|
||||
|
||||
# Number of prompts each persona should have
|
||||
NUM_PERSONA_PROMPTS = 4
|
||||
NUM_PERSONA_PROMPT_GENERATION_CHUNKS = 5
|
||||
|
||||
# Agentic search takes significantly more tokens and therefore has much higher cost.
|
||||
# This configuration allows users to get a search-only experience with instant results
|
||||
# and no involvement from the LLM.
|
||||
# Additionally, some LLM providers have strict rate limits which may prohibit
|
||||
# sending many API requests at once (as is done in agentic search).
|
||||
# Whether the LLM should evaluate all of the document chunks passed in for usefulness
|
||||
# in relation to the user query
|
||||
DISABLE_LLM_DOC_RELEVANCE = (
|
||||
os.environ.get("DISABLE_LLM_DOC_RELEVANCE", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Stops streaming answers back to the UI if this pattern is seen:
|
||||
STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None
|
||||
|
||||
@@ -86,9 +49,6 @@ HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "").lower() == "true"
|
||||
NUM_INTERNET_SEARCH_RESULTS = int(os.environ.get("NUM_INTERNET_SEARCH_RESULTS") or 10)
|
||||
NUM_INTERNET_SEARCH_CHUNKS = int(os.environ.get("NUM_INTERNET_SEARCH_CHUNKS") or 50)
|
||||
|
||||
# Enable in-house model for detecting connector-based filtering in queries
|
||||
ENABLE_CONNECTOR_CLASSIFIER = os.environ.get("ENABLE_CONNECTOR_CLASSIFIER", False)
|
||||
|
||||
VESPA_SEARCHER_THREADS = int(os.environ.get("VESPA_SEARCHER_THREADS") or 2)
|
||||
|
||||
# Whether or not to use the semantic & keyword search expansions for Basic Search
|
||||
@@ -96,5 +56,3 @@ USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH = (
|
||||
os.environ.get("USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH", "false").lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
USE_DIV_CON_AGENT = os.environ.get("USE_DIV_CON_AGENT", "false").lower() == "true"
|
||||
|
||||
@@ -7,6 +7,7 @@ from enum import Enum
|
||||
|
||||
ONYX_DEFAULT_APPLICATION_NAME = "Onyx"
|
||||
ONYX_DISCORD_URL = "https://discord.gg/4NA5SbzrWb"
|
||||
ONYX_UTM_SOURCE = "onyx_app"
|
||||
SLACK_USER_TOKEN_PREFIX = "xoxp-"
|
||||
SLACK_BOT_TOKEN_PREFIX = "xoxb-"
|
||||
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
|
||||
@@ -22,6 +23,9 @@ PUBLIC_DOC_PAT = "PUBLIC"
|
||||
ID_SEPARATOR = ":;:"
|
||||
DEFAULT_BOOST = 0
|
||||
|
||||
# Tag for endpoints that should be included in the public API documentation
|
||||
PUBLIC_API_TAGS: list[str | Enum] = ["public"]
|
||||
|
||||
# Cookies
|
||||
FASTAPI_USERS_AUTH_COOKIE_NAME = (
|
||||
"fastapiusersauth" # Currently a constant, but logic allows for configuration
|
||||
@@ -89,6 +93,7 @@ SSL_CERT_FILE = "bundle.pem"
|
||||
DANSWER_API_KEY_PREFIX = "API_KEY__"
|
||||
DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN = "onyxapikey.ai"
|
||||
UNNAMED_KEY_PLACEHOLDER = "Unnamed"
|
||||
DISCORD_SERVICE_API_KEY_NAME = "discord-bot-service"
|
||||
|
||||
# Key-Value store keys
|
||||
KV_REINDEX_KEY = "needs_reindexing"
|
||||
@@ -235,6 +240,7 @@ class NotificationType(str, Enum):
|
||||
PERSONA_SHARED = "persona_shared"
|
||||
TRIAL_ENDS_TWO_DAYS = "two_day_trial_ending" # 2 days left in trial
|
||||
RELEASE_NOTES = "release_notes"
|
||||
ASSISTANT_FILES_READY = "assistant_files_ready"
|
||||
|
||||
|
||||
class BlobType(str, Enum):
|
||||
@@ -422,6 +428,9 @@ class OnyxRedisLocks:
|
||||
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
|
||||
USER_FILE_DELETE_LOCK_PREFIX = "da_lock:user_file_delete"
|
||||
|
||||
# Release notes
|
||||
RELEASE_NOTES_FETCH_LOCK = "da_lock:release_notes_fetch"
|
||||
|
||||
|
||||
class OnyxRedisSignals:
|
||||
BLOCK_VALIDATE_INDEXING_FENCES = "signal:block_validate_indexing_fences"
|
||||
|
||||
@@ -4,8 +4,6 @@ import os
|
||||
# Onyx Slack Bot Configs
|
||||
#####
|
||||
ONYX_BOT_NUM_RETRIES = int(os.environ.get("ONYX_BOT_NUM_RETRIES", "5"))
|
||||
# How much of the available input context can be used for thread context
|
||||
MAX_THREAD_CONTEXT_PERCENTAGE = 512 * 2 / 3072
|
||||
# Number of docs to display in "Reference Documents"
|
||||
ONYX_BOT_NUM_DOCS_TO_DISPLAY = int(os.environ.get("ONYX_BOT_NUM_DOCS_TO_DISPLAY", "5"))
|
||||
# If the LLM fails to answer, Onyx can still show the "Reference Documents"
|
||||
@@ -47,10 +45,6 @@ ONYX_BOT_MAX_WAIT_TIME = int(os.environ.get("ONYX_BOT_MAX_WAIT_TIME") or 180)
|
||||
# Time (in minutes) after which a Slack message is sent to the user to remind him to give feedback.
|
||||
# Set to 0 to disable it (default)
|
||||
ONYX_BOT_FEEDBACK_REMINDER = int(os.environ.get("ONYX_BOT_FEEDBACK_REMINDER") or 0)
|
||||
# Set to True to rephrase the Slack users messages
|
||||
ONYX_BOT_REPHRASE_MESSAGE = (
|
||||
os.environ.get("ONYX_BOT_REPHRASE_MESSAGE", "").lower() == "true"
|
||||
)
|
||||
|
||||
# ONYX_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD is the number of
|
||||
# responses OnyxBot can send in a given time period.
|
||||
|
||||
@@ -93,7 +93,7 @@ if __name__ == "__main__":
|
||||
#### Docs Changes
|
||||
|
||||
Create the new connector page (with guiding images!) with how to get the connector credentials and how to set up the
|
||||
connector in Onyx. Then create a Pull Request in https://github.com/onyx-dot-app/onyx-docs.
|
||||
connector in Onyx. Then create a Pull Request in [https://github.com/onyx-dot-app/documentation](https://github.com/onyx-dot-app/documentation).
|
||||
|
||||
### Before opening PR
|
||||
|
||||
|
||||
@@ -901,13 +901,16 @@ class OnyxConfluence:
|
||||
space_key: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
This is a confluence server specific method that can be used to
|
||||
This is a confluence server/data center specific method that can be used to
|
||||
fetch the permissions of a space.
|
||||
This is better logging than calling the get_space_permissions method
|
||||
because it returns a jsonrpc response.
|
||||
TODO: Make this call these endpoints for newer confluence versions:
|
||||
- /rest/api/space/{spaceKey}/permissions
|
||||
- /rest/api/space/{spaceKey}/permissions/anonymous
|
||||
|
||||
NOTE: This uses the JSON-RPC API which is the ONLY way to get space permissions
|
||||
on Confluence Server/Data Center. The REST API equivalent (expand=permissions)
|
||||
is Cloud-only and not available on Data Center as of version 8.9.x.
|
||||
|
||||
If this fails with 401 Unauthorized, the customer needs to enable JSON-RPC:
|
||||
Confluence Admin -> General Configuration -> Further Configuration
|
||||
-> Enable "Remote API (XML-RPC & SOAP)"
|
||||
"""
|
||||
url = "rpc/json-rpc/confluenceservice-v2"
|
||||
data = {
|
||||
@@ -916,7 +919,18 @@ class OnyxConfluence:
|
||||
"id": 7,
|
||||
"params": [space_key],
|
||||
}
|
||||
response = self.post(url, data=data)
|
||||
try:
|
||||
response = self.post(url, data=data)
|
||||
except HTTPError as e:
|
||||
if e.response is not None and e.response.status_code == 401:
|
||||
raise HTTPError(
|
||||
"Unauthorized (401) when calling JSON-RPC API for space permissions. "
|
||||
"This is likely because the Remote API is disabled. "
|
||||
"To fix: Confluence Admin -> General Configuration -> Further Configuration "
|
||||
"-> Enable 'Remote API (XML-RPC & SOAP)'",
|
||||
response=e.response,
|
||||
) from e
|
||||
raise
|
||||
logger.debug(f"jsonrpc response: {response}")
|
||||
if not response.get("result"):
|
||||
logger.warning(
|
||||
|
||||
@@ -97,10 +97,17 @@ def basic_expert_info_representation(info: BasicExpertInfo) -> str | None:
|
||||
def get_experts_stores_representations(
|
||||
experts: list[BasicExpertInfo] | None,
|
||||
) -> list[str] | None:
|
||||
"""Gets string representations of experts supplied.
|
||||
|
||||
If an expert cannot be represented as a string, it is omitted from the
|
||||
result.
|
||||
"""
|
||||
if not experts:
|
||||
return None
|
||||
|
||||
reps = [basic_expert_info_representation(owner) for owner in experts]
|
||||
reps: list[str | None] = [
|
||||
basic_expert_info_representation(owner) for owner in experts
|
||||
]
|
||||
return [owner for owner in reps if owner is not None]
|
||||
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from typing_extensions import override
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
|
||||
from onyx.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
|
||||
from onyx.configs.app_configs import JIRA_SLIM_PAGE_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
is_atlassian_date_error,
|
||||
@@ -57,7 +58,6 @@ logger = setup_logger()
|
||||
ONE_HOUR = 3600
|
||||
|
||||
_MAX_RESULTS_FETCH_IDS = 5000 # 5000
|
||||
_JIRA_SLIM_PAGE_SIZE = 500
|
||||
_JIRA_FULL_PAGE_SIZE = 50
|
||||
|
||||
# Constants for Jira field names
|
||||
@@ -683,7 +683,7 @@ class JiraConnector(
|
||||
jira_client=self.jira_client,
|
||||
jql=jql,
|
||||
start=current_offset,
|
||||
max_results=_JIRA_SLIM_PAGE_SIZE,
|
||||
max_results=JIRA_SLIM_PAGE_SIZE,
|
||||
all_issue_ids=checkpoint.all_issue_ids,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
nextPageToken=checkpoint.cursor,
|
||||
@@ -703,11 +703,11 @@ class JiraConnector(
|
||||
)
|
||||
)
|
||||
current_offset += 1
|
||||
if len(slim_doc_batch) >= _JIRA_SLIM_PAGE_SIZE:
|
||||
if len(slim_doc_batch) >= JIRA_SLIM_PAGE_SIZE:
|
||||
yield slim_doc_batch
|
||||
slim_doc_batch = []
|
||||
self.update_checkpoint_for_next_run(
|
||||
checkpoint, current_offset, prev_offset, _JIRA_SLIM_PAGE_SIZE
|
||||
checkpoint, current_offset, prev_offset, JIRA_SLIM_PAGE_SIZE
|
||||
)
|
||||
prev_offset = current_offset
|
||||
|
||||
|
||||
@@ -161,6 +161,8 @@ class DocumentBase(BaseModel):
|
||||
sections: list[TextSection | ImageSection]
|
||||
source: DocumentSource | None = None
|
||||
semantic_identifier: str # displayed in the UI as the main identifier for the doc
|
||||
# TODO(andrei): Ideally we could improve this to where each value is just a
|
||||
# list of strings.
|
||||
metadata: dict[str, str | list[str]]
|
||||
|
||||
# UTC time
|
||||
@@ -202,13 +204,7 @@ class DocumentBase(BaseModel):
|
||||
if not self.metadata:
|
||||
return None
|
||||
# Combined string for the key/value for easy filtering
|
||||
attributes: list[str] = []
|
||||
for k, v in self.metadata.items():
|
||||
if isinstance(v, list):
|
||||
attributes.extend([k + INDEX_SEPARATOR + vi for vi in v])
|
||||
else:
|
||||
attributes.append(k + INDEX_SEPARATOR + v)
|
||||
return attributes
|
||||
return convert_metadata_dict_to_list_of_strings(self.metadata)
|
||||
|
||||
def __sizeof__(self) -> int:
|
||||
size = sys.getsizeof(self.id)
|
||||
@@ -240,6 +236,66 @@ class DocumentBase(BaseModel):
|
||||
return " ".join([section.text for section in self.sections if section.text])
|
||||
|
||||
|
||||
def convert_metadata_dict_to_list_of_strings(
|
||||
metadata: dict[str, str | list[str]],
|
||||
) -> list[str]:
|
||||
"""Converts a metadata dict to a list of strings.
|
||||
|
||||
Each string is a key-value pair separated by the INDEX_SEPARATOR. If a key
|
||||
points to a list of values, each value generates a unique pair.
|
||||
|
||||
Args:
|
||||
metadata: The metadata dict to convert where values can be either a
|
||||
string or a list of strings.
|
||||
|
||||
Returns:
|
||||
A list of strings where each string is a key-value pair separated by the
|
||||
INDEX_SEPARATOR.
|
||||
"""
|
||||
attributes: list[str] = []
|
||||
for k, v in metadata.items():
|
||||
if isinstance(v, list):
|
||||
attributes.extend([k + INDEX_SEPARATOR + vi for vi in v])
|
||||
else:
|
||||
attributes.append(k + INDEX_SEPARATOR + v)
|
||||
return attributes
|
||||
|
||||
|
||||
def convert_metadata_list_of_strings_to_dict(
|
||||
metadata_list: list[str],
|
||||
) -> dict[str, str | list[str]]:
|
||||
"""
|
||||
Converts a list of strings to a metadata dict. The inverse of
|
||||
convert_metadata_dict_to_list_of_strings.
|
||||
|
||||
Assumes the input strings are formatted as in the output of
|
||||
convert_metadata_dict_to_list_of_strings.
|
||||
|
||||
The schema of the output metadata dict is suboptimal yet bound to legacy
|
||||
code. Ideally each key would just point to a list of strings, where each
|
||||
list might contain just one element.
|
||||
|
||||
Args:
|
||||
metadata_list: The list of strings to convert to a metadata dict.
|
||||
|
||||
Returns:
|
||||
A metadata dict where values can be either a string or a list of
|
||||
strings.
|
||||
"""
|
||||
metadata: dict[str, str | list[str]] = {}
|
||||
for item in metadata_list:
|
||||
key, value = item.split(INDEX_SEPARATOR, 1)
|
||||
if key in metadata:
|
||||
# We have already seen this key therefore it must point to a list.
|
||||
if isinstance(metadata[key], list):
|
||||
cast(list[str], metadata[key]).append(value)
|
||||
else:
|
||||
metadata[key] = [cast(str, metadata[key]), value]
|
||||
else:
|
||||
metadata[key] = value
|
||||
return metadata
|
||||
|
||||
|
||||
class Document(DocumentBase):
|
||||
"""Used for Onyx ingestion api, the ID is required"""
|
||||
|
||||
|
||||
@@ -13,13 +13,6 @@ class RecencyBiasSetting(str, Enum):
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class OptionalSearchSetting(str, Enum):
|
||||
ALWAYS = "always"
|
||||
NEVER = "never"
|
||||
# Determine whether to run search based on history and latest query
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class QueryType(str, Enum):
|
||||
"""
|
||||
The type of first-pass query to use for hybrid search.
|
||||
@@ -36,15 +29,3 @@ class SearchType(str, Enum):
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
INTERNET = "internet"
|
||||
|
||||
|
||||
class LLMEvaluationType(str, Enum):
|
||||
AGENTIC = "agentic" # applies agentic evaluation
|
||||
BASIC = "basic" # applies boolean evaluation
|
||||
SKIP = "skip" # skips evaluation
|
||||
UNSPECIFIED = "unspecified" # reverts to default
|
||||
|
||||
|
||||
class QueryFlow(str, Enum):
|
||||
SEARCH = "search"
|
||||
QUESTION_ANSWER = "question-answer"
|
||||
|
||||
@@ -31,7 +31,6 @@ from onyx.context.search.federated.slack_search_utils import is_recency_query
|
||||
from onyx.context.search.federated.slack_search_utils import should_include_message
|
||||
from onyx.context.search.models import ChunkIndexRequest
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import SearchQuery
|
||||
from onyx.db.document import DocumentSource
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.document_index_utils import (
|
||||
@@ -425,7 +424,6 @@ class SlackQueryResult(BaseModel):
|
||||
|
||||
def query_slack(
|
||||
query_string: str,
|
||||
original_query: SearchQuery,
|
||||
access_token: str,
|
||||
limit: int | None = None,
|
||||
allowed_private_channel: str | None = None,
|
||||
@@ -456,7 +454,7 @@ def query_slack(
|
||||
logger.info(f"Final query to slack: {final_query}")
|
||||
|
||||
# Detect if query asks for most recent results
|
||||
sort_by_time = is_recency_query(original_query.query)
|
||||
sort_by_time = is_recency_query(query_string)
|
||||
|
||||
slack_client = WebClient(token=access_token)
|
||||
try:
|
||||
@@ -536,8 +534,7 @@ def query_slack(
|
||||
)
|
||||
document_id = f"{channel_id}_{message_id}"
|
||||
|
||||
# compute recency bias (parallels vespa calculation) and metadata
|
||||
decay_factor = DOC_TIME_DECAY * original_query.recency_bias_multiplier
|
||||
decay_factor = DOC_TIME_DECAY
|
||||
doc_time = datetime.fromtimestamp(float(message_id))
|
||||
doc_age_years = (datetime.now() - doc_time).total_seconds() / (
|
||||
365 * 24 * 60 * 60
|
||||
@@ -1002,7 +999,6 @@ def slack_retrieval(
|
||||
query_slack,
|
||||
(
|
||||
query_string,
|
||||
query,
|
||||
access_token,
|
||||
query_limit,
|
||||
allowed_private_channel,
|
||||
@@ -1045,7 +1041,6 @@ def slack_retrieval(
|
||||
query_slack,
|
||||
(
|
||||
query_string,
|
||||
query,
|
||||
access_token,
|
||||
query_limit,
|
||||
allowed_private_channel,
|
||||
@@ -1225,7 +1220,6 @@ def slack_retrieval(
|
||||
source_type=DocumentSource.SLACK,
|
||||
title=chunk.title_prefix,
|
||||
boost=0,
|
||||
recency_bias=docid_to_message[document_id].recency_bias,
|
||||
score=convert_slack_score(docid_to_message[document_id].slack_score),
|
||||
hidden=False,
|
||||
is_relevant=None,
|
||||
|
||||
@@ -13,7 +13,9 @@ from onyx.context.search.federated.models import ChannelMetadata
|
||||
from onyx.context.search.models import ChunkIndexRequest
|
||||
from onyx.federated_connectors.slack.models import SlackEntities
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.llm.utils import llm_response_to_string
|
||||
from onyx.natural_language_processing.english_stopwords import ENGLISH_STOPWORDS_SET
|
||||
from onyx.onyxbot.slack.models import ChannelType
|
||||
from onyx.prompts.federated_search import SLACK_DATE_EXTRACTION_PROMPT
|
||||
from onyx.prompts.federated_search import SLACK_QUERY_EXPANSION_PROMPT
|
||||
@@ -112,7 +114,7 @@ def is_recency_query(query: str) -> bool:
|
||||
if not has_recency_keyword:
|
||||
return False
|
||||
|
||||
# Get combined stop words (NLTK + Slack-specific)
|
||||
# Get combined stop words (English + Slack-specific)
|
||||
all_stop_words = _get_combined_stop_words()
|
||||
|
||||
# Extract content words (excluding stop words)
|
||||
@@ -190,7 +192,7 @@ def extract_date_range_from_query(
|
||||
|
||||
try:
|
||||
prompt = SLACK_DATE_EXTRACTION_PROMPT.format(query=query)
|
||||
response = llm_response_to_string(llm.invoke(prompt))
|
||||
response = llm_response_to_string(llm.invoke(UserMessage(content=prompt)))
|
||||
|
||||
response_clean = _parse_llm_code_block_response(response)
|
||||
|
||||
@@ -487,7 +489,7 @@ def build_channel_override_query(channel_references: set[str], time_filter: str)
|
||||
return f"__CHANNEL_OVERRIDE__ {channel_filter}{time_filter}"
|
||||
|
||||
|
||||
# Slack-specific stop words (in addition to standard NLTK stop words)
|
||||
# Slack-specific stop words (in addition to standard English stop words)
|
||||
# These include Slack-specific terms and temporal/recency keywords
|
||||
SLACK_SPECIFIC_STOP_WORDS = frozenset(
|
||||
RECENCY_KEYWORDS
|
||||
@@ -507,27 +509,16 @@ SLACK_SPECIFIC_STOP_WORDS = frozenset(
|
||||
)
|
||||
|
||||
|
||||
def _get_combined_stop_words() -> set[str]:
|
||||
"""Get combined NLTK + Slack-specific stop words.
|
||||
def _get_combined_stop_words() -> frozenset[str]:
|
||||
"""Get combined English + Slack-specific stop words.
|
||||
|
||||
Returns a set of stop words for filtering content words.
|
||||
Falls back to just Slack-specific stop words if NLTK is unavailable.
|
||||
Returns a frozenset of stop words for filtering content words.
|
||||
|
||||
Note: Currently only supports English stop words. Non-English queries
|
||||
may have suboptimal content word extraction. Future enhancement could
|
||||
detect query language and load appropriate stop words.
|
||||
"""
|
||||
try:
|
||||
from nltk.corpus import stopwords # type: ignore
|
||||
|
||||
# TODO: Support multiple languages - currently hardcoded to English
|
||||
# Could detect language or allow configuration
|
||||
nltk_stop_words = set(stopwords.words("english"))
|
||||
except Exception:
|
||||
# Fallback if NLTK not available
|
||||
nltk_stop_words = set()
|
||||
|
||||
return nltk_stop_words | SLACK_SPECIFIC_STOP_WORDS
|
||||
return ENGLISH_STOPWORDS_SET | SLACK_SPECIFIC_STOP_WORDS
|
||||
|
||||
|
||||
def extract_content_words_from_recency_query(
|
||||
@@ -535,7 +526,7 @@ def extract_content_words_from_recency_query(
|
||||
) -> list[str]:
|
||||
"""Extract meaningful content words from a recency query.
|
||||
|
||||
Filters out NLTK stop words, Slack-specific terms, channel references, and proper nouns.
|
||||
Filters out English stop words, Slack-specific terms, channel references, and proper nouns.
|
||||
|
||||
Args:
|
||||
query_text: The user's query text
|
||||
@@ -544,7 +535,7 @@ def extract_content_words_from_recency_query(
|
||||
Returns:
|
||||
List of content words (up to MAX_CONTENT_WORDS)
|
||||
"""
|
||||
# Get combined stop words (NLTK + Slack-specific)
|
||||
# Get combined stop words (English + Slack-specific)
|
||||
all_stop_words = _get_combined_stop_words()
|
||||
|
||||
words = query_text.split()
|
||||
@@ -566,6 +557,23 @@ def extract_content_words_from_recency_query(
|
||||
return content_words_filtered[:MAX_CONTENT_WORDS]
|
||||
|
||||
|
||||
def _is_valid_keyword_query(line: str) -> bool:
|
||||
"""Check if a line looks like a valid keyword query vs explanatory text.
|
||||
|
||||
Returns False for lines that appear to be LLM explanations rather than keywords.
|
||||
"""
|
||||
# Reject lines that start with parentheses (explanatory notes)
|
||||
if line.startswith("("):
|
||||
return False
|
||||
|
||||
# Reject lines that are too long (likely sentences, not keywords)
|
||||
# Keywords should be short - reject if > 50 chars or > 6 words
|
||||
if len(line) > 50 or len(line.split()) > 6:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
"""Use LLM to expand query into multiple search variations.
|
||||
|
||||
@@ -576,8 +584,10 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
Returns:
|
||||
List of rephrased query strings (up to MAX_SLACK_QUERY_EXPANSIONS)
|
||||
"""
|
||||
prompt = SLACK_QUERY_EXPANSION_PROMPT.format(
|
||||
query=query_text, max_queries=MAX_SLACK_QUERY_EXPANSIONS
|
||||
prompt = UserMessage(
|
||||
content=SLACK_QUERY_EXPANSION_PROMPT.format(
|
||||
query=query_text, max_queries=MAX_SLACK_QUERY_EXPANSIONS
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -586,10 +596,18 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
response_clean = _parse_llm_code_block_response(response)
|
||||
|
||||
# Split into lines and filter out empty lines
|
||||
rephrased_queries = [
|
||||
raw_queries = [
|
||||
line.strip() for line in response_clean.split("\n") if line.strip()
|
||||
]
|
||||
|
||||
# Filter out lines that look like explanatory text rather than keywords
|
||||
rephrased_queries = [q for q in raw_queries if _is_valid_keyword_query(q)]
|
||||
|
||||
# Log if we filtered out garbage
|
||||
if len(raw_queries) != len(rephrased_queries):
|
||||
filtered_out = set(raw_queries) - set(rephrased_queries)
|
||||
logger.warning(f"Filtered out non-keyword LLM responses: {filtered_out}")
|
||||
|
||||
# If no queries generated, use empty query
|
||||
if not rephrased_queries:
|
||||
logger.debug("No content keywords extracted from query expansion")
|
||||
|
||||
@@ -5,27 +5,15 @@ from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.indexing.models import BaseChunk
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from onyx.tools.tool_implementations.web_search.models import WEB_SEARCH_PREFIX
|
||||
from shared_configs.enums import RerankerProvider
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
|
||||
MAX_METRICS_CONTENT = (
|
||||
200 # Just need enough characters to identify where in the doc the chunk is
|
||||
)
|
||||
|
||||
|
||||
class QueryExpansions(BaseModel):
|
||||
@@ -38,6 +26,7 @@ class QueryExpansionType(Enum):
|
||||
SEMANTIC = "semantic"
|
||||
|
||||
|
||||
# TODO clean up this stuff, reranking is no longer used
|
||||
class RerankingDetails(BaseModel):
|
||||
# If model is None (or num_rerank is 0), then reranking is turned off
|
||||
rerank_model_name: str | None
|
||||
@@ -131,13 +120,6 @@ class IndexFilters(BaseFilters, UserFileFilters):
|
||||
tenant_id: str | None = None
|
||||
|
||||
|
||||
class ChunkMetric(BaseModel):
|
||||
document_id: str
|
||||
chunk_content_start: str
|
||||
first_link: str | None
|
||||
score: float
|
||||
|
||||
|
||||
class ChunkContext(BaseModel):
|
||||
# If not specified (None), picked up from Persona settings if there is space
|
||||
# if specified (even if 0), it always uses the specified number of chunks above and below
|
||||
@@ -162,10 +144,6 @@ class BasicChunkRequest(BaseModel):
|
||||
# In case some queries favor recency more than other queries.
|
||||
recency_bias_multiplier: float = 1.0
|
||||
|
||||
# Sometimes we may want to extract specific keywords from a more semantic query for
|
||||
# a better keyword search.
|
||||
query_keywords: list[str] | None = None # Not used currently
|
||||
|
||||
limit: int | None = None
|
||||
offset: int | None = None # This one is not set currently
|
||||
|
||||
@@ -184,6 +162,8 @@ class ChunkIndexRequest(BasicChunkRequest):
|
||||
# Calculated final filters
|
||||
filters: IndexFilters
|
||||
|
||||
query_keywords: list[str] | None = None
|
||||
|
||||
|
||||
class ContextExpansionType(str, Enum):
|
||||
NOT_RELEVANT = "not_relevant"
|
||||
@@ -192,94 +172,18 @@ class ContextExpansionType(str, Enum):
|
||||
FULL_DOCUMENT = "full_document"
|
||||
|
||||
|
||||
class SearchRequest(ChunkContext):
|
||||
query: str
|
||||
|
||||
expanded_queries: QueryExpansions | None = None
|
||||
original_query: str | None = None
|
||||
|
||||
search_type: SearchType = SearchType.SEMANTIC
|
||||
|
||||
human_selected_filters: BaseFilters | None = None
|
||||
user_file_filters: UserFileFilters | None = None
|
||||
enable_auto_detect_filters: bool | None = None
|
||||
persona: Persona | None = None
|
||||
|
||||
# if None, no offset / limit
|
||||
offset: int | None = None
|
||||
limit: int | None = None
|
||||
|
||||
multilingual_expansion: list[str] | None = None
|
||||
recency_bias_multiplier: float = 1.0
|
||||
hybrid_alpha: float | None = None
|
||||
rerank_settings: RerankingDetails | None = None
|
||||
evaluation_type: LLMEvaluationType = LLMEvaluationType.UNSPECIFIED
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
precomputed_query_embedding: Embedding | None = None
|
||||
precomputed_is_keyword: bool | None = None
|
||||
precomputed_keywords: list[str] | None = None
|
||||
|
||||
|
||||
class SearchQuery(ChunkContext):
|
||||
query: str
|
||||
processed_keywords: list[str]
|
||||
search_type: SearchType
|
||||
evaluation_type: LLMEvaluationType
|
||||
filters: IndexFilters
|
||||
|
||||
# by this point, the chunks_above and chunks_below must be set
|
||||
chunks_above: int
|
||||
chunks_below: int
|
||||
|
||||
rerank_settings: RerankingDetails | None
|
||||
hybrid_alpha: float
|
||||
recency_bias_multiplier: float
|
||||
|
||||
# Only used if LLM evaluation type is not skip, None to use default settings
|
||||
max_llm_filter_sections: int
|
||||
|
||||
num_hits: int = NUM_RETURNED_HITS
|
||||
offset: int = 0
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
precomputed_query_embedding: Embedding | None = None
|
||||
|
||||
expanded_queries: QueryExpansions | None = None
|
||||
original_query: str | None
|
||||
|
||||
|
||||
class RetrievalDetails(ChunkContext):
|
||||
# Use LLM to determine whether to do a retrieval or only rely on existing history
|
||||
# If the Persona is configured to not run search (0 chunks), this is bypassed
|
||||
# If no Prompt is configured, the only search results are shown, this is bypassed
|
||||
run_search: OptionalSearchSetting = OptionalSearchSetting.AUTO
|
||||
# Is this a real-time/streaming call or a question where Onyx can take more time?
|
||||
# Used to determine reranking flow
|
||||
real_time: bool = True
|
||||
# The following have defaults in the Persona settings which can be overridden via
|
||||
# the query, if None, then use Persona settings
|
||||
filters: BaseFilters | None = None
|
||||
enable_auto_detect_filters: bool | None = None
|
||||
# if None, no offset / limit
|
||||
offset: int | None = None
|
||||
limit: int | None = None
|
||||
|
||||
# If this is set, only the highest matching chunk (or merged chunks) is returned
|
||||
dedupe_docs: bool = False
|
||||
|
||||
|
||||
class InferenceChunk(BaseChunk):
|
||||
document_id: str
|
||||
source_type: DocumentSource
|
||||
semantic_identifier: str
|
||||
title: str | None # Separate from Semantic Identifier though often same
|
||||
boost: int
|
||||
recency_bias: float
|
||||
score: float | None
|
||||
hidden: bool
|
||||
is_relevant: bool | None = None
|
||||
relevance_explanation: str | None = None
|
||||
# TODO(andrei): Ideally we could improve this to where each value is just a
|
||||
# list of strings.
|
||||
metadata: dict[str, str | list[str]]
|
||||
# Matched sections in the chunk. Uses Vespa syntax e.g. <hi>TEXT</hi>
|
||||
# to specify that a set of words should be highlighted. For example:
|
||||
@@ -466,6 +370,10 @@ class SearchDocsResponse(BaseModel):
|
||||
# document id is the most staightforward way.
|
||||
citation_mapping: dict[int, str]
|
||||
|
||||
# For cases where the frontend only needs to display a subset of the search docs
|
||||
# The whole list is typically still needed for later steps but this set should be saved separately
|
||||
displayed_docs: list[SearchDoc] | None = None
|
||||
|
||||
|
||||
class SavedSearchDoc(SearchDoc):
|
||||
db_doc_id: int
|
||||
@@ -524,25 +432,8 @@ class SavedSearchDoc(SearchDoc):
|
||||
return self_score < other_score
|
||||
|
||||
|
||||
class CitationDocInfo(BaseModel):
|
||||
search_doc: SearchDoc
|
||||
citation_number: int | None
|
||||
|
||||
|
||||
class SavedSearchDocWithContent(SavedSearchDoc):
|
||||
"""Used for endpoints that need to return the actual contents of the retrieved
|
||||
section in addition to the match_highlights."""
|
||||
|
||||
content: str
|
||||
|
||||
|
||||
class RetrievalMetricsContainer(BaseModel):
|
||||
search_type: SearchType
|
||||
metrics: list[ChunkMetric] # This contains the scores for retrieval as well
|
||||
|
||||
|
||||
class RerankMetricsContainer(BaseModel):
|
||||
"""The score held by this is the un-boosted, averaged score of the ensemble cross-encoders"""
|
||||
|
||||
metrics: list[ChunkMetric]
|
||||
raw_similarity_scores: list[float]
|
||||
|
||||
@@ -19,6 +19,7 @@ from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.natural_language_processing.english_stopwords import strip_stopwords
|
||||
from onyx.secondary_llm_flows.source_filter import extract_source_filter
|
||||
from onyx.secondary_llm_flows.time_filter import extract_time_filter
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -278,12 +279,16 @@ def search_pipeline(
|
||||
bypass_acl=chunk_search_request.bypass_acl,
|
||||
)
|
||||
|
||||
query_keywords = strip_stopwords(chunk_search_request.query)
|
||||
|
||||
query_request = ChunkIndexRequest(
|
||||
query=chunk_search_request.query,
|
||||
hybrid_alpha=chunk_search_request.hybrid_alpha,
|
||||
recency_bias_multiplier=chunk_search_request.recency_bias_multiplier,
|
||||
query_keywords=chunk_search_request.query_keywords,
|
||||
query_keywords=query_keywords,
|
||||
filters=filters,
|
||||
limit=chunk_search_request.limit,
|
||||
offset=chunk_search_request.offset,
|
||||
)
|
||||
|
||||
retrieved_chunks = search_chunks(
|
||||
|
||||
@@ -1,272 +0,0 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.chat_configs import BASE_RECENCY_DECAY
|
||||
from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
|
||||
from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
|
||||
from onyx.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
|
||||
from onyx.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER
|
||||
from onyx.configs.chat_configs import HYBRID_ALPHA
|
||||
from onyx.configs.chat_configs import HYBRID_ALPHA_KEYWORD
|
||||
from onyx.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import SearchQuery
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.context.search.preprocessing.access_filters import (
|
||||
build_access_filters_for_user,
|
||||
)
|
||||
from onyx.context.search.utils import (
|
||||
remove_stop_words_and_punctuation,
|
||||
)
|
||||
from onyx.db.models import User
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.natural_language_processing.search_nlp_models import QueryAnalysisModel
|
||||
from onyx.secondary_llm_flows.source_filter import extract_source_filter
|
||||
from onyx.secondary_llm_flows.time_filter import extract_time_filter
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import FunctionCall
|
||||
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def query_analysis(query: str) -> tuple[bool, list[str]]:
|
||||
analysis_model = QueryAnalysisModel()
|
||||
return analysis_model.predict(query)
|
||||
|
||||
|
||||
# TODO: This is unused code.
|
||||
@log_function_time(print_only=True)
|
||||
def retrieval_preprocessing(
|
||||
search_request: SearchRequest,
|
||||
user: User | None,
|
||||
llm: LLM,
|
||||
skip_query_analysis: bool,
|
||||
db_session: Session,
|
||||
favor_recent_decay_multiplier: float = FAVOR_RECENT_DECAY_MULTIPLIER,
|
||||
base_recency_decay: float = BASE_RECENCY_DECAY,
|
||||
bypass_acl: bool = False,
|
||||
) -> SearchQuery:
|
||||
"""Logic is as follows:
|
||||
Any global disables apply first
|
||||
Then any filters or settings as part of the query are used
|
||||
Then defaults to Persona settings if not specified by the query
|
||||
"""
|
||||
query = search_request.query
|
||||
limit = search_request.limit
|
||||
offset = search_request.offset
|
||||
persona = search_request.persona
|
||||
|
||||
preset_filters = search_request.human_selected_filters or BaseFilters()
|
||||
if persona and persona.document_sets and preset_filters.document_set is None:
|
||||
preset_filters.document_set = [
|
||||
document_set.name for document_set in persona.document_sets
|
||||
]
|
||||
|
||||
time_filter = preset_filters.time_cutoff
|
||||
if time_filter is None and persona:
|
||||
time_filter = persona.search_start_date
|
||||
|
||||
source_filter = preset_filters.source_type
|
||||
|
||||
auto_detect_time_filter = True
|
||||
auto_detect_source_filter = True
|
||||
if not search_request.enable_auto_detect_filters:
|
||||
logger.debug("Retrieval details disables auto detect filters")
|
||||
auto_detect_time_filter = False
|
||||
auto_detect_source_filter = False
|
||||
elif persona and persona.llm_filter_extraction is False:
|
||||
logger.debug("Persona disables auto detect filters")
|
||||
auto_detect_time_filter = False
|
||||
auto_detect_source_filter = False
|
||||
else:
|
||||
logger.debug("Auto detect filters enabled")
|
||||
|
||||
if (
|
||||
time_filter is not None
|
||||
and persona
|
||||
and persona.recency_bias != RecencyBiasSetting.AUTO
|
||||
):
|
||||
auto_detect_time_filter = False
|
||||
logger.debug("Not extract time filter - already provided")
|
||||
if source_filter is not None:
|
||||
logger.debug("Not extract source filter - already provided")
|
||||
auto_detect_source_filter = False
|
||||
|
||||
# Based on the query figure out if we should apply any hard time filters /
|
||||
# if we should bias more recent docs even more strongly
|
||||
run_time_filters = (
|
||||
FunctionCall(extract_time_filter, (query, llm), {})
|
||||
if auto_detect_time_filter
|
||||
else None
|
||||
)
|
||||
|
||||
# Based on the query, figure out if we should apply any source filters
|
||||
run_source_filters = (
|
||||
FunctionCall(extract_source_filter, (query, llm, db_session), {})
|
||||
if auto_detect_source_filter
|
||||
else None
|
||||
)
|
||||
|
||||
# Sometimes this is pre-computed in parallel with other heavy tasks to improve
|
||||
# latency, and in that case we don't need to run the model again
|
||||
run_query_analysis = (
|
||||
None
|
||||
if (skip_query_analysis or search_request.precomputed_is_keyword is not None)
|
||||
else FunctionCall(query_analysis, (query,), {})
|
||||
)
|
||||
|
||||
functions_to_run = [
|
||||
filter_fn
|
||||
for filter_fn in [
|
||||
run_time_filters,
|
||||
run_source_filters,
|
||||
run_query_analysis,
|
||||
]
|
||||
if filter_fn
|
||||
]
|
||||
parallel_results = run_functions_in_parallel(functions_to_run)
|
||||
|
||||
predicted_time_cutoff, predicted_favor_recent = (
|
||||
parallel_results[run_time_filters.result_id]
|
||||
if run_time_filters
|
||||
else (None, None)
|
||||
)
|
||||
predicted_source_filters = (
|
||||
parallel_results[run_source_filters.result_id] if run_source_filters else None
|
||||
)
|
||||
|
||||
# The extracted keywords right now are not very reliable, not using for now
|
||||
# Can maybe use for highlighting
|
||||
is_keyword, _extracted_keywords = False, None
|
||||
if search_request.precomputed_is_keyword is not None:
|
||||
is_keyword = search_request.precomputed_is_keyword
|
||||
_extracted_keywords = search_request.precomputed_keywords
|
||||
elif run_query_analysis:
|
||||
is_keyword, _extracted_keywords = parallel_results[run_query_analysis.result_id]
|
||||
|
||||
all_query_terms = query.split()
|
||||
processed_keywords = (
|
||||
remove_stop_words_and_punctuation(all_query_terms)
|
||||
# If the user is using a different language, don't edit the query or remove english stopwords
|
||||
if not search_request.multilingual_expansion
|
||||
else all_query_terms
|
||||
)
|
||||
|
||||
user_acl_filters = (
|
||||
None if bypass_acl else build_access_filters_for_user(user, db_session)
|
||||
)
|
||||
user_file_filters = search_request.user_file_filters
|
||||
user_file_ids = (user_file_filters.user_file_ids or []) if user_file_filters else []
|
||||
if persona and persona.user_files:
|
||||
user_file_ids = list(
|
||||
set(user_file_ids) | set([file.id for file in persona.user_files])
|
||||
)
|
||||
|
||||
final_filters = IndexFilters(
|
||||
user_file_ids=user_file_ids,
|
||||
project_id=user_file_filters.project_id if user_file_filters else None,
|
||||
source_type=preset_filters.source_type or predicted_source_filters,
|
||||
document_set=preset_filters.document_set,
|
||||
time_cutoff=time_filter or predicted_time_cutoff,
|
||||
tags=preset_filters.tags, # Tags are never auto-extracted
|
||||
access_control_list=user_acl_filters,
|
||||
tenant_id=get_current_tenant_id() if MULTI_TENANT else None,
|
||||
# kg_entities=preset_filters.kg_entities,
|
||||
# kg_relationships=preset_filters.kg_relationships,
|
||||
# kg_terms=preset_filters.kg_terms,
|
||||
# kg_sources=preset_filters.kg_sources,
|
||||
# kg_chunk_id_zero_only=preset_filters.kg_chunk_id_zero_only,
|
||||
)
|
||||
|
||||
llm_evaluation_type = LLMEvaluationType.BASIC
|
||||
if search_request.evaluation_type is not LLMEvaluationType.UNSPECIFIED:
|
||||
llm_evaluation_type = search_request.evaluation_type
|
||||
|
||||
elif persona:
|
||||
llm_evaluation_type = (
|
||||
LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
)
|
||||
|
||||
if DISABLE_LLM_DOC_RELEVANCE:
|
||||
if llm_evaluation_type:
|
||||
logger.info(
|
||||
"LLM chunk filtering would have run but has been globally disabled"
|
||||
)
|
||||
llm_evaluation_type = LLMEvaluationType.SKIP
|
||||
|
||||
rerank_settings = search_request.rerank_settings
|
||||
# If not explicitly specified by the query, use the current settings
|
||||
if rerank_settings is None:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
# For non-streaming flows, the rerank settings are applied at the search_request level
|
||||
if not search_settings.disable_rerank_for_streaming:
|
||||
rerank_settings = RerankingDetails.from_db_model(search_settings)
|
||||
|
||||
# Decays at 1 / (1 + (multiplier * num years))
|
||||
if persona and persona.recency_bias == RecencyBiasSetting.NO_DECAY:
|
||||
recency_bias_multiplier = 0.0
|
||||
elif persona and persona.recency_bias == RecencyBiasSetting.BASE_DECAY:
|
||||
recency_bias_multiplier = base_recency_decay
|
||||
elif persona and persona.recency_bias == RecencyBiasSetting.FAVOR_RECENT:
|
||||
recency_bias_multiplier = base_recency_decay * favor_recent_decay_multiplier
|
||||
else:
|
||||
if predicted_favor_recent:
|
||||
recency_bias_multiplier = base_recency_decay * favor_recent_decay_multiplier
|
||||
else:
|
||||
recency_bias_multiplier = base_recency_decay
|
||||
|
||||
hybrid_alpha = HYBRID_ALPHA_KEYWORD if is_keyword else HYBRID_ALPHA
|
||||
if search_request.hybrid_alpha:
|
||||
hybrid_alpha = search_request.hybrid_alpha
|
||||
|
||||
# Search request overrides anything else as it's explicitly set by the request
|
||||
# If not explicitly specified, use the persona settings if they exist
|
||||
# Otherwise, use the global defaults
|
||||
chunks_above = (
|
||||
search_request.chunks_above
|
||||
if search_request.chunks_above is not None
|
||||
else (persona.chunks_above if persona else CONTEXT_CHUNKS_ABOVE)
|
||||
)
|
||||
chunks_below = (
|
||||
search_request.chunks_below
|
||||
if search_request.chunks_below is not None
|
||||
else (persona.chunks_below if persona else CONTEXT_CHUNKS_BELOW)
|
||||
)
|
||||
|
||||
return SearchQuery(
|
||||
query=query,
|
||||
original_query=search_request.original_query,
|
||||
processed_keywords=processed_keywords,
|
||||
search_type=SearchType.KEYWORD if is_keyword else SearchType.SEMANTIC,
|
||||
evaluation_type=llm_evaluation_type,
|
||||
filters=final_filters,
|
||||
hybrid_alpha=hybrid_alpha,
|
||||
recency_bias_multiplier=recency_bias_multiplier,
|
||||
num_hits=limit if limit is not None else NUM_RETURNED_HITS,
|
||||
offset=offset or 0,
|
||||
rerank_settings=rerank_settings,
|
||||
# Should match the LLM filtering to the same as the reranked, it's understood as this is the number of results
|
||||
# the user wants to do heavier processing on, so do the same for the LLM if reranking is on
|
||||
# if no reranking settings are set, then use the global default
|
||||
max_llm_filter_sections=(
|
||||
rerank_settings.num_rerank if rerank_settings else NUM_POSTPROCESSED_RESULTS
|
||||
),
|
||||
chunks_above=chunks_above,
|
||||
chunks_below=chunks_below,
|
||||
full_doc=search_request.full_doc,
|
||||
precomputed_query_embedding=search_request.precomputed_query_embedding,
|
||||
expanded_queries=search_request.expanded_queries,
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user