mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-02 22:25:47 +00:00
Compare commits
6 Commits
v2.9.8
...
thread_sen
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5848975679 | ||
|
|
dcc330010e | ||
|
|
d0f5f1f5ae | ||
|
|
3e475993ff | ||
|
|
7c2b5fa822 | ||
|
|
409cfdc788 |
389
.github/workflows/deployment.yml
vendored
389
.github/workflows/deployment.yml
vendored
@@ -8,9 +8,7 @@ on:
|
||||
|
||||
# Set restrictive default permissions for all jobs. Jobs that need more permissions
|
||||
# should explicitly declare them.
|
||||
permissions:
|
||||
# Required for OIDC authentication with AWS
|
||||
id-token: write # zizmor: ignore[excessive-permissions]
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
|
||||
@@ -152,30 +150,16 @@ jobs:
|
||||
if: always() && needs.check-version-tag.result == 'failure' && github.event_name != 'workflow_dispatch'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
environment: release
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
MONITOR_DEPLOYMENTS_WEBHOOK, deploy/monitor-deployments-webhook
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ env.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
failed-jobs: "• check-version-tag"
|
||||
title: "🚨 Version Tag Check Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
@@ -184,7 +168,6 @@ jobs:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-desktop == 'true'
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: write
|
||||
actions: read
|
||||
strategy:
|
||||
@@ -202,33 +185,12 @@ jobs:
|
||||
|
||||
runs-on: ${{ matrix.platform }}
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6.0.1
|
||||
with:
|
||||
# NOTE: persist-credentials is needed for tauri-action to create GitHub releases.
|
||||
persist-credentials: true # zizmor: ignore[artipacked]
|
||||
|
||||
- name: Configure AWS credentials
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
APPLE_ID, deploy/apple-id
|
||||
APPLE_PASSWORD, deploy/apple-password
|
||||
APPLE_CERTIFICATE, deploy/apple-certificate
|
||||
APPLE_CERTIFICATE_PASSWORD, deploy/apple-certificate-password
|
||||
KEYCHAIN_PASSWORD, deploy/keychain-password
|
||||
APPLE_TEAM_ID, deploy/apple-team-id
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: install dependencies (ubuntu only)
|
||||
if: startsWith(matrix.platform, 'ubuntu-')
|
||||
run: |
|
||||
@@ -323,40 +285,15 @@ jobs:
|
||||
|
||||
Write-Host "Versions set to: $VERSION"
|
||||
|
||||
- name: Import Apple Developer Certificate
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
run: |
|
||||
echo $APPLE_CERTIFICATE | base64 --decode > certificate.p12
|
||||
security create-keychain -p "$KEYCHAIN_PASSWORD" build.keychain
|
||||
security default-keychain -s build.keychain
|
||||
security unlock-keychain -p "$KEYCHAIN_PASSWORD" build.keychain
|
||||
security set-keychain-settings -t 3600 -u build.keychain
|
||||
security import certificate.p12 -k build.keychain -P "$APPLE_CERTIFICATE_PASSWORD" -T /usr/bin/codesign
|
||||
security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k "$KEYCHAIN_PASSWORD" build.keychain
|
||||
security find-identity -v -p codesigning build.keychain
|
||||
|
||||
- name: Verify Certificate
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
run: |
|
||||
CERT_INFO=$(security find-identity -v -p codesigning build.keychain | grep -E "(Developer ID Application|Apple Distribution|Apple Development)" | head -n 1)
|
||||
CERT_ID=$(echo "$CERT_INFO" | awk -F'"' '{print $2}')
|
||||
echo "CERT_ID=$CERT_ID" >> $GITHUB_ENV
|
||||
echo "Certificate imported."
|
||||
|
||||
- uses: tauri-apps/tauri-action@73fb865345c54760d875b94642314f8c0c894afa # ratchet:tauri-apps/tauri-action@action-v0.6.1
|
||||
- uses: tauri-apps/tauri-action@19b93bb55601e3e373a93cfb6eb4242e45f5af20 # ratchet:tauri-apps/tauri-action@action-v0.6.0
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
APPLE_ID: ${{ env.APPLE_ID }}
|
||||
APPLE_PASSWORD: ${{ env.APPLE_PASSWORD }}
|
||||
APPLE_SIGNING_IDENTITY: ${{ env.CERT_ID }}
|
||||
APPLE_TEAM_ID: ${{ env.APPLE_TEAM_ID }}
|
||||
with:
|
||||
tagName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseBody: "See the assets to download this version and install."
|
||||
releaseDraft: true
|
||||
prerelease: false
|
||||
assetNamePattern: "[name]_[arch][ext]"
|
||||
args: ${{ matrix.args }}
|
||||
|
||||
build-web-amd64:
|
||||
@@ -368,7 +305,6 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-web-amd64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -381,20 +317,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -409,8 +331,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
@@ -441,7 +363,6 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-web-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -454,20 +375,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -482,8 +389,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
@@ -516,34 +423,19 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-merge-web
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -579,7 +471,6 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-web-cloud-amd64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -592,20 +483,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -620,8 +497,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
@@ -660,7 +537,6 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-web-cloud-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -673,20 +549,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -701,8 +563,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
@@ -743,34 +605,19 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-merge-web-cloud
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -803,7 +650,6 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-backend-amd64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -816,20 +662,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -844,8 +676,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
@@ -875,7 +707,6 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-backend-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -888,20 +719,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -916,8 +733,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
@@ -949,34 +766,19 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-merge-backend
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -1013,7 +815,6 @@ jobs:
|
||||
- volume=40gb
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -1026,20 +827,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -1056,8 +843,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
@@ -1092,7 +879,6 @@ jobs:
|
||||
- volume=40gb
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -1105,20 +891,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -1135,8 +907,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
@@ -1172,34 +944,19 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-merge-model-server
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -1237,26 +994,11 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-trivy-scan-web
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
@@ -1272,8 +1014,8 @@ jobs:
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
@@ -1292,26 +1034,11 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-trivy-scan-web-cloud
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
@@ -1327,8 +1054,8 @@ jobs:
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
@@ -1347,7 +1074,6 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-trivy-scan-backend
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
steps:
|
||||
@@ -1358,20 +1084,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
@@ -1388,8 +1100,8 @@ jobs:
|
||||
-v ${{ github.workspace }}/backend/.trivyignore:/tmp/.trivyignore:ro \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
@@ -1409,26 +1121,11 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-trivy-scan-model-server
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
@@ -1444,8 +1141,8 @@ jobs:
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
@@ -1473,26 +1170,12 @@ jobs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
MONITOR_DEPLOYMENTS_WEBHOOK, deploy/monitor-deployments-webhook
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Determine failed jobs
|
||||
id: failed-jobs
|
||||
shell: bash
|
||||
@@ -1558,7 +1241,7 @@ jobs:
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ env.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
failed-jobs: ${{ steps.failed-jobs.outputs.jobs }}
|
||||
title: "🚨 Deployment Workflow Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
|
||||
@@ -13,7 +13,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # ratchet:actions/stale@v10
|
||||
- uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # ratchet:actions/stale@v10
|
||||
with:
|
||||
stale-issue-message: 'This issue is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
|
||||
stale-pr-message: 'This PR is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
|
||||
|
||||
@@ -172,7 +172,7 @@ jobs:
|
||||
|
||||
- name: Upload Docker logs
|
||||
if: failure()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v5
|
||||
with:
|
||||
name: docker-logs-${{ matrix.test-dir }}
|
||||
path: docker-logs/
|
||||
|
||||
7
.github/workflows/pr-integration-tests.yml
vendored
7
.github/workflows/pr-integration-tests.yml
vendored
@@ -310,9 +310,8 @@ jobs:
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=1
|
||||
MCP_SERVER_ENABLED=true
|
||||
USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
|
||||
EOF
|
||||
|
||||
- name: Start Docker containers
|
||||
@@ -439,7 +438,7 @@ jobs:
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-all-logs-${{ matrix.test-dir.name }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
@@ -568,7 +567,7 @@ jobs:
|
||||
|
||||
- name: Upload logs (multi-tenant)
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-all-logs-multitenant
|
||||
path: ${{ github.workspace }}/docker-compose-multitenant.log
|
||||
|
||||
2
.github/workflows/pr-jest-tests.yml
vendored
2
.github/workflows/pr-jest-tests.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
with:
|
||||
name: jest-coverage-${{ github.run_id }}
|
||||
path: ./web/coverage
|
||||
|
||||
@@ -301,7 +301,7 @@ jobs:
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
MCP_SERVER_ENABLED=true
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=1
|
||||
EOF
|
||||
|
||||
- name: Start Docker containers
|
||||
@@ -424,7 +424,7 @@ jobs:
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-all-logs-${{ matrix.test-dir.name }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
4
.github/workflows/pr-playwright-tests.yml
vendored
4
.github/workflows/pr-playwright-tests.yml
vendored
@@ -435,7 +435,7 @@ jobs:
|
||||
fi
|
||||
npx playwright test --project ${PROJECT}
|
||||
|
||||
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
- uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
if: always()
|
||||
with:
|
||||
# Includes test results and trace.zip files
|
||||
@@ -455,7 +455,7 @@ jobs:
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
3
.github/workflows/pr-python-checks.yml
vendored
3
.github/workflows/pr-python-checks.yml
vendored
@@ -50,9 +50,8 @@ jobs:
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: backend/.mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
key: mypy-${{ runner.os }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
restore-keys: |
|
||||
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
|
||||
mypy-${{ runner.os }}-
|
||||
|
||||
- name: Run MyPy
|
||||
|
||||
2
.github/workflows/pr-python-model-tests.yml
vendored
2
.github/workflows/pr-python-model-tests.yml
vendored
@@ -144,7 +144,7 @@ jobs:
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-all-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -21,7 +21,6 @@ backend/tests/regression/search_quality/*.json
|
||||
backend/onyx/evals/data/
|
||||
backend/onyx/evals/one_off/*.json
|
||||
*.log
|
||||
*.csv
|
||||
|
||||
# secret files
|
||||
.env
|
||||
|
||||
@@ -11,6 +11,7 @@ repos:
|
||||
- id: uv-sync
|
||||
args: ["--locked", "--all-extras"]
|
||||
- id: uv-lock
|
||||
files: ^pyproject\.toml$
|
||||
- id: uv-export
|
||||
name: uv-export default.txt
|
||||
args:
|
||||
|
||||
@@ -225,6 +225,7 @@ def do_run_migrations(
|
||||
) -> None:
|
||||
if create_schema:
|
||||
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
|
||||
connection.execute(text("COMMIT"))
|
||||
|
||||
connection.execute(text(f'SET search_path TO "{schema_name}"'))
|
||||
|
||||
@@ -308,7 +309,6 @@ async def run_async_migrations() -> None:
|
||||
schema_name=schema,
|
||||
create_schema=create_schema,
|
||||
)
|
||||
await connection.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema}: {e}")
|
||||
if not continue_on_error:
|
||||
@@ -346,7 +346,6 @@ async def run_async_migrations() -> None:
|
||||
schema_name=schema,
|
||||
create_schema=create_schema,
|
||||
)
|
||||
await connection.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema}: {e}")
|
||||
if not continue_on_error:
|
||||
|
||||
@@ -85,122 +85,103 @@ class UserRow(NamedTuple):
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Step 1: Create or update the unified assistant (ID 0)
|
||||
search_assistant = conn.execute(
|
||||
sa.text("SELECT * FROM persona WHERE id = 0")
|
||||
).fetchone()
|
||||
# Start transaction
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
if search_assistant:
|
||||
# Update existing Search assistant to be the unified assistant
|
||||
try:
|
||||
# Step 1: Create or update the unified assistant (ID 0)
|
||||
search_assistant = conn.execute(
|
||||
sa.text("SELECT * FROM persona WHERE id = 0")
|
||||
).fetchone()
|
||||
|
||||
if search_assistant:
|
||||
# Update existing Search assistant to be the unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET name = :name,
|
||||
description = :description,
|
||||
system_prompt = :system_prompt,
|
||||
num_chunks = :num_chunks,
|
||||
is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false,
|
||||
display_priority = :display_priority,
|
||||
llm_filter_extraction = :llm_filter_extraction,
|
||||
llm_relevance_filter = :llm_relevance_filter,
|
||||
recency_bias = :recency_bias,
|
||||
chunks_above = :chunks_above,
|
||||
chunks_below = :chunks_below,
|
||||
datetime_aware = :datetime_aware,
|
||||
starter_messages = null
|
||||
WHERE id = 0
|
||||
"""
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
else:
|
||||
# Create new unified assistant with ID 0
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona (
|
||||
id, name, description, system_prompt, num_chunks,
|
||||
is_default_persona, is_visible, deleted, display_priority,
|
||||
llm_filter_extraction, llm_relevance_filter, recency_bias,
|
||||
chunks_above, chunks_below, datetime_aware, starter_messages,
|
||||
builtin_persona
|
||||
) VALUES (
|
||||
0, :name, :description, :system_prompt, :num_chunks,
|
||||
true, true, false, :display_priority, :llm_filter_extraction,
|
||||
:llm_relevance_filter, :recency_bias, :chunks_above, :chunks_below,
|
||||
:datetime_aware, null, true
|
||||
)
|
||||
"""
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
|
||||
# Step 2: Mark ALL builtin assistants as deleted (except the unified assistant ID 0)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET name = :name,
|
||||
description = :description,
|
||||
system_prompt = :system_prompt,
|
||||
num_chunks = :num_chunks,
|
||||
is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false,
|
||||
display_priority = :display_priority,
|
||||
llm_filter_extraction = :llm_filter_extraction,
|
||||
llm_relevance_filter = :llm_relevance_filter,
|
||||
recency_bias = :recency_bias,
|
||||
chunks_above = :chunks_above,
|
||||
chunks_below = :chunks_below,
|
||||
datetime_aware = :datetime_aware,
|
||||
starter_messages = null
|
||||
WHERE id = 0
|
||||
SET deleted = true, is_visible = false, is_default_persona = false
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
"""
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
else:
|
||||
# Create new unified assistant with ID 0
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona (
|
||||
id, name, description, system_prompt, num_chunks,
|
||||
is_default_persona, is_visible, deleted, display_priority,
|
||||
llm_filter_extraction, llm_relevance_filter, recency_bias,
|
||||
chunks_above, chunks_below, datetime_aware, starter_messages,
|
||||
builtin_persona
|
||||
) VALUES (
|
||||
0, :name, :description, :system_prompt, :num_chunks,
|
||||
true, true, false, :display_priority, :llm_filter_extraction,
|
||||
:llm_relevance_filter, :recency_bias, :chunks_above, :chunks_below,
|
||||
:datetime_aware, null, true
|
||||
)
|
||||
"""
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
)
|
||||
|
||||
# Step 2: Mark ALL builtin assistants as deleted (except the unified assistant ID 0)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = true, is_visible = false, is_default_persona = false
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
"""
|
||||
)
|
||||
)
|
||||
# Step 3: Add all built-in tools to the unified assistant
|
||||
# First, get the tool IDs for SearchTool, ImageGenerationTool, and WebSearchTool
|
||||
search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'SearchTool'")
|
||||
).fetchone()
|
||||
|
||||
# Step 3: Add all built-in tools to the unified assistant
|
||||
# First, get the tool IDs for SearchTool, ImageGenerationTool, and WebSearchTool
|
||||
search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'SearchTool'")
|
||||
).fetchone()
|
||||
if not search_tool:
|
||||
raise ValueError(
|
||||
"SearchTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
|
||||
if not search_tool:
|
||||
raise ValueError(
|
||||
"SearchTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
image_gen_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'ImageGenerationTool'")
|
||||
).fetchone()
|
||||
|
||||
image_gen_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'ImageGenerationTool'")
|
||||
).fetchone()
|
||||
if not image_gen_tool:
|
||||
raise ValueError(
|
||||
"ImageGenerationTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
|
||||
if not image_gen_tool:
|
||||
raise ValueError(
|
||||
"ImageGenerationTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
# WebSearchTool is optional - may not be configured
|
||||
web_search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'WebSearchTool'")
|
||||
).fetchone()
|
||||
|
||||
# WebSearchTool is optional - may not be configured
|
||||
web_search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'WebSearchTool'")
|
||||
).fetchone()
|
||||
# Clear existing tool associations for persona 0
|
||||
conn.execute(sa.text("DELETE FROM persona__tool WHERE persona_id = 0"))
|
||||
|
||||
# Clear existing tool associations for persona 0
|
||||
conn.execute(sa.text("DELETE FROM persona__tool WHERE persona_id = 0"))
|
||||
|
||||
# Add tools to the unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": search_tool[0]},
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": image_gen_tool[0]},
|
||||
)
|
||||
|
||||
if web_search_tool:
|
||||
# Add tools to the unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
@@ -209,148 +190,191 @@ def upgrade() -> None:
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": web_search_tool[0]},
|
||||
{"tool_id": search_tool[0]},
|
||||
)
|
||||
|
||||
# Step 4: Migrate existing chat sessions from all builtin assistants to unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
UPDATE chat_session
|
||||
SET persona_id = 0
|
||||
WHERE persona_id IN (
|
||||
SELECT id FROM persona WHERE builtin_persona = true AND id != 0
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"tool_id": image_gen_tool[0]},
|
||||
)
|
||||
)
|
||||
|
||||
# Step 5: Migrate user preferences - remove references to all builtin assistants
|
||||
# First, get all builtin assistant IDs (except 0)
|
||||
builtin_assistants_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id FROM persona
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
builtin_assistant_ids = [row[0] for row in builtin_assistants_result]
|
||||
|
||||
# Get all users with preferences
|
||||
users_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, chosen_assistants, visible_assistants,
|
||||
hidden_assistants, pinned_assistants
|
||||
FROM "user"
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for user_row in users_result:
|
||||
user = UserRow(*user_row)
|
||||
user_id: UUID = user.id
|
||||
updates: dict[str, Any] = {}
|
||||
|
||||
# Remove all builtin assistants from chosen_assistants
|
||||
if user.chosen_assistants:
|
||||
new_chosen: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.chosen_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_chosen != user.chosen_assistants:
|
||||
updates["chosen_assistants"] = json.dumps(new_chosen)
|
||||
|
||||
# Remove all builtin assistants from visible_assistants
|
||||
if user.visible_assistants:
|
||||
new_visible: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.visible_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_visible != user.visible_assistants:
|
||||
updates["visible_assistants"] = json.dumps(new_visible)
|
||||
|
||||
# Add all builtin assistants to hidden_assistants
|
||||
if user.hidden_assistants:
|
||||
new_hidden: list[int] = list(user.hidden_assistants)
|
||||
for old_id in builtin_assistant_ids:
|
||||
if old_id not in new_hidden:
|
||||
new_hidden.append(old_id)
|
||||
if new_hidden != user.hidden_assistants:
|
||||
updates["hidden_assistants"] = json.dumps(new_hidden)
|
||||
else:
|
||||
updates["hidden_assistants"] = json.dumps(builtin_assistant_ids)
|
||||
|
||||
# Remove all builtin assistants from pinned_assistants
|
||||
if user.pinned_assistants:
|
||||
new_pinned: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.pinned_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_pinned != user.pinned_assistants:
|
||||
updates["pinned_assistants"] = json.dumps(new_pinned)
|
||||
|
||||
# Apply updates if any
|
||||
if updates:
|
||||
set_clause = ", ".join([f"{k} = :{k}" for k in updates.keys()])
|
||||
updates["user_id"] = str(user_id) # Convert UUID to string for SQL
|
||||
if web_search_tool:
|
||||
conn.execute(
|
||||
sa.text(f'UPDATE "user" SET {set_clause} WHERE id = :user_id'),
|
||||
updates,
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": web_search_tool[0]},
|
||||
)
|
||||
|
||||
# Step 4: Migrate existing chat sessions from all builtin assistants to unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE chat_session
|
||||
SET persona_id = 0
|
||||
WHERE persona_id IN (
|
||||
SELECT id FROM persona WHERE builtin_persona = true AND id != 0
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Step 5: Migrate user preferences - remove references to all builtin assistants
|
||||
# First, get all builtin assistant IDs (except 0)
|
||||
builtin_assistants_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id FROM persona
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
builtin_assistant_ids = [row[0] for row in builtin_assistants_result]
|
||||
|
||||
# Get all users with preferences
|
||||
users_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, chosen_assistants, visible_assistants,
|
||||
hidden_assistants, pinned_assistants
|
||||
FROM "user"
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for user_row in users_result:
|
||||
user = UserRow(*user_row)
|
||||
user_id: UUID = user.id
|
||||
updates: dict[str, Any] = {}
|
||||
|
||||
# Remove all builtin assistants from chosen_assistants
|
||||
if user.chosen_assistants:
|
||||
new_chosen: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.chosen_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_chosen != user.chosen_assistants:
|
||||
updates["chosen_assistants"] = json.dumps(new_chosen)
|
||||
|
||||
# Remove all builtin assistants from visible_assistants
|
||||
if user.visible_assistants:
|
||||
new_visible: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.visible_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_visible != user.visible_assistants:
|
||||
updates["visible_assistants"] = json.dumps(new_visible)
|
||||
|
||||
# Add all builtin assistants to hidden_assistants
|
||||
if user.hidden_assistants:
|
||||
new_hidden: list[int] = list(user.hidden_assistants)
|
||||
for old_id in builtin_assistant_ids:
|
||||
if old_id not in new_hidden:
|
||||
new_hidden.append(old_id)
|
||||
if new_hidden != user.hidden_assistants:
|
||||
updates["hidden_assistants"] = json.dumps(new_hidden)
|
||||
else:
|
||||
updates["hidden_assistants"] = json.dumps(builtin_assistant_ids)
|
||||
|
||||
# Remove all builtin assistants from pinned_assistants
|
||||
if user.pinned_assistants:
|
||||
new_pinned: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.pinned_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_pinned != user.pinned_assistants:
|
||||
updates["pinned_assistants"] = json.dumps(new_pinned)
|
||||
|
||||
# Apply updates if any
|
||||
if updates:
|
||||
set_clause = ", ".join([f"{k} = :{k}" for k in updates.keys()])
|
||||
updates["user_id"] = str(user_id) # Convert UUID to string for SQL
|
||||
conn.execute(
|
||||
sa.text(f'UPDATE "user" SET {set_clause} WHERE id = :user_id'),
|
||||
updates,
|
||||
)
|
||||
|
||||
# Commit transaction
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on error
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Only restore General (ID -1) and Art (ID -3) assistants
|
||||
# Step 1: Keep Search assistant (ID 0) as default but restore original state
|
||||
conn.execute(
|
||||
sa.text(
|
||||
# Start transaction
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
try:
|
||||
# Only restore General (ID -1) and Art (ID -3) assistants
|
||||
# Step 1: Keep Search assistant (ID 0) as default but restore original state
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false
|
||||
WHERE id = 0
|
||||
"""
|
||||
UPDATE persona
|
||||
SET is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false
|
||||
WHERE id = 0
|
||||
"""
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Step 2: Restore General assistant (ID -1)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
# Step 2: Restore General assistant (ID -1)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :general_assistant_id
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :general_assistant_id
|
||||
"""
|
||||
),
|
||||
{"general_assistant_id": GENERAL_ASSISTANT_ID},
|
||||
)
|
||||
),
|
||||
{"general_assistant_id": GENERAL_ASSISTANT_ID},
|
||||
)
|
||||
|
||||
# Step 3: Restore Art assistant (ID -3)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
# Step 3: Restore Art assistant (ID -3)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :art_assistant_id
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :art_assistant_id
|
||||
"""
|
||||
),
|
||||
{"art_assistant_id": ART_ASSISTANT_ID},
|
||||
)
|
||||
),
|
||||
{"art_assistant_id": ART_ASSISTANT_ID},
|
||||
)
|
||||
|
||||
# Note: We don't restore the original tool associations, names, or descriptions
|
||||
# as those would require more complex logic to determine original state.
|
||||
# We also cannot restore original chat session persona_ids as we don't
|
||||
# have the original mappings.
|
||||
# Other builtin assistants remain deleted as per the requirement.
|
||||
# Note: We don't restore the original tool associations, names, or descriptions
|
||||
# as those would require more complex logic to determine original state.
|
||||
# We also cannot restore original chat session persona_ids as we don't
|
||||
# have the original mappings.
|
||||
# Other builtin assistants remain deleted as per the requirement.
|
||||
|
||||
# Commit transaction
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on error
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
"""notifications constraint, sort index, and cleanup old notifications
|
||||
|
||||
Revision ID: 8405ca81cc83
|
||||
Revises: a3c1a7904cd0
|
||||
Create Date: 2026-01-07 16:43:44.855156
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8405ca81cc83"
|
||||
down_revision = "a3c1a7904cd0"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create unique index for notification deduplication.
|
||||
# This enables atomic ON CONFLICT DO NOTHING inserts in batch_create_notifications.
|
||||
#
|
||||
# Uses COALESCE to handle NULL additional_data (NULLs are normally distinct
|
||||
# in unique constraints, but we want NULL == NULL for deduplication).
|
||||
# The '{}' represents an empty JSONB object as the NULL replacement.
|
||||
|
||||
# Clean up legacy notifications first
|
||||
op.execute("DELETE FROM notification WHERE title = 'New Notification'")
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS ix_notification_user_type_data
|
||||
ON notification (user_id, notif_type, COALESCE(additional_data, '{}'::jsonb))
|
||||
"""
|
||||
)
|
||||
|
||||
# Create index for efficient notification sorting by user
|
||||
# Covers: WHERE user_id = ? ORDER BY dismissed, first_shown DESC
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS ix_notification_user_sort
|
||||
ON notification (user_id, dismissed, first_shown DESC)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS ix_notification_user_type_data")
|
||||
op.execute("DROP INDEX IF EXISTS ix_notification_user_sort")
|
||||
@@ -42,13 +42,20 @@ TOOL_DESCRIPTIONS = {
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
for tool_id, description in TOOL_DESCRIPTIONS.items():
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"UPDATE tool SET description = :description WHERE in_code_tool_id = :tool_id"
|
||||
),
|
||||
{"description": description, "tool_id": tool_id},
|
||||
)
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
try:
|
||||
for tool_id, description in TOOL_DESCRIPTIONS.items():
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"UPDATE tool SET description = :description WHERE in_code_tool_id = :tool_id"
|
||||
),
|
||||
{"description": description, "tool_id": tool_id},
|
||||
)
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
except Exception as e:
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
@@ -7,6 +7,7 @@ Create Date: 2025-12-18 16:00:00.000000
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_DB_NAME
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
@@ -18,7 +19,7 @@ depends_on = None
|
||||
|
||||
|
||||
DEEP_RESEARCH_TOOL = {
|
||||
"name": "ResearchAgent",
|
||||
"name": RESEARCH_AGENT_DB_NAME,
|
||||
"display_name": "Research Agent",
|
||||
"description": "The Research Agent is a sub-agent that conducts research on a specific topic.",
|
||||
"in_code_tool_id": "ResearchAgent",
|
||||
|
||||
@@ -70,66 +70,80 @@ BUILT_IN_TOOLS = [
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Get existing tools to check what already exists
|
||||
existing_tools = conn.execute(
|
||||
sa.text("SELECT in_code_tool_id FROM tool WHERE in_code_tool_id IS NOT NULL")
|
||||
).fetchall()
|
||||
existing_tool_ids = {row[0] for row in existing_tools}
|
||||
# Start transaction
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
# Insert or update built-in tools
|
||||
for tool in BUILT_IN_TOOLS:
|
||||
in_code_id = tool["in_code_tool_id"]
|
||||
try:
|
||||
# Get existing tools to check what already exists
|
||||
existing_tools = conn.execute(
|
||||
sa.text(
|
||||
"SELECT in_code_tool_id FROM tool WHERE in_code_tool_id IS NOT NULL"
|
||||
)
|
||||
).fetchall()
|
||||
existing_tool_ids = {row[0] for row in existing_tools}
|
||||
|
||||
# Handle historical rename: InternetSearchTool -> WebSearchTool
|
||||
if (
|
||||
in_code_id == "WebSearchTool"
|
||||
and "WebSearchTool" not in existing_tool_ids
|
||||
and "InternetSearchTool" in existing_tool_ids
|
||||
):
|
||||
# Rename the existing InternetSearchTool row in place and update fields
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description,
|
||||
in_code_tool_id = :in_code_tool_id
|
||||
WHERE in_code_tool_id = 'InternetSearchTool'
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
# Keep the local view of existing ids in sync to avoid duplicate insert
|
||||
existing_tool_ids.discard("InternetSearchTool")
|
||||
existing_tool_ids.add("WebSearchTool")
|
||||
continue
|
||||
# Insert or update built-in tools
|
||||
for tool in BUILT_IN_TOOLS:
|
||||
in_code_id = tool["in_code_tool_id"]
|
||||
|
||||
if in_code_id in existing_tool_ids:
|
||||
# Update existing tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
else:
|
||||
# Insert new tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id)
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
# Handle historical rename: InternetSearchTool -> WebSearchTool
|
||||
if (
|
||||
in_code_id == "WebSearchTool"
|
||||
and "WebSearchTool" not in existing_tool_ids
|
||||
and "InternetSearchTool" in existing_tool_ids
|
||||
):
|
||||
# Rename the existing InternetSearchTool row in place and update fields
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description,
|
||||
in_code_tool_id = :in_code_tool_id
|
||||
WHERE in_code_tool_id = 'InternetSearchTool'
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
# Keep the local view of existing ids in sync to avoid duplicate insert
|
||||
existing_tool_ids.discard("InternetSearchTool")
|
||||
existing_tool_ids.add("WebSearchTool")
|
||||
continue
|
||||
|
||||
if in_code_id in existing_tool_ids:
|
||||
# Update existing tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
else:
|
||||
# Insert new tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id)
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
|
||||
# Commit transaction
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on error
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
"""sync_exa_api_key_to_content_provider
|
||||
|
||||
Revision ID: d1b637d7050a
|
||||
Revises: d25168c2beee
|
||||
Create Date: 2026-01-09 15:54:15.646249
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
from sqlalchemy import text
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d1b637d7050a"
|
||||
down_revision = "d25168c2beee"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Exa uses a shared API key between search and content providers.
|
||||
# For existing Exa search providers with API keys, create the corresponding
|
||||
# content provider if it doesn't exist yet.
|
||||
connection = op.get_bind()
|
||||
|
||||
# Check if Exa search provider exists with an API key
|
||||
result = connection.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT api_key FROM internet_search_provider
|
||||
WHERE provider_type = 'exa' AND api_key IS NOT NULL
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
)
|
||||
row = result.fetchone()
|
||||
|
||||
if row:
|
||||
api_key = row[0]
|
||||
# Create Exa content provider with the shared key
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO internet_content_provider
|
||||
(name, provider_type, api_key, is_active)
|
||||
VALUES ('Exa', 'exa', :api_key, false)
|
||||
ON CONFLICT (name) DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"api_key": api_key},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the Exa content provider that was created by this migration
|
||||
connection = op.get_bind()
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
DELETE FROM internet_content_provider
|
||||
WHERE provider_type = 'exa'
|
||||
"""
|
||||
)
|
||||
)
|
||||
@@ -1,86 +0,0 @@
|
||||
"""tool_name_consistency
|
||||
|
||||
Revision ID: d25168c2beee
|
||||
Revises: 8405ca81cc83
|
||||
Create Date: 2026-01-11 17:54:40.135777
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d25168c2beee"
|
||||
down_revision = "8405ca81cc83"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
# Currently the seeded tools have the in_code_tool_id == name
|
||||
CURRENT_TOOL_NAME_MAPPING = [
|
||||
"SearchTool",
|
||||
"WebSearchTool",
|
||||
"ImageGenerationTool",
|
||||
"PythonTool",
|
||||
"OpenURLTool",
|
||||
"KnowledgeGraphTool",
|
||||
"ResearchAgent",
|
||||
]
|
||||
|
||||
# Mapping of in_code_tool_id -> name
|
||||
# These are the expected names that we want in the database
|
||||
EXPECTED_TOOL_NAME_MAPPING = {
|
||||
"SearchTool": "internal_search",
|
||||
"WebSearchTool": "web_search",
|
||||
"ImageGenerationTool": "generate_image",
|
||||
"PythonTool": "python",
|
||||
"OpenURLTool": "open_url",
|
||||
"KnowledgeGraphTool": "run_kg_search",
|
||||
"ResearchAgent": "research_agent",
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Mapping of in_code_tool_id to the NAME constant from each tool class
|
||||
# These match the .name property of each tool implementation
|
||||
tool_name_mapping = EXPECTED_TOOL_NAME_MAPPING
|
||||
|
||||
# Update the name column for each tool based on its in_code_tool_id
|
||||
for in_code_tool_id, expected_name in tool_name_mapping.items():
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :expected_name
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
{
|
||||
"expected_name": expected_name,
|
||||
"in_code_tool_id": in_code_tool_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Reverse the migration by setting name back to in_code_tool_id
|
||||
# This matches the original pattern where name was the class name
|
||||
for in_code_tool_id in CURRENT_TOOL_NAME_MAPPING:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :current_name
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
{
|
||||
"current_name": in_code_tool_id,
|
||||
"in_code_tool_id": in_code_tool_id,
|
||||
},
|
||||
)
|
||||
@@ -109,6 +109,7 @@ CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS = float(
|
||||
|
||||
|
||||
STRIPE_SECRET_KEY = os.environ.get("STRIPE_SECRET_KEY")
|
||||
STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")
|
||||
|
||||
# JWT Public Key URL
|
||||
JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
|
||||
|
||||
@@ -3,42 +3,30 @@ from uuid import UUID
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.server.features.persona.models import PersonaSharedNotificationData
|
||||
|
||||
|
||||
def update_persona_access(
|
||||
def make_persona_private(
|
||||
persona_id: int,
|
||||
creator_user_id: UUID | None,
|
||||
user_ids: list[UUID] | None,
|
||||
group_ids: list[int] | None,
|
||||
db_session: Session,
|
||||
is_public: bool | None = None,
|
||||
user_ids: list[UUID] | None = None,
|
||||
group_ids: list[int] | None = None,
|
||||
) -> None:
|
||||
"""Updates the access settings for a persona including public status, user shares,
|
||||
and group shares.
|
||||
"""NOTE(rkuo): This function batches all updates into a single commit. If we don't
|
||||
dedupe the inputs, the commit will exception."""
|
||||
|
||||
NOTE: This function batches all updates. If we don't dedupe the inputs,
|
||||
the commit will exception.
|
||||
|
||||
NOTE: Callers are responsible for committing."""
|
||||
|
||||
if is_public is not None:
|
||||
persona = db_session.query(Persona).filter(Persona.id == persona_id).first()
|
||||
if persona:
|
||||
persona.is_public = is_public
|
||||
|
||||
# NOTE: For user-ids and group-ids, `None` means "leave unchanged", `[]` means "clear all shares",
|
||||
# and a non-empty list means "replace with these shares".
|
||||
|
||||
if user_ids is not None:
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
db_session.query(Persona__UserGroup).filter(
|
||||
Persona__UserGroup.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
if user_ids:
|
||||
user_ids_set = set(user_ids)
|
||||
for user_id in user_ids_set:
|
||||
db_session.add(Persona__User(persona_id=persona_id, user_id=user_id))
|
||||
@@ -53,13 +41,11 @@ def update_persona_access(
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
if group_ids is not None:
|
||||
db_session.query(Persona__UserGroup).filter(
|
||||
Persona__UserGroup.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
if group_ids:
|
||||
group_ids_set = set(group_ids)
|
||||
for group_id in group_ids_set:
|
||||
db_session.add(
|
||||
Persona__UserGroup(persona_id=persona_id, user_group_id=group_id)
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
@@ -23,7 +23,6 @@ from onyx.db.models import User
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -101,7 +100,6 @@ def handle_simplified_chat_message(
|
||||
chunks_below=0,
|
||||
full_doc=chat_message_req.full_doc,
|
||||
structured_response_format=chat_message_req.structured_response_format,
|
||||
origin=MessageOrigin.API,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
@@ -205,7 +203,6 @@ def handle_send_message_simple_with_history(
|
||||
chunks_below=0,
|
||||
full_doc=req.full_doc,
|
||||
structured_response_format=req.structured_response_format,
|
||||
origin=MessageOrigin.API,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
import requests
|
||||
import stripe
|
||||
|
||||
from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
@@ -16,21 +16,15 @@ stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def fetch_stripe_checkout_session(
|
||||
tenant_id: str,
|
||||
billing_period: Literal["monthly", "annual"] = "monthly",
|
||||
) -> str:
|
||||
def fetch_stripe_checkout_session(tenant_id: str) -> str:
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/create-checkout-session"
|
||||
payload = {
|
||||
"tenant_id": tenant_id,
|
||||
"billing_period": billing_period,
|
||||
}
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
params = {"tenant_id": tenant_id}
|
||||
response = requests.post(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
return response.json()["sessionId"]
|
||||
|
||||
@@ -78,24 +72,22 @@ def fetch_billing_information(
|
||||
|
||||
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:
|
||||
"""
|
||||
Update the number of seats for a tenant's subscription.
|
||||
Preserves the existing price (monthly, annual, or grandfathered).
|
||||
Send a request to the control service to register the number of users for a tenant.
|
||||
"""
|
||||
|
||||
if not STRIPE_PRICE_ID:
|
||||
raise Exception("STRIPE_PRICE_ID is not set")
|
||||
|
||||
response = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_subscription_id = cast(str, response.get("stripe_subscription_id"))
|
||||
|
||||
subscription = stripe.Subscription.retrieve(stripe_subscription_id)
|
||||
subscription_item = subscription["items"]["data"][0]
|
||||
|
||||
# Use existing price to preserve the customer's current plan
|
||||
current_price_id = subscription_item.price.id
|
||||
|
||||
updated_subscription = stripe.Subscription.modify(
|
||||
stripe_subscription_id,
|
||||
items=[
|
||||
{
|
||||
"id": subscription_item.id,
|
||||
"price": current_price_id,
|
||||
"id": subscription["items"]["data"][0].id,
|
||||
"price": STRIPE_PRICE_ID,
|
||||
"quantity": number_of_users,
|
||||
}
|
||||
],
|
||||
|
||||
@@ -10,7 +10,6 @@ from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
|
||||
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import CreateSubscriptionSessionRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingFullSyncRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingResponse
|
||||
@@ -105,18 +104,15 @@ async def create_customer_portal_session(
|
||||
|
||||
@router.post("/create-subscription-session")
|
||||
async def create_subscription_session(
|
||||
request: CreateSubscriptionSessionRequest | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> SubscriptionSessionResponse:
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="Tenant ID not found")
|
||||
|
||||
billing_period = request.billing_period if request else "monthly"
|
||||
session_id = fetch_stripe_checkout_session(tenant_id, billing_period)
|
||||
session_id = fetch_stripe_checkout_session(tenant_id)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create subscription session")
|
||||
logger.exception("Failed to create resubscription session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -74,12 +73,6 @@ class SubscriptionSessionResponse(BaseModel):
|
||||
sessionId: str
|
||||
|
||||
|
||||
class CreateSubscriptionSessionRequest(BaseModel):
|
||||
"""Request to create a subscription checkout session."""
|
||||
|
||||
billing_period: Literal["monthly", "annual"] = "monthly"
|
||||
|
||||
|
||||
class TenantByDomainResponse(BaseModel):
|
||||
tenant_id: str
|
||||
number_of_users: int
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
"""EE Usage limits - trial detection via billing information."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
@@ -28,7 +31,13 @@ def is_tenant_on_trial(tenant_id: str) -> bool:
|
||||
return True
|
||||
|
||||
if isinstance(billing_info, BillingInformation):
|
||||
return billing_info.status == "trialing"
|
||||
# Check if trial is active
|
||||
if billing_info.trial_end is not None:
|
||||
now = datetime.now(timezone.utc)
|
||||
# Trial active if trial_end is in the future
|
||||
# and subscription status indicates trialing
|
||||
if billing_info.trial_end > now and billing_info.status == "trialing":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@@ -105,8 +105,6 @@ class DocExternalAccess:
|
||||
)
|
||||
|
||||
|
||||
# TODO(andrei): First refactor this into a pydantic model, then get rid of
|
||||
# duplicate fields.
|
||||
@dataclass(frozen=True, init=False)
|
||||
class DocumentAccess(ExternalAccess):
|
||||
# User emails for Onyx users, None indicates admin
|
||||
|
||||
@@ -124,7 +124,6 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.kg_processing",
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
# Light worker tasks
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
|
||||
@@ -174,7 +174,7 @@ if AUTO_LLM_CONFIG_URL:
|
||||
"schedule": timedelta(seconds=AUTO_LLM_UPDATE_INTERVAL_SECONDS),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"expires": AUTO_LLM_UPDATE_INTERVAL_SECONDS,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -5,9 +5,6 @@ from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
sync_llm_models_from_github,
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -29,9 +26,24 @@ def check_for_auto_llm_updates(self: Task, *, tenant_id: str) -> bool | None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
fetch_llm_recommendations_from_github,
|
||||
)
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
sync_llm_models_from_github,
|
||||
)
|
||||
|
||||
# Fetch config from GitHub
|
||||
config = fetch_llm_recommendations_from_github()
|
||||
|
||||
if not config:
|
||||
task_logger.warning("Failed to fetch GitHub config")
|
||||
return None
|
||||
|
||||
# Sync to database
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
results = sync_llm_models_from_github(db_session)
|
||||
results = sync_llm_models_from_github(db_session, config)
|
||||
|
||||
if results:
|
||||
task_logger.info(f"Auto mode sync results: {results}")
|
||||
|
||||
@@ -12,7 +12,6 @@ from retry import retry
|
||||
from sqlalchemy import select
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
@@ -20,14 +19,12 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
@@ -56,17 +53,6 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_queued_key(user_file_id: str | UUID) -> str:
|
||||
"""Key that exists while a process_single_user_file task is sitting in the queue.
|
||||
|
||||
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
before enqueuing and the worker deletes it as its first action. This prevents
|
||||
the beat from adding duplicate tasks for files that already have a live task
|
||||
in flight.
|
||||
"""
|
||||
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
@@ -130,24 +116,7 @@ def _get_document_chunk_count(
|
||||
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
|
||||
|
||||
Three mechanisms prevent queue runaway:
|
||||
|
||||
1. **Queue depth backpressure** – if the broker queue already has more than
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
|
||||
entirely. Workers are clearly behind; adding more tasks would only make
|
||||
the backlog worse.
|
||||
|
||||
2. **Per-file queued guard** – before enqueuing a task we set a short-lived
|
||||
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
|
||||
already exists the file already has a live task in the queue, so we skip
|
||||
it. The worker deletes the key the moment it picks up the task so the
|
||||
next beat cycle can re-enqueue if the file is still PROCESSING.
|
||||
|
||||
3. **Task expiry** – every enqueued task carries an `expires` value equal to
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
|
||||
the queue after that deadline, Celery discards it without touching the DB.
|
||||
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
|
||||
Redis restart), stale tasks evict themselves rather than piling up forever.
|
||||
Uses direct Redis locks to avoid overlapping runs.
|
||||
"""
|
||||
task_logger.info("check_user_file_processing - Starting")
|
||||
|
||||
@@ -162,21 +131,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
enqueued = 0
|
||||
skipped_guard = 0
|
||||
try:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
queue_len = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
|
||||
)
|
||||
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
f"check_user_file_processing - Queue depth {queue_len} exceeds "
|
||||
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file_ids = (
|
||||
db_session.execute(
|
||||
@@ -189,35 +144,12 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
)
|
||||
|
||||
for user_file_id in user_file_ids:
|
||||
# --- Protection 2: per-file queued guard ---
|
||||
queued_key = _user_file_queued_key(user_file_id)
|
||||
guard_set = redis_client.set(
|
||||
queued_key,
|
||||
1,
|
||||
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
nx=True,
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
if not guard_set:
|
||||
skipped_guard += 1
|
||||
continue
|
||||
|
||||
# --- Protection 3: task expiry ---
|
||||
# If task submission fails, clear the guard immediately so the
|
||||
# next beat cycle can retry enqueuing this file.
|
||||
try:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={
|
||||
"user_file_id": str(user_file_id),
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
)
|
||||
except Exception:
|
||||
redis_client.delete(queued_key)
|
||||
raise
|
||||
enqueued += 1
|
||||
|
||||
finally:
|
||||
@@ -225,8 +157,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
|
||||
f"tasks for tenant={tenant_id}"
|
||||
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -241,12 +172,6 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
|
||||
start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Clear the "queued" guard set by the beat generator so that the next beat
|
||||
# cycle can re-enqueue this file if it is still in PROCESSING state after
|
||||
# this task completes or fails.
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from redis.client import Redis
|
||||
|
||||
# Redis key prefixes for chat message processing
|
||||
PREFIX = "chatprocessing"
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
FENCE_TTL = 30 * 60 # 30 minutes
|
||||
|
||||
|
||||
def _get_fence_key(chat_session_id: UUID) -> str:
|
||||
"""
|
||||
Generate the Redis key for a chat session processing a message.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
|
||||
Returns:
|
||||
The fence key string (tenant_id is automatically added by the Redis client)
|
||||
"""
|
||||
return f"{FENCE_PREFIX}_{chat_session_id}"
|
||||
|
||||
|
||||
def set_processing_status(
|
||||
chat_session_id: UUID, redis_client: Redis, value: bool
|
||||
) -> None:
|
||||
"""
|
||||
Set or clear the fence for a chat session processing a message.
|
||||
|
||||
If the key exists, we are processing a message. If the key does not exist, we are not processing a message.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
redis_client: The Redis client to use
|
||||
value: True to set the fence, False to clear it
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
|
||||
if value:
|
||||
redis_client.set(fence_key, 0, ex=FENCE_TTL)
|
||||
else:
|
||||
redis_client.delete(fence_key)
|
||||
|
||||
|
||||
def is_chat_session_processing(chat_session_id: UUID, redis_client: Redis) -> bool:
|
||||
"""
|
||||
Check if the chat session is processing a message.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
redis_client: The Redis client to use
|
||||
|
||||
Returns:
|
||||
True if the chat session is processing a message, False otherwise
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
return bool(redis_client.exists(fence_key))
|
||||
@@ -94,7 +94,6 @@ class ChatStateContainer:
|
||||
|
||||
def run_chat_loop_with_state_containers(
|
||||
func: Callable[..., None],
|
||||
completion_callback: Callable[[ChatStateContainer], None],
|
||||
is_connected: Callable[[], bool],
|
||||
emitter: Emitter,
|
||||
state_container: ChatStateContainer,
|
||||
@@ -197,12 +196,3 @@ def run_chat_loop_with_state_containers(
|
||||
# Skip waiting if user disconnected to exit quickly.
|
||||
if is_connected():
|
||||
wait_on_background(thread)
|
||||
try:
|
||||
completion_callback(state_container)
|
||||
except Exception as e:
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=PacketException(type="error", exception=e),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -55,7 +55,6 @@ from onyx.prompts.chat_prompts import TOOL_CALL_RESPONSE_CROSS_MESSAGE
|
||||
from onyx.prompts.tool_prompts import TOOL_CALL_FAILURE_PROMPT
|
||||
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import (
|
||||
@@ -118,7 +117,6 @@ def prepare_chat_message_request(
|
||||
llm_override: LLMOverride | None = None,
|
||||
allowed_tool_ids: list[int] | None = None,
|
||||
forced_tool_ids: list[int] | None = None,
|
||||
origin: MessageOrigin | None = None,
|
||||
) -> CreateChatMessageRequest:
|
||||
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
|
||||
new_chat_session = create_chat_session(
|
||||
@@ -146,7 +144,6 @@ def prepare_chat_message_request(
|
||||
llm_override=llm_override,
|
||||
allowed_tool_ids=allowed_tool_ids,
|
||||
forced_tool_ids=forced_tool_ids,
|
||||
origin=origin or MessageOrigin.UNKNOWN,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -505,7 +505,7 @@ def run_llm_loop(
|
||||
# in-flight citations
|
||||
# It can be cleaned up but not super trivial or worthwhile right now
|
||||
just_ran_web_search = False
|
||||
parallel_tool_call_results = run_tool_calls(
|
||||
tool_responses, citation_mapping = run_tool_calls(
|
||||
tool_calls=tool_calls,
|
||||
tools=final_tools,
|
||||
message_history=truncated_message_history,
|
||||
@@ -516,8 +516,6 @@ def run_llm_loop(
|
||||
max_concurrent_tools=None,
|
||||
skip_search_query_expansion=has_called_search_tool,
|
||||
)
|
||||
tool_responses = parallel_tool_call_results.tool_responses
|
||||
citation_mapping = parallel_tool_call_results.updated_citation_mapping
|
||||
|
||||
# Failure case, give something reasonable to the LLM to try again
|
||||
if tool_calls and not tool_responses:
|
||||
|
||||
@@ -5,13 +5,10 @@ An overview can be found in the README.md file in this directory.
|
||||
|
||||
import re
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from uuid import UUID
|
||||
|
||||
from redis.client import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_processing_checker import set_processing_status
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_state import run_chat_loop_with_state_containers
|
||||
from onyx.chat.chat_utils import convert_chat_history
|
||||
@@ -48,8 +45,6 @@ from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.memory import get_memories
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import User
|
||||
from onyx.db.projects import get_project_token_count
|
||||
from onyx.db.projects import get_user_files_from_project
|
||||
@@ -83,16 +78,20 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.timing import log_function_time
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from onyx.utils.variable_functionality import noop_fallback
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
ERROR_TYPE_CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class ToolCallException(Exception):
|
||||
"""Exception raised for errors during tool calls."""
|
||||
|
||||
def __init__(self, message: str, tool_name: str | None = None):
|
||||
super().__init__(message)
|
||||
self.tool_name = tool_name
|
||||
|
||||
|
||||
def _extract_project_file_texts_and_images(
|
||||
project_id: int | None,
|
||||
user_id: UUID | None,
|
||||
@@ -295,8 +294,6 @@ def handle_stream_message_objects(
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
llm: LLM | None = None
|
||||
chat_session: ChatSession | None = None
|
||||
redis_client: Redis | None = None
|
||||
|
||||
user_id = user.id if user is not None else None
|
||||
llm_user_identifier = (
|
||||
@@ -342,24 +339,6 @@ def handle_stream_message_objects(
|
||||
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
|
||||
)
|
||||
|
||||
# Track user message in PostHog for analytics
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
module="onyx.utils.telemetry",
|
||||
attribute="event_telemetry",
|
||||
fallback=noop_fallback,
|
||||
)(
|
||||
distinct_id=user.email if user else tenant_id,
|
||||
event="user_message_sent",
|
||||
properties={
|
||||
"origin": new_msg_req.origin.value,
|
||||
"has_files": len(new_msg_req.file_descriptors) > 0,
|
||||
"has_project": chat_session.project_id is not None,
|
||||
"has_persona": persona is not None and persona.id != DEFAULT_PERSONA_ID,
|
||||
"deep_research": new_msg_req.deep_research,
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
)
|
||||
|
||||
llm = get_llm_for_persona(
|
||||
persona=persona,
|
||||
user=user,
|
||||
@@ -401,10 +380,7 @@ def handle_stream_message_objects(
|
||||
if new_msg_req.parent_message_id == AUTO_PLACE_AFTER_LATEST_MESSAGE:
|
||||
# Auto-place after the latest message in the chain
|
||||
parent_message = chat_history[-1] if chat_history else root_message
|
||||
elif (
|
||||
new_msg_req.parent_message_id is None
|
||||
or new_msg_req.parent_message_id == root_message.id
|
||||
):
|
||||
elif new_msg_req.parent_message_id is None:
|
||||
# None = regeneration from root
|
||||
parent_message = root_message
|
||||
# Truncate history since we're starting from root
|
||||
@@ -560,27 +536,10 @@ def handle_stream_message_objects(
|
||||
def check_is_connected() -> bool:
|
||||
return check_stop_signal(chat_session.id, redis_client)
|
||||
|
||||
set_processing_status(
|
||||
chat_session_id=chat_session.id,
|
||||
redis_client=redis_client,
|
||||
value=True,
|
||||
)
|
||||
|
||||
# Use external state container if provided, otherwise create internal one
|
||||
# External container allows non-streaming callers to access accumulated state
|
||||
state_container = external_state_container or ChatStateContainer()
|
||||
|
||||
def llm_loop_completion_callback(
|
||||
state_container: ChatStateContainer,
|
||||
) -> None:
|
||||
llm_loop_completion_handle(
|
||||
state_container=state_container,
|
||||
db_session=db_session,
|
||||
chat_session_id=str(chat_session.id),
|
||||
is_connected=check_is_connected,
|
||||
assistant_message=assistant_response,
|
||||
)
|
||||
|
||||
# Run the LLM loop with explicit wrapper for stop signal handling
|
||||
# The wrapper runs run_llm_loop in a background thread and polls every 300ms
|
||||
# for stop signals. run_llm_loop itself doesn't know about stopping.
|
||||
@@ -596,7 +555,6 @@ def handle_stream_message_objects(
|
||||
|
||||
yield from run_chat_loop_with_state_containers(
|
||||
run_deep_research_llm_loop,
|
||||
llm_loop_completion_callback,
|
||||
is_connected=check_is_connected,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
@@ -613,7 +571,6 @@ def handle_stream_message_objects(
|
||||
else:
|
||||
yield from run_chat_loop_with_state_containers(
|
||||
run_llm_loop,
|
||||
llm_loop_completion_callback,
|
||||
is_connected=check_is_connected, # Not passed through to run_llm_loop
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
@@ -631,6 +588,51 @@ def handle_stream_message_objects(
|
||||
chat_session_id=str(chat_session.id),
|
||||
)
|
||||
|
||||
# Determine if stopped by user
|
||||
completed_normally = check_is_connected()
|
||||
if not completed_normally:
|
||||
logger.debug(f"Chat session {chat_session.id} stopped by user")
|
||||
|
||||
# Build final answer based on completion status
|
||||
if completed_normally:
|
||||
if state_container.answer_tokens is None:
|
||||
raise RuntimeError(
|
||||
"LLM run completed normally but did not return an answer."
|
||||
)
|
||||
final_answer = state_container.answer_tokens
|
||||
else:
|
||||
# Stopped by user - append stop message
|
||||
if state_container.answer_tokens:
|
||||
final_answer = (
|
||||
state_container.answer_tokens
|
||||
+ " ... The generation was stopped by the user here."
|
||||
)
|
||||
else:
|
||||
final_answer = "The generation was stopped by the user."
|
||||
|
||||
# Build citation_docs_info from accumulated citations in state container
|
||||
citation_docs_info: list[CitationDocInfo] = []
|
||||
seen_citation_nums: set[int] = set()
|
||||
for citation_num, search_doc in state_container.citation_to_doc.items():
|
||||
if citation_num not in seen_citation_nums:
|
||||
seen_citation_nums.add(citation_num)
|
||||
citation_docs_info.append(
|
||||
CitationDocInfo(
|
||||
search_doc=search_doc,
|
||||
citation_number=citation_num,
|
||||
)
|
||||
)
|
||||
|
||||
save_chat_turn(
|
||||
message_text=final_answer,
|
||||
reasoning_tokens=state_container.reasoning_tokens,
|
||||
citation_docs_info=citation_docs_info,
|
||||
tool_calls=state_container.tool_calls,
|
||||
db_session=db_session,
|
||||
assistant_message=assistant_response,
|
||||
is_clarification=state_container.is_clarification,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to process chat message.")
|
||||
|
||||
@@ -648,7 +650,15 @@ def handle_stream_message_objects(
|
||||
error_msg = str(e)
|
||||
stack_trace = traceback.format_exc()
|
||||
|
||||
if llm:
|
||||
if isinstance(e, ToolCallException):
|
||||
yield StreamingError(
|
||||
error=error_msg,
|
||||
stack_trace=stack_trace,
|
||||
error_code="TOOL_CALL_FAILED",
|
||||
is_retryable=True,
|
||||
details={"tool_name": e.tool_name} if e.tool_name else None,
|
||||
)
|
||||
elif llm:
|
||||
client_error_msg, error_code, is_retryable = litellm_exception_to_error_msg(
|
||||
e, llm
|
||||
)
|
||||
@@ -680,67 +690,7 @@ def handle_stream_message_objects(
|
||||
)
|
||||
|
||||
db_session.rollback()
|
||||
finally:
|
||||
try:
|
||||
if redis_client is not None and chat_session is not None:
|
||||
set_processing_status(
|
||||
chat_session_id=chat_session.id,
|
||||
redis_client=redis_client,
|
||||
value=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error in setting processing status")
|
||||
|
||||
|
||||
def llm_loop_completion_handle(
|
||||
state_container: ChatStateContainer,
|
||||
is_connected: Callable[[], bool],
|
||||
db_session: Session,
|
||||
chat_session_id: str,
|
||||
assistant_message: ChatMessage,
|
||||
) -> None:
|
||||
# Determine if stopped by user
|
||||
completed_normally = is_connected()
|
||||
# Build final answer based on completion status
|
||||
if completed_normally:
|
||||
if state_container.answer_tokens is None:
|
||||
raise RuntimeError(
|
||||
"LLM run completed normally but did not return an answer."
|
||||
)
|
||||
final_answer = state_container.answer_tokens
|
||||
else:
|
||||
# Stopped by user - append stop message
|
||||
logger.debug(f"Chat session {chat_session_id} stopped by user")
|
||||
if state_container.answer_tokens:
|
||||
final_answer = (
|
||||
state_container.answer_tokens
|
||||
+ " ... \n\nGeneration was stopped by the user."
|
||||
)
|
||||
else:
|
||||
final_answer = "The generation was stopped by the user."
|
||||
|
||||
# Build citation_docs_info from accumulated citations in state container
|
||||
citation_docs_info: list[CitationDocInfo] = []
|
||||
seen_citation_nums: set[int] = set()
|
||||
for citation_num, search_doc in state_container.citation_to_doc.items():
|
||||
if citation_num not in seen_citation_nums:
|
||||
seen_citation_nums.add(citation_num)
|
||||
citation_docs_info.append(
|
||||
CitationDocInfo(
|
||||
search_doc=search_doc,
|
||||
citation_number=citation_num,
|
||||
)
|
||||
)
|
||||
|
||||
save_chat_turn(
|
||||
message_text=final_answer,
|
||||
reasoning_tokens=state_container.reasoning_tokens,
|
||||
citation_docs_info=citation_docs_info,
|
||||
tool_calls=state_container.tool_calls,
|
||||
db_session=db_session,
|
||||
assistant_message=assistant_message,
|
||||
is_clarification=state_container.is_clarification,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def stream_chat_message_objects(
|
||||
@@ -789,7 +739,6 @@ def stream_chat_message_objects(
|
||||
deep_research=new_msg_req.deep_research,
|
||||
parent_message_id=new_msg_req.parent_message_id,
|
||||
chat_session_id=new_msg_req.chat_session_id,
|
||||
origin=new_msg_req.origin,
|
||||
)
|
||||
return handle_stream_message_objects(
|
||||
new_msg_req=translated_new_msg_req,
|
||||
|
||||
@@ -568,7 +568,6 @@ JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
JIRA_CONNECTOR_MAX_TICKET_SIZE = int(
|
||||
os.environ.get("JIRA_CONNECTOR_MAX_TICKET_SIZE", 100 * 1024)
|
||||
)
|
||||
JIRA_SLIM_PAGE_SIZE = int(os.environ.get("JIRA_SLIM_PAGE_SIZE", 500))
|
||||
|
||||
GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
|
||||
|
||||
@@ -996,9 +995,3 @@ COHERE_DEFAULT_API_KEY = os.environ.get("COHERE_DEFAULT_API_KEY")
|
||||
VERTEXAI_DEFAULT_CREDENTIALS = os.environ.get("VERTEXAI_DEFAULT_CREDENTIALS")
|
||||
VERTEXAI_DEFAULT_LOCATION = os.environ.get("VERTEXAI_DEFAULT_LOCATION", "global")
|
||||
OPENROUTER_DEFAULT_API_KEY = os.environ.get("OPENROUTER_DEFAULT_API_KEY")
|
||||
|
||||
INSTANCE_TYPE = (
|
||||
"managed"
|
||||
if os.environ.get("IS_MANAGED_INSTANCE", "").lower() == "true"
|
||||
else "cloud" if AUTH_TYPE == AuthType.CLOUD else "self_hosted"
|
||||
)
|
||||
|
||||
@@ -7,7 +7,6 @@ from enum import Enum
|
||||
|
||||
ONYX_DEFAULT_APPLICATION_NAME = "Onyx"
|
||||
ONYX_DISCORD_URL = "https://discord.gg/4NA5SbzrWb"
|
||||
ONYX_UTM_SOURCE = "onyx_app"
|
||||
SLACK_USER_TOKEN_PREFIX = "xoxp-"
|
||||
SLACK_BOT_TOKEN_PREFIX = "xoxb-"
|
||||
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
|
||||
@@ -149,17 +148,6 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
|
||||
|
||||
# How long a queued user-file task is valid before workers discard it.
|
||||
# Should be longer than the beat interval (20 s) but short enough to prevent
|
||||
# indefinite queue growth. Workers drop tasks older than this without touching
|
||||
# the DB, so a shorter value = faster drain of stale duplicates.
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
|
||||
|
||||
# Maximum number of tasks allowed in the user-file-processing queue before the
|
||||
# beat generator stops adding more. Prevents unbounded queue growth when workers
|
||||
# fall behind.
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
|
||||
|
||||
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
|
||||
@@ -247,7 +235,6 @@ class NotificationType(str, Enum):
|
||||
PERSONA_SHARED = "persona_shared"
|
||||
TRIAL_ENDS_TWO_DAYS = "two_day_trial_ending" # 2 days left in trial
|
||||
RELEASE_NOTES = "release_notes"
|
||||
ASSISTANT_FILES_READY = "assistant_files_ready"
|
||||
|
||||
|
||||
class BlobType(str, Enum):
|
||||
@@ -430,17 +417,11 @@ class OnyxRedisLocks:
|
||||
# User file processing
|
||||
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
|
||||
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
|
||||
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
|
||||
# Prevents the beat from re-enqueuing the same file while a task is already queued.
|
||||
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
|
||||
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
|
||||
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
|
||||
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
|
||||
USER_FILE_DELETE_LOCK_PREFIX = "da_lock:user_file_delete"
|
||||
|
||||
# Release notes
|
||||
RELEASE_NOTES_FETCH_LOCK = "da_lock:release_notes_fetch"
|
||||
|
||||
|
||||
class OnyxRedisSignals:
|
||||
BLOCK_VALIDATE_INDEXING_FENCES = "signal:block_validate_indexing_fences"
|
||||
|
||||
@@ -93,7 +93,7 @@ if __name__ == "__main__":
|
||||
#### Docs Changes
|
||||
|
||||
Create the new connector page (with guiding images!) with how to get the connector credentials and how to set up the
|
||||
connector in Onyx. Then create a Pull Request in [https://github.com/onyx-dot-app/documentation](https://github.com/onyx-dot-app/documentation).
|
||||
connector in Onyx. Then create a Pull Request in https://github.com/onyx-dot-app/onyx-docs.
|
||||
|
||||
### Before opening PR
|
||||
|
||||
|
||||
@@ -901,16 +901,13 @@ class OnyxConfluence:
|
||||
space_key: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
This is a confluence server/data center specific method that can be used to
|
||||
This is a confluence server specific method that can be used to
|
||||
fetch the permissions of a space.
|
||||
|
||||
NOTE: This uses the JSON-RPC API which is the ONLY way to get space permissions
|
||||
on Confluence Server/Data Center. The REST API equivalent (expand=permissions)
|
||||
is Cloud-only and not available on Data Center as of version 8.9.x.
|
||||
|
||||
If this fails with 401 Unauthorized, the customer needs to enable JSON-RPC:
|
||||
Confluence Admin -> General Configuration -> Further Configuration
|
||||
-> Enable "Remote API (XML-RPC & SOAP)"
|
||||
This is better logging than calling the get_space_permissions method
|
||||
because it returns a jsonrpc response.
|
||||
TODO: Make this call these endpoints for newer confluence versions:
|
||||
- /rest/api/space/{spaceKey}/permissions
|
||||
- /rest/api/space/{spaceKey}/permissions/anonymous
|
||||
"""
|
||||
url = "rpc/json-rpc/confluenceservice-v2"
|
||||
data = {
|
||||
@@ -919,18 +916,7 @@ class OnyxConfluence:
|
||||
"id": 7,
|
||||
"params": [space_key],
|
||||
}
|
||||
try:
|
||||
response = self.post(url, data=data)
|
||||
except HTTPError as e:
|
||||
if e.response is not None and e.response.status_code == 401:
|
||||
raise HTTPError(
|
||||
"Unauthorized (401) when calling JSON-RPC API for space permissions. "
|
||||
"This is likely because the Remote API is disabled. "
|
||||
"To fix: Confluence Admin -> General Configuration -> Further Configuration "
|
||||
"-> Enable 'Remote API (XML-RPC & SOAP)'",
|
||||
response=e.response,
|
||||
) from e
|
||||
raise
|
||||
response = self.post(url, data=data)
|
||||
logger.debug(f"jsonrpc response: {response}")
|
||||
if not response.get("result"):
|
||||
logger.warning(
|
||||
|
||||
@@ -97,17 +97,10 @@ def basic_expert_info_representation(info: BasicExpertInfo) -> str | None:
|
||||
def get_experts_stores_representations(
|
||||
experts: list[BasicExpertInfo] | None,
|
||||
) -> list[str] | None:
|
||||
"""Gets string representations of experts supplied.
|
||||
|
||||
If an expert cannot be represented as a string, it is omitted from the
|
||||
result.
|
||||
"""
|
||||
if not experts:
|
||||
return None
|
||||
|
||||
reps: list[str | None] = [
|
||||
basic_expert_info_representation(owner) for owner in experts
|
||||
]
|
||||
reps = [basic_expert_info_representation(owner) for owner in experts]
|
||||
return [owner for owner in reps if owner is not None]
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ from typing_extensions import override
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
|
||||
from onyx.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
|
||||
from onyx.configs.app_configs import JIRA_SLIM_PAGE_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
is_atlassian_date_error,
|
||||
@@ -58,6 +57,7 @@ logger = setup_logger()
|
||||
ONE_HOUR = 3600
|
||||
|
||||
_MAX_RESULTS_FETCH_IDS = 5000 # 5000
|
||||
_JIRA_SLIM_PAGE_SIZE = 500
|
||||
_JIRA_FULL_PAGE_SIZE = 50
|
||||
|
||||
# Constants for Jira field names
|
||||
@@ -683,7 +683,7 @@ class JiraConnector(
|
||||
jira_client=self.jira_client,
|
||||
jql=jql,
|
||||
start=current_offset,
|
||||
max_results=JIRA_SLIM_PAGE_SIZE,
|
||||
max_results=_JIRA_SLIM_PAGE_SIZE,
|
||||
all_issue_ids=checkpoint.all_issue_ids,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
nextPageToken=checkpoint.cursor,
|
||||
@@ -703,11 +703,11 @@ class JiraConnector(
|
||||
)
|
||||
)
|
||||
current_offset += 1
|
||||
if len(slim_doc_batch) >= JIRA_SLIM_PAGE_SIZE:
|
||||
if len(slim_doc_batch) >= _JIRA_SLIM_PAGE_SIZE:
|
||||
yield slim_doc_batch
|
||||
slim_doc_batch = []
|
||||
self.update_checkpoint_for_next_run(
|
||||
checkpoint, current_offset, prev_offset, JIRA_SLIM_PAGE_SIZE
|
||||
checkpoint, current_offset, prev_offset, _JIRA_SLIM_PAGE_SIZE
|
||||
)
|
||||
prev_offset = current_offset
|
||||
|
||||
|
||||
@@ -566,23 +566,6 @@ def extract_content_words_from_recency_query(
|
||||
return content_words_filtered[:MAX_CONTENT_WORDS]
|
||||
|
||||
|
||||
def _is_valid_keyword_query(line: str) -> bool:
|
||||
"""Check if a line looks like a valid keyword query vs explanatory text.
|
||||
|
||||
Returns False for lines that appear to be LLM explanations rather than keywords.
|
||||
"""
|
||||
# Reject lines that start with parentheses (explanatory notes)
|
||||
if line.startswith("("):
|
||||
return False
|
||||
|
||||
# Reject lines that are too long (likely sentences, not keywords)
|
||||
# Keywords should be short - reject if > 50 chars or > 6 words
|
||||
if len(line) > 50 or len(line.split()) > 6:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
"""Use LLM to expand query into multiple search variations.
|
||||
|
||||
@@ -603,18 +586,10 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
response_clean = _parse_llm_code_block_response(response)
|
||||
|
||||
# Split into lines and filter out empty lines
|
||||
raw_queries = [
|
||||
rephrased_queries = [
|
||||
line.strip() for line in response_clean.split("\n") if line.strip()
|
||||
]
|
||||
|
||||
# Filter out lines that look like explanatory text rather than keywords
|
||||
rephrased_queries = [q for q in raw_queries if _is_valid_keyword_query(q)]
|
||||
|
||||
# Log if we filtered out garbage
|
||||
if len(raw_queries) != len(rephrased_queries):
|
||||
filtered_out = set(raw_queries) - set(rephrased_queries)
|
||||
logger.warning(f"Filtered out non-keyword LLM responses: {filtered_out}")
|
||||
|
||||
# If no queries generated, use empty query
|
||||
if not rephrased_queries:
|
||||
logger.debug("No content keywords extracted from query expansion")
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Tuple
|
||||
from uuid import UUID
|
||||
|
||||
@@ -182,11 +181,7 @@ def get_chat_sessions_by_user(
|
||||
.correlate(ChatSession)
|
||||
)
|
||||
|
||||
# Leeway for newly created chats that don't have messages yet
|
||||
time = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
recently_created = ChatSession.time_created >= time
|
||||
|
||||
stmt = stmt.where(or_(non_system_message_exists_subq, recently_created))
|
||||
stmt = stmt.where(non_system_message_exists_subq)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
chat_sessions = result.scalars().all()
|
||||
|
||||
@@ -444,8 +444,6 @@ def upsert_documents(
|
||||
logger.info("No documents to upsert. Skipping.")
|
||||
return
|
||||
|
||||
includes_permissions = any(doc.external_access for doc in seen_documents.values())
|
||||
|
||||
insert_stmt = insert(DbDocument).values(
|
||||
[
|
||||
model_to_dict(
|
||||
@@ -481,38 +479,21 @@ def upsert_documents(
|
||||
]
|
||||
)
|
||||
|
||||
update_set = {
|
||||
"from_ingestion_api": insert_stmt.excluded.from_ingestion_api,
|
||||
"boost": insert_stmt.excluded.boost,
|
||||
"hidden": insert_stmt.excluded.hidden,
|
||||
"semantic_id": insert_stmt.excluded.semantic_id,
|
||||
"link": insert_stmt.excluded.link,
|
||||
"primary_owners": insert_stmt.excluded.primary_owners,
|
||||
"secondary_owners": insert_stmt.excluded.secondary_owners,
|
||||
"doc_metadata": insert_stmt.excluded.doc_metadata,
|
||||
}
|
||||
if includes_permissions:
|
||||
# Use COALESCE to preserve existing permissions when new values are NULL.
|
||||
# This prevents subsequent indexing runs (which don't fetch permissions)
|
||||
# from overwriting permissions set by permission sync jobs.
|
||||
update_set.update(
|
||||
{
|
||||
"external_user_emails": func.coalesce(
|
||||
insert_stmt.excluded.external_user_emails,
|
||||
DbDocument.external_user_emails,
|
||||
),
|
||||
"external_user_group_ids": func.coalesce(
|
||||
insert_stmt.excluded.external_user_group_ids,
|
||||
DbDocument.external_user_group_ids,
|
||||
),
|
||||
"is_public": func.coalesce(
|
||||
insert_stmt.excluded.is_public,
|
||||
DbDocument.is_public,
|
||||
),
|
||||
}
|
||||
)
|
||||
on_conflict_stmt = insert_stmt.on_conflict_do_update(
|
||||
index_elements=["id"], set_=update_set # Conflict target
|
||||
index_elements=["id"], # Conflict target
|
||||
set_={
|
||||
"from_ingestion_api": insert_stmt.excluded.from_ingestion_api,
|
||||
"boost": insert_stmt.excluded.boost,
|
||||
"hidden": insert_stmt.excluded.hidden,
|
||||
"semantic_id": insert_stmt.excluded.semantic_id,
|
||||
"link": insert_stmt.excluded.link,
|
||||
"primary_owners": insert_stmt.excluded.primary_owners,
|
||||
"secondary_owners": insert_stmt.excluded.secondary_owners,
|
||||
"external_user_emails": insert_stmt.excluded.external_user_emails,
|
||||
"external_user_group_ids": insert_stmt.excluded.external_user_group_ids,
|
||||
"is_public": insert_stmt.excluded.is_public,
|
||||
"doc_metadata": insert_stmt.excluded.doc_metadata,
|
||||
},
|
||||
)
|
||||
db_session.execute(on_conflict_stmt)
|
||||
db_session.commit()
|
||||
|
||||
@@ -374,7 +374,7 @@ def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolM
|
||||
def fetch_existing_llm_providers(
|
||||
db_session: Session,
|
||||
only_public: bool = False,
|
||||
exclude_image_generation_providers: bool = True,
|
||||
exclude_image_generation_providers: bool = False,
|
||||
) -> list[LLMProviderModel]:
|
||||
"""Fetch all LLM providers with optional filtering.
|
||||
|
||||
@@ -585,12 +585,13 @@ def update_default_vision_provider(
|
||||
|
||||
def fetch_auto_mode_providers(db_session: Session) -> list[LLMProviderModel]:
|
||||
"""Fetch all LLM providers that are in Auto mode."""
|
||||
query = (
|
||||
select(LLMProviderModel)
|
||||
.where(LLMProviderModel.is_auto_mode.is_(True))
|
||||
.options(selectinload(LLMProviderModel.model_configurations))
|
||||
return list(
|
||||
db_session.scalars(
|
||||
select(LLMProviderModel)
|
||||
.where(LLMProviderModel.is_auto_mode == True) # noqa: E712
|
||||
.options(selectinload(LLMProviderModel.model_configurations))
|
||||
).all()
|
||||
)
|
||||
return list(db_session.scalars(query).all())
|
||||
|
||||
|
||||
def sync_auto_mode_models(
|
||||
@@ -619,9 +620,7 @@ def sync_auto_mode_models(
|
||||
|
||||
# Build the list of all visible models from the config
|
||||
# All models in the config are visible (default + additional_visible_models)
|
||||
recommended_visible_models = llm_recommendations.get_visible_models(
|
||||
provider.provider
|
||||
)
|
||||
recommended_visible_models = llm_recommendations.get_visible_models(provider.name)
|
||||
recommended_visible_model_names = [
|
||||
model.name for model in recommended_visible_models
|
||||
]
|
||||
@@ -636,12 +635,11 @@ def sync_auto_mode_models(
|
||||
).all()
|
||||
}
|
||||
|
||||
# Mark models that are no longer in GitHub config as not visible
|
||||
# Remove models that are no longer in GitHub config
|
||||
for model_name, model in existing_models.items():
|
||||
if model_name not in recommended_visible_model_names:
|
||||
if model.is_visible:
|
||||
model.is_visible = False
|
||||
changes += 1
|
||||
db_session.delete(model)
|
||||
changes += 1
|
||||
|
||||
# Add or update models from GitHub config
|
||||
for model_config in recommended_visible_models:
|
||||
@@ -671,7 +669,7 @@ def sync_auto_mode_models(
|
||||
changes += 1
|
||||
|
||||
# In Auto mode, default model is always set from GitHub config
|
||||
default_model = llm_recommendations.get_default_model(provider.provider)
|
||||
default_model = llm_recommendations.get_default_model(provider.name)
|
||||
if default_model and provider.default_model_name != default_model.name:
|
||||
provider.default_model_name = default_model.name
|
||||
changes += 1
|
||||
|
||||
@@ -377,17 +377,6 @@ class Notification(Base):
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
|
||||
# Unique constraint ix_notification_user_type_data on (user_id, notif_type, additional_data)
|
||||
# ensures notification deduplication for batch inserts. Defined in migration 8405ca81cc83.
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_notification_user_sort",
|
||||
"user_id",
|
||||
"dismissed",
|
||||
desc("first_shown"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
Association Tables
|
||||
@@ -2616,7 +2605,6 @@ class Tool(Base):
|
||||
__tablename__ = "tool"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
# The name of the tool that the LLM will see
|
||||
name: Mapped[str] = mapped_column(String, nullable=False)
|
||||
description: Mapped[str] = mapped_column(Text, nullable=True)
|
||||
# ID of the tool in the codebase, only applies for in-code tools.
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import cast
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
@@ -22,33 +17,23 @@ def create_notification(
|
||||
title: str,
|
||||
description: str | None = None,
|
||||
additional_data: dict | None = None,
|
||||
autocommit: bool = True,
|
||||
) -> Notification:
|
||||
# Previously, we only matched the first identical, undismissed notification
|
||||
# Now, we assume some uniqueness to notifications
|
||||
# If we previously issued a notification that was dismissed, we no longer issue a new one
|
||||
|
||||
# Normalize additional_data to match the unique index behavior
|
||||
# The index uses COALESCE(additional_data, '{}'::jsonb)
|
||||
# We need to match this logic in our query
|
||||
additional_data_normalized = additional_data if additional_data is not None else {}
|
||||
|
||||
# Check if an undismissed notification of the same type and data exists
|
||||
existing_notification = (
|
||||
db_session.query(Notification)
|
||||
.filter_by(user_id=user_id, notif_type=notif_type)
|
||||
.filter(
|
||||
func.coalesce(Notification.additional_data, cast({}, postgresql.JSONB))
|
||||
== additional_data_normalized
|
||||
.filter_by(
|
||||
user_id=user_id,
|
||||
notif_type=notif_type,
|
||||
dismissed=False,
|
||||
)
|
||||
.filter(Notification.additional_data == additional_data)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_notification:
|
||||
# Update the last_shown timestamp if the notification is not dismissed
|
||||
if not existing_notification.dismissed:
|
||||
existing_notification.last_shown = func.now()
|
||||
if autocommit:
|
||||
db_session.commit()
|
||||
# Update the last_shown timestamp
|
||||
existing_notification.last_shown = func.now()
|
||||
db_session.commit()
|
||||
return existing_notification
|
||||
|
||||
# Create a new notification if none exists
|
||||
@@ -63,8 +48,7 @@ def create_notification(
|
||||
additional_data=additional_data,
|
||||
)
|
||||
db_session.add(notification)
|
||||
if autocommit:
|
||||
db_session.commit()
|
||||
db_session.commit()
|
||||
return notification
|
||||
|
||||
|
||||
@@ -97,11 +81,6 @@ def get_notifications(
|
||||
query = query.where(Notification.dismissed.is_(False))
|
||||
if notif_type:
|
||||
query = query.where(Notification.notif_type == notif_type)
|
||||
# Sort: undismissed first, then by date (newest first)
|
||||
query = query.order_by(
|
||||
Notification.dismissed.asc(),
|
||||
Notification.first_shown.desc(),
|
||||
)
|
||||
return list(db_session.execute(query).scalars().all())
|
||||
|
||||
|
||||
@@ -120,63 +99,6 @@ def dismiss_notification(notification: Notification, db_session: Session) -> Non
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def batch_dismiss_notifications(
|
||||
notifications: list[Notification],
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
for notification in notifications:
|
||||
notification.dismissed = True
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def batch_create_notifications(
|
||||
user_ids: list[UUID],
|
||||
notif_type: NotificationType,
|
||||
db_session: Session,
|
||||
title: str,
|
||||
description: str | None = None,
|
||||
additional_data: dict | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Create notifications for multiple users in a single batch operation.
|
||||
Uses ON CONFLICT DO NOTHING for atomic idempotent inserts - if a user already
|
||||
has a notification with the same (user_id, notif_type, additional_data), the
|
||||
insert is silently skipped.
|
||||
|
||||
Returns the number of notifications created.
|
||||
|
||||
Relies on unique index on (user_id, notif_type, COALESCE(additional_data, '{}'))
|
||||
"""
|
||||
if not user_ids:
|
||||
return 0
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
# Use empty dict instead of None to match COALESCE behavior in the unique index
|
||||
additional_data_normalized = additional_data if additional_data is not None else {}
|
||||
|
||||
values = [
|
||||
{
|
||||
"user_id": uid,
|
||||
"notif_type": notif_type.value,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"dismissed": False,
|
||||
"last_shown": now,
|
||||
"first_shown": now,
|
||||
"additional_data": additional_data_normalized,
|
||||
}
|
||||
for uid in user_ids
|
||||
]
|
||||
|
||||
stmt = insert(Notification).values(values).on_conflict_do_nothing()
|
||||
result = db_session.execute(stmt)
|
||||
db_session.commit()
|
||||
|
||||
# rowcount returns number of rows inserted (excludes conflicts)
|
||||
# CursorResult has rowcount but session.execute type hints are too broad
|
||||
return result.rowcount if result.rowcount >= 0 else 0 # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def update_notification_last_shown(
|
||||
notification: Notification, db_session: Session
|
||||
) -> None:
|
||||
|
||||
@@ -187,25 +187,13 @@ def _get_persona_by_name(
|
||||
return result
|
||||
|
||||
|
||||
def update_persona_access(
|
||||
def make_persona_private(
|
||||
persona_id: int,
|
||||
creator_user_id: UUID | None,
|
||||
user_ids: list[UUID] | None,
|
||||
group_ids: list[int] | None,
|
||||
db_session: Session,
|
||||
is_public: bool | None = None,
|
||||
user_ids: list[UUID] | None = None,
|
||||
group_ids: list[int] | None = None,
|
||||
) -> None:
|
||||
"""Updates the access settings for a persona including public status and user shares.
|
||||
|
||||
NOTE: Callers are responsible for committing."""
|
||||
|
||||
if is_public is not None:
|
||||
persona = db_session.query(Persona).filter(Persona.id == persona_id).first()
|
||||
if persona:
|
||||
persona.is_public = is_public
|
||||
|
||||
# NOTE: For user-ids and group-ids, `None` means "leave unchanged", `[]` means "clear all shares",
|
||||
# and a non-empty list means "replace with these shares".
|
||||
if user_ids is not None:
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
@@ -224,15 +212,11 @@ def update_persona_access(
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
# MIT doesn't support group-based sharing, so we allow clearing (no-op since
|
||||
# there shouldn't be any) but raise an error if trying to add actual groups.
|
||||
if group_ids is not None:
|
||||
db_session.query(Persona__UserGroup).filter(
|
||||
Persona__UserGroup.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
db_session.commit()
|
||||
|
||||
if group_ids:
|
||||
raise NotImplementedError("Onyx MIT does not support group-based sharing")
|
||||
# May cause error if someone switches down to MIT from EE
|
||||
if group_ids:
|
||||
raise NotImplementedError("Onyx MIT does not support private Personas")
|
||||
|
||||
|
||||
def create_update_persona(
|
||||
@@ -298,21 +282,20 @@ def create_update_persona(
|
||||
llm_filter_extraction=create_persona_request.llm_filter_extraction,
|
||||
is_default_persona=create_persona_request.is_default_persona,
|
||||
user_file_ids=converted_user_file_ids,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
versioned_update_persona_access = fetch_versioned_implementation(
|
||||
"onyx.db.persona", "update_persona_access"
|
||||
versioned_make_persona_private = fetch_versioned_implementation(
|
||||
"onyx.db.persona", "make_persona_private"
|
||||
)
|
||||
|
||||
versioned_update_persona_access(
|
||||
# Privatize Persona
|
||||
versioned_make_persona_private(
|
||||
persona_id=persona.id,
|
||||
creator_user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
user_ids=create_persona_request.users,
|
||||
group_ids=create_persona_request.groups,
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to create persona")
|
||||
@@ -321,13 +304,11 @@ def create_update_persona(
|
||||
return FullPersonaSnapshot.from_model(persona)
|
||||
|
||||
|
||||
def update_persona_shared(
|
||||
def update_persona_shared_users(
|
||||
persona_id: int,
|
||||
user_ids: list[UUID],
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
user_ids: list[UUID] | None = None,
|
||||
group_ids: list[int] | None = None,
|
||||
is_public: bool | None = None,
|
||||
) -> None:
|
||||
"""Simplified version of `create_update_persona` which only touches the
|
||||
accessibility rather than any of the logic (e.g. prompt, connected data sources,
|
||||
@@ -336,25 +317,22 @@ def update_persona_shared(
|
||||
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
||||
)
|
||||
|
||||
if user and user.role != UserRole.ADMIN and persona.user_id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have permission to modify this persona"
|
||||
)
|
||||
if persona.is_public:
|
||||
raise HTTPException(status_code=400, detail="Cannot share public persona")
|
||||
|
||||
versioned_update_persona_access = fetch_versioned_implementation(
|
||||
"onyx.db.persona", "update_persona_access"
|
||||
versioned_make_persona_private = fetch_versioned_implementation(
|
||||
"onyx.db.persona", "make_persona_private"
|
||||
)
|
||||
versioned_update_persona_access(
|
||||
|
||||
# Privatize Persona
|
||||
versioned_make_persona_private(
|
||||
persona_id=persona_id,
|
||||
creator_user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
is_public=is_public,
|
||||
user_ids=user_ids,
|
||||
group_ids=group_ids,
|
||||
group_ids=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_persona_public_status(
|
||||
persona_id: int,
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
"""Database functions for release notes functionality."""
|
||||
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import INSTANCE_TYPE
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import ONYX_UTM_SOURCE
|
||||
from onyx.db.models import User
|
||||
from onyx.db.notification import batch_create_notifications
|
||||
from onyx.server.features.release_notes.constants import DOCS_CHANGELOG_BASE_URL
|
||||
from onyx.server.features.release_notes.models import ReleaseNoteEntry
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def create_release_notifications_for_versions(
|
||||
db_session: Session,
|
||||
release_note_entries: list[ReleaseNoteEntry],
|
||||
) -> int:
|
||||
"""
|
||||
Create release notes notifications for each release note entry.
|
||||
Uses batch_create_notifications for efficient bulk insertion.
|
||||
|
||||
If a user already has a notification for a specific version (dismissed or not),
|
||||
no new one is created (handled by unique constraint on additional_data).
|
||||
|
||||
Note: Entries should already be filtered by app_version before calling this
|
||||
function. The filtering happens in _parse_mdx_to_release_note_entries().
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
release_note_entries: List of release note entries to notify about (pre-filtered)
|
||||
|
||||
Returns:
|
||||
Total number of notifications created across all versions.
|
||||
"""
|
||||
if not release_note_entries:
|
||||
logger.debug("No release note entries to notify about")
|
||||
return 0
|
||||
|
||||
# Get active users and exclude API key users
|
||||
user_ids = list(
|
||||
db_session.scalars(
|
||||
select(User.id).where( # type: ignore
|
||||
User.is_active == True, # noqa: E712
|
||||
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]),
|
||||
User.email.endswith(DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN).is_(False), # type: ignore[attr-defined]
|
||||
)
|
||||
).all()
|
||||
)
|
||||
|
||||
total_created = 0
|
||||
for entry in release_note_entries:
|
||||
# Convert version to anchor format for external docs links
|
||||
# v2.7.0 -> v2-7-0
|
||||
version_anchor = entry.version.replace(".", "-")
|
||||
|
||||
# Build UTM parameters for tracking
|
||||
utm_params = {
|
||||
"utm_source": ONYX_UTM_SOURCE,
|
||||
"utm_medium": "notification",
|
||||
"utm_campaign": INSTANCE_TYPE,
|
||||
"utm_content": f"release_notes-{entry.version}",
|
||||
}
|
||||
|
||||
link = f"{DOCS_CHANGELOG_BASE_URL}#{version_anchor}?{urlencode(utm_params)}"
|
||||
|
||||
additional_data: dict[str, str] = {
|
||||
"version": entry.version,
|
||||
"link": link,
|
||||
}
|
||||
|
||||
created_count = batch_create_notifications(
|
||||
user_ids,
|
||||
NotificationType.RELEASE_NOTES,
|
||||
db_session,
|
||||
title=entry.title,
|
||||
description=f"Check out what's new in {entry.version}",
|
||||
additional_data=additional_data,
|
||||
)
|
||||
total_created += created_count
|
||||
|
||||
logger.debug(
|
||||
f"Created {created_count} release notes notifications "
|
||||
f"(version {entry.version}, {len(user_ids)} eligible users)"
|
||||
)
|
||||
|
||||
return total_created
|
||||
@@ -113,6 +113,7 @@ def upsert_web_search_provider(
|
||||
if activate:
|
||||
set_active_web_search_provider(provider_id=provider.id, db_session=db_session)
|
||||
|
||||
db_session.commit()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
@@ -268,6 +269,7 @@ def upsert_web_content_provider(
|
||||
if activate:
|
||||
set_active_web_content_provider(provider_id=provider.id, db_session=db_session)
|
||||
|
||||
db_session.commit()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from onyx.configs.constants import MessageType
|
||||
from onyx.db.tools import get_tool_by_name
|
||||
from onyx.deep_research.dr_mock_tools import get_clarification_tool_definitions
|
||||
from onyx.deep_research.dr_mock_tools import get_orchestrator_tools
|
||||
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_DB_NAME
|
||||
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_TOOL_NAME
|
||||
from onyx.deep_research.dr_mock_tools import THINK_TOOL_RESPONSE_MESSAGE
|
||||
from onyx.deep_research.dr_mock_tools import THINK_TOOL_RESPONSE_TOKEN_COUNT
|
||||
@@ -149,9 +150,6 @@ def generate_final_report(
|
||||
is_deep_research=True,
|
||||
)
|
||||
|
||||
# Save citation mapping to state_container so citations are persisted
|
||||
state_container.set_citation_mapping(citation_processor.citation_to_doc)
|
||||
|
||||
final_report = llm_step_result.answer
|
||||
if final_report is None:
|
||||
raise ValueError("LLM failed to generate the final deep research report")
|
||||
@@ -219,90 +217,35 @@ def run_deep_research_llm_loop(
|
||||
else ""
|
||||
)
|
||||
if not skip_clarification:
|
||||
with function_span("clarification_step") as span:
|
||||
clarification_prompt = CLARIFICATION_PROMPT.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False),
|
||||
internal_search_clarification_guidance=internal_search_clarification_guidance,
|
||||
)
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=clarification_prompt,
|
||||
token_count=300, # Skips the exact token count but has enough leeway
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
|
||||
truncated_message_history = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
)
|
||||
|
||||
llm_step_result, _ = run_llm_step(
|
||||
emitter=emitter,
|
||||
history=truncated_message_history,
|
||||
tool_definitions=get_clarification_tool_definitions(),
|
||||
tool_choice=ToolChoiceOptions.AUTO,
|
||||
llm=llm,
|
||||
placement=Placement(turn_index=0),
|
||||
# No citations in this step, it should just pass through all
|
||||
# tokens directly so initialized as an empty citation processor
|
||||
citation_processor=None,
|
||||
state_container=state_container,
|
||||
final_documents=None,
|
||||
user_identity=user_identity,
|
||||
is_deep_research=True,
|
||||
)
|
||||
|
||||
if not llm_step_result.tool_calls:
|
||||
# Mark this turn as a clarification question
|
||||
state_container.set_is_clarification(True)
|
||||
span.span_data.output = "clarification_required"
|
||||
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=OverallStop(type="stop"),
|
||||
)
|
||||
)
|
||||
|
||||
# If a clarification is asked, we need to end this turn and wait on user input
|
||||
return
|
||||
|
||||
#########################################################
|
||||
# RESEARCH PLAN STEP
|
||||
#########################################################
|
||||
with function_span("research_plan_step") as span:
|
||||
clarification_prompt = CLARIFICATION_PROMPT.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False),
|
||||
internal_search_clarification_guidance=internal_search_clarification_guidance,
|
||||
)
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=RESEARCH_PLAN_PROMPT.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False)
|
||||
),
|
||||
token_count=300,
|
||||
message=clarification_prompt,
|
||||
token_count=300, # Skips the exact token count but has enough leeway
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
reminder_message = ChatMessageSimple(
|
||||
message=RESEARCH_PLAN_REMINDER,
|
||||
token_count=100,
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
|
||||
truncated_message_history = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history + [reminder_message],
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT + 1,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
)
|
||||
|
||||
research_plan_generator = run_llm_step_pkt_generator(
|
||||
llm_step_result, _ = run_llm_step(
|
||||
emitter=emitter,
|
||||
history=truncated_message_history,
|
||||
tool_definitions=[],
|
||||
tool_choice=ToolChoiceOptions.NONE,
|
||||
tool_definitions=get_clarification_tool_definitions(),
|
||||
tool_choice=ToolChoiceOptions.AUTO,
|
||||
llm=llm,
|
||||
placement=Placement(turn_index=0),
|
||||
# No citations in this step, it should just pass through all
|
||||
# tokens directly so initialized as an empty citation processor
|
||||
citation_processor=None,
|
||||
state_container=state_container,
|
||||
final_documents=None,
|
||||
@@ -310,90 +253,301 @@ def run_deep_research_llm_loop(
|
||||
is_deep_research=True,
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
packet = next(research_plan_generator)
|
||||
# Translate AgentResponseStart/Delta packets to DeepResearchPlanStart/Delta
|
||||
# The LLM response from this prompt is the research plan
|
||||
if isinstance(packet.obj, AgentResponseStart):
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=packet.placement,
|
||||
obj=DeepResearchPlanStart(),
|
||||
)
|
||||
)
|
||||
elif isinstance(packet.obj, AgentResponseDelta):
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=packet.placement,
|
||||
obj=DeepResearchPlanDelta(content=packet.obj.content),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Pass through other packet types (e.g., ReasoningStart, ReasoningDelta, etc.)
|
||||
emitter.emit(packet)
|
||||
except StopIteration as e:
|
||||
llm_step_result, reasoned = e.value
|
||||
if not llm_step_result.tool_calls:
|
||||
# Mark this turn as a clarification question
|
||||
state_container.set_is_clarification(True)
|
||||
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=0), obj=OverallStop(type="stop")
|
||||
)
|
||||
)
|
||||
|
||||
# If a clarification is asked, we need to end this turn and wait on user input
|
||||
return
|
||||
|
||||
#########################################################
|
||||
# RESEARCH PLAN STEP
|
||||
#########################################################
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=RESEARCH_PLAN_PROMPT.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False)
|
||||
),
|
||||
token_count=300,
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
reminder_message = ChatMessageSimple(
|
||||
message=RESEARCH_PLAN_REMINDER,
|
||||
token_count=100,
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
truncated_message_history = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history + [reminder_message],
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT + 1,
|
||||
)
|
||||
|
||||
research_plan_generator = run_llm_step_pkt_generator(
|
||||
history=truncated_message_history,
|
||||
tool_definitions=[],
|
||||
tool_choice=ToolChoiceOptions.NONE,
|
||||
llm=llm,
|
||||
placement=Placement(turn_index=0),
|
||||
citation_processor=None,
|
||||
state_container=state_container,
|
||||
final_documents=None,
|
||||
user_identity=user_identity,
|
||||
is_deep_research=True,
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
packet = next(research_plan_generator)
|
||||
# Translate AgentResponseStart/Delta packets to DeepResearchPlanStart/Delta
|
||||
# The LLM response from this prompt is the research plan
|
||||
if isinstance(packet.obj, AgentResponseStart):
|
||||
emitter.emit(
|
||||
Packet(
|
||||
# Marks the last turn end which should be the plan generation
|
||||
placement=Placement(
|
||||
turn_index=1 if reasoned else 0,
|
||||
),
|
||||
obj=SectionEnd(),
|
||||
placement=packet.placement,
|
||||
obj=DeepResearchPlanStart(),
|
||||
)
|
||||
)
|
||||
if reasoned:
|
||||
orchestrator_start_turn_index += 1
|
||||
break
|
||||
llm_step_result = cast(LlmStepResult, llm_step_result)
|
||||
elif isinstance(packet.obj, AgentResponseDelta):
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=packet.placement,
|
||||
obj=DeepResearchPlanDelta(content=packet.obj.content),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Pass through other packet types (e.g., ReasoningStart, ReasoningDelta, etc.)
|
||||
emitter.emit(packet)
|
||||
except StopIteration as e:
|
||||
llm_step_result, reasoned = e.value
|
||||
emitter.emit(
|
||||
Packet(
|
||||
# Marks the last turn end which should be the plan generation
|
||||
placement=Placement(
|
||||
turn_index=1 if reasoned else 0,
|
||||
),
|
||||
obj=SectionEnd(),
|
||||
)
|
||||
)
|
||||
if reasoned:
|
||||
orchestrator_start_turn_index += 1
|
||||
break
|
||||
llm_step_result = cast(LlmStepResult, llm_step_result)
|
||||
|
||||
research_plan = llm_step_result.answer
|
||||
span.span_data.output = research_plan if research_plan else None
|
||||
research_plan = llm_step_result.answer
|
||||
|
||||
#########################################################
|
||||
# RESEARCH EXECUTION STEP
|
||||
#########################################################
|
||||
with function_span("research_execution_step") as span:
|
||||
is_reasoning_model = model_is_reasoning_model(
|
||||
llm.config.model_name, llm.config.model_provider
|
||||
)
|
||||
is_reasoning_model = model_is_reasoning_model(
|
||||
llm.config.model_name, llm.config.model_provider
|
||||
)
|
||||
|
||||
max_orchestrator_cycles = (
|
||||
MAX_ORCHESTRATOR_CYCLES
|
||||
if not is_reasoning_model
|
||||
else MAX_ORCHESTRATOR_CYCLES_REASONING
|
||||
)
|
||||
max_orchestrator_cycles = (
|
||||
MAX_ORCHESTRATOR_CYCLES
|
||||
if not is_reasoning_model
|
||||
else MAX_ORCHESTRATOR_CYCLES_REASONING
|
||||
)
|
||||
|
||||
orchestrator_prompt_template = (
|
||||
ORCHESTRATOR_PROMPT
|
||||
if not is_reasoning_model
|
||||
else ORCHESTRATOR_PROMPT_REASONING
|
||||
)
|
||||
orchestrator_prompt_template = (
|
||||
ORCHESTRATOR_PROMPT
|
||||
if not is_reasoning_model
|
||||
else ORCHESTRATOR_PROMPT_REASONING
|
||||
)
|
||||
|
||||
internal_search_research_task_guidance = (
|
||||
INTERNAL_SEARCH_RESEARCH_TASK_GUIDANCE
|
||||
if include_internal_search_tunings
|
||||
else ""
|
||||
)
|
||||
token_count_prompt = orchestrator_prompt_template.format(
|
||||
internal_search_research_task_guidance = (
|
||||
INTERNAL_SEARCH_RESEARCH_TASK_GUIDANCE
|
||||
if include_internal_search_tunings
|
||||
else ""
|
||||
)
|
||||
token_count_prompt = orchestrator_prompt_template.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False),
|
||||
current_cycle_count=1,
|
||||
max_cycles=max_orchestrator_cycles,
|
||||
research_plan=research_plan,
|
||||
internal_search_research_task_guidance=internal_search_research_task_guidance,
|
||||
)
|
||||
orchestration_tokens = token_counter(token_count_prompt)
|
||||
|
||||
reasoning_cycles = 0
|
||||
most_recent_reasoning: str | None = None
|
||||
citation_mapping: CitationMapping = {}
|
||||
final_turn_index: int = (
|
||||
orchestrator_start_turn_index # Track the final turn_index for stop packet
|
||||
)
|
||||
for cycle in range(max_orchestrator_cycles):
|
||||
if cycle == max_orchestrator_cycles - 1:
|
||||
# If it's the last cycle, forcibly generate the final report
|
||||
report_turn_index = (
|
||||
orchestrator_start_turn_index + cycle + reasoning_cycles
|
||||
)
|
||||
report_reasoned = generate_final_report(
|
||||
history=simple_chat_history,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
state_container=state_container,
|
||||
emitter=emitter,
|
||||
turn_index=report_turn_index,
|
||||
citation_mapping=citation_mapping,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
# Update final_turn_index: base + 1 for the report itself + 1 if reasoning occurred
|
||||
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
|
||||
break
|
||||
|
||||
research_agent_calls: list[ToolCallKickoff] = []
|
||||
|
||||
orchestrator_prompt = orchestrator_prompt_template.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False),
|
||||
current_cycle_count=1,
|
||||
current_cycle_count=cycle,
|
||||
max_cycles=max_orchestrator_cycles,
|
||||
research_plan=research_plan,
|
||||
internal_search_research_task_guidance=internal_search_research_task_guidance,
|
||||
)
|
||||
orchestration_tokens = token_counter(token_count_prompt)
|
||||
|
||||
reasoning_cycles = 0
|
||||
most_recent_reasoning: str | None = None
|
||||
citation_mapping: CitationMapping = {}
|
||||
final_turn_index: int = (
|
||||
orchestrator_start_turn_index # Track the final turn_index for stop packet
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=orchestrator_prompt,
|
||||
token_count=orchestration_tokens,
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
for cycle in range(max_orchestrator_cycles):
|
||||
if cycle == max_orchestrator_cycles - 1:
|
||||
# If it's the last cycle, forcibly generate the final report
|
||||
|
||||
truncated_message_history = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
)
|
||||
|
||||
# Use think tool processor for non-reasoning models to convert
|
||||
# think_tool calls to reasoning content
|
||||
custom_processor = (
|
||||
create_think_tool_token_processor() if not is_reasoning_model else None
|
||||
)
|
||||
|
||||
llm_step_result, has_reasoned = run_llm_step(
|
||||
emitter=emitter,
|
||||
history=truncated_message_history,
|
||||
tool_definitions=get_orchestrator_tools(
|
||||
include_think_tool=not is_reasoning_model
|
||||
),
|
||||
tool_choice=ToolChoiceOptions.REQUIRED,
|
||||
llm=llm,
|
||||
placement=Placement(
|
||||
turn_index=orchestrator_start_turn_index + cycle + reasoning_cycles
|
||||
),
|
||||
# No citations in this step, it should just pass through all
|
||||
# tokens directly so initialized as an empty citation processor
|
||||
citation_processor=DynamicCitationProcessor(),
|
||||
state_container=state_container,
|
||||
final_documents=None,
|
||||
user_identity=user_identity,
|
||||
custom_token_processor=custom_processor,
|
||||
is_deep_research=True,
|
||||
)
|
||||
if has_reasoned:
|
||||
reasoning_cycles += 1
|
||||
|
||||
tool_calls = llm_step_result.tool_calls or []
|
||||
|
||||
if not tool_calls and cycle == 0:
|
||||
raise RuntimeError(
|
||||
"Deep Research failed to generate any research tasks for the agents."
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
# Basically hope that this is an infrequent occurence and hopefully multiple research
|
||||
# cycles have already ran
|
||||
logger.warning("No tool calls found, this should not happen.")
|
||||
report_turn_index = (
|
||||
orchestrator_start_turn_index + cycle + reasoning_cycles
|
||||
)
|
||||
report_reasoned = generate_final_report(
|
||||
history=simple_chat_history,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
state_container=state_container,
|
||||
emitter=emitter,
|
||||
turn_index=report_turn_index,
|
||||
citation_mapping=citation_mapping,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
|
||||
break
|
||||
|
||||
special_tool_calls = check_special_tool_calls(tool_calls=tool_calls)
|
||||
|
||||
if special_tool_calls.generate_report_tool_call:
|
||||
report_turn_index = (
|
||||
special_tool_calls.generate_report_tool_call.placement.turn_index
|
||||
)
|
||||
report_reasoned = generate_final_report(
|
||||
history=simple_chat_history,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
state_container=state_container,
|
||||
emitter=emitter,
|
||||
turn_index=report_turn_index,
|
||||
citation_mapping=citation_mapping,
|
||||
user_identity=user_identity,
|
||||
saved_reasoning=most_recent_reasoning,
|
||||
)
|
||||
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
|
||||
break
|
||||
elif special_tool_calls.think_tool_call:
|
||||
think_tool_call = special_tool_calls.think_tool_call
|
||||
# Only process the THINK_TOOL and skip all other tool calls
|
||||
# This will not actually get saved to the db as a tool call but we'll attach it to the tool(s) called after
|
||||
# it as if it were just a reasoning model doing it. In the chat history, because it happens in 2 steps,
|
||||
# we will show it as a separate message.
|
||||
# NOTE: This does not need to increment the reasoning cycles because the custom token processor causes
|
||||
# the LLM step to handle this
|
||||
with function_span("think_tool") as span:
|
||||
span.span_data.input = str(think_tool_call.tool_args)
|
||||
most_recent_reasoning = state_container.reasoning_tokens
|
||||
tool_call_message = think_tool_call.to_msg_str()
|
||||
|
||||
think_tool_msg = ChatMessageSimple(
|
||||
message=tool_call_message,
|
||||
token_count=token_counter(tool_call_message),
|
||||
message_type=MessageType.TOOL_CALL,
|
||||
tool_call_id=think_tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(think_tool_msg)
|
||||
|
||||
think_tool_response_msg = ChatMessageSimple(
|
||||
message=THINK_TOOL_RESPONSE_MESSAGE,
|
||||
token_count=THINK_TOOL_RESPONSE_TOKEN_COUNT,
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id=think_tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(think_tool_response_msg)
|
||||
span.span_data.output = THINK_TOOL_RESPONSE_MESSAGE
|
||||
continue
|
||||
else:
|
||||
for tool_call in tool_calls:
|
||||
if tool_call.tool_name != RESEARCH_AGENT_TOOL_NAME:
|
||||
logger.warning(f"Unexpected tool call: {tool_call.tool_name}")
|
||||
continue
|
||||
|
||||
research_agent_calls.append(tool_call)
|
||||
|
||||
if not research_agent_calls:
|
||||
logger.warning(
|
||||
"No research agent tool calls found, this should not happen."
|
||||
)
|
||||
report_turn_index = (
|
||||
orchestrator_start_turn_index + cycle + reasoning_cycles
|
||||
)
|
||||
@@ -407,267 +561,94 @@ def run_deep_research_llm_loop(
|
||||
citation_mapping=citation_mapping,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
# Update final_turn_index: base + 1 for the report itself + 1 if reasoning occurred
|
||||
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
|
||||
break
|
||||
|
||||
research_agent_calls: list[ToolCallKickoff] = []
|
||||
if len(research_agent_calls) > 1:
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(
|
||||
turn_index=research_agent_calls[0].placement.turn_index
|
||||
),
|
||||
obj=TopLevelBranching(
|
||||
num_parallel_branches=len(research_agent_calls)
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
orchestrator_prompt = orchestrator_prompt_template.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False),
|
||||
current_cycle_count=cycle,
|
||||
max_cycles=max_orchestrator_cycles,
|
||||
research_plan=research_plan,
|
||||
internal_search_research_task_guidance=internal_search_research_task_guidance,
|
||||
)
|
||||
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=orchestrator_prompt,
|
||||
token_count=orchestration_tokens,
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
|
||||
truncated_message_history = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
)
|
||||
|
||||
# Use think tool processor for non-reasoning models to convert
|
||||
# think_tool calls to reasoning content
|
||||
custom_processor = (
|
||||
create_think_tool_token_processor()
|
||||
if not is_reasoning_model
|
||||
else None
|
||||
)
|
||||
|
||||
llm_step_result, has_reasoned = run_llm_step(
|
||||
research_results = run_research_agent_calls(
|
||||
# The tool calls here contain the placement information
|
||||
research_agent_calls=research_agent_calls,
|
||||
parent_tool_call_ids=[
|
||||
tool_call.tool_call_id for tool_call in tool_calls
|
||||
],
|
||||
tools=allowed_tools,
|
||||
emitter=emitter,
|
||||
history=truncated_message_history,
|
||||
tool_definitions=get_orchestrator_tools(
|
||||
include_think_tool=not is_reasoning_model
|
||||
),
|
||||
tool_choice=ToolChoiceOptions.REQUIRED,
|
||||
state_container=state_container,
|
||||
llm=llm,
|
||||
placement=Placement(
|
||||
is_reasoning_model=is_reasoning_model,
|
||||
token_counter=token_counter,
|
||||
citation_mapping=citation_mapping,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
|
||||
citation_mapping = research_results.citation_mapping
|
||||
|
||||
for tab_index, report in enumerate(
|
||||
research_results.intermediate_reports
|
||||
):
|
||||
if report is None:
|
||||
# The LLM will not see that this research was even attempted, it may try
|
||||
# something similar again but this is not bad.
|
||||
logger.error(
|
||||
f"Research agent call at tab_index {tab_index} failed, skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
current_tool_call = research_agent_calls[tab_index]
|
||||
tool_call_info = ToolCallInfo(
|
||||
parent_tool_call_id=None,
|
||||
turn_index=orchestrator_start_turn_index
|
||||
+ cycle
|
||||
+ reasoning_cycles
|
||||
),
|
||||
# No citations in this step, it should just pass through all
|
||||
# tokens directly so initialized as an empty citation processor
|
||||
citation_processor=DynamicCitationProcessor(),
|
||||
state_container=state_container,
|
||||
final_documents=None,
|
||||
user_identity=user_identity,
|
||||
custom_token_processor=custom_processor,
|
||||
is_deep_research=True,
|
||||
)
|
||||
if has_reasoned:
|
||||
reasoning_cycles += 1
|
||||
|
||||
tool_calls = llm_step_result.tool_calls or []
|
||||
|
||||
if not tool_calls and cycle == 0:
|
||||
raise RuntimeError(
|
||||
"Deep Research failed to generate any research tasks for the agents."
|
||||
+ reasoning_cycles,
|
||||
tab_index=tab_index,
|
||||
tool_name=current_tool_call.tool_name,
|
||||
tool_call_id=current_tool_call.tool_call_id,
|
||||
tool_id=get_tool_by_name(
|
||||
tool_name=RESEARCH_AGENT_DB_NAME, db_session=db_session
|
||||
).id,
|
||||
reasoning_tokens=llm_step_result.reasoning
|
||||
or most_recent_reasoning,
|
||||
tool_call_arguments=current_tool_call.tool_args,
|
||||
tool_call_response=report,
|
||||
search_docs=None, # Intermediate docs are not saved/shown
|
||||
generated_images=None,
|
||||
)
|
||||
state_container.add_tool_call(tool_call_info)
|
||||
|
||||
if not tool_calls:
|
||||
# Basically hope that this is an infrequent occurence and hopefully multiple research
|
||||
# cycles have already ran
|
||||
logger.warning("No tool calls found, this should not happen.")
|
||||
report_turn_index = (
|
||||
orchestrator_start_turn_index + cycle + reasoning_cycles
|
||||
tool_call_message = current_tool_call.to_msg_str()
|
||||
tool_call_token_count = token_counter(tool_call_message)
|
||||
|
||||
tool_call_msg = ChatMessageSimple(
|
||||
message=tool_call_message,
|
||||
token_count=tool_call_token_count,
|
||||
message_type=MessageType.TOOL_CALL,
|
||||
tool_call_id=current_tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
report_reasoned = generate_final_report(
|
||||
history=simple_chat_history,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
state_container=state_container,
|
||||
emitter=emitter,
|
||||
turn_index=report_turn_index,
|
||||
citation_mapping=citation_mapping,
|
||||
user_identity=user_identity,
|
||||
simple_chat_history.append(tool_call_msg)
|
||||
|
||||
tool_call_response_msg = ChatMessageSimple(
|
||||
message=report,
|
||||
token_count=token_counter(report),
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id=current_tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
|
||||
break
|
||||
simple_chat_history.append(tool_call_response_msg)
|
||||
|
||||
special_tool_calls = check_special_tool_calls(tool_calls=tool_calls)
|
||||
|
||||
if special_tool_calls.generate_report_tool_call:
|
||||
report_turn_index = (
|
||||
special_tool_calls.generate_report_tool_call.placement.turn_index
|
||||
)
|
||||
report_reasoned = generate_final_report(
|
||||
history=simple_chat_history,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
state_container=state_container,
|
||||
emitter=emitter,
|
||||
turn_index=report_turn_index,
|
||||
citation_mapping=citation_mapping,
|
||||
user_identity=user_identity,
|
||||
saved_reasoning=most_recent_reasoning,
|
||||
)
|
||||
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
|
||||
break
|
||||
elif special_tool_calls.think_tool_call:
|
||||
think_tool_call = special_tool_calls.think_tool_call
|
||||
# Only process the THINK_TOOL and skip all other tool calls
|
||||
# This will not actually get saved to the db as a tool call but we'll attach it to the tool(s) called after
|
||||
# it as if it were just a reasoning model doing it. In the chat history, because it happens in 2 steps,
|
||||
# we will show it as a separate message.
|
||||
# NOTE: This does not need to increment the reasoning cycles because the custom token processor causes
|
||||
# the LLM step to handle this
|
||||
with function_span("think_tool") as span:
|
||||
span.span_data.input = str(think_tool_call.tool_args)
|
||||
most_recent_reasoning = state_container.reasoning_tokens
|
||||
tool_call_message = think_tool_call.to_msg_str()
|
||||
|
||||
think_tool_msg = ChatMessageSimple(
|
||||
message=tool_call_message,
|
||||
token_count=token_counter(tool_call_message),
|
||||
message_type=MessageType.TOOL_CALL,
|
||||
tool_call_id=think_tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(think_tool_msg)
|
||||
|
||||
think_tool_response_msg = ChatMessageSimple(
|
||||
message=THINK_TOOL_RESPONSE_MESSAGE,
|
||||
token_count=THINK_TOOL_RESPONSE_TOKEN_COUNT,
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id=think_tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(think_tool_response_msg)
|
||||
span.span_data.output = THINK_TOOL_RESPONSE_MESSAGE
|
||||
continue
|
||||
else:
|
||||
for tool_call in tool_calls:
|
||||
if tool_call.tool_name != RESEARCH_AGENT_TOOL_NAME:
|
||||
logger.warning(
|
||||
f"Unexpected tool call: {tool_call.tool_name}"
|
||||
)
|
||||
continue
|
||||
|
||||
research_agent_calls.append(tool_call)
|
||||
|
||||
if not research_agent_calls:
|
||||
logger.warning(
|
||||
"No research agent tool calls found, this should not happen."
|
||||
)
|
||||
report_turn_index = (
|
||||
orchestrator_start_turn_index + cycle + reasoning_cycles
|
||||
)
|
||||
report_reasoned = generate_final_report(
|
||||
history=simple_chat_history,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
state_container=state_container,
|
||||
emitter=emitter,
|
||||
turn_index=report_turn_index,
|
||||
citation_mapping=citation_mapping,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
final_turn_index = report_turn_index + (
|
||||
1 if report_reasoned else 0
|
||||
)
|
||||
break
|
||||
|
||||
if len(research_agent_calls) > 1:
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(
|
||||
turn_index=research_agent_calls[
|
||||
0
|
||||
].placement.turn_index
|
||||
),
|
||||
obj=TopLevelBranching(
|
||||
num_parallel_branches=len(research_agent_calls)
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
research_results = run_research_agent_calls(
|
||||
# The tool calls here contain the placement information
|
||||
research_agent_calls=research_agent_calls,
|
||||
parent_tool_call_ids=[
|
||||
tool_call.tool_call_id for tool_call in tool_calls
|
||||
],
|
||||
tools=allowed_tools,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
llm=llm,
|
||||
is_reasoning_model=is_reasoning_model,
|
||||
token_counter=token_counter,
|
||||
citation_mapping=citation_mapping,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
|
||||
citation_mapping = research_results.citation_mapping
|
||||
|
||||
for tab_index, report in enumerate(
|
||||
research_results.intermediate_reports
|
||||
):
|
||||
if report is None:
|
||||
# The LLM will not see that this research was even attempted, it may try
|
||||
# something similar again but this is not bad.
|
||||
logger.error(
|
||||
f"Research agent call at tab_index {tab_index} failed, skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
current_tool_call = research_agent_calls[tab_index]
|
||||
tool_call_info = ToolCallInfo(
|
||||
parent_tool_call_id=None,
|
||||
turn_index=orchestrator_start_turn_index
|
||||
+ cycle
|
||||
+ reasoning_cycles,
|
||||
tab_index=tab_index,
|
||||
tool_name=current_tool_call.tool_name,
|
||||
tool_call_id=current_tool_call.tool_call_id,
|
||||
tool_id=get_tool_by_name(
|
||||
tool_name=RESEARCH_AGENT_TOOL_NAME,
|
||||
db_session=db_session,
|
||||
).id,
|
||||
reasoning_tokens=llm_step_result.reasoning
|
||||
or most_recent_reasoning,
|
||||
tool_call_arguments=current_tool_call.tool_args,
|
||||
tool_call_response=report,
|
||||
search_docs=None, # Intermediate docs are not saved/shown
|
||||
generated_images=None,
|
||||
)
|
||||
state_container.add_tool_call(tool_call_info)
|
||||
|
||||
tool_call_message = current_tool_call.to_msg_str()
|
||||
tool_call_token_count = token_counter(tool_call_message)
|
||||
|
||||
tool_call_msg = ChatMessageSimple(
|
||||
message=tool_call_message,
|
||||
token_count=tool_call_token_count,
|
||||
message_type=MessageType.TOOL_CALL,
|
||||
tool_call_id=current_tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(tool_call_msg)
|
||||
|
||||
tool_call_response_msg = ChatMessageSimple(
|
||||
message=report,
|
||||
token_count=token_counter(report),
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id=current_tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(tool_call_response_msg)
|
||||
|
||||
# If it reached this point, it did not call reasoning, so here we wipe it to not save it to multiple turns
|
||||
most_recent_reasoning = None
|
||||
# If it reached this point, it did not call reasoning, so here we wipe it to not save it to multiple turns
|
||||
most_recent_reasoning = None
|
||||
|
||||
emitter.emit(
|
||||
Packet(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
GENERATE_PLAN_TOOL_NAME = "generate_plan"
|
||||
|
||||
RESEARCH_AGENT_IN_CODE_ID = "ResearchAgent"
|
||||
RESEARCH_AGENT_DB_NAME = "ResearchAgent"
|
||||
RESEARCH_AGENT_TOOL_NAME = "research_agent"
|
||||
RESEARCH_AGENT_TASK_KEY = "task"
|
||||
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
# Opensearch Idiosyncrasies
|
||||
|
||||
## How it works at a high level
|
||||
Opensearch has 2 phases, a `Search` phase and a `Fetch` phase. The `Search` phase works by getting the document scores on each
|
||||
shard separately, then typically a fetch phase grabs all of the relevant fields/data for returning to the user. There is also
|
||||
an intermediate phase (seemingly built specifically to handle hybrid search queries) which can run in between as a processor.
|
||||
References:
|
||||
https://docs.opensearch.org/latest/search-plugins/search-pipelines/search-processors/
|
||||
https://docs.opensearch.org/latest/search-plugins/search-pipelines/normalization-processor/
|
||||
https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
|
||||
|
||||
## How Hybrid queries work
|
||||
Hybrid queries are basically parallel queries that each run through their own `Search` phase and do not interact in any way.
|
||||
They also run across all the shards. It is not entirely clear what happens if a combination pipeline is not specified for them,
|
||||
perhaps the scores are just summed.
|
||||
|
||||
When the normalization processor is applied to keyword/vector hybrid searches, documents that show up due to keyword match may
|
||||
not also have showed up in the vector search and vice versa. In these situations, it just receives a 0 score for the missing
|
||||
query component. Opensearch does not run another phase to recapture those missing values. The impact of this is that after
|
||||
normalizing, the missing scores are 0 but this is a higher score than if it actually received a non-zero score.
|
||||
|
||||
This may not be immediately obvious so an explanation is included here. If it got a non-zero score instead, it must be lower
|
||||
than all of the other scores of the list (otherwise it would have shown up). Therefore it would impact the normalization and
|
||||
push the other scores higher so that it's not only the lowest score still, but now it's a differentiated lowest score. This is
|
||||
not strictly the case in a multi-node setup but the high level concept approximately holds. So basically the 0 score is a form
|
||||
of "minimum value clipping".
|
||||
|
||||
## On time decay and boosting
|
||||
Embedding models do not have a uniform distribution from 0 to 1. The values typically cluster strongly around 0.6 to 0.8 but also
|
||||
varies between models and even the query. It is not a safe assumption to pre-normalize the scores so we also cannot apply any
|
||||
additive or multiplicative boost to it. Ie. if results of a doc cluster around 0.6 to 0.8 and I give a 50% penalty to the score,
|
||||
it doesn't bring a result from the top of the range to 50 percentile, it brings its under the 0.6 and is now the worst match.
|
||||
Same logic applies to additive boosting.
|
||||
|
||||
So these boosts can only be applied after normalization. Unfortunately with Opensearch, the normalization processor runs last
|
||||
and only applies to the results of the completely independent `Search` phase queries. So if a time based boost (a separate
|
||||
query which filters on recently updated documents) is added, it would not be able to introduce any new documents
|
||||
to the set (since the new documents would have no keyword/vector score or already be present) since the 0 scores on keyword
|
||||
and vector would make the docs which only came because of time filter very low scoring. This can however make some of the lower
|
||||
scored documents from the union of all the `Search` phase documents to show up higher and potentially not get dropped before
|
||||
being fetched and returned to the user. But there are other issues of including these:
|
||||
- There is no way to sort by this field, only a filter, so there's no way to guarantee the best docs even irrespective of the
|
||||
contents. If there are lots of updates, this may miss
|
||||
- There is not a good way to normalize this field, the best is to clip it on the bottom.
|
||||
- This would require using min-max norm but z-score norm is better for the other functions due to things like it being less
|
||||
sensitive to outliers, better handles distribution drifts (min-max assumes stable meaningful ranges), better for comparing
|
||||
"unusual-ness" across distributions.
|
||||
|
||||
So while it is possible to apply time based boosting at the normalization stage (or specifically to the keyword score), we have
|
||||
decided it is better to not apply it during the OpenSearch query.
|
||||
|
||||
Because of these limitations, Onyx in code applies further refinements, boostings, etc. based on OpenSearch providing an initial
|
||||
filtering. The impact of time decay and boost should not be so big that we would need orders of magnitude more results back
|
||||
from OpenSearch.
|
||||
|
||||
## Other concepts to be aware of
|
||||
Within the `Search` phase, there are optional steps like Rescore but these are not useful for the combination/normalization
|
||||
work that is relevant for the hybrid search. Since the Rescore happens prior to normalization, it's not able to provide any
|
||||
meaningful operations to the query for our usage.
|
||||
|
||||
Because the Title is included in the Contents for both embedding and keyword searches, the Title scores are very low relative to
|
||||
the actual full contents scoring. It is seen as a boost rather than a core scoring component. Time decay works similarly.
|
||||
@@ -3,9 +3,6 @@ import json
|
||||
import httpx
|
||||
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_experts_stores_representations,
|
||||
)
|
||||
from onyx.context.search.enums import QueryType
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
@@ -47,7 +44,6 @@ from onyx.document_index.opensearch.search import (
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import Document
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
|
||||
@@ -62,36 +58,50 @@ def _convert_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
blurb=chunk.blurb,
|
||||
content=chunk.content,
|
||||
source_links=json.loads(chunk.source_links) if chunk.source_links else None,
|
||||
image_file_id=chunk.image_file_id,
|
||||
# Deprecated. Fill in some reasonable default.
|
||||
image_file_id=chunk.image_file_name,
|
||||
# TODO(andrei) Yuhong says he doesn't think we need that anymore. Used
|
||||
# if a section needed to be split into diff chunks. A section is a part
|
||||
# of a doc that a link will take you to. But don't chunks have their own
|
||||
# links? Look at this in a followup.
|
||||
section_continuation=False,
|
||||
document_id=chunk.document_id,
|
||||
source_type=DocumentSource(chunk.source_type),
|
||||
semantic_identifier=chunk.semantic_identifier,
|
||||
title=chunk.title,
|
||||
boost=chunk.global_boost,
|
||||
# TODO(andrei): Do in a followup. We should be able to get this from
|
||||
# OpenSearch.
|
||||
# TODO(andrei): Same comment as in
|
||||
# _convert_onyx_chunk_to_opensearch_document. Yuhong thinks OpenSearch
|
||||
# has some thing out of the box for this. Just need to look at it in a
|
||||
# followup.
|
||||
boost=1,
|
||||
# TODO(andrei): Do in a followup.
|
||||
recency_bias=1.0,
|
||||
# TODO(andrei): This is how good the match is, we need this, key insight
|
||||
# is we can order chunks by this. Should not be hard to plumb this from
|
||||
# a search result, do that in a followup.
|
||||
score=None,
|
||||
hidden=chunk.hidden,
|
||||
metadata=json.loads(chunk.metadata),
|
||||
# TODO(andrei): Don't worry about these for now.
|
||||
# is_relevant
|
||||
# relevance_explanation
|
||||
# metadata
|
||||
# TODO(andrei): Same comment as in
|
||||
# _convert_onyx_chunk_to_opensearch_document.
|
||||
metadata={},
|
||||
# TODO(andrei): The vector DB needs to supply this. I vaguely know
|
||||
# OpenSearch can from the documentation I've seen till now, look at this
|
||||
# in a followup.
|
||||
match_highlights=[],
|
||||
# TODO(andrei) Consider storing a chunk content index instead of a full
|
||||
# string when working on chunk content augmentation.
|
||||
doc_summary=chunk.doc_summary,
|
||||
# TODO(andrei) This content is not queried on, it is only used to clean
|
||||
# appended content to chunks. Consider storing a chunk content index
|
||||
# instead of a full string when working on chunk content augmentation.
|
||||
doc_summary="",
|
||||
# TODO(andrei) Same thing as contx ret above, LLM gens context for each
|
||||
# chunk.
|
||||
chunk_context=chunk.chunk_context,
|
||||
chunk_context="",
|
||||
updated_at=chunk.last_updated,
|
||||
primary_owners=chunk.primary_owners,
|
||||
secondary_owners=chunk.secondary_owners,
|
||||
# primary_owners TODO(andrei)
|
||||
# secondary_owners TODO(andrei)
|
||||
# large_chunk_reference_ids TODO(andrei): Don't worry about this one.
|
||||
# TODO(andrei): This is the suffix appended to the end of the chunk
|
||||
# content to assist querying. There are better ways we can do this, for
|
||||
# ex. keeping an index of where to string split from.
|
||||
@@ -116,31 +126,44 @@ def _convert_onyx_chunk_to_opensearch_document(
|
||||
title_vector=chunk.title_embedding,
|
||||
content=chunk.content,
|
||||
content_vector=chunk.embeddings.full_embedding,
|
||||
# TODO(andrei): We should know this. Reason to have this is convenience,
|
||||
# but it could also change when you change your embedding model, maybe
|
||||
# we can remove it, Yuhong to look at this. Hardcoded to some nonsense
|
||||
# value for now.
|
||||
num_tokens=0,
|
||||
source_type=chunk.source_document.source.value,
|
||||
metadata=json.dumps(chunk.source_document.metadata),
|
||||
# TODO(andrei): This is just represented a bit differently in
|
||||
# DocumentBase than how we expect it in the schema currently. Look at
|
||||
# this closer in a followup. Always defaults to None for now.
|
||||
# metadata=chunk.source_document.metadata,
|
||||
last_updated=chunk.source_document.doc_updated_at,
|
||||
# TODO(andrei): Don't currently see an easy way of porting this, and
|
||||
# besides some connectors genuinely don't have this data. Look at this
|
||||
# closer in a followup. Always defaults to None for now.
|
||||
# created_at=None,
|
||||
public=chunk.access.is_public,
|
||||
access_control_list=list(chunk.access.to_acl()),
|
||||
global_boost=chunk.boost,
|
||||
# TODO(andrei): Implement ACL in a followup, currently none of the
|
||||
# methods in OpenSearchDocumentIndex support it anyway. Always defaults
|
||||
# to None for now.
|
||||
# access_control_list=chunk.access.to_acl(),
|
||||
# TODO(andrei): This doesn't work bc global_boost is float, presumably
|
||||
# between 0.0 and inf (check this) and chunk.boost is an int from -inf
|
||||
# to +inf. Look at how the scaling compares between these in a followup.
|
||||
# Always defaults to 1.0 for now.
|
||||
# global_boost=chunk.boost,
|
||||
semantic_identifier=chunk.source_document.semantic_identifier,
|
||||
image_file_id=chunk.image_file_id,
|
||||
# TODO(andrei): Ask Chris more about this later. Always defaults to None
|
||||
# for now.
|
||||
# image_file_name=None,
|
||||
source_links=json.dumps(chunk.source_links) if chunk.source_links else None,
|
||||
blurb=chunk.blurb,
|
||||
doc_summary=chunk.doc_summary,
|
||||
chunk_context=chunk.chunk_context,
|
||||
document_sets=list(chunk.document_sets) if chunk.document_sets else None,
|
||||
project_ids=list(chunk.user_project) if chunk.user_project else None,
|
||||
primary_owners=get_experts_stores_representations(
|
||||
chunk.source_document.primary_owners
|
||||
),
|
||||
secondary_owners=get_experts_stores_representations(
|
||||
chunk.source_document.secondary_owners
|
||||
),
|
||||
# TODO(andrei): Consider not even getting this from
|
||||
# DocMetadataAwareIndexChunk and instead using OpenSearchDocumentIndex's
|
||||
# instance variable. One source of truth -> less chance of a very bad
|
||||
# bug in prod.
|
||||
tenant_id=TenantState(tenant_id=chunk.tenant_id, multitenant=MULTI_TENANT),
|
||||
tenant_id=chunk.tenant_id,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -4,35 +4,30 @@ from typing import Any
|
||||
from typing import Self
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import field_serializer
|
||||
from pydantic import field_validator
|
||||
from pydantic import model_serializer
|
||||
from pydantic import model_validator
|
||||
from pydantic import SerializerFunctionWrapHandler
|
||||
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.constants import EF_CONSTRUCTION
|
||||
from onyx.document_index.opensearch.constants import EF_SEARCH
|
||||
from onyx.document_index.opensearch.constants import M
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
TITLE_FIELD_NAME = "title"
|
||||
TITLE_VECTOR_FIELD_NAME = "title_vector"
|
||||
CONTENT_FIELD_NAME = "content"
|
||||
CONTENT_VECTOR_FIELD_NAME = "content_vector"
|
||||
NUM_TOKENS_FIELD_NAME = "num_tokens"
|
||||
SOURCE_TYPE_FIELD_NAME = "source_type"
|
||||
METADATA_FIELD_NAME = "metadata"
|
||||
LAST_UPDATED_FIELD_NAME = "last_updated"
|
||||
CREATED_AT_FIELD_NAME = "created_at"
|
||||
PUBLIC_FIELD_NAME = "public"
|
||||
ACCESS_CONTROL_LIST_FIELD_NAME = "access_control_list"
|
||||
HIDDEN_FIELD_NAME = "hidden"
|
||||
GLOBAL_BOOST_FIELD_NAME = "global_boost"
|
||||
SEMANTIC_IDENTIFIER_FIELD_NAME = "semantic_identifier"
|
||||
IMAGE_FILE_ID_FIELD_NAME = "image_file_id"
|
||||
IMAGE_FILE_NAME_FIELD_NAME = "image_file_name"
|
||||
SOURCE_LINKS_FIELD_NAME = "source_links"
|
||||
DOCUMENT_SETS_FIELD_NAME = "document_sets"
|
||||
PROJECT_IDS_FIELD_NAME = "project_ids"
|
||||
@@ -41,10 +36,6 @@ CHUNK_INDEX_FIELD_NAME = "chunk_index"
|
||||
MAX_CHUNK_SIZE_FIELD_NAME = "max_chunk_size"
|
||||
TENANT_ID_FIELD_NAME = "tenant_id"
|
||||
BLURB_FIELD_NAME = "blurb"
|
||||
DOC_SUMMARY_FIELD_NAME = "doc_summary"
|
||||
CHUNK_CONTEXT_FIELD_NAME = "chunk_context"
|
||||
PRIMARY_OWNERS_FIELD_NAME = "primary_owners"
|
||||
SECONDARY_OWNERS_FIELD_NAME = "secondary_owners"
|
||||
|
||||
|
||||
def get_opensearch_doc_chunk_id(
|
||||
@@ -61,27 +52,12 @@ def get_opensearch_doc_chunk_id(
|
||||
return f"{document_id}__{max_chunk_size}__{chunk_index}"
|
||||
|
||||
|
||||
def set_or_convert_timezone_to_utc(value: datetime) -> datetime:
|
||||
if value.tzinfo is None:
|
||||
# astimezone will raise if value does not have a timezone set.
|
||||
value = value.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
# Does appropriate time conversion if value was set in a different
|
||||
# timezone.
|
||||
value = value.astimezone(timezone.utc)
|
||||
return value
|
||||
|
||||
|
||||
class DocumentChunk(BaseModel):
|
||||
"""
|
||||
Represents a chunk of a document in the OpenSearch index.
|
||||
|
||||
The names of these fields are based on the OpenSearch schema. Changes to the
|
||||
schema require changes here. See get_document_schema.
|
||||
|
||||
WARNING: Relies on MULTI_TENANT which is global state. Also uses
|
||||
get_current_tenant_id. Generally relying on global state is bad, in this
|
||||
case we accept it because of the importance of validating tenant logic.
|
||||
"""
|
||||
|
||||
model_config = {"frozen": True}
|
||||
@@ -99,44 +75,41 @@ class DocumentChunk(BaseModel):
|
||||
title_vector: list[float] | None = None
|
||||
content: str
|
||||
content_vector: list[float]
|
||||
# The actual number of tokens in the chunk.
|
||||
num_tokens: int
|
||||
|
||||
source_type: str
|
||||
# Contains a string representation of a dict which maps string key to either
|
||||
# string value or list of string values.
|
||||
# TODO(andrei): When we augment content with metadata this can just be an
|
||||
# index pointer, and when we support metadata list that will just be a list
|
||||
# of strings.
|
||||
metadata: str
|
||||
# If it exists, time zone should always be UTC.
|
||||
# Application logic should store these strings the format key:::value.
|
||||
metadata: list[str] | None = None
|
||||
last_updated: datetime | None = None
|
||||
created_at: datetime | None = None
|
||||
|
||||
public: bool
|
||||
access_control_list: list[str]
|
||||
access_control_list: list[str] | None = None
|
||||
# Defaults to False, currently gets written during update not index.
|
||||
hidden: bool = False
|
||||
|
||||
global_boost: int
|
||||
global_boost: float = 1.0
|
||||
|
||||
semantic_identifier: str
|
||||
image_file_id: str | None = None
|
||||
image_file_name: str | None = None
|
||||
# Contains a string representation of a dict which maps offset into the raw
|
||||
# chunk text to the link corresponding to that point.
|
||||
source_links: str | None = None
|
||||
blurb: str
|
||||
doc_summary: str
|
||||
chunk_context: str
|
||||
|
||||
document_sets: list[str] | None = None
|
||||
# User projects.
|
||||
project_ids: list[int] | None = None
|
||||
primary_owners: list[str] | None = None
|
||||
secondary_owners: list[str] | None = None
|
||||
|
||||
tenant_id: TenantState = Field(
|
||||
default_factory=lambda: TenantState(
|
||||
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
|
||||
)
|
||||
)
|
||||
tenant_id: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_num_tokens_fits_within_max_chunk_size(self) -> Self:
|
||||
if self.num_tokens > self.max_chunk_size:
|
||||
raise ValueError(
|
||||
"Bug: Num tokens must be less than or equal to max chunk size."
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_title_and_title_vector_are_consistent(self) -> Self:
|
||||
@@ -147,116 +120,25 @@ class DocumentChunk(BaseModel):
|
||||
raise ValueError("Bug: Title must not be None if title vector is not None.")
|
||||
return self
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def serialize_model(
|
||||
self, handler: SerializerFunctionWrapHandler
|
||||
) -> dict[str, object]:
|
||||
"""Invokes pydantic's serialization logic, then excludes Nones.
|
||||
|
||||
We do this because .model_dump(exclude_none=True) does not work after
|
||||
@field_serializer logic, so for some field serializers which return None
|
||||
and which we would like to exclude from the final dump, they would be
|
||||
included without this.
|
||||
|
||||
Args:
|
||||
handler: Callable from pydantic which takes the instance of the
|
||||
model as an argument and performs standard serialization.
|
||||
|
||||
Returns:
|
||||
The return of handler but with None items excluded.
|
||||
"""
|
||||
serialized: dict[str, object] = handler(self)
|
||||
serialized_exclude_none = {k: v for k, v in serialized.items() if v is not None}
|
||||
return serialized_exclude_none
|
||||
|
||||
@field_serializer("last_updated", mode="wrap")
|
||||
@field_serializer("last_updated", "created_at", mode="plain")
|
||||
def serialize_datetime_fields_to_epoch_millis(
|
||||
self, value: datetime | None, handler: SerializerFunctionWrapHandler
|
||||
self, value: datetime | None
|
||||
) -> int | None:
|
||||
"""
|
||||
Serializes datetime fields to milliseconds since the Unix epoch.
|
||||
|
||||
If there is no datetime, returns None.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
value = set_or_convert_timezone_to_utc(value)
|
||||
if value.tzinfo is None:
|
||||
# astimezone will raise if value does not have a timezone set.
|
||||
value = value.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
# Does appropriate time conversion if value was set in a different
|
||||
# timezone.
|
||||
value = value.astimezone(timezone.utc)
|
||||
# timestamp returns a float in seconds so convert to millis.
|
||||
return int(value.timestamp() * 1000)
|
||||
|
||||
@field_validator("last_updated", mode="before")
|
||||
@classmethod
|
||||
def parse_epoch_millis_to_datetime(cls, value: Any) -> datetime | None:
|
||||
"""Parses milliseconds since the Unix epoch to a datetime object.
|
||||
|
||||
If the input is None, returns None.
|
||||
|
||||
The datetime returned will be in UTC.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, datetime):
|
||||
value = set_or_convert_timezone_to_utc(value)
|
||||
return value
|
||||
if not isinstance(value, int):
|
||||
raise ValueError(
|
||||
f"Bug: Expected an int for the last_updated property from OpenSearch, got {type(value)} instead."
|
||||
)
|
||||
return datetime.fromtimestamp(value / 1000, tz=timezone.utc)
|
||||
|
||||
@field_serializer("tenant_id", mode="wrap")
|
||||
def serialize_tenant_state(
|
||||
self, value: TenantState, handler: SerializerFunctionWrapHandler
|
||||
) -> str | None:
|
||||
"""
|
||||
Serializes tenant_state to the tenant str if multitenant, or None if
|
||||
not.
|
||||
|
||||
The idea is that in single tenant mode, the schema does not have a
|
||||
tenant_id field, so we don't want to supply it in our serialized
|
||||
DocumentChunk. This assumes the final serialized model excludes None
|
||||
fields, which serialize_model should enforce.
|
||||
"""
|
||||
if not value.multitenant:
|
||||
return None
|
||||
else:
|
||||
return value.tenant_id
|
||||
|
||||
@field_validator("tenant_id", mode="before")
|
||||
@classmethod
|
||||
def parse_tenant_id(cls, value: Any) -> TenantState:
|
||||
"""
|
||||
Generates a TenantState from OpenSearch's tenant_id if it exists, or
|
||||
generates a default state if it does not (implies we are in single
|
||||
tenant mode).
|
||||
"""
|
||||
if value is None:
|
||||
if MULTI_TENANT:
|
||||
raise ValueError(
|
||||
"Bug: No tenant_id was supplied but multi-tenant mode is enabled."
|
||||
)
|
||||
return TenantState(
|
||||
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
|
||||
)
|
||||
elif isinstance(value, TenantState):
|
||||
if MULTI_TENANT != value.multitenant:
|
||||
raise ValueError(
|
||||
f"Bug: An existing TenantState object was supplied to the DocumentChunk model but its multi-tenant mode "
|
||||
f"({value.multitenant}) does not match the program's current global tenancy state."
|
||||
)
|
||||
return value
|
||||
elif not isinstance(value, str):
|
||||
raise ValueError(
|
||||
f"Bug: Expected a str for the tenant_id property from OpenSearch, got {type(value)} instead."
|
||||
)
|
||||
else:
|
||||
if not MULTI_TENANT:
|
||||
raise ValueError(
|
||||
"Bug: Got a non-null str for the tenant_id property from OpenSearch but multi-tenant mode is not enabled. "
|
||||
"This is unexpected because in single-tenant mode we don't expect to see a tenant_id."
|
||||
)
|
||||
return TenantState(tenant_id=value, multitenant=MULTI_TENANT)
|
||||
|
||||
|
||||
class DocumentSchema:
|
||||
"""
|
||||
@@ -294,19 +176,13 @@ class DocumentSchema:
|
||||
OpenSearch client. The structure of this dictionary is
|
||||
determined by OpenSearch documentation.
|
||||
"""
|
||||
schema: dict[str, Any] = {
|
||||
# By default OpenSearch allows dynamically adding new properties
|
||||
# based on indexed documents. This is awful and we disable it here.
|
||||
# An exception will be raised if you try to index a new doc which
|
||||
# contains unexpected fields.
|
||||
"dynamic": "strict",
|
||||
schema = {
|
||||
"properties": {
|
||||
TITLE_FIELD_NAME: {
|
||||
"type": "text",
|
||||
"fields": {
|
||||
# Subfield accessed as title.keyword. Not indexed for
|
||||
# values longer than 256 chars.
|
||||
# TODO(andrei): Ask Yuhong do we want this?
|
||||
"keyword": {"type": "keyword", "ignore_above": 256}
|
||||
},
|
||||
},
|
||||
@@ -324,8 +200,6 @@ class DocumentSchema:
|
||||
"parameters": {"ef_construction": EF_CONSTRUCTION, "m": M},
|
||||
},
|
||||
},
|
||||
# TODO(andrei): This is a tensor in Vespa. Also look at feature
|
||||
# parity for these other method fields.
|
||||
CONTENT_VECTOR_FIELD_NAME: {
|
||||
"type": "knn_vector",
|
||||
"dimension": vector_dimension,
|
||||
@@ -336,10 +210,14 @@ class DocumentSchema:
|
||||
"parameters": {"ef_construction": EF_CONSTRUCTION, "m": M},
|
||||
},
|
||||
},
|
||||
# See TODO in _convert_onyx_chunk_to_opensearch_document. I
|
||||
# don't want to actually add this to the schema until we know
|
||||
# for sure we need it. If we decide we don't I will remove this.
|
||||
# # Number of tokens in the chunk's content.
|
||||
# NUM_TOKENS_FIELD_NAME: {"type": "integer", "store": True},
|
||||
SOURCE_TYPE_FIELD_NAME: {"type": "keyword"},
|
||||
# Application logic should store in the format key:::value.
|
||||
METADATA_FIELD_NAME: {"type": "keyword"},
|
||||
# TODO(andrei): Check if Vespa stores seconds, we may wanna do
|
||||
# seconds here not millis.
|
||||
LAST_UPDATED_FIELD_NAME: {
|
||||
"type": "date",
|
||||
"format": "epoch_millis",
|
||||
@@ -347,6 +225,16 @@ class DocumentSchema:
|
||||
# would make sense to sort by date.
|
||||
"doc_values": True,
|
||||
},
|
||||
# See TODO in _convert_onyx_chunk_to_opensearch_document. I
|
||||
# don't want to actually add this to the schema until we know
|
||||
# for sure we need it. If we decide we don't I will remove this.
|
||||
# CREATED_AT_FIELD_NAME: {
|
||||
# "type": "date",
|
||||
# "format": "epoch_millis",
|
||||
# # For some reason date defaults to False, even though it
|
||||
# # would make sense to sort by date.
|
||||
# "doc_values": True,
|
||||
# },
|
||||
# Access control fields.
|
||||
# Whether the doc is public. Could have fallen under access
|
||||
# control list but is such a broad and critical filter that it
|
||||
@@ -359,7 +247,7 @@ class DocumentSchema:
|
||||
# all other search filters; up to search implementations to
|
||||
# guarantee this.
|
||||
HIDDEN_FIELD_NAME: {"type": "boolean"},
|
||||
GLOBAL_BOOST_FIELD_NAME: {"type": "integer"},
|
||||
GLOBAL_BOOST_FIELD_NAME: {"type": "float"},
|
||||
# This field is only used for displaying a useful name for the
|
||||
# doc in the UI and is not used for searching. Disabling these
|
||||
# features to increase perf.
|
||||
@@ -370,7 +258,7 @@ class DocumentSchema:
|
||||
"store": False,
|
||||
},
|
||||
# Same as above; used to display an image along with the doc.
|
||||
IMAGE_FILE_ID_FIELD_NAME: {
|
||||
IMAGE_FILE_NAME_FIELD_NAME: {
|
||||
"type": "keyword",
|
||||
"index": False,
|
||||
"doc_values": False,
|
||||
@@ -390,36 +278,15 @@ class DocumentSchema:
|
||||
"doc_values": False,
|
||||
"store": False,
|
||||
},
|
||||
# Same as above.
|
||||
# TODO(andrei): If we want to search on this this needs to be
|
||||
# changed.
|
||||
DOC_SUMMARY_FIELD_NAME: {
|
||||
"type": "keyword",
|
||||
"index": False,
|
||||
"doc_values": False,
|
||||
"store": False,
|
||||
},
|
||||
# Same as above.
|
||||
# TODO(andrei): If we want to search on this this needs to be
|
||||
# changed.
|
||||
CHUNK_CONTEXT_FIELD_NAME: {
|
||||
"type": "keyword",
|
||||
"index": False,
|
||||
"doc_values": False,
|
||||
"store": False,
|
||||
},
|
||||
# Product-specific fields.
|
||||
DOCUMENT_SETS_FIELD_NAME: {"type": "keyword"},
|
||||
PROJECT_IDS_FIELD_NAME: {"type": "integer"},
|
||||
PRIMARY_OWNERS_FIELD_NAME: {"type": "keyword"},
|
||||
SECONDARY_OWNERS_FIELD_NAME: {"type": "keyword"},
|
||||
# OpenSearch metadata fields.
|
||||
DOCUMENT_ID_FIELD_NAME: {"type": "keyword"},
|
||||
CHUNK_INDEX_FIELD_NAME: {"type": "integer"},
|
||||
# The maximum number of tokens this chunk's content can hold.
|
||||
# TODO(andrei): Can we generalize this to embedding type?
|
||||
MAX_CHUNK_SIZE_FIELD_NAME: {"type": "integer"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if multitenant:
|
||||
|
||||
@@ -24,7 +24,7 @@ from onyx.document_index.opensearch.schema import TITLE_VECTOR_FIELD_NAME
|
||||
# TODO(andrei): Turn all magic dictionaries to pydantic models.
|
||||
|
||||
MIN_MAX_NORMALIZATION_PIPELINE_NAME = "normalization_pipeline_min_max"
|
||||
MIN_MAX_NORMALIZATION_PIPELINE_CONFIG: dict[str, Any] = {
|
||||
MIN_MAX_NORMALIZATION_PIPELINE_CONFIG = {
|
||||
"description": "Normalization for keyword and vector scores using min-max",
|
||||
"phase_results_processors": [
|
||||
{
|
||||
@@ -49,7 +49,7 @@ MIN_MAX_NORMALIZATION_PIPELINE_CONFIG: dict[str, Any] = {
|
||||
}
|
||||
|
||||
ZSCORE_NORMALIZATION_PIPELINE_NAME = "normalization_pipeline_zscore"
|
||||
ZSCORE_NORMALIZATION_PIPELINE_CONFIG: dict[str, Any] = {
|
||||
ZSCORE_NORMALIZATION_PIPELINE_CONFIG = {
|
||||
"description": "Normalization for keyword and vector scores using z-score",
|
||||
"phase_results_processors": [
|
||||
{
|
||||
@@ -140,7 +140,7 @@ class DocumentQuery:
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
|
||||
]
|
||||
|
||||
if tenant_state.multitenant:
|
||||
if tenant_state.tenant_id is not None:
|
||||
# TODO(andrei): Fix tenant stuff.
|
||||
filter_clauses.append(
|
||||
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
|
||||
@@ -199,7 +199,7 @@ class DocumentQuery:
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
|
||||
]
|
||||
|
||||
if tenant_state.multitenant:
|
||||
if tenant_state.tenant_id is not None:
|
||||
filter_clauses.append(
|
||||
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
|
||||
)
|
||||
@@ -316,7 +316,6 @@ class DocumentQuery:
|
||||
{
|
||||
"multi_match": {
|
||||
"query": query_text,
|
||||
# TODO(andrei): Ask Yuhong do we want this?
|
||||
"fields": [f"{TITLE_FIELD_NAME}^2", f"{TITLE_FIELD_NAME}.keyword"],
|
||||
"type": "best_fields",
|
||||
}
|
||||
@@ -341,7 +340,7 @@ class DocumentQuery:
|
||||
{"term": {PUBLIC_FIELD_NAME: {"value": True}}},
|
||||
{"term": {HIDDEN_FIELD_NAME: {"value": False}}},
|
||||
]
|
||||
if tenant_state.multitenant:
|
||||
if tenant_state.tenant_id is not None:
|
||||
hybrid_search_filters.append(
|
||||
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
|
||||
)
|
||||
|
||||
@@ -164,7 +164,7 @@ def format_document_soup(
|
||||
|
||||
|
||||
def parse_html_page_basic(text: str | BytesIO | IO[bytes]) -> str:
|
||||
soup = bs4.BeautifulSoup(text, "lxml")
|
||||
soup = bs4.BeautifulSoup(text, "html.parser")
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
@@ -174,7 +174,7 @@ def web_html_cleanup(
|
||||
additional_element_types_to_discard: list[str] | None = None,
|
||||
) -> ParsedHTML:
|
||||
if isinstance(page_content, str):
|
||||
soup = bs4.BeautifulSoup(page_content, "lxml")
|
||||
soup = bs4.BeautifulSoup(page_content, "html.parser")
|
||||
else:
|
||||
soup = page_content
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from unstructured_client.models import operations
|
||||
from unstructured_client.models import operations # type: ignore
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -55,19 +55,19 @@ def _sdk_partition_request(
|
||||
|
||||
def unstructured_to_text(file: IO[Any], file_name: str) -> str:
|
||||
from unstructured.staging.base import dict_to_elements
|
||||
from unstructured_client import UnstructuredClient
|
||||
from unstructured_client import UnstructuredClient # type: ignore
|
||||
|
||||
logger.debug(f"Starting to read file: {file_name}")
|
||||
req = _sdk_partition_request(file, file_name, strategy="fast")
|
||||
|
||||
unstructured_client = UnstructuredClient(api_key_auth=get_unstructured_api_key())
|
||||
|
||||
response = unstructured_client.general.partition(request=req)
|
||||
response = unstructured_client.general.partition(req)
|
||||
elements = dict_to_elements(response.elements)
|
||||
|
||||
if response.status_code != 200:
|
||||
err = f"Received unexpected status code {response.status_code} from Unstructured API."
|
||||
logger.error(err)
|
||||
raise ValueError(err)
|
||||
|
||||
elements = dict_to_elements(response.elements or [])
|
||||
return "\n\n".join(str(el) for el in elements)
|
||||
|
||||
@@ -6,19 +6,15 @@ from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm.session import TransactionalContext
|
||||
|
||||
from onyx.access.access import get_access_for_user_files
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.constants import DEFAULT_BOOST
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.user_file import fetch_chunk_counts_for_user_files
|
||||
from onyx.db.user_file import fetch_user_project_ids_for_user_files
|
||||
from onyx.file_store.utils import store_user_file_plaintext
|
||||
@@ -198,42 +194,6 @@ class UserFileIndexingAdapter:
|
||||
user_file_id_to_token_count=user_file_id_to_token_count,
|
||||
)
|
||||
|
||||
def _notify_assistant_owners_if_files_ready(
|
||||
self, user_files: list[UserFile]
|
||||
) -> None:
|
||||
"""
|
||||
Check if all files for associated assistants are processed and notify owners.
|
||||
Only sends notification when all files for an assistant are COMPLETED.
|
||||
"""
|
||||
for user_file in user_files:
|
||||
if user_file.status == UserFileStatus.COMPLETED:
|
||||
for assistant in user_file.assistants:
|
||||
# Skip assistants without owners
|
||||
if assistant.user_id is None:
|
||||
continue
|
||||
|
||||
# Check if all OTHER files for this assistant are completed
|
||||
# (we already know current file is completed from the outer check)
|
||||
all_files_completed = all(
|
||||
f.status == UserFileStatus.COMPLETED
|
||||
for f in assistant.user_files
|
||||
if f.id != user_file.id
|
||||
)
|
||||
|
||||
if all_files_completed:
|
||||
create_notification(
|
||||
user_id=assistant.user_id,
|
||||
notif_type=NotificationType.ASSISTANT_FILES_READY,
|
||||
db_session=self.db_session,
|
||||
title="Your files are ready!",
|
||||
description=f"All files for agent {assistant.name} have been processed and are now available.",
|
||||
additional_data={
|
||||
"persona_id": assistant.id,
|
||||
"link": f"/assistants/{assistant.id}",
|
||||
},
|
||||
autocommit=False,
|
||||
)
|
||||
|
||||
def post_index(
|
||||
self,
|
||||
context: DocumentBatchPrepareContext,
|
||||
@@ -244,10 +204,7 @@ class UserFileIndexingAdapter:
|
||||
user_file_ids = [doc.id for doc in context.updatable_docs]
|
||||
|
||||
user_files = (
|
||||
self.db_session.query(UserFile)
|
||||
.options(selectinload(UserFile.assistants).selectinload(Persona.user_files))
|
||||
.filter(UserFile.id.in_(user_file_ids))
|
||||
.all()
|
||||
self.db_session.query(UserFile).filter(UserFile.id.in_(user_file_ids)).all()
|
||||
)
|
||||
for user_file in user_files:
|
||||
# don't update the status if the user file is being deleted
|
||||
@@ -260,10 +217,6 @@ class UserFileIndexingAdapter:
|
||||
user_file.token_count = result.user_file_id_to_token_count[
|
||||
str(user_file.id)
|
||||
]
|
||||
|
||||
# Notify assistant owners if all their files are now processed
|
||||
self._notify_assistant_owners_if_files_ready(user_files)
|
||||
|
||||
self.db_session.commit()
|
||||
|
||||
# Store the plaintext in the file store for faster retrieval
|
||||
|
||||
@@ -40,7 +40,6 @@ class BaseChunk(BaseModel):
|
||||
source_links: dict[int, str] | None
|
||||
image_file_id: str | None
|
||||
# True if this Chunk's start is not at the start of a Section
|
||||
# TODO(andrei): This is deprecated as of the OpenSearch migration. Remove.
|
||||
section_continuation: bool
|
||||
|
||||
|
||||
|
||||
@@ -369,8 +369,6 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
# New output item added
|
||||
output_item = parsed_chunk.get("item", {})
|
||||
if output_item.get("type") == "function_call":
|
||||
# Track that we've received tool calls via streaming
|
||||
self._has_streamed_tool_calls = True
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=ChatCompletionToolCallChunk(
|
||||
@@ -396,8 +394,6 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
content_part: Optional[str] = parsed_chunk.get("delta", None)
|
||||
if content_part:
|
||||
# Track that we've received tool calls via streaming
|
||||
self._has_streamed_tool_calls = True
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=ChatCompletionToolCallChunk(
|
||||
@@ -495,71 +491,21 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
|
||||
elif event_type == "response.completed":
|
||||
# Final event signaling all output items (including parallel tool calls) are done
|
||||
# Check if we already received tool calls via streaming events
|
||||
# There is an issue where OpenAI (not via Azure) will give back the tool calls streamed out as tokens
|
||||
# But on Azure, it's only given out all at once. OpenAI also happens to give back the tool calls in the
|
||||
# response.completed event so we need to throw it out here or there are duplicate tool calls.
|
||||
has_streamed_tool_calls = getattr(self, "_has_streamed_tool_calls", False)
|
||||
|
||||
response_data = parsed_chunk.get("response", {})
|
||||
output_items = response_data.get("output", [])
|
||||
|
||||
# Check if there are function_call items in the output
|
||||
has_function_calls = any(
|
||||
isinstance(item, dict) and item.get("type") == "function_call"
|
||||
for item in output_items
|
||||
)
|
||||
|
||||
if has_function_calls and not has_streamed_tool_calls:
|
||||
# Azure's Responses API returns all tool calls in response.completed
|
||||
# without streaming them incrementally. Extract them here.
|
||||
from litellm.types.utils import (
|
||||
Delta,
|
||||
ModelResponseStream,
|
||||
StreamingChoices,
|
||||
)
|
||||
|
||||
tool_calls = []
|
||||
for idx, item in enumerate(output_items):
|
||||
# Determine finish reason based on response content
|
||||
finish_reason = "stop"
|
||||
if response_data.get("output"):
|
||||
for item in response_data["output"]:
|
||||
if isinstance(item, dict) and item.get("type") == "function_call":
|
||||
tool_calls.append(
|
||||
ChatCompletionToolCallChunk(
|
||||
id=item.get("call_id"),
|
||||
index=idx,
|
||||
type="function",
|
||||
function=ChatCompletionToolCallFunctionChunk(
|
||||
name=item.get("name"),
|
||||
arguments=item.get("arguments", ""),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return ModelResponseStream(
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
index=0,
|
||||
delta=Delta(tool_calls=tool_calls),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
]
|
||||
)
|
||||
elif has_function_calls:
|
||||
# Tool calls were already streamed, just signal completion
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason="tool_calls",
|
||||
usage=None,
|
||||
)
|
||||
else:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason="stop",
|
||||
usage=None,
|
||||
)
|
||||
finish_reason = "tool_calls"
|
||||
break
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason=finish_reason,
|
||||
usage=None,
|
||||
)
|
||||
|
||||
else:
|
||||
pass
|
||||
@@ -685,40 +631,6 @@ def _patch_openai_responses_transform_response() -> None:
|
||||
LiteLLMResponsesTransformationHandler.transform_response = _patched_transform_response # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _patch_azure_responses_should_fake_stream() -> None:
|
||||
"""
|
||||
Patches AzureOpenAIResponsesAPIConfig.should_fake_stream to always return False.
|
||||
|
||||
By default, LiteLLM uses "fake streaming" (MockResponsesAPIStreamingIterator) for models
|
||||
not in its database. This causes Azure custom model deployments to buffer the entire
|
||||
response before yielding, resulting in poor time-to-first-token.
|
||||
|
||||
Azure's Responses API supports native streaming, so we override this to always use
|
||||
real streaming (SyncResponsesAPIStreamingIterator).
|
||||
"""
|
||||
from litellm.llms.azure.responses.transformation import (
|
||||
AzureOpenAIResponsesAPIConfig,
|
||||
)
|
||||
|
||||
if (
|
||||
getattr(AzureOpenAIResponsesAPIConfig.should_fake_stream, "__name__", "")
|
||||
== "_patched_should_fake_stream"
|
||||
):
|
||||
return
|
||||
|
||||
def _patched_should_fake_stream(
|
||||
self: Any,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
# Azure Responses API supports native streaming - never fake it
|
||||
return False
|
||||
|
||||
_patched_should_fake_stream.__name__ = "_patched_should_fake_stream"
|
||||
AzureOpenAIResponsesAPIConfig.should_fake_stream = _patched_should_fake_stream # type: ignore[method-assign]
|
||||
|
||||
|
||||
def apply_monkey_patches() -> None:
|
||||
"""
|
||||
Apply all necessary monkey patches to LiteLLM for compatibility.
|
||||
@@ -728,13 +640,12 @@ def apply_monkey_patches() -> None:
|
||||
- Patching OllamaChatCompletionResponseIterator.chunk_parser for streaming content
|
||||
- Patching OpenAiResponsesToChatCompletionStreamIterator.chunk_parser for OpenAI Responses API
|
||||
- Patching LiteLLMResponsesTransformationHandler.transform_response for non-streaming responses
|
||||
- Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming
|
||||
- Patching LiteLLMResponsesTransformationHandler._convert_content_str_to_input_text for tool content types
|
||||
"""
|
||||
_patch_ollama_transform_request()
|
||||
_patch_ollama_chunk_parser()
|
||||
_patch_openai_responses_chunk_parser()
|
||||
_patch_openai_responses_transform_response()
|
||||
_patch_azure_responses_should_fake_stream()
|
||||
|
||||
|
||||
def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[str]]:
|
||||
|
||||
@@ -63,7 +63,7 @@ def process_with_prompt_cache(
|
||||
return suffix, None
|
||||
|
||||
# Get provider adapter
|
||||
provider_adapter = get_provider_adapter(llm_config)
|
||||
provider_adapter = get_provider_adapter(llm_config.model_provider)
|
||||
|
||||
# If provider doesn't support caching, combine and return unchanged
|
||||
if not provider_adapter.supports_caching():
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
"""Factory for creating provider-specific prompt cache adapters."""
|
||||
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.prompt_cache.providers.anthropic import AnthropicPromptCacheProvider
|
||||
from onyx.llm.prompt_cache.providers.base import PromptCacheProvider
|
||||
from onyx.llm.prompt_cache.providers.noop import NoOpPromptCacheProvider
|
||||
from onyx.llm.prompt_cache.providers.openai import OpenAIPromptCacheProvider
|
||||
from onyx.llm.prompt_cache.providers.vertex import VertexAIPromptCacheProvider
|
||||
|
||||
ANTHROPIC_BEDROCK_TAG = "anthropic."
|
||||
|
||||
|
||||
def get_provider_adapter(llm_config: LLMConfig) -> PromptCacheProvider:
|
||||
def get_provider_adapter(provider: str) -> PromptCacheProvider:
|
||||
"""Get the appropriate prompt cache provider adapter for a given provider.
|
||||
|
||||
Args:
|
||||
@@ -20,14 +17,11 @@ def get_provider_adapter(llm_config: LLMConfig) -> PromptCacheProvider:
|
||||
Returns:
|
||||
PromptCacheProvider instance for the given provider
|
||||
"""
|
||||
if llm_config.model_provider == LlmProviderNames.OPENAI:
|
||||
if provider == LlmProviderNames.OPENAI:
|
||||
return OpenAIPromptCacheProvider()
|
||||
elif llm_config.model_provider == LlmProviderNames.ANTHROPIC or (
|
||||
llm_config.model_provider == LlmProviderNames.BEDROCK
|
||||
and ANTHROPIC_BEDROCK_TAG in llm_config.model_name
|
||||
):
|
||||
elif provider in [LlmProviderNames.ANTHROPIC, LlmProviderNames.BEDROCK]:
|
||||
return AnthropicPromptCacheProvider()
|
||||
elif llm_config.model_provider == LlmProviderNames.VERTEX_AI:
|
||||
elif provider == LlmProviderNames.VERTEX_AI:
|
||||
return VertexAIPromptCacheProvider()
|
||||
else:
|
||||
# Default to no-op for providers without caching support
|
||||
|
||||
@@ -48,7 +48,7 @@ class VertexAIPromptCacheProvider(PromptCacheProvider):
|
||||
cacheable_prefix=cacheable_prefix,
|
||||
suffix=suffix,
|
||||
continuation=continuation,
|
||||
transform_cacheable=None, # TODO: support explicit caching
|
||||
transform_cacheable=_add_vertex_cache_control,
|
||||
)
|
||||
|
||||
def extract_cache_metadata(
|
||||
@@ -89,10 +89,6 @@ def _add_vertex_cache_control(
|
||||
not at the message level. This function converts string content to the array format
|
||||
and adds cache_control to the last content block in each cacheable message.
|
||||
"""
|
||||
# NOTE: unfortunately we need a much more sophisticated mechnism to support
|
||||
# explict caching with vertex in the presence of tools and system messages
|
||||
# (since they're supposed to be stripped out when setting cache_control)
|
||||
# so we're deferring this to a future PR.
|
||||
updated: list[ChatCompletionMessage] = []
|
||||
for message in messages:
|
||||
mutated = dict(message)
|
||||
|
||||
@@ -82,6 +82,7 @@ def fetch_llm_recommendations_from_github(
|
||||
|
||||
def sync_llm_models_from_github(
|
||||
db_session: Session,
|
||||
config: LLMRecommendations,
|
||||
force: bool = False,
|
||||
) -> dict[str, int]:
|
||||
"""Sync models from GitHub config to database for all Auto mode providers.
|
||||
@@ -100,24 +101,19 @@ def sync_llm_models_from_github(
|
||||
Returns:
|
||||
Dict of provider_name -> number of changes made.
|
||||
"""
|
||||
results: dict[str, int] = {}
|
||||
|
||||
# Get all providers in Auto mode
|
||||
auto_providers = fetch_auto_mode_providers(db_session)
|
||||
if not auto_providers:
|
||||
logger.debug("No providers in Auto mode found")
|
||||
return {}
|
||||
|
||||
# Fetch config from GitHub
|
||||
config = fetch_llm_recommendations_from_github()
|
||||
if not config:
|
||||
logger.warning("Failed to fetch GitHub config")
|
||||
return {}
|
||||
|
||||
# Skip if we've already processed this version (unless forced)
|
||||
last_updated_at = _get_cached_last_updated_at()
|
||||
if not force and last_updated_at and config.updated_at <= last_updated_at:
|
||||
logger.debug("GitHub config unchanged, skipping sync")
|
||||
return {}
|
||||
|
||||
results: dict[str, int] = {}
|
||||
|
||||
# Get all providers in Auto mode
|
||||
auto_providers = fetch_auto_mode_providers(db_session)
|
||||
|
||||
if not auto_providers:
|
||||
logger.debug("No providers in Auto mode found")
|
||||
_set_cached_last_updated_at(config.updated_at)
|
||||
return {}
|
||||
|
||||
|
||||
@@ -35,7 +35,6 @@ from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
|
||||
from onyx.onyxbot.slack.utils import SlackRateLimiter
|
||||
from onyx.onyxbot.slack.utils import update_emote_react
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.utils.logger import OnyxLoggingAdapter
|
||||
|
||||
srl = SlackRateLimiter()
|
||||
@@ -237,7 +236,6 @@ def handle_regular_answer(
|
||||
retrieval_details=retrieval_details,
|
||||
rerank_settings=None, # Rerank customization supported in Slack flow
|
||||
db_session=db_session,
|
||||
origin=MessageOrigin.SLACKBOT,
|
||||
)
|
||||
|
||||
# if it's a DM or ephemeral message, answer based on private documents.
|
||||
|
||||
@@ -1,39 +1,30 @@
|
||||
from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS
|
||||
|
||||
SLACK_QUERY_EXPANSION_PROMPT = f"""
|
||||
Rewrite the user's query into at most {MAX_SLACK_QUERY_EXPANSIONS} keyword-only queries for Slack's keyword search.
|
||||
Rewrite the user's query and, if helpful, split it into at most {MAX_SLACK_QUERY_EXPANSIONS} \
|
||||
keyword-only queries, so that Slack's keyword search yields the best matches.
|
||||
|
||||
Slack search behavior:
|
||||
- Pure keyword AND search (no semantics)
|
||||
- More words = fewer matches, so keep queries concise (1-3 words)
|
||||
Keep in mind the Slack's search behavior:
|
||||
- Pure keyword AND search (no semantics).
|
||||
- Word order matters.
|
||||
- More words = fewer matches, so keep each query concise.
|
||||
- IMPORTANT: Prefer simple 1-2 word queries over longer multi-word queries.
|
||||
|
||||
ALWAYS include:
|
||||
- Person names (e.g., "Sarah Chen", "Mike Johnson") - people search for messages from/about specific people
|
||||
- Project/product names, technical terms, proper nouns
|
||||
- Actual content words: "performance", "bug", "deployment", "API", "error"
|
||||
Critical: Extract ONLY keywords that would actually appear in Slack message content.
|
||||
|
||||
DO NOT include:
|
||||
- Meta-words: "topics", "conversations", "discussed", "summary", "messages"
|
||||
- Temporal: "today", "yesterday", "week", "month", "recent", "last"
|
||||
- Channel names: "general", "eng-general", "random"
|
||||
- Meta-words: "topics", "conversations", "discussed", "summary", "messages", "big", "main", "talking"
|
||||
- Temporal: "today", "yesterday", "week", "month", "recent", "past", "last"
|
||||
- Channels/Users: "general", "eng-general", "engineering", "@username"
|
||||
|
||||
DO include:
|
||||
- Actual content: "performance", "bug", "deployment", "API", "database", "error", "feature"
|
||||
|
||||
Examples:
|
||||
|
||||
Query: "what are the big topics in eng-general this week?"
|
||||
Output:
|
||||
|
||||
Query: "messages with Sarah about the deployment"
|
||||
Output:
|
||||
Sarah deployment
|
||||
Sarah
|
||||
deployment
|
||||
|
||||
Query: "what did Mike say about the budget?"
|
||||
Output:
|
||||
Mike budget
|
||||
Mike
|
||||
budget
|
||||
|
||||
Query: "performance issues in eng-general"
|
||||
Output:
|
||||
performance issues
|
||||
@@ -50,7 +41,7 @@ Now process this query:
|
||||
|
||||
{{query}}
|
||||
|
||||
Output (keywords only, one per line, NO explanations or commentary):
|
||||
Output:
|
||||
"""
|
||||
|
||||
SLACK_DATE_EXTRACTION_PROMPT = """
|
||||
|
||||
@@ -109,7 +109,6 @@ class TenantRedis(redis.Redis):
|
||||
"unlock",
|
||||
"get",
|
||||
"set",
|
||||
"setex",
|
||||
"delete",
|
||||
"exists",
|
||||
"incrby",
|
||||
|
||||
@@ -697,7 +697,7 @@ def save_user_credentials(
|
||||
# TODO: fix and/or type correctly w/base model
|
||||
config_data = MCPConnectionData(
|
||||
headers=auth_template.config.get("headers", {}),
|
||||
header_substitutions=request.credentials,
|
||||
header_substitutions=auth_template.config.get(HEADER_SUBSTITUTIONS, {}),
|
||||
)
|
||||
for oauth_field_key in MCPOAuthKeys:
|
||||
field_key: Literal["client_info", "tokens", "metadata"] = (
|
||||
|
||||
@@ -9,13 +9,11 @@ from onyx.db.models import User
|
||||
from onyx.db.notification import dismiss_notification
|
||||
from onyx.db.notification import get_notification_by_id
|
||||
from onyx.db.notification import get_notifications
|
||||
from onyx.server.features.release_notes.utils import (
|
||||
ensure_release_notes_fresh_and_notify,
|
||||
)
|
||||
from onyx.server.settings.models import Notification as NotificationModel
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/notifications")
|
||||
|
||||
|
||||
@@ -24,27 +22,9 @@ def get_notifications_api(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[NotificationModel]:
|
||||
"""
|
||||
Get all undismissed notifications for the current user.
|
||||
|
||||
Note: also executes background checks that should create notifications.
|
||||
|
||||
Examples of checks that create new notifications:
|
||||
- Checking for new release notes the user hasn't seen
|
||||
- Checking for misconfigurations due to version changes
|
||||
- Explicitly announcing breaking changes
|
||||
"""
|
||||
# If more background checks are added, this should be moved to a helper function
|
||||
try:
|
||||
ensure_release_notes_fresh_and_notify(db_session)
|
||||
except Exception:
|
||||
# Log exception but don't fail the entire endpoint
|
||||
# Users can still see their existing notifications
|
||||
logger.exception("Failed to check for release notes in notifications endpoint")
|
||||
|
||||
notifications = [
|
||||
NotificationModel.from_model(notif)
|
||||
for notif in get_notifications(user, db_session, include_dismissed=True)
|
||||
for notif in get_notifications(user, db_session, include_dismissed=False)
|
||||
]
|
||||
return notifications
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ from onyx.db.persona import mark_persona_as_not_deleted
|
||||
from onyx.db.persona import update_persona_is_default
|
||||
from onyx.db.persona import update_persona_label
|
||||
from onyx.db.persona import update_persona_public_status
|
||||
from onyx.db.persona import update_persona_shared
|
||||
from onyx.db.persona import update_persona_shared_users
|
||||
from onyx.db.persona import update_persona_visibility
|
||||
from onyx.db.persona import update_personas_display_priority
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
@@ -366,9 +366,7 @@ def delete_label(
|
||||
|
||||
|
||||
class PersonaShareRequest(BaseModel):
|
||||
user_ids: list[UUID] | None = None
|
||||
group_ids: list[int] | None = None
|
||||
is_public: bool | None = None
|
||||
user_ids: list[UUID]
|
||||
|
||||
|
||||
# We notify each user when a user is shared with them
|
||||
@@ -379,13 +377,11 @@ def share_persona(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_persona_shared(
|
||||
update_persona_shared_users(
|
||||
persona_id=persona_id,
|
||||
user_ids=persona_share_request.user_ids,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
user_ids=persona_share_request.user_ids,
|
||||
group_ids=persona_share_request.group_ids,
|
||||
is_public=persona_share_request.is_public,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
"""Constants for release notes functionality."""
|
||||
|
||||
# GitHub source
|
||||
GITHUB_RAW_BASE_URL = (
|
||||
"https://raw.githubusercontent.com/onyx-dot-app/documentation/main"
|
||||
)
|
||||
GITHUB_CHANGELOG_RAW_URL = f"{GITHUB_RAW_BASE_URL}/changelog.mdx"
|
||||
|
||||
# Base URL for changelog documentation (used for notification links)
|
||||
DOCS_CHANGELOG_BASE_URL = "https://docs.onyx.app/changelog"
|
||||
|
||||
FETCH_TIMEOUT = 60.0
|
||||
|
||||
# Redis keys (in shared namespace)
|
||||
REDIS_KEY_PREFIX = "release_notes:"
|
||||
REDIS_KEY_FETCHED_AT = f"{REDIS_KEY_PREFIX}fetched_at"
|
||||
REDIS_KEY_ETAG = f"{REDIS_KEY_PREFIX}etag"
|
||||
|
||||
# Cache TTL: 24 hours
|
||||
REDIS_CACHE_TTL = 60 * 60 * 24
|
||||
|
||||
# Auto-refresh threshold: 1 hour
|
||||
AUTO_REFRESH_THRESHOLD_SECONDS = 60 * 60
|
||||
@@ -1,11 +0,0 @@
|
||||
"""Pydantic models for release notes."""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ReleaseNoteEntry(BaseModel):
|
||||
"""A single version's release note entry."""
|
||||
|
||||
version: str # e.g., "v2.7.0"
|
||||
date: str # e.g., "January 7th, 2026"
|
||||
title: str # Display title for notifications: "Onyx v2.7.0 is available!"
|
||||
@@ -1,247 +0,0 @@
|
||||
"""Utility functions for release notes parsing and caching."""
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx import __version__
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.release_notes import create_release_notifications_for_versions
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
from onyx.server.features.release_notes.constants import AUTO_REFRESH_THRESHOLD_SECONDS
|
||||
from onyx.server.features.release_notes.constants import FETCH_TIMEOUT
|
||||
from onyx.server.features.release_notes.constants import GITHUB_CHANGELOG_RAW_URL
|
||||
from onyx.server.features.release_notes.constants import REDIS_CACHE_TTL
|
||||
from onyx.server.features.release_notes.constants import REDIS_KEY_ETAG
|
||||
from onyx.server.features.release_notes.constants import REDIS_KEY_FETCHED_AT
|
||||
from onyx.server.features.release_notes.models import ReleaseNoteEntry
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Version Utilities
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def is_valid_version(version: str) -> bool:
|
||||
"""Check if version matches vX.Y.Z or vX.Y.Z-suffix.N pattern exactly."""
|
||||
return bool(re.match(r"^v\d+\.\d+\.\d+(-[a-zA-Z]+\.\d+)?$", version))
|
||||
|
||||
|
||||
def parse_version_tuple(version: str) -> tuple[int, int, int]:
|
||||
"""Parse version string to tuple for semantic sorting."""
|
||||
clean = re.sub(r"^v", "", version)
|
||||
clean = re.sub(r"-.*$", "", clean)
|
||||
parts = clean.split(".")
|
||||
return (
|
||||
int(parts[0]) if len(parts) > 0 else 0,
|
||||
int(parts[1]) if len(parts) > 1 else 0,
|
||||
int(parts[2]) if len(parts) > 2 else 0,
|
||||
)
|
||||
|
||||
|
||||
def is_version_gte(v1: str, v2: str) -> bool:
|
||||
"""Check if v1 >= v2. Strips suffixes like -cloud.X or -beta.X."""
|
||||
return parse_version_tuple(v1) >= parse_version_tuple(v2)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# MDX Parsing
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def parse_mdx_to_release_note_entries(mdx_content: str) -> list[ReleaseNoteEntry]:
|
||||
"""Parse MDX content into ReleaseNoteEntry objects for versions >= __version__."""
|
||||
all_entries = []
|
||||
|
||||
update_pattern = (
|
||||
r'<Update\s+label="([^"]+)"\s+description="([^"]+)"'
|
||||
r"(?:\s+tags=\{([^}]+)\})?[^>]*>"
|
||||
r".*?"
|
||||
r"</Update>"
|
||||
)
|
||||
|
||||
for match in re.finditer(update_pattern, mdx_content, re.DOTALL):
|
||||
version = match.group(1)
|
||||
date = match.group(2)
|
||||
|
||||
if is_valid_version(version):
|
||||
all_entries.append(
|
||||
ReleaseNoteEntry(
|
||||
version=version,
|
||||
date=date,
|
||||
title=f"Onyx {version} is available!",
|
||||
)
|
||||
)
|
||||
|
||||
if not all_entries:
|
||||
raise ValueError("Could not parse any release note entries from MDX.")
|
||||
|
||||
# Filter to valid versions >= __version__
|
||||
if __version__ and is_valid_version(__version__):
|
||||
entries = [
|
||||
entry for entry in all_entries if is_version_gte(entry.version, __version__)
|
||||
]
|
||||
elif "nightly" in __version__:
|
||||
# Just show the latest entry for nightly versions
|
||||
entries = sorted(
|
||||
all_entries, key=lambda x: parse_version_tuple(x.version), reverse=True
|
||||
)[:1]
|
||||
else:
|
||||
# If not recognized version
|
||||
# likely `development` and we should show all entries
|
||||
entries = all_entries
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Cache Helpers (ETag + timestamp only)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def get_cached_etag() -> str | None:
|
||||
"""Get the cached GitHub ETag from Redis."""
|
||||
redis_client = get_shared_redis_client()
|
||||
try:
|
||||
etag = redis_client.get(REDIS_KEY_ETAG)
|
||||
if etag:
|
||||
return etag.decode("utf-8") if isinstance(etag, bytes) else str(etag)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get cached etag from Redis: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_last_fetch_time() -> datetime | None:
|
||||
"""Get the last fetch timestamp from Redis."""
|
||||
redis_client = get_shared_redis_client()
|
||||
try:
|
||||
fetched_at_str = redis_client.get(REDIS_KEY_FETCHED_AT)
|
||||
if not fetched_at_str:
|
||||
return None
|
||||
|
||||
decoded = (
|
||||
fetched_at_str.decode("utf-8")
|
||||
if isinstance(fetched_at_str, bytes)
|
||||
else str(fetched_at_str)
|
||||
)
|
||||
|
||||
last_fetch = datetime.fromisoformat(decoded)
|
||||
|
||||
# Defensively ensure timezone awareness
|
||||
# fromisoformat() returns naive datetime if input lacks timezone
|
||||
if last_fetch.tzinfo is None:
|
||||
# Assume UTC for naive datetimes
|
||||
last_fetch = last_fetch.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
# Convert to UTC if timezone-aware
|
||||
last_fetch = last_fetch.astimezone(timezone.utc)
|
||||
|
||||
return last_fetch
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get last fetch time from Redis: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def save_fetch_metadata(etag: str | None) -> None:
|
||||
"""Save ETag and fetch timestamp to Redis."""
|
||||
redis_client = get_shared_redis_client()
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
try:
|
||||
redis_client.set(REDIS_KEY_FETCHED_AT, now.isoformat(), ex=REDIS_CACHE_TTL)
|
||||
if etag:
|
||||
redis_client.set(REDIS_KEY_ETAG, etag, ex=REDIS_CACHE_TTL)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save fetch metadata to Redis: {e}")
|
||||
|
||||
|
||||
def is_cache_stale() -> bool:
|
||||
"""Check if we should fetch from GitHub."""
|
||||
last_fetch = get_last_fetch_time()
|
||||
if last_fetch is None:
|
||||
return True
|
||||
age = datetime.now(timezone.utc) - last_fetch
|
||||
return age.total_seconds() > AUTO_REFRESH_THRESHOLD_SECONDS
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Main Function
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def ensure_release_notes_fresh_and_notify(db_session: Session) -> None:
|
||||
"""
|
||||
Check for new release notes and create notifications if needed.
|
||||
|
||||
Called from /api/notifications endpoint. Uses ETag for efficient
|
||||
GitHub requests. Database handles notification deduplication.
|
||||
|
||||
Since all users will trigger this via notification fetch,
|
||||
uses Redis lock to prevent concurrent GitHub requests when cache is stale.
|
||||
"""
|
||||
if not is_cache_stale():
|
||||
return
|
||||
|
||||
# Acquire lock to prevent concurrent fetches
|
||||
redis_client = get_shared_redis_client()
|
||||
lock = redis_client.lock(
|
||||
OnyxRedisLocks.RELEASE_NOTES_FETCH_LOCK,
|
||||
timeout=90, # 90 second timeout for the lock
|
||||
)
|
||||
|
||||
# Non-blocking acquire - if we can't get the lock, another request is handling it
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
logger.debug("Another request is already fetching release notes, skipping.")
|
||||
return
|
||||
|
||||
try:
|
||||
logger.debug("Checking GitHub for release notes updates.")
|
||||
|
||||
# Use ETag for conditional request
|
||||
headers: dict[str, str] = {}
|
||||
etag = get_cached_etag()
|
||||
if etag:
|
||||
headers["If-None-Match"] = etag
|
||||
|
||||
try:
|
||||
response = httpx.get(
|
||||
GITHUB_CHANGELOG_RAW_URL,
|
||||
headers=headers,
|
||||
timeout=FETCH_TIMEOUT,
|
||||
follow_redirects=True,
|
||||
)
|
||||
|
||||
if response.status_code == 304:
|
||||
# Content unchanged, just update timestamp
|
||||
logger.debug("Release notes unchanged (304).")
|
||||
save_fetch_metadata(etag)
|
||||
return
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
# Parse and create notifications
|
||||
entries = parse_mdx_to_release_note_entries(response.text)
|
||||
new_etag = response.headers.get("ETag")
|
||||
save_fetch_metadata(new_etag)
|
||||
|
||||
# Create notifications, sorted semantically to create them in chronological order
|
||||
entries = sorted(entries, key=lambda x: parse_version_tuple(x.version))
|
||||
create_release_notifications_for_versions(db_session, entries)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check release notes: {e}")
|
||||
# Update timestamp even on failure to prevent retry storms
|
||||
# We don't save etag on failure to allow retry with conditional request
|
||||
save_fetch_metadata(None)
|
||||
finally:
|
||||
# Always release the lock
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
@@ -22,9 +22,6 @@ from onyx.tools.tool_implementations.open_url.models import WebContentProvider
|
||||
from onyx.tools.tool_implementations.open_url.onyx_web_crawler import (
|
||||
OnyxWebCrawler,
|
||||
)
|
||||
from onyx.tools.tool_implementations.open_url.utils import (
|
||||
filter_web_contents_with_no_title_or_content,
|
||||
)
|
||||
from onyx.tools.tool_implementations.web_search.models import WebContentProviderConfig
|
||||
from onyx.tools.tool_implementations.web_search.models import WebSearchProvider
|
||||
from onyx.tools.tool_implementations.web_search.providers import (
|
||||
@@ -33,9 +30,6 @@ from onyx.tools.tool_implementations.web_search.providers import (
|
||||
from onyx.tools.tool_implementations.web_search.providers import (
|
||||
build_search_provider_from_config,
|
||||
)
|
||||
from onyx.tools.tool_implementations.web_search.utils import (
|
||||
filter_web_search_results_with_no_title_or_snippet,
|
||||
)
|
||||
from onyx.tools.tool_implementations.web_search.utils import (
|
||||
truncate_search_result_content,
|
||||
)
|
||||
@@ -162,10 +156,7 @@ def _run_web_search(
|
||||
status_code=502, detail="Web search provider failed to execute query."
|
||||
) from exc
|
||||
|
||||
filtered_results = filter_web_search_results_with_no_title_or_snippet(
|
||||
list(search_results)
|
||||
)
|
||||
trimmed_results = list(filtered_results)[: request.max_results]
|
||||
trimmed_results = list(search_results)[: request.max_results]
|
||||
for search_result in trimmed_results:
|
||||
results.append(
|
||||
LlmWebSearchResult(
|
||||
@@ -189,9 +180,7 @@ def _open_urls(
|
||||
provider_view, provider = _get_active_content_provider(db_session)
|
||||
|
||||
try:
|
||||
docs = filter_web_contents_with_no_title_or_content(
|
||||
list(provider.contents(urls))
|
||||
)
|
||||
docs = provider.contents(urls)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
|
||||
@@ -410,20 +410,26 @@ def list_llm_provider_basics(
|
||||
|
||||
all_providers = fetch_existing_llm_providers(db_session)
|
||||
user_group_ids = fetch_user_group_ids(db_session, user) if user else set()
|
||||
is_admin = user is not None and user.role == UserRole.ADMIN
|
||||
is_admin = user and user.role == UserRole.ADMIN
|
||||
|
||||
accessible_providers = []
|
||||
|
||||
for provider in all_providers:
|
||||
# Use centralized access control logic with persona=None since we're
|
||||
# listing providers without a specific persona context. This correctly:
|
||||
# - Includes all public providers
|
||||
# - Includes providers user can access via group membership
|
||||
# - Excludes persona-only restricted providers (requires specific persona)
|
||||
# - Excludes non-public providers with no restrictions (admin-only)
|
||||
if can_user_access_llm_provider(
|
||||
provider, user_group_ids, persona=None, is_admin=is_admin
|
||||
):
|
||||
# Include all public providers
|
||||
if provider.is_public:
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
continue
|
||||
|
||||
# Include restricted providers user has access to via groups
|
||||
if is_admin:
|
||||
# Admins see all providers
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
elif provider.groups:
|
||||
# User must be in at least one of the provider's groups
|
||||
if user_group_ids.intersection({g.id for g in provider.groups}):
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
elif not provider.personas:
|
||||
# No restrictions = accessible
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
|
||||
end_time = datetime.now(timezone.utc)
|
||||
|
||||
@@ -4,13 +4,10 @@ from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import InternetContentProvider
|
||||
from onyx.db.models import InternetSearchProvider
|
||||
from onyx.db.models import User
|
||||
from onyx.db.web_search import deactivate_web_content_provider
|
||||
from onyx.db.web_search import deactivate_web_search_provider
|
||||
@@ -32,9 +29,6 @@ from onyx.server.manage.web_search.models import WebContentProviderView
|
||||
from onyx.server.manage.web_search.models import WebSearchProviderTestRequest
|
||||
from onyx.server.manage.web_search.models import WebSearchProviderUpsertRequest
|
||||
from onyx.server.manage.web_search.models import WebSearchProviderView
|
||||
from onyx.tools.tool_implementations.open_url.utils import (
|
||||
filter_web_contents_with_no_title_or_content,
|
||||
)
|
||||
from onyx.tools.tool_implementations.web_search.providers import (
|
||||
build_content_provider_from_config,
|
||||
)
|
||||
@@ -97,28 +91,6 @@ def upsert_search_provider_endpoint(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Sync Exa key of search engine to content provider
|
||||
if (
|
||||
request.provider_type == WebSearchProviderType.EXA
|
||||
and request.api_key_changed
|
||||
and request.api_key
|
||||
):
|
||||
stmt = (
|
||||
insert(InternetContentProvider)
|
||||
.values(
|
||||
name="Exa",
|
||||
provider_type=WebContentProviderType.EXA.value,
|
||||
api_key=request.api_key,
|
||||
is_active=False,
|
||||
)
|
||||
.on_conflict_do_update(
|
||||
index_elements=["name"],
|
||||
set_={"api_key": request.api_key},
|
||||
)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
db_session.flush()
|
||||
|
||||
db_session.commit()
|
||||
return WebSearchProviderView(
|
||||
id=provider.id,
|
||||
@@ -270,28 +242,6 @@ def upsert_content_provider_endpoint(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Sync Exa key of content provider to search provider
|
||||
if (
|
||||
request.provider_type == WebContentProviderType.EXA
|
||||
and request.api_key_changed
|
||||
and request.api_key
|
||||
):
|
||||
stmt = (
|
||||
insert(InternetSearchProvider)
|
||||
.values(
|
||||
name="Exa",
|
||||
provider_type=WebSearchProviderType.EXA.value,
|
||||
api_key=request.api_key,
|
||||
is_active=False,
|
||||
)
|
||||
.on_conflict_do_update(
|
||||
index_elements=["name"],
|
||||
set_={"api_key": request.api_key},
|
||||
)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
db_session.flush()
|
||||
|
||||
db_session.commit()
|
||||
return WebContentProviderView(
|
||||
id=provider.id,
|
||||
@@ -403,9 +353,7 @@ def test_content_provider(
|
||||
# Actually test the API key by making a real content fetch call
|
||||
try:
|
||||
test_url = "https://example.com"
|
||||
test_results = filter_web_contents_with_no_title_or_content(
|
||||
list(provider.contents([test_url]))
|
||||
)
|
||||
test_results = provider.contents([test_url])
|
||||
if not test_results or not any(
|
||||
result.scrape_successful for result in test_results
|
||||
):
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import Generator
|
||||
from datetime import timedelta
|
||||
from uuid import UUID
|
||||
@@ -16,11 +18,8 @@ from pydantic import BaseModel
|
||||
from redis.client import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
from onyx.auth.pat import get_hashed_pat_from_request
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.chat_processing_checker import is_chat_session_processing
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_utils import create_chat_history_chain
|
||||
from onyx.chat.chat_utils import create_chat_session_from_request
|
||||
@@ -88,7 +87,6 @@ from onyx.server.query_and_chat.models import ChatSessionSummary
|
||||
from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.models import LLMOverride
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.server.query_and_chat.models import PromptOverride
|
||||
from onyx.server.query_and_chat.models import RenameChatSessionResponse
|
||||
from onyx.server.query_and_chat.models import SearchFeedbackRequest
|
||||
@@ -107,6 +105,7 @@ from onyx.server.utils import PUBLIC_API_TAGS
|
||||
from onyx.utils.headers import get_custom_tool_additional_request_headers
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -293,18 +292,6 @@ def get_chat_session(
|
||||
translate_db_message_to_chat_message_detail(msg) for msg in session_messages
|
||||
]
|
||||
|
||||
try:
|
||||
is_processing = is_chat_session_processing(session_id, get_redis_client())
|
||||
# Edit the last message to indicate loading (Overriding default message value)
|
||||
if is_processing and chat_message_details:
|
||||
last_msg = chat_message_details[-1]
|
||||
if last_msg.message_type == MessageType.ASSISTANT:
|
||||
last_msg.message = "Message is loading... Please refresh the page soon."
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"An error occurred while checking if the chat session is processing"
|
||||
)
|
||||
|
||||
# Every assistant message might have a set of tool calls associated with it, these need to be replayed back for the frontend
|
||||
# Each list is the set of tool calls for the given assistant message.
|
||||
replay_packet_lists: list[list[Packet]] = []
|
||||
@@ -523,7 +510,7 @@ def handle_new_chat_message(
|
||||
|
||||
|
||||
@router.post("/send-chat-message", response_model=None, tags=PUBLIC_API_TAGS)
|
||||
def handle_send_chat_message(
|
||||
async def handle_send_chat_message(
|
||||
chat_message_req: SendMessageRequest,
|
||||
request: Request,
|
||||
user: User | None = Depends(current_chat_accessible_user),
|
||||
@@ -553,11 +540,6 @@ def handle_send_chat_message(
|
||||
event=MilestoneRecordType.RAN_QUERY,
|
||||
)
|
||||
|
||||
# Override origin to API when authenticated via API key or PAT
|
||||
# to prevent clients from polluting telemetry data
|
||||
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
|
||||
chat_message_req.origin = MessageOrigin.API
|
||||
|
||||
# Non-streaming path: consume all packets and return complete response
|
||||
if not chat_message_req.stream:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -593,34 +575,63 @@ def handle_send_chat_message(
|
||||
# Note: LLM cost tracking is now handled in multi_llm.py
|
||||
return result
|
||||
|
||||
# Streaming path, normal Onyx UI behavior
|
||||
def stream_generator() -> Generator[str, None, None]:
|
||||
# Use prod-cons pattern to continue processing even if request stops yielding
|
||||
buffer: asyncio.Queue[str | None] = asyncio.Queue()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Capture headers before spawning thread
|
||||
litellm_headers = extract_headers(request.headers, LITELLM_PASS_THROUGH_HEADERS)
|
||||
custom_tool_headers = get_custom_tool_additional_request_headers(request.headers)
|
||||
|
||||
def producer() -> None:
|
||||
"""
|
||||
Producer function that runs handle_stream_message_objects in a loop
|
||||
and writes results to the buffer.
|
||||
"""
|
||||
state_container = ChatStateContainer()
|
||||
try:
|
||||
logger.debug("Producer started")
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for obj in handle_stream_message_objects(
|
||||
new_msg_req=chat_message_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
litellm_additional_headers=extract_headers(
|
||||
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
||||
),
|
||||
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
|
||||
request.headers
|
||||
),
|
||||
litellm_additional_headers=litellm_headers,
|
||||
custom_tool_additional_headers=custom_tool_headers,
|
||||
external_state_container=state_container,
|
||||
):
|
||||
yield get_json_line(obj.model_dump())
|
||||
# Thread-safe put into the asyncio queue
|
||||
loop.call_soon_threadsafe(
|
||||
buffer.put_nowait, get_json_line(obj.model_dump())
|
||||
)
|
||||
# Note: LLM cost tracking is now handled in multi_llm.py
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error in chat message streaming")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
loop.call_soon_threadsafe(buffer.put_nowait, json.dumps({"error": str(e)}))
|
||||
finally:
|
||||
logger.debug("Stream generator finished")
|
||||
# Signal end of stream
|
||||
loop.call_soon_threadsafe(buffer.put_nowait, None)
|
||||
logger.debug("Producer finished")
|
||||
|
||||
return StreamingResponse(stream_generator(), media_type="text/event-stream")
|
||||
async def stream_from_buffer() -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Async generator that reads from the buffer and yields to the client.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
item = await buffer.get()
|
||||
if item is None:
|
||||
# End of stream signal
|
||||
break
|
||||
yield item
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("Stream cancelled (Consumer disconnected)")
|
||||
finally:
|
||||
logger.debug("Stream consumer finished")
|
||||
|
||||
run_in_background(producer)
|
||||
|
||||
return StreamingResponse(stream_from_buffer(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.put("/set-message-as-latest")
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
@@ -37,17 +36,6 @@ from onyx.server.query_and_chat.streaming_models import Packet
|
||||
AUTO_PLACE_AFTER_LATEST_MESSAGE = -1
|
||||
|
||||
|
||||
class MessageOrigin(str, Enum):
|
||||
"""Origin of a chat message for telemetry tracking."""
|
||||
|
||||
WEBAPP = "webapp"
|
||||
CHROME_EXTENSION = "chrome_extension"
|
||||
API = "api"
|
||||
SLACKBOT = "slackbot"
|
||||
UNKNOWN = "unknown"
|
||||
UNSET = "unset"
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
@@ -105,9 +93,6 @@ class SendMessageRequest(BaseModel):
|
||||
|
||||
deep_research: bool = False
|
||||
|
||||
# Origin of the message for telemetry tracking
|
||||
origin: MessageOrigin = MessageOrigin.UNSET
|
||||
|
||||
# Placement information for the message in the conversation tree:
|
||||
# - -1: auto-place after latest message in chain
|
||||
# - null: regeneration from root (first message)
|
||||
@@ -199,9 +184,6 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
|
||||
deep_research: bool = False
|
||||
|
||||
# Origin of the message for telemetry tracking
|
||||
origin: MessageOrigin = MessageOrigin.UNKNOWN
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_search_doc_ids_or_retrieval_options(self) -> "CreateChatMessageRequest":
|
||||
if self.search_doc_ids is None and self.retrieval_options is None:
|
||||
|
||||
@@ -60,7 +60,6 @@ from onyx.server.query_and_chat.models import ChatSessionsResponse
|
||||
from onyx.server.query_and_chat.models import DocumentSearchPagination
|
||||
from onyx.server.query_and_chat.models import DocumentSearchRequest
|
||||
from onyx.server.query_and_chat.models import DocumentSearchResponse
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.server.query_and_chat.models import OneShotQARequest
|
||||
from onyx.server.query_and_chat.models import OneShotQAResponse
|
||||
from onyx.server.query_and_chat.models import SearchSessionDetailResponse
|
||||
@@ -252,7 +251,6 @@ def get_answer_stream(
|
||||
)
|
||||
|
||||
# Also creates a new chat session
|
||||
# Origin is hardcoded to API since this endpoint is only accessible via API calls
|
||||
request = prepare_chat_message_request(
|
||||
message_text=combined_message,
|
||||
user=user,
|
||||
@@ -263,7 +261,6 @@ def get_answer_stream(
|
||||
rerank_settings=query_request.rerank_settings,
|
||||
db_session=db_session,
|
||||
skip_gen_ai_answer_generation=query_request.skip_gen_ai_answer_generation,
|
||||
origin=MessageOrigin.API,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
|
||||
@@ -11,7 +11,7 @@ from onyx.db.chat import get_db_search_doc_by_id
|
||||
from onyx.db.chat import translate_db_search_doc_to_saved_search_doc
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.tools import get_tool_by_id
|
||||
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_IN_CODE_ID
|
||||
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_DB_NAME
|
||||
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_TASK_KEY
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
@@ -23,7 +23,6 @@ from onyx.server.query_and_chat.streaming_models import GeneratedImage
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import IntermediateReportDelta
|
||||
from onyx.server.query_and_chat.streaming_models import IntermediateReportStart
|
||||
from onyx.server.query_and_chat.streaming_models import OpenUrlDocuments
|
||||
from onyx.server.query_and_chat.streaming_models import OpenUrlStart
|
||||
from onyx.server.query_and_chat.streaming_models import OpenUrlUrls
|
||||
@@ -36,7 +35,6 @@ from onyx.server.query_and_chat.streaming_models import SearchToolDocumentsDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolQueriesDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.server.query_and_chat.streaming_models import TopLevelBranching
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
@@ -209,7 +207,6 @@ def create_research_agent_packets(
|
||||
"""Create packets for research agent tool calls.
|
||||
This recreates the packet structure that ResearchAgentRenderer expects:
|
||||
- ResearchAgentStart with the research task
|
||||
- IntermediateReportStart to signal report begins
|
||||
- IntermediateReportDelta with the report content (if available)
|
||||
- SectionEnd to mark completion
|
||||
"""
|
||||
@@ -225,14 +222,6 @@ def create_research_agent_packets(
|
||||
|
||||
# Emit report content if available
|
||||
if report_content:
|
||||
# Emit IntermediateReportStart before delta
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
obj=IntermediateReportStart(),
|
||||
)
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index, tab_index=tab_index),
|
||||
@@ -392,17 +381,10 @@ def translate_assistant_message_to_packets(
|
||||
)
|
||||
)
|
||||
|
||||
# Process each tool call in this turn (single pass).
|
||||
# We buffer packets for the turn so we can conditionally prepend a TopLevelBranching
|
||||
# packet (which must appear before any tool output in the turn).
|
||||
research_agent_count = 0
|
||||
turn_tool_packets: list[Packet] = []
|
||||
# Process each tool call in this turn
|
||||
for tool_call in tool_calls_in_turn:
|
||||
# Here we do a try because some tools may get deleted before the session is reloaded.
|
||||
try:
|
||||
tool = get_tool_by_id(tool_call.tool_id, db_session)
|
||||
if tool.in_code_tool_id == RESEARCH_AGENT_IN_CODE_ID:
|
||||
research_agent_count += 1
|
||||
|
||||
# Handle different tool types
|
||||
if tool.in_code_tool_id in [
|
||||
@@ -416,7 +398,7 @@ def translate_assistant_message_to_packets(
|
||||
translate_db_search_doc_to_saved_search_doc(doc)
|
||||
for doc in tool_call.search_docs
|
||||
]
|
||||
turn_tool_packets.extend(
|
||||
packet_list.extend(
|
||||
create_search_packets(
|
||||
search_queries=queries,
|
||||
search_docs=search_docs,
|
||||
@@ -436,7 +418,7 @@ def translate_assistant_message_to_packets(
|
||||
urls = cast(
|
||||
list[str], tool_call.tool_call_arguments.get("urls", [])
|
||||
)
|
||||
turn_tool_packets.extend(
|
||||
packet_list.extend(
|
||||
create_fetch_packets(
|
||||
fetch_docs,
|
||||
urls,
|
||||
@@ -451,20 +433,20 @@ def translate_assistant_message_to_packets(
|
||||
GeneratedImage(**img)
|
||||
for img in tool_call.generated_images
|
||||
]
|
||||
turn_tool_packets.extend(
|
||||
packet_list.extend(
|
||||
create_image_generation_packets(
|
||||
images, turn_num, tab_index=tool_call.tab_index
|
||||
)
|
||||
)
|
||||
|
||||
elif tool.in_code_tool_id == RESEARCH_AGENT_IN_CODE_ID:
|
||||
elif tool.in_code_tool_id == RESEARCH_AGENT_DB_NAME:
|
||||
# Not ideal but not a huge issue if the research task is lost.
|
||||
research_task = cast(
|
||||
str,
|
||||
tool_call.tool_call_arguments.get(RESEARCH_AGENT_TASK_KEY)
|
||||
or "Could not fetch saved research task.",
|
||||
)
|
||||
turn_tool_packets.extend(
|
||||
packet_list.extend(
|
||||
create_research_agent_packets(
|
||||
research_task=research_task,
|
||||
report_content=tool_call.tool_call_response,
|
||||
@@ -475,7 +457,7 @@ def translate_assistant_message_to_packets(
|
||||
|
||||
else:
|
||||
# Custom tool or unknown tool
|
||||
turn_tool_packets.extend(
|
||||
packet_list.extend(
|
||||
create_custom_tool_packets(
|
||||
tool_name=tool.display_name or tool.name,
|
||||
response_type="text",
|
||||
@@ -489,18 +471,6 @@ def translate_assistant_message_to_packets(
|
||||
logger.warning(f"Error processing tool call {tool_call.id}: {e}")
|
||||
continue
|
||||
|
||||
if research_agent_count > 1:
|
||||
# Emit TopLevelBranching before processing any tool output in the turn.
|
||||
packet_list.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_num),
|
||||
obj=TopLevelBranching(
|
||||
num_parallel_branches=research_agent_count
|
||||
),
|
||||
)
|
||||
)
|
||||
packet_list.extend(turn_tool_packets)
|
||||
|
||||
# Determine the next turn_index for the final message
|
||||
# It should come after all tool calls
|
||||
max_tool_turn = 0
|
||||
@@ -569,18 +539,9 @@ def translate_assistant_message_to_packets(
|
||||
if citation_info_list:
|
||||
final_turn_index = max(final_turn_index, citation_turn_index)
|
||||
|
||||
# Determine stop reason - check if message indicates user cancelled
|
||||
stop_reason: str | None = None
|
||||
if chat_message.message:
|
||||
if "Generation was stopped" in chat_message.message:
|
||||
stop_reason = "user_cancelled"
|
||||
|
||||
# Add overall stop packet at the end
|
||||
packet_list.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=final_turn_index),
|
||||
obj=OverallStop(stop_reason=stop_reason),
|
||||
)
|
||||
Packet(placement=Placement(turn_index=final_turn_index), obj=OverallStop())
|
||||
)
|
||||
|
||||
return packet_list
|
||||
|
||||
@@ -410,7 +410,7 @@ def run_research_agent_call(
|
||||
most_recent_reasoning = llm_step_result.reasoning
|
||||
continue
|
||||
else:
|
||||
parallel_tool_call_results = run_tool_calls(
|
||||
tool_responses, citation_mapping = run_tool_calls(
|
||||
tool_calls=tool_calls,
|
||||
tools=current_tools,
|
||||
message_history=msg_history,
|
||||
@@ -424,10 +424,6 @@ def run_research_agent_call(
|
||||
# May be better to not do this step, hard to say, needs to be tested
|
||||
skip_search_query_expansion=False,
|
||||
)
|
||||
tool_responses = parallel_tool_call_results.tool_responses
|
||||
citation_mapping = (
|
||||
parallel_tool_call_results.updated_citation_mapping
|
||||
)
|
||||
|
||||
if tool_calls and not tool_responses:
|
||||
failure_messages = create_tool_call_failure_messages(
|
||||
|
||||
@@ -25,17 +25,6 @@ TOOL_CALL_MSG_FUNC_NAME = "function_name"
|
||||
TOOL_CALL_MSG_ARGUMENTS = "arguments"
|
||||
|
||||
|
||||
class ToolCallException(Exception):
|
||||
"""Exception raised for errors during tool calls."""
|
||||
|
||||
def __init__(self, message: str, llm_facing_message: str):
|
||||
# This is the full error message which is used for tracing
|
||||
super().__init__(message)
|
||||
# LLM made tool calls are acceptable and not flow terminating, this is the message
|
||||
# which will populate the tool response.
|
||||
self.llm_facing_message = llm_facing_message
|
||||
|
||||
|
||||
class SearchToolUsage(str, Enum):
|
||||
DISABLED = "disabled"
|
||||
ENABLED = "enabled"
|
||||
@@ -88,11 +77,6 @@ class ToolResponse(BaseModel):
|
||||
tool_call: ToolCallKickoff | None = None
|
||||
|
||||
|
||||
class ParallelToolCallResponse(BaseModel):
|
||||
tool_responses: list[ToolResponse]
|
||||
updated_citation_mapping: dict[int, str]
|
||||
|
||||
|
||||
class ToolRunnerResponse(BaseModel):
|
||||
tool_run_kickoff: ToolCallKickoff | None = None
|
||||
tool_response: ToolResponse | None = None
|
||||
|
||||
@@ -34,9 +34,6 @@ from onyx.tools.tool_implementations.open_url.url_normalization import (
|
||||
_default_url_normalizer,
|
||||
)
|
||||
from onyx.tools.tool_implementations.open_url.url_normalization import normalize_url
|
||||
from onyx.tools.tool_implementations.open_url.utils import (
|
||||
filter_web_contents_with_no_title_or_content,
|
||||
)
|
||||
from onyx.tools.tool_implementations.web_search.providers import (
|
||||
get_default_content_provider,
|
||||
)
|
||||
@@ -523,11 +520,6 @@ class OpenURLTool(Tool[OpenURLToolOverrideKwargs]):
|
||||
)
|
||||
return ToolResponse(rich_response=None, llm_facing_response=failure_msg)
|
||||
|
||||
for section in inference_sections:
|
||||
chunk = section.center_chunk
|
||||
if not chunk.semantic_identifier and chunk.source_links:
|
||||
chunk.semantic_identifier = chunk.source_links[0]
|
||||
|
||||
# Convert sections to search docs, preserving source information
|
||||
search_docs = convert_inference_sections_to_search_docs(
|
||||
inference_sections, is_internet=False
|
||||
@@ -774,23 +766,15 @@ class OpenURLTool(Tool[OpenURLToolOverrideKwargs]):
|
||||
if not urls:
|
||||
return [], []
|
||||
|
||||
raw_web_contents = self._provider.contents(urls)
|
||||
# Treat "no title and no content" as a failure for that URL, but don't
|
||||
# include the empty entry in downstream prompting/sections.
|
||||
failed_urls: list[str] = [
|
||||
content.link
|
||||
for content in raw_web_contents
|
||||
if not content.title.strip() and not content.full_content.strip()
|
||||
]
|
||||
web_contents = filter_web_contents_with_no_title_or_content(raw_web_contents)
|
||||
web_contents = self._provider.contents(urls)
|
||||
sections: list[InferenceSection] = []
|
||||
failed_urls: list[str] = []
|
||||
|
||||
for content in web_contents:
|
||||
# Check if content is insufficient (e.g., "Loading..." or too short)
|
||||
text_stripped = content.full_content.strip()
|
||||
is_insufficient = (
|
||||
not text_stripped
|
||||
# TODO: Likely a behavior of our scraper, understand why this special pattern occurs
|
||||
or text_stripped.lower() == "loading..."
|
||||
or len(text_stripped) < 50
|
||||
)
|
||||
@@ -802,9 +786,6 @@ class OpenURLTool(Tool[OpenURLToolOverrideKwargs]):
|
||||
):
|
||||
sections.append(inference_section_from_internet_page_scrape(content))
|
||||
else:
|
||||
# TODO: Slight improvement - if failed URL reasons are passed back to the LLM
|
||||
# for example, if it tries to crawl Reddit and fails, it should know (probably) that this error would
|
||||
# happen again if it tried to crawl Reddit again.
|
||||
failed_urls.append(content.link or "")
|
||||
|
||||
return sections, failed_urls
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
from onyx.tools.tool_implementations.open_url.models import WebContent
|
||||
|
||||
|
||||
def filter_web_contents_with_no_title_or_content(
|
||||
contents: list[WebContent],
|
||||
) -> list[WebContent]:
|
||||
"""Filter out content entries that have neither a title nor any extracted text.
|
||||
|
||||
Some content providers can return placeholder/partial entries that only include a URL.
|
||||
Downstream uses these fields for display + prompting; drop empty ones centrally
|
||||
rather than duplicating checks across provider clients.
|
||||
"""
|
||||
filtered: list[WebContent] = []
|
||||
for content in contents:
|
||||
if content.title.strip() or content.full_content.strip():
|
||||
filtered.append(content)
|
||||
return filtered
|
||||
@@ -252,14 +252,14 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
|
||||
# Store session factory instead of session for thread-safety
|
||||
# When tools are called in parallel, each thread needs its own session
|
||||
# TODO ensure this works!!!
|
||||
self._session_bind = db_session.get_bind()
|
||||
self._session_factory = sessionmaker(bind=self._session_bind)
|
||||
|
||||
self._id = tool_id
|
||||
|
||||
def _get_thread_safe_session(self) -> Session:
|
||||
"""Create a new database session for the current thread. Note this is only safe for the ORM caches/identity maps,
|
||||
pending objects, flush state, etc. But it is still using the same underlying database connection.
|
||||
"""Create a new database session for the current thread.
|
||||
|
||||
This ensures thread-safety when the search tool is called in parallel.
|
||||
Each parallel execution gets its own isolated database session with
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
|
||||
from exa_py import Exa
|
||||
@@ -20,21 +19,7 @@ from onyx.utils.retry_wrapper import retry_builder
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _extract_site_operators(query: str) -> tuple[str, list[str]]:
|
||||
"""Extract site: operators and return cleaned query + full domains.
|
||||
|
||||
Returns (cleaned_query, full_domains) where full_domains contains the full
|
||||
values after site: (e.g., ["reddit.com/r/leagueoflegends"]).
|
||||
"""
|
||||
full_domains = re.findall(r"site:\s*([^\s]+)", query, re.IGNORECASE)
|
||||
cleaned_query = re.sub(r"site:\s*\S+\s*", "", query, flags=re.IGNORECASE).strip()
|
||||
|
||||
if not cleaned_query and full_domains:
|
||||
cleaned_query = full_domains[0]
|
||||
|
||||
return cleaned_query, full_domains
|
||||
|
||||
|
||||
# TODO can probably break this up
|
||||
class ExaClient(WebSearchProvider, WebContentProvider):
|
||||
def __init__(self, api_key: str, num_results: int = 10) -> None:
|
||||
self.exa = Exa(api_key=api_key)
|
||||
@@ -44,9 +29,8 @@ class ExaClient(WebSearchProvider, WebContentProvider):
|
||||
def supports_site_filter(self) -> bool:
|
||||
return False
|
||||
|
||||
def _search_exa(
|
||||
self, query: str, include_domains: list[str] | None = None
|
||||
) -> list[WebSearchResult]:
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def search(self, query: str) -> list[WebSearchResult]:
|
||||
response = self.exa.search_and_contents(
|
||||
query,
|
||||
type="auto",
|
||||
@@ -55,43 +39,22 @@ class ExaClient(WebSearchProvider, WebContentProvider):
|
||||
highlights_per_url=1,
|
||||
),
|
||||
num_results=self._num_results,
|
||||
include_domains=include_domains,
|
||||
)
|
||||
|
||||
results: list[WebSearchResult] = []
|
||||
for result in response.results:
|
||||
title = (result.title or "").strip()
|
||||
snippet = (result.highlights[0] if result.highlights else "").strip()
|
||||
results.append(
|
||||
WebSearchResult(
|
||||
title=title,
|
||||
link=result.url,
|
||||
snippet=snippet,
|
||||
author=result.author,
|
||||
published_date=(
|
||||
time_str_to_utc(result.published_date)
|
||||
if result.published_date
|
||||
else None
|
||||
),
|
||||
)
|
||||
return [
|
||||
WebSearchResult(
|
||||
title=result.title or "",
|
||||
link=result.url,
|
||||
snippet=result.highlights[0] if result.highlights else "",
|
||||
author=result.author,
|
||||
published_date=(
|
||||
time_str_to_utc(result.published_date)
|
||||
if result.published_date
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def search(self, query: str) -> list[WebSearchResult]:
|
||||
cleaned_query, full_domains = _extract_site_operators(query)
|
||||
|
||||
if full_domains:
|
||||
# Try with include_domains using base domains (e.g., ["reddit.com"])
|
||||
base_domains = [d.split("/")[0].removeprefix("www.") for d in full_domains]
|
||||
results = self._search_exa(cleaned_query, include_domains=base_domains)
|
||||
if results:
|
||||
return results
|
||||
|
||||
# Fallback: add full domains as keywords
|
||||
query_with_domains = f"{cleaned_query} {' '.join(full_domains)}".strip()
|
||||
return self._search_exa(query_with_domains)
|
||||
for result in response.results
|
||||
]
|
||||
|
||||
def test_connection(self) -> dict[str, str]:
|
||||
try:
|
||||
@@ -130,24 +93,16 @@ class ExaClient(WebSearchProvider, WebContentProvider):
|
||||
livecrawl="preferred",
|
||||
)
|
||||
|
||||
# Exa can return partial/empty content entries; skip those to avoid
|
||||
# downstream prompt + UI pollution.
|
||||
contents: list[WebContent] = []
|
||||
for result in response.results:
|
||||
title = (result.title or "").strip()
|
||||
full_content = (result.text or "").strip()
|
||||
contents.append(
|
||||
WebContent(
|
||||
title=title,
|
||||
link=result.url,
|
||||
full_content=full_content,
|
||||
published_date=(
|
||||
time_str_to_utc(result.published_date)
|
||||
if result.published_date
|
||||
else None
|
||||
),
|
||||
scrape_successful=bool(full_content),
|
||||
)
|
||||
return [
|
||||
WebContent(
|
||||
title=result.title or "",
|
||||
link=result.url,
|
||||
full_content=result.text or "",
|
||||
published_date=(
|
||||
time_str_to_utc(result.published_date)
|
||||
if result.published_date
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
return contents
|
||||
for result in response.results
|
||||
]
|
||||
|
||||
@@ -47,28 +47,20 @@ class SerperClient(WebSearchProvider, WebContentProvider):
|
||||
response.raise_for_status()
|
||||
|
||||
results = response.json()
|
||||
organic_results = results.get("organic") or []
|
||||
organic_results = results["organic"]
|
||||
|
||||
validated_results: list[WebSearchResult] = []
|
||||
for result in organic_results:
|
||||
link = (result.get("link") or "").strip()
|
||||
if not link:
|
||||
continue
|
||||
organic_results = filter(lambda result: "link" in result, organic_results)
|
||||
|
||||
title = (result.get("title") or "").strip()
|
||||
snippet = (result.get("snippet") or "").strip()
|
||||
|
||||
validated_results.append(
|
||||
WebSearchResult(
|
||||
title=title,
|
||||
link=link,
|
||||
snippet=snippet,
|
||||
author=None,
|
||||
published_date=None,
|
||||
)
|
||||
return [
|
||||
WebSearchResult(
|
||||
title=result.get("title", ""),
|
||||
link=result.get("link"),
|
||||
snippet=result.get("snippet", ""),
|
||||
author=None,
|
||||
published_date=None,
|
||||
)
|
||||
|
||||
return validated_results
|
||||
for result in organic_results
|
||||
]
|
||||
|
||||
def test_connection(self) -> dict[str, str]:
|
||||
try:
|
||||
|
||||
@@ -98,9 +98,6 @@ def build_content_provider_from_config(
|
||||
timeout_seconds=config.timeout_seconds,
|
||||
)
|
||||
|
||||
if provider_type == WebContentProviderType.EXA:
|
||||
return ExaClient(api_key=api_key)
|
||||
|
||||
|
||||
def get_default_provider() -> WebSearchProvider | None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
|
||||
@@ -6,22 +6,6 @@ from onyx.tools.tool_implementations.web_search.models import WEB_SEARCH_PREFIX
|
||||
from onyx.tools.tool_implementations.web_search.models import WebSearchResult
|
||||
|
||||
|
||||
def filter_web_search_results_with_no_title_or_snippet(
|
||||
results: list[WebSearchResult],
|
||||
) -> list[WebSearchResult]:
|
||||
"""Filter out results that have neither a title nor a snippet.
|
||||
|
||||
Some providers can return entries that only include a URL. Downstream uses
|
||||
titles/snippets for display and prompting, so we drop those empty entries
|
||||
centrally (rather than duplicating the check in each client).
|
||||
"""
|
||||
filtered: list[WebSearchResult] = []
|
||||
for result in results:
|
||||
if result.title.strip() or result.snippet.strip():
|
||||
filtered.append(result)
|
||||
return filtered
|
||||
|
||||
|
||||
def truncate_search_result_content(content: str, max_chars: int = 15000) -> str:
|
||||
"""Truncate search result content to a maximum number of characters"""
|
||||
if len(content) <= max_chars:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -16,7 +15,6 @@ from onyx.server.query_and_chat.streaming_models import SearchToolDocumentsDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolQueriesDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ToolCallException
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.models import WebSearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.utils import (
|
||||
@@ -27,9 +25,6 @@ from onyx.tools.tool_implementations.web_search.models import WebSearchResult
|
||||
from onyx.tools.tool_implementations.web_search.providers import (
|
||||
build_search_provider_from_config,
|
||||
)
|
||||
from onyx.tools.tool_implementations.web_search.utils import (
|
||||
filter_web_search_results_with_no_title_or_snippet,
|
||||
)
|
||||
from onyx.tools.tool_implementations.web_search.utils import (
|
||||
inference_section_from_internet_search_result,
|
||||
)
|
||||
@@ -129,28 +124,13 @@ class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
|
||||
)
|
||||
)
|
||||
|
||||
def _safe_execute_single_search(
|
||||
def _execute_single_search(
|
||||
self,
|
||||
query: str,
|
||||
provider: Any,
|
||||
) -> tuple[list[WebSearchResult] | None, str | None]:
|
||||
"""Execute a single search query and return results with error capture.
|
||||
|
||||
Returns:
|
||||
A tuple of (results, error_message). If successful, error_message is None.
|
||||
If failed, results is None and error_message contains the error.
|
||||
"""
|
||||
try:
|
||||
raw_results = list(provider.search(query))
|
||||
filtered_results = filter_web_search_results_with_no_title_or_snippet(
|
||||
raw_results
|
||||
)
|
||||
results = filtered_results[:DEFAULT_MAX_RESULTS]
|
||||
return (results, None)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.warning(f"Web search query '{query}' failed: {error_msg}")
|
||||
return (None, error_msg)
|
||||
) -> list[WebSearchResult]:
|
||||
"""Execute a single search query and return results."""
|
||||
return list(provider.search(query))[:DEFAULT_MAX_RESULTS]
|
||||
|
||||
def run(
|
||||
self,
|
||||
@@ -169,46 +149,22 @@ class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
|
||||
)
|
||||
)
|
||||
|
||||
# Perform searches in parallel with error capture
|
||||
# Perform searches in parallel
|
||||
functions_with_args = [
|
||||
(self._safe_execute_single_search, (query, self._provider))
|
||||
for query in queries
|
||||
(self._execute_single_search, (query, self._provider)) for query in queries
|
||||
]
|
||||
search_results_with_errors: list[
|
||||
tuple[list[WebSearchResult] | None, str | None]
|
||||
] = run_functions_tuples_in_parallel(
|
||||
functions_with_args,
|
||||
allow_failures=False, # Our wrapper handles errors internally
|
||||
search_results_per_query: list[list[WebSearchResult]] = (
|
||||
run_functions_tuples_in_parallel(
|
||||
functions_with_args,
|
||||
allow_failures=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Separate successful results from failures
|
||||
valid_results: list[list[WebSearchResult]] = []
|
||||
failed_queries: dict[str, str] = {}
|
||||
|
||||
for query, (results, error) in zip(queries, search_results_with_errors):
|
||||
if error is not None:
|
||||
failed_queries[query] = error
|
||||
elif results is not None:
|
||||
valid_results.append(results)
|
||||
|
||||
# Log partial failures but continue if we have at least one success
|
||||
if failed_queries and valid_results:
|
||||
logger.warning(
|
||||
f"Web search partial failure: {len(failed_queries)}/{len(queries)} "
|
||||
f"queries failed. Failed queries: {json.dumps(failed_queries)}"
|
||||
)
|
||||
|
||||
# If all queries failed, raise ToolCallException with details
|
||||
if not valid_results:
|
||||
error_details = json.dumps(failed_queries, indent=2)
|
||||
raise ToolCallException(
|
||||
message=f"All web search queries failed: {error_details}",
|
||||
llm_facing_message=(
|
||||
f"All web search queries failed. Query failures:\n{error_details}"
|
||||
),
|
||||
)
|
||||
|
||||
# Interweave top results from each query in round-robin fashion
|
||||
# Filter out None results from failures
|
||||
valid_results = [
|
||||
results for results in search_results_per_query if results is not None
|
||||
]
|
||||
all_search_results: list[WebSearchResult] = []
|
||||
|
||||
if valid_results:
|
||||
@@ -235,15 +191,8 @@ class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
|
||||
if not added_any:
|
||||
break
|
||||
|
||||
# This should be a very rare case and is due to not failing loudly enough in the search provider implementation.
|
||||
if not all_search_results:
|
||||
raise ToolCallException(
|
||||
message="Web search queries succeeded but returned no results",
|
||||
llm_facing_message=(
|
||||
"Web search completed but found no results for the given queries. "
|
||||
"Try rephrasing or using different search terms."
|
||||
),
|
||||
)
|
||||
raise RuntimeError("No search results found.")
|
||||
|
||||
# Convert search results to InferenceSections with rank-based scoring
|
||||
inference_sections = [
|
||||
@@ -265,22 +214,13 @@ class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
|
||||
)
|
||||
|
||||
# Format for LLM
|
||||
if not all_search_results:
|
||||
docs_str = json.dumps(
|
||||
{
|
||||
"results": [],
|
||||
"message": "The web search completed but returned no results for any of the queries. Do not search again.",
|
||||
}
|
||||
)
|
||||
citation_mapping: dict[int, str] = {}
|
||||
else:
|
||||
docs_str, citation_mapping = convert_inference_sections_to_llm_string(
|
||||
top_sections=inference_sections,
|
||||
citation_start=override_kwargs.starting_citation_num,
|
||||
limit=None, # Already truncated
|
||||
include_source_type=False,
|
||||
include_link=True,
|
||||
)
|
||||
docs_str, citation_mapping = convert_inference_sections_to_llm_string(
|
||||
top_sections=inference_sections,
|
||||
citation_start=override_kwargs.starting_citation_num,
|
||||
limit=None, # Already truncated
|
||||
include_source_type=False,
|
||||
include_link=True,
|
||||
)
|
||||
|
||||
return ToolResponse(
|
||||
rich_response=SearchDocsResponse(
|
||||
|
||||
@@ -11,9 +11,7 @@ from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ChatMinimalTextMessage
|
||||
from onyx.tools.models import OpenURLToolOverrideKwargs
|
||||
from onyx.tools.models import ParallelToolCallResponse
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.models import ToolCallException
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.models import WebSearchToolOverrideKwargs
|
||||
@@ -29,7 +27,6 @@ logger = setup_logger()
|
||||
|
||||
QUERIES_FIELD = "queries"
|
||||
URLS_FIELD = "urls"
|
||||
GENERIC_TOOL_ERROR_MESSAGE = "Tool failed with error: {error}"
|
||||
|
||||
# Mapping of tool name to the field that should be merged when multiple calls exist
|
||||
MERGEABLE_TOOL_FIELDS: dict[str, str] = {
|
||||
@@ -94,7 +91,7 @@ def _merge_tool_calls(tool_calls: list[ToolCallKickoff]) -> list[ToolCallKickoff
|
||||
return merged_calls
|
||||
|
||||
|
||||
def _safe_run_single_tool(
|
||||
def _run_single_tool(
|
||||
tool: Tool,
|
||||
tool_call: ToolCallKickoff,
|
||||
override_kwargs: Any,
|
||||
@@ -102,18 +99,7 @@ def _safe_run_single_tool(
|
||||
"""Execute a single tool and return its response.
|
||||
|
||||
This function is designed to be run in parallel via run_functions_tuples_in_parallel.
|
||||
|
||||
Exception handling:
|
||||
- ToolCallException: Expected errors from tool execution (e.g., invalid input,
|
||||
API failures). Uses the exception's llm_facing_message for LLM consumption.
|
||||
- Other exceptions: Unexpected errors. Uses a generic error message.
|
||||
|
||||
In all cases (success or failure):
|
||||
- SectionEnd packet is emitted to signal tool completion
|
||||
- tool_call is set on the response for downstream processing
|
||||
"""
|
||||
tool_response: ToolResponse | None = None
|
||||
|
||||
with function_span(tool.name) as span_fn:
|
||||
span_fn.span_data.input = str(tool_call.tool_args)
|
||||
try:
|
||||
@@ -123,47 +109,19 @@ def _safe_run_single_tool(
|
||||
**tool_call.tool_args,
|
||||
)
|
||||
span_fn.span_data.output = tool_response.llm_facing_response
|
||||
except ToolCallException as e:
|
||||
# ToolCallException is an expected error from tool execution
|
||||
# Use llm_facing_message which is specifically designed for LLM consumption
|
||||
logger.error(f"Tool call error for {tool.name}: {e}")
|
||||
tool_response = ToolResponse(
|
||||
rich_response=None,
|
||||
llm_facing_response=GENERIC_TOOL_ERROR_MESSAGE.format(
|
||||
error=e.llm_facing_message
|
||||
),
|
||||
)
|
||||
_error_tracing.attach_error_to_current_span(
|
||||
SpanError(
|
||||
message="Tool call error (expected)",
|
||||
data={
|
||||
"tool_name": tool.name,
|
||||
"tool_call_id": tool_call.tool_call_id,
|
||||
"tool_args": tool_call.tool_args,
|
||||
"error": str(e),
|
||||
"llm_facing_message": e.llm_facing_message,
|
||||
"stack_trace": traceback.format_exc(),
|
||||
"error_type": "ToolCallException",
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
# Unexpected error during tool execution
|
||||
logger.error(f"Unexpected error running tool {tool.name}: {e}")
|
||||
logger.error(f"Error running tool {tool.name}: {e}")
|
||||
tool_response = ToolResponse(
|
||||
rich_response=None,
|
||||
llm_facing_response=GENERIC_TOOL_ERROR_MESSAGE.format(error=str(e)),
|
||||
llm_facing_response="Tool execution failed with: " + str(e),
|
||||
)
|
||||
_error_tracing.attach_error_to_current_span(
|
||||
SpanError(
|
||||
message="Tool execution error (unexpected)",
|
||||
message="Error running tool",
|
||||
data={
|
||||
"tool_name": tool.name,
|
||||
"tool_call_id": tool_call.tool_call_id,
|
||||
"tool_args": tool_call.tool_args,
|
||||
"error": str(e),
|
||||
"stack_trace": traceback.format_exc(),
|
||||
"error_type": type(e).__name__,
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -195,52 +153,35 @@ def run_tool_calls(
|
||||
max_concurrent_tools: int | None = None,
|
||||
# Skip query expansion for repeat search tool calls
|
||||
skip_search_query_expansion: bool = False,
|
||||
) -> ParallelToolCallResponse:
|
||||
"""Run (optionally merged) tool calls in parallel and update citation mappings.
|
||||
) -> tuple[list[ToolResponse], dict[int, str]]:
|
||||
"""Run multiple tool calls in parallel and update citation mappings.
|
||||
|
||||
Before execution, tool calls for `SearchTool`, `WebSearchTool`, and `OpenURLTool`
|
||||
are merged so repeated calls are collapsed into a single call per tool:
|
||||
- `SearchTool` / `WebSearchTool`: merge the `queries` list
|
||||
- `OpenURLTool`: merge the `urls` list
|
||||
|
||||
Tools are executed in parallel (threadpool). For tools that generate citations,
|
||||
each tool call is assigned a **distinct** `starting_citation_num` range to avoid
|
||||
citation number collisions when running concurrently (the range is advanced by
|
||||
100 per tool call).
|
||||
|
||||
The provided `citation_mapping` may be mutated in-place: any new
|
||||
`SearchDocsResponse.citation_mapping` entries are merged into it.
|
||||
Merges tool calls for SearchTool, WebSearchTool, and OpenURLTool before execution.
|
||||
All tools are executed in parallel, and citation mappings are updated
|
||||
from search tool responses.
|
||||
|
||||
Args:
|
||||
tool_calls: List of tool calls to execute.
|
||||
tools: List of available tool instances.
|
||||
message_history: Chat message history (used to find the most recent user query
|
||||
for `SearchTool` override kwargs).
|
||||
memories: User memories, if available (passed through to `SearchTool`).
|
||||
user_info: User information string, if available (passed through to `SearchTool`).
|
||||
citation_mapping: Current citation number to URL mapping. May be updated with
|
||||
new citations produced by search tools.
|
||||
next_citation_num: The next citation number to allocate from.
|
||||
tool_calls: List of tool calls to execute
|
||||
tools: List of available tools
|
||||
message_history: Chat message history for context
|
||||
memories: User memories, if available
|
||||
user_info: User information string, if available
|
||||
citation_mapping: Current citation number to URL mapping
|
||||
next_citation_num: Next citation number to use
|
||||
max_concurrent_tools: Max number of tools to run in this batch. If set, any
|
||||
tool calls after this limit are dropped (not queued).
|
||||
skip_search_query_expansion: Whether to skip query expansion for `SearchTool`
|
||||
(intended for repeated search calls within the same chat turn).
|
||||
skip_search_query_expansion: Whether to skip query expansion for search tools
|
||||
|
||||
Returns:
|
||||
A `ParallelToolCallResponse` containing:
|
||||
- `tool_responses`: `ToolResponse` objects for successfully dispatched tool calls
|
||||
(each has `tool_call` set). If a tool execution fails at the threadpool layer,
|
||||
its entry will be omitted.
|
||||
- `updated_citation_mapping`: The updated citation mapping dictionary.
|
||||
A tuple containing:
|
||||
- List of ToolResponse objects (each with tool_call set)
|
||||
- Updated citation mapping dictionary
|
||||
"""
|
||||
# Merge tool calls for SearchTool, WebSearchTool, and OpenURLTool
|
||||
# Merge tool calls for SearchTool and WebSearchTool
|
||||
merged_tool_calls = _merge_tool_calls(tool_calls)
|
||||
|
||||
if not merged_tool_calls:
|
||||
return ParallelToolCallResponse(
|
||||
tool_responses=[],
|
||||
updated_citation_mapping=citation_mapping,
|
||||
)
|
||||
return [], citation_mapping
|
||||
|
||||
tools_by_name = {tool.name: tool for tool in tools}
|
||||
|
||||
@@ -255,10 +196,7 @@ def run_tool_calls(
|
||||
# Apply safety cap (drop tool calls beyond the cap)
|
||||
if max_concurrent_tools is not None:
|
||||
if max_concurrent_tools <= 0:
|
||||
return ParallelToolCallResponse(
|
||||
tool_responses=[],
|
||||
updated_citation_mapping=citation_mapping,
|
||||
)
|
||||
return [], citation_mapping
|
||||
filtered_tool_calls = filtered_tool_calls[:max_concurrent_tools]
|
||||
|
||||
# Get starting citation number from citation processor to avoid conflicts with project files
|
||||
@@ -331,29 +269,24 @@ def run_tool_calls(
|
||||
|
||||
# Run all tools in parallel
|
||||
functions_with_args = [
|
||||
(_safe_run_single_tool, (tool, tool_call, override_kwargs))
|
||||
(_run_single_tool, (tool, tool_call, override_kwargs))
|
||||
for tool, tool_call, override_kwargs in tool_run_params
|
||||
]
|
||||
|
||||
tool_run_results: list[ToolResponse | None] = run_functions_tuples_in_parallel(
|
||||
tool_responses: list[ToolResponse] = run_functions_tuples_in_parallel(
|
||||
functions_with_args,
|
||||
allow_failures=True, # Continue even if some tools fail
|
||||
max_workers=max_concurrent_tools,
|
||||
)
|
||||
|
||||
# Process results and update citation_mapping
|
||||
for result in tool_run_results:
|
||||
if result is None:
|
||||
continue
|
||||
|
||||
if result and isinstance(result.rich_response, SearchDocsResponse):
|
||||
new_citations = result.rich_response.citation_mapping
|
||||
for tool_response in tool_responses:
|
||||
if tool_response and isinstance(
|
||||
tool_response.rich_response, SearchDocsResponse
|
||||
):
|
||||
new_citations = tool_response.rich_response.citation_mapping
|
||||
if new_citations:
|
||||
# Merge new citations into the existing mapping
|
||||
citation_mapping.update(new_citations)
|
||||
|
||||
tool_responses = [result for result in tool_run_results if result is not None]
|
||||
return ParallelToolCallResponse(
|
||||
tool_responses=tool_responses,
|
||||
updated_citation_mapping=citation_mapping,
|
||||
)
|
||||
return tool_responses, citation_mapping
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[project]
|
||||
name = "onyx-backend"
|
||||
version = "0.0.0"
|
||||
requires-python = ">=3.11"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
dependencies = [
|
||||
"onyx[backend,dev,ee]",
|
||||
]
|
||||
|
||||
@@ -5,9 +5,7 @@ aioboto3==15.1.0
|
||||
aiobotocore==2.24.0
|
||||
# via aioboto3
|
||||
aiofiles==25.1.0
|
||||
# via
|
||||
# aioboto3
|
||||
# unstructured-client
|
||||
# via aioboto3
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.13.3
|
||||
@@ -117,6 +115,7 @@ certifi==2025.11.12
|
||||
# requests
|
||||
# sentry-sdk
|
||||
# trafilatura
|
||||
# unstructured-client
|
||||
cffi==2.0.0
|
||||
# via
|
||||
# argon2-cffi-bindings
|
||||
@@ -124,7 +123,9 @@ cffi==2.0.0
|
||||
# pynacl
|
||||
# zstandard
|
||||
chardet==5.2.0
|
||||
# via onyx
|
||||
# via
|
||||
# onyx
|
||||
# unstructured
|
||||
charset-normalizer==3.4.4
|
||||
# via
|
||||
# htmldate
|
||||
@@ -132,7 +133,7 @@ charset-normalizer==3.4.4
|
||||
# pdfminer-six
|
||||
# requests
|
||||
# trafilatura
|
||||
# unstructured
|
||||
# unstructured-client
|
||||
chevron==0.14.0
|
||||
# via braintrust
|
||||
chonkie==1.0.10
|
||||
@@ -148,7 +149,6 @@ click==8.3.1
|
||||
# litellm
|
||||
# magika
|
||||
# nltk
|
||||
# python-oxmsg
|
||||
# typer
|
||||
# uvicorn
|
||||
# zulip
|
||||
@@ -185,7 +185,6 @@ cryptography==46.0.3
|
||||
# pyjwt
|
||||
# secretstorage
|
||||
# sendgrid
|
||||
# unstructured-client
|
||||
cyclopts==4.2.4
|
||||
# via fastmcp
|
||||
dask==2023.8.1
|
||||
@@ -193,13 +192,17 @@ dask==2023.8.1
|
||||
# distributed
|
||||
# onyx
|
||||
dataclasses-json==0.6.7
|
||||
# via unstructured
|
||||
# via
|
||||
# unstructured
|
||||
# unstructured-client
|
||||
dateparser==1.2.2
|
||||
# via htmldate
|
||||
ddtrace==3.10.0
|
||||
# via onyx
|
||||
decorator==5.2.1
|
||||
# via retry
|
||||
deepdiff==8.6.1
|
||||
# via unstructured-client
|
||||
defusedxml==0.7.1
|
||||
# via
|
||||
# jira
|
||||
@@ -351,7 +354,7 @@ greenlet==3.2.4
|
||||
# sqlalchemy
|
||||
grpc-google-iam-v1==0.14.3
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.67.1 ; python_full_version < '3.14'
|
||||
grpcio==1.67.1
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
@@ -359,17 +362,7 @@ grpcio==1.67.1 ; python_full_version < '3.14'
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# litellm
|
||||
grpcio==1.76.0 ; python_full_version >= '3.14'
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# litellm
|
||||
grpcio-status==1.67.1 ; python_full_version < '3.14'
|
||||
# via google-api-core
|
||||
grpcio-status==1.76.0 ; python_full_version >= '3.14'
|
||||
grpcio-status==1.67.1
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via
|
||||
@@ -381,15 +374,12 @@ hf-xet==1.2.0 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or
|
||||
# via huggingface-hub
|
||||
hpack==4.1.0
|
||||
# via h2
|
||||
html5lib==1.1
|
||||
# via unstructured
|
||||
htmldate==1.9.1
|
||||
# via trafilatura
|
||||
httpcore==1.0.9
|
||||
# via
|
||||
# httpx
|
||||
# onyx
|
||||
# unstructured-client
|
||||
httplib2==0.31.0
|
||||
# via
|
||||
# google-api-python-client
|
||||
@@ -430,6 +420,7 @@ idna==3.11
|
||||
# email-validator
|
||||
# httpx
|
||||
# requests
|
||||
# unstructured-client
|
||||
# yarl
|
||||
importlib-metadata==8.7.0
|
||||
# via
|
||||
@@ -475,6 +466,8 @@ joblib==1.5.2
|
||||
# via nltk
|
||||
jsonpatch==1.33
|
||||
# via langchain-core
|
||||
jsonpath-python==1.0.6
|
||||
# via unstructured-client
|
||||
jsonpointer==3.0.0
|
||||
# via jsonpatch
|
||||
jsonref==1.1.0
|
||||
@@ -516,8 +509,6 @@ langsmith==0.3.45
|
||||
# langchain-core
|
||||
lazy-imports==1.0.1
|
||||
# via onyx
|
||||
legacy-cgi==2.6.4 ; python_full_version >= '3.13'
|
||||
# via ddtrace
|
||||
litellm==1.80.11
|
||||
# via onyx
|
||||
locket==1.0.0
|
||||
@@ -564,7 +555,9 @@ markupsafe==3.0.3
|
||||
# mako
|
||||
# werkzeug
|
||||
marshmallow==3.26.2
|
||||
# via dataclasses-json
|
||||
# via
|
||||
# dataclasses-json
|
||||
# unstructured-client
|
||||
matrix-client==0.3.2
|
||||
# via zulip
|
||||
mcp==1.25.0
|
||||
@@ -605,13 +598,16 @@ mypy-extensions==1.0.0
|
||||
# via
|
||||
# mypy
|
||||
# typing-inspect
|
||||
# unstructured-client
|
||||
nest-asyncio==1.6.0
|
||||
# via onyx
|
||||
# via
|
||||
# onyx
|
||||
# unstructured-client
|
||||
nltk==3.9.1
|
||||
# via
|
||||
# onyx
|
||||
# unstructured
|
||||
numpy==2.4.1
|
||||
numpy==1.26.4
|
||||
# via
|
||||
# magika
|
||||
# onnxruntime
|
||||
@@ -627,9 +623,7 @@ oauthlib==3.2.2
|
||||
office365-rest-python-client==2.5.9
|
||||
# via onyx
|
||||
olefile==0.47
|
||||
# via
|
||||
# msoffcrypto-tool
|
||||
# python-oxmsg
|
||||
# via msoffcrypto-tool
|
||||
onnxruntime==1.20.1
|
||||
# via magika
|
||||
openai==2.14.0
|
||||
@@ -684,6 +678,8 @@ opentelemetry-semantic-conventions==0.60b1
|
||||
# via
|
||||
# opentelemetry-instrumentation
|
||||
# opentelemetry-sdk
|
||||
orderly-set==5.5.0
|
||||
# via deepdiff
|
||||
orjson==3.11.4 ; platform_python_implementation != 'PyPy'
|
||||
# via langsmith
|
||||
packaging==24.2
|
||||
@@ -704,6 +700,7 @@ packaging==24.2
|
||||
# opentelemetry-instrumentation
|
||||
# pytest
|
||||
# pywikibot
|
||||
# unstructured-client
|
||||
pandas==2.2.3
|
||||
# via markitdown
|
||||
parameterized==0.9.0
|
||||
@@ -751,19 +748,7 @@ proto-plus==1.26.1
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==5.29.5 ; python_full_version < '3.14'
|
||||
# via
|
||||
# ddtrace
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# onnxruntime
|
||||
# opentelemetry-proto
|
||||
# proto-plus
|
||||
protobuf==6.33.4 ; python_full_version >= '3.14'
|
||||
protobuf==5.29.5
|
||||
# via
|
||||
# ddtrace
|
||||
# google-api-core
|
||||
@@ -825,7 +810,6 @@ pydantic==2.11.7
|
||||
# openapi-pydantic
|
||||
# pyairtable
|
||||
# pydantic-settings
|
||||
# unstructured-client
|
||||
pydantic-core==2.33.2
|
||||
# via pydantic
|
||||
pydantic-settings==2.12.0
|
||||
@@ -851,7 +835,7 @@ pynacl==1.6.2
|
||||
# via pygithub
|
||||
pyparsing==3.2.5
|
||||
# via httplib2
|
||||
pypdf==6.6.0
|
||||
pypdf==6.1.3
|
||||
# via
|
||||
# onyx
|
||||
# unstructured-client
|
||||
@@ -883,6 +867,7 @@ python-dateutil==2.8.2
|
||||
# onyx
|
||||
# opensearch-py
|
||||
# pandas
|
||||
# unstructured-client
|
||||
python-docx==1.1.2
|
||||
# via onyx
|
||||
python-dotenv==1.1.1
|
||||
@@ -909,8 +894,6 @@ python-multipart==0.0.20
|
||||
# fastapi-users
|
||||
# mcp
|
||||
# onyx
|
||||
python-oxmsg==0.0.2
|
||||
# via unstructured
|
||||
python-pptx==0.6.23
|
||||
# via
|
||||
# markitdown
|
||||
@@ -1002,6 +985,7 @@ requests==2.32.5
|
||||
# stripe
|
||||
# tiktoken
|
||||
# unstructured
|
||||
# unstructured-client
|
||||
# voyageai
|
||||
# zeep
|
||||
# zulip
|
||||
@@ -1061,12 +1045,12 @@ six==1.17.0
|
||||
# atlassian-python-api
|
||||
# dropbox
|
||||
# google-auth-httplib2
|
||||
# html5lib
|
||||
# hubspot-api-client
|
||||
# langdetect
|
||||
# markdownify
|
||||
# python-dateutil
|
||||
# stone
|
||||
# unstructured-client
|
||||
slack-sdk==3.20.2
|
||||
# via onyx
|
||||
smmap==5.0.2
|
||||
@@ -1105,6 +1089,8 @@ supervisor==4.3.0
|
||||
# via onyx
|
||||
sympy==1.13.1
|
||||
# via onnxruntime
|
||||
tabulate==0.9.0
|
||||
# via unstructured
|
||||
tblib==3.2.2
|
||||
# via distributed
|
||||
tenacity==9.1.2
|
||||
@@ -1172,7 +1158,6 @@ typing-extensions==4.15.0
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# jira
|
||||
# langchain-core
|
||||
@@ -1193,7 +1178,6 @@ typing-extensions==4.15.0
|
||||
# pyee
|
||||
# pygithub
|
||||
# python-docx
|
||||
# python-oxmsg
|
||||
# referencing
|
||||
# simple-salesforce
|
||||
# sqlalchemy
|
||||
@@ -1203,9 +1187,12 @@ typing-extensions==4.15.0
|
||||
# typing-inspect
|
||||
# typing-inspection
|
||||
# unstructured
|
||||
# unstructured-client
|
||||
# zulip
|
||||
typing-inspect==0.9.0
|
||||
# via dataclasses-json
|
||||
# via
|
||||
# dataclasses-json
|
||||
# unstructured-client
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
# mcp
|
||||
@@ -1218,9 +1205,9 @@ tzdata==2025.2
|
||||
# tzlocal
|
||||
tzlocal==5.3.1
|
||||
# via dateparser
|
||||
unstructured==0.18.27
|
||||
unstructured==0.15.1
|
||||
# via onyx
|
||||
unstructured-client==0.42.6
|
||||
unstructured-client==0.25.4
|
||||
# via
|
||||
# onyx
|
||||
# unstructured
|
||||
@@ -1242,6 +1229,7 @@ urllib3==2.6.3
|
||||
# sentry-sdk
|
||||
# trafilatura
|
||||
# types-requests
|
||||
# unstructured-client
|
||||
uvicorn==0.35.0
|
||||
# via
|
||||
# fastmcp
|
||||
@@ -1256,8 +1244,6 @@ voyageai==0.2.3
|
||||
# via onyx
|
||||
wcwidth==0.2.14
|
||||
# via prompt-toolkit
|
||||
webencodings==0.5.1
|
||||
# via html5lib
|
||||
websockets==15.0.1
|
||||
# via
|
||||
# fastmcp
|
||||
|
||||
@@ -175,7 +175,7 @@ greenlet==3.2.4 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or
|
||||
# via sqlalchemy
|
||||
grpc-google-iam-v1==0.14.3
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.67.1 ; python_full_version < '3.14'
|
||||
grpcio==1.67.1
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
@@ -183,17 +183,7 @@ grpcio==1.67.1 ; python_full_version < '3.14'
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# litellm
|
||||
grpcio==1.76.0 ; python_full_version >= '3.14'
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# litellm
|
||||
grpcio-status==1.67.1 ; python_full_version < '3.14'
|
||||
# via google-api-core
|
||||
grpcio-status==1.76.0 ; python_full_version >= '3.14'
|
||||
grpcio-status==1.67.1
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via
|
||||
@@ -288,14 +278,14 @@ nest-asyncio==1.6.0
|
||||
# via ipykernel
|
||||
nodeenv==1.9.1
|
||||
# via pre-commit
|
||||
numpy==2.4.1
|
||||
numpy==1.26.4
|
||||
# via
|
||||
# contourpy
|
||||
# matplotlib
|
||||
# pandas-stubs
|
||||
# shapely
|
||||
# voyageai
|
||||
onyx-devtools==0.6.2
|
||||
onyx-devtools==0.2.0
|
||||
# via onyx
|
||||
openai==2.14.0
|
||||
# via
|
||||
@@ -357,16 +347,7 @@ proto-plus==1.26.1
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==5.29.5 ; python_full_version < '3.14'
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# proto-plus
|
||||
protobuf==6.33.4 ; python_full_version >= '3.14'
|
||||
protobuf==5.29.5
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
@@ -565,7 +546,6 @@ typing-extensions==4.15.0
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# ipython
|
||||
# mypy
|
||||
|
||||
@@ -132,7 +132,7 @@ googleapis-common-protos==1.72.0
|
||||
# grpcio-status
|
||||
grpc-google-iam-v1==0.14.3
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.67.1 ; python_full_version < '3.14'
|
||||
grpcio==1.67.1
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
@@ -140,17 +140,7 @@ grpcio==1.67.1 ; python_full_version < '3.14'
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# litellm
|
||||
grpcio==1.76.0 ; python_full_version >= '3.14'
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# litellm
|
||||
grpcio-status==1.67.1 ; python_full_version < '3.14'
|
||||
# via google-api-core
|
||||
grpcio-status==1.76.0 ; python_full_version >= '3.14'
|
||||
grpcio-status==1.67.1
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via
|
||||
@@ -202,7 +192,7 @@ multidict==6.7.0
|
||||
# aiobotocore
|
||||
# aiohttp
|
||||
# yarl
|
||||
numpy==2.4.1
|
||||
numpy==1.26.4
|
||||
# via
|
||||
# shapely
|
||||
# voyageai
|
||||
@@ -234,16 +224,7 @@ proto-plus==1.26.1
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==5.29.5 ; python_full_version < '3.14'
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# proto-plus
|
||||
protobuf==6.33.4 ; python_full_version >= '3.14'
|
||||
protobuf==5.29.5
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
@@ -348,7 +329,6 @@ typing-extensions==4.15.0
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# openai
|
||||
# pydantic
|
||||
|
||||
@@ -157,7 +157,7 @@ googleapis-common-protos==1.72.0
|
||||
# grpcio-status
|
||||
grpc-google-iam-v1==0.14.3
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.67.1 ; python_full_version < '3.14'
|
||||
grpcio==1.67.1
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
@@ -165,17 +165,7 @@ grpcio==1.67.1 ; python_full_version < '3.14'
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# litellm
|
||||
grpcio==1.76.0 ; python_full_version >= '3.14'
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# litellm
|
||||
grpcio-status==1.67.1 ; python_full_version < '3.14'
|
||||
# via google-api-core
|
||||
grpcio-status==1.76.0 ; python_full_version >= '3.14'
|
||||
grpcio-status==1.67.1
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via
|
||||
@@ -239,7 +229,7 @@ multidict==6.7.0
|
||||
# yarl
|
||||
networkx==3.5
|
||||
# via torch
|
||||
numpy==2.4.1
|
||||
numpy==1.26.4
|
||||
# via
|
||||
# accelerate
|
||||
# onyx
|
||||
@@ -316,16 +306,7 @@ proto-plus==1.26.1
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==5.29.5 ; python_full_version < '3.14'
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# proto-plus
|
||||
protobuf==6.33.4 ; python_full_version >= '3.14'
|
||||
protobuf==5.29.5
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
@@ -469,7 +450,6 @@ typing-extensions==4.15.0
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# openai
|
||||
# pydantic
|
||||
|
||||
@@ -34,7 +34,6 @@ from scripts.tenant_cleanup.cleanup_utils import execute_control_plane_query
|
||||
from scripts.tenant_cleanup.cleanup_utils import find_worker_pod
|
||||
from scripts.tenant_cleanup.cleanup_utils import get_tenant_status
|
||||
from scripts.tenant_cleanup.cleanup_utils import read_tenant_ids_from_csv
|
||||
from scripts.tenant_cleanup.cleanup_utils import TenantNotFoundInControlPlaneError
|
||||
|
||||
|
||||
def signal_handler(signum: int, frame: object) -> None:
|
||||
@@ -419,9 +418,6 @@ def cleanup_tenant(tenant_id: str, pod_name: str, force: bool = False) -> bool:
|
||||
"""
|
||||
print(f"Starting cleanup for tenant: {tenant_id}")
|
||||
|
||||
# Track if tenant was not found in control plane (for force mode)
|
||||
tenant_not_found_in_control_plane = False
|
||||
|
||||
# Check tenant status first
|
||||
print(f"\n{'=' * 80}")
|
||||
try:
|
||||
@@ -461,25 +457,8 @@ def cleanup_tenant(tenant_id: str, pod_name: str, force: bool = False) -> bool:
|
||||
if response.lower() != "yes":
|
||||
print("Cleanup aborted - could not verify tenant status")
|
||||
return False
|
||||
except TenantNotFoundInControlPlaneError as e:
|
||||
# Tenant/table not found in control plane
|
||||
error_str = str(e)
|
||||
print(f"⚠️ WARNING: Tenant not found in control plane: {error_str}")
|
||||
tenant_not_found_in_control_plane = True
|
||||
|
||||
if force:
|
||||
print(
|
||||
"[FORCE MODE] Tenant not found in control plane - continuing with dataplane cleanup only"
|
||||
)
|
||||
else:
|
||||
response = input("Continue anyway? Type 'yes' to confirm: ")
|
||||
if response.lower() != "yes":
|
||||
print("Cleanup aborted - tenant not found in control plane")
|
||||
return False
|
||||
except Exception as e:
|
||||
# Other errors (not "not found")
|
||||
error_str = str(e)
|
||||
print(f"⚠️ WARNING: Failed to check tenant status: {error_str}")
|
||||
print(f"⚠️ WARNING: Failed to check tenant status: {e}")
|
||||
|
||||
if force:
|
||||
print(f"Skipping cleanup for tenant {tenant_id} in force mode")
|
||||
@@ -537,14 +516,8 @@ def cleanup_tenant(tenant_id: str, pod_name: str, force: bool = False) -> bool:
|
||||
else:
|
||||
print("Step 2 skipped by user")
|
||||
|
||||
# Step 3: Clean up control plane (skip if tenant not found in control plane with --force)
|
||||
if tenant_not_found_in_control_plane:
|
||||
print(f"\n{'=' * 80}")
|
||||
print(
|
||||
"Step 3/3: Skipping control plane cleanup (tenant not found in control plane)"
|
||||
)
|
||||
print(f"{'=' * 80}\n")
|
||||
elif confirm_step(
|
||||
# Step 3: Clean up control plane
|
||||
if confirm_step(
|
||||
"Step 3/3: Delete control plane records (tenant_notification, tenant_config, subscription, tenant)",
|
||||
force,
|
||||
):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user