mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-26 01:52:45 +00:00
Compare commits
147 Commits
release/v2
...
v2.9.9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c4534739cc | ||
|
|
f1c30974f5 | ||
|
|
81bf07fb15 | ||
|
|
b565bf8291 | ||
|
|
b4da99cbdd | ||
|
|
f910feea0f | ||
|
|
e3af8c6c8a | ||
|
|
d6e46ed792 | ||
|
|
4ce1f4ecdd | ||
|
|
a4678884d7 | ||
|
|
c861ba68f1 | ||
|
|
b1d0e0bb0b | ||
|
|
0d78bf52e3 | ||
|
|
bd743282e6 | ||
|
|
d44d1d92b3 | ||
|
|
4cedcfee59 | ||
|
|
90a721a76e | ||
|
|
3ccd99e931 | ||
|
|
9076bf603f | ||
|
|
8c6e0a70c3 | ||
|
|
bebe9555d4 | ||
|
|
c530722c9f | ||
|
|
68380b4ddb | ||
|
|
b3380746ab | ||
|
|
56be114c87 | ||
|
|
54f467da5c | ||
|
|
8726b112fe | ||
|
|
92181d07b2 | ||
|
|
3a73f7fab2 | ||
|
|
7dabaca7cd | ||
|
|
dec4748825 | ||
|
|
072836cd86 | ||
|
|
2705b5fb0e | ||
|
|
37dcde4226 | ||
|
|
a765b5f622 | ||
|
|
5e093368d1 | ||
|
|
f945ab6b05 | ||
|
|
11b7a22404 | ||
|
|
8e34f944cc | ||
|
|
32606dc752 | ||
|
|
1f6c4b40bf | ||
|
|
1943f1c745 | ||
|
|
82460729a6 | ||
|
|
c445e6a8c0 | ||
|
|
8d30a03d7f | ||
|
|
277428f579 | ||
|
|
9f8c0d4237 | ||
|
|
9ccbb6a04b | ||
|
|
58a943f782 | ||
|
|
9021c607f2 | ||
|
|
c03b0d80fd | ||
|
|
fcf0b316a4 | ||
|
|
157f672b4b | ||
|
|
51b9484b96 | ||
|
|
0c8f55c049 | ||
|
|
c7be2571d1 | ||
|
|
4948b6cca9 | ||
|
|
638ea5f316 | ||
|
|
6e3268ca75 | ||
|
|
d8921df60c | ||
|
|
693d9f5f69 | ||
|
|
02e17871cc | ||
|
|
209cfd00b0 | ||
|
|
cd36baa484 | ||
|
|
c78fe275af | ||
|
|
c935c4808f | ||
|
|
4ebcfef541 | ||
|
|
e320ef9d9c | ||
|
|
9e02438af5 | ||
|
|
177e097ddb | ||
|
|
9ecd47ec31 | ||
|
|
83f3d29b10 | ||
|
|
12e668cc0f | ||
|
|
afe8376d5e | ||
|
|
082ef3e096 | ||
|
|
cb2951a1c0 | ||
|
|
eda5598af5 | ||
|
|
0bbb4b6988 | ||
|
|
4768aadb20 | ||
|
|
e05e85e782 | ||
|
|
6408f61307 | ||
|
|
5a5cd51e4f | ||
|
|
7c047c47a0 | ||
|
|
22138bbb33 | ||
|
|
7cff1064a8 | ||
|
|
deeb6fdcd2 | ||
|
|
3e7f4e0aa5 | ||
|
|
ac73671e35 | ||
|
|
3c20d132e0 | ||
|
|
0e3e7eb4a2 | ||
|
|
c85aebe8ab | ||
|
|
a47e6a3146 | ||
|
|
1e61737e03 | ||
|
|
c7fc1cd5ae | ||
|
|
e2b60bf67c | ||
|
|
f4d4d14286 | ||
|
|
1c24bc6ea2 | ||
|
|
cacbd18dcd | ||
|
|
8527b83b15 | ||
|
|
33e37a1846 | ||
|
|
d454d8a878 | ||
|
|
00ad65a6a8 | ||
|
|
dac60d403c | ||
|
|
6256b2854d | ||
|
|
8acb8e191d | ||
|
|
8c4cbddc43 | ||
|
|
f6cd006bd6 | ||
|
|
0033934319 | ||
|
|
ff87b79d14 | ||
|
|
ebf18af7c9 | ||
|
|
cf67ae962c | ||
|
|
7a9a132739 | ||
|
|
33bad8c37b | ||
|
|
9241ff7a75 | ||
|
|
0a25bc30ec | ||
|
|
e359732f4c | ||
|
|
be47866a4d | ||
|
|
8a20540559 | ||
|
|
e6e1f2860a | ||
|
|
fc3f433df7 | ||
|
|
016caf453b | ||
|
|
a9de25053f | ||
|
|
8ef8dfdeb7 | ||
|
|
0643b626d9 | ||
|
|
64a0eb52e0 | ||
|
|
b82ffc82cf | ||
|
|
b3014b9911 | ||
|
|
439707c395 | ||
|
|
65351aa8bd | ||
|
|
b44ee07eaf | ||
|
|
065d391c08 | ||
|
|
14fe3b375f | ||
|
|
bb1b96dded | ||
|
|
9f949ae2d9 | ||
|
|
975c0e8009 | ||
|
|
3dfb38c460 | ||
|
|
a1512a0485 | ||
|
|
8ea3bacd38 | ||
|
|
6b560b8162 | ||
|
|
3b750939ed | ||
|
|
bd4cb17a48 | ||
|
|
485cd9a311 | ||
|
|
2108c72353 | ||
|
|
98f43fb6ab | ||
|
|
e112ebb371 | ||
|
|
f88cbcfe27 | ||
|
|
0df0b10d3a |
408
.github/workflows/deployment.yml
vendored
408
.github/workflows/deployment.yml
vendored
@@ -8,7 +8,9 @@ on:
|
||||
|
||||
# Set restrictive default permissions for all jobs. Jobs that need more permissions
|
||||
# should explicitly declare them.
|
||||
permissions: {}
|
||||
permissions:
|
||||
# Required for OIDC authentication with AWS
|
||||
id-token: write # zizmor: ignore[excessive-permissions]
|
||||
|
||||
env:
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
|
||||
@@ -148,19 +150,32 @@ jobs:
|
||||
needs:
|
||||
- check-version-tag
|
||||
if: always() && needs.check-version-tag.result == 'failure' && github.event_name != 'workflow_dispatch'
|
||||
environment: release
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
environment: release
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
MONITOR_DEPLOYMENTS_WEBHOOK, deploy/monitor-deployments-webhook
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
webhook-url: ${{ env.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
failed-jobs: "• check-version-tag"
|
||||
title: "🚨 Version Tag Check Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
@@ -168,8 +183,8 @@ jobs:
|
||||
build-desktop:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-desktop == 'true'
|
||||
environment: release
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: write
|
||||
actions: read
|
||||
strategy:
|
||||
@@ -187,12 +202,33 @@ jobs:
|
||||
|
||||
runs-on: ${{ matrix.platform }}
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6.0.1
|
||||
with:
|
||||
# NOTE: persist-credentials is needed for tauri-action to create GitHub releases.
|
||||
persist-credentials: true # zizmor: ignore[artipacked]
|
||||
|
||||
- name: Configure AWS credentials
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
APPLE_ID, deploy/apple-id
|
||||
APPLE_PASSWORD, deploy/apple-password
|
||||
APPLE_CERTIFICATE, deploy/apple-certificate
|
||||
APPLE_CERTIFICATE_PASSWORD, deploy/apple-certificate-password
|
||||
KEYCHAIN_PASSWORD, deploy/keychain-password
|
||||
APPLE_TEAM_ID, deploy/apple-team-id
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: install dependencies (ubuntu only)
|
||||
if: startsWith(matrix.platform, 'ubuntu-')
|
||||
run: |
|
||||
@@ -287,27 +323,52 @@ jobs:
|
||||
|
||||
Write-Host "Versions set to: $VERSION"
|
||||
|
||||
- uses: tauri-apps/tauri-action@19b93bb55601e3e373a93cfb6eb4242e45f5af20 # ratchet:tauri-apps/tauri-action@action-v0.6.0
|
||||
- name: Import Apple Developer Certificate
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
run: |
|
||||
echo $APPLE_CERTIFICATE | base64 --decode > certificate.p12
|
||||
security create-keychain -p "$KEYCHAIN_PASSWORD" build.keychain
|
||||
security default-keychain -s build.keychain
|
||||
security unlock-keychain -p "$KEYCHAIN_PASSWORD" build.keychain
|
||||
security set-keychain-settings -t 3600 -u build.keychain
|
||||
security import certificate.p12 -k build.keychain -P "$APPLE_CERTIFICATE_PASSWORD" -T /usr/bin/codesign
|
||||
security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k "$KEYCHAIN_PASSWORD" build.keychain
|
||||
security find-identity -v -p codesigning build.keychain
|
||||
|
||||
- name: Verify Certificate
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
run: |
|
||||
CERT_INFO=$(security find-identity -v -p codesigning build.keychain | grep -E "(Developer ID Application|Apple Distribution|Apple Development)" | head -n 1)
|
||||
CERT_ID=$(echo "$CERT_INFO" | awk -F'"' '{print $2}')
|
||||
echo "CERT_ID=$CERT_ID" >> $GITHUB_ENV
|
||||
echo "Certificate imported."
|
||||
|
||||
- uses: tauri-apps/tauri-action@73fb865345c54760d875b94642314f8c0c894afa # ratchet:tauri-apps/tauri-action@action-v0.6.1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
APPLE_ID: ${{ env.APPLE_ID }}
|
||||
APPLE_PASSWORD: ${{ env.APPLE_PASSWORD }}
|
||||
APPLE_SIGNING_IDENTITY: ${{ env.CERT_ID }}
|
||||
APPLE_TEAM_ID: ${{ env.APPLE_TEAM_ID }}
|
||||
with:
|
||||
tagName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseBody: "See the assets to download this version and install."
|
||||
releaseDraft: true
|
||||
prerelease: false
|
||||
assetNamePattern: "[name]_[arch][ext]"
|
||||
args: ${{ matrix.args }}
|
||||
|
||||
build-web-amd64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-web == 'true'
|
||||
environment: release
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-web-amd64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -320,6 +381,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -334,8 +409,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
@@ -360,13 +435,13 @@ jobs:
|
||||
build-web-arm64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-web == 'true'
|
||||
environment: release
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-web-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -379,6 +454,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -393,8 +482,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
@@ -421,26 +510,40 @@ jobs:
|
||||
- determine-builds
|
||||
- build-web-amd64
|
||||
- build-web-arm64
|
||||
environment: release
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-merge-web
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -475,8 +578,8 @@ jobs:
|
||||
- runner=4cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-web-cloud-amd64
|
||||
- extras=ecr-cache
|
||||
environment: release
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -489,6 +592,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -503,8 +620,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
@@ -537,13 +654,13 @@ jobs:
|
||||
build-web-cloud-arm64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-web-cloud == 'true'
|
||||
environment: release
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-web-cloud-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -556,6 +673,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -570,8 +701,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
@@ -606,26 +737,40 @@ jobs:
|
||||
- determine-builds
|
||||
- build-web-cloud-amd64
|
||||
- build-web-cloud-arm64
|
||||
environment: release
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-merge-web-cloud
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -652,13 +797,13 @@ jobs:
|
||||
build-backend-amd64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-backend == 'true'
|
||||
environment: release
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-backend-amd64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -671,6 +816,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -685,8 +844,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
@@ -710,13 +869,13 @@ jobs:
|
||||
build-backend-arm64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-backend == 'true'
|
||||
environment: release
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-backend-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -729,6 +888,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -743,8 +916,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
@@ -770,26 +943,40 @@ jobs:
|
||||
- determine-builds
|
||||
- build-backend-amd64
|
||||
- build-backend-arm64
|
||||
environment: release
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-merge-backend
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -819,7 +1006,6 @@ jobs:
|
||||
build-model-server-amd64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-model-server == 'true'
|
||||
environment: release
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
@@ -827,6 +1013,7 @@ jobs:
|
||||
- volume=40gb
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -839,6 +1026,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -855,8 +1056,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
@@ -884,7 +1085,6 @@ jobs:
|
||||
build-model-server-arm64:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-model-server == 'true'
|
||||
environment: release
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
@@ -892,6 +1092,7 @@ jobs:
|
||||
- volume=40gb
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -904,6 +1105,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -920,8 +1135,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
@@ -951,26 +1166,40 @@ jobs:
|
||||
- determine-builds
|
||||
- build-model-server-amd64
|
||||
- build-model-server-arm64
|
||||
environment: release
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-merge-model-server
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -1002,18 +1231,32 @@ jobs:
|
||||
- determine-builds
|
||||
- merge-web
|
||||
if: needs.merge-web.result == 'success'
|
||||
environment: release
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-trivy-scan-web
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
@@ -1029,8 +1272,8 @@ jobs:
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
@@ -1043,18 +1286,32 @@ jobs:
|
||||
- determine-builds
|
||||
- merge-web-cloud
|
||||
if: needs.merge-web-cloud.result == 'success'
|
||||
environment: release
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-trivy-scan-web-cloud
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
@@ -1070,8 +1327,8 @@ jobs:
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
@@ -1084,13 +1341,13 @@ jobs:
|
||||
- determine-builds
|
||||
- merge-backend
|
||||
if: needs.merge-backend.result == 'success'
|
||||
environment: release
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-trivy-scan-backend
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
steps:
|
||||
@@ -1101,6 +1358,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
@@ -1117,8 +1388,8 @@ jobs:
|
||||
-v ${{ github.workspace }}/backend/.trivyignore:/tmp/.trivyignore:ro \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
@@ -1132,18 +1403,32 @@ jobs:
|
||||
- determine-builds
|
||||
- merge-model-server
|
||||
if: needs.merge-model-server.result == 'success'
|
||||
environment: release
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-trivy-scan-model-server
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
@@ -1159,8 +1444,8 @@ jobs:
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
@@ -1185,16 +1470,29 @@ jobs:
|
||||
- build-model-server-arm64
|
||||
- merge-model-server
|
||||
if: always() && (needs.build-desktop.result == 'failure' || needs.build-web-amd64.result == 'failure' || needs.build-web-arm64.result == 'failure' || needs.merge-web.result == 'failure' || needs.build-web-cloud-amd64.result == 'failure' || needs.build-web-cloud-arm64.result == 'failure' || needs.merge-web-cloud.result == 'failure' || needs.build-backend-amd64.result == 'failure' || needs.build-backend-arm64.result == 'failure' || needs.merge-backend.result == 'failure' || needs.build-model-server-amd64.result == 'failure' || needs.build-model-server-arm64.result == 'failure' || needs.merge-model-server.result == 'failure') && needs.determine-builds.outputs.is-test-run != 'true'
|
||||
environment: release
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
MONITOR_DEPLOYMENTS_WEBHOOK, deploy/monitor-deployments-webhook
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Determine failed jobs
|
||||
id: failed-jobs
|
||||
shell: bash
|
||||
@@ -1260,7 +1558,7 @@ jobs:
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
webhook-url: ${{ env.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
failed-jobs: ${{ steps.failed-jobs.outputs.jobs }}
|
||||
title: "🚨 Deployment Workflow Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
|
||||
@@ -13,7 +13,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # ratchet:actions/stale@v10
|
||||
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # ratchet:actions/stale@v10
|
||||
with:
|
||||
stale-issue-message: 'This issue is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
|
||||
stale-pr-message: 'This PR is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
|
||||
|
||||
@@ -172,7 +172,7 @@ jobs:
|
||||
|
||||
- name: Upload Docker logs
|
||||
if: failure()
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v5
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-logs-${{ matrix.test-dir }}
|
||||
path: docker-logs/
|
||||
|
||||
7
.github/workflows/pr-integration-tests.yml
vendored
7
.github/workflows/pr-integration-tests.yml
vendored
@@ -310,8 +310,9 @@ jobs:
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=1
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
MCP_SERVER_ENABLED=true
|
||||
USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
|
||||
EOF
|
||||
|
||||
- name: Start Docker containers
|
||||
@@ -438,7 +439,7 @@ jobs:
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-${{ matrix.test-dir.name }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
@@ -567,7 +568,7 @@ jobs:
|
||||
|
||||
- name: Upload logs (multi-tenant)
|
||||
if: always()
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-multitenant
|
||||
path: ${{ github.workspace }}/docker-compose-multitenant.log
|
||||
|
||||
2
.github/workflows/pr-jest-tests.yml
vendored
2
.github/workflows/pr-jest-tests.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports
|
||||
if: always()
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: jest-coverage-${{ github.run_id }}
|
||||
path: ./web/coverage
|
||||
|
||||
@@ -301,7 +301,7 @@ jobs:
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
MCP_SERVER_ENABLED=true
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=1
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
EOF
|
||||
|
||||
- name: Start Docker containers
|
||||
@@ -424,7 +424,7 @@ jobs:
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-${{ matrix.test-dir.name }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
4
.github/workflows/pr-playwright-tests.yml
vendored
4
.github/workflows/pr-playwright-tests.yml
vendored
@@ -435,7 +435,7 @@ jobs:
|
||||
fi
|
||||
npx playwright test --project ${PROJECT}
|
||||
|
||||
- uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
if: always()
|
||||
with:
|
||||
# Includes test results and trace.zip files
|
||||
@@ -455,7 +455,7 @@ jobs:
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
3
.github/workflows/pr-python-checks.yml
vendored
3
.github/workflows/pr-python-checks.yml
vendored
@@ -50,8 +50,9 @@ jobs:
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: backend/.mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
restore-keys: |
|
||||
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
|
||||
mypy-${{ runner.os }}-
|
||||
|
||||
- name: Run MyPy
|
||||
|
||||
2
.github/workflows/pr-python-model-tests.yml
vendored
2
.github/workflows/pr-python-model-tests.yml
vendored
@@ -144,7 +144,7 @@ jobs:
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -21,6 +21,7 @@ backend/tests/regression/search_quality/*.json
|
||||
backend/onyx/evals/data/
|
||||
backend/onyx/evals/one_off/*.json
|
||||
*.log
|
||||
*.csv
|
||||
|
||||
# secret files
|
||||
.env
|
||||
|
||||
@@ -11,7 +11,6 @@ repos:
|
||||
- id: uv-sync
|
||||
args: ["--locked", "--all-extras"]
|
||||
- id: uv-lock
|
||||
files: ^pyproject\.toml$
|
||||
- id: uv-export
|
||||
name: uv-export default.txt
|
||||
args:
|
||||
|
||||
@@ -42,7 +42,9 @@ RUN apt-get update && \
|
||||
pkg-config \
|
||||
gcc \
|
||||
nano \
|
||||
vim && \
|
||||
vim \
|
||||
libjemalloc2 \
|
||||
&& \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
apt-get clean
|
||||
|
||||
@@ -130,6 +132,13 @@ ENV PYTHONPATH=/app
|
||||
ARG ONYX_VERSION=0.0.0-dev
|
||||
ENV ONYX_VERSION=${ONYX_VERSION}
|
||||
|
||||
# Use jemalloc instead of glibc malloc to reduce memory fragmentation
|
||||
# in long-running Python processes (API server, Celery workers).
|
||||
# The soname is architecture-independent; the dynamic linker resolves
|
||||
# the correct path from standard library directories.
|
||||
# Placed after all RUN steps so build-time processes are unaffected.
|
||||
ENV LD_PRELOAD=libjemalloc.so.2
|
||||
|
||||
# Default command which does nothing
|
||||
# This container is used by api server and background which specify their own CMD
|
||||
CMD ["tail", "-f", "/dev/null"]
|
||||
|
||||
@@ -225,7 +225,6 @@ def do_run_migrations(
|
||||
) -> None:
|
||||
if create_schema:
|
||||
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
|
||||
connection.execute(text("COMMIT"))
|
||||
|
||||
connection.execute(text(f'SET search_path TO "{schema_name}"'))
|
||||
|
||||
@@ -309,6 +308,7 @@ async def run_async_migrations() -> None:
|
||||
schema_name=schema,
|
||||
create_schema=create_schema,
|
||||
)
|
||||
await connection.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema}: {e}")
|
||||
if not continue_on_error:
|
||||
@@ -346,6 +346,7 @@ async def run_async_migrations() -> None:
|
||||
schema_name=schema,
|
||||
create_schema=create_schema,
|
||||
)
|
||||
await connection.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema}: {e}")
|
||||
if not continue_on_error:
|
||||
|
||||
@@ -85,103 +85,122 @@ class UserRow(NamedTuple):
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Start transaction
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
# Step 1: Create or update the unified assistant (ID 0)
|
||||
search_assistant = conn.execute(
|
||||
sa.text("SELECT * FROM persona WHERE id = 0")
|
||||
).fetchone()
|
||||
|
||||
try:
|
||||
# Step 1: Create or update the unified assistant (ID 0)
|
||||
search_assistant = conn.execute(
|
||||
sa.text("SELECT * FROM persona WHERE id = 0")
|
||||
).fetchone()
|
||||
|
||||
if search_assistant:
|
||||
# Update existing Search assistant to be the unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET name = :name,
|
||||
description = :description,
|
||||
system_prompt = :system_prompt,
|
||||
num_chunks = :num_chunks,
|
||||
is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false,
|
||||
display_priority = :display_priority,
|
||||
llm_filter_extraction = :llm_filter_extraction,
|
||||
llm_relevance_filter = :llm_relevance_filter,
|
||||
recency_bias = :recency_bias,
|
||||
chunks_above = :chunks_above,
|
||||
chunks_below = :chunks_below,
|
||||
datetime_aware = :datetime_aware,
|
||||
starter_messages = null
|
||||
WHERE id = 0
|
||||
"""
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
else:
|
||||
# Create new unified assistant with ID 0
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona (
|
||||
id, name, description, system_prompt, num_chunks,
|
||||
is_default_persona, is_visible, deleted, display_priority,
|
||||
llm_filter_extraction, llm_relevance_filter, recency_bias,
|
||||
chunks_above, chunks_below, datetime_aware, starter_messages,
|
||||
builtin_persona
|
||||
) VALUES (
|
||||
0, :name, :description, :system_prompt, :num_chunks,
|
||||
true, true, false, :display_priority, :llm_filter_extraction,
|
||||
:llm_relevance_filter, :recency_bias, :chunks_above, :chunks_below,
|
||||
:datetime_aware, null, true
|
||||
)
|
||||
"""
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
|
||||
# Step 2: Mark ALL builtin assistants as deleted (except the unified assistant ID 0)
|
||||
if search_assistant:
|
||||
# Update existing Search assistant to be the unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = true, is_visible = false, is_default_persona = false
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
SET name = :name,
|
||||
description = :description,
|
||||
system_prompt = :system_prompt,
|
||||
num_chunks = :num_chunks,
|
||||
is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false,
|
||||
display_priority = :display_priority,
|
||||
llm_filter_extraction = :llm_filter_extraction,
|
||||
llm_relevance_filter = :llm_relevance_filter,
|
||||
recency_bias = :recency_bias,
|
||||
chunks_above = :chunks_above,
|
||||
chunks_below = :chunks_below,
|
||||
datetime_aware = :datetime_aware,
|
||||
starter_messages = null
|
||||
WHERE id = 0
|
||||
"""
|
||||
)
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
else:
|
||||
# Create new unified assistant with ID 0
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona (
|
||||
id, name, description, system_prompt, num_chunks,
|
||||
is_default_persona, is_visible, deleted, display_priority,
|
||||
llm_filter_extraction, llm_relevance_filter, recency_bias,
|
||||
chunks_above, chunks_below, datetime_aware, starter_messages,
|
||||
builtin_persona
|
||||
) VALUES (
|
||||
0, :name, :description, :system_prompt, :num_chunks,
|
||||
true, true, false, :display_priority, :llm_filter_extraction,
|
||||
:llm_relevance_filter, :recency_bias, :chunks_above, :chunks_below,
|
||||
:datetime_aware, null, true
|
||||
)
|
||||
"""
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
|
||||
# Step 3: Add all built-in tools to the unified assistant
|
||||
# First, get the tool IDs for SearchTool, ImageGenerationTool, and WebSearchTool
|
||||
search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'SearchTool'")
|
||||
).fetchone()
|
||||
# Step 2: Mark ALL builtin assistants as deleted (except the unified assistant ID 0)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = true, is_visible = false, is_default_persona = false
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
if not search_tool:
|
||||
raise ValueError(
|
||||
"SearchTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
# Step 3: Add all built-in tools to the unified assistant
|
||||
# First, get the tool IDs for SearchTool, ImageGenerationTool, and WebSearchTool
|
||||
search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'SearchTool'")
|
||||
).fetchone()
|
||||
|
||||
image_gen_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'ImageGenerationTool'")
|
||||
).fetchone()
|
||||
if not search_tool:
|
||||
raise ValueError(
|
||||
"SearchTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
|
||||
if not image_gen_tool:
|
||||
raise ValueError(
|
||||
"ImageGenerationTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
image_gen_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'ImageGenerationTool'")
|
||||
).fetchone()
|
||||
|
||||
# WebSearchTool is optional - may not be configured
|
||||
web_search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'WebSearchTool'")
|
||||
).fetchone()
|
||||
if not image_gen_tool:
|
||||
raise ValueError(
|
||||
"ImageGenerationTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
|
||||
# Clear existing tool associations for persona 0
|
||||
conn.execute(sa.text("DELETE FROM persona__tool WHERE persona_id = 0"))
|
||||
# WebSearchTool is optional - may not be configured
|
||||
web_search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'WebSearchTool'")
|
||||
).fetchone()
|
||||
|
||||
# Add tools to the unified assistant
|
||||
# Clear existing tool associations for persona 0
|
||||
conn.execute(sa.text("DELETE FROM persona__tool WHERE persona_id = 0"))
|
||||
|
||||
# Add tools to the unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": search_tool[0]},
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": image_gen_tool[0]},
|
||||
)
|
||||
|
||||
if web_search_tool:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
@@ -190,191 +209,148 @@ def upgrade() -> None:
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": search_tool[0]},
|
||||
{"tool_id": web_search_tool[0]},
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
# Step 4: Migrate existing chat sessions from all builtin assistants to unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
),
|
||||
{"tool_id": image_gen_tool[0]},
|
||||
UPDATE chat_session
|
||||
SET persona_id = 0
|
||||
WHERE persona_id IN (
|
||||
SELECT id FROM persona WHERE builtin_persona = true AND id != 0
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
if web_search_tool:
|
||||
# Step 5: Migrate user preferences - remove references to all builtin assistants
|
||||
# First, get all builtin assistant IDs (except 0)
|
||||
builtin_assistants_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id FROM persona
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
builtin_assistant_ids = [row[0] for row in builtin_assistants_result]
|
||||
|
||||
# Get all users with preferences
|
||||
users_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, chosen_assistants, visible_assistants,
|
||||
hidden_assistants, pinned_assistants
|
||||
FROM "user"
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for user_row in users_result:
|
||||
user = UserRow(*user_row)
|
||||
user_id: UUID = user.id
|
||||
updates: dict[str, Any] = {}
|
||||
|
||||
# Remove all builtin assistants from chosen_assistants
|
||||
if user.chosen_assistants:
|
||||
new_chosen: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.chosen_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_chosen != user.chosen_assistants:
|
||||
updates["chosen_assistants"] = json.dumps(new_chosen)
|
||||
|
||||
# Remove all builtin assistants from visible_assistants
|
||||
if user.visible_assistants:
|
||||
new_visible: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.visible_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_visible != user.visible_assistants:
|
||||
updates["visible_assistants"] = json.dumps(new_visible)
|
||||
|
||||
# Add all builtin assistants to hidden_assistants
|
||||
if user.hidden_assistants:
|
||||
new_hidden: list[int] = list(user.hidden_assistants)
|
||||
for old_id in builtin_assistant_ids:
|
||||
if old_id not in new_hidden:
|
||||
new_hidden.append(old_id)
|
||||
if new_hidden != user.hidden_assistants:
|
||||
updates["hidden_assistants"] = json.dumps(new_hidden)
|
||||
else:
|
||||
updates["hidden_assistants"] = json.dumps(builtin_assistant_ids)
|
||||
|
||||
# Remove all builtin assistants from pinned_assistants
|
||||
if user.pinned_assistants:
|
||||
new_pinned: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.pinned_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_pinned != user.pinned_assistants:
|
||||
updates["pinned_assistants"] = json.dumps(new_pinned)
|
||||
|
||||
# Apply updates if any
|
||||
if updates:
|
||||
set_clause = ", ".join([f"{k} = :{k}" for k in updates.keys()])
|
||||
updates["user_id"] = str(user_id) # Convert UUID to string for SQL
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": web_search_tool[0]},
|
||||
sa.text(f'UPDATE "user" SET {set_clause} WHERE id = :user_id'),
|
||||
updates,
|
||||
)
|
||||
|
||||
# Step 4: Migrate existing chat sessions from all builtin assistants to unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE chat_session
|
||||
SET persona_id = 0
|
||||
WHERE persona_id IN (
|
||||
SELECT id FROM persona WHERE builtin_persona = true AND id != 0
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Step 5: Migrate user preferences - remove references to all builtin assistants
|
||||
# First, get all builtin assistant IDs (except 0)
|
||||
builtin_assistants_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id FROM persona
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
builtin_assistant_ids = [row[0] for row in builtin_assistants_result]
|
||||
|
||||
# Get all users with preferences
|
||||
users_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, chosen_assistants, visible_assistants,
|
||||
hidden_assistants, pinned_assistants
|
||||
FROM "user"
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for user_row in users_result:
|
||||
user = UserRow(*user_row)
|
||||
user_id: UUID = user.id
|
||||
updates: dict[str, Any] = {}
|
||||
|
||||
# Remove all builtin assistants from chosen_assistants
|
||||
if user.chosen_assistants:
|
||||
new_chosen: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.chosen_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_chosen != user.chosen_assistants:
|
||||
updates["chosen_assistants"] = json.dumps(new_chosen)
|
||||
|
||||
# Remove all builtin assistants from visible_assistants
|
||||
if user.visible_assistants:
|
||||
new_visible: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.visible_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_visible != user.visible_assistants:
|
||||
updates["visible_assistants"] = json.dumps(new_visible)
|
||||
|
||||
# Add all builtin assistants to hidden_assistants
|
||||
if user.hidden_assistants:
|
||||
new_hidden: list[int] = list(user.hidden_assistants)
|
||||
for old_id in builtin_assistant_ids:
|
||||
if old_id not in new_hidden:
|
||||
new_hidden.append(old_id)
|
||||
if new_hidden != user.hidden_assistants:
|
||||
updates["hidden_assistants"] = json.dumps(new_hidden)
|
||||
else:
|
||||
updates["hidden_assistants"] = json.dumps(builtin_assistant_ids)
|
||||
|
||||
# Remove all builtin assistants from pinned_assistants
|
||||
if user.pinned_assistants:
|
||||
new_pinned: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.pinned_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_pinned != user.pinned_assistants:
|
||||
updates["pinned_assistants"] = json.dumps(new_pinned)
|
||||
|
||||
# Apply updates if any
|
||||
if updates:
|
||||
set_clause = ", ".join([f"{k} = :{k}" for k in updates.keys()])
|
||||
updates["user_id"] = str(user_id) # Convert UUID to string for SQL
|
||||
conn.execute(
|
||||
sa.text(f'UPDATE "user" SET {set_clause} WHERE id = :user_id'),
|
||||
updates,
|
||||
)
|
||||
|
||||
# Commit transaction
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on error
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Start transaction
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
try:
|
||||
# Only restore General (ID -1) and Art (ID -3) assistants
|
||||
# Step 1: Keep Search assistant (ID 0) as default but restore original state
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false
|
||||
WHERE id = 0
|
||||
# Only restore General (ID -1) and Art (ID -3) assistants
|
||||
# Step 1: Keep Search assistant (ID 0) as default but restore original state
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
)
|
||||
UPDATE persona
|
||||
SET is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false
|
||||
WHERE id = 0
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Step 2: Restore General assistant (ID -1)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :general_assistant_id
|
||||
# Step 2: Restore General assistant (ID -1)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
),
|
||||
{"general_assistant_id": GENERAL_ASSISTANT_ID},
|
||||
)
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :general_assistant_id
|
||||
"""
|
||||
),
|
||||
{"general_assistant_id": GENERAL_ASSISTANT_ID},
|
||||
)
|
||||
|
||||
# Step 3: Restore Art assistant (ID -3)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :art_assistant_id
|
||||
# Step 3: Restore Art assistant (ID -3)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
),
|
||||
{"art_assistant_id": ART_ASSISTANT_ID},
|
||||
)
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :art_assistant_id
|
||||
"""
|
||||
),
|
||||
{"art_assistant_id": ART_ASSISTANT_ID},
|
||||
)
|
||||
|
||||
# Note: We don't restore the original tool associations, names, or descriptions
|
||||
# as those would require more complex logic to determine original state.
|
||||
# We also cannot restore original chat session persona_ids as we don't
|
||||
# have the original mappings.
|
||||
# Other builtin assistants remain deleted as per the requirement.
|
||||
|
||||
# Commit transaction
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on error
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
# Note: We don't restore the original tool associations, names, or descriptions
|
||||
# as those would require more complex logic to determine original state.
|
||||
# We also cannot restore original chat session persona_ids as we don't
|
||||
# have the original mappings.
|
||||
# Other builtin assistants remain deleted as per the requirement.
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
"""backend driven notification details
|
||||
|
||||
Revision ID: 5c3dca366b35
|
||||
Revises: 9087b548dd69
|
||||
Create Date: 2026-01-06 16:03:11.413724
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5c3dca366b35"
|
||||
down_revision = "9087b548dd69"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"notification",
|
||||
sa.Column(
|
||||
"title", sa.String(), nullable=False, server_default="New Notification"
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"notification",
|
||||
sa.Column("description", sa.String(), nullable=True, server_default=""),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("notification", "title")
|
||||
op.drop_column("notification", "description")
|
||||
@@ -0,0 +1,49 @@
|
||||
"""notifications constraint, sort index, and cleanup old notifications
|
||||
|
||||
Revision ID: 8405ca81cc83
|
||||
Revises: a3c1a7904cd0
|
||||
Create Date: 2026-01-07 16:43:44.855156
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8405ca81cc83"
|
||||
down_revision = "a3c1a7904cd0"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create unique index for notification deduplication.
|
||||
# This enables atomic ON CONFLICT DO NOTHING inserts in batch_create_notifications.
|
||||
#
|
||||
# Uses COALESCE to handle NULL additional_data (NULLs are normally distinct
|
||||
# in unique constraints, but we want NULL == NULL for deduplication).
|
||||
# The '{}' represents an empty JSONB object as the NULL replacement.
|
||||
|
||||
# Clean up legacy notifications first
|
||||
op.execute("DELETE FROM notification WHERE title = 'New Notification'")
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS ix_notification_user_type_data
|
||||
ON notification (user_id, notif_type, COALESCE(additional_data, '{}'::jsonb))
|
||||
"""
|
||||
)
|
||||
|
||||
# Create index for efficient notification sorting by user
|
||||
# Covers: WHERE user_id = ? ORDER BY dismissed, first_shown DESC
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS ix_notification_user_sort
|
||||
ON notification (user_id, dismissed, first_shown DESC)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS ix_notification_user_type_data")
|
||||
op.execute("DROP INDEX IF EXISTS ix_notification_user_sort")
|
||||
@@ -42,20 +42,13 @@ TOOL_DESCRIPTIONS = {
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
try:
|
||||
for tool_id, description in TOOL_DESCRIPTIONS.items():
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"UPDATE tool SET description = :description WHERE in_code_tool_id = :tool_id"
|
||||
),
|
||||
{"description": description, "tool_id": tool_id},
|
||||
)
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
except Exception as e:
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
for tool_id, description in TOOL_DESCRIPTIONS.items():
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"UPDATE tool SET description = :description WHERE in_code_tool_id = :tool_id"
|
||||
),
|
||||
{"description": description, "tool_id": tool_id},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
"""remove userfile related deprecated fields
|
||||
|
||||
Revision ID: a3c1a7904cd0
|
||||
Revises: 5c3dca366b35
|
||||
Create Date: 2026-01-06 13:00:30.634396
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a3c1a7904cd0"
|
||||
down_revision = "5c3dca366b35"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("user_file", "document_id")
|
||||
op.drop_column("user_file", "document_id_migrated")
|
||||
op.drop_column("connector_credential_pair", "is_user_file")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column("is_user_file", sa.Boolean(), nullable=False, server_default="false"),
|
||||
)
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column("document_id", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column(
|
||||
"document_id_migrated", sa.Boolean(), nullable=False, server_default="true"
|
||||
),
|
||||
)
|
||||
@@ -7,7 +7,6 @@ Create Date: 2025-12-18 16:00:00.000000
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_DB_NAME
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
@@ -19,7 +18,7 @@ depends_on = None
|
||||
|
||||
|
||||
DEEP_RESEARCH_TOOL = {
|
||||
"name": RESEARCH_AGENT_DB_NAME,
|
||||
"name": "ResearchAgent",
|
||||
"display_name": "Research Agent",
|
||||
"description": "The Research Agent is a sub-agent that conducts research on a specific topic.",
|
||||
"in_code_tool_id": "ResearchAgent",
|
||||
|
||||
@@ -70,80 +70,66 @@ BUILT_IN_TOOLS = [
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Start transaction
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
# Get existing tools to check what already exists
|
||||
existing_tools = conn.execute(
|
||||
sa.text("SELECT in_code_tool_id FROM tool WHERE in_code_tool_id IS NOT NULL")
|
||||
).fetchall()
|
||||
existing_tool_ids = {row[0] for row in existing_tools}
|
||||
|
||||
try:
|
||||
# Get existing tools to check what already exists
|
||||
existing_tools = conn.execute(
|
||||
sa.text(
|
||||
"SELECT in_code_tool_id FROM tool WHERE in_code_tool_id IS NOT NULL"
|
||||
# Insert or update built-in tools
|
||||
for tool in BUILT_IN_TOOLS:
|
||||
in_code_id = tool["in_code_tool_id"]
|
||||
|
||||
# Handle historical rename: InternetSearchTool -> WebSearchTool
|
||||
if (
|
||||
in_code_id == "WebSearchTool"
|
||||
and "WebSearchTool" not in existing_tool_ids
|
||||
and "InternetSearchTool" in existing_tool_ids
|
||||
):
|
||||
# Rename the existing InternetSearchTool row in place and update fields
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description,
|
||||
in_code_tool_id = :in_code_tool_id
|
||||
WHERE in_code_tool_id = 'InternetSearchTool'
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
).fetchall()
|
||||
existing_tool_ids = {row[0] for row in existing_tools}
|
||||
# Keep the local view of existing ids in sync to avoid duplicate insert
|
||||
existing_tool_ids.discard("InternetSearchTool")
|
||||
existing_tool_ids.add("WebSearchTool")
|
||||
continue
|
||||
|
||||
# Insert or update built-in tools
|
||||
for tool in BUILT_IN_TOOLS:
|
||||
in_code_id = tool["in_code_tool_id"]
|
||||
|
||||
# Handle historical rename: InternetSearchTool -> WebSearchTool
|
||||
if (
|
||||
in_code_id == "WebSearchTool"
|
||||
and "WebSearchTool" not in existing_tool_ids
|
||||
and "InternetSearchTool" in existing_tool_ids
|
||||
):
|
||||
# Rename the existing InternetSearchTool row in place and update fields
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description,
|
||||
in_code_tool_id = :in_code_tool_id
|
||||
WHERE in_code_tool_id = 'InternetSearchTool'
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
# Keep the local view of existing ids in sync to avoid duplicate insert
|
||||
existing_tool_ids.discard("InternetSearchTool")
|
||||
existing_tool_ids.add("WebSearchTool")
|
||||
continue
|
||||
|
||||
if in_code_id in existing_tool_ids:
|
||||
# Update existing tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
else:
|
||||
# Insert new tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id)
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
|
||||
# Commit transaction
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on error
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
if in_code_id in existing_tool_ids:
|
||||
# Update existing tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
else:
|
||||
# Insert new tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id)
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
"""sync_exa_api_key_to_content_provider
|
||||
|
||||
Revision ID: d1b637d7050a
|
||||
Revises: d25168c2beee
|
||||
Create Date: 2026-01-09 15:54:15.646249
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
from sqlalchemy import text
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d1b637d7050a"
|
||||
down_revision = "d25168c2beee"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Exa uses a shared API key between search and content providers.
|
||||
# For existing Exa search providers with API keys, create the corresponding
|
||||
# content provider if it doesn't exist yet.
|
||||
connection = op.get_bind()
|
||||
|
||||
# Check if Exa search provider exists with an API key
|
||||
result = connection.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT api_key FROM internet_search_provider
|
||||
WHERE provider_type = 'exa' AND api_key IS NOT NULL
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
)
|
||||
row = result.fetchone()
|
||||
|
||||
if row:
|
||||
api_key = row[0]
|
||||
# Create Exa content provider with the shared key
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO internet_content_provider
|
||||
(name, provider_type, api_key, is_active)
|
||||
VALUES ('Exa', 'exa', :api_key, false)
|
||||
ON CONFLICT (name) DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"api_key": api_key},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the Exa content provider that was created by this migration
|
||||
connection = op.get_bind()
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
DELETE FROM internet_content_provider
|
||||
WHERE provider_type = 'exa'
|
||||
"""
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,86 @@
|
||||
"""tool_name_consistency
|
||||
|
||||
Revision ID: d25168c2beee
|
||||
Revises: 8405ca81cc83
|
||||
Create Date: 2026-01-11 17:54:40.135777
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d25168c2beee"
|
||||
down_revision = "8405ca81cc83"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
# Currently the seeded tools have the in_code_tool_id == name
|
||||
CURRENT_TOOL_NAME_MAPPING = [
|
||||
"SearchTool",
|
||||
"WebSearchTool",
|
||||
"ImageGenerationTool",
|
||||
"PythonTool",
|
||||
"OpenURLTool",
|
||||
"KnowledgeGraphTool",
|
||||
"ResearchAgent",
|
||||
]
|
||||
|
||||
# Mapping of in_code_tool_id -> name
|
||||
# These are the expected names that we want in the database
|
||||
EXPECTED_TOOL_NAME_MAPPING = {
|
||||
"SearchTool": "internal_search",
|
||||
"WebSearchTool": "web_search",
|
||||
"ImageGenerationTool": "generate_image",
|
||||
"PythonTool": "python",
|
||||
"OpenURLTool": "open_url",
|
||||
"KnowledgeGraphTool": "run_kg_search",
|
||||
"ResearchAgent": "research_agent",
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Mapping of in_code_tool_id to the NAME constant from each tool class
|
||||
# These match the .name property of each tool implementation
|
||||
tool_name_mapping = EXPECTED_TOOL_NAME_MAPPING
|
||||
|
||||
# Update the name column for each tool based on its in_code_tool_id
|
||||
for in_code_tool_id, expected_name in tool_name_mapping.items():
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :expected_name
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
{
|
||||
"expected_name": expected_name,
|
||||
"in_code_tool_id": in_code_tool_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Reverse the migration by setting name back to in_code_tool_id
|
||||
# This matches the original pattern where name was the class name
|
||||
for in_code_tool_id in CURRENT_TOOL_NAME_MAPPING:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :current_name
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
{
|
||||
"current_name": in_code_tool_id,
|
||||
"in_code_tool_id": in_code_tool_id,
|
||||
},
|
||||
)
|
||||
@@ -109,7 +109,6 @@ CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS = float(
|
||||
|
||||
|
||||
STRIPE_SECRET_KEY = os.environ.get("STRIPE_SECRET_KEY")
|
||||
STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")
|
||||
|
||||
# JWT Public Key URL
|
||||
JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
|
||||
|
||||
@@ -3,30 +3,42 @@ from uuid import UUID
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.server.features.persona.models import PersonaSharedNotificationData
|
||||
|
||||
|
||||
def make_persona_private(
|
||||
def update_persona_access(
|
||||
persona_id: int,
|
||||
creator_user_id: UUID | None,
|
||||
user_ids: list[UUID] | None,
|
||||
group_ids: list[int] | None,
|
||||
db_session: Session,
|
||||
is_public: bool | None = None,
|
||||
user_ids: list[UUID] | None = None,
|
||||
group_ids: list[int] | None = None,
|
||||
) -> None:
|
||||
"""NOTE(rkuo): This function batches all updates into a single commit. If we don't
|
||||
dedupe the inputs, the commit will exception."""
|
||||
"""Updates the access settings for a persona including public status, user shares,
|
||||
and group shares.
|
||||
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
db_session.query(Persona__UserGroup).filter(
|
||||
Persona__UserGroup.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
NOTE: This function batches all updates. If we don't dedupe the inputs,
|
||||
the commit will exception.
|
||||
|
||||
NOTE: Callers are responsible for committing."""
|
||||
|
||||
if is_public is not None:
|
||||
persona = db_session.query(Persona).filter(Persona.id == persona_id).first()
|
||||
if persona:
|
||||
persona.is_public = is_public
|
||||
|
||||
# NOTE: For user-ids and group-ids, `None` means "leave unchanged", `[]` means "clear all shares",
|
||||
# and a non-empty list means "replace with these shares".
|
||||
|
||||
if user_ids is not None:
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
if user_ids:
|
||||
user_ids_set = set(user_ids)
|
||||
for user_id in user_ids_set:
|
||||
db_session.add(Persona__User(persona_id=persona_id, user_id=user_id))
|
||||
@@ -34,17 +46,20 @@ def make_persona_private(
|
||||
create_notification(
|
||||
user_id=user_id,
|
||||
notif_type=NotificationType.PERSONA_SHARED,
|
||||
title="A new agent was shared with you!",
|
||||
db_session=db_session,
|
||||
additional_data=PersonaSharedNotificationData(
|
||||
persona_id=persona_id,
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
if group_ids:
|
||||
if group_ids is not None:
|
||||
db_session.query(Persona__UserGroup).filter(
|
||||
Persona__UserGroup.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
group_ids_set = set(group_ids)
|
||||
for group_id in group_ids_set:
|
||||
db_session.add(
|
||||
Persona__UserGroup(persona_id=persona_id, user_group_id=group_id)
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
@@ -21,8 +21,9 @@ from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.server.utils import PUBLIC_API_TAGS
|
||||
|
||||
router = APIRouter(prefix="/analytics")
|
||||
router = APIRouter(prefix="/analytics", tags=PUBLIC_API_TAGS)
|
||||
|
||||
|
||||
_DEFAULT_LOOKBACK_DAYS = 30
|
||||
|
||||
@@ -23,6 +23,7 @@ from onyx.db.models import User
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -100,6 +101,7 @@ def handle_simplified_chat_message(
|
||||
chunks_below=0,
|
||||
full_doc=chat_message_req.full_doc,
|
||||
structured_response_format=chat_message_req.structured_response_format,
|
||||
origin=MessageOrigin.API,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
@@ -203,6 +205,7 @@ def handle_send_message_simple_with_history(
|
||||
chunks_below=0,
|
||||
full_doc=req.full_doc,
|
||||
structured_response_format=req.structured_response_format,
|
||||
origin=MessageOrigin.API,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
|
||||
@@ -48,6 +48,7 @@ from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from onyx.server.query_and_chat.models import ChatSessionDetails
|
||||
from onyx.server.query_and_chat.models import ChatSessionsResponse
|
||||
from onyx.server.utils import PUBLIC_API_TAGS
|
||||
from onyx.utils.threadpool_concurrency import parallel_yield
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
@@ -294,7 +295,7 @@ def list_all_query_history_exports(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/admin/query-history/start-export")
|
||||
@router.post("/admin/query-history/start-export", tags=PUBLIC_API_TAGS)
|
||||
def start_query_history_export(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
@@ -340,7 +341,7 @@ def start_query_history_export(
|
||||
return {"request_id": task_id}
|
||||
|
||||
|
||||
@router.get("/admin/query-history/export-status")
|
||||
@router.get("/admin/query-history/export-status", tags=PUBLIC_API_TAGS)
|
||||
def get_query_history_export_status(
|
||||
request_id: str,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
@@ -374,7 +375,7 @@ def get_query_history_export_status(
|
||||
return {"status": TaskStatus.SUCCESS}
|
||||
|
||||
|
||||
@router.get("/admin/query-history/download")
|
||||
@router.get("/admin/query-history/download", tags=PUBLIC_API_TAGS)
|
||||
def download_query_history_csv(
|
||||
request_id: str,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
import requests
|
||||
import stripe
|
||||
|
||||
from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
@@ -16,15 +16,21 @@ stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def fetch_stripe_checkout_session(tenant_id: str) -> str:
|
||||
def fetch_stripe_checkout_session(
|
||||
tenant_id: str,
|
||||
billing_period: Literal["monthly", "annual"] = "monthly",
|
||||
) -> str:
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/create-checkout-session"
|
||||
params = {"tenant_id": tenant_id}
|
||||
response = requests.post(url, headers=headers, params=params)
|
||||
payload = {
|
||||
"tenant_id": tenant_id,
|
||||
"billing_period": billing_period,
|
||||
}
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()["sessionId"]
|
||||
|
||||
@@ -72,22 +78,24 @@ def fetch_billing_information(
|
||||
|
||||
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:
|
||||
"""
|
||||
Send a request to the control service to register the number of users for a tenant.
|
||||
Update the number of seats for a tenant's subscription.
|
||||
Preserves the existing price (monthly, annual, or grandfathered).
|
||||
"""
|
||||
|
||||
if not STRIPE_PRICE_ID:
|
||||
raise Exception("STRIPE_PRICE_ID is not set")
|
||||
|
||||
response = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_subscription_id = cast(str, response.get("stripe_subscription_id"))
|
||||
|
||||
subscription = stripe.Subscription.retrieve(stripe_subscription_id)
|
||||
subscription_item = subscription["items"]["data"][0]
|
||||
|
||||
# Use existing price to preserve the customer's current plan
|
||||
current_price_id = subscription_item.price.id
|
||||
|
||||
updated_subscription = stripe.Subscription.modify(
|
||||
stripe_subscription_id,
|
||||
items=[
|
||||
{
|
||||
"id": subscription["items"]["data"][0].id,
|
||||
"price": STRIPE_PRICE_ID,
|
||||
"id": subscription_item.id,
|
||||
"price": current_price_id,
|
||||
"quantity": number_of_users,
|
||||
}
|
||||
],
|
||||
|
||||
@@ -10,6 +10,7 @@ from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
|
||||
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import CreateSubscriptionSessionRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingFullSyncRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingResponse
|
||||
@@ -104,15 +105,18 @@ async def create_customer_portal_session(
|
||||
|
||||
@router.post("/create-subscription-session")
|
||||
async def create_subscription_session(
|
||||
request: CreateSubscriptionSessionRequest | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> SubscriptionSessionResponse:
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="Tenant ID not found")
|
||||
session_id = fetch_stripe_checkout_session(tenant_id)
|
||||
|
||||
billing_period = request.billing_period if request else "monthly"
|
||||
session_id = fetch_stripe_checkout_session(tenant_id, billing_period)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create resubscription session")
|
||||
logger.exception("Failed to create subscription session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -73,6 +74,12 @@ class SubscriptionSessionResponse(BaseModel):
|
||||
sessionId: str
|
||||
|
||||
|
||||
class CreateSubscriptionSessionRequest(BaseModel):
|
||||
"""Request to create a subscription checkout session."""
|
||||
|
||||
billing_period: Literal["monthly", "annual"] = "monthly"
|
||||
|
||||
|
||||
class TenantByDomainResponse(BaseModel):
|
||||
tenant_id: str
|
||||
number_of_users: int
|
||||
|
||||
@@ -16,8 +16,9 @@ from onyx.db.token_limit import insert_user_token_rate_limit
|
||||
from onyx.server.query_and_chat.token_limit import any_rate_limit_exists
|
||||
from onyx.server.token_rate_limits.models import TokenRateLimitArgs
|
||||
from onyx.server.token_rate_limits.models import TokenRateLimitDisplay
|
||||
from onyx.server.utils import PUBLIC_API_TAGS
|
||||
|
||||
router = APIRouter(prefix="/admin/token-rate-limits")
|
||||
router = APIRouter(prefix="/admin/token-rate-limits", tags=PUBLIC_API_TAGS)
|
||||
|
||||
|
||||
"""
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
"""EE Usage limits - trial detection via billing information."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
@@ -31,13 +28,7 @@ def is_tenant_on_trial(tenant_id: str) -> bool:
|
||||
return True
|
||||
|
||||
if isinstance(billing_info, BillingInformation):
|
||||
# Check if trial is active
|
||||
if billing_info.trial_end is not None:
|
||||
now = datetime.now(timezone.utc)
|
||||
# Trial active if trial_end is in the future
|
||||
# and subscription status indicates trialing
|
||||
if billing_info.trial_end > now and billing_info.status == "trialing":
|
||||
return True
|
||||
return billing_info.status == "trialing"
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@@ -21,11 +21,12 @@ from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.server.utils import PUBLIC_API_TAGS
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/manage")
|
||||
router = APIRouter(prefix="/manage", tags=PUBLIC_API_TAGS)
|
||||
|
||||
|
||||
@router.get("/admin/user-group")
|
||||
|
||||
@@ -105,6 +105,8 @@ class DocExternalAccess:
|
||||
)
|
||||
|
||||
|
||||
# TODO(andrei): First refactor this into a pydantic model, then get rid of
|
||||
# duplicate fields.
|
||||
@dataclass(frozen=True, init=False)
|
||||
class DocumentAccess(ExternalAccess):
|
||||
# User emails for Onyx users, None indicates admin
|
||||
|
||||
@@ -124,6 +124,7 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.kg_processing",
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
# Light worker tasks
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
|
||||
@@ -98,8 +98,5 @@ for bootstep in base_bootsteps:
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
# Ensure the user files indexing worker registers the doc_id migration task
|
||||
# TODO(subash): remove this once the doc_id migration is complete
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -2,9 +2,12 @@ import copy
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from celery.schedules import crontab
|
||||
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -54,16 +57,6 @@ beat_task_templates: list[dict] = [
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "user-file-docid-migration",
|
||||
"task": OnyxCeleryTask.USER_FILE_DOCID_MIGRATION,
|
||||
"schedule": timedelta(minutes=10),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.USER_FILES_INDEXING,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-kg-processing",
|
||||
"task": OnyxCeleryTask.CHECK_KG_PROCESSING,
|
||||
@@ -181,7 +174,26 @@ if AUTO_LLM_CONFIG_URL:
|
||||
"schedule": timedelta(seconds=AUTO_LLM_UPDATE_INTERVAL_SECONDS),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": AUTO_LLM_UPDATE_INTERVAL_SECONDS,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Add scheduled eval task if datasets are configured
|
||||
if SCHEDULED_EVAL_DATASET_NAMES:
|
||||
beat_task_templates.append(
|
||||
{
|
||||
"name": "scheduled-eval-pipeline",
|
||||
"task": OnyxCeleryTask.SCHEDULED_EVAL_TASK,
|
||||
# run every Sunday at midnight UTC
|
||||
"schedule": crontab(
|
||||
hour=0,
|
||||
minute=0,
|
||||
day_of_week=0,
|
||||
),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -72,15 +72,6 @@ def try_creating_docfetching_task(
|
||||
# Another indexing attempt is already running
|
||||
return None
|
||||
|
||||
# Determine which queue to use based on whether this is a user file
|
||||
# TODO: at the moment the indexing pipeline is
|
||||
# shared between user files and connectors
|
||||
queue = (
|
||||
OnyxCeleryQueues.USER_FILES_INDEXING
|
||||
if cc_pair.is_user_file
|
||||
else OnyxCeleryQueues.CONNECTOR_DOC_FETCHING
|
||||
)
|
||||
|
||||
# Use higher priority for first-time indexing to ensure new connectors
|
||||
# get processed before re-indexing of existing connectors
|
||||
has_successful_attempt = cc_pair.last_successful_index_time is not None
|
||||
@@ -99,7 +90,7 @@ def try_creating_docfetching_task(
|
||||
search_settings_id=search_settings.id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=queue,
|
||||
queue=OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
|
||||
task_id=custom_task_id,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
@@ -40,9 +41,11 @@ from onyx.background.indexing.checkpointing_utils import (
|
||||
)
|
||||
from onyx.background.indexing.index_attempt_utils import cleanup_index_attempts
|
||||
from onyx.background.indexing.index_attempt_utils import get_old_index_attempts
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
@@ -59,11 +62,9 @@ from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from onyx.db.connector_credential_pair import (
|
||||
fetch_indexable_standard_connector_credential_pair_ids,
|
||||
)
|
||||
from onyx.db.connector_credential_pair import (
|
||||
fetch_indexable_user_file_connector_credential_pair_ids,
|
||||
)
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import set_cc_pair_repeated_error_state
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair_from_id
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.time_utils import get_db_current_time
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
@@ -86,7 +87,6 @@ from onyx.db.models import SearchSettings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.db.swap_index import check_and_perform_index_swap
|
||||
from onyx.db.usage import UsageLimitExceededError
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.file_store.document_batch_storage import DocumentBatchStorage
|
||||
from onyx.file_store.document_batch_storage import get_document_batch_storage
|
||||
@@ -540,12 +540,7 @@ def check_indexing_completion(
|
||||
]:
|
||||
# User file connectors must be paused on success
|
||||
# NOTE: _run_indexing doesn't update connectors if the index attempt is the future embedding model
|
||||
# TODO: figure out why this doesn't pause connectors during swap
|
||||
cc_pair.status = (
|
||||
ConnectorCredentialPairStatus.PAUSED
|
||||
if cc_pair.is_user_file
|
||||
else ConnectorCredentialPairStatus.ACTIVE
|
||||
)
|
||||
cc_pair.status = ConnectorCredentialPairStatus.ACTIVE
|
||||
db_session.commit()
|
||||
|
||||
mt_cloud_telemetry(
|
||||
@@ -811,13 +806,8 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
db_session, active_cc_pairs_only=True
|
||||
)
|
||||
)
|
||||
user_file_cc_pair_ids = (
|
||||
fetch_indexable_user_file_connector_credential_pair_ids(
|
||||
db_session, search_settings_id=current_search_settings.id
|
||||
)
|
||||
)
|
||||
|
||||
primary_cc_pair_ids = standard_cc_pair_ids + user_file_cc_pair_ids
|
||||
primary_cc_pair_ids = standard_cc_pair_ids
|
||||
|
||||
# Get CC pairs for secondary search settings
|
||||
secondary_cc_pair_ids: list[int] = []
|
||||
@@ -833,30 +823,47 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
db_session, active_cc_pairs_only=not include_paused
|
||||
)
|
||||
)
|
||||
user_file_cc_pair_ids = (
|
||||
fetch_indexable_user_file_connector_credential_pair_ids(
|
||||
db_session, search_settings_id=secondary_search_settings.id
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
secondary_cc_pair_ids = standard_cc_pair_ids + user_file_cc_pair_ids
|
||||
secondary_cc_pair_ids = standard_cc_pair_ids
|
||||
|
||||
# Flag CC pairs in repeated error state for primary/current search settings
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for cc_pair_id in primary_cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
|
||||
if is_in_repeated_error_state(
|
||||
cc_pair_id=cc_pair_id,
|
||||
search_settings_id=current_search_settings.id,
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
# if already in repeated error state, don't do anything
|
||||
# this is important so that we don't keep pausing the connector
|
||||
# immediately upon a user un-pausing it to manually re-trigger and
|
||||
# recover.
|
||||
if (
|
||||
cc_pair
|
||||
and not cc_pair.in_repeated_error_state
|
||||
and is_in_repeated_error_state(
|
||||
cc_pair=cc_pair,
|
||||
search_settings_id=current_search_settings.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
):
|
||||
set_cc_pair_repeated_error_state(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
in_repeated_error_state=True,
|
||||
)
|
||||
# When entering repeated error state, also pause the connector
|
||||
# to prevent continued indexing retry attempts burning through embedding credits.
|
||||
# NOTE: only for Cloud, since most self-hosted users use self-hosted embedding
|
||||
# models. Also, they are more prone to repeated failures -> eventual success.
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
update_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
status=ConnectorCredentialPairStatus.PAUSED,
|
||||
)
|
||||
|
||||
# NOTE: At this point, we haven't done heavy checks on whether or not the CC pairs should actually be indexed
|
||||
# Heavy check, should_index(), is called in _kickoff_indexing_tasks
|
||||
@@ -1289,19 +1296,14 @@ def _check_chunk_usage_limit(tenant_id: str) -> None:
|
||||
if not USAGE_LIMITS_ENABLED:
|
||||
return
|
||||
|
||||
from onyx.db.usage import check_usage_limit
|
||||
from onyx.db.usage import UsageType
|
||||
from onyx.server.usage_limits import get_limit_for_usage_type
|
||||
from onyx.server.usage_limits import is_tenant_on_trial
|
||||
|
||||
is_trial = is_tenant_on_trial(tenant_id)
|
||||
limit = get_limit_for_usage_type(UsageType.CHUNKS_INDEXED, is_trial)
|
||||
from onyx.server.usage_limits import check_usage_and_raise
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
check_usage_limit(
|
||||
check_usage_and_raise(
|
||||
db_session=db_session,
|
||||
usage_type=UsageType.CHUNKS_INDEXED,
|
||||
limit=limit,
|
||||
tenant_id=tenant_id,
|
||||
pending_amount=0, # Just check current usage
|
||||
)
|
||||
|
||||
@@ -1321,7 +1323,7 @@ def _docprocessing_task(
|
||||
if USAGE_LIMITS_ENABLED:
|
||||
try:
|
||||
_check_chunk_usage_limit(tenant_id)
|
||||
except UsageLimitExceededError as e:
|
||||
except HTTPException as e:
|
||||
# Log the error and fail the indexing attempt
|
||||
task_logger.error(
|
||||
f"Chunk indexing usage limit exceeded for tenant {tenant_id}: {e}"
|
||||
|
||||
@@ -10,7 +10,6 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine.time_utils import get_db_current_time
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
@@ -126,18 +125,9 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
|
||||
|
||||
def is_in_repeated_error_state(
|
||||
cc_pair_id: int, search_settings_id: int, db_session: Session
|
||||
cc_pair: ConnectorCredentialPair, search_settings_id: int, db_session: Session
|
||||
) -> bool:
|
||||
"""Checks if the cc pair / search setting combination is in a repeated error state."""
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise RuntimeError(
|
||||
f"is_in_repeated_error_state - could not find cc_pair with id={cc_pair_id}"
|
||||
)
|
||||
|
||||
# if the connector doesn't have a refresh_freq, a single failed attempt is enough
|
||||
number_of_failed_attempts_in_a_row_needed = (
|
||||
NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE
|
||||
@@ -146,7 +136,7 @@ def is_in_repeated_error_state(
|
||||
)
|
||||
|
||||
most_recent_index_attempts = get_recent_attempts_for_cc_pair(
|
||||
cc_pair_id=cc_pair_id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
search_settings_id=search_settings_id,
|
||||
limit=number_of_failed_attempts_in_a_row_needed,
|
||||
db_session=db_session,
|
||||
@@ -180,7 +170,7 @@ def should_index(
|
||||
db_session=db_session,
|
||||
)
|
||||
all_recent_errored = is_in_repeated_error_state(
|
||||
cc_pair_id=cc_pair.id,
|
||||
cc_pair=cc_pair,
|
||||
search_settings_id=search_settings_instance.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from onyx.configs.app_configs import BRAINTRUST_API_KEY
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_PERMISSIONS_EMAIL
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_PROJECT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.evals.eval import run_eval
|
||||
from onyx.evals.models import EvalConfigurationOptions
|
||||
@@ -33,3 +39,109 @@ def eval_run_task(
|
||||
except Exception:
|
||||
logger.error("Failed to run eval task")
|
||||
raise
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.SCHEDULED_EVAL_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT * 5, # Allow more time for multiple datasets
|
||||
bind=True,
|
||||
trail=False,
|
||||
)
|
||||
def scheduled_eval_task(self: Task, **kwargs: Any) -> None:
|
||||
"""
|
||||
Scheduled task to run evaluations on configured datasets.
|
||||
Runs weekly on Sunday at midnight UTC.
|
||||
|
||||
Configure via environment variables (with defaults):
|
||||
- SCHEDULED_EVAL_DATASET_NAMES: Comma-separated list of Braintrust dataset names
|
||||
- SCHEDULED_EVAL_PERMISSIONS_EMAIL: Email for search permissions (default: roshan@onyx.app)
|
||||
- SCHEDULED_EVAL_PROJECT: Braintrust project name
|
||||
"""
|
||||
if not BRAINTRUST_API_KEY:
|
||||
logger.error("BRAINTRUST_API_KEY is not configured, cannot run scheduled evals")
|
||||
return
|
||||
|
||||
if not SCHEDULED_EVAL_PROJECT:
|
||||
logger.error(
|
||||
"SCHEDULED_EVAL_PROJECT is not configured, cannot run scheduled evals"
|
||||
)
|
||||
return
|
||||
|
||||
if not SCHEDULED_EVAL_DATASET_NAMES:
|
||||
logger.info("No scheduled eval datasets configured, skipping")
|
||||
return
|
||||
|
||||
if not SCHEDULED_EVAL_PERMISSIONS_EMAIL:
|
||||
logger.error("SCHEDULED_EVAL_PERMISSIONS_EMAIL not configured")
|
||||
return
|
||||
|
||||
project_name = SCHEDULED_EVAL_PROJECT
|
||||
dataset_names = SCHEDULED_EVAL_DATASET_NAMES
|
||||
permissions_email = SCHEDULED_EVAL_PERMISSIONS_EMAIL
|
||||
|
||||
# Create a timestamp for the scheduled run
|
||||
run_timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
|
||||
logger.info(
|
||||
f"Starting scheduled eval pipeline for project '{project_name}' "
|
||||
f"with {len(dataset_names)} dataset(s): {dataset_names}"
|
||||
)
|
||||
|
||||
pipeline_start = datetime.now(timezone.utc)
|
||||
results: list[dict[str, Any]] = []
|
||||
|
||||
for dataset_name in dataset_names:
|
||||
start_time = datetime.now(timezone.utc)
|
||||
error_message: str | None = None
|
||||
success = False
|
||||
|
||||
# Create informative experiment name for scheduled runs
|
||||
experiment_name = f"{dataset_name} - {run_timestamp}"
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
f"Running scheduled eval for dataset: {dataset_name} "
|
||||
f"(project: {project_name})"
|
||||
)
|
||||
|
||||
configuration = EvalConfigurationOptions(
|
||||
search_permissions_email=permissions_email,
|
||||
dataset_name=dataset_name,
|
||||
no_send_logs=False,
|
||||
braintrust_project=project_name,
|
||||
experiment_name=experiment_name,
|
||||
)
|
||||
|
||||
result = run_eval(
|
||||
configuration=configuration,
|
||||
remote_dataset_name=dataset_name,
|
||||
)
|
||||
success = result.success
|
||||
logger.info(f"Completed eval for {dataset_name}: success={success}")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to run scheduled eval for {dataset_name}")
|
||||
error_message = str(e)
|
||||
success = False
|
||||
|
||||
end_time = datetime.now(timezone.utc)
|
||||
|
||||
results.append(
|
||||
{
|
||||
"dataset_name": dataset_name,
|
||||
"success": success,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"error_message": error_message,
|
||||
}
|
||||
)
|
||||
|
||||
pipeline_end = datetime.now(timezone.utc)
|
||||
total_duration = (pipeline_end - pipeline_start).total_seconds()
|
||||
|
||||
passed_count = sum(1 for r in results if r["success"])
|
||||
logger.info(
|
||||
f"Scheduled eval pipeline completed: {passed_count}/{len(results)} passed "
|
||||
f"in {total_duration:.1f}s"
|
||||
)
|
||||
|
||||
@@ -5,6 +5,9 @@ from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
sync_llm_models_from_github,
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -26,24 +29,9 @@ def check_for_auto_llm_updates(self: Task, *, tenant_id: str) -> bool | None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
fetch_llm_recommendations_from_github,
|
||||
)
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
sync_llm_models_from_github,
|
||||
)
|
||||
|
||||
# Fetch config from GitHub
|
||||
config = fetch_llm_recommendations_from_github()
|
||||
|
||||
if not config:
|
||||
task_logger.warning("Failed to fetch GitHub config")
|
||||
return None
|
||||
|
||||
# Sync to database
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
results = sync_llm_models_from_github(db_session, config)
|
||||
results = sync_llm_models_from_github(db_session)
|
||||
|
||||
if results:
|
||||
task_logger.info(f"Auto mode sync results: {results}")
|
||||
|
||||
@@ -886,9 +886,7 @@ def monitor_celery_queues_helper(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
|
||||
)
|
||||
n_docprocessing = celery_get_queue_length(OnyxCeleryQueues.DOCPROCESSING, r_celery)
|
||||
n_user_files_indexing = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILES_INDEXING, r_celery
|
||||
)
|
||||
|
||||
n_user_file_processing = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
|
||||
)
|
||||
@@ -924,7 +922,6 @@ def monitor_celery_queues_helper(
|
||||
f"docfetching_prefetched={len(n_docfetching_prefetched)} "
|
||||
f"docprocessing={n_docprocessing} "
|
||||
f"docprocessing_prefetched={len(n_docprocessing_prefetched)} "
|
||||
f"user_files_indexing={n_user_files_indexing} "
|
||||
f"user_file_processing={n_user_file_processing} "
|
||||
f"user_file_project_sync={n_user_file_project_sync} "
|
||||
f"user_file_delete={n_user_file_delete} "
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import datetime
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
@@ -13,39 +12,33 @@ from retry import retry
|
||||
from sqlalchemy import select
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import FileRecord
|
||||
from onyx.db.models import SearchDoc
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.document_index.vespa.shared_utils.utils import (
|
||||
replace_invalid_doc_id_characters,
|
||||
)
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.file_store import S3BackedFileStore
|
||||
from onyx.file_store.utils import user_file_id_to_plaintext_file_name
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.adapters.user_file_indexing_adapter import UserFileIndexingAdapter
|
||||
@@ -63,6 +56,17 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_queued_key(user_file_id: str | UUID) -> str:
|
||||
"""Key that exists while a process_single_user_file task is sitting in the queue.
|
||||
|
||||
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
before enqueuing and the worker deletes it as its first action. This prevents
|
||||
the beat from adding duplicate tasks for files that already have a live task
|
||||
in flight.
|
||||
"""
|
||||
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
@@ -126,7 +130,24 @@ def _get_document_chunk_count(
|
||||
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
|
||||
|
||||
Uses direct Redis locks to avoid overlapping runs.
|
||||
Three mechanisms prevent queue runaway:
|
||||
|
||||
1. **Queue depth backpressure** – if the broker queue already has more than
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
|
||||
entirely. Workers are clearly behind; adding more tasks would only make
|
||||
the backlog worse.
|
||||
|
||||
2. **Per-file queued guard** – before enqueuing a task we set a short-lived
|
||||
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
|
||||
already exists the file already has a live task in the queue, so we skip
|
||||
it. The worker deletes the key the moment it picks up the task so the
|
||||
next beat cycle can re-enqueue if the file is still PROCESSING.
|
||||
|
||||
3. **Task expiry** – every enqueued task carries an `expires` value equal to
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
|
||||
the queue after that deadline, Celery discards it without touching the DB.
|
||||
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
|
||||
Redis restart), stale tasks evict themselves rather than piling up forever.
|
||||
"""
|
||||
task_logger.info("check_user_file_processing - Starting")
|
||||
|
||||
@@ -141,7 +162,21 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
enqueued = 0
|
||||
skipped_guard = 0
|
||||
try:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
queue_len = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
|
||||
)
|
||||
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
f"check_user_file_processing - Queue depth {queue_len} exceeds "
|
||||
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file_ids = (
|
||||
db_session.execute(
|
||||
@@ -154,12 +189,35 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
)
|
||||
|
||||
for user_file_id in user_file_ids:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
# --- Protection 2: per-file queued guard ---
|
||||
queued_key = _user_file_queued_key(user_file_id)
|
||||
guard_set = redis_client.set(
|
||||
queued_key,
|
||||
1,
|
||||
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
nx=True,
|
||||
)
|
||||
if not guard_set:
|
||||
skipped_guard += 1
|
||||
continue
|
||||
|
||||
# --- Protection 3: task expiry ---
|
||||
# If task submission fails, clear the guard immediately so the
|
||||
# next beat cycle can retry enqueuing this file.
|
||||
try:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={
|
||||
"user_file_id": str(user_file_id),
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
)
|
||||
except Exception:
|
||||
redis_client.delete(queued_key)
|
||||
raise
|
||||
enqueued += 1
|
||||
|
||||
finally:
|
||||
@@ -167,7 +225,8 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
|
||||
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
|
||||
f"tasks for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -182,6 +241,12 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
|
||||
start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Clear the "queued" guard set by the beat generator so that the next beat
|
||||
# cycle can re-enqueue this file if it is still in PROCESSING state after
|
||||
# this task completes or fails.
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
@@ -618,315 +683,3 @@ def process_single_user_file_project_sync(
|
||||
file_lock.release()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_legacy_user_file_doc_id(old_id: str) -> str:
|
||||
# Convert USER_FILE_CONNECTOR__<uuid> -> FILE_CONNECTOR__<uuid> for legacy values
|
||||
user_prefix = "USER_FILE_CONNECTOR__"
|
||||
file_prefix = "FILE_CONNECTOR__"
|
||||
if old_id.startswith(user_prefix):
|
||||
remainder = old_id[len(user_prefix) :]
|
||||
return file_prefix + remainder
|
||||
return old_id
|
||||
|
||||
|
||||
def update_legacy_plaintext_file_records() -> None:
|
||||
"""Migrate legacy plaintext cache objects from int-based keys to UUID-based
|
||||
keys. Copies each S3 object to its expected UUID key and updates DB.
|
||||
|
||||
Examples:
|
||||
- Old key: bucket/schema/plaintext_<int>
|
||||
- New key: bucket/schema/plaintext_<uuid>
|
||||
"""
|
||||
|
||||
task_logger.info("update_legacy_plaintext_file_records - Starting")
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
store = get_default_file_store()
|
||||
|
||||
if not isinstance(store, S3BackedFileStore):
|
||||
task_logger.info(
|
||||
"update_legacy_plaintext_file_records - Skipping non-S3 store"
|
||||
)
|
||||
return
|
||||
|
||||
s3_client = store._get_s3_client()
|
||||
bucket_name = store._get_bucket_name()
|
||||
|
||||
# Select PLAINTEXT_CACHE records whose object_key ends with 'plaintext_' + non-hyphen chars
|
||||
# Example: 'some/path/plaintext_abc123' matches; '.../plaintext_foo-bar' does not
|
||||
plaintext_records: Sequence[FileRecord] = (
|
||||
db_session.execute(
|
||||
sa.select(FileRecord).where(
|
||||
FileRecord.file_origin == FileOrigin.PLAINTEXT_CACHE,
|
||||
FileRecord.object_key.op("~")(r"plaintext_[^-]+$"),
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"update_legacy_plaintext_file_records - Found {len(plaintext_records)} plaintext records to update"
|
||||
)
|
||||
|
||||
normalized = 0
|
||||
for fr in plaintext_records:
|
||||
try:
|
||||
expected_key = store._get_s3_key(fr.file_id)
|
||||
if fr.object_key == expected_key:
|
||||
continue
|
||||
|
||||
if fr.bucket_name is None:
|
||||
task_logger.warning(f"id={fr.file_id} - Bucket name is None")
|
||||
continue
|
||||
|
||||
if fr.object_key is None:
|
||||
task_logger.warning(f"id={fr.file_id} - Object key is None")
|
||||
continue
|
||||
|
||||
# Copy old object to new key
|
||||
copy_source = f"{fr.bucket_name}/{fr.object_key}"
|
||||
s3_client.copy_object(
|
||||
CopySource=copy_source,
|
||||
Bucket=bucket_name,
|
||||
Key=expected_key,
|
||||
MetadataDirective="COPY",
|
||||
)
|
||||
|
||||
# Delete old object (best-effort)
|
||||
try:
|
||||
s3_client.delete_object(Bucket=fr.bucket_name, Key=fr.object_key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Update DB record with new key
|
||||
fr.object_key = expected_key
|
||||
db_session.add(fr)
|
||||
normalized += 1
|
||||
except Exception as e:
|
||||
task_logger.warning(f"id={fr.file_id} - {e.__class__.__name__}")
|
||||
|
||||
if normalized:
|
||||
db_session.commit()
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task normalized {normalized} plaintext objects"
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.USER_FILE_DOCID_MIGRATION,
|
||||
ignore_result=True,
|
||||
bind=True,
|
||||
)
|
||||
def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
|
||||
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Starting for tenant={tenant_id}"
|
||||
)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
lock: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.USER_FILE_DOCID_MIGRATION_LOCK,
|
||||
timeout=CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
if not lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Lock held, skipping tenant={tenant_id}"
|
||||
)
|
||||
return False
|
||||
|
||||
updated_count = 0
|
||||
try:
|
||||
update_legacy_plaintext_file_records()
|
||||
# Track lock renewal
|
||||
last_lock_time = time.monotonic()
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
|
||||
# 20 is the documented default for httpx max_keepalive_connections
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
active_settings = get_active_search_settings(db_session)
|
||||
document_index = get_default_document_index(
|
||||
search_settings=active_settings.primary,
|
||||
secondary_search_settings=active_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
|
||||
retry_index = RetryDocumentIndex(document_index)
|
||||
|
||||
# Select user files with a legacy doc id that have not been migrated
|
||||
user_files = (
|
||||
db_session.execute(
|
||||
sa.select(UserFile).where(
|
||||
sa.and_(
|
||||
UserFile.document_id.is_not(None),
|
||||
UserFile.document_id_migrated.is_(False),
|
||||
)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Found {len(user_files)} user files to migrate"
|
||||
)
|
||||
|
||||
# Query all SearchDocs that need updating
|
||||
search_docs = (
|
||||
db_session.execute(
|
||||
sa.select(SearchDoc).where(
|
||||
SearchDoc.document_id.like("%FILE_CONNECTOR__%")
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Found {len(search_docs)} search docs to update"
|
||||
)
|
||||
|
||||
# Build a map of normalized doc IDs to SearchDocs
|
||||
search_doc_map: dict[str, list[SearchDoc]] = {}
|
||||
for sd in search_docs:
|
||||
doc_id = sd.document_id
|
||||
if search_doc_map.get(doc_id) is None:
|
||||
search_doc_map[doc_id] = []
|
||||
search_doc_map[doc_id].append(sd)
|
||||
|
||||
task_logger.debug(
|
||||
f"user_file_docid_migration_task - Built search doc map with {len(search_doc_map)} entries"
|
||||
)
|
||||
|
||||
ids_preview = list(search_doc_map.keys())[:5]
|
||||
task_logger.debug(
|
||||
f"user_file_docid_migration_task - First few search_doc_map ids: {ids_preview if ids_preview else 'No ids found'}"
|
||||
)
|
||||
task_logger.debug(
|
||||
f"user_file_docid_migration_task - search_doc_map total items: "
|
||||
f"{sum(len(docs) for docs in search_doc_map.values())}"
|
||||
)
|
||||
for user_file in user_files:
|
||||
# Periodically renew the Redis lock to prevent expiry mid-run
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT / 4
|
||||
):
|
||||
renewed = False
|
||||
try:
|
||||
# extend lock ttl to full timeout window
|
||||
lock.extend(CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT)
|
||||
renewed = True
|
||||
except Exception:
|
||||
# if extend fails, best-effort reacquire as a fallback
|
||||
try:
|
||||
lock.reacquire()
|
||||
renewed = True
|
||||
except Exception:
|
||||
renewed = False
|
||||
last_lock_time = current_time
|
||||
if not renewed or not lock.owned():
|
||||
task_logger.error(
|
||||
"user_file_docid_migration_task - Lost lock ownership or failed to renew; aborting for safety"
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
clean_old_doc_id = replace_invalid_doc_id_characters(
|
||||
user_file.document_id
|
||||
)
|
||||
normalized_doc_id = _normalize_legacy_user_file_doc_id(
|
||||
clean_old_doc_id
|
||||
)
|
||||
user_project_ids = [project.id for project in user_file.projects]
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Migrating user file {user_file.id} with doc_id {normalized_doc_id}"
|
||||
)
|
||||
|
||||
index_name = active_settings.primary.index_name
|
||||
|
||||
# First find the chunks count using direct Vespa query
|
||||
selection = f"{index_name}.document_id=='{normalized_doc_id}'"
|
||||
|
||||
# Count all chunks for this document
|
||||
chunk_count = _get_document_chunk_count(
|
||||
index_name=index_name,
|
||||
selection=selection,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Found {chunk_count} chunks for document {normalized_doc_id}"
|
||||
)
|
||||
|
||||
# Now update Vespa chunks with the found chunk count using retry_index
|
||||
# WARNING: In the future this will error; we no longer want
|
||||
# to support changing document ID.
|
||||
# TODO(andrei): Delete soon.
|
||||
retry_index.update_single(
|
||||
doc_id=str(normalized_doc_id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
fields=VespaDocumentFields(document_id=str(user_file.id)),
|
||||
user_fields=VespaDocumentUserFields(
|
||||
user_projects=user_project_ids
|
||||
),
|
||||
)
|
||||
user_file.chunk_count = chunk_count
|
||||
|
||||
# Update the SearchDocs
|
||||
actual_doc_id = str(user_file.document_id)
|
||||
normalized_actual_doc_id = _normalize_legacy_user_file_doc_id(
|
||||
actual_doc_id
|
||||
)
|
||||
if (
|
||||
normalized_doc_id in search_doc_map
|
||||
or normalized_actual_doc_id in search_doc_map
|
||||
):
|
||||
to_update = (
|
||||
search_doc_map[normalized_doc_id]
|
||||
if normalized_doc_id in search_doc_map
|
||||
else search_doc_map[normalized_actual_doc_id]
|
||||
)
|
||||
task_logger.debug(
|
||||
f"user_file_docid_migration_task - Updating {len(to_update)} search docs for user file {user_file.id}"
|
||||
)
|
||||
for search_doc in to_update:
|
||||
search_doc.document_id = str(user_file.id)
|
||||
db_session.add(search_doc)
|
||||
|
||||
user_file.document_id_migrated = True
|
||||
db_session.add(user_file)
|
||||
db_session.commit()
|
||||
updated_count += 1
|
||||
except Exception as per_file_exc:
|
||||
# Rollback the current transaction and continue with the next file
|
||||
db_session.rollback()
|
||||
task_logger.exception(
|
||||
f"user_file_docid_migration_task - Error migrating user file {user_file.id} - "
|
||||
f"{per_file_exc.__class__.__name__}"
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Updated {updated_count} user files"
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Completed for tenant={tenant_id} (updated={updated_count})"
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"user_file_docid_migration_task - Error during execution for tenant={tenant_id} "
|
||||
f"(updated={updated_count}) exception={e.__class__.__name__}"
|
||||
)
|
||||
return False
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
57
backend/onyx/chat/chat_processing_checker.py
Normal file
57
backend/onyx/chat/chat_processing_checker.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from uuid import UUID
|
||||
|
||||
from redis.client import Redis
|
||||
|
||||
# Redis key prefixes for chat message processing
|
||||
PREFIX = "chatprocessing"
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
FENCE_TTL = 30 * 60 # 30 minutes
|
||||
|
||||
|
||||
def _get_fence_key(chat_session_id: UUID) -> str:
|
||||
"""
|
||||
Generate the Redis key for a chat session processing a message.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
|
||||
Returns:
|
||||
The fence key string (tenant_id is automatically added by the Redis client)
|
||||
"""
|
||||
return f"{FENCE_PREFIX}_{chat_session_id}"
|
||||
|
||||
|
||||
def set_processing_status(
|
||||
chat_session_id: UUID, redis_client: Redis, value: bool
|
||||
) -> None:
|
||||
"""
|
||||
Set or clear the fence for a chat session processing a message.
|
||||
|
||||
If the key exists, we are processing a message. If the key does not exist, we are not processing a message.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
redis_client: The Redis client to use
|
||||
value: True to set the fence, False to clear it
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
|
||||
if value:
|
||||
redis_client.set(fence_key, 0, ex=FENCE_TTL)
|
||||
else:
|
||||
redis_client.delete(fence_key)
|
||||
|
||||
|
||||
def is_chat_session_processing(chat_session_id: UUID, redis_client: Redis) -> bool:
|
||||
"""
|
||||
Check if the chat session is processing a message.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
redis_client: The Redis client to use
|
||||
|
||||
Returns:
|
||||
True if the chat session is processing a message, False otherwise
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
return bool(redis_client.exists(fence_key))
|
||||
@@ -94,6 +94,7 @@ class ChatStateContainer:
|
||||
|
||||
def run_chat_loop_with_state_containers(
|
||||
func: Callable[..., None],
|
||||
completion_callback: Callable[[ChatStateContainer], None],
|
||||
is_connected: Callable[[], bool],
|
||||
emitter: Emitter,
|
||||
state_container: ChatStateContainer,
|
||||
@@ -196,3 +197,12 @@ def run_chat_loop_with_state_containers(
|
||||
# Skip waiting if user disconnected to exit quickly.
|
||||
if is_connected():
|
||||
wait_on_background(thread)
|
||||
try:
|
||||
completion_callback(state_container)
|
||||
except Exception as e:
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=PacketException(type="error", exception=e),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -55,6 +55,7 @@ from onyx.prompts.chat_prompts import TOOL_CALL_RESPONSE_CROSS_MESSAGE
|
||||
from onyx.prompts.tool_prompts import TOOL_CALL_FAILURE_PROMPT
|
||||
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import (
|
||||
@@ -117,6 +118,7 @@ def prepare_chat_message_request(
|
||||
llm_override: LLMOverride | None = None,
|
||||
allowed_tool_ids: list[int] | None = None,
|
||||
forced_tool_ids: list[int] | None = None,
|
||||
origin: MessageOrigin | None = None,
|
||||
) -> CreateChatMessageRequest:
|
||||
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
|
||||
new_chat_session = create_chat_session(
|
||||
@@ -144,6 +146,7 @@ def prepare_chat_message_request(
|
||||
llm_override=llm_override,
|
||||
allowed_tool_ids=allowed_tool_ids,
|
||||
forced_tool_ids=forced_tool_ids,
|
||||
origin=origin or MessageOrigin.UNKNOWN,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -505,7 +505,7 @@ def run_llm_loop(
|
||||
# in-flight citations
|
||||
# It can be cleaned up but not super trivial or worthwhile right now
|
||||
just_ran_web_search = False
|
||||
tool_responses, citation_mapping = run_tool_calls(
|
||||
parallel_tool_call_results = run_tool_calls(
|
||||
tool_calls=tool_calls,
|
||||
tools=final_tools,
|
||||
message_history=truncated_message_history,
|
||||
@@ -513,8 +513,11 @@ def run_llm_loop(
|
||||
user_info=None, # TODO, this is part of memories right now, might want to separate it out
|
||||
citation_mapping=citation_mapping,
|
||||
next_citation_num=citation_processor.get_next_citation_number(),
|
||||
max_concurrent_tools=None,
|
||||
skip_search_query_expansion=has_called_search_tool,
|
||||
)
|
||||
tool_responses = parallel_tool_call_results.tool_responses
|
||||
citation_mapping = parallel_tool_call_results.updated_citation_mapping
|
||||
|
||||
# Failure case, give something reasonable to the LLM to try again
|
||||
if tool_calls and not tool_responses:
|
||||
|
||||
@@ -5,10 +5,13 @@ An overview can be found in the README.md file in this directory.
|
||||
|
||||
import re
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from uuid import UUID
|
||||
|
||||
from redis.client import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_processing_checker import set_processing_status
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_state import run_chat_loop_with_state_containers
|
||||
from onyx.chat.chat_utils import convert_chat_history
|
||||
@@ -45,6 +48,8 @@ from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.memory import get_memories
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import User
|
||||
from onyx.db.projects import get_project_token_count
|
||||
from onyx.db.projects import get_user_files_from_project
|
||||
@@ -78,20 +83,16 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.timing import log_function_time
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from onyx.utils.variable_functionality import noop_fallback
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
ERROR_TYPE_CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class ToolCallException(Exception):
|
||||
"""Exception raised for errors during tool calls."""
|
||||
|
||||
def __init__(self, message: str, tool_name: str | None = None):
|
||||
super().__init__(message)
|
||||
self.tool_name = tool_name
|
||||
|
||||
|
||||
def _extract_project_file_texts_and_images(
|
||||
project_id: int | None,
|
||||
user_id: UUID | None,
|
||||
@@ -294,6 +295,8 @@ def handle_stream_message_objects(
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
llm: LLM | None = None
|
||||
chat_session: ChatSession | None = None
|
||||
redis_client: Redis | None = None
|
||||
|
||||
user_id = user.id if user is not None else None
|
||||
llm_user_identifier = (
|
||||
@@ -339,6 +342,24 @@ def handle_stream_message_objects(
|
||||
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
|
||||
)
|
||||
|
||||
# Track user message in PostHog for analytics
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
module="onyx.utils.telemetry",
|
||||
attribute="event_telemetry",
|
||||
fallback=noop_fallback,
|
||||
)(
|
||||
distinct_id=user.email if user else tenant_id,
|
||||
event="user_message_sent",
|
||||
properties={
|
||||
"origin": new_msg_req.origin.value,
|
||||
"has_files": len(new_msg_req.file_descriptors) > 0,
|
||||
"has_project": chat_session.project_id is not None,
|
||||
"has_persona": persona is not None and persona.id != DEFAULT_PERSONA_ID,
|
||||
"deep_research": new_msg_req.deep_research,
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
)
|
||||
|
||||
llm = get_llm_for_persona(
|
||||
persona=persona,
|
||||
user=user,
|
||||
@@ -380,7 +401,10 @@ def handle_stream_message_objects(
|
||||
if new_msg_req.parent_message_id == AUTO_PLACE_AFTER_LATEST_MESSAGE:
|
||||
# Auto-place after the latest message in the chain
|
||||
parent_message = chat_history[-1] if chat_history else root_message
|
||||
elif new_msg_req.parent_message_id is None:
|
||||
elif (
|
||||
new_msg_req.parent_message_id is None
|
||||
or new_msg_req.parent_message_id == root_message.id
|
||||
):
|
||||
# None = regeneration from root
|
||||
parent_message = root_message
|
||||
# Truncate history since we're starting from root
|
||||
@@ -536,10 +560,27 @@ def handle_stream_message_objects(
|
||||
def check_is_connected() -> bool:
|
||||
return check_stop_signal(chat_session.id, redis_client)
|
||||
|
||||
set_processing_status(
|
||||
chat_session_id=chat_session.id,
|
||||
redis_client=redis_client,
|
||||
value=True,
|
||||
)
|
||||
|
||||
# Use external state container if provided, otherwise create internal one
|
||||
# External container allows non-streaming callers to access accumulated state
|
||||
state_container = external_state_container or ChatStateContainer()
|
||||
|
||||
def llm_loop_completion_callback(
|
||||
state_container: ChatStateContainer,
|
||||
) -> None:
|
||||
llm_loop_completion_handle(
|
||||
state_container=state_container,
|
||||
db_session=db_session,
|
||||
chat_session_id=str(chat_session.id),
|
||||
is_connected=check_is_connected,
|
||||
assistant_message=assistant_response,
|
||||
)
|
||||
|
||||
# Run the LLM loop with explicit wrapper for stop signal handling
|
||||
# The wrapper runs run_llm_loop in a background thread and polls every 300ms
|
||||
# for stop signals. run_llm_loop itself doesn't know about stopping.
|
||||
@@ -555,6 +596,7 @@ def handle_stream_message_objects(
|
||||
|
||||
yield from run_chat_loop_with_state_containers(
|
||||
run_deep_research_llm_loop,
|
||||
llm_loop_completion_callback,
|
||||
is_connected=check_is_connected,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
@@ -571,6 +613,7 @@ def handle_stream_message_objects(
|
||||
else:
|
||||
yield from run_chat_loop_with_state_containers(
|
||||
run_llm_loop,
|
||||
llm_loop_completion_callback,
|
||||
is_connected=check_is_connected, # Not passed through to run_llm_loop
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
@@ -588,51 +631,6 @@ def handle_stream_message_objects(
|
||||
chat_session_id=str(chat_session.id),
|
||||
)
|
||||
|
||||
# Determine if stopped by user
|
||||
completed_normally = check_is_connected()
|
||||
if not completed_normally:
|
||||
logger.debug(f"Chat session {chat_session.id} stopped by user")
|
||||
|
||||
# Build final answer based on completion status
|
||||
if completed_normally:
|
||||
if state_container.answer_tokens is None:
|
||||
raise RuntimeError(
|
||||
"LLM run completed normally but did not return an answer."
|
||||
)
|
||||
final_answer = state_container.answer_tokens
|
||||
else:
|
||||
# Stopped by user - append stop message
|
||||
if state_container.answer_tokens:
|
||||
final_answer = (
|
||||
state_container.answer_tokens
|
||||
+ " ... The generation was stopped by the user here."
|
||||
)
|
||||
else:
|
||||
final_answer = "The generation was stopped by the user."
|
||||
|
||||
# Build citation_docs_info from accumulated citations in state container
|
||||
citation_docs_info: list[CitationDocInfo] = []
|
||||
seen_citation_nums: set[int] = set()
|
||||
for citation_num, search_doc in state_container.citation_to_doc.items():
|
||||
if citation_num not in seen_citation_nums:
|
||||
seen_citation_nums.add(citation_num)
|
||||
citation_docs_info.append(
|
||||
CitationDocInfo(
|
||||
search_doc=search_doc,
|
||||
citation_number=citation_num,
|
||||
)
|
||||
)
|
||||
|
||||
save_chat_turn(
|
||||
message_text=final_answer,
|
||||
reasoning_tokens=state_container.reasoning_tokens,
|
||||
citation_docs_info=citation_docs_info,
|
||||
tool_calls=state_container.tool_calls,
|
||||
db_session=db_session,
|
||||
assistant_message=assistant_response,
|
||||
is_clarification=state_container.is_clarification,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to process chat message.")
|
||||
|
||||
@@ -650,15 +648,7 @@ def handle_stream_message_objects(
|
||||
error_msg = str(e)
|
||||
stack_trace = traceback.format_exc()
|
||||
|
||||
if isinstance(e, ToolCallException):
|
||||
yield StreamingError(
|
||||
error=error_msg,
|
||||
stack_trace=stack_trace,
|
||||
error_code="TOOL_CALL_FAILED",
|
||||
is_retryable=True,
|
||||
details={"tool_name": e.tool_name} if e.tool_name else None,
|
||||
)
|
||||
elif llm:
|
||||
if llm:
|
||||
client_error_msg, error_code, is_retryable = litellm_exception_to_error_msg(
|
||||
e, llm
|
||||
)
|
||||
@@ -690,7 +680,67 @@ def handle_stream_message_objects(
|
||||
)
|
||||
|
||||
db_session.rollback()
|
||||
return
|
||||
finally:
|
||||
try:
|
||||
if redis_client is not None and chat_session is not None:
|
||||
set_processing_status(
|
||||
chat_session_id=chat_session.id,
|
||||
redis_client=redis_client,
|
||||
value=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error in setting processing status")
|
||||
|
||||
|
||||
def llm_loop_completion_handle(
|
||||
state_container: ChatStateContainer,
|
||||
is_connected: Callable[[], bool],
|
||||
db_session: Session,
|
||||
chat_session_id: str,
|
||||
assistant_message: ChatMessage,
|
||||
) -> None:
|
||||
# Determine if stopped by user
|
||||
completed_normally = is_connected()
|
||||
# Build final answer based on completion status
|
||||
if completed_normally:
|
||||
if state_container.answer_tokens is None:
|
||||
raise RuntimeError(
|
||||
"LLM run completed normally but did not return an answer."
|
||||
)
|
||||
final_answer = state_container.answer_tokens
|
||||
else:
|
||||
# Stopped by user - append stop message
|
||||
logger.debug(f"Chat session {chat_session_id} stopped by user")
|
||||
if state_container.answer_tokens:
|
||||
final_answer = (
|
||||
state_container.answer_tokens
|
||||
+ " ... \n\nGeneration was stopped by the user."
|
||||
)
|
||||
else:
|
||||
final_answer = "The generation was stopped by the user."
|
||||
|
||||
# Build citation_docs_info from accumulated citations in state container
|
||||
citation_docs_info: list[CitationDocInfo] = []
|
||||
seen_citation_nums: set[int] = set()
|
||||
for citation_num, search_doc in state_container.citation_to_doc.items():
|
||||
if citation_num not in seen_citation_nums:
|
||||
seen_citation_nums.add(citation_num)
|
||||
citation_docs_info.append(
|
||||
CitationDocInfo(
|
||||
search_doc=search_doc,
|
||||
citation_number=citation_num,
|
||||
)
|
||||
)
|
||||
|
||||
save_chat_turn(
|
||||
message_text=final_answer,
|
||||
reasoning_tokens=state_container.reasoning_tokens,
|
||||
citation_docs_info=citation_docs_info,
|
||||
tool_calls=state_container.tool_calls,
|
||||
db_session=db_session,
|
||||
assistant_message=assistant_message,
|
||||
is_clarification=state_container.is_clarification,
|
||||
)
|
||||
|
||||
|
||||
def stream_chat_message_objects(
|
||||
@@ -739,6 +789,7 @@ def stream_chat_message_objects(
|
||||
deep_research=new_msg_req.deep_research,
|
||||
parent_message_id=new_msg_req.parent_message_id,
|
||||
chat_session_id=new_msg_req.chat_session_id,
|
||||
origin=new_msg_req.origin,
|
||||
)
|
||||
return handle_stream_message_objects(
|
||||
new_msg_req=translated_new_msg_req,
|
||||
|
||||
@@ -23,6 +23,7 @@ from onyx.prompts.tool_prompts import PYTHON_TOOL_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import TOOL_SECTION_HEADER
|
||||
from onyx.prompts.tool_prompts import WEB_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import WEB_SEARCH_SITE_DISABLED_GUIDANCE
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
@@ -173,7 +174,9 @@ def build_system_prompt(
|
||||
TOOL_SECTION_HEADER
|
||||
+ TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
+ INTERNAL_SEARCH_GUIDANCE
|
||||
+ WEB_SEARCH_GUIDANCE
|
||||
+ WEB_SEARCH_GUIDANCE.format(
|
||||
site_colon_disabled=WEB_SEARCH_SITE_DISABLED_GUIDANCE
|
||||
)
|
||||
+ OPEN_URLS_GUIDANCE
|
||||
+ GENERATE_IMAGE_GUIDANCE
|
||||
+ PYTHON_TOOL_GUIDANCE
|
||||
@@ -199,7 +202,16 @@ def build_system_prompt(
|
||||
system_prompt += INTERNAL_SEARCH_GUIDANCE
|
||||
|
||||
if has_web_search or include_all_guidance:
|
||||
system_prompt += WEB_SEARCH_GUIDANCE
|
||||
site_disabled_guidance = ""
|
||||
if has_web_search:
|
||||
web_search_tool = next(
|
||||
(t for t in tools if isinstance(t, WebSearchTool)), None
|
||||
)
|
||||
if web_search_tool and not web_search_tool.supports_site_filter:
|
||||
site_disabled_guidance = WEB_SEARCH_SITE_DISABLED_GUIDANCE
|
||||
system_prompt += WEB_SEARCH_GUIDANCE.format(
|
||||
site_colon_disabled=site_disabled_guidance
|
||||
)
|
||||
|
||||
if has_open_urls or include_all_guidance:
|
||||
system_prompt += OPEN_URLS_GUIDANCE
|
||||
|
||||
@@ -568,6 +568,7 @@ JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
JIRA_CONNECTOR_MAX_TICKET_SIZE = int(
|
||||
os.environ.get("JIRA_CONNECTOR_MAX_TICKET_SIZE", 100 * 1024)
|
||||
)
|
||||
JIRA_SLIM_PAGE_SIZE = int(os.environ.get("JIRA_SLIM_PAGE_SIZE", 500))
|
||||
|
||||
GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
|
||||
|
||||
@@ -679,10 +680,6 @@ INDEXING_EMBEDDING_MODEL_NUM_THREADS = int(
|
||||
os.environ.get("INDEXING_EMBEDDING_MODEL_NUM_THREADS") or 8
|
||||
)
|
||||
|
||||
# Maximum number of user file connector credential pairs to index in a single batch
|
||||
# Setting this number too high may overload the indexing process
|
||||
USER_FILE_INDEXING_LIMIT = int(os.environ.get("USER_FILE_INDEXING_LIMIT") or 100)
|
||||
|
||||
# Maximum file size in a document to be indexed
|
||||
MAX_DOCUMENT_CHARS = int(os.environ.get("MAX_DOCUMENT_CHARS") or 5_000_000)
|
||||
MAX_FILE_SIZE_BYTES = int(
|
||||
@@ -754,7 +751,27 @@ BRAINTRUST_PROJECT = os.environ.get("BRAINTRUST_PROJECT", "Onyx")
|
||||
# Braintrust API key - if provided, Braintrust tracing will be enabled
|
||||
BRAINTRUST_API_KEY = os.environ.get("BRAINTRUST_API_KEY") or ""
|
||||
# Maximum concurrency for Braintrust evaluations
|
||||
BRAINTRUST_MAX_CONCURRENCY = int(os.environ.get("BRAINTRUST_MAX_CONCURRENCY") or 5)
|
||||
# None means unlimited concurrency, otherwise specify a number
|
||||
_braintrust_concurrency = os.environ.get("BRAINTRUST_MAX_CONCURRENCY")
|
||||
BRAINTRUST_MAX_CONCURRENCY = (
|
||||
int(_braintrust_concurrency) if _braintrust_concurrency else None
|
||||
)
|
||||
|
||||
#####
|
||||
# Scheduled Evals Configuration
|
||||
#####
|
||||
# Comma-separated list of Braintrust dataset names to run on schedule
|
||||
SCHEDULED_EVAL_DATASET_NAMES = [
|
||||
name.strip()
|
||||
for name in os.environ.get("SCHEDULED_EVAL_DATASET_NAMES", "").split(",")
|
||||
if name.strip()
|
||||
]
|
||||
# Email address to use for search permissions during scheduled evals
|
||||
SCHEDULED_EVAL_PERMISSIONS_EMAIL = os.environ.get(
|
||||
"SCHEDULED_EVAL_PERMISSIONS_EMAIL", "roshan@onyx.app"
|
||||
)
|
||||
# Braintrust project name to use for scheduled evals
|
||||
SCHEDULED_EVAL_PROJECT = os.environ.get("SCHEDULED_EVAL_PROJECT", "st-dev")
|
||||
|
||||
#####
|
||||
# Langfuse Configuration
|
||||
@@ -979,3 +996,9 @@ COHERE_DEFAULT_API_KEY = os.environ.get("COHERE_DEFAULT_API_KEY")
|
||||
VERTEXAI_DEFAULT_CREDENTIALS = os.environ.get("VERTEXAI_DEFAULT_CREDENTIALS")
|
||||
VERTEXAI_DEFAULT_LOCATION = os.environ.get("VERTEXAI_DEFAULT_LOCATION", "global")
|
||||
OPENROUTER_DEFAULT_API_KEY = os.environ.get("OPENROUTER_DEFAULT_API_KEY")
|
||||
|
||||
INSTANCE_TYPE = (
|
||||
"managed"
|
||||
if os.environ.get("IS_MANAGED_INSTANCE", "").lower() == "true"
|
||||
else "cloud" if AUTH_TYPE == AuthType.CLOUD else "self_hosted"
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ from enum import Enum
|
||||
|
||||
ONYX_DEFAULT_APPLICATION_NAME = "Onyx"
|
||||
ONYX_DISCORD_URL = "https://discord.gg/4NA5SbzrWb"
|
||||
ONYX_UTM_SOURCE = "onyx_app"
|
||||
SLACK_USER_TOKEN_PREFIX = "xoxp-"
|
||||
SLACK_BOT_TOKEN_PREFIX = "xoxb-"
|
||||
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
|
||||
@@ -146,11 +147,19 @@ CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT = 3600 # 1 hour (in seconds)
|
||||
|
||||
CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
# Doc ID migration can be long-running; use a longer TTL and renew periodically
|
||||
CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT = 10 * 60 # 10 minutes (in seconds)
|
||||
|
||||
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
|
||||
|
||||
# How long a queued user-file task is valid before workers discard it.
|
||||
# Should be longer than the beat interval (20 s) but short enough to prevent
|
||||
# indefinite queue growth. Workers drop tasks older than this without touching
|
||||
# the DB, so a shorter value = faster drain of stale duplicates.
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
|
||||
|
||||
# Maximum number of tasks allowed in the user-file-processing queue before the
|
||||
# beat generator stops adding more. Prevents unbounded queue growth when workers
|
||||
# fall behind.
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
|
||||
|
||||
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
|
||||
@@ -237,6 +246,8 @@ class NotificationType(str, Enum):
|
||||
REINDEX = "reindex"
|
||||
PERSONA_SHARED = "persona_shared"
|
||||
TRIAL_ENDS_TWO_DAYS = "two_day_trial_ending" # 2 days left in trial
|
||||
RELEASE_NOTES = "release_notes"
|
||||
ASSISTANT_FILES_READY = "assistant_files_ready"
|
||||
|
||||
|
||||
class BlobType(str, Enum):
|
||||
@@ -365,9 +376,6 @@ class OnyxCeleryQueues:
|
||||
CONNECTOR_EXTERNAL_GROUP_SYNC = "connector_external_group_sync"
|
||||
CSV_GENERATION = "csv_generation"
|
||||
|
||||
# Indexing queue
|
||||
USER_FILES_INDEXING = "user_files_indexing"
|
||||
|
||||
# User file processing queue
|
||||
USER_FILE_PROCESSING = "user_file_processing"
|
||||
USER_FILE_PROJECT_SYNC = "user_file_project_sync"
|
||||
@@ -422,11 +430,16 @@ class OnyxRedisLocks:
|
||||
# User file processing
|
||||
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
|
||||
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
|
||||
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
|
||||
# Prevents the beat from re-enqueuing the same file while a task is already queued.
|
||||
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
|
||||
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
|
||||
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
|
||||
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
|
||||
USER_FILE_DELETE_LOCK_PREFIX = "da_lock:user_file_delete"
|
||||
USER_FILE_DOCID_MIGRATION_LOCK = "da_lock:user_file_docid_migration"
|
||||
|
||||
# Release notes
|
||||
RELEASE_NOTES_FETCH_LOCK = "da_lock:release_notes_fetch"
|
||||
|
||||
|
||||
class OnyxRedisSignals:
|
||||
@@ -533,7 +546,6 @@ class OnyxCeleryTask:
|
||||
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
|
||||
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
|
||||
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"
|
||||
USER_FILE_DOCID_MIGRATION = "user_file_docid_migration"
|
||||
|
||||
# chat retention
|
||||
CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task"
|
||||
@@ -542,6 +554,7 @@ class OnyxCeleryTask:
|
||||
GENERATE_USAGE_REPORT_TASK = "generate_usage_report_task"
|
||||
|
||||
EVAL_RUN_TASK = "eval_run_task"
|
||||
SCHEDULED_EVAL_TASK = "scheduled_eval_task"
|
||||
|
||||
EXPORT_QUERY_HISTORY_TASK = "export_query_history_task"
|
||||
EXPORT_QUERY_HISTORY_CLEANUP_TASK = "export_query_history_cleanup_task"
|
||||
|
||||
@@ -93,7 +93,7 @@ if __name__ == "__main__":
|
||||
#### Docs Changes
|
||||
|
||||
Create the new connector page (with guiding images!) with how to get the connector credentials and how to set up the
|
||||
connector in Onyx. Then create a Pull Request in https://github.com/onyx-dot-app/onyx-docs.
|
||||
connector in Onyx. Then create a Pull Request in [https://github.com/onyx-dot-app/documentation](https://github.com/onyx-dot-app/documentation).
|
||||
|
||||
### Before opening PR
|
||||
|
||||
|
||||
@@ -901,13 +901,16 @@ class OnyxConfluence:
|
||||
space_key: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
This is a confluence server specific method that can be used to
|
||||
This is a confluence server/data center specific method that can be used to
|
||||
fetch the permissions of a space.
|
||||
This is better logging than calling the get_space_permissions method
|
||||
because it returns a jsonrpc response.
|
||||
TODO: Make this call these endpoints for newer confluence versions:
|
||||
- /rest/api/space/{spaceKey}/permissions
|
||||
- /rest/api/space/{spaceKey}/permissions/anonymous
|
||||
|
||||
NOTE: This uses the JSON-RPC API which is the ONLY way to get space permissions
|
||||
on Confluence Server/Data Center. The REST API equivalent (expand=permissions)
|
||||
is Cloud-only and not available on Data Center as of version 8.9.x.
|
||||
|
||||
If this fails with 401 Unauthorized, the customer needs to enable JSON-RPC:
|
||||
Confluence Admin -> General Configuration -> Further Configuration
|
||||
-> Enable "Remote API (XML-RPC & SOAP)"
|
||||
"""
|
||||
url = "rpc/json-rpc/confluenceservice-v2"
|
||||
data = {
|
||||
@@ -916,7 +919,18 @@ class OnyxConfluence:
|
||||
"id": 7,
|
||||
"params": [space_key],
|
||||
}
|
||||
response = self.post(url, data=data)
|
||||
try:
|
||||
response = self.post(url, data=data)
|
||||
except HTTPError as e:
|
||||
if e.response is not None and e.response.status_code == 401:
|
||||
raise HTTPError(
|
||||
"Unauthorized (401) when calling JSON-RPC API for space permissions. "
|
||||
"This is likely because the Remote API is disabled. "
|
||||
"To fix: Confluence Admin -> General Configuration -> Further Configuration "
|
||||
"-> Enable 'Remote API (XML-RPC & SOAP)'",
|
||||
response=e.response,
|
||||
) from e
|
||||
raise
|
||||
logger.debug(f"jsonrpc response: {response}")
|
||||
if not response.get("result"):
|
||||
logger.warning(
|
||||
@@ -961,14 +975,20 @@ def get_user_email_from_username__server(
|
||||
try:
|
||||
response = confluence_client.get_mobile_parameters(user_name)
|
||||
email = response.get("email")
|
||||
except Exception:
|
||||
logger.warning(f"failed to get confluence email for {user_name}")
|
||||
except HTTPError as e:
|
||||
status_code = e.response.status_code if e.response is not None else "N/A"
|
||||
logger.warning(
|
||||
f"Failed to get confluence email for {user_name}: "
|
||||
f"HTTP {status_code} - {e}"
|
||||
)
|
||||
# For now, we'll just return None and log a warning. This means
|
||||
# we will keep retrying to get the email every group sync.
|
||||
email = None
|
||||
# We may want to just return a string that indicates failure so we dont
|
||||
# keep retrying
|
||||
# email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to get confluence email for {user_name}: {type(e).__name__} - {e}"
|
||||
)
|
||||
email = None
|
||||
_USER_EMAIL_CACHE[user_name] = email
|
||||
return _USER_EMAIL_CACHE[user_name]
|
||||
|
||||
|
||||
@@ -97,10 +97,17 @@ def basic_expert_info_representation(info: BasicExpertInfo) -> str | None:
|
||||
def get_experts_stores_representations(
|
||||
experts: list[BasicExpertInfo] | None,
|
||||
) -> list[str] | None:
|
||||
"""Gets string representations of experts supplied.
|
||||
|
||||
If an expert cannot be represented as a string, it is omitted from the
|
||||
result.
|
||||
"""
|
||||
if not experts:
|
||||
return None
|
||||
|
||||
reps = [basic_expert_info_representation(owner) for owner in experts]
|
||||
reps: list[str | None] = [
|
||||
basic_expert_info_representation(owner) for owner in experts
|
||||
]
|
||||
return [owner for owner in reps if owner is not None]
|
||||
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from typing_extensions import override
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
|
||||
from onyx.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
|
||||
from onyx.configs.app_configs import JIRA_SLIM_PAGE_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
is_atlassian_date_error,
|
||||
@@ -57,7 +58,6 @@ logger = setup_logger()
|
||||
ONE_HOUR = 3600
|
||||
|
||||
_MAX_RESULTS_FETCH_IDS = 5000 # 5000
|
||||
_JIRA_SLIM_PAGE_SIZE = 500
|
||||
_JIRA_FULL_PAGE_SIZE = 50
|
||||
|
||||
# Constants for Jira field names
|
||||
@@ -683,7 +683,7 @@ class JiraConnector(
|
||||
jira_client=self.jira_client,
|
||||
jql=jql,
|
||||
start=current_offset,
|
||||
max_results=_JIRA_SLIM_PAGE_SIZE,
|
||||
max_results=JIRA_SLIM_PAGE_SIZE,
|
||||
all_issue_ids=checkpoint.all_issue_ids,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
nextPageToken=checkpoint.cursor,
|
||||
@@ -703,11 +703,11 @@ class JiraConnector(
|
||||
)
|
||||
)
|
||||
current_offset += 1
|
||||
if len(slim_doc_batch) >= _JIRA_SLIM_PAGE_SIZE:
|
||||
if len(slim_doc_batch) >= JIRA_SLIM_PAGE_SIZE:
|
||||
yield slim_doc_batch
|
||||
slim_doc_batch = []
|
||||
self.update_checkpoint_for_next_run(
|
||||
checkpoint, current_offset, prev_offset, _JIRA_SLIM_PAGE_SIZE
|
||||
checkpoint, current_offset, prev_offset, JIRA_SLIM_PAGE_SIZE
|
||||
)
|
||||
prev_offset = current_offset
|
||||
|
||||
|
||||
@@ -566,6 +566,23 @@ def extract_content_words_from_recency_query(
|
||||
return content_words_filtered[:MAX_CONTENT_WORDS]
|
||||
|
||||
|
||||
def _is_valid_keyword_query(line: str) -> bool:
|
||||
"""Check if a line looks like a valid keyword query vs explanatory text.
|
||||
|
||||
Returns False for lines that appear to be LLM explanations rather than keywords.
|
||||
"""
|
||||
# Reject lines that start with parentheses (explanatory notes)
|
||||
if line.startswith("("):
|
||||
return False
|
||||
|
||||
# Reject lines that are too long (likely sentences, not keywords)
|
||||
# Keywords should be short - reject if > 50 chars or > 6 words
|
||||
if len(line) > 50 or len(line.split()) > 6:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
"""Use LLM to expand query into multiple search variations.
|
||||
|
||||
@@ -586,10 +603,18 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
response_clean = _parse_llm_code_block_response(response)
|
||||
|
||||
# Split into lines and filter out empty lines
|
||||
rephrased_queries = [
|
||||
raw_queries = [
|
||||
line.strip() for line in response_clean.split("\n") if line.strip()
|
||||
]
|
||||
|
||||
# Filter out lines that look like explanatory text rather than keywords
|
||||
rephrased_queries = [q for q in raw_queries if _is_valid_keyword_query(q)]
|
||||
|
||||
# Log if we filtered out garbage
|
||||
if len(raw_queries) != len(rephrased_queries):
|
||||
filtered_out = set(raw_queries) - set(rephrased_queries)
|
||||
logger.warning(f"Filtered out non-keyword LLM responses: {filtered_out}")
|
||||
|
||||
# If no queries generated, use empty query
|
||||
if not rephrased_queries:
|
||||
logger.debug("No content keywords extracted from query expansion")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Tuple
|
||||
from uuid import UUID
|
||||
|
||||
@@ -181,7 +182,11 @@ def get_chat_sessions_by_user(
|
||||
.correlate(ChatSession)
|
||||
)
|
||||
|
||||
stmt = stmt.where(non_system_message_exists_subq)
|
||||
# Leeway for newly created chats that don't have messages yet
|
||||
time = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
recently_created = ChatSession.time_created >= time
|
||||
|
||||
stmt = stmt.where(or_(non_system_message_exists_subq, recently_created))
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
chat_sessions = result.scalars().all()
|
||||
|
||||
@@ -6,21 +6,15 @@ from fastapi import HTTPException
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import exists
|
||||
from sqlalchemy import lateral
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import true
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.configs.app_configs import USER_FILE_INDEXING_LIMIT
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.credentials import fetch_credential_by_id
|
||||
@@ -120,7 +114,6 @@ def get_connector_credential_pairs_for_user(
|
||||
eager_load_connector: bool = False,
|
||||
eager_load_credential: bool = False,
|
||||
eager_load_user: bool = False,
|
||||
include_user_files: bool = False,
|
||||
order_by_desc: bool = False,
|
||||
source: DocumentSource | None = None,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
@@ -149,9 +142,6 @@ def get_connector_credential_pairs_for_user(
|
||||
if ids:
|
||||
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
|
||||
|
||||
if not include_user_files:
|
||||
stmt = stmt.where(ConnectorCredentialPair.is_user_file.is_(False))
|
||||
|
||||
if order_by_desc:
|
||||
stmt = stmt.order_by(desc(ConnectorCredentialPair.id))
|
||||
|
||||
@@ -186,16 +176,13 @@ def get_connector_credential_pairs_for_user_parallel(
|
||||
|
||||
|
||||
def get_connector_credential_pairs(
|
||||
db_session: Session, ids: list[int] | None = None, include_user_files: bool = False
|
||||
db_session: Session, ids: list[int] | None = None
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
stmt = select(ConnectorCredentialPair).distinct()
|
||||
|
||||
if ids:
|
||||
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
|
||||
|
||||
if not include_user_files:
|
||||
stmt = stmt.where(ConnectorCredentialPair.is_user_file != True) # noqa: E712
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
@@ -242,15 +229,12 @@ def get_connector_credential_pair_for_user(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
user: User | None,
|
||||
include_user_files: bool = False,
|
||||
get_editable: bool = True,
|
||||
) -> ConnectorCredentialPair | None:
|
||||
stmt = select(ConnectorCredentialPair)
|
||||
stmt = _add_user_filters(stmt, user, get_editable)
|
||||
stmt = stmt.where(ConnectorCredentialPair.connector_id == connector_id)
|
||||
stmt = stmt.where(ConnectorCredentialPair.credential_id == credential_id)
|
||||
if not include_user_files:
|
||||
stmt = stmt.where(ConnectorCredentialPair.is_user_file != True) # noqa: E712
|
||||
result = db_session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@@ -377,8 +361,6 @@ def _update_connector_credential_pair(
|
||||
cc_pair.total_docs_indexed += net_docs
|
||||
if status is not None:
|
||||
cc_pair.status = status
|
||||
if cc_pair.is_user_file:
|
||||
cc_pair.status = ConnectorCredentialPairStatus.PAUSED
|
||||
|
||||
db_session.commit()
|
||||
|
||||
@@ -444,27 +426,10 @@ def set_cc_pair_repeated_error_state(
|
||||
cc_pair_id: int,
|
||||
in_repeated_error_state: bool,
|
||||
) -> None:
|
||||
values: dict = {"in_repeated_error_state": in_repeated_error_state}
|
||||
|
||||
# When entering repeated error state, also pause the connector
|
||||
# to prevent continued indexing retry attempts burning through embedding credits.
|
||||
# However, don't pause if there's an active manual indexing trigger,
|
||||
# which indicates the user wants to retry immediately.
|
||||
# NOTE: only for Cloud, since most self-hosted users use self-hosted embedding
|
||||
# models. Also, they are more prone to repeated failures -> eventual success.
|
||||
if in_repeated_error_state and AUTH_TYPE == AuthType.CLOUD:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
# Only pause if there's no manual indexing trigger active
|
||||
if cc_pair and cc_pair.indexing_trigger is None:
|
||||
values["status"] = ConnectorCredentialPairStatus.PAUSED
|
||||
|
||||
stmt = (
|
||||
update(ConnectorCredentialPair)
|
||||
.where(ConnectorCredentialPair.id == cc_pair_id)
|
||||
.values(**values)
|
||||
.values(in_repeated_error_state=in_repeated_error_state)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
db_session.commit()
|
||||
@@ -536,7 +501,6 @@ def add_credential_to_connector(
|
||||
initial_status: ConnectorCredentialPairStatus = ConnectorCredentialPairStatus.SCHEDULED,
|
||||
last_successful_index_time: datetime | None = None,
|
||||
seeding_flow: bool = False,
|
||||
is_user_file: bool = False,
|
||||
) -> StatusResponse:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
|
||||
@@ -602,7 +566,6 @@ def add_credential_to_connector(
|
||||
access_type=access_type,
|
||||
auto_sync_options=auto_sync_options,
|
||||
last_successful_index_time=last_successful_index_time,
|
||||
is_user_file=is_user_file,
|
||||
)
|
||||
db_session.add(association)
|
||||
db_session.flush() # make sure the association has an id
|
||||
@@ -699,67 +662,12 @@ def fetch_indexable_standard_connector_credential_pair_ids(
|
||||
)
|
||||
)
|
||||
|
||||
# Exclude user files. NOTE: some cc pairs have null for is_user_file instead of False
|
||||
stmt = stmt.where(ConnectorCredentialPair.is_user_file.is_not(True))
|
||||
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
return list(db_session.scalars(stmt))
|
||||
|
||||
|
||||
def fetch_indexable_user_file_connector_credential_pair_ids(
|
||||
db_session: Session,
|
||||
search_settings_id: int,
|
||||
limit: int | None = USER_FILE_INDEXING_LIMIT,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Return up to `limit` user file connector_credential_pair IDs that still
|
||||
need indexing for the given `search_settings_id`
|
||||
|
||||
A cc_pair is considered "needs indexing" if its most recent IndexAttempt
|
||||
for this search_settings_id is either:
|
||||
- Missing entirely (no attempts yet)
|
||||
- Present but not SUCCESS status
|
||||
|
||||
Implementation details:
|
||||
- Uses a LEFT JOIN LATERAL subquery to fetch only the single newest attempt
|
||||
per cc_pair (`ORDER BY time_updated DESC LIMIT 1`), instead of joining all
|
||||
attempts. This avoids scanning thousands of historical attempts and
|
||||
keeps memory/CPU usage low
|
||||
- `ON TRUE` is required in the lateral join because the correlation to
|
||||
ConnectorCredentialPair.id happens inside the subquery itself
|
||||
- NOTE: Shares some redundant logic with should_index() (TODO: combine)
|
||||
|
||||
Returns:
|
||||
list[int]: connector_credential_pair IDs that should be indexed next
|
||||
"""
|
||||
latest_attempt = lateral(
|
||||
select(IndexAttempt.status)
|
||||
.where(
|
||||
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
|
||||
IndexAttempt.search_settings_id == search_settings_id,
|
||||
)
|
||||
.order_by(IndexAttempt.time_updated.desc())
|
||||
.limit(1)
|
||||
).alias("latest_attempt")
|
||||
|
||||
stmt = (
|
||||
select(ConnectorCredentialPair.id)
|
||||
.outerjoin(latest_attempt, true()) # ON TRUE, Postgres-style lateral join
|
||||
.where(
|
||||
ConnectorCredentialPair.is_user_file.is_(True),
|
||||
or_(
|
||||
latest_attempt.c.status.is_(None), # no attempts at all
|
||||
latest_attempt.c.status != IndexingStatus.SUCCESS, # latest != SUCCESS
|
||||
),
|
||||
)
|
||||
.limit(limit) # Always apply a limit when fetching user file cc pairs
|
||||
)
|
||||
|
||||
return list(db_session.scalars(stmt))
|
||||
|
||||
|
||||
def fetch_connector_credential_pair_for_connector(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
|
||||
@@ -444,6 +444,8 @@ def upsert_documents(
|
||||
logger.info("No documents to upsert. Skipping.")
|
||||
return
|
||||
|
||||
includes_permissions = any(doc.external_access for doc in seen_documents.values())
|
||||
|
||||
insert_stmt = insert(DbDocument).values(
|
||||
[
|
||||
model_to_dict(
|
||||
@@ -479,21 +481,38 @@ def upsert_documents(
|
||||
]
|
||||
)
|
||||
|
||||
update_set = {
|
||||
"from_ingestion_api": insert_stmt.excluded.from_ingestion_api,
|
||||
"boost": insert_stmt.excluded.boost,
|
||||
"hidden": insert_stmt.excluded.hidden,
|
||||
"semantic_id": insert_stmt.excluded.semantic_id,
|
||||
"link": insert_stmt.excluded.link,
|
||||
"primary_owners": insert_stmt.excluded.primary_owners,
|
||||
"secondary_owners": insert_stmt.excluded.secondary_owners,
|
||||
"doc_metadata": insert_stmt.excluded.doc_metadata,
|
||||
}
|
||||
if includes_permissions:
|
||||
# Use COALESCE to preserve existing permissions when new values are NULL.
|
||||
# This prevents subsequent indexing runs (which don't fetch permissions)
|
||||
# from overwriting permissions set by permission sync jobs.
|
||||
update_set.update(
|
||||
{
|
||||
"external_user_emails": func.coalesce(
|
||||
insert_stmt.excluded.external_user_emails,
|
||||
DbDocument.external_user_emails,
|
||||
),
|
||||
"external_user_group_ids": func.coalesce(
|
||||
insert_stmt.excluded.external_user_group_ids,
|
||||
DbDocument.external_user_group_ids,
|
||||
),
|
||||
"is_public": func.coalesce(
|
||||
insert_stmt.excluded.is_public,
|
||||
DbDocument.is_public,
|
||||
),
|
||||
}
|
||||
)
|
||||
on_conflict_stmt = insert_stmt.on_conflict_do_update(
|
||||
index_elements=["id"], # Conflict target
|
||||
set_={
|
||||
"from_ingestion_api": insert_stmt.excluded.from_ingestion_api,
|
||||
"boost": insert_stmt.excluded.boost,
|
||||
"hidden": insert_stmt.excluded.hidden,
|
||||
"semantic_id": insert_stmt.excluded.semantic_id,
|
||||
"link": insert_stmt.excluded.link,
|
||||
"primary_owners": insert_stmt.excluded.primary_owners,
|
||||
"secondary_owners": insert_stmt.excluded.secondary_owners,
|
||||
"external_user_emails": insert_stmt.excluded.external_user_emails,
|
||||
"external_user_group_ids": insert_stmt.excluded.external_user_group_ids,
|
||||
"is_public": insert_stmt.excluded.is_public,
|
||||
"doc_metadata": insert_stmt.excluded.doc_metadata,
|
||||
},
|
||||
index_elements=["id"], set_=update_set # Conflict target
|
||||
)
|
||||
db_session.execute(on_conflict_stmt)
|
||||
db_session.commit()
|
||||
|
||||
@@ -374,7 +374,7 @@ def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolM
|
||||
def fetch_existing_llm_providers(
|
||||
db_session: Session,
|
||||
only_public: bool = False,
|
||||
exclude_image_generation_providers: bool = False,
|
||||
exclude_image_generation_providers: bool = True,
|
||||
) -> list[LLMProviderModel]:
|
||||
"""Fetch all LLM providers with optional filtering.
|
||||
|
||||
@@ -585,13 +585,12 @@ def update_default_vision_provider(
|
||||
|
||||
def fetch_auto_mode_providers(db_session: Session) -> list[LLMProviderModel]:
|
||||
"""Fetch all LLM providers that are in Auto mode."""
|
||||
return list(
|
||||
db_session.scalars(
|
||||
select(LLMProviderModel)
|
||||
.where(LLMProviderModel.is_auto_mode == True) # noqa: E712
|
||||
.options(selectinload(LLMProviderModel.model_configurations))
|
||||
).all()
|
||||
query = (
|
||||
select(LLMProviderModel)
|
||||
.where(LLMProviderModel.is_auto_mode.is_(True))
|
||||
.options(selectinload(LLMProviderModel.model_configurations))
|
||||
)
|
||||
return list(db_session.scalars(query).all())
|
||||
|
||||
|
||||
def sync_auto_mode_models(
|
||||
@@ -620,7 +619,9 @@ def sync_auto_mode_models(
|
||||
|
||||
# Build the list of all visible models from the config
|
||||
# All models in the config are visible (default + additional_visible_models)
|
||||
recommended_visible_models = llm_recommendations.get_visible_models(provider.name)
|
||||
recommended_visible_models = llm_recommendations.get_visible_models(
|
||||
provider.provider
|
||||
)
|
||||
recommended_visible_model_names = [
|
||||
model.name for model in recommended_visible_models
|
||||
]
|
||||
@@ -635,11 +636,12 @@ def sync_auto_mode_models(
|
||||
).all()
|
||||
}
|
||||
|
||||
# Remove models that are no longer in GitHub config
|
||||
# Mark models that are no longer in GitHub config as not visible
|
||||
for model_name, model in existing_models.items():
|
||||
if model_name not in recommended_visible_model_names:
|
||||
db_session.delete(model)
|
||||
changes += 1
|
||||
if model.is_visible:
|
||||
model.is_visible = False
|
||||
changes += 1
|
||||
|
||||
# Add or update models from GitHub config
|
||||
for model_config in recommended_visible_models:
|
||||
@@ -669,7 +671,7 @@ def sync_auto_mode_models(
|
||||
changes += 1
|
||||
|
||||
# In Auto mode, default model is always set from GitHub config
|
||||
default_model = llm_recommendations.get_default_model(provider.name)
|
||||
default_model = llm_recommendations.get_default_model(provider.provider)
|
||||
if default_model and provider.default_model_name != default_model.name:
|
||||
provider.default_model_name = default_model.name
|
||||
changes += 1
|
||||
|
||||
@@ -369,12 +369,25 @@ class Notification(Base):
|
||||
dismissed: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
last_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
|
||||
first_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
|
||||
title: Mapped[str] = mapped_column(String)
|
||||
description: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
user: Mapped[User] = relationship("User", back_populates="notifications")
|
||||
additional_data: Mapped[dict | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
|
||||
# Unique constraint ix_notification_user_type_data on (user_id, notif_type, additional_data)
|
||||
# ensures notification deduplication for batch inserts. Defined in migration 8405ca81cc83.
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_notification_user_sort",
|
||||
"user_id",
|
||||
"dismissed",
|
||||
desc("first_shown"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
Association Tables
|
||||
@@ -532,7 +545,6 @@ class ConnectorCredentialPair(Base):
|
||||
"""
|
||||
|
||||
__tablename__ = "connector_credential_pair"
|
||||
is_user_file: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
# NOTE: this `id` column has to use `Sequence` instead of `autoincrement=True`
|
||||
# due to some SQLAlchemy quirks + this not being a primary key column
|
||||
id: Mapped[int] = mapped_column(
|
||||
@@ -2604,6 +2616,7 @@ class Tool(Base):
|
||||
__tablename__ = "tool"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
# The name of the tool that the LLM will see
|
||||
name: Mapped[str] = mapped_column(String, nullable=False)
|
||||
description: Mapped[str] = mapped_column(Text, nullable=True)
|
||||
# ID of the tool in the codebase, only applies for in-code tools.
|
||||
@@ -3569,9 +3582,6 @@ class UserFile(Base):
|
||||
back_populates="user_files",
|
||||
)
|
||||
file_id: Mapped[str] = mapped_column(nullable=False)
|
||||
document_id: Mapped[str] = mapped_column(
|
||||
nullable=False
|
||||
) # TODO(subash): legacy document_id, will be removed in a future migration
|
||||
name: Mapped[str] = mapped_column(nullable=False)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
default=datetime.datetime.utcnow
|
||||
@@ -3599,9 +3609,6 @@ class UserFile(Base):
|
||||
|
||||
link_url: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
content_type: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
document_id_migrated: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=True
|
||||
)
|
||||
|
||||
projects: Mapped[list["UserProject"]] = relationship(
|
||||
"UserProject",
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import cast
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
@@ -14,37 +19,52 @@ def create_notification(
|
||||
user_id: UUID | None,
|
||||
notif_type: NotificationType,
|
||||
db_session: Session,
|
||||
title: str,
|
||||
description: str | None = None,
|
||||
additional_data: dict | None = None,
|
||||
autocommit: bool = True,
|
||||
) -> Notification:
|
||||
# Check if an undismissed notification of the same type and data exists
|
||||
# Previously, we only matched the first identical, undismissed notification
|
||||
# Now, we assume some uniqueness to notifications
|
||||
# If we previously issued a notification that was dismissed, we no longer issue a new one
|
||||
|
||||
# Normalize additional_data to match the unique index behavior
|
||||
# The index uses COALESCE(additional_data, '{}'::jsonb)
|
||||
# We need to match this logic in our query
|
||||
additional_data_normalized = additional_data if additional_data is not None else {}
|
||||
|
||||
existing_notification = (
|
||||
db_session.query(Notification)
|
||||
.filter_by(
|
||||
user_id=user_id,
|
||||
notif_type=notif_type,
|
||||
dismissed=False,
|
||||
.filter_by(user_id=user_id, notif_type=notif_type)
|
||||
.filter(
|
||||
func.coalesce(Notification.additional_data, cast({}, postgresql.JSONB))
|
||||
== additional_data_normalized
|
||||
)
|
||||
.filter(Notification.additional_data == additional_data)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_notification:
|
||||
# Update the last_shown timestamp
|
||||
existing_notification.last_shown = func.now()
|
||||
db_session.commit()
|
||||
# Update the last_shown timestamp if the notification is not dismissed
|
||||
if not existing_notification.dismissed:
|
||||
existing_notification.last_shown = func.now()
|
||||
if autocommit:
|
||||
db_session.commit()
|
||||
return existing_notification
|
||||
|
||||
# Create a new notification if none exists
|
||||
notification = Notification(
|
||||
user_id=user_id,
|
||||
notif_type=notif_type,
|
||||
title=title,
|
||||
description=description,
|
||||
dismissed=False,
|
||||
last_shown=func.now(),
|
||||
first_shown=func.now(),
|
||||
additional_data=additional_data,
|
||||
)
|
||||
db_session.add(notification)
|
||||
db_session.commit()
|
||||
if autocommit:
|
||||
db_session.commit()
|
||||
return notification
|
||||
|
||||
|
||||
@@ -77,6 +97,11 @@ def get_notifications(
|
||||
query = query.where(Notification.dismissed.is_(False))
|
||||
if notif_type:
|
||||
query = query.where(Notification.notif_type == notif_type)
|
||||
# Sort: undismissed first, then by date (newest first)
|
||||
query = query.order_by(
|
||||
Notification.dismissed.asc(),
|
||||
Notification.first_shown.desc(),
|
||||
)
|
||||
return list(db_session.execute(query).scalars().all())
|
||||
|
||||
|
||||
@@ -95,6 +120,63 @@ def dismiss_notification(notification: Notification, db_session: Session) -> Non
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def batch_dismiss_notifications(
|
||||
notifications: list[Notification],
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
for notification in notifications:
|
||||
notification.dismissed = True
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def batch_create_notifications(
|
||||
user_ids: list[UUID],
|
||||
notif_type: NotificationType,
|
||||
db_session: Session,
|
||||
title: str,
|
||||
description: str | None = None,
|
||||
additional_data: dict | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Create notifications for multiple users in a single batch operation.
|
||||
Uses ON CONFLICT DO NOTHING for atomic idempotent inserts - if a user already
|
||||
has a notification with the same (user_id, notif_type, additional_data), the
|
||||
insert is silently skipped.
|
||||
|
||||
Returns the number of notifications created.
|
||||
|
||||
Relies on unique index on (user_id, notif_type, COALESCE(additional_data, '{}'))
|
||||
"""
|
||||
if not user_ids:
|
||||
return 0
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
# Use empty dict instead of None to match COALESCE behavior in the unique index
|
||||
additional_data_normalized = additional_data if additional_data is not None else {}
|
||||
|
||||
values = [
|
||||
{
|
||||
"user_id": uid,
|
||||
"notif_type": notif_type.value,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"dismissed": False,
|
||||
"last_shown": now,
|
||||
"first_shown": now,
|
||||
"additional_data": additional_data_normalized,
|
||||
}
|
||||
for uid in user_ids
|
||||
]
|
||||
|
||||
stmt = insert(Notification).values(values).on_conflict_do_nothing()
|
||||
result = db_session.execute(stmt)
|
||||
db_session.commit()
|
||||
|
||||
# rowcount returns number of rows inserted (excludes conflicts)
|
||||
# CursorResult has rowcount but session.execute type hints are too broad
|
||||
return result.rowcount if result.rowcount >= 0 else 0 # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def update_notification_last_shown(
|
||||
notification: Notification, db_session: Session
|
||||
) -> None:
|
||||
|
||||
@@ -187,13 +187,25 @@ def _get_persona_by_name(
|
||||
return result
|
||||
|
||||
|
||||
def make_persona_private(
|
||||
def update_persona_access(
|
||||
persona_id: int,
|
||||
creator_user_id: UUID | None,
|
||||
user_ids: list[UUID] | None,
|
||||
group_ids: list[int] | None,
|
||||
db_session: Session,
|
||||
is_public: bool | None = None,
|
||||
user_ids: list[UUID] | None = None,
|
||||
group_ids: list[int] | None = None,
|
||||
) -> None:
|
||||
"""Updates the access settings for a persona including public status and user shares.
|
||||
|
||||
NOTE: Callers are responsible for committing."""
|
||||
|
||||
if is_public is not None:
|
||||
persona = db_session.query(Persona).filter(Persona.id == persona_id).first()
|
||||
if persona:
|
||||
persona.is_public = is_public
|
||||
|
||||
# NOTE: For user-ids and group-ids, `None` means "leave unchanged", `[]` means "clear all shares",
|
||||
# and a non-empty list means "replace with these shares".
|
||||
if user_ids is not None:
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
@@ -205,17 +217,22 @@ def make_persona_private(
|
||||
create_notification(
|
||||
user_id=user_uuid,
|
||||
notif_type=NotificationType.PERSONA_SHARED,
|
||||
title="A new agent was shared with you!",
|
||||
db_session=db_session,
|
||||
additional_data=PersonaSharedNotificationData(
|
||||
persona_id=persona_id,
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
# MIT doesn't support group-based sharing, so we allow clearing (no-op since
|
||||
# there shouldn't be any) but raise an error if trying to add actual groups.
|
||||
if group_ids is not None:
|
||||
db_session.query(Persona__UserGroup).filter(
|
||||
Persona__UserGroup.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
# May cause error if someone switches down to MIT from EE
|
||||
if group_ids:
|
||||
raise NotImplementedError("Onyx MIT does not support private Personas")
|
||||
if group_ids:
|
||||
raise NotImplementedError("Onyx MIT does not support group-based sharing")
|
||||
|
||||
|
||||
def create_update_persona(
|
||||
@@ -281,20 +298,21 @@ def create_update_persona(
|
||||
llm_filter_extraction=create_persona_request.llm_filter_extraction,
|
||||
is_default_persona=create_persona_request.is_default_persona,
|
||||
user_file_ids=converted_user_file_ids,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
versioned_make_persona_private = fetch_versioned_implementation(
|
||||
"onyx.db.persona", "make_persona_private"
|
||||
versioned_update_persona_access = fetch_versioned_implementation(
|
||||
"onyx.db.persona", "update_persona_access"
|
||||
)
|
||||
|
||||
# Privatize Persona
|
||||
versioned_make_persona_private(
|
||||
versioned_update_persona_access(
|
||||
persona_id=persona.id,
|
||||
creator_user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
user_ids=create_persona_request.users,
|
||||
group_ids=create_persona_request.groups,
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to create persona")
|
||||
@@ -303,11 +321,13 @@ def create_update_persona(
|
||||
return FullPersonaSnapshot.from_model(persona)
|
||||
|
||||
|
||||
def update_persona_shared_users(
|
||||
def update_persona_shared(
|
||||
persona_id: int,
|
||||
user_ids: list[UUID],
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
user_ids: list[UUID] | None = None,
|
||||
group_ids: list[int] | None = None,
|
||||
is_public: bool | None = None,
|
||||
) -> None:
|
||||
"""Simplified version of `create_update_persona` which only touches the
|
||||
accessibility rather than any of the logic (e.g. prompt, connected data sources,
|
||||
@@ -316,22 +336,25 @@ def update_persona_shared_users(
|
||||
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
||||
)
|
||||
|
||||
if persona.is_public:
|
||||
raise HTTPException(status_code=400, detail="Cannot share public persona")
|
||||
if user and user.role != UserRole.ADMIN and persona.user_id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have permission to modify this persona"
|
||||
)
|
||||
|
||||
versioned_make_persona_private = fetch_versioned_implementation(
|
||||
"onyx.db.persona", "make_persona_private"
|
||||
versioned_update_persona_access = fetch_versioned_implementation(
|
||||
"onyx.db.persona", "update_persona_access"
|
||||
)
|
||||
|
||||
# Privatize Persona
|
||||
versioned_make_persona_private(
|
||||
versioned_update_persona_access(
|
||||
persona_id=persona_id,
|
||||
creator_user_id=user.id if user else None,
|
||||
user_ids=user_ids,
|
||||
group_ids=None,
|
||||
db_session=db_session,
|
||||
is_public=is_public,
|
||||
user_ids=user_ids,
|
||||
group_ids=group_ids,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_persona_public_status(
|
||||
persona_id: int,
|
||||
|
||||
@@ -21,6 +21,7 @@ from onyx.db.models import UserFile
|
||||
from onyx.db.models import UserProject
|
||||
from onyx.server.documents.connector import upload_files
|
||||
from onyx.server.features.projects.projects_file_utils import categorize_uploaded_files
|
||||
from onyx.server.features.projects.projects_file_utils import RejectedFile
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
@@ -29,8 +30,7 @@ logger = setup_logger()
|
||||
|
||||
class CategorizedFilesResult(BaseModel):
|
||||
user_files: list[UserFile]
|
||||
non_accepted_files: list[str]
|
||||
unsupported_files: list[str]
|
||||
rejected_files: list[RejectedFile]
|
||||
id_to_temp_id: dict[str, str]
|
||||
# Allow SQLAlchemy ORM models inside this result container
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
@@ -56,8 +56,7 @@ def create_user_files(
|
||||
# Should revisit to decide whether this should be a feature.
|
||||
upload_response = upload_files(categorized_files.acceptable, FileOrigin.USER_FILE)
|
||||
user_files = []
|
||||
non_accepted_files = categorized_files.non_accepted
|
||||
unsupported_files = categorized_files.unsupported
|
||||
rejected_files = categorized_files.rejected
|
||||
id_to_temp_id: dict[str, str] = {}
|
||||
# Pair returned storage paths with the same set of acceptable files we uploaded
|
||||
for file_path, file in zip(
|
||||
@@ -73,7 +72,6 @@ def create_user_files(
|
||||
id=new_id,
|
||||
user_id=user.id if user else None,
|
||||
file_id=file_path,
|
||||
document_id=str(new_id),
|
||||
name=file.filename,
|
||||
token_count=categorized_files.acceptable_file_to_token_count[
|
||||
file.filename or ""
|
||||
@@ -96,8 +94,7 @@ def create_user_files(
|
||||
db_session.commit()
|
||||
return CategorizedFilesResult(
|
||||
user_files=user_files,
|
||||
non_accepted_files=non_accepted_files,
|
||||
unsupported_files=unsupported_files,
|
||||
rejected_files=rejected_files,
|
||||
id_to_temp_id=id_to_temp_id,
|
||||
)
|
||||
|
||||
@@ -122,17 +119,14 @@ def upload_files_to_user_files_with_indexing(
|
||||
temp_id_map=temp_id_map,
|
||||
)
|
||||
user_files = categorized_files_result.user_files
|
||||
non_accepted_files = categorized_files_result.non_accepted_files
|
||||
unsupported_files = categorized_files_result.unsupported_files
|
||||
rejected_files = categorized_files_result.rejected_files
|
||||
id_to_temp_id = categorized_files_result.id_to_temp_id
|
||||
# Trigger per-file processing immediately for the current tenant
|
||||
tenant_id = get_current_tenant_id()
|
||||
if non_accepted_files:
|
||||
for filename in non_accepted_files:
|
||||
logger.warning(f"Non-accepted file: {filename}")
|
||||
if unsupported_files:
|
||||
for filename in unsupported_files:
|
||||
logger.warning(f"Unsupported file: {filename}")
|
||||
for rejected_file in rejected_files:
|
||||
logger.warning(
|
||||
f"File {rejected_file.filename} rejected for {rejected_file.reason}"
|
||||
)
|
||||
for user_file in user_files:
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
@@ -146,8 +140,7 @@ def upload_files_to_user_files_with_indexing(
|
||||
|
||||
return CategorizedFilesResult(
|
||||
user_files=user_files,
|
||||
non_accepted_files=non_accepted_files,
|
||||
unsupported_files=unsupported_files,
|
||||
rejected_files=rejected_files,
|
||||
id_to_temp_id=id_to_temp_id,
|
||||
)
|
||||
|
||||
|
||||
94
backend/onyx/db/release_notes.py
Normal file
94
backend/onyx/db/release_notes.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Database functions for release notes functionality."""
|
||||
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import INSTANCE_TYPE
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import ONYX_UTM_SOURCE
|
||||
from onyx.db.models import User
|
||||
from onyx.db.notification import batch_create_notifications
|
||||
from onyx.server.features.release_notes.constants import DOCS_CHANGELOG_BASE_URL
|
||||
from onyx.server.features.release_notes.models import ReleaseNoteEntry
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def create_release_notifications_for_versions(
|
||||
db_session: Session,
|
||||
release_note_entries: list[ReleaseNoteEntry],
|
||||
) -> int:
|
||||
"""
|
||||
Create release notes notifications for each release note entry.
|
||||
Uses batch_create_notifications for efficient bulk insertion.
|
||||
|
||||
If a user already has a notification for a specific version (dismissed or not),
|
||||
no new one is created (handled by unique constraint on additional_data).
|
||||
|
||||
Note: Entries should already be filtered by app_version before calling this
|
||||
function. The filtering happens in _parse_mdx_to_release_note_entries().
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
release_note_entries: List of release note entries to notify about (pre-filtered)
|
||||
|
||||
Returns:
|
||||
Total number of notifications created across all versions.
|
||||
"""
|
||||
if not release_note_entries:
|
||||
logger.debug("No release note entries to notify about")
|
||||
return 0
|
||||
|
||||
# Get active users and exclude API key users
|
||||
user_ids = list(
|
||||
db_session.scalars(
|
||||
select(User.id).where( # type: ignore
|
||||
User.is_active == True, # noqa: E712
|
||||
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]),
|
||||
User.email.endswith(DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN).is_(False), # type: ignore[attr-defined]
|
||||
)
|
||||
).all()
|
||||
)
|
||||
|
||||
total_created = 0
|
||||
for entry in release_note_entries:
|
||||
# Convert version to anchor format for external docs links
|
||||
# v2.7.0 -> v2-7-0
|
||||
version_anchor = entry.version.replace(".", "-")
|
||||
|
||||
# Build UTM parameters for tracking
|
||||
utm_params = {
|
||||
"utm_source": ONYX_UTM_SOURCE,
|
||||
"utm_medium": "notification",
|
||||
"utm_campaign": INSTANCE_TYPE,
|
||||
"utm_content": f"release_notes-{entry.version}",
|
||||
}
|
||||
|
||||
link = f"{DOCS_CHANGELOG_BASE_URL}#{version_anchor}?{urlencode(utm_params)}"
|
||||
|
||||
additional_data: dict[str, str] = {
|
||||
"version": entry.version,
|
||||
"link": link,
|
||||
}
|
||||
|
||||
created_count = batch_create_notifications(
|
||||
user_ids,
|
||||
NotificationType.RELEASE_NOTES,
|
||||
db_session,
|
||||
title=entry.title,
|
||||
description=f"Check out what's new in {entry.version}",
|
||||
additional_data=additional_data,
|
||||
)
|
||||
total_created += created_count
|
||||
|
||||
logger.debug(
|
||||
f"Created {created_count} release notes notifications "
|
||||
f"(version {entry.version}, {len(user_ids)} eligible users)"
|
||||
)
|
||||
|
||||
return total_created
|
||||
@@ -126,7 +126,7 @@ def check_and_perform_index_swap(db_session: Session) -> SearchSettings | None:
|
||||
did change.
|
||||
"""
|
||||
# Default CC-pair created for Ingestion API unused here
|
||||
all_cc_pairs = get_connector_credential_pairs(db_session, include_user_files=True)
|
||||
all_cc_pairs = get_connector_credential_pairs(db_session)
|
||||
cc_pair_count = max(len(all_cc_pairs) - 1, 0)
|
||||
new_search_settings = get_secondary_search_settings(db_session)
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ def create_or_add_document_tag(
|
||||
is_list=False,
|
||||
)
|
||||
insert_stmt = insert_stmt.on_conflict_do_nothing(
|
||||
index_elements=["tag_key", "tag_value", "source", "is_list"]
|
||||
constraint="_tag_key_value_source_list_uc"
|
||||
)
|
||||
db_session.execute(insert_stmt)
|
||||
|
||||
@@ -98,7 +98,7 @@ def create_or_add_document_tag_list(
|
||||
is_list=True,
|
||||
)
|
||||
insert_stmt = insert_stmt.on_conflict_do_nothing(
|
||||
index_elements=["tag_key", "tag_value", "source", "is_list"]
|
||||
constraint="_tag_key_value_source_list_uc"
|
||||
)
|
||||
db_session.execute(insert_stmt)
|
||||
|
||||
|
||||
@@ -113,7 +113,6 @@ def upsert_web_search_provider(
|
||||
if activate:
|
||||
set_active_web_search_provider(provider_id=provider.id, db_session=db_session)
|
||||
|
||||
db_session.commit()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
@@ -269,7 +268,6 @@ def upsert_web_content_provider(
|
||||
if activate:
|
||||
set_active_web_content_provider(provider_id=provider.id, db_session=db_session)
|
||||
|
||||
db_session.commit()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ from onyx.configs.constants import MessageType
|
||||
from onyx.db.tools import get_tool_by_name
|
||||
from onyx.deep_research.dr_mock_tools import get_clarification_tool_definitions
|
||||
from onyx.deep_research.dr_mock_tools import get_orchestrator_tools
|
||||
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_DB_NAME
|
||||
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_TOOL_NAME
|
||||
from onyx.deep_research.dr_mock_tools import THINK_TOOL_RESPONSE_MESSAGE
|
||||
from onyx.deep_research.dr_mock_tools import THINK_TOOL_RESPONSE_TOKEN_COUNT
|
||||
@@ -150,6 +149,9 @@ def generate_final_report(
|
||||
is_deep_research=True,
|
||||
)
|
||||
|
||||
# Save citation mapping to state_container so citations are persisted
|
||||
state_container.set_citation_mapping(citation_processor.citation_to_doc)
|
||||
|
||||
final_report = llm_step_result.answer
|
||||
if final_report is None:
|
||||
raise ValueError("LLM failed to generate the final deep research report")
|
||||
@@ -217,35 +219,90 @@ def run_deep_research_llm_loop(
|
||||
else ""
|
||||
)
|
||||
if not skip_clarification:
|
||||
clarification_prompt = CLARIFICATION_PROMPT.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False),
|
||||
internal_search_clarification_guidance=internal_search_clarification_guidance,
|
||||
)
|
||||
with function_span("clarification_step") as span:
|
||||
clarification_prompt = CLARIFICATION_PROMPT.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False),
|
||||
internal_search_clarification_guidance=internal_search_clarification_guidance,
|
||||
)
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=clarification_prompt,
|
||||
token_count=300, # Skips the exact token count but has enough leeway
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
|
||||
truncated_message_history = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
)
|
||||
|
||||
llm_step_result, _ = run_llm_step(
|
||||
emitter=emitter,
|
||||
history=truncated_message_history,
|
||||
tool_definitions=get_clarification_tool_definitions(),
|
||||
tool_choice=ToolChoiceOptions.AUTO,
|
||||
llm=llm,
|
||||
placement=Placement(turn_index=0),
|
||||
# No citations in this step, it should just pass through all
|
||||
# tokens directly so initialized as an empty citation processor
|
||||
citation_processor=None,
|
||||
state_container=state_container,
|
||||
final_documents=None,
|
||||
user_identity=user_identity,
|
||||
is_deep_research=True,
|
||||
)
|
||||
|
||||
if not llm_step_result.tool_calls:
|
||||
# Mark this turn as a clarification question
|
||||
state_container.set_is_clarification(True)
|
||||
span.span_data.output = "clarification_required"
|
||||
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=OverallStop(type="stop"),
|
||||
)
|
||||
)
|
||||
|
||||
# If a clarification is asked, we need to end this turn and wait on user input
|
||||
return
|
||||
|
||||
#########################################################
|
||||
# RESEARCH PLAN STEP
|
||||
#########################################################
|
||||
with function_span("research_plan_step") as span:
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=clarification_prompt,
|
||||
token_count=300, # Skips the exact token count but has enough leeway
|
||||
message=RESEARCH_PLAN_PROMPT.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False)
|
||||
),
|
||||
token_count=300,
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
|
||||
reminder_message = ChatMessageSimple(
|
||||
message=RESEARCH_PLAN_REMINDER,
|
||||
token_count=100,
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
truncated_message_history = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
simple_chat_history=simple_chat_history + [reminder_message],
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT + 1,
|
||||
)
|
||||
|
||||
llm_step_result, _ = run_llm_step(
|
||||
emitter=emitter,
|
||||
research_plan_generator = run_llm_step_pkt_generator(
|
||||
history=truncated_message_history,
|
||||
tool_definitions=get_clarification_tool_definitions(),
|
||||
tool_choice=ToolChoiceOptions.AUTO,
|
||||
tool_definitions=[],
|
||||
tool_choice=ToolChoiceOptions.NONE,
|
||||
llm=llm,
|
||||
placement=Placement(turn_index=0),
|
||||
# No citations in this step, it should just pass through all
|
||||
# tokens directly so initialized as an empty citation processor
|
||||
citation_processor=None,
|
||||
state_container=state_container,
|
||||
final_documents=None,
|
||||
@@ -253,301 +310,177 @@ def run_deep_research_llm_loop(
|
||||
is_deep_research=True,
|
||||
)
|
||||
|
||||
if not llm_step_result.tool_calls:
|
||||
# Mark this turn as a clarification question
|
||||
state_container.set_is_clarification(True)
|
||||
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=0), obj=OverallStop(type="stop")
|
||||
)
|
||||
)
|
||||
|
||||
# If a clarification is asked, we need to end this turn and wait on user input
|
||||
return
|
||||
|
||||
#########################################################
|
||||
# RESEARCH PLAN STEP
|
||||
#########################################################
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=RESEARCH_PLAN_PROMPT.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False)
|
||||
),
|
||||
token_count=300,
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
reminder_message = ChatMessageSimple(
|
||||
message=RESEARCH_PLAN_REMINDER,
|
||||
token_count=100,
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
truncated_message_history = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history + [reminder_message],
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT + 1,
|
||||
)
|
||||
|
||||
research_plan_generator = run_llm_step_pkt_generator(
|
||||
history=truncated_message_history,
|
||||
tool_definitions=[],
|
||||
tool_choice=ToolChoiceOptions.NONE,
|
||||
llm=llm,
|
||||
placement=Placement(turn_index=0),
|
||||
citation_processor=None,
|
||||
state_container=state_container,
|
||||
final_documents=None,
|
||||
user_identity=user_identity,
|
||||
is_deep_research=True,
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
packet = next(research_plan_generator)
|
||||
# Translate AgentResponseStart/Delta packets to DeepResearchPlanStart/Delta
|
||||
# The LLM response from this prompt is the research plan
|
||||
if isinstance(packet.obj, AgentResponseStart):
|
||||
while True:
|
||||
try:
|
||||
packet = next(research_plan_generator)
|
||||
# Translate AgentResponseStart/Delta packets to DeepResearchPlanStart/Delta
|
||||
# The LLM response from this prompt is the research plan
|
||||
if isinstance(packet.obj, AgentResponseStart):
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=packet.placement,
|
||||
obj=DeepResearchPlanStart(),
|
||||
)
|
||||
)
|
||||
elif isinstance(packet.obj, AgentResponseDelta):
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=packet.placement,
|
||||
obj=DeepResearchPlanDelta(content=packet.obj.content),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Pass through other packet types (e.g., ReasoningStart, ReasoningDelta, etc.)
|
||||
emitter.emit(packet)
|
||||
except StopIteration as e:
|
||||
llm_step_result, reasoned = e.value
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=packet.placement,
|
||||
obj=DeepResearchPlanStart(),
|
||||
# Marks the last turn end which should be the plan generation
|
||||
placement=Placement(
|
||||
turn_index=1 if reasoned else 0,
|
||||
),
|
||||
obj=SectionEnd(),
|
||||
)
|
||||
)
|
||||
elif isinstance(packet.obj, AgentResponseDelta):
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=packet.placement,
|
||||
obj=DeepResearchPlanDelta(content=packet.obj.content),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Pass through other packet types (e.g., ReasoningStart, ReasoningDelta, etc.)
|
||||
emitter.emit(packet)
|
||||
except StopIteration as e:
|
||||
llm_step_result, reasoned = e.value
|
||||
emitter.emit(
|
||||
Packet(
|
||||
# Marks the last turn end which should be the plan generation
|
||||
placement=Placement(
|
||||
turn_index=1 if reasoned else 0,
|
||||
),
|
||||
obj=SectionEnd(),
|
||||
)
|
||||
)
|
||||
if reasoned:
|
||||
orchestrator_start_turn_index += 1
|
||||
break
|
||||
llm_step_result = cast(LlmStepResult, llm_step_result)
|
||||
if reasoned:
|
||||
orchestrator_start_turn_index += 1
|
||||
break
|
||||
llm_step_result = cast(LlmStepResult, llm_step_result)
|
||||
|
||||
research_plan = llm_step_result.answer
|
||||
research_plan = llm_step_result.answer
|
||||
span.span_data.output = research_plan if research_plan else None
|
||||
|
||||
#########################################################
|
||||
# RESEARCH EXECUTION STEP
|
||||
#########################################################
|
||||
is_reasoning_model = model_is_reasoning_model(
|
||||
llm.config.model_name, llm.config.model_provider
|
||||
)
|
||||
with function_span("research_execution_step") as span:
|
||||
is_reasoning_model = model_is_reasoning_model(
|
||||
llm.config.model_name, llm.config.model_provider
|
||||
)
|
||||
|
||||
max_orchestrator_cycles = (
|
||||
MAX_ORCHESTRATOR_CYCLES
|
||||
if not is_reasoning_model
|
||||
else MAX_ORCHESTRATOR_CYCLES_REASONING
|
||||
)
|
||||
max_orchestrator_cycles = (
|
||||
MAX_ORCHESTRATOR_CYCLES
|
||||
if not is_reasoning_model
|
||||
else MAX_ORCHESTRATOR_CYCLES_REASONING
|
||||
)
|
||||
|
||||
orchestrator_prompt_template = (
|
||||
ORCHESTRATOR_PROMPT
|
||||
if not is_reasoning_model
|
||||
else ORCHESTRATOR_PROMPT_REASONING
|
||||
)
|
||||
orchestrator_prompt_template = (
|
||||
ORCHESTRATOR_PROMPT
|
||||
if not is_reasoning_model
|
||||
else ORCHESTRATOR_PROMPT_REASONING
|
||||
)
|
||||
|
||||
internal_search_research_task_guidance = (
|
||||
INTERNAL_SEARCH_RESEARCH_TASK_GUIDANCE
|
||||
if include_internal_search_tunings
|
||||
else ""
|
||||
)
|
||||
token_count_prompt = orchestrator_prompt_template.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False),
|
||||
current_cycle_count=1,
|
||||
max_cycles=max_orchestrator_cycles,
|
||||
research_plan=research_plan,
|
||||
internal_search_research_task_guidance=internal_search_research_task_guidance,
|
||||
)
|
||||
orchestration_tokens = token_counter(token_count_prompt)
|
||||
|
||||
reasoning_cycles = 0
|
||||
most_recent_reasoning: str | None = None
|
||||
citation_mapping: CitationMapping = {}
|
||||
final_turn_index: int = (
|
||||
orchestrator_start_turn_index # Track the final turn_index for stop packet
|
||||
)
|
||||
for cycle in range(max_orchestrator_cycles):
|
||||
if cycle == max_orchestrator_cycles - 1:
|
||||
# If it's the last cycle, forcibly generate the final report
|
||||
report_turn_index = (
|
||||
orchestrator_start_turn_index + cycle + reasoning_cycles
|
||||
)
|
||||
report_reasoned = generate_final_report(
|
||||
history=simple_chat_history,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
state_container=state_container,
|
||||
emitter=emitter,
|
||||
turn_index=report_turn_index,
|
||||
citation_mapping=citation_mapping,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
# Update final_turn_index: base + 1 for the report itself + 1 if reasoning occurred
|
||||
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
|
||||
break
|
||||
|
||||
research_agent_calls: list[ToolCallKickoff] = []
|
||||
|
||||
orchestrator_prompt = orchestrator_prompt_template.format(
|
||||
internal_search_research_task_guidance = (
|
||||
INTERNAL_SEARCH_RESEARCH_TASK_GUIDANCE
|
||||
if include_internal_search_tunings
|
||||
else ""
|
||||
)
|
||||
token_count_prompt = orchestrator_prompt_template.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False),
|
||||
current_cycle_count=cycle,
|
||||
current_cycle_count=1,
|
||||
max_cycles=max_orchestrator_cycles,
|
||||
research_plan=research_plan,
|
||||
internal_search_research_task_guidance=internal_search_research_task_guidance,
|
||||
)
|
||||
orchestration_tokens = token_counter(token_count_prompt)
|
||||
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=orchestrator_prompt,
|
||||
token_count=orchestration_tokens,
|
||||
message_type=MessageType.SYSTEM,
|
||||
reasoning_cycles = 0
|
||||
most_recent_reasoning: str | None = None
|
||||
citation_mapping: CitationMapping = {}
|
||||
final_turn_index: int = (
|
||||
orchestrator_start_turn_index # Track the final turn_index for stop packet
|
||||
)
|
||||
for cycle in range(max_orchestrator_cycles):
|
||||
if cycle == max_orchestrator_cycles - 1:
|
||||
# If it's the last cycle, forcibly generate the final report
|
||||
report_turn_index = (
|
||||
orchestrator_start_turn_index + cycle + reasoning_cycles
|
||||
)
|
||||
report_reasoned = generate_final_report(
|
||||
history=simple_chat_history,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
state_container=state_container,
|
||||
emitter=emitter,
|
||||
turn_index=report_turn_index,
|
||||
citation_mapping=citation_mapping,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
# Update final_turn_index: base + 1 for the report itself + 1 if reasoning occurred
|
||||
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
|
||||
break
|
||||
|
||||
truncated_message_history = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
)
|
||||
research_agent_calls: list[ToolCallKickoff] = []
|
||||
|
||||
# Use think tool processor for non-reasoning models to convert
|
||||
# think_tool calls to reasoning content
|
||||
custom_processor = (
|
||||
create_think_tool_token_processor() if not is_reasoning_model else None
|
||||
)
|
||||
|
||||
llm_step_result, has_reasoned = run_llm_step(
|
||||
emitter=emitter,
|
||||
history=truncated_message_history,
|
||||
tool_definitions=get_orchestrator_tools(
|
||||
include_think_tool=not is_reasoning_model
|
||||
),
|
||||
tool_choice=ToolChoiceOptions.REQUIRED,
|
||||
llm=llm,
|
||||
placement=Placement(
|
||||
turn_index=orchestrator_start_turn_index + cycle + reasoning_cycles
|
||||
),
|
||||
# No citations in this step, it should just pass through all
|
||||
# tokens directly so initialized as an empty citation processor
|
||||
citation_processor=DynamicCitationProcessor(),
|
||||
state_container=state_container,
|
||||
final_documents=None,
|
||||
user_identity=user_identity,
|
||||
custom_token_processor=custom_processor,
|
||||
is_deep_research=True,
|
||||
)
|
||||
if has_reasoned:
|
||||
reasoning_cycles += 1
|
||||
|
||||
tool_calls = llm_step_result.tool_calls or []
|
||||
|
||||
if not tool_calls and cycle == 0:
|
||||
raise RuntimeError(
|
||||
"Deep Research failed to generate any research tasks for the agents."
|
||||
orchestrator_prompt = orchestrator_prompt_template.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False),
|
||||
current_cycle_count=cycle,
|
||||
max_cycles=max_orchestrator_cycles,
|
||||
research_plan=research_plan,
|
||||
internal_search_research_task_guidance=internal_search_research_task_guidance,
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
# Basically hope that this is an infrequent occurence and hopefully multiple research
|
||||
# cycles have already ran
|
||||
logger.warning("No tool calls found, this should not happen.")
|
||||
report_turn_index = (
|
||||
orchestrator_start_turn_index + cycle + reasoning_cycles
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=orchestrator_prompt,
|
||||
token_count=orchestration_tokens,
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
report_reasoned = generate_final_report(
|
||||
history=simple_chat_history,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
state_container=state_container,
|
||||
|
||||
truncated_message_history = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
)
|
||||
|
||||
# Use think tool processor for non-reasoning models to convert
|
||||
# think_tool calls to reasoning content
|
||||
custom_processor = (
|
||||
create_think_tool_token_processor()
|
||||
if not is_reasoning_model
|
||||
else None
|
||||
)
|
||||
|
||||
llm_step_result, has_reasoned = run_llm_step(
|
||||
emitter=emitter,
|
||||
turn_index=report_turn_index,
|
||||
citation_mapping=citation_mapping,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
|
||||
break
|
||||
|
||||
special_tool_calls = check_special_tool_calls(tool_calls=tool_calls)
|
||||
|
||||
if special_tool_calls.generate_report_tool_call:
|
||||
report_turn_index = (
|
||||
special_tool_calls.generate_report_tool_call.placement.turn_index
|
||||
)
|
||||
report_reasoned = generate_final_report(
|
||||
history=simple_chat_history,
|
||||
history=truncated_message_history,
|
||||
tool_definitions=get_orchestrator_tools(
|
||||
include_think_tool=not is_reasoning_model
|
||||
),
|
||||
tool_choice=ToolChoiceOptions.REQUIRED,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
placement=Placement(
|
||||
turn_index=orchestrator_start_turn_index
|
||||
+ cycle
|
||||
+ reasoning_cycles
|
||||
),
|
||||
# No citations in this step, it should just pass through all
|
||||
# tokens directly so initialized as an empty citation processor
|
||||
citation_processor=DynamicCitationProcessor(),
|
||||
state_container=state_container,
|
||||
emitter=emitter,
|
||||
turn_index=report_turn_index,
|
||||
citation_mapping=citation_mapping,
|
||||
final_documents=None,
|
||||
user_identity=user_identity,
|
||||
saved_reasoning=most_recent_reasoning,
|
||||
custom_token_processor=custom_processor,
|
||||
is_deep_research=True,
|
||||
)
|
||||
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
|
||||
break
|
||||
elif special_tool_calls.think_tool_call:
|
||||
think_tool_call = special_tool_calls.think_tool_call
|
||||
# Only process the THINK_TOOL and skip all other tool calls
|
||||
# This will not actually get saved to the db as a tool call but we'll attach it to the tool(s) called after
|
||||
# it as if it were just a reasoning model doing it. In the chat history, because it happens in 2 steps,
|
||||
# we will show it as a separate message.
|
||||
# NOTE: This does not need to increment the reasoning cycles because the custom token processor causes
|
||||
# the LLM step to handle this
|
||||
with function_span("think_tool") as span:
|
||||
span.span_data.input = str(think_tool_call.tool_args)
|
||||
most_recent_reasoning = state_container.reasoning_tokens
|
||||
tool_call_message = think_tool_call.to_msg_str()
|
||||
if has_reasoned:
|
||||
reasoning_cycles += 1
|
||||
|
||||
think_tool_msg = ChatMessageSimple(
|
||||
message=tool_call_message,
|
||||
token_count=token_counter(tool_call_message),
|
||||
message_type=MessageType.TOOL_CALL,
|
||||
tool_call_id=think_tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
tool_calls = llm_step_result.tool_calls or []
|
||||
|
||||
if not tool_calls and cycle == 0:
|
||||
raise RuntimeError(
|
||||
"Deep Research failed to generate any research tasks for the agents."
|
||||
)
|
||||
simple_chat_history.append(think_tool_msg)
|
||||
|
||||
think_tool_response_msg = ChatMessageSimple(
|
||||
message=THINK_TOOL_RESPONSE_MESSAGE,
|
||||
token_count=THINK_TOOL_RESPONSE_TOKEN_COUNT,
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id=think_tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(think_tool_response_msg)
|
||||
span.span_data.output = THINK_TOOL_RESPONSE_MESSAGE
|
||||
continue
|
||||
else:
|
||||
for tool_call in tool_calls:
|
||||
if tool_call.tool_name != RESEARCH_AGENT_TOOL_NAME:
|
||||
logger.warning(f"Unexpected tool call: {tool_call.tool_name}")
|
||||
continue
|
||||
|
||||
research_agent_calls.append(tool_call)
|
||||
|
||||
if not research_agent_calls:
|
||||
logger.warning(
|
||||
"No research agent tool calls found, this should not happen."
|
||||
)
|
||||
if not tool_calls:
|
||||
# Basically hope that this is an infrequent occurence and hopefully multiple research
|
||||
# cycles have already ran
|
||||
logger.warning("No tool calls found, this should not happen.")
|
||||
report_turn_index = (
|
||||
orchestrator_start_turn_index + cycle + reasoning_cycles
|
||||
)
|
||||
@@ -564,91 +497,177 @@ def run_deep_research_llm_loop(
|
||||
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
|
||||
break
|
||||
|
||||
if len(research_agent_calls) > 1:
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(
|
||||
turn_index=research_agent_calls[0].placement.turn_index
|
||||
),
|
||||
obj=TopLevelBranching(
|
||||
num_parallel_branches=len(research_agent_calls)
|
||||
),
|
||||
special_tool_calls = check_special_tool_calls(tool_calls=tool_calls)
|
||||
|
||||
if special_tool_calls.generate_report_tool_call:
|
||||
report_turn_index = (
|
||||
special_tool_calls.generate_report_tool_call.placement.turn_index
|
||||
)
|
||||
report_reasoned = generate_final_report(
|
||||
history=simple_chat_history,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
state_container=state_container,
|
||||
emitter=emitter,
|
||||
turn_index=report_turn_index,
|
||||
citation_mapping=citation_mapping,
|
||||
user_identity=user_identity,
|
||||
saved_reasoning=most_recent_reasoning,
|
||||
)
|
||||
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
|
||||
break
|
||||
elif special_tool_calls.think_tool_call:
|
||||
think_tool_call = special_tool_calls.think_tool_call
|
||||
# Only process the THINK_TOOL and skip all other tool calls
|
||||
# This will not actually get saved to the db as a tool call but we'll attach it to the tool(s) called after
|
||||
# it as if it were just a reasoning model doing it. In the chat history, because it happens in 2 steps,
|
||||
# we will show it as a separate message.
|
||||
# NOTE: This does not need to increment the reasoning cycles because the custom token processor causes
|
||||
# the LLM step to handle this
|
||||
with function_span("think_tool") as span:
|
||||
span.span_data.input = str(think_tool_call.tool_args)
|
||||
most_recent_reasoning = state_container.reasoning_tokens
|
||||
tool_call_message = think_tool_call.to_msg_str()
|
||||
|
||||
think_tool_msg = ChatMessageSimple(
|
||||
message=tool_call_message,
|
||||
token_count=token_counter(tool_call_message),
|
||||
message_type=MessageType.TOOL_CALL,
|
||||
tool_call_id=think_tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
)
|
||||
simple_chat_history.append(think_tool_msg)
|
||||
|
||||
research_results = run_research_agent_calls(
|
||||
# The tool calls here contain the placement information
|
||||
research_agent_calls=research_agent_calls,
|
||||
parent_tool_call_ids=[
|
||||
tool_call.tool_call_id for tool_call in tool_calls
|
||||
],
|
||||
tools=allowed_tools,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
llm=llm,
|
||||
is_reasoning_model=is_reasoning_model,
|
||||
token_counter=token_counter,
|
||||
citation_mapping=citation_mapping,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
|
||||
citation_mapping = research_results.citation_mapping
|
||||
|
||||
for tab_index, report in enumerate(
|
||||
research_results.intermediate_reports
|
||||
):
|
||||
if report is None:
|
||||
# The LLM will not see that this research was even attempted, it may try
|
||||
# something similar again but this is not bad.
|
||||
logger.error(
|
||||
f"Research agent call at tab_index {tab_index} failed, skipping"
|
||||
think_tool_response_msg = ChatMessageSimple(
|
||||
message=THINK_TOOL_RESPONSE_MESSAGE,
|
||||
token_count=THINK_TOOL_RESPONSE_TOKEN_COUNT,
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id=think_tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
continue
|
||||
simple_chat_history.append(think_tool_response_msg)
|
||||
span.span_data.output = THINK_TOOL_RESPONSE_MESSAGE
|
||||
continue
|
||||
else:
|
||||
for tool_call in tool_calls:
|
||||
if tool_call.tool_name != RESEARCH_AGENT_TOOL_NAME:
|
||||
logger.warning(
|
||||
f"Unexpected tool call: {tool_call.tool_name}"
|
||||
)
|
||||
continue
|
||||
|
||||
current_tool_call = research_agent_calls[tab_index]
|
||||
tool_call_info = ToolCallInfo(
|
||||
parent_tool_call_id=None,
|
||||
turn_index=orchestrator_start_turn_index
|
||||
+ cycle
|
||||
+ reasoning_cycles,
|
||||
tab_index=tab_index,
|
||||
tool_name=current_tool_call.tool_name,
|
||||
tool_call_id=current_tool_call.tool_call_id,
|
||||
tool_id=get_tool_by_name(
|
||||
tool_name=RESEARCH_AGENT_DB_NAME, db_session=db_session
|
||||
).id,
|
||||
reasoning_tokens=llm_step_result.reasoning
|
||||
or most_recent_reasoning,
|
||||
tool_call_arguments=current_tool_call.tool_args,
|
||||
tool_call_response=report,
|
||||
search_docs=None, # Intermediate docs are not saved/shown
|
||||
generated_images=None,
|
||||
research_agent_calls.append(tool_call)
|
||||
|
||||
if not research_agent_calls:
|
||||
logger.warning(
|
||||
"No research agent tool calls found, this should not happen."
|
||||
)
|
||||
report_turn_index = (
|
||||
orchestrator_start_turn_index + cycle + reasoning_cycles
|
||||
)
|
||||
report_reasoned = generate_final_report(
|
||||
history=simple_chat_history,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
state_container=state_container,
|
||||
emitter=emitter,
|
||||
turn_index=report_turn_index,
|
||||
citation_mapping=citation_mapping,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
final_turn_index = report_turn_index + (
|
||||
1 if report_reasoned else 0
|
||||
)
|
||||
break
|
||||
|
||||
if len(research_agent_calls) > 1:
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(
|
||||
turn_index=research_agent_calls[
|
||||
0
|
||||
].placement.turn_index
|
||||
),
|
||||
obj=TopLevelBranching(
|
||||
num_parallel_branches=len(research_agent_calls)
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
research_results = run_research_agent_calls(
|
||||
# The tool calls here contain the placement information
|
||||
research_agent_calls=research_agent_calls,
|
||||
parent_tool_call_ids=[
|
||||
tool_call.tool_call_id for tool_call in tool_calls
|
||||
],
|
||||
tools=allowed_tools,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
llm=llm,
|
||||
is_reasoning_model=is_reasoning_model,
|
||||
token_counter=token_counter,
|
||||
citation_mapping=citation_mapping,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
state_container.add_tool_call(tool_call_info)
|
||||
|
||||
tool_call_message = current_tool_call.to_msg_str()
|
||||
tool_call_token_count = token_counter(tool_call_message)
|
||||
citation_mapping = research_results.citation_mapping
|
||||
|
||||
tool_call_msg = ChatMessageSimple(
|
||||
message=tool_call_message,
|
||||
token_count=tool_call_token_count,
|
||||
message_type=MessageType.TOOL_CALL,
|
||||
tool_call_id=current_tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(tool_call_msg)
|
||||
for tab_index, report in enumerate(
|
||||
research_results.intermediate_reports
|
||||
):
|
||||
if report is None:
|
||||
# The LLM will not see that this research was even attempted, it may try
|
||||
# something similar again but this is not bad.
|
||||
logger.error(
|
||||
f"Research agent call at tab_index {tab_index} failed, skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
tool_call_response_msg = ChatMessageSimple(
|
||||
message=report,
|
||||
token_count=token_counter(report),
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id=current_tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(tool_call_response_msg)
|
||||
current_tool_call = research_agent_calls[tab_index]
|
||||
tool_call_info = ToolCallInfo(
|
||||
parent_tool_call_id=None,
|
||||
turn_index=orchestrator_start_turn_index
|
||||
+ cycle
|
||||
+ reasoning_cycles,
|
||||
tab_index=tab_index,
|
||||
tool_name=current_tool_call.tool_name,
|
||||
tool_call_id=current_tool_call.tool_call_id,
|
||||
tool_id=get_tool_by_name(
|
||||
tool_name=RESEARCH_AGENT_TOOL_NAME,
|
||||
db_session=db_session,
|
||||
).id,
|
||||
reasoning_tokens=llm_step_result.reasoning
|
||||
or most_recent_reasoning,
|
||||
tool_call_arguments=current_tool_call.tool_args,
|
||||
tool_call_response=report,
|
||||
search_docs=None, # Intermediate docs are not saved/shown
|
||||
generated_images=None,
|
||||
)
|
||||
state_container.add_tool_call(tool_call_info)
|
||||
|
||||
# If it reached this point, it did not call reasoning, so here we wipe it to not save it to multiple turns
|
||||
most_recent_reasoning = None
|
||||
tool_call_message = current_tool_call.to_msg_str()
|
||||
tool_call_token_count = token_counter(tool_call_message)
|
||||
|
||||
tool_call_msg = ChatMessageSimple(
|
||||
message=tool_call_message,
|
||||
token_count=tool_call_token_count,
|
||||
message_type=MessageType.TOOL_CALL,
|
||||
tool_call_id=current_tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(tool_call_msg)
|
||||
|
||||
tool_call_response_msg = ChatMessageSimple(
|
||||
message=report,
|
||||
token_count=token_counter(report),
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id=current_tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(tool_call_response_msg)
|
||||
|
||||
# If it reached this point, it did not call reasoning, so here we wipe it to not save it to multiple turns
|
||||
most_recent_reasoning = None
|
||||
|
||||
emitter.emit(
|
||||
Packet(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
GENERATE_PLAN_TOOL_NAME = "generate_plan"
|
||||
|
||||
RESEARCH_AGENT_DB_NAME = "ResearchAgent"
|
||||
RESEARCH_AGENT_IN_CODE_ID = "ResearchAgent"
|
||||
RESEARCH_AGENT_TOOL_NAME = "research_agent"
|
||||
RESEARCH_AGENT_TASK_KEY = "task"
|
||||
|
||||
|
||||
@@ -109,10 +109,6 @@ class VespaDocumentFields:
|
||||
hidden: bool | None = None
|
||||
aggregated_chunk_boost_factor: float | None = None
|
||||
|
||||
# document_id is added for migration purposes, ideally we should not be updating this field
|
||||
# TODO(subash): remove this field in a future migration
|
||||
document_id: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class VespaDocumentUserFields:
|
||||
|
||||
@@ -282,10 +282,6 @@ class Updatable(abc.ABC):
|
||||
def update(
|
||||
self,
|
||||
update_requests: list[MetadataUpdateRequest],
|
||||
# TODO(andrei), WARNING: Very temporary, this is not the interface we want
|
||||
# in Updatable, we only have this to continue supporting
|
||||
# user_file_docid_migration_task for Vespa which should be done soon.
|
||||
old_doc_id_to_new_doc_id: dict[str, str],
|
||||
) -> None:
|
||||
"""
|
||||
Updates some set of chunks. The document and fields to update are specified in the update
|
||||
|
||||
62
backend/onyx/document_index/opensearch/README.md
Normal file
62
backend/onyx/document_index/opensearch/README.md
Normal file
@@ -0,0 +1,62 @@
|
||||
# Opensearch Idiosyncrasies
|
||||
|
||||
## How it works at a high level
|
||||
Opensearch has 2 phases, a `Search` phase and a `Fetch` phase. The `Search` phase works by getting the document scores on each
|
||||
shard separately, then typically a fetch phase grabs all of the relevant fields/data for returning to the user. There is also
|
||||
an intermediate phase (seemingly built specifically to handle hybrid search queries) which can run in between as a processor.
|
||||
References:
|
||||
https://docs.opensearch.org/latest/search-plugins/search-pipelines/search-processors/
|
||||
https://docs.opensearch.org/latest/search-plugins/search-pipelines/normalization-processor/
|
||||
https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
|
||||
|
||||
## How Hybrid queries work
|
||||
Hybrid queries are basically parallel queries that each run through their own `Search` phase and do not interact in any way.
|
||||
They also run across all the shards. It is not entirely clear what happens if a combination pipeline is not specified for them,
|
||||
perhaps the scores are just summed.
|
||||
|
||||
When the normalization processor is applied to keyword/vector hybrid searches, documents that show up due to keyword match may
|
||||
not also have showed up in the vector search and vice versa. In these situations, it just receives a 0 score for the missing
|
||||
query component. Opensearch does not run another phase to recapture those missing values. The impact of this is that after
|
||||
normalizing, the missing scores are 0 but this is a higher score than if it actually received a non-zero score.
|
||||
|
||||
This may not be immediately obvious so an explanation is included here. If it got a non-zero score instead, it must be lower
|
||||
than all of the other scores of the list (otherwise it would have shown up). Therefore it would impact the normalization and
|
||||
push the other scores higher so that it's not only the lowest score still, but now it's a differentiated lowest score. This is
|
||||
not strictly the case in a multi-node setup but the high level concept approximately holds. So basically the 0 score is a form
|
||||
of "minimum value clipping".
|
||||
|
||||
## On time decay and boosting
|
||||
Embedding models do not have a uniform distribution from 0 to 1. The values typically cluster strongly around 0.6 to 0.8 but also
|
||||
varies between models and even the query. It is not a safe assumption to pre-normalize the scores so we also cannot apply any
|
||||
additive or multiplicative boost to it. Ie. if results of a doc cluster around 0.6 to 0.8 and I give a 50% penalty to the score,
|
||||
it doesn't bring a result from the top of the range to 50 percentile, it brings its under the 0.6 and is now the worst match.
|
||||
Same logic applies to additive boosting.
|
||||
|
||||
So these boosts can only be applied after normalization. Unfortunately with Opensearch, the normalization processor runs last
|
||||
and only applies to the results of the completely independent `Search` phase queries. So if a time based boost (a separate
|
||||
query which filters on recently updated documents) is added, it would not be able to introduce any new documents
|
||||
to the set (since the new documents would have no keyword/vector score or already be present) since the 0 scores on keyword
|
||||
and vector would make the docs which only came because of time filter very low scoring. This can however make some of the lower
|
||||
scored documents from the union of all the `Search` phase documents to show up higher and potentially not get dropped before
|
||||
being fetched and returned to the user. But there are other issues of including these:
|
||||
- There is no way to sort by this field, only a filter, so there's no way to guarantee the best docs even irrespective of the
|
||||
contents. If there are lots of updates, this may miss
|
||||
- There is not a good way to normalize this field, the best is to clip it on the bottom.
|
||||
- This would require using min-max norm but z-score norm is better for the other functions due to things like it being less
|
||||
sensitive to outliers, better handles distribution drifts (min-max assumes stable meaningful ranges), better for comparing
|
||||
"unusual-ness" across distributions.
|
||||
|
||||
So while it is possible to apply time based boosting at the normalization stage (or specifically to the keyword score), we have
|
||||
decided it is better to not apply it during the OpenSearch query.
|
||||
|
||||
Because of these limitations, Onyx in code applies further refinements, boostings, etc. based on OpenSearch providing an initial
|
||||
filtering. The impact of time decay and boost should not be so big that we would need orders of magnitude more results back
|
||||
from OpenSearch.
|
||||
|
||||
## Other concepts to be aware of
|
||||
Within the `Search` phase, there are optional steps like Rescore but these are not useful for the combination/normalization
|
||||
work that is relevant for the hybrid search. Since the Rescore happens prior to normalization, it's not able to provide any
|
||||
meaningful operations to the query for our usage.
|
||||
|
||||
Because the Title is included in the Contents for both embedding and keyword searches, the Title scores are very low relative to
|
||||
the actual full contents scoring. It is seen as a boost rather than a core scoring component. Time decay works similarly.
|
||||
@@ -1,6 +1,11 @@
|
||||
import json
|
||||
|
||||
import httpx
|
||||
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_experts_stores_representations,
|
||||
)
|
||||
from onyx.context.search.enums import QueryType
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
@@ -42,6 +47,7 @@ from onyx.document_index.opensearch.search import (
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import Document
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
|
||||
@@ -53,64 +59,39 @@ def _convert_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
) -> InferenceChunkUncleaned:
|
||||
return InferenceChunkUncleaned(
|
||||
chunk_id=chunk.chunk_index,
|
||||
# TODO(andrei): Do this in a followup. This is the top of the doc, we
|
||||
# use it in the UI so it's needed for now but we should show match
|
||||
# highlights instead really.
|
||||
blurb="",
|
||||
blurb=chunk.blurb,
|
||||
content=chunk.content,
|
||||
# TODO(andrei): Same comment as in
|
||||
# _convert_onyx_chunk_to_opensearch_document, just do this in a
|
||||
# followup.
|
||||
source_links=None,
|
||||
image_file_id=chunk.image_file_name,
|
||||
# TODO(andrei) Yuhong says he doesn't think we need that anymore. Used
|
||||
# if a section needed to be split into diff chunks. A section is a part
|
||||
# of a doc that a link will take you to. But don't chunks have their own
|
||||
# links? Look at this in a followup.
|
||||
source_links=json.loads(chunk.source_links) if chunk.source_links else None,
|
||||
image_file_id=chunk.image_file_id,
|
||||
# Deprecated. Fill in some reasonable default.
|
||||
section_continuation=False,
|
||||
document_id=chunk.document_id,
|
||||
source_type=DocumentSource(chunk.source_type),
|
||||
# TODO(andrei): Yuhong says this should never be None. I'll followup up
|
||||
# on that later.
|
||||
semantic_identifier=(
|
||||
chunk.semantic_identifier if chunk.semantic_identifier else ""
|
||||
),
|
||||
semantic_identifier=chunk.semantic_identifier,
|
||||
title=chunk.title,
|
||||
# TODO(andrei): Same comment as in
|
||||
# _convert_onyx_chunk_to_opensearch_document. Yuhong thinks OpenSearch
|
||||
# has some thing out of the box for this. Just need to look at it in a
|
||||
# followup.
|
||||
boost=1,
|
||||
# TODO(andrei): Do in a followup.
|
||||
boost=chunk.global_boost,
|
||||
# TODO(andrei): Do in a followup. We should be able to get this from
|
||||
# OpenSearch.
|
||||
recency_bias=1.0,
|
||||
# TODO(andrei): This is how good the match is, we need this, key insight
|
||||
# is we can order chunks by this. Should not be hard to plumb this from
|
||||
# a search result, do that in a followup.
|
||||
score=None,
|
||||
hidden=chunk.hidden,
|
||||
# TODO(andrei): Don't worry about these for now.
|
||||
# is_relevant
|
||||
# relevance_explanation
|
||||
# metadata
|
||||
# TODO(andrei): Same comment as in
|
||||
# _convert_onyx_chunk_to_opensearch_document.
|
||||
metadata={},
|
||||
metadata=json.loads(chunk.metadata),
|
||||
# TODO(andrei): The vector DB needs to supply this. I vaguely know
|
||||
# OpenSearch can from the documentation I've seen till now, look at this
|
||||
# in a followup.
|
||||
match_highlights=[],
|
||||
# TODO(andrei) Summary of the entire doc, specifically if you enable
|
||||
# contextual retrieval. Look at this in a followup.
|
||||
doc_summary="",
|
||||
# TODO(andrei) Consider storing a chunk content index instead of a full
|
||||
# string when working on chunk content augmentation.
|
||||
doc_summary=chunk.doc_summary,
|
||||
# TODO(andrei) Same thing as contx ret above, LLM gens context for each
|
||||
# chunk.
|
||||
chunk_context="",
|
||||
chunk_context=chunk.chunk_context,
|
||||
updated_at=chunk.last_updated,
|
||||
# primary_owners TODO(andrei)
|
||||
# secondary_owners TODO(andrei)
|
||||
# large_chunk_reference_ids TODO(andrei): Don't worry about this one.
|
||||
# TODO(andrei): Slack is special.
|
||||
is_federated=False,
|
||||
primary_owners=chunk.primary_owners,
|
||||
secondary_owners=chunk.secondary_owners,
|
||||
# TODO(andrei): This is the suffix appended to the end of the chunk
|
||||
# content to assist querying. There are better ways we can do this, for
|
||||
# ex. keeping an index of where to string split from.
|
||||
@@ -135,49 +116,31 @@ def _convert_onyx_chunk_to_opensearch_document(
|
||||
title_vector=chunk.title_embedding,
|
||||
content=chunk.content,
|
||||
content_vector=chunk.embeddings.full_embedding,
|
||||
# TODO(andrei): We should know this. Reason to have this is convenience,
|
||||
# but it could also change when you change your embedding model, maybe
|
||||
# we can remove it, Yuhong to look at this. Hardcoded to some nonsense
|
||||
# value for now.
|
||||
num_tokens=0,
|
||||
source_type=chunk.source_document.source.value,
|
||||
# TODO(andrei): This is just represented a bit differently in
|
||||
# DocumentBase than how we expect it in the schema currently. Look at
|
||||
# this closer in a followup. Always defaults to None for now.
|
||||
# metadata=chunk.source_document.metadata,
|
||||
metadata=json.dumps(chunk.source_document.metadata),
|
||||
last_updated=chunk.source_document.doc_updated_at,
|
||||
# TODO(andrei): Don't currently see an easy way of porting this, and
|
||||
# besides some connectors genuinely don't have this data. Look at this
|
||||
# closer in a followup. Always defaults to None for now.
|
||||
# created_at=None,
|
||||
public=chunk.access.is_public,
|
||||
# TODO(andrei): Implement ACL in a followup, currently none of the
|
||||
# methods in OpenSearchDocumentIndex support it anyway. Always defaults
|
||||
# to None for now.
|
||||
# access_control_list=chunk.access.to_acl(),
|
||||
# TODO(andrei): Look at this in a followup. Always defaults to False for
|
||||
# now.
|
||||
hidden=False,
|
||||
# TODO(andrei): This doesn't work bc global_boost is float, presumably
|
||||
# between 0.0 and inf (check this) and chunk.boost is an int from -inf
|
||||
# to +inf. Look at how the scaling compares between these in a followup.
|
||||
# Always defaults to 1.0 for now.
|
||||
# global_boost=chunk.boost,
|
||||
access_control_list=list(chunk.access.to_acl()),
|
||||
global_boost=chunk.boost,
|
||||
semantic_identifier=chunk.source_document.semantic_identifier,
|
||||
# TODO(andrei): Ask Chris more about this later. Always defaults to None
|
||||
# for now.
|
||||
# image_file_name=None,
|
||||
# TODO(andrei): Look at this in a followup. Don't think it'll be too
|
||||
# crazy just that source_document represents source links a bit
|
||||
# differently than we expect here. Always defaults to None for now.
|
||||
# source_links=chunk.source_document.source_links,
|
||||
image_file_id=chunk.image_file_id,
|
||||
source_links=json.dumps(chunk.source_links) if chunk.source_links else None,
|
||||
blurb=chunk.blurb,
|
||||
doc_summary=chunk.doc_summary,
|
||||
chunk_context=chunk.chunk_context,
|
||||
document_sets=list(chunk.document_sets) if chunk.document_sets else None,
|
||||
project_ids=list(chunk.user_project) if chunk.user_project else None,
|
||||
primary_owners=get_experts_stores_representations(
|
||||
chunk.source_document.primary_owners
|
||||
),
|
||||
secondary_owners=get_experts_stores_representations(
|
||||
chunk.source_document.secondary_owners
|
||||
),
|
||||
# TODO(andrei): Consider not even getting this from
|
||||
# DocMetadataAwareIndexChunk and instead using OpenSearchDocumentIndex's
|
||||
# instance variable. One source of truth -> less chance of a very bad
|
||||
# bug in prod.
|
||||
tenant_id=chunk.tenant_id,
|
||||
tenant_id=TenantState(tenant_id=chunk.tenant_id, multitenant=MULTI_TENANT),
|
||||
)
|
||||
|
||||
|
||||
@@ -322,7 +285,7 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
),
|
||||
)
|
||||
|
||||
return self._real_index.update([update_request], old_doc_id_to_new_doc_id={})
|
||||
return self._real_index.update([update_request])
|
||||
|
||||
def update(
|
||||
self,
|
||||
@@ -530,10 +493,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
def update(
|
||||
self,
|
||||
update_requests: list[MetadataUpdateRequest],
|
||||
# TODO(andrei), WARNING: Very temporary, this is not the interface we
|
||||
# want in Updatable, we only have this to continue supporting
|
||||
# user_file_docid_migration_task for Vespa which should be done soon.
|
||||
old_doc_id_to_new_doc_id: dict[str, str],
|
||||
) -> None:
|
||||
logger.info("[ANDREI]: Updating documents...")
|
||||
# TODO(andrei): This needs to be implemented. I explicitly do not raise
|
||||
|
||||
@@ -4,30 +4,35 @@ from typing import Any
|
||||
from typing import Self
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import field_serializer
|
||||
from pydantic import field_validator
|
||||
from pydantic import model_serializer
|
||||
from pydantic import model_validator
|
||||
from pydantic import SerializerFunctionWrapHandler
|
||||
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.constants import EF_CONSTRUCTION
|
||||
from onyx.document_index.opensearch.constants import EF_SEARCH
|
||||
from onyx.document_index.opensearch.constants import M
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
TITLE_FIELD_NAME = "title"
|
||||
TITLE_VECTOR_FIELD_NAME = "title_vector"
|
||||
CONTENT_FIELD_NAME = "content"
|
||||
CONTENT_VECTOR_FIELD_NAME = "content_vector"
|
||||
NUM_TOKENS_FIELD_NAME = "num_tokens"
|
||||
SOURCE_TYPE_FIELD_NAME = "source_type"
|
||||
METADATA_FIELD_NAME = "metadata"
|
||||
LAST_UPDATED_FIELD_NAME = "last_updated"
|
||||
CREATED_AT_FIELD_NAME = "created_at"
|
||||
PUBLIC_FIELD_NAME = "public"
|
||||
ACCESS_CONTROL_LIST_FIELD_NAME = "access_control_list"
|
||||
HIDDEN_FIELD_NAME = "hidden"
|
||||
GLOBAL_BOOST_FIELD_NAME = "global_boost"
|
||||
SEMANTIC_IDENTIFIER_FIELD_NAME = "semantic_identifier"
|
||||
IMAGE_FILE_NAME_FIELD_NAME = "image_file_name"
|
||||
IMAGE_FILE_ID_FIELD_NAME = "image_file_id"
|
||||
SOURCE_LINKS_FIELD_NAME = "source_links"
|
||||
DOCUMENT_SETS_FIELD_NAME = "document_sets"
|
||||
PROJECT_IDS_FIELD_NAME = "project_ids"
|
||||
@@ -35,6 +40,11 @@ DOCUMENT_ID_FIELD_NAME = "document_id"
|
||||
CHUNK_INDEX_FIELD_NAME = "chunk_index"
|
||||
MAX_CHUNK_SIZE_FIELD_NAME = "max_chunk_size"
|
||||
TENANT_ID_FIELD_NAME = "tenant_id"
|
||||
BLURB_FIELD_NAME = "blurb"
|
||||
DOC_SUMMARY_FIELD_NAME = "doc_summary"
|
||||
CHUNK_CONTEXT_FIELD_NAME = "chunk_context"
|
||||
PRIMARY_OWNERS_FIELD_NAME = "primary_owners"
|
||||
SECONDARY_OWNERS_FIELD_NAME = "secondary_owners"
|
||||
|
||||
|
||||
def get_opensearch_doc_chunk_id(
|
||||
@@ -51,12 +61,27 @@ def get_opensearch_doc_chunk_id(
|
||||
return f"{document_id}__{max_chunk_size}__{chunk_index}"
|
||||
|
||||
|
||||
def set_or_convert_timezone_to_utc(value: datetime) -> datetime:
|
||||
if value.tzinfo is None:
|
||||
# astimezone will raise if value does not have a timezone set.
|
||||
value = value.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
# Does appropriate time conversion if value was set in a different
|
||||
# timezone.
|
||||
value = value.astimezone(timezone.utc)
|
||||
return value
|
||||
|
||||
|
||||
class DocumentChunk(BaseModel):
|
||||
"""
|
||||
Represents a chunk of a document in the OpenSearch index.
|
||||
|
||||
The names of these fields are based on the OpenSearch schema. Changes to the
|
||||
schema require changes here. See get_document_schema.
|
||||
|
||||
WARNING: Relies on MULTI_TENANT which is global state. Also uses
|
||||
get_current_tenant_id. Generally relying on global state is bad, in this
|
||||
case we accept it because of the importance of validating tenant logic.
|
||||
"""
|
||||
|
||||
model_config = {"frozen": True}
|
||||
@@ -74,38 +99,44 @@ class DocumentChunk(BaseModel):
|
||||
title_vector: list[float] | None = None
|
||||
content: str
|
||||
content_vector: list[float]
|
||||
# The actual number of tokens in the chunk.
|
||||
num_tokens: int
|
||||
|
||||
source_type: str
|
||||
# Application logic should store these strings the format key:::value.
|
||||
metadata: list[str] | None = None
|
||||
# Contains a string representation of a dict which maps string key to either
|
||||
# string value or list of string values.
|
||||
# TODO(andrei): When we augment content with metadata this can just be an
|
||||
# index pointer, and when we support metadata list that will just be a list
|
||||
# of strings.
|
||||
metadata: str
|
||||
# If it exists, time zone should always be UTC.
|
||||
last_updated: datetime | None = None
|
||||
created_at: datetime | None = None
|
||||
|
||||
public: bool = False
|
||||
access_control_list: list[str] | None = None
|
||||
public: bool
|
||||
access_control_list: list[str]
|
||||
# Defaults to False, currently gets written during update not index.
|
||||
hidden: bool = False
|
||||
|
||||
global_boost: float = 1.0
|
||||
global_boost: int
|
||||
|
||||
# TODO(andrei): Make this non-nullable in a followup.
|
||||
semantic_identifier: str | None = None
|
||||
image_file_name: str | None = None
|
||||
source_links: list[str] | None = None
|
||||
semantic_identifier: str
|
||||
image_file_id: str | None = None
|
||||
# Contains a string representation of a dict which maps offset into the raw
|
||||
# chunk text to the link corresponding to that point.
|
||||
source_links: str | None = None
|
||||
blurb: str
|
||||
doc_summary: str
|
||||
chunk_context: str
|
||||
|
||||
document_sets: list[str] | None = None
|
||||
# User projects.
|
||||
project_ids: list[int] | None = None
|
||||
primary_owners: list[str] | None = None
|
||||
secondary_owners: list[str] | None = None
|
||||
|
||||
tenant_id: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_num_tokens_fits_within_max_chunk_size(self) -> Self:
|
||||
if self.num_tokens > self.max_chunk_size:
|
||||
raise ValueError(
|
||||
"Bug: Num tokens must be less than or equal to max chunk size."
|
||||
)
|
||||
return self
|
||||
tenant_id: TenantState = Field(
|
||||
default_factory=lambda: TenantState(
|
||||
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
|
||||
)
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_title_and_title_vector_are_consistent(self) -> Self:
|
||||
@@ -116,25 +147,116 @@ class DocumentChunk(BaseModel):
|
||||
raise ValueError("Bug: Title must not be None if title vector is not None.")
|
||||
return self
|
||||
|
||||
@field_serializer("last_updated", "created_at", mode="plain")
|
||||
@model_serializer(mode="wrap")
|
||||
def serialize_model(
|
||||
self, handler: SerializerFunctionWrapHandler
|
||||
) -> dict[str, object]:
|
||||
"""Invokes pydantic's serialization logic, then excludes Nones.
|
||||
|
||||
We do this because .model_dump(exclude_none=True) does not work after
|
||||
@field_serializer logic, so for some field serializers which return None
|
||||
and which we would like to exclude from the final dump, they would be
|
||||
included without this.
|
||||
|
||||
Args:
|
||||
handler: Callable from pydantic which takes the instance of the
|
||||
model as an argument and performs standard serialization.
|
||||
|
||||
Returns:
|
||||
The return of handler but with None items excluded.
|
||||
"""
|
||||
serialized: dict[str, object] = handler(self)
|
||||
serialized_exclude_none = {k: v for k, v in serialized.items() if v is not None}
|
||||
return serialized_exclude_none
|
||||
|
||||
@field_serializer("last_updated", mode="wrap")
|
||||
def serialize_datetime_fields_to_epoch_millis(
|
||||
self, value: datetime | None
|
||||
self, value: datetime | None, handler: SerializerFunctionWrapHandler
|
||||
) -> int | None:
|
||||
"""
|
||||
Serializes datetime fields to milliseconds since the Unix epoch.
|
||||
|
||||
If there is no datetime, returns None.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if value.tzinfo is None:
|
||||
# astimezone will raise if value does not have a timezone set.
|
||||
value = value.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
# Does appropriate time conversion if value was set in a different
|
||||
# timezone.
|
||||
value = value.astimezone(timezone.utc)
|
||||
value = set_or_convert_timezone_to_utc(value)
|
||||
# timestamp returns a float in seconds so convert to millis.
|
||||
return int(value.timestamp() * 1000)
|
||||
|
||||
@field_validator("last_updated", mode="before")
|
||||
@classmethod
|
||||
def parse_epoch_millis_to_datetime(cls, value: Any) -> datetime | None:
|
||||
"""Parses milliseconds since the Unix epoch to a datetime object.
|
||||
|
||||
If the input is None, returns None.
|
||||
|
||||
The datetime returned will be in UTC.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, datetime):
|
||||
value = set_or_convert_timezone_to_utc(value)
|
||||
return value
|
||||
if not isinstance(value, int):
|
||||
raise ValueError(
|
||||
f"Bug: Expected an int for the last_updated property from OpenSearch, got {type(value)} instead."
|
||||
)
|
||||
return datetime.fromtimestamp(value / 1000, tz=timezone.utc)
|
||||
|
||||
@field_serializer("tenant_id", mode="wrap")
|
||||
def serialize_tenant_state(
|
||||
self, value: TenantState, handler: SerializerFunctionWrapHandler
|
||||
) -> str | None:
|
||||
"""
|
||||
Serializes tenant_state to the tenant str if multitenant, or None if
|
||||
not.
|
||||
|
||||
The idea is that in single tenant mode, the schema does not have a
|
||||
tenant_id field, so we don't want to supply it in our serialized
|
||||
DocumentChunk. This assumes the final serialized model excludes None
|
||||
fields, which serialize_model should enforce.
|
||||
"""
|
||||
if not value.multitenant:
|
||||
return None
|
||||
else:
|
||||
return value.tenant_id
|
||||
|
||||
@field_validator("tenant_id", mode="before")
|
||||
@classmethod
|
||||
def parse_tenant_id(cls, value: Any) -> TenantState:
|
||||
"""
|
||||
Generates a TenantState from OpenSearch's tenant_id if it exists, or
|
||||
generates a default state if it does not (implies we are in single
|
||||
tenant mode).
|
||||
"""
|
||||
if value is None:
|
||||
if MULTI_TENANT:
|
||||
raise ValueError(
|
||||
"Bug: No tenant_id was supplied but multi-tenant mode is enabled."
|
||||
)
|
||||
return TenantState(
|
||||
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
|
||||
)
|
||||
elif isinstance(value, TenantState):
|
||||
if MULTI_TENANT != value.multitenant:
|
||||
raise ValueError(
|
||||
f"Bug: An existing TenantState object was supplied to the DocumentChunk model but its multi-tenant mode "
|
||||
f"({value.multitenant}) does not match the program's current global tenancy state."
|
||||
)
|
||||
return value
|
||||
elif not isinstance(value, str):
|
||||
raise ValueError(
|
||||
f"Bug: Expected a str for the tenant_id property from OpenSearch, got {type(value)} instead."
|
||||
)
|
||||
else:
|
||||
if not MULTI_TENANT:
|
||||
raise ValueError(
|
||||
"Bug: Got a non-null str for the tenant_id property from OpenSearch but multi-tenant mode is not enabled. "
|
||||
"This is unexpected because in single-tenant mode we don't expect to see a tenant_id."
|
||||
)
|
||||
return TenantState(tenant_id=value, multitenant=MULTI_TENANT)
|
||||
|
||||
|
||||
class DocumentSchema:
|
||||
"""
|
||||
@@ -172,13 +294,19 @@ class DocumentSchema:
|
||||
OpenSearch client. The structure of this dictionary is
|
||||
determined by OpenSearch documentation.
|
||||
"""
|
||||
schema = {
|
||||
schema: dict[str, Any] = {
|
||||
# By default OpenSearch allows dynamically adding new properties
|
||||
# based on indexed documents. This is awful and we disable it here.
|
||||
# An exception will be raised if you try to index a new doc which
|
||||
# contains unexpected fields.
|
||||
"dynamic": "strict",
|
||||
"properties": {
|
||||
TITLE_FIELD_NAME: {
|
||||
"type": "text",
|
||||
"fields": {
|
||||
# Subfield accessed as title.keyword. Not indexed for
|
||||
# values longer than 256 chars.
|
||||
# TODO(andrei): Ask Yuhong do we want this?
|
||||
"keyword": {"type": "keyword", "ignore_above": 256}
|
||||
},
|
||||
},
|
||||
@@ -196,6 +324,8 @@ class DocumentSchema:
|
||||
"parameters": {"ef_construction": EF_CONSTRUCTION, "m": M},
|
||||
},
|
||||
},
|
||||
# TODO(andrei): This is a tensor in Vespa. Also look at feature
|
||||
# parity for these other method fields.
|
||||
CONTENT_VECTOR_FIELD_NAME: {
|
||||
"type": "knn_vector",
|
||||
"dimension": vector_dimension,
|
||||
@@ -206,11 +336,10 @@ class DocumentSchema:
|
||||
"parameters": {"ef_construction": EF_CONSTRUCTION, "m": M},
|
||||
},
|
||||
},
|
||||
# Number of tokens in the chunk's content.
|
||||
NUM_TOKENS_FIELD_NAME: {"type": "integer", "store": True},
|
||||
SOURCE_TYPE_FIELD_NAME: {"type": "keyword"},
|
||||
# Application logic should store in the format key:::value.
|
||||
METADATA_FIELD_NAME: {"type": "keyword"},
|
||||
# TODO(andrei): Check if Vespa stores seconds, we may wanna do
|
||||
# seconds here not millis.
|
||||
LAST_UPDATED_FIELD_NAME: {
|
||||
"type": "date",
|
||||
"format": "epoch_millis",
|
||||
@@ -218,13 +347,6 @@ class DocumentSchema:
|
||||
# would make sense to sort by date.
|
||||
"doc_values": True,
|
||||
},
|
||||
CREATED_AT_FIELD_NAME: {
|
||||
"type": "date",
|
||||
"format": "epoch_millis",
|
||||
# For some reason date defaults to False, even though it
|
||||
# would make sense to sort by date.
|
||||
"doc_values": True,
|
||||
},
|
||||
# Access control fields.
|
||||
# Whether the doc is public. Could have fallen under access
|
||||
# control list but is such a broad and critical filter that it
|
||||
@@ -237,7 +359,7 @@ class DocumentSchema:
|
||||
# all other search filters; up to search implementations to
|
||||
# guarantee this.
|
||||
HIDDEN_FIELD_NAME: {"type": "boolean"},
|
||||
GLOBAL_BOOST_FIELD_NAME: {"type": "float"},
|
||||
GLOBAL_BOOST_FIELD_NAME: {"type": "integer"},
|
||||
# This field is only used for displaying a useful name for the
|
||||
# doc in the UI and is not used for searching. Disabling these
|
||||
# features to increase perf.
|
||||
@@ -248,7 +370,7 @@ class DocumentSchema:
|
||||
"store": False,
|
||||
},
|
||||
# Same as above; used to display an image along with the doc.
|
||||
IMAGE_FILE_NAME_FIELD_NAME: {
|
||||
IMAGE_FILE_ID_FIELD_NAME: {
|
||||
"type": "keyword",
|
||||
"index": False,
|
||||
"doc_values": False,
|
||||
@@ -261,15 +383,43 @@ class DocumentSchema:
|
||||
"doc_values": False,
|
||||
"store": False,
|
||||
},
|
||||
# Same as above; used to quickly summarize the doc in the UI.
|
||||
BLURB_FIELD_NAME: {
|
||||
"type": "keyword",
|
||||
"index": False,
|
||||
"doc_values": False,
|
||||
"store": False,
|
||||
},
|
||||
# Same as above.
|
||||
# TODO(andrei): If we want to search on this this needs to be
|
||||
# changed.
|
||||
DOC_SUMMARY_FIELD_NAME: {
|
||||
"type": "keyword",
|
||||
"index": False,
|
||||
"doc_values": False,
|
||||
"store": False,
|
||||
},
|
||||
# Same as above.
|
||||
# TODO(andrei): If we want to search on this this needs to be
|
||||
# changed.
|
||||
CHUNK_CONTEXT_FIELD_NAME: {
|
||||
"type": "keyword",
|
||||
"index": False,
|
||||
"doc_values": False,
|
||||
"store": False,
|
||||
},
|
||||
# Product-specific fields.
|
||||
DOCUMENT_SETS_FIELD_NAME: {"type": "keyword"},
|
||||
PROJECT_IDS_FIELD_NAME: {"type": "integer"},
|
||||
PRIMARY_OWNERS_FIELD_NAME: {"type": "keyword"},
|
||||
SECONDARY_OWNERS_FIELD_NAME: {"type": "keyword"},
|
||||
# OpenSearch metadata fields.
|
||||
DOCUMENT_ID_FIELD_NAME: {"type": "keyword"},
|
||||
CHUNK_INDEX_FIELD_NAME: {"type": "integer"},
|
||||
# The maximum number of tokens this chunk's content can hold.
|
||||
# TODO(andrei): Can we generalize this to embedding type?
|
||||
MAX_CHUNK_SIZE_FIELD_NAME: {"type": "integer"},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
if multitenant:
|
||||
|
||||
@@ -24,7 +24,7 @@ from onyx.document_index.opensearch.schema import TITLE_VECTOR_FIELD_NAME
|
||||
# TODO(andrei): Turn all magic dictionaries to pydantic models.
|
||||
|
||||
MIN_MAX_NORMALIZATION_PIPELINE_NAME = "normalization_pipeline_min_max"
|
||||
MIN_MAX_NORMALIZATION_PIPELINE_CONFIG = {
|
||||
MIN_MAX_NORMALIZATION_PIPELINE_CONFIG: dict[str, Any] = {
|
||||
"description": "Normalization for keyword and vector scores using min-max",
|
||||
"phase_results_processors": [
|
||||
{
|
||||
@@ -49,7 +49,7 @@ MIN_MAX_NORMALIZATION_PIPELINE_CONFIG = {
|
||||
}
|
||||
|
||||
ZSCORE_NORMALIZATION_PIPELINE_NAME = "normalization_pipeline_zscore"
|
||||
ZSCORE_NORMALIZATION_PIPELINE_CONFIG = {
|
||||
ZSCORE_NORMALIZATION_PIPELINE_CONFIG: dict[str, Any] = {
|
||||
"description": "Normalization for keyword and vector scores using z-score",
|
||||
"phase_results_processors": [
|
||||
{
|
||||
@@ -140,7 +140,7 @@ class DocumentQuery:
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
|
||||
]
|
||||
|
||||
if tenant_state.tenant_id is not None:
|
||||
if tenant_state.multitenant:
|
||||
# TODO(andrei): Fix tenant stuff.
|
||||
filter_clauses.append(
|
||||
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
|
||||
@@ -199,7 +199,7 @@ class DocumentQuery:
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
|
||||
]
|
||||
|
||||
if tenant_state.tenant_id is not None:
|
||||
if tenant_state.multitenant:
|
||||
filter_clauses.append(
|
||||
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
|
||||
)
|
||||
@@ -316,6 +316,7 @@ class DocumentQuery:
|
||||
{
|
||||
"multi_match": {
|
||||
"query": query_text,
|
||||
# TODO(andrei): Ask Yuhong do we want this?
|
||||
"fields": [f"{TITLE_FIELD_NAME}^2", f"{TITLE_FIELD_NAME}.keyword"],
|
||||
"type": "best_fields",
|
||||
}
|
||||
@@ -340,7 +341,7 @@ class DocumentQuery:
|
||||
{"term": {PUBLIC_FIELD_NAME: {"value": True}}},
|
||||
{"term": {HIDDEN_FIELD_NAME: {"value": False}}},
|
||||
]
|
||||
if tenant_state.tenant_id is not None:
|
||||
if tenant_state.multitenant:
|
||||
hybrid_search_filters.append(
|
||||
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
|
||||
)
|
||||
|
||||
@@ -686,12 +686,7 @@ class VespaIndex(DocumentIndex):
|
||||
project_ids=project_ids,
|
||||
)
|
||||
|
||||
old_doc_id_to_new_doc_id: dict[str, str] = dict()
|
||||
if fields is not None and fields.document_id is not None:
|
||||
old_doc_id_to_new_doc_id[doc_id] = fields.document_id
|
||||
vespa_document_index.update(
|
||||
[update_request], old_doc_id_to_new_doc_id=old_doc_id_to_new_doc_id
|
||||
)
|
||||
vespa_document_index.update([update_request])
|
||||
|
||||
def delete_single(
|
||||
self,
|
||||
|
||||
@@ -19,7 +19,6 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
from onyx.document_index.document_index_utils import get_uuid_from_chunk
|
||||
from onyx.document_index.document_index_utils import get_uuid_from_chunk_info_old
|
||||
from onyx.document_index.interfaces import MinimalDocumentIndexingInfo
|
||||
from onyx.document_index.vespa.shared_utils.utils import remove_invalid_unicode_chars
|
||||
from onyx.document_index.vespa.shared_utils.utils import (
|
||||
replace_invalid_doc_id_characters,
|
||||
)
|
||||
@@ -56,6 +55,7 @@ from onyx.document_index.vespa_constants import TITLE_EMBEDDING
|
||||
from onyx.document_index.vespa_constants import USER_PROJECT
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.text_processing import remove_invalid_unicode_chars
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import re
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
@@ -53,15 +52,6 @@ def replace_invalid_doc_id_characters(text: str) -> str:
|
||||
return text.replace("'", "_")
|
||||
|
||||
|
||||
def remove_invalid_unicode_chars(text: str) -> str:
|
||||
"""Vespa does not take in unicode chars that aren't valid for XML.
|
||||
This removes them."""
|
||||
_illegal_xml_chars_RE: re.Pattern = re.compile(
|
||||
"[\x00-\x08\x0b\x0c\x0e-\x1f\ud800-\udfff\ufdd0-\ufdef\ufffe\uffff]"
|
||||
)
|
||||
return _illegal_xml_chars_RE.sub("", text)
|
||||
|
||||
|
||||
def get_vespa_http_client(no_timeout: bool = False, http2: bool = True) -> httpx.Client:
|
||||
"""
|
||||
Configures and returns an HTTP client for communicating with Vespa,
|
||||
|
||||
@@ -215,7 +215,6 @@ def _update_single_chunk(
|
||||
doc_id: str,
|
||||
http_client: httpx.Client,
|
||||
update_request: MetadataUpdateRequest,
|
||||
new_doc_id: str | None,
|
||||
) -> None:
|
||||
"""Updates a single document chunk in Vespa.
|
||||
|
||||
@@ -251,11 +250,6 @@ def _update_single_chunk(
|
||||
model_config = {"frozen": True}
|
||||
assign: list[int]
|
||||
|
||||
# TODO(andrei): Very temporary, delete soon.
|
||||
class _DocumentId(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
assign: str
|
||||
|
||||
class _VespaPutFields(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
# The names of these fields are based the Vespa schema. Changes to the
|
||||
@@ -266,8 +260,6 @@ def _update_single_chunk(
|
||||
access_control_list: _AccessControl | None = None
|
||||
hidden: _Hidden | None = None
|
||||
user_project: _UserProjects | None = None
|
||||
# TODO(andrei): Very temporary, delete soon.
|
||||
document_id: _DocumentId | None = None
|
||||
|
||||
class _VespaPutRequest(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
@@ -302,10 +294,6 @@ def _update_single_chunk(
|
||||
if update_request.project_ids is not None
|
||||
else None
|
||||
)
|
||||
# TODO(andrei): Very temporary, delete soon.
|
||||
document_id_update: _DocumentId | None = (
|
||||
_DocumentId(assign=new_doc_id) if new_doc_id is not None else None
|
||||
)
|
||||
|
||||
vespa_put_fields = _VespaPutFields(
|
||||
boost=boost_update,
|
||||
@@ -313,8 +301,6 @@ def _update_single_chunk(
|
||||
access_control_list=access_update,
|
||||
hidden=hidden_update,
|
||||
user_project=user_projects_update,
|
||||
# TODO(andrei): Very temporary, delete soon.
|
||||
document_id=document_id_update,
|
||||
)
|
||||
|
||||
vespa_put_request = _VespaPutRequest(
|
||||
@@ -540,10 +526,6 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
def update(
|
||||
self,
|
||||
update_requests: list[MetadataUpdateRequest],
|
||||
# TODO(andrei), WARNING: Very temporary, this is not the interface we want
|
||||
# in Updatable, we only have this to continue supporting
|
||||
# user_file_docid_migration_task for Vespa which should be done soon.
|
||||
old_doc_id_to_new_doc_id: dict[str, str],
|
||||
) -> None:
|
||||
# WARNING: This method can be called by vespa_metadata_sync_task, which
|
||||
# is kicked off by check_for_vespa_sync_task, notably before a document
|
||||
@@ -584,8 +566,6 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
doc_id,
|
||||
httpx_client,
|
||||
update_request,
|
||||
# NOTE: The key is the raw ID, not the sanitized ID.
|
||||
new_doc_id=old_doc_id_to_new_doc_id.get(doc_id, None),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -232,10 +232,23 @@ def gather_stream_with_tools(packets: AnswerStream) -> GatherStreamResult:
|
||||
stream_end_time = time.time()
|
||||
|
||||
if message_id is None:
|
||||
raise ValueError("Message ID is required")
|
||||
# If we got a streaming error, include it in the exception
|
||||
if error_msg:
|
||||
raise ValueError(f"Message ID is required. Stream error: {error_msg}")
|
||||
raise ValueError(
|
||||
f"Message ID is required. No MessageResponseIDInfo received. "
|
||||
f"Tools called: {tools_called}"
|
||||
)
|
||||
|
||||
# Allow empty answers for tool-only turns (e.g., in multi-turn evals)
|
||||
# Some turns may only execute tools without generating a text response
|
||||
if answer is None:
|
||||
raise RuntimeError("Answer was not generated")
|
||||
logger.warning(
|
||||
"No answer content generated. Tools called: %s. "
|
||||
"This may be expected for tool-only turns.",
|
||||
tools_called,
|
||||
)
|
||||
answer = ""
|
||||
|
||||
# Calculate timings
|
||||
total_ms = (stream_end_time - stream_start_time) * 1000
|
||||
@@ -484,15 +497,18 @@ def _get_multi_turn_answer_with_tools(
|
||||
if configuration.search_permissions_email
|
||||
else None
|
||||
)
|
||||
# Cache user_id to avoid SQLAlchemy expiration issues
|
||||
user_id = user.id if user else None
|
||||
|
||||
# Create a single chat session for all turns
|
||||
chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="Multi-turn eval session",
|
||||
user_id=user.id if user else None,
|
||||
user_id=user_id,
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
onyxbot_flow=True,
|
||||
)
|
||||
chat_session_id = chat_session.id
|
||||
|
||||
# Process each turn sequentially
|
||||
for turn_idx, msg in enumerate(messages):
|
||||
@@ -539,7 +555,7 @@ def _get_multi_turn_answer_with_tools(
|
||||
# Create request for this turn
|
||||
# Use AUTO_PLACE_AFTER_LATEST_MESSAGE to chain messages
|
||||
request = CreateChatMessageRequest(
|
||||
chat_session_id=chat_session.id,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=AUTO_PLACE_AFTER_LATEST_MESSAGE,
|
||||
message=msg.message,
|
||||
file_descriptors=[],
|
||||
|
||||
@@ -90,6 +90,10 @@ class EvalConfigurationOptions(BaseModel):
|
||||
search_permissions_email: str
|
||||
dataset_name: str
|
||||
no_send_logs: bool = False
|
||||
# Optional override for Braintrust project (defaults to BRAINTRUST_PROJECT env var)
|
||||
braintrust_project: str | None = None
|
||||
# Optional experiment name for the eval run (shows in Braintrust UI)
|
||||
experiment_name: str | None = None
|
||||
|
||||
def get_configuration(self, db_session: Session) -> EvalConfiguration:
|
||||
persona_override_config = self.persona_override_config or PersonaOverrideConfig(
|
||||
|
||||
@@ -122,11 +122,12 @@ class BraintrustEvalProvider(EvalProvider):
|
||||
return multi_turn_task(eval_input)
|
||||
return task(eval_input)
|
||||
|
||||
project_name = configuration.braintrust_project or BRAINTRUST_PROJECT
|
||||
experiment_name = configuration.experiment_name
|
||||
|
||||
eval_data: Any = None
|
||||
if remote_dataset_name is not None:
|
||||
eval_data = init_dataset(
|
||||
project=BRAINTRUST_PROJECT, name=remote_dataset_name
|
||||
)
|
||||
eval_data = init_dataset(project=project_name, name=remote_dataset_name)
|
||||
else:
|
||||
if data:
|
||||
eval_data = [
|
||||
@@ -150,7 +151,8 @@ class BraintrustEvalProvider(EvalProvider):
|
||||
metadata = configuration.model_dump()
|
||||
|
||||
Eval( # type: ignore[misc]
|
||||
name=BRAINTRUST_PROJECT,
|
||||
name=project_name,
|
||||
experiment_name=experiment_name,
|
||||
data=eval_data,
|
||||
task=dispatch_task,
|
||||
scores=[tool_assertion_scorer],
|
||||
|
||||
@@ -164,7 +164,7 @@ def format_document_soup(
|
||||
|
||||
|
||||
def parse_html_page_basic(text: str | BytesIO | IO[bytes]) -> str:
|
||||
soup = bs4.BeautifulSoup(text, "html.parser")
|
||||
soup = bs4.BeautifulSoup(text, "lxml")
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
@@ -174,7 +174,7 @@ def web_html_cleanup(
|
||||
additional_element_types_to_discard: list[str] | None = None,
|
||||
) -> ParsedHTML:
|
||||
if isinstance(page_content, str):
|
||||
soup = bs4.BeautifulSoup(page_content, "html.parser")
|
||||
soup = bs4.BeautifulSoup(page_content, "lxml")
|
||||
else:
|
||||
soup = page_content
|
||||
|
||||
|
||||
87
backend/onyx/file_processing/password_validation.py
Normal file
87
backend/onyx/file_processing/password_validation.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from typing import IO
|
||||
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
PASSWORD_PROTECTED_FILES = [
|
||||
".pdf",
|
||||
".docx",
|
||||
".pptx",
|
||||
".xlsx",
|
||||
]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def preserve_position(file: IO[Any]) -> Generator[IO[Any], None, None]:
|
||||
"""Preserves the file's cursor position"""
|
||||
pos = file.tell()
|
||||
try:
|
||||
file.seek(0)
|
||||
yield file
|
||||
finally:
|
||||
file.seek(pos)
|
||||
|
||||
|
||||
def is_pdf_protected(file: IO[Any]) -> bool:
|
||||
from pypdf import PdfReader
|
||||
|
||||
with preserve_position(file):
|
||||
reader = PdfReader(file)
|
||||
|
||||
return bool(reader.is_encrypted)
|
||||
|
||||
|
||||
def is_docx_protected(file: IO[Any]) -> bool:
|
||||
return is_office_file_protected(file)
|
||||
|
||||
|
||||
def is_pptx_protected(file: IO[Any]) -> bool:
|
||||
return is_office_file_protected(file)
|
||||
|
||||
|
||||
def is_xlsx_protected(file: IO[Any]) -> bool:
|
||||
return is_office_file_protected(file)
|
||||
|
||||
|
||||
def is_office_file_protected(file: IO[Any]) -> bool:
|
||||
import msoffcrypto # type: ignore[import-untyped]
|
||||
|
||||
with preserve_position(file):
|
||||
office = msoffcrypto.OfficeFile(file)
|
||||
|
||||
return office.is_encrypted()
|
||||
|
||||
|
||||
def is_file_password_protected(
|
||||
file: IO[Any],
|
||||
file_name: str,
|
||||
extension: str | None = None,
|
||||
) -> bool:
|
||||
extension_to_function: dict[str, Callable[[IO[Any]], bool]] = {
|
||||
".pdf": is_pdf_protected,
|
||||
".docx": is_docx_protected,
|
||||
".pptx": is_pptx_protected,
|
||||
".xlsx": is_xlsx_protected,
|
||||
}
|
||||
|
||||
if not extension:
|
||||
extension = get_file_ext(file_name)
|
||||
|
||||
if extension not in PASSWORD_PROTECTED_FILES:
|
||||
return False
|
||||
|
||||
if extension not in extension_to_function:
|
||||
logger.warning(
|
||||
f"Extension={extension} can be password protected, but no function found"
|
||||
)
|
||||
return False
|
||||
|
||||
func = extension_to_function[extension]
|
||||
|
||||
return func(file)
|
||||
@@ -9,7 +9,7 @@ from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from unstructured_client.models import operations # type: ignore
|
||||
from unstructured_client.models import operations
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -55,19 +55,19 @@ def _sdk_partition_request(
|
||||
|
||||
def unstructured_to_text(file: IO[Any], file_name: str) -> str:
|
||||
from unstructured.staging.base import dict_to_elements
|
||||
from unstructured_client import UnstructuredClient # type: ignore
|
||||
from unstructured_client import UnstructuredClient
|
||||
|
||||
logger.debug(f"Starting to read file: {file_name}")
|
||||
req = _sdk_partition_request(file, file_name, strategy="fast")
|
||||
|
||||
unstructured_client = UnstructuredClient(api_key_auth=get_unstructured_api_key())
|
||||
|
||||
response = unstructured_client.general.partition(req)
|
||||
elements = dict_to_elements(response.elements)
|
||||
response = unstructured_client.general.partition(request=req)
|
||||
|
||||
if response.status_code != 200:
|
||||
err = f"Received unexpected status code {response.status_code} from Unstructured API."
|
||||
logger.error(err)
|
||||
raise ValueError(err)
|
||||
|
||||
elements = dict_to_elements(response.elements or [])
|
||||
return "\n\n".join(str(el) for el in elements)
|
||||
|
||||
@@ -6,15 +6,19 @@ from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm.session import TransactionalContext
|
||||
|
||||
from onyx.access.access import get_access_for_user_files
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.constants import DEFAULT_BOOST
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.user_file import fetch_chunk_counts_for_user_files
|
||||
from onyx.db.user_file import fetch_user_project_ids_for_user_files
|
||||
from onyx.file_store.utils import store_user_file_plaintext
|
||||
@@ -194,6 +198,42 @@ class UserFileIndexingAdapter:
|
||||
user_file_id_to_token_count=user_file_id_to_token_count,
|
||||
)
|
||||
|
||||
def _notify_assistant_owners_if_files_ready(
|
||||
self, user_files: list[UserFile]
|
||||
) -> None:
|
||||
"""
|
||||
Check if all files for associated assistants are processed and notify owners.
|
||||
Only sends notification when all files for an assistant are COMPLETED.
|
||||
"""
|
||||
for user_file in user_files:
|
||||
if user_file.status == UserFileStatus.COMPLETED:
|
||||
for assistant in user_file.assistants:
|
||||
# Skip assistants without owners
|
||||
if assistant.user_id is None:
|
||||
continue
|
||||
|
||||
# Check if all OTHER files for this assistant are completed
|
||||
# (we already know current file is completed from the outer check)
|
||||
all_files_completed = all(
|
||||
f.status == UserFileStatus.COMPLETED
|
||||
for f in assistant.user_files
|
||||
if f.id != user_file.id
|
||||
)
|
||||
|
||||
if all_files_completed:
|
||||
create_notification(
|
||||
user_id=assistant.user_id,
|
||||
notif_type=NotificationType.ASSISTANT_FILES_READY,
|
||||
db_session=self.db_session,
|
||||
title="Your files are ready!",
|
||||
description=f"All files for agent {assistant.name} have been processed and are now available.",
|
||||
additional_data={
|
||||
"persona_id": assistant.id,
|
||||
"link": f"/assistants/{assistant.id}",
|
||||
},
|
||||
autocommit=False,
|
||||
)
|
||||
|
||||
def post_index(
|
||||
self,
|
||||
context: DocumentBatchPrepareContext,
|
||||
@@ -204,7 +244,10 @@ class UserFileIndexingAdapter:
|
||||
user_file_ids = [doc.id for doc in context.updatable_docs]
|
||||
|
||||
user_files = (
|
||||
self.db_session.query(UserFile).filter(UserFile.id.in_(user_file_ids)).all()
|
||||
self.db_session.query(UserFile)
|
||||
.options(selectinload(UserFile.assistants).selectinload(Persona.user_files))
|
||||
.filter(UserFile.id.in_(user_file_ids))
|
||||
.all()
|
||||
)
|
||||
for user_file in user_files:
|
||||
# don't update the status if the user file is being deleted
|
||||
@@ -217,6 +260,10 @@ class UserFileIndexingAdapter:
|
||||
user_file.token_count = result.user_file_id_to_token_count[
|
||||
str(user_file.id)
|
||||
]
|
||||
|
||||
# Notify assistant owners if all their files are now processed
|
||||
self._notify_assistant_owners_if_files_ready(user_files)
|
||||
|
||||
self.db_session.commit()
|
||||
|
||||
# Store the plaintext in the file store for faster retrieval
|
||||
|
||||
@@ -40,6 +40,7 @@ class BaseChunk(BaseModel):
|
||||
source_links: dict[int, str] | None
|
||||
image_file_id: str | None
|
||||
# True if this Chunk's start is not at the start of a Section
|
||||
# TODO(andrei): This is deprecated as of the OpenSearch migration. Remove.
|
||||
section_continuation: bool
|
||||
|
||||
|
||||
|
||||
@@ -369,6 +369,8 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
# New output item added
|
||||
output_item = parsed_chunk.get("item", {})
|
||||
if output_item.get("type") == "function_call":
|
||||
# Track that we've received tool calls via streaming
|
||||
self._has_streamed_tool_calls = True
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=ChatCompletionToolCallChunk(
|
||||
@@ -394,6 +396,8 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
content_part: Optional[str] = parsed_chunk.get("delta", None)
|
||||
if content_part:
|
||||
# Track that we've received tool calls via streaming
|
||||
self._has_streamed_tool_calls = True
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=ChatCompletionToolCallChunk(
|
||||
@@ -491,22 +495,72 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
|
||||
elif event_type == "response.completed":
|
||||
# Final event signaling all output items (including parallel tool calls) are done
|
||||
# Check if we already received tool calls via streaming events
|
||||
# There is an issue where OpenAI (not via Azure) will give back the tool calls streamed out as tokens
|
||||
# But on Azure, it's only given out all at once. OpenAI also happens to give back the tool calls in the
|
||||
# response.completed event so we need to throw it out here or there are duplicate tool calls.
|
||||
has_streamed_tool_calls = getattr(self, "_has_streamed_tool_calls", False)
|
||||
|
||||
response_data = parsed_chunk.get("response", {})
|
||||
# Determine finish reason based on response content
|
||||
finish_reason = "stop"
|
||||
if response_data.get("output"):
|
||||
for item in response_data["output"]:
|
||||
if isinstance(item, dict) and item.get("type") == "function_call":
|
||||
finish_reason = "tool_calls"
|
||||
break
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason=finish_reason,
|
||||
usage=None,
|
||||
output_items = response_data.get("output", [])
|
||||
|
||||
# Check if there are function_call items in the output
|
||||
has_function_calls = any(
|
||||
isinstance(item, dict) and item.get("type") == "function_call"
|
||||
for item in output_items
|
||||
)
|
||||
|
||||
if has_function_calls and not has_streamed_tool_calls:
|
||||
# Azure's Responses API returns all tool calls in response.completed
|
||||
# without streaming them incrementally. Extract them here.
|
||||
from litellm.types.utils import (
|
||||
Delta,
|
||||
ModelResponseStream,
|
||||
StreamingChoices,
|
||||
)
|
||||
|
||||
tool_calls = []
|
||||
for idx, item in enumerate(output_items):
|
||||
if isinstance(item, dict) and item.get("type") == "function_call":
|
||||
tool_calls.append(
|
||||
ChatCompletionToolCallChunk(
|
||||
id=item.get("call_id"),
|
||||
index=idx,
|
||||
type="function",
|
||||
function=ChatCompletionToolCallFunctionChunk(
|
||||
name=item.get("name"),
|
||||
arguments=item.get("arguments", ""),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return ModelResponseStream(
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
index=0,
|
||||
delta=Delta(tool_calls=tool_calls),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
]
|
||||
)
|
||||
elif has_function_calls:
|
||||
# Tool calls were already streamed, just signal completion
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason="tool_calls",
|
||||
usage=None,
|
||||
)
|
||||
else:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason="stop",
|
||||
usage=None,
|
||||
)
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
@@ -631,6 +685,40 @@ def _patch_openai_responses_transform_response() -> None:
|
||||
LiteLLMResponsesTransformationHandler.transform_response = _patched_transform_response # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _patch_azure_responses_should_fake_stream() -> None:
|
||||
"""
|
||||
Patches AzureOpenAIResponsesAPIConfig.should_fake_stream to always return False.
|
||||
|
||||
By default, LiteLLM uses "fake streaming" (MockResponsesAPIStreamingIterator) for models
|
||||
not in its database. This causes Azure custom model deployments to buffer the entire
|
||||
response before yielding, resulting in poor time-to-first-token.
|
||||
|
||||
Azure's Responses API supports native streaming, so we override this to always use
|
||||
real streaming (SyncResponsesAPIStreamingIterator).
|
||||
"""
|
||||
from litellm.llms.azure.responses.transformation import (
|
||||
AzureOpenAIResponsesAPIConfig,
|
||||
)
|
||||
|
||||
if (
|
||||
getattr(AzureOpenAIResponsesAPIConfig.should_fake_stream, "__name__", "")
|
||||
== "_patched_should_fake_stream"
|
||||
):
|
||||
return
|
||||
|
||||
def _patched_should_fake_stream(
|
||||
self: Any,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
# Azure Responses API supports native streaming - never fake it
|
||||
return False
|
||||
|
||||
_patched_should_fake_stream.__name__ = "_patched_should_fake_stream"
|
||||
AzureOpenAIResponsesAPIConfig.should_fake_stream = _patched_should_fake_stream # type: ignore[method-assign]
|
||||
|
||||
|
||||
def apply_monkey_patches() -> None:
|
||||
"""
|
||||
Apply all necessary monkey patches to LiteLLM for compatibility.
|
||||
@@ -640,12 +728,13 @@ def apply_monkey_patches() -> None:
|
||||
- Patching OllamaChatCompletionResponseIterator.chunk_parser for streaming content
|
||||
- Patching OpenAiResponsesToChatCompletionStreamIterator.chunk_parser for OpenAI Responses API
|
||||
- Patching LiteLLMResponsesTransformationHandler.transform_response for non-streaming responses
|
||||
- Patching LiteLLMResponsesTransformationHandler._convert_content_str_to_input_text for tool content types
|
||||
- Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming
|
||||
"""
|
||||
_patch_ollama_transform_request()
|
||||
_patch_ollama_chunk_parser()
|
||||
_patch_openai_responses_chunk_parser()
|
||||
_patch_openai_responses_transform_response()
|
||||
_patch_azure_responses_should_fake_stream()
|
||||
|
||||
|
||||
def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[str]]:
|
||||
|
||||
@@ -160,6 +160,22 @@ def _validate_and_extract_base_fields(
|
||||
return str(response_id), str(created), choices[0] or {}
|
||||
|
||||
|
||||
def _usage_from_usage_data(usage_data: dict[str, Any]) -> Usage:
|
||||
# NOTE: sometimes the usage data dictionary has these keys and the values are None
|
||||
# hence the "or 0" instead of just using default values
|
||||
return Usage(
|
||||
completion_tokens=usage_data.get("completion_tokens") or 0,
|
||||
prompt_tokens=usage_data.get("prompt_tokens") or 0,
|
||||
total_tokens=usage_data.get("total_tokens") or 0,
|
||||
cache_creation_input_tokens=usage_data.get("cache_creation_input_tokens") or 0,
|
||||
cache_read_input_tokens=usage_data.get(
|
||||
"cache_read_input_tokens",
|
||||
(usage_data.get("prompt_tokens_details") or {}).get("cached_tokens"),
|
||||
)
|
||||
or 0,
|
||||
)
|
||||
|
||||
|
||||
def from_litellm_model_response_stream(
|
||||
response: "LiteLLMModelResponseStream",
|
||||
) -> ModelResponseStream:
|
||||
@@ -189,24 +205,7 @@ def from_litellm_model_response_stream(
|
||||
id=response_id,
|
||||
created=created,
|
||||
choice=streaming_choice,
|
||||
usage=(
|
||||
Usage(
|
||||
completion_tokens=usage_data.get("completion_tokens", 0),
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
total_tokens=usage_data.get("total_tokens", 0),
|
||||
cache_creation_input_tokens=usage_data.get(
|
||||
"cache_creation_input_tokens", 0
|
||||
),
|
||||
cache_read_input_tokens=usage_data.get(
|
||||
"cache_read_input_tokens",
|
||||
(usage_data.get("prompt_tokens_details") or {}).get(
|
||||
"cached_tokens", 0
|
||||
),
|
||||
),
|
||||
)
|
||||
if usage_data
|
||||
else None
|
||||
),
|
||||
usage=(_usage_from_usage_data(usage_data) if usage_data else None),
|
||||
)
|
||||
|
||||
|
||||
@@ -242,22 +241,5 @@ def from_litellm_model_response(
|
||||
id=response_id,
|
||||
created=created,
|
||||
choice=choice,
|
||||
usage=(
|
||||
Usage(
|
||||
completion_tokens=usage_data.get("completion_tokens", 0),
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
total_tokens=usage_data.get("total_tokens", 0),
|
||||
cache_creation_input_tokens=usage_data.get(
|
||||
"cache_creation_input_tokens", 0
|
||||
),
|
||||
cache_read_input_tokens=usage_data.get(
|
||||
"cache_read_input_tokens",
|
||||
(usage_data.get("prompt_tokens_details") or {}).get(
|
||||
"cached_tokens", 0
|
||||
),
|
||||
),
|
||||
)
|
||||
if usage_data
|
||||
else None
|
||||
),
|
||||
usage=(_usage_from_usage_data(usage_data) if usage_data else None),
|
||||
)
|
||||
|
||||
@@ -63,7 +63,7 @@ def process_with_prompt_cache(
|
||||
return suffix, None
|
||||
|
||||
# Get provider adapter
|
||||
provider_adapter = get_provider_adapter(llm_config.model_provider)
|
||||
provider_adapter = get_provider_adapter(llm_config)
|
||||
|
||||
# If provider doesn't support caching, combine and return unchanged
|
||||
if not provider_adapter.supports_caching():
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
"""Factory for creating provider-specific prompt cache adapters."""
|
||||
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.prompt_cache.providers.anthropic import AnthropicPromptCacheProvider
|
||||
from onyx.llm.prompt_cache.providers.base import PromptCacheProvider
|
||||
from onyx.llm.prompt_cache.providers.noop import NoOpPromptCacheProvider
|
||||
from onyx.llm.prompt_cache.providers.openai import OpenAIPromptCacheProvider
|
||||
from onyx.llm.prompt_cache.providers.vertex import VertexAIPromptCacheProvider
|
||||
|
||||
ANTHROPIC_BEDROCK_TAG = "anthropic."
|
||||
|
||||
def get_provider_adapter(provider: str) -> PromptCacheProvider:
|
||||
|
||||
def get_provider_adapter(llm_config: LLMConfig) -> PromptCacheProvider:
|
||||
"""Get the appropriate prompt cache provider adapter for a given provider.
|
||||
|
||||
Args:
|
||||
@@ -17,11 +20,14 @@ def get_provider_adapter(provider: str) -> PromptCacheProvider:
|
||||
Returns:
|
||||
PromptCacheProvider instance for the given provider
|
||||
"""
|
||||
if provider == LlmProviderNames.OPENAI:
|
||||
if llm_config.model_provider == LlmProviderNames.OPENAI:
|
||||
return OpenAIPromptCacheProvider()
|
||||
elif provider in [LlmProviderNames.ANTHROPIC, LlmProviderNames.BEDROCK]:
|
||||
elif llm_config.model_provider == LlmProviderNames.ANTHROPIC or (
|
||||
llm_config.model_provider == LlmProviderNames.BEDROCK
|
||||
and ANTHROPIC_BEDROCK_TAG in llm_config.model_name
|
||||
):
|
||||
return AnthropicPromptCacheProvider()
|
||||
elif provider == LlmProviderNames.VERTEX_AI:
|
||||
elif llm_config.model_provider == LlmProviderNames.VERTEX_AI:
|
||||
return VertexAIPromptCacheProvider()
|
||||
else:
|
||||
# Default to no-op for providers without caching support
|
||||
|
||||
@@ -48,7 +48,7 @@ class VertexAIPromptCacheProvider(PromptCacheProvider):
|
||||
cacheable_prefix=cacheable_prefix,
|
||||
suffix=suffix,
|
||||
continuation=continuation,
|
||||
transform_cacheable=_add_vertex_cache_control,
|
||||
transform_cacheable=None, # TODO: support explicit caching
|
||||
)
|
||||
|
||||
def extract_cache_metadata(
|
||||
@@ -89,6 +89,10 @@ def _add_vertex_cache_control(
|
||||
not at the message level. This function converts string content to the array format
|
||||
and adds cache_control to the last content block in each cacheable message.
|
||||
"""
|
||||
# NOTE: unfortunately we need a much more sophisticated mechnism to support
|
||||
# explict caching with vertex in the presence of tools and system messages
|
||||
# (since they're supposed to be stripped out when setting cache_control)
|
||||
# so we're deferring this to a future PR.
|
||||
updated: list[ChatCompletionMessage] = []
|
||||
for message in messages:
|
||||
mutated = dict(message)
|
||||
|
||||
@@ -82,7 +82,6 @@ def fetch_llm_recommendations_from_github(
|
||||
|
||||
def sync_llm_models_from_github(
|
||||
db_session: Session,
|
||||
config: LLMRecommendations,
|
||||
force: bool = False,
|
||||
) -> dict[str, int]:
|
||||
"""Sync models from GitHub config to database for all Auto mode providers.
|
||||
@@ -101,19 +100,24 @@ def sync_llm_models_from_github(
|
||||
Returns:
|
||||
Dict of provider_name -> number of changes made.
|
||||
"""
|
||||
# Skip if we've already processed this version (unless forced)
|
||||
last_updated_at = _get_cached_last_updated_at()
|
||||
if not force and last_updated_at and config.updated_at <= last_updated_at:
|
||||
logger.debug("GitHub config unchanged, skipping sync")
|
||||
return {}
|
||||
|
||||
results: dict[str, int] = {}
|
||||
|
||||
# Get all providers in Auto mode
|
||||
auto_providers = fetch_auto_mode_providers(db_session)
|
||||
|
||||
if not auto_providers:
|
||||
logger.debug("No providers in Auto mode found")
|
||||
return {}
|
||||
|
||||
# Fetch config from GitHub
|
||||
config = fetch_llm_recommendations_from_github()
|
||||
if not config:
|
||||
logger.warning("Failed to fetch GitHub config")
|
||||
return {}
|
||||
|
||||
# Skip if we've already processed this version (unless forced)
|
||||
last_updated_at = _get_cached_last_updated_at()
|
||||
if not force and last_updated_at and config.updated_at <= last_updated_at:
|
||||
logger.debug("GitHub config unchanged, skipping sync")
|
||||
_set_cached_last_updated_at(config.updated_at)
|
||||
return {}
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.natural_language_processing.utils import tokenizer_trim_content
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.search_nlp_models_utils import pass_aws_key
|
||||
from onyx.utils.text_processing import remove_invalid_unicode_chars
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import API_BASED_EMBEDDING_TIMEOUT
|
||||
from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
@@ -984,6 +985,10 @@ class EmbeddingModel:
|
||||
for text in texts
|
||||
]
|
||||
|
||||
# Remove invalid Unicode characters (e.g., unpaired surrogates from malformed documents)
|
||||
# that would cause UTF-8 encoding errors when sent to embedding providers
|
||||
texts = [remove_invalid_unicode_chars(text) or "<>" for text in texts]
|
||||
|
||||
batch_size = (
|
||||
api_embedding_batch_size
|
||||
if self.provider_type
|
||||
|
||||
@@ -35,6 +35,7 @@ from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
|
||||
from onyx.onyxbot.slack.utils import SlackRateLimiter
|
||||
from onyx.onyxbot.slack.utils import update_emote_react
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.utils.logger import OnyxLoggingAdapter
|
||||
|
||||
srl = SlackRateLimiter()
|
||||
@@ -236,6 +237,7 @@ def handle_regular_answer(
|
||||
retrieval_details=retrieval_details,
|
||||
rerank_settings=None, # Rerank customization supported in Slack flow
|
||||
db_session=db_session,
|
||||
origin=MessageOrigin.SLACKBOT,
|
||||
)
|
||||
|
||||
# if it's a DM or ephemeral message, answer based on private documents.
|
||||
|
||||
@@ -42,6 +42,8 @@ Your job is now to organize the findings to return a comprehensive report that p
|
||||
The report will be seen by another agent instead of a user so keep it free of formatting or commentary and instead focus on the facts only. \
|
||||
Do not give it a title, do not break it down into sections, and do not provide any of your own conclusions/analysis.
|
||||
|
||||
You may see a list of tool calls in the history but you do not have access to tools anymore. You should only use the information in the history to create the report.
|
||||
|
||||
CRITICAL - This report should be as long as necessary to return ALL of the information that the researcher has gathered. It should be several pages long so as to capture as much detail as possible from the research. \
|
||||
It cannot be stressed enough that this report must be EXTREMELY THOROUGH and COMPREHENSIVE. Only this report is going to be returned, so it's CRUCIAL that you don't lose any details from the raw messages.
|
||||
|
||||
|
||||
@@ -1,30 +1,39 @@
|
||||
from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS
|
||||
|
||||
SLACK_QUERY_EXPANSION_PROMPT = f"""
|
||||
Rewrite the user's query and, if helpful, split it into at most {MAX_SLACK_QUERY_EXPANSIONS} \
|
||||
keyword-only queries, so that Slack's keyword search yields the best matches.
|
||||
Rewrite the user's query into at most {MAX_SLACK_QUERY_EXPANSIONS} keyword-only queries for Slack's keyword search.
|
||||
|
||||
Keep in mind the Slack's search behavior:
|
||||
- Pure keyword AND search (no semantics).
|
||||
- Word order matters.
|
||||
- More words = fewer matches, so keep each query concise.
|
||||
- IMPORTANT: Prefer simple 1-2 word queries over longer multi-word queries.
|
||||
Slack search behavior:
|
||||
- Pure keyword AND search (no semantics)
|
||||
- More words = fewer matches, so keep queries concise (1-3 words)
|
||||
|
||||
Critical: Extract ONLY keywords that would actually appear in Slack message content.
|
||||
ALWAYS include:
|
||||
- Person names (e.g., "Sarah Chen", "Mike Johnson") - people search for messages from/about specific people
|
||||
- Project/product names, technical terms, proper nouns
|
||||
- Actual content words: "performance", "bug", "deployment", "API", "error"
|
||||
|
||||
DO NOT include:
|
||||
- Meta-words: "topics", "conversations", "discussed", "summary", "messages", "big", "main", "talking"
|
||||
- Temporal: "today", "yesterday", "week", "month", "recent", "past", "last"
|
||||
- Channels/Users: "general", "eng-general", "engineering", "@username"
|
||||
|
||||
DO include:
|
||||
- Actual content: "performance", "bug", "deployment", "API", "database", "error", "feature"
|
||||
- Meta-words: "topics", "conversations", "discussed", "summary", "messages"
|
||||
- Temporal: "today", "yesterday", "week", "month", "recent", "last"
|
||||
- Channel names: "general", "eng-general", "random"
|
||||
|
||||
Examples:
|
||||
|
||||
Query: "what are the big topics in eng-general this week?"
|
||||
Output:
|
||||
|
||||
Query: "messages with Sarah about the deployment"
|
||||
Output:
|
||||
Sarah deployment
|
||||
Sarah
|
||||
deployment
|
||||
|
||||
Query: "what did Mike say about the budget?"
|
||||
Output:
|
||||
Mike budget
|
||||
Mike
|
||||
budget
|
||||
|
||||
Query: "performance issues in eng-general"
|
||||
Output:
|
||||
performance issues
|
||||
@@ -41,7 +50,7 @@ Now process this query:
|
||||
|
||||
{{query}}
|
||||
|
||||
Output:
|
||||
Output (keywords only, one per line, NO explanations or commentary):
|
||||
"""
|
||||
|
||||
SLACK_DATE_EXTRACTION_PROMPT = """
|
||||
|
||||
@@ -33,9 +33,13 @@ WEB_SEARCH_GUIDANCE = """
|
||||
Use the `web_search` tool to access up-to-date information from the web. Some examples of when to use `web_search` include:
|
||||
- Freshness: if up-to-date information on a topic could change or enhance the answer. Very important for topics that are changing or evolving.
|
||||
- Niche Information: detailed info not widely known or understood (but that is likely found on the internet).
|
||||
- Accuracy: if the cost of outdated information is high, use web sources directly.
|
||||
- Accuracy: if the cost of outdated information is high, use web sources directly.{site_colon_disabled}
|
||||
"""
|
||||
|
||||
WEB_SEARCH_SITE_DISABLED_GUIDANCE = """
|
||||
Do not use the "site:" operator in your web search queries.
|
||||
""".rstrip()
|
||||
|
||||
|
||||
OPEN_URLS_GUIDANCE = """
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user