mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-01 21:55:46 +00:00
Compare commits
63 Commits
fix/chat-h
...
v2.9.8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f1c30974f5 | ||
|
|
81bf07fb15 | ||
|
|
b565bf8291 | ||
|
|
b4da99cbdd | ||
|
|
f910feea0f | ||
|
|
e3af8c6c8a | ||
|
|
d6e46ed792 | ||
|
|
4ce1f4ecdd | ||
|
|
a4678884d7 | ||
|
|
c861ba68f1 | ||
|
|
b1d0e0bb0b | ||
|
|
0d78bf52e3 | ||
|
|
bd743282e6 | ||
|
|
d44d1d92b3 | ||
|
|
4cedcfee59 | ||
|
|
90a721a76e | ||
|
|
3ccd99e931 | ||
|
|
9076bf603f | ||
|
|
8c6e0a70c3 | ||
|
|
bebe9555d4 | ||
|
|
c530722c9f | ||
|
|
68380b4ddb | ||
|
|
b3380746ab | ||
|
|
56be114c87 | ||
|
|
54f467da5c | ||
|
|
8726b112fe | ||
|
|
92181d07b2 | ||
|
|
3a73f7fab2 | ||
|
|
7dabaca7cd | ||
|
|
dec4748825 | ||
|
|
072836cd86 | ||
|
|
2705b5fb0e | ||
|
|
37dcde4226 | ||
|
|
a765b5f622 | ||
|
|
5e093368d1 | ||
|
|
f945ab6b05 | ||
|
|
11b7a22404 | ||
|
|
8e34f944cc | ||
|
|
32606dc752 | ||
|
|
1f6c4b40bf | ||
|
|
1943f1c745 | ||
|
|
82460729a6 | ||
|
|
c445e6a8c0 | ||
|
|
8d30a03d7f | ||
|
|
277428f579 | ||
|
|
9f8c0d4237 | ||
|
|
9ccbb6a04b | ||
|
|
58a943f782 | ||
|
|
9021c607f2 | ||
|
|
c03b0d80fd | ||
|
|
fcf0b316a4 | ||
|
|
157f672b4b | ||
|
|
51b9484b96 | ||
|
|
0c8f55c049 | ||
|
|
c7be2571d1 | ||
|
|
4948b6cca9 | ||
|
|
638ea5f316 | ||
|
|
6e3268ca75 | ||
|
|
d8921df60c | ||
|
|
693d9f5f69 | ||
|
|
02e17871cc | ||
|
|
209cfd00b0 | ||
|
|
cd36baa484 |
389
.github/workflows/deployment.yml
vendored
389
.github/workflows/deployment.yml
vendored
@@ -8,7 +8,9 @@ on:
|
||||
|
||||
# Set restrictive default permissions for all jobs. Jobs that need more permissions
|
||||
# should explicitly declare them.
|
||||
permissions: {}
|
||||
permissions:
|
||||
# Required for OIDC authentication with AWS
|
||||
id-token: write # zizmor: ignore[excessive-permissions]
|
||||
|
||||
env:
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
|
||||
@@ -150,16 +152,30 @@ jobs:
|
||||
if: always() && needs.check-version-tag.result == 'failure' && github.event_name != 'workflow_dispatch'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
environment: release
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
MONITOR_DEPLOYMENTS_WEBHOOK, deploy/monitor-deployments-webhook
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
webhook-url: ${{ env.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
failed-jobs: "• check-version-tag"
|
||||
title: "🚨 Version Tag Check Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
@@ -168,6 +184,7 @@ jobs:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-desktop == 'true'
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: write
|
||||
actions: read
|
||||
strategy:
|
||||
@@ -185,12 +202,33 @@ jobs:
|
||||
|
||||
runs-on: ${{ matrix.platform }}
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6.0.1
|
||||
with:
|
||||
# NOTE: persist-credentials is needed for tauri-action to create GitHub releases.
|
||||
persist-credentials: true # zizmor: ignore[artipacked]
|
||||
|
||||
- name: Configure AWS credentials
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
APPLE_ID, deploy/apple-id
|
||||
APPLE_PASSWORD, deploy/apple-password
|
||||
APPLE_CERTIFICATE, deploy/apple-certificate
|
||||
APPLE_CERTIFICATE_PASSWORD, deploy/apple-certificate-password
|
||||
KEYCHAIN_PASSWORD, deploy/keychain-password
|
||||
APPLE_TEAM_ID, deploy/apple-team-id
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: install dependencies (ubuntu only)
|
||||
if: startsWith(matrix.platform, 'ubuntu-')
|
||||
run: |
|
||||
@@ -285,15 +323,40 @@ jobs:
|
||||
|
||||
Write-Host "Versions set to: $VERSION"
|
||||
|
||||
- uses: tauri-apps/tauri-action@19b93bb55601e3e373a93cfb6eb4242e45f5af20 # ratchet:tauri-apps/tauri-action@action-v0.6.0
|
||||
- name: Import Apple Developer Certificate
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
run: |
|
||||
echo $APPLE_CERTIFICATE | base64 --decode > certificate.p12
|
||||
security create-keychain -p "$KEYCHAIN_PASSWORD" build.keychain
|
||||
security default-keychain -s build.keychain
|
||||
security unlock-keychain -p "$KEYCHAIN_PASSWORD" build.keychain
|
||||
security set-keychain-settings -t 3600 -u build.keychain
|
||||
security import certificate.p12 -k build.keychain -P "$APPLE_CERTIFICATE_PASSWORD" -T /usr/bin/codesign
|
||||
security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k "$KEYCHAIN_PASSWORD" build.keychain
|
||||
security find-identity -v -p codesigning build.keychain
|
||||
|
||||
- name: Verify Certificate
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
run: |
|
||||
CERT_INFO=$(security find-identity -v -p codesigning build.keychain | grep -E "(Developer ID Application|Apple Distribution|Apple Development)" | head -n 1)
|
||||
CERT_ID=$(echo "$CERT_INFO" | awk -F'"' '{print $2}')
|
||||
echo "CERT_ID=$CERT_ID" >> $GITHUB_ENV
|
||||
echo "Certificate imported."
|
||||
|
||||
- uses: tauri-apps/tauri-action@73fb865345c54760d875b94642314f8c0c894afa # ratchet:tauri-apps/tauri-action@action-v0.6.1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
APPLE_ID: ${{ env.APPLE_ID }}
|
||||
APPLE_PASSWORD: ${{ env.APPLE_PASSWORD }}
|
||||
APPLE_SIGNING_IDENTITY: ${{ env.CERT_ID }}
|
||||
APPLE_TEAM_ID: ${{ env.APPLE_TEAM_ID }}
|
||||
with:
|
||||
tagName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseBody: "See the assets to download this version and install."
|
||||
releaseDraft: true
|
||||
prerelease: false
|
||||
assetNamePattern: "[name]_[arch][ext]"
|
||||
args: ${{ matrix.args }}
|
||||
|
||||
build-web-amd64:
|
||||
@@ -305,6 +368,7 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-web-amd64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -317,6 +381,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -331,8 +409,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
@@ -363,6 +441,7 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-web-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -375,6 +454,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -389,8 +482,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
@@ -423,19 +516,34 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-merge-web
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -471,6 +579,7 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-web-cloud-amd64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -483,6 +592,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -497,8 +620,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
@@ -537,6 +660,7 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-web-cloud-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -549,6 +673,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -563,8 +701,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
@@ -605,19 +743,34 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-merge-web-cloud
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -650,6 +803,7 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-backend-amd64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -662,6 +816,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -676,8 +844,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
@@ -707,6 +875,7 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-backend-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -719,6 +888,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -733,8 +916,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
@@ -766,19 +949,34 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-merge-backend
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -815,6 +1013,7 @@ jobs:
|
||||
- volume=40gb
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -827,6 +1026,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -843,8 +1056,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
@@ -879,6 +1092,7 @@ jobs:
|
||||
- volume=40gb
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -891,6 +1105,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -907,8 +1135,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
@@ -944,19 +1172,34 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-merge-model-server
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -994,11 +1237,26 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-trivy-scan-web
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
@@ -1014,8 +1272,8 @@ jobs:
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
@@ -1034,11 +1292,26 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-trivy-scan-web-cloud
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
@@ -1054,8 +1327,8 @@ jobs:
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
@@ -1074,6 +1347,7 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-trivy-scan-backend
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
steps:
|
||||
@@ -1084,6 +1358,20 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
@@ -1100,8 +1388,8 @@ jobs:
|
||||
-v ${{ github.workspace }}/backend/.trivyignore:/tmp/.trivyignore:ro \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
@@ -1121,11 +1409,26 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-trivy-scan-model-server
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
@@ -1141,8 +1444,8 @@ jobs:
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
@@ -1170,12 +1473,26 @@ jobs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
MONITOR_DEPLOYMENTS_WEBHOOK, deploy/monitor-deployments-webhook
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Determine failed jobs
|
||||
id: failed-jobs
|
||||
shell: bash
|
||||
@@ -1241,7 +1558,7 @@ jobs:
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
webhook-url: ${{ env.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
failed-jobs: ${{ steps.failed-jobs.outputs.jobs }}
|
||||
title: "🚨 Deployment Workflow Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
|
||||
@@ -13,7 +13,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # ratchet:actions/stale@v10
|
||||
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # ratchet:actions/stale@v10
|
||||
with:
|
||||
stale-issue-message: 'This issue is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
|
||||
stale-pr-message: 'This PR is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
|
||||
|
||||
@@ -172,7 +172,7 @@ jobs:
|
||||
|
||||
- name: Upload Docker logs
|
||||
if: failure()
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v5
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-logs-${{ matrix.test-dir }}
|
||||
path: docker-logs/
|
||||
|
||||
4
.github/workflows/pr-integration-tests.yml
vendored
4
.github/workflows/pr-integration-tests.yml
vendored
@@ -439,7 +439,7 @@ jobs:
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-${{ matrix.test-dir.name }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
@@ -568,7 +568,7 @@ jobs:
|
||||
|
||||
- name: Upload logs (multi-tenant)
|
||||
if: always()
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-multitenant
|
||||
path: ${{ github.workspace }}/docker-compose-multitenant.log
|
||||
|
||||
2
.github/workflows/pr-jest-tests.yml
vendored
2
.github/workflows/pr-jest-tests.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports
|
||||
if: always()
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: jest-coverage-${{ github.run_id }}
|
||||
path: ./web/coverage
|
||||
|
||||
@@ -424,7 +424,7 @@ jobs:
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-${{ matrix.test-dir.name }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
4
.github/workflows/pr-playwright-tests.yml
vendored
4
.github/workflows/pr-playwright-tests.yml
vendored
@@ -435,7 +435,7 @@ jobs:
|
||||
fi
|
||||
npx playwright test --project ${PROJECT}
|
||||
|
||||
- uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
if: always()
|
||||
with:
|
||||
# Includes test results and trace.zip files
|
||||
@@ -455,7 +455,7 @@ jobs:
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
3
.github/workflows/pr-python-checks.yml
vendored
3
.github/workflows/pr-python-checks.yml
vendored
@@ -50,8 +50,9 @@ jobs:
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: backend/.mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
restore-keys: |
|
||||
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
|
||||
mypy-${{ runner.os }}-
|
||||
|
||||
- name: Run MyPy
|
||||
|
||||
2
.github/workflows/pr-python-model-tests.yml
vendored
2
.github/workflows/pr-python-model-tests.yml
vendored
@@ -144,7 +144,7 @@ jobs:
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -21,6 +21,7 @@ backend/tests/regression/search_quality/*.json
|
||||
backend/onyx/evals/data/
|
||||
backend/onyx/evals/one_off/*.json
|
||||
*.log
|
||||
*.csv
|
||||
|
||||
# secret files
|
||||
.env
|
||||
|
||||
@@ -11,7 +11,6 @@ repos:
|
||||
- id: uv-sync
|
||||
args: ["--locked", "--all-extras"]
|
||||
- id: uv-lock
|
||||
files: ^pyproject\.toml$
|
||||
- id: uv-export
|
||||
name: uv-export default.txt
|
||||
args:
|
||||
|
||||
@@ -225,7 +225,6 @@ def do_run_migrations(
|
||||
) -> None:
|
||||
if create_schema:
|
||||
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
|
||||
connection.execute(text("COMMIT"))
|
||||
|
||||
connection.execute(text(f'SET search_path TO "{schema_name}"'))
|
||||
|
||||
@@ -309,6 +308,7 @@ async def run_async_migrations() -> None:
|
||||
schema_name=schema,
|
||||
create_schema=create_schema,
|
||||
)
|
||||
await connection.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema}: {e}")
|
||||
if not continue_on_error:
|
||||
@@ -346,6 +346,7 @@ async def run_async_migrations() -> None:
|
||||
schema_name=schema,
|
||||
create_schema=create_schema,
|
||||
)
|
||||
await connection.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema}: {e}")
|
||||
if not continue_on_error:
|
||||
|
||||
@@ -85,103 +85,122 @@ class UserRow(NamedTuple):
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Start transaction
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
# Step 1: Create or update the unified assistant (ID 0)
|
||||
search_assistant = conn.execute(
|
||||
sa.text("SELECT * FROM persona WHERE id = 0")
|
||||
).fetchone()
|
||||
|
||||
try:
|
||||
# Step 1: Create or update the unified assistant (ID 0)
|
||||
search_assistant = conn.execute(
|
||||
sa.text("SELECT * FROM persona WHERE id = 0")
|
||||
).fetchone()
|
||||
|
||||
if search_assistant:
|
||||
# Update existing Search assistant to be the unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET name = :name,
|
||||
description = :description,
|
||||
system_prompt = :system_prompt,
|
||||
num_chunks = :num_chunks,
|
||||
is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false,
|
||||
display_priority = :display_priority,
|
||||
llm_filter_extraction = :llm_filter_extraction,
|
||||
llm_relevance_filter = :llm_relevance_filter,
|
||||
recency_bias = :recency_bias,
|
||||
chunks_above = :chunks_above,
|
||||
chunks_below = :chunks_below,
|
||||
datetime_aware = :datetime_aware,
|
||||
starter_messages = null
|
||||
WHERE id = 0
|
||||
"""
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
else:
|
||||
# Create new unified assistant with ID 0
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona (
|
||||
id, name, description, system_prompt, num_chunks,
|
||||
is_default_persona, is_visible, deleted, display_priority,
|
||||
llm_filter_extraction, llm_relevance_filter, recency_bias,
|
||||
chunks_above, chunks_below, datetime_aware, starter_messages,
|
||||
builtin_persona
|
||||
) VALUES (
|
||||
0, :name, :description, :system_prompt, :num_chunks,
|
||||
true, true, false, :display_priority, :llm_filter_extraction,
|
||||
:llm_relevance_filter, :recency_bias, :chunks_above, :chunks_below,
|
||||
:datetime_aware, null, true
|
||||
)
|
||||
"""
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
|
||||
# Step 2: Mark ALL builtin assistants as deleted (except the unified assistant ID 0)
|
||||
if search_assistant:
|
||||
# Update existing Search assistant to be the unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = true, is_visible = false, is_default_persona = false
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
SET name = :name,
|
||||
description = :description,
|
||||
system_prompt = :system_prompt,
|
||||
num_chunks = :num_chunks,
|
||||
is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false,
|
||||
display_priority = :display_priority,
|
||||
llm_filter_extraction = :llm_filter_extraction,
|
||||
llm_relevance_filter = :llm_relevance_filter,
|
||||
recency_bias = :recency_bias,
|
||||
chunks_above = :chunks_above,
|
||||
chunks_below = :chunks_below,
|
||||
datetime_aware = :datetime_aware,
|
||||
starter_messages = null
|
||||
WHERE id = 0
|
||||
"""
|
||||
)
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
else:
|
||||
# Create new unified assistant with ID 0
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona (
|
||||
id, name, description, system_prompt, num_chunks,
|
||||
is_default_persona, is_visible, deleted, display_priority,
|
||||
llm_filter_extraction, llm_relevance_filter, recency_bias,
|
||||
chunks_above, chunks_below, datetime_aware, starter_messages,
|
||||
builtin_persona
|
||||
) VALUES (
|
||||
0, :name, :description, :system_prompt, :num_chunks,
|
||||
true, true, false, :display_priority, :llm_filter_extraction,
|
||||
:llm_relevance_filter, :recency_bias, :chunks_above, :chunks_below,
|
||||
:datetime_aware, null, true
|
||||
)
|
||||
"""
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
|
||||
# Step 3: Add all built-in tools to the unified assistant
|
||||
# First, get the tool IDs for SearchTool, ImageGenerationTool, and WebSearchTool
|
||||
search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'SearchTool'")
|
||||
).fetchone()
|
||||
# Step 2: Mark ALL builtin assistants as deleted (except the unified assistant ID 0)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = true, is_visible = false, is_default_persona = false
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
if not search_tool:
|
||||
raise ValueError(
|
||||
"SearchTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
# Step 3: Add all built-in tools to the unified assistant
|
||||
# First, get the tool IDs for SearchTool, ImageGenerationTool, and WebSearchTool
|
||||
search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'SearchTool'")
|
||||
).fetchone()
|
||||
|
||||
image_gen_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'ImageGenerationTool'")
|
||||
).fetchone()
|
||||
if not search_tool:
|
||||
raise ValueError(
|
||||
"SearchTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
|
||||
if not image_gen_tool:
|
||||
raise ValueError(
|
||||
"ImageGenerationTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
image_gen_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'ImageGenerationTool'")
|
||||
).fetchone()
|
||||
|
||||
# WebSearchTool is optional - may not be configured
|
||||
web_search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'WebSearchTool'")
|
||||
).fetchone()
|
||||
if not image_gen_tool:
|
||||
raise ValueError(
|
||||
"ImageGenerationTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
|
||||
# Clear existing tool associations for persona 0
|
||||
conn.execute(sa.text("DELETE FROM persona__tool WHERE persona_id = 0"))
|
||||
# WebSearchTool is optional - may not be configured
|
||||
web_search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'WebSearchTool'")
|
||||
).fetchone()
|
||||
|
||||
# Add tools to the unified assistant
|
||||
# Clear existing tool associations for persona 0
|
||||
conn.execute(sa.text("DELETE FROM persona__tool WHERE persona_id = 0"))
|
||||
|
||||
# Add tools to the unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": search_tool[0]},
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": image_gen_tool[0]},
|
||||
)
|
||||
|
||||
if web_search_tool:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
@@ -190,191 +209,148 @@ def upgrade() -> None:
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": search_tool[0]},
|
||||
{"tool_id": web_search_tool[0]},
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
# Step 4: Migrate existing chat sessions from all builtin assistants to unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
),
|
||||
{"tool_id": image_gen_tool[0]},
|
||||
UPDATE chat_session
|
||||
SET persona_id = 0
|
||||
WHERE persona_id IN (
|
||||
SELECT id FROM persona WHERE builtin_persona = true AND id != 0
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
if web_search_tool:
|
||||
# Step 5: Migrate user preferences - remove references to all builtin assistants
|
||||
# First, get all builtin assistant IDs (except 0)
|
||||
builtin_assistants_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id FROM persona
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
builtin_assistant_ids = [row[0] for row in builtin_assistants_result]
|
||||
|
||||
# Get all users with preferences
|
||||
users_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, chosen_assistants, visible_assistants,
|
||||
hidden_assistants, pinned_assistants
|
||||
FROM "user"
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for user_row in users_result:
|
||||
user = UserRow(*user_row)
|
||||
user_id: UUID = user.id
|
||||
updates: dict[str, Any] = {}
|
||||
|
||||
# Remove all builtin assistants from chosen_assistants
|
||||
if user.chosen_assistants:
|
||||
new_chosen: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.chosen_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_chosen != user.chosen_assistants:
|
||||
updates["chosen_assistants"] = json.dumps(new_chosen)
|
||||
|
||||
# Remove all builtin assistants from visible_assistants
|
||||
if user.visible_assistants:
|
||||
new_visible: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.visible_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_visible != user.visible_assistants:
|
||||
updates["visible_assistants"] = json.dumps(new_visible)
|
||||
|
||||
# Add all builtin assistants to hidden_assistants
|
||||
if user.hidden_assistants:
|
||||
new_hidden: list[int] = list(user.hidden_assistants)
|
||||
for old_id in builtin_assistant_ids:
|
||||
if old_id not in new_hidden:
|
||||
new_hidden.append(old_id)
|
||||
if new_hidden != user.hidden_assistants:
|
||||
updates["hidden_assistants"] = json.dumps(new_hidden)
|
||||
else:
|
||||
updates["hidden_assistants"] = json.dumps(builtin_assistant_ids)
|
||||
|
||||
# Remove all builtin assistants from pinned_assistants
|
||||
if user.pinned_assistants:
|
||||
new_pinned: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.pinned_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_pinned != user.pinned_assistants:
|
||||
updates["pinned_assistants"] = json.dumps(new_pinned)
|
||||
|
||||
# Apply updates if any
|
||||
if updates:
|
||||
set_clause = ", ".join([f"{k} = :{k}" for k in updates.keys()])
|
||||
updates["user_id"] = str(user_id) # Convert UUID to string for SQL
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": web_search_tool[0]},
|
||||
sa.text(f'UPDATE "user" SET {set_clause} WHERE id = :user_id'),
|
||||
updates,
|
||||
)
|
||||
|
||||
# Step 4: Migrate existing chat sessions from all builtin assistants to unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE chat_session
|
||||
SET persona_id = 0
|
||||
WHERE persona_id IN (
|
||||
SELECT id FROM persona WHERE builtin_persona = true AND id != 0
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Step 5: Migrate user preferences - remove references to all builtin assistants
|
||||
# First, get all builtin assistant IDs (except 0)
|
||||
builtin_assistants_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id FROM persona
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
builtin_assistant_ids = [row[0] for row in builtin_assistants_result]
|
||||
|
||||
# Get all users with preferences
|
||||
users_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, chosen_assistants, visible_assistants,
|
||||
hidden_assistants, pinned_assistants
|
||||
FROM "user"
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for user_row in users_result:
|
||||
user = UserRow(*user_row)
|
||||
user_id: UUID = user.id
|
||||
updates: dict[str, Any] = {}
|
||||
|
||||
# Remove all builtin assistants from chosen_assistants
|
||||
if user.chosen_assistants:
|
||||
new_chosen: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.chosen_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_chosen != user.chosen_assistants:
|
||||
updates["chosen_assistants"] = json.dumps(new_chosen)
|
||||
|
||||
# Remove all builtin assistants from visible_assistants
|
||||
if user.visible_assistants:
|
||||
new_visible: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.visible_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_visible != user.visible_assistants:
|
||||
updates["visible_assistants"] = json.dumps(new_visible)
|
||||
|
||||
# Add all builtin assistants to hidden_assistants
|
||||
if user.hidden_assistants:
|
||||
new_hidden: list[int] = list(user.hidden_assistants)
|
||||
for old_id in builtin_assistant_ids:
|
||||
if old_id not in new_hidden:
|
||||
new_hidden.append(old_id)
|
||||
if new_hidden != user.hidden_assistants:
|
||||
updates["hidden_assistants"] = json.dumps(new_hidden)
|
||||
else:
|
||||
updates["hidden_assistants"] = json.dumps(builtin_assistant_ids)
|
||||
|
||||
# Remove all builtin assistants from pinned_assistants
|
||||
if user.pinned_assistants:
|
||||
new_pinned: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.pinned_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_pinned != user.pinned_assistants:
|
||||
updates["pinned_assistants"] = json.dumps(new_pinned)
|
||||
|
||||
# Apply updates if any
|
||||
if updates:
|
||||
set_clause = ", ".join([f"{k} = :{k}" for k in updates.keys()])
|
||||
updates["user_id"] = str(user_id) # Convert UUID to string for SQL
|
||||
conn.execute(
|
||||
sa.text(f'UPDATE "user" SET {set_clause} WHERE id = :user_id'),
|
||||
updates,
|
||||
)
|
||||
|
||||
# Commit transaction
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on error
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Start transaction
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
try:
|
||||
# Only restore General (ID -1) and Art (ID -3) assistants
|
||||
# Step 1: Keep Search assistant (ID 0) as default but restore original state
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false
|
||||
WHERE id = 0
|
||||
# Only restore General (ID -1) and Art (ID -3) assistants
|
||||
# Step 1: Keep Search assistant (ID 0) as default but restore original state
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
)
|
||||
UPDATE persona
|
||||
SET is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false
|
||||
WHERE id = 0
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Step 2: Restore General assistant (ID -1)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :general_assistant_id
|
||||
# Step 2: Restore General assistant (ID -1)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
),
|
||||
{"general_assistant_id": GENERAL_ASSISTANT_ID},
|
||||
)
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :general_assistant_id
|
||||
"""
|
||||
),
|
||||
{"general_assistant_id": GENERAL_ASSISTANT_ID},
|
||||
)
|
||||
|
||||
# Step 3: Restore Art assistant (ID -3)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :art_assistant_id
|
||||
# Step 3: Restore Art assistant (ID -3)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
),
|
||||
{"art_assistant_id": ART_ASSISTANT_ID},
|
||||
)
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :art_assistant_id
|
||||
"""
|
||||
),
|
||||
{"art_assistant_id": ART_ASSISTANT_ID},
|
||||
)
|
||||
|
||||
# Note: We don't restore the original tool associations, names, or descriptions
|
||||
# as those would require more complex logic to determine original state.
|
||||
# We also cannot restore original chat session persona_ids as we don't
|
||||
# have the original mappings.
|
||||
# Other builtin assistants remain deleted as per the requirement.
|
||||
|
||||
# Commit transaction
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on error
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
# Note: We don't restore the original tool associations, names, or descriptions
|
||||
# as those would require more complex logic to determine original state.
|
||||
# We also cannot restore original chat session persona_ids as we don't
|
||||
# have the original mappings.
|
||||
# Other builtin assistants remain deleted as per the requirement.
|
||||
|
||||
@@ -24,6 +24,9 @@ def upgrade() -> None:
|
||||
# in unique constraints, but we want NULL == NULL for deduplication).
|
||||
# The '{}' represents an empty JSONB object as the NULL replacement.
|
||||
|
||||
# Clean up legacy notifications first
|
||||
op.execute("DELETE FROM notification WHERE title = 'New Notification'")
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS ix_notification_user_type_data
|
||||
@@ -40,9 +43,6 @@ def upgrade() -> None:
|
||||
"""
|
||||
)
|
||||
|
||||
# Clean up legacy 'reindex' notifications that are no longer needed
|
||||
op.execute("DELETE FROM notification WHERE title = 'New Notification'")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS ix_notification_user_type_data")
|
||||
|
||||
@@ -42,20 +42,13 @@ TOOL_DESCRIPTIONS = {
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
try:
|
||||
for tool_id, description in TOOL_DESCRIPTIONS.items():
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"UPDATE tool SET description = :description WHERE in_code_tool_id = :tool_id"
|
||||
),
|
||||
{"description": description, "tool_id": tool_id},
|
||||
)
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
except Exception as e:
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
for tool_id, description in TOOL_DESCRIPTIONS.items():
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"UPDATE tool SET description = :description WHERE in_code_tool_id = :tool_id"
|
||||
),
|
||||
{"description": description, "tool_id": tool_id},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
@@ -7,7 +7,6 @@ Create Date: 2025-12-18 16:00:00.000000
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_DB_NAME
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
@@ -19,7 +18,7 @@ depends_on = None
|
||||
|
||||
|
||||
DEEP_RESEARCH_TOOL = {
|
||||
"name": RESEARCH_AGENT_DB_NAME,
|
||||
"name": "ResearchAgent",
|
||||
"display_name": "Research Agent",
|
||||
"description": "The Research Agent is a sub-agent that conducts research on a specific topic.",
|
||||
"in_code_tool_id": "ResearchAgent",
|
||||
|
||||
@@ -70,80 +70,66 @@ BUILT_IN_TOOLS = [
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Start transaction
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
# Get existing tools to check what already exists
|
||||
existing_tools = conn.execute(
|
||||
sa.text("SELECT in_code_tool_id FROM tool WHERE in_code_tool_id IS NOT NULL")
|
||||
).fetchall()
|
||||
existing_tool_ids = {row[0] for row in existing_tools}
|
||||
|
||||
try:
|
||||
# Get existing tools to check what already exists
|
||||
existing_tools = conn.execute(
|
||||
sa.text(
|
||||
"SELECT in_code_tool_id FROM tool WHERE in_code_tool_id IS NOT NULL"
|
||||
# Insert or update built-in tools
|
||||
for tool in BUILT_IN_TOOLS:
|
||||
in_code_id = tool["in_code_tool_id"]
|
||||
|
||||
# Handle historical rename: InternetSearchTool -> WebSearchTool
|
||||
if (
|
||||
in_code_id == "WebSearchTool"
|
||||
and "WebSearchTool" not in existing_tool_ids
|
||||
and "InternetSearchTool" in existing_tool_ids
|
||||
):
|
||||
# Rename the existing InternetSearchTool row in place and update fields
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description,
|
||||
in_code_tool_id = :in_code_tool_id
|
||||
WHERE in_code_tool_id = 'InternetSearchTool'
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
).fetchall()
|
||||
existing_tool_ids = {row[0] for row in existing_tools}
|
||||
# Keep the local view of existing ids in sync to avoid duplicate insert
|
||||
existing_tool_ids.discard("InternetSearchTool")
|
||||
existing_tool_ids.add("WebSearchTool")
|
||||
continue
|
||||
|
||||
# Insert or update built-in tools
|
||||
for tool in BUILT_IN_TOOLS:
|
||||
in_code_id = tool["in_code_tool_id"]
|
||||
|
||||
# Handle historical rename: InternetSearchTool -> WebSearchTool
|
||||
if (
|
||||
in_code_id == "WebSearchTool"
|
||||
and "WebSearchTool" not in existing_tool_ids
|
||||
and "InternetSearchTool" in existing_tool_ids
|
||||
):
|
||||
# Rename the existing InternetSearchTool row in place and update fields
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description,
|
||||
in_code_tool_id = :in_code_tool_id
|
||||
WHERE in_code_tool_id = 'InternetSearchTool'
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
# Keep the local view of existing ids in sync to avoid duplicate insert
|
||||
existing_tool_ids.discard("InternetSearchTool")
|
||||
existing_tool_ids.add("WebSearchTool")
|
||||
continue
|
||||
|
||||
if in_code_id in existing_tool_ids:
|
||||
# Update existing tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
else:
|
||||
# Insert new tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id)
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
|
||||
# Commit transaction
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on error
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
if in_code_id in existing_tool_ids:
|
||||
# Update existing tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
else:
|
||||
# Insert new tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id)
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
"""sync_exa_api_key_to_content_provider
|
||||
|
||||
Revision ID: d1b637d7050a
|
||||
Revises: d25168c2beee
|
||||
Create Date: 2026-01-09 15:54:15.646249
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
from sqlalchemy import text
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d1b637d7050a"
|
||||
down_revision = "d25168c2beee"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Exa uses a shared API key between search and content providers.
|
||||
# For existing Exa search providers with API keys, create the corresponding
|
||||
# content provider if it doesn't exist yet.
|
||||
connection = op.get_bind()
|
||||
|
||||
# Check if Exa search provider exists with an API key
|
||||
result = connection.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT api_key FROM internet_search_provider
|
||||
WHERE provider_type = 'exa' AND api_key IS NOT NULL
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
)
|
||||
row = result.fetchone()
|
||||
|
||||
if row:
|
||||
api_key = row[0]
|
||||
# Create Exa content provider with the shared key
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO internet_content_provider
|
||||
(name, provider_type, api_key, is_active)
|
||||
VALUES ('Exa', 'exa', :api_key, false)
|
||||
ON CONFLICT (name) DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"api_key": api_key},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the Exa content provider that was created by this migration
|
||||
connection = op.get_bind()
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
DELETE FROM internet_content_provider
|
||||
WHERE provider_type = 'exa'
|
||||
"""
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,86 @@
|
||||
"""tool_name_consistency
|
||||
|
||||
Revision ID: d25168c2beee
|
||||
Revises: 8405ca81cc83
|
||||
Create Date: 2026-01-11 17:54:40.135777
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d25168c2beee"
|
||||
down_revision = "8405ca81cc83"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
# Currently the seeded tools have the in_code_tool_id == name
|
||||
CURRENT_TOOL_NAME_MAPPING = [
|
||||
"SearchTool",
|
||||
"WebSearchTool",
|
||||
"ImageGenerationTool",
|
||||
"PythonTool",
|
||||
"OpenURLTool",
|
||||
"KnowledgeGraphTool",
|
||||
"ResearchAgent",
|
||||
]
|
||||
|
||||
# Mapping of in_code_tool_id -> name
|
||||
# These are the expected names that we want in the database
|
||||
EXPECTED_TOOL_NAME_MAPPING = {
|
||||
"SearchTool": "internal_search",
|
||||
"WebSearchTool": "web_search",
|
||||
"ImageGenerationTool": "generate_image",
|
||||
"PythonTool": "python",
|
||||
"OpenURLTool": "open_url",
|
||||
"KnowledgeGraphTool": "run_kg_search",
|
||||
"ResearchAgent": "research_agent",
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Mapping of in_code_tool_id to the NAME constant from each tool class
|
||||
# These match the .name property of each tool implementation
|
||||
tool_name_mapping = EXPECTED_TOOL_NAME_MAPPING
|
||||
|
||||
# Update the name column for each tool based on its in_code_tool_id
|
||||
for in_code_tool_id, expected_name in tool_name_mapping.items():
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :expected_name
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
{
|
||||
"expected_name": expected_name,
|
||||
"in_code_tool_id": in_code_tool_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Reverse the migration by setting name back to in_code_tool_id
|
||||
# This matches the original pattern where name was the class name
|
||||
for in_code_tool_id in CURRENT_TOOL_NAME_MAPPING:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :current_name
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
{
|
||||
"current_name": in_code_tool_id,
|
||||
"in_code_tool_id": in_code_tool_id,
|
||||
},
|
||||
)
|
||||
@@ -109,7 +109,6 @@ CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS = float(
|
||||
|
||||
|
||||
STRIPE_SECRET_KEY = os.environ.get("STRIPE_SECRET_KEY")
|
||||
STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")
|
||||
|
||||
# JWT Public Key URL
|
||||
JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
|
||||
|
||||
@@ -3,30 +3,42 @@ from uuid import UUID
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.server.features.persona.models import PersonaSharedNotificationData
|
||||
|
||||
|
||||
def make_persona_private(
|
||||
def update_persona_access(
|
||||
persona_id: int,
|
||||
creator_user_id: UUID | None,
|
||||
user_ids: list[UUID] | None,
|
||||
group_ids: list[int] | None,
|
||||
db_session: Session,
|
||||
is_public: bool | None = None,
|
||||
user_ids: list[UUID] | None = None,
|
||||
group_ids: list[int] | None = None,
|
||||
) -> None:
|
||||
"""NOTE(rkuo): This function batches all updates into a single commit. If we don't
|
||||
dedupe the inputs, the commit will exception."""
|
||||
"""Updates the access settings for a persona including public status, user shares,
|
||||
and group shares.
|
||||
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
db_session.query(Persona__UserGroup).filter(
|
||||
Persona__UserGroup.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
NOTE: This function batches all updates. If we don't dedupe the inputs,
|
||||
the commit will exception.
|
||||
|
||||
NOTE: Callers are responsible for committing."""
|
||||
|
||||
if is_public is not None:
|
||||
persona = db_session.query(Persona).filter(Persona.id == persona_id).first()
|
||||
if persona:
|
||||
persona.is_public = is_public
|
||||
|
||||
# NOTE: For user-ids and group-ids, `None` means "leave unchanged", `[]` means "clear all shares",
|
||||
# and a non-empty list means "replace with these shares".
|
||||
|
||||
if user_ids is not None:
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
if user_ids:
|
||||
user_ids_set = set(user_ids)
|
||||
for user_id in user_ids_set:
|
||||
db_session.add(Persona__User(persona_id=persona_id, user_id=user_id))
|
||||
@@ -41,11 +53,13 @@ def make_persona_private(
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
if group_ids:
|
||||
if group_ids is not None:
|
||||
db_session.query(Persona__UserGroup).filter(
|
||||
Persona__UserGroup.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
group_ids_set = set(group_ids)
|
||||
for group_id in group_ids_set:
|
||||
db_session.add(
|
||||
Persona__UserGroup(persona_id=persona_id, user_group_id=group_id)
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
import requests
|
||||
import stripe
|
||||
|
||||
from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
@@ -16,15 +16,21 @@ stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def fetch_stripe_checkout_session(tenant_id: str) -> str:
|
||||
def fetch_stripe_checkout_session(
|
||||
tenant_id: str,
|
||||
billing_period: Literal["monthly", "annual"] = "monthly",
|
||||
) -> str:
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/create-checkout-session"
|
||||
params = {"tenant_id": tenant_id}
|
||||
response = requests.post(url, headers=headers, params=params)
|
||||
payload = {
|
||||
"tenant_id": tenant_id,
|
||||
"billing_period": billing_period,
|
||||
}
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()["sessionId"]
|
||||
|
||||
@@ -72,22 +78,24 @@ def fetch_billing_information(
|
||||
|
||||
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:
|
||||
"""
|
||||
Send a request to the control service to register the number of users for a tenant.
|
||||
Update the number of seats for a tenant's subscription.
|
||||
Preserves the existing price (monthly, annual, or grandfathered).
|
||||
"""
|
||||
|
||||
if not STRIPE_PRICE_ID:
|
||||
raise Exception("STRIPE_PRICE_ID is not set")
|
||||
|
||||
response = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_subscription_id = cast(str, response.get("stripe_subscription_id"))
|
||||
|
||||
subscription = stripe.Subscription.retrieve(stripe_subscription_id)
|
||||
subscription_item = subscription["items"]["data"][0]
|
||||
|
||||
# Use existing price to preserve the customer's current plan
|
||||
current_price_id = subscription_item.price.id
|
||||
|
||||
updated_subscription = stripe.Subscription.modify(
|
||||
stripe_subscription_id,
|
||||
items=[
|
||||
{
|
||||
"id": subscription["items"]["data"][0].id,
|
||||
"price": STRIPE_PRICE_ID,
|
||||
"id": subscription_item.id,
|
||||
"price": current_price_id,
|
||||
"quantity": number_of_users,
|
||||
}
|
||||
],
|
||||
|
||||
@@ -10,6 +10,7 @@ from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
|
||||
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import CreateSubscriptionSessionRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingFullSyncRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingResponse
|
||||
@@ -104,15 +105,18 @@ async def create_customer_portal_session(
|
||||
|
||||
@router.post("/create-subscription-session")
|
||||
async def create_subscription_session(
|
||||
request: CreateSubscriptionSessionRequest | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> SubscriptionSessionResponse:
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="Tenant ID not found")
|
||||
session_id = fetch_stripe_checkout_session(tenant_id)
|
||||
|
||||
billing_period = request.billing_period if request else "monthly"
|
||||
session_id = fetch_stripe_checkout_session(tenant_id, billing_period)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create resubscription session")
|
||||
logger.exception("Failed to create subscription session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -73,6 +74,12 @@ class SubscriptionSessionResponse(BaseModel):
|
||||
sessionId: str
|
||||
|
||||
|
||||
class CreateSubscriptionSessionRequest(BaseModel):
|
||||
"""Request to create a subscription checkout session."""
|
||||
|
||||
billing_period: Literal["monthly", "annual"] = "monthly"
|
||||
|
||||
|
||||
class TenantByDomainResponse(BaseModel):
|
||||
tenant_id: str
|
||||
number_of_users: int
|
||||
|
||||
@@ -105,6 +105,8 @@ class DocExternalAccess:
|
||||
)
|
||||
|
||||
|
||||
# TODO(andrei): First refactor this into a pydantic model, then get rid of
|
||||
# duplicate fields.
|
||||
@dataclass(frozen=True, init=False)
|
||||
class DocumentAccess(ExternalAccess):
|
||||
# User emails for Onyx users, None indicates admin
|
||||
|
||||
@@ -12,6 +12,7 @@ from retry import retry
|
||||
from sqlalchemy import select
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
@@ -19,12 +20,14 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
@@ -53,6 +56,17 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_queued_key(user_file_id: str | UUID) -> str:
|
||||
"""Key that exists while a process_single_user_file task is sitting in the queue.
|
||||
|
||||
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
before enqueuing and the worker deletes it as its first action. This prevents
|
||||
the beat from adding duplicate tasks for files that already have a live task
|
||||
in flight.
|
||||
"""
|
||||
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
@@ -116,7 +130,24 @@ def _get_document_chunk_count(
|
||||
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
|
||||
|
||||
Uses direct Redis locks to avoid overlapping runs.
|
||||
Three mechanisms prevent queue runaway:
|
||||
|
||||
1. **Queue depth backpressure** – if the broker queue already has more than
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
|
||||
entirely. Workers are clearly behind; adding more tasks would only make
|
||||
the backlog worse.
|
||||
|
||||
2. **Per-file queued guard** – before enqueuing a task we set a short-lived
|
||||
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
|
||||
already exists the file already has a live task in the queue, so we skip
|
||||
it. The worker deletes the key the moment it picks up the task so the
|
||||
next beat cycle can re-enqueue if the file is still PROCESSING.
|
||||
|
||||
3. **Task expiry** – every enqueued task carries an `expires` value equal to
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
|
||||
the queue after that deadline, Celery discards it without touching the DB.
|
||||
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
|
||||
Redis restart), stale tasks evict themselves rather than piling up forever.
|
||||
"""
|
||||
task_logger.info("check_user_file_processing - Starting")
|
||||
|
||||
@@ -131,7 +162,21 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
enqueued = 0
|
||||
skipped_guard = 0
|
||||
try:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
queue_len = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
|
||||
)
|
||||
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
f"check_user_file_processing - Queue depth {queue_len} exceeds "
|
||||
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file_ids = (
|
||||
db_session.execute(
|
||||
@@ -144,12 +189,35 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
)
|
||||
|
||||
for user_file_id in user_file_ids:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
# --- Protection 2: per-file queued guard ---
|
||||
queued_key = _user_file_queued_key(user_file_id)
|
||||
guard_set = redis_client.set(
|
||||
queued_key,
|
||||
1,
|
||||
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
nx=True,
|
||||
)
|
||||
if not guard_set:
|
||||
skipped_guard += 1
|
||||
continue
|
||||
|
||||
# --- Protection 3: task expiry ---
|
||||
# If task submission fails, clear the guard immediately so the
|
||||
# next beat cycle can retry enqueuing this file.
|
||||
try:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={
|
||||
"user_file_id": str(user_file_id),
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
)
|
||||
except Exception:
|
||||
redis_client.delete(queued_key)
|
||||
raise
|
||||
enqueued += 1
|
||||
|
||||
finally:
|
||||
@@ -157,7 +225,8 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
|
||||
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
|
||||
f"tasks for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -172,6 +241,12 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
|
||||
start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Clear the "queued" guard set by the beat generator so that the next beat
|
||||
# cycle can re-enqueue this file if it is still in PROCESSING state after
|
||||
# this task completes or fails.
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
|
||||
@@ -401,7 +401,10 @@ def handle_stream_message_objects(
|
||||
if new_msg_req.parent_message_id == AUTO_PLACE_AFTER_LATEST_MESSAGE:
|
||||
# Auto-place after the latest message in the chain
|
||||
parent_message = chat_history[-1] if chat_history else root_message
|
||||
elif new_msg_req.parent_message_id is None:
|
||||
elif (
|
||||
new_msg_req.parent_message_id is None
|
||||
or new_msg_req.parent_message_id == root_message.id
|
||||
):
|
||||
# None = regeneration from root
|
||||
parent_message = root_message
|
||||
# Truncate history since we're starting from root
|
||||
|
||||
@@ -149,6 +149,17 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
|
||||
|
||||
# How long a queued user-file task is valid before workers discard it.
|
||||
# Should be longer than the beat interval (20 s) but short enough to prevent
|
||||
# indefinite queue growth. Workers drop tasks older than this without touching
|
||||
# the DB, so a shorter value = faster drain of stale duplicates.
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
|
||||
|
||||
# Maximum number of tasks allowed in the user-file-processing queue before the
|
||||
# beat generator stops adding more. Prevents unbounded queue growth when workers
|
||||
# fall behind.
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
|
||||
|
||||
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
|
||||
@@ -419,6 +430,9 @@ class OnyxRedisLocks:
|
||||
# User file processing
|
||||
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
|
||||
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
|
||||
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
|
||||
# Prevents the beat from re-enqueuing the same file while a task is already queued.
|
||||
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
|
||||
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
|
||||
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
|
||||
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
|
||||
|
||||
@@ -97,10 +97,17 @@ def basic_expert_info_representation(info: BasicExpertInfo) -> str | None:
|
||||
def get_experts_stores_representations(
|
||||
experts: list[BasicExpertInfo] | None,
|
||||
) -> list[str] | None:
|
||||
"""Gets string representations of experts supplied.
|
||||
|
||||
If an expert cannot be represented as a string, it is omitted from the
|
||||
result.
|
||||
"""
|
||||
if not experts:
|
||||
return None
|
||||
|
||||
reps = [basic_expert_info_representation(owner) for owner in experts]
|
||||
reps: list[str | None] = [
|
||||
basic_expert_info_representation(owner) for owner in experts
|
||||
]
|
||||
return [owner for owner in reps if owner is not None]
|
||||
|
||||
|
||||
|
||||
@@ -566,6 +566,23 @@ def extract_content_words_from_recency_query(
|
||||
return content_words_filtered[:MAX_CONTENT_WORDS]
|
||||
|
||||
|
||||
def _is_valid_keyword_query(line: str) -> bool:
|
||||
"""Check if a line looks like a valid keyword query vs explanatory text.
|
||||
|
||||
Returns False for lines that appear to be LLM explanations rather than keywords.
|
||||
"""
|
||||
# Reject lines that start with parentheses (explanatory notes)
|
||||
if line.startswith("("):
|
||||
return False
|
||||
|
||||
# Reject lines that are too long (likely sentences, not keywords)
|
||||
# Keywords should be short - reject if > 50 chars or > 6 words
|
||||
if len(line) > 50 or len(line.split()) > 6:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
"""Use LLM to expand query into multiple search variations.
|
||||
|
||||
@@ -586,10 +603,18 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
response_clean = _parse_llm_code_block_response(response)
|
||||
|
||||
# Split into lines and filter out empty lines
|
||||
rephrased_queries = [
|
||||
raw_queries = [
|
||||
line.strip() for line in response_clean.split("\n") if line.strip()
|
||||
]
|
||||
|
||||
# Filter out lines that look like explanatory text rather than keywords
|
||||
rephrased_queries = [q for q in raw_queries if _is_valid_keyword_query(q)]
|
||||
|
||||
# Log if we filtered out garbage
|
||||
if len(raw_queries) != len(rephrased_queries):
|
||||
filtered_out = set(raw_queries) - set(rephrased_queries)
|
||||
logger.warning(f"Filtered out non-keyword LLM responses: {filtered_out}")
|
||||
|
||||
# If no queries generated, use empty query
|
||||
if not rephrased_queries:
|
||||
logger.debug("No content keywords extracted from query expansion")
|
||||
|
||||
@@ -444,6 +444,8 @@ def upsert_documents(
|
||||
logger.info("No documents to upsert. Skipping.")
|
||||
return
|
||||
|
||||
includes_permissions = any(doc.external_access for doc in seen_documents.values())
|
||||
|
||||
insert_stmt = insert(DbDocument).values(
|
||||
[
|
||||
model_to_dict(
|
||||
@@ -479,21 +481,38 @@ def upsert_documents(
|
||||
]
|
||||
)
|
||||
|
||||
update_set = {
|
||||
"from_ingestion_api": insert_stmt.excluded.from_ingestion_api,
|
||||
"boost": insert_stmt.excluded.boost,
|
||||
"hidden": insert_stmt.excluded.hidden,
|
||||
"semantic_id": insert_stmt.excluded.semantic_id,
|
||||
"link": insert_stmt.excluded.link,
|
||||
"primary_owners": insert_stmt.excluded.primary_owners,
|
||||
"secondary_owners": insert_stmt.excluded.secondary_owners,
|
||||
"doc_metadata": insert_stmt.excluded.doc_metadata,
|
||||
}
|
||||
if includes_permissions:
|
||||
# Use COALESCE to preserve existing permissions when new values are NULL.
|
||||
# This prevents subsequent indexing runs (which don't fetch permissions)
|
||||
# from overwriting permissions set by permission sync jobs.
|
||||
update_set.update(
|
||||
{
|
||||
"external_user_emails": func.coalesce(
|
||||
insert_stmt.excluded.external_user_emails,
|
||||
DbDocument.external_user_emails,
|
||||
),
|
||||
"external_user_group_ids": func.coalesce(
|
||||
insert_stmt.excluded.external_user_group_ids,
|
||||
DbDocument.external_user_group_ids,
|
||||
),
|
||||
"is_public": func.coalesce(
|
||||
insert_stmt.excluded.is_public,
|
||||
DbDocument.is_public,
|
||||
),
|
||||
}
|
||||
)
|
||||
on_conflict_stmt = insert_stmt.on_conflict_do_update(
|
||||
index_elements=["id"], # Conflict target
|
||||
set_={
|
||||
"from_ingestion_api": insert_stmt.excluded.from_ingestion_api,
|
||||
"boost": insert_stmt.excluded.boost,
|
||||
"hidden": insert_stmt.excluded.hidden,
|
||||
"semantic_id": insert_stmt.excluded.semantic_id,
|
||||
"link": insert_stmt.excluded.link,
|
||||
"primary_owners": insert_stmt.excluded.primary_owners,
|
||||
"secondary_owners": insert_stmt.excluded.secondary_owners,
|
||||
"external_user_emails": insert_stmt.excluded.external_user_emails,
|
||||
"external_user_group_ids": insert_stmt.excluded.external_user_group_ids,
|
||||
"is_public": insert_stmt.excluded.is_public,
|
||||
"doc_metadata": insert_stmt.excluded.doc_metadata,
|
||||
},
|
||||
index_elements=["id"], set_=update_set # Conflict target
|
||||
)
|
||||
db_session.execute(on_conflict_stmt)
|
||||
db_session.commit()
|
||||
|
||||
@@ -2616,6 +2616,7 @@ class Tool(Base):
|
||||
__tablename__ = "tool"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
# The name of the tool that the LLM will see
|
||||
name: Mapped[str] = mapped_column(String, nullable=False)
|
||||
description: Mapped[str] = mapped_column(Text, nullable=True)
|
||||
# ID of the tool in the codebase, only applies for in-code tools.
|
||||
|
||||
@@ -187,13 +187,25 @@ def _get_persona_by_name(
|
||||
return result
|
||||
|
||||
|
||||
def make_persona_private(
|
||||
def update_persona_access(
|
||||
persona_id: int,
|
||||
creator_user_id: UUID | None,
|
||||
user_ids: list[UUID] | None,
|
||||
group_ids: list[int] | None,
|
||||
db_session: Session,
|
||||
is_public: bool | None = None,
|
||||
user_ids: list[UUID] | None = None,
|
||||
group_ids: list[int] | None = None,
|
||||
) -> None:
|
||||
"""Updates the access settings for a persona including public status and user shares.
|
||||
|
||||
NOTE: Callers are responsible for committing."""
|
||||
|
||||
if is_public is not None:
|
||||
persona = db_session.query(Persona).filter(Persona.id == persona_id).first()
|
||||
if persona:
|
||||
persona.is_public = is_public
|
||||
|
||||
# NOTE: For user-ids and group-ids, `None` means "leave unchanged", `[]` means "clear all shares",
|
||||
# and a non-empty list means "replace with these shares".
|
||||
if user_ids is not None:
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
@@ -212,11 +224,15 @@ def make_persona_private(
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
# MIT doesn't support group-based sharing, so we allow clearing (no-op since
|
||||
# there shouldn't be any) but raise an error if trying to add actual groups.
|
||||
if group_ids is not None:
|
||||
db_session.query(Persona__UserGroup).filter(
|
||||
Persona__UserGroup.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
# May cause error if someone switches down to MIT from EE
|
||||
if group_ids:
|
||||
raise NotImplementedError("Onyx MIT does not support private Personas")
|
||||
if group_ids:
|
||||
raise NotImplementedError("Onyx MIT does not support group-based sharing")
|
||||
|
||||
|
||||
def create_update_persona(
|
||||
@@ -282,20 +298,21 @@ def create_update_persona(
|
||||
llm_filter_extraction=create_persona_request.llm_filter_extraction,
|
||||
is_default_persona=create_persona_request.is_default_persona,
|
||||
user_file_ids=converted_user_file_ids,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
versioned_make_persona_private = fetch_versioned_implementation(
|
||||
"onyx.db.persona", "make_persona_private"
|
||||
versioned_update_persona_access = fetch_versioned_implementation(
|
||||
"onyx.db.persona", "update_persona_access"
|
||||
)
|
||||
|
||||
# Privatize Persona
|
||||
versioned_make_persona_private(
|
||||
versioned_update_persona_access(
|
||||
persona_id=persona.id,
|
||||
creator_user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
user_ids=create_persona_request.users,
|
||||
group_ids=create_persona_request.groups,
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to create persona")
|
||||
@@ -304,11 +321,13 @@ def create_update_persona(
|
||||
return FullPersonaSnapshot.from_model(persona)
|
||||
|
||||
|
||||
def update_persona_shared_users(
|
||||
def update_persona_shared(
|
||||
persona_id: int,
|
||||
user_ids: list[UUID],
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
user_ids: list[UUID] | None = None,
|
||||
group_ids: list[int] | None = None,
|
||||
is_public: bool | None = None,
|
||||
) -> None:
|
||||
"""Simplified version of `create_update_persona` which only touches the
|
||||
accessibility rather than any of the logic (e.g. prompt, connected data sources,
|
||||
@@ -317,22 +336,25 @@ def update_persona_shared_users(
|
||||
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
||||
)
|
||||
|
||||
if persona.is_public:
|
||||
raise HTTPException(status_code=400, detail="Cannot share public persona")
|
||||
if user and user.role != UserRole.ADMIN and persona.user_id != user.id:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You don't have permission to modify this persona"
|
||||
)
|
||||
|
||||
versioned_make_persona_private = fetch_versioned_implementation(
|
||||
"onyx.db.persona", "make_persona_private"
|
||||
versioned_update_persona_access = fetch_versioned_implementation(
|
||||
"onyx.db.persona", "update_persona_access"
|
||||
)
|
||||
|
||||
# Privatize Persona
|
||||
versioned_make_persona_private(
|
||||
versioned_update_persona_access(
|
||||
persona_id=persona_id,
|
||||
creator_user_id=user.id if user else None,
|
||||
user_ids=user_ids,
|
||||
group_ids=None,
|
||||
db_session=db_session,
|
||||
is_public=is_public,
|
||||
user_ids=user_ids,
|
||||
group_ids=group_ids,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_persona_public_status(
|
||||
persona_id: int,
|
||||
|
||||
@@ -113,7 +113,6 @@ def upsert_web_search_provider(
|
||||
if activate:
|
||||
set_active_web_search_provider(provider_id=provider.id, db_session=db_session)
|
||||
|
||||
db_session.commit()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
@@ -269,7 +268,6 @@ def upsert_web_content_provider(
|
||||
if activate:
|
||||
set_active_web_content_provider(provider_id=provider.id, db_session=db_session)
|
||||
|
||||
db_session.commit()
|
||||
db_session.refresh(provider)
|
||||
return provider
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ from onyx.configs.constants import MessageType
|
||||
from onyx.db.tools import get_tool_by_name
|
||||
from onyx.deep_research.dr_mock_tools import get_clarification_tool_definitions
|
||||
from onyx.deep_research.dr_mock_tools import get_orchestrator_tools
|
||||
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_DB_NAME
|
||||
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_TOOL_NAME
|
||||
from onyx.deep_research.dr_mock_tools import THINK_TOOL_RESPONSE_MESSAGE
|
||||
from onyx.deep_research.dr_mock_tools import THINK_TOOL_RESPONSE_TOKEN_COUNT
|
||||
@@ -220,35 +219,90 @@ def run_deep_research_llm_loop(
|
||||
else ""
|
||||
)
|
||||
if not skip_clarification:
|
||||
clarification_prompt = CLARIFICATION_PROMPT.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False),
|
||||
internal_search_clarification_guidance=internal_search_clarification_guidance,
|
||||
)
|
||||
with function_span("clarification_step") as span:
|
||||
clarification_prompt = CLARIFICATION_PROMPT.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False),
|
||||
internal_search_clarification_guidance=internal_search_clarification_guidance,
|
||||
)
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=clarification_prompt,
|
||||
token_count=300, # Skips the exact token count but has enough leeway
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
|
||||
truncated_message_history = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
)
|
||||
|
||||
llm_step_result, _ = run_llm_step(
|
||||
emitter=emitter,
|
||||
history=truncated_message_history,
|
||||
tool_definitions=get_clarification_tool_definitions(),
|
||||
tool_choice=ToolChoiceOptions.AUTO,
|
||||
llm=llm,
|
||||
placement=Placement(turn_index=0),
|
||||
# No citations in this step, it should just pass through all
|
||||
# tokens directly so initialized as an empty citation processor
|
||||
citation_processor=None,
|
||||
state_container=state_container,
|
||||
final_documents=None,
|
||||
user_identity=user_identity,
|
||||
is_deep_research=True,
|
||||
)
|
||||
|
||||
if not llm_step_result.tool_calls:
|
||||
# Mark this turn as a clarification question
|
||||
state_container.set_is_clarification(True)
|
||||
span.span_data.output = "clarification_required"
|
||||
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=OverallStop(type="stop"),
|
||||
)
|
||||
)
|
||||
|
||||
# If a clarification is asked, we need to end this turn and wait on user input
|
||||
return
|
||||
|
||||
#########################################################
|
||||
# RESEARCH PLAN STEP
|
||||
#########################################################
|
||||
with function_span("research_plan_step") as span:
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=clarification_prompt,
|
||||
token_count=300, # Skips the exact token count but has enough leeway
|
||||
message=RESEARCH_PLAN_PROMPT.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False)
|
||||
),
|
||||
token_count=300,
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
|
||||
reminder_message = ChatMessageSimple(
|
||||
message=RESEARCH_PLAN_REMINDER,
|
||||
token_count=100,
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
truncated_message_history = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
simple_chat_history=simple_chat_history + [reminder_message],
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT + 1,
|
||||
)
|
||||
|
||||
llm_step_result, _ = run_llm_step(
|
||||
emitter=emitter,
|
||||
research_plan_generator = run_llm_step_pkt_generator(
|
||||
history=truncated_message_history,
|
||||
tool_definitions=get_clarification_tool_definitions(),
|
||||
tool_choice=ToolChoiceOptions.AUTO,
|
||||
tool_definitions=[],
|
||||
tool_choice=ToolChoiceOptions.NONE,
|
||||
llm=llm,
|
||||
placement=Placement(turn_index=0),
|
||||
# No citations in this step, it should just pass through all
|
||||
# tokens directly so initialized as an empty citation processor
|
||||
citation_processor=None,
|
||||
state_container=state_container,
|
||||
final_documents=None,
|
||||
@@ -256,301 +310,177 @@ def run_deep_research_llm_loop(
|
||||
is_deep_research=True,
|
||||
)
|
||||
|
||||
if not llm_step_result.tool_calls:
|
||||
# Mark this turn as a clarification question
|
||||
state_container.set_is_clarification(True)
|
||||
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=0), obj=OverallStop(type="stop")
|
||||
)
|
||||
)
|
||||
|
||||
# If a clarification is asked, we need to end this turn and wait on user input
|
||||
return
|
||||
|
||||
#########################################################
|
||||
# RESEARCH PLAN STEP
|
||||
#########################################################
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=RESEARCH_PLAN_PROMPT.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False)
|
||||
),
|
||||
token_count=300,
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
reminder_message = ChatMessageSimple(
|
||||
message=RESEARCH_PLAN_REMINDER,
|
||||
token_count=100,
|
||||
message_type=MessageType.USER,
|
||||
)
|
||||
truncated_message_history = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history + [reminder_message],
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT + 1,
|
||||
)
|
||||
|
||||
research_plan_generator = run_llm_step_pkt_generator(
|
||||
history=truncated_message_history,
|
||||
tool_definitions=[],
|
||||
tool_choice=ToolChoiceOptions.NONE,
|
||||
llm=llm,
|
||||
placement=Placement(turn_index=0),
|
||||
citation_processor=None,
|
||||
state_container=state_container,
|
||||
final_documents=None,
|
||||
user_identity=user_identity,
|
||||
is_deep_research=True,
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
packet = next(research_plan_generator)
|
||||
# Translate AgentResponseStart/Delta packets to DeepResearchPlanStart/Delta
|
||||
# The LLM response from this prompt is the research plan
|
||||
if isinstance(packet.obj, AgentResponseStart):
|
||||
while True:
|
||||
try:
|
||||
packet = next(research_plan_generator)
|
||||
# Translate AgentResponseStart/Delta packets to DeepResearchPlanStart/Delta
|
||||
# The LLM response from this prompt is the research plan
|
||||
if isinstance(packet.obj, AgentResponseStart):
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=packet.placement,
|
||||
obj=DeepResearchPlanStart(),
|
||||
)
|
||||
)
|
||||
elif isinstance(packet.obj, AgentResponseDelta):
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=packet.placement,
|
||||
obj=DeepResearchPlanDelta(content=packet.obj.content),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Pass through other packet types (e.g., ReasoningStart, ReasoningDelta, etc.)
|
||||
emitter.emit(packet)
|
||||
except StopIteration as e:
|
||||
llm_step_result, reasoned = e.value
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=packet.placement,
|
||||
obj=DeepResearchPlanStart(),
|
||||
# Marks the last turn end which should be the plan generation
|
||||
placement=Placement(
|
||||
turn_index=1 if reasoned else 0,
|
||||
),
|
||||
obj=SectionEnd(),
|
||||
)
|
||||
)
|
||||
elif isinstance(packet.obj, AgentResponseDelta):
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=packet.placement,
|
||||
obj=DeepResearchPlanDelta(content=packet.obj.content),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Pass through other packet types (e.g., ReasoningStart, ReasoningDelta, etc.)
|
||||
emitter.emit(packet)
|
||||
except StopIteration as e:
|
||||
llm_step_result, reasoned = e.value
|
||||
emitter.emit(
|
||||
Packet(
|
||||
# Marks the last turn end which should be the plan generation
|
||||
placement=Placement(
|
||||
turn_index=1 if reasoned else 0,
|
||||
),
|
||||
obj=SectionEnd(),
|
||||
)
|
||||
)
|
||||
if reasoned:
|
||||
orchestrator_start_turn_index += 1
|
||||
break
|
||||
llm_step_result = cast(LlmStepResult, llm_step_result)
|
||||
if reasoned:
|
||||
orchestrator_start_turn_index += 1
|
||||
break
|
||||
llm_step_result = cast(LlmStepResult, llm_step_result)
|
||||
|
||||
research_plan = llm_step_result.answer
|
||||
research_plan = llm_step_result.answer
|
||||
span.span_data.output = research_plan if research_plan else None
|
||||
|
||||
#########################################################
|
||||
# RESEARCH EXECUTION STEP
|
||||
#########################################################
|
||||
is_reasoning_model = model_is_reasoning_model(
|
||||
llm.config.model_name, llm.config.model_provider
|
||||
)
|
||||
with function_span("research_execution_step") as span:
|
||||
is_reasoning_model = model_is_reasoning_model(
|
||||
llm.config.model_name, llm.config.model_provider
|
||||
)
|
||||
|
||||
max_orchestrator_cycles = (
|
||||
MAX_ORCHESTRATOR_CYCLES
|
||||
if not is_reasoning_model
|
||||
else MAX_ORCHESTRATOR_CYCLES_REASONING
|
||||
)
|
||||
max_orchestrator_cycles = (
|
||||
MAX_ORCHESTRATOR_CYCLES
|
||||
if not is_reasoning_model
|
||||
else MAX_ORCHESTRATOR_CYCLES_REASONING
|
||||
)
|
||||
|
||||
orchestrator_prompt_template = (
|
||||
ORCHESTRATOR_PROMPT
|
||||
if not is_reasoning_model
|
||||
else ORCHESTRATOR_PROMPT_REASONING
|
||||
)
|
||||
orchestrator_prompt_template = (
|
||||
ORCHESTRATOR_PROMPT
|
||||
if not is_reasoning_model
|
||||
else ORCHESTRATOR_PROMPT_REASONING
|
||||
)
|
||||
|
||||
internal_search_research_task_guidance = (
|
||||
INTERNAL_SEARCH_RESEARCH_TASK_GUIDANCE
|
||||
if include_internal_search_tunings
|
||||
else ""
|
||||
)
|
||||
token_count_prompt = orchestrator_prompt_template.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False),
|
||||
current_cycle_count=1,
|
||||
max_cycles=max_orchestrator_cycles,
|
||||
research_plan=research_plan,
|
||||
internal_search_research_task_guidance=internal_search_research_task_guidance,
|
||||
)
|
||||
orchestration_tokens = token_counter(token_count_prompt)
|
||||
|
||||
reasoning_cycles = 0
|
||||
most_recent_reasoning: str | None = None
|
||||
citation_mapping: CitationMapping = {}
|
||||
final_turn_index: int = (
|
||||
orchestrator_start_turn_index # Track the final turn_index for stop packet
|
||||
)
|
||||
for cycle in range(max_orchestrator_cycles):
|
||||
if cycle == max_orchestrator_cycles - 1:
|
||||
# If it's the last cycle, forcibly generate the final report
|
||||
report_turn_index = (
|
||||
orchestrator_start_turn_index + cycle + reasoning_cycles
|
||||
)
|
||||
report_reasoned = generate_final_report(
|
||||
history=simple_chat_history,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
state_container=state_container,
|
||||
emitter=emitter,
|
||||
turn_index=report_turn_index,
|
||||
citation_mapping=citation_mapping,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
# Update final_turn_index: base + 1 for the report itself + 1 if reasoning occurred
|
||||
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
|
||||
break
|
||||
|
||||
research_agent_calls: list[ToolCallKickoff] = []
|
||||
|
||||
orchestrator_prompt = orchestrator_prompt_template.format(
|
||||
internal_search_research_task_guidance = (
|
||||
INTERNAL_SEARCH_RESEARCH_TASK_GUIDANCE
|
||||
if include_internal_search_tunings
|
||||
else ""
|
||||
)
|
||||
token_count_prompt = orchestrator_prompt_template.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False),
|
||||
current_cycle_count=cycle,
|
||||
current_cycle_count=1,
|
||||
max_cycles=max_orchestrator_cycles,
|
||||
research_plan=research_plan,
|
||||
internal_search_research_task_guidance=internal_search_research_task_guidance,
|
||||
)
|
||||
orchestration_tokens = token_counter(token_count_prompt)
|
||||
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=orchestrator_prompt,
|
||||
token_count=orchestration_tokens,
|
||||
message_type=MessageType.SYSTEM,
|
||||
reasoning_cycles = 0
|
||||
most_recent_reasoning: str | None = None
|
||||
citation_mapping: CitationMapping = {}
|
||||
final_turn_index: int = (
|
||||
orchestrator_start_turn_index # Track the final turn_index for stop packet
|
||||
)
|
||||
for cycle in range(max_orchestrator_cycles):
|
||||
if cycle == max_orchestrator_cycles - 1:
|
||||
# If it's the last cycle, forcibly generate the final report
|
||||
report_turn_index = (
|
||||
orchestrator_start_turn_index + cycle + reasoning_cycles
|
||||
)
|
||||
report_reasoned = generate_final_report(
|
||||
history=simple_chat_history,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
state_container=state_container,
|
||||
emitter=emitter,
|
||||
turn_index=report_turn_index,
|
||||
citation_mapping=citation_mapping,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
# Update final_turn_index: base + 1 for the report itself + 1 if reasoning occurred
|
||||
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
|
||||
break
|
||||
|
||||
truncated_message_history = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
)
|
||||
research_agent_calls: list[ToolCallKickoff] = []
|
||||
|
||||
# Use think tool processor for non-reasoning models to convert
|
||||
# think_tool calls to reasoning content
|
||||
custom_processor = (
|
||||
create_think_tool_token_processor() if not is_reasoning_model else None
|
||||
)
|
||||
|
||||
llm_step_result, has_reasoned = run_llm_step(
|
||||
emitter=emitter,
|
||||
history=truncated_message_history,
|
||||
tool_definitions=get_orchestrator_tools(
|
||||
include_think_tool=not is_reasoning_model
|
||||
),
|
||||
tool_choice=ToolChoiceOptions.REQUIRED,
|
||||
llm=llm,
|
||||
placement=Placement(
|
||||
turn_index=orchestrator_start_turn_index + cycle + reasoning_cycles
|
||||
),
|
||||
# No citations in this step, it should just pass through all
|
||||
# tokens directly so initialized as an empty citation processor
|
||||
citation_processor=DynamicCitationProcessor(),
|
||||
state_container=state_container,
|
||||
final_documents=None,
|
||||
user_identity=user_identity,
|
||||
custom_token_processor=custom_processor,
|
||||
is_deep_research=True,
|
||||
)
|
||||
if has_reasoned:
|
||||
reasoning_cycles += 1
|
||||
|
||||
tool_calls = llm_step_result.tool_calls or []
|
||||
|
||||
if not tool_calls and cycle == 0:
|
||||
raise RuntimeError(
|
||||
"Deep Research failed to generate any research tasks for the agents."
|
||||
orchestrator_prompt = orchestrator_prompt_template.format(
|
||||
current_datetime=get_current_llm_day_time(full_sentence=False),
|
||||
current_cycle_count=cycle,
|
||||
max_cycles=max_orchestrator_cycles,
|
||||
research_plan=research_plan,
|
||||
internal_search_research_task_guidance=internal_search_research_task_guidance,
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
# Basically hope that this is an infrequent occurence and hopefully multiple research
|
||||
# cycles have already ran
|
||||
logger.warning("No tool calls found, this should not happen.")
|
||||
report_turn_index = (
|
||||
orchestrator_start_turn_index + cycle + reasoning_cycles
|
||||
system_prompt = ChatMessageSimple(
|
||||
message=orchestrator_prompt,
|
||||
token_count=orchestration_tokens,
|
||||
message_type=MessageType.SYSTEM,
|
||||
)
|
||||
report_reasoned = generate_final_report(
|
||||
history=simple_chat_history,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
state_container=state_container,
|
||||
|
||||
truncated_message_history = construct_message_history(
|
||||
system_prompt=system_prompt,
|
||||
custom_agent_prompt=None,
|
||||
simple_chat_history=simple_chat_history,
|
||||
reminder_message=None,
|
||||
project_files=None,
|
||||
available_tokens=available_tokens,
|
||||
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
|
||||
)
|
||||
|
||||
# Use think tool processor for non-reasoning models to convert
|
||||
# think_tool calls to reasoning content
|
||||
custom_processor = (
|
||||
create_think_tool_token_processor()
|
||||
if not is_reasoning_model
|
||||
else None
|
||||
)
|
||||
|
||||
llm_step_result, has_reasoned = run_llm_step(
|
||||
emitter=emitter,
|
||||
turn_index=report_turn_index,
|
||||
citation_mapping=citation_mapping,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
|
||||
break
|
||||
|
||||
special_tool_calls = check_special_tool_calls(tool_calls=tool_calls)
|
||||
|
||||
if special_tool_calls.generate_report_tool_call:
|
||||
report_turn_index = (
|
||||
special_tool_calls.generate_report_tool_call.placement.turn_index
|
||||
)
|
||||
report_reasoned = generate_final_report(
|
||||
history=simple_chat_history,
|
||||
history=truncated_message_history,
|
||||
tool_definitions=get_orchestrator_tools(
|
||||
include_think_tool=not is_reasoning_model
|
||||
),
|
||||
tool_choice=ToolChoiceOptions.REQUIRED,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
placement=Placement(
|
||||
turn_index=orchestrator_start_turn_index
|
||||
+ cycle
|
||||
+ reasoning_cycles
|
||||
),
|
||||
# No citations in this step, it should just pass through all
|
||||
# tokens directly so initialized as an empty citation processor
|
||||
citation_processor=DynamicCitationProcessor(),
|
||||
state_container=state_container,
|
||||
emitter=emitter,
|
||||
turn_index=report_turn_index,
|
||||
citation_mapping=citation_mapping,
|
||||
final_documents=None,
|
||||
user_identity=user_identity,
|
||||
saved_reasoning=most_recent_reasoning,
|
||||
custom_token_processor=custom_processor,
|
||||
is_deep_research=True,
|
||||
)
|
||||
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
|
||||
break
|
||||
elif special_tool_calls.think_tool_call:
|
||||
think_tool_call = special_tool_calls.think_tool_call
|
||||
# Only process the THINK_TOOL and skip all other tool calls
|
||||
# This will not actually get saved to the db as a tool call but we'll attach it to the tool(s) called after
|
||||
# it as if it were just a reasoning model doing it. In the chat history, because it happens in 2 steps,
|
||||
# we will show it as a separate message.
|
||||
# NOTE: This does not need to increment the reasoning cycles because the custom token processor causes
|
||||
# the LLM step to handle this
|
||||
with function_span("think_tool") as span:
|
||||
span.span_data.input = str(think_tool_call.tool_args)
|
||||
most_recent_reasoning = state_container.reasoning_tokens
|
||||
tool_call_message = think_tool_call.to_msg_str()
|
||||
if has_reasoned:
|
||||
reasoning_cycles += 1
|
||||
|
||||
think_tool_msg = ChatMessageSimple(
|
||||
message=tool_call_message,
|
||||
token_count=token_counter(tool_call_message),
|
||||
message_type=MessageType.TOOL_CALL,
|
||||
tool_call_id=think_tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
tool_calls = llm_step_result.tool_calls or []
|
||||
|
||||
if not tool_calls and cycle == 0:
|
||||
raise RuntimeError(
|
||||
"Deep Research failed to generate any research tasks for the agents."
|
||||
)
|
||||
simple_chat_history.append(think_tool_msg)
|
||||
|
||||
think_tool_response_msg = ChatMessageSimple(
|
||||
message=THINK_TOOL_RESPONSE_MESSAGE,
|
||||
token_count=THINK_TOOL_RESPONSE_TOKEN_COUNT,
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id=think_tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(think_tool_response_msg)
|
||||
span.span_data.output = THINK_TOOL_RESPONSE_MESSAGE
|
||||
continue
|
||||
else:
|
||||
for tool_call in tool_calls:
|
||||
if tool_call.tool_name != RESEARCH_AGENT_TOOL_NAME:
|
||||
logger.warning(f"Unexpected tool call: {tool_call.tool_name}")
|
||||
continue
|
||||
|
||||
research_agent_calls.append(tool_call)
|
||||
|
||||
if not research_agent_calls:
|
||||
logger.warning(
|
||||
"No research agent tool calls found, this should not happen."
|
||||
)
|
||||
if not tool_calls:
|
||||
# Basically hope that this is an infrequent occurence and hopefully multiple research
|
||||
# cycles have already ran
|
||||
logger.warning("No tool calls found, this should not happen.")
|
||||
report_turn_index = (
|
||||
orchestrator_start_turn_index + cycle + reasoning_cycles
|
||||
)
|
||||
@@ -567,91 +497,177 @@ def run_deep_research_llm_loop(
|
||||
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
|
||||
break
|
||||
|
||||
if len(research_agent_calls) > 1:
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(
|
||||
turn_index=research_agent_calls[0].placement.turn_index
|
||||
),
|
||||
obj=TopLevelBranching(
|
||||
num_parallel_branches=len(research_agent_calls)
|
||||
),
|
||||
special_tool_calls = check_special_tool_calls(tool_calls=tool_calls)
|
||||
|
||||
if special_tool_calls.generate_report_tool_call:
|
||||
report_turn_index = (
|
||||
special_tool_calls.generate_report_tool_call.placement.turn_index
|
||||
)
|
||||
report_reasoned = generate_final_report(
|
||||
history=simple_chat_history,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
state_container=state_container,
|
||||
emitter=emitter,
|
||||
turn_index=report_turn_index,
|
||||
citation_mapping=citation_mapping,
|
||||
user_identity=user_identity,
|
||||
saved_reasoning=most_recent_reasoning,
|
||||
)
|
||||
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
|
||||
break
|
||||
elif special_tool_calls.think_tool_call:
|
||||
think_tool_call = special_tool_calls.think_tool_call
|
||||
# Only process the THINK_TOOL and skip all other tool calls
|
||||
# This will not actually get saved to the db as a tool call but we'll attach it to the tool(s) called after
|
||||
# it as if it were just a reasoning model doing it. In the chat history, because it happens in 2 steps,
|
||||
# we will show it as a separate message.
|
||||
# NOTE: This does not need to increment the reasoning cycles because the custom token processor causes
|
||||
# the LLM step to handle this
|
||||
with function_span("think_tool") as span:
|
||||
span.span_data.input = str(think_tool_call.tool_args)
|
||||
most_recent_reasoning = state_container.reasoning_tokens
|
||||
tool_call_message = think_tool_call.to_msg_str()
|
||||
|
||||
think_tool_msg = ChatMessageSimple(
|
||||
message=tool_call_message,
|
||||
token_count=token_counter(tool_call_message),
|
||||
message_type=MessageType.TOOL_CALL,
|
||||
tool_call_id=think_tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
)
|
||||
simple_chat_history.append(think_tool_msg)
|
||||
|
||||
research_results = run_research_agent_calls(
|
||||
# The tool calls here contain the placement information
|
||||
research_agent_calls=research_agent_calls,
|
||||
parent_tool_call_ids=[
|
||||
tool_call.tool_call_id for tool_call in tool_calls
|
||||
],
|
||||
tools=allowed_tools,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
llm=llm,
|
||||
is_reasoning_model=is_reasoning_model,
|
||||
token_counter=token_counter,
|
||||
citation_mapping=citation_mapping,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
|
||||
citation_mapping = research_results.citation_mapping
|
||||
|
||||
for tab_index, report in enumerate(
|
||||
research_results.intermediate_reports
|
||||
):
|
||||
if report is None:
|
||||
# The LLM will not see that this research was even attempted, it may try
|
||||
# something similar again but this is not bad.
|
||||
logger.error(
|
||||
f"Research agent call at tab_index {tab_index} failed, skipping"
|
||||
think_tool_response_msg = ChatMessageSimple(
|
||||
message=THINK_TOOL_RESPONSE_MESSAGE,
|
||||
token_count=THINK_TOOL_RESPONSE_TOKEN_COUNT,
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id=think_tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
continue
|
||||
simple_chat_history.append(think_tool_response_msg)
|
||||
span.span_data.output = THINK_TOOL_RESPONSE_MESSAGE
|
||||
continue
|
||||
else:
|
||||
for tool_call in tool_calls:
|
||||
if tool_call.tool_name != RESEARCH_AGENT_TOOL_NAME:
|
||||
logger.warning(
|
||||
f"Unexpected tool call: {tool_call.tool_name}"
|
||||
)
|
||||
continue
|
||||
|
||||
current_tool_call = research_agent_calls[tab_index]
|
||||
tool_call_info = ToolCallInfo(
|
||||
parent_tool_call_id=None,
|
||||
turn_index=orchestrator_start_turn_index
|
||||
+ cycle
|
||||
+ reasoning_cycles,
|
||||
tab_index=tab_index,
|
||||
tool_name=current_tool_call.tool_name,
|
||||
tool_call_id=current_tool_call.tool_call_id,
|
||||
tool_id=get_tool_by_name(
|
||||
tool_name=RESEARCH_AGENT_DB_NAME, db_session=db_session
|
||||
).id,
|
||||
reasoning_tokens=llm_step_result.reasoning
|
||||
or most_recent_reasoning,
|
||||
tool_call_arguments=current_tool_call.tool_args,
|
||||
tool_call_response=report,
|
||||
search_docs=None, # Intermediate docs are not saved/shown
|
||||
generated_images=None,
|
||||
research_agent_calls.append(tool_call)
|
||||
|
||||
if not research_agent_calls:
|
||||
logger.warning(
|
||||
"No research agent tool calls found, this should not happen."
|
||||
)
|
||||
report_turn_index = (
|
||||
orchestrator_start_turn_index + cycle + reasoning_cycles
|
||||
)
|
||||
report_reasoned = generate_final_report(
|
||||
history=simple_chat_history,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
state_container=state_container,
|
||||
emitter=emitter,
|
||||
turn_index=report_turn_index,
|
||||
citation_mapping=citation_mapping,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
final_turn_index = report_turn_index + (
|
||||
1 if report_reasoned else 0
|
||||
)
|
||||
break
|
||||
|
||||
if len(research_agent_calls) > 1:
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(
|
||||
turn_index=research_agent_calls[
|
||||
0
|
||||
].placement.turn_index
|
||||
),
|
||||
obj=TopLevelBranching(
|
||||
num_parallel_branches=len(research_agent_calls)
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
research_results = run_research_agent_calls(
|
||||
# The tool calls here contain the placement information
|
||||
research_agent_calls=research_agent_calls,
|
||||
parent_tool_call_ids=[
|
||||
tool_call.tool_call_id for tool_call in tool_calls
|
||||
],
|
||||
tools=allowed_tools,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
llm=llm,
|
||||
is_reasoning_model=is_reasoning_model,
|
||||
token_counter=token_counter,
|
||||
citation_mapping=citation_mapping,
|
||||
user_identity=user_identity,
|
||||
)
|
||||
state_container.add_tool_call(tool_call_info)
|
||||
|
||||
tool_call_message = current_tool_call.to_msg_str()
|
||||
tool_call_token_count = token_counter(tool_call_message)
|
||||
citation_mapping = research_results.citation_mapping
|
||||
|
||||
tool_call_msg = ChatMessageSimple(
|
||||
message=tool_call_message,
|
||||
token_count=tool_call_token_count,
|
||||
message_type=MessageType.TOOL_CALL,
|
||||
tool_call_id=current_tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(tool_call_msg)
|
||||
for tab_index, report in enumerate(
|
||||
research_results.intermediate_reports
|
||||
):
|
||||
if report is None:
|
||||
# The LLM will not see that this research was even attempted, it may try
|
||||
# something similar again but this is not bad.
|
||||
logger.error(
|
||||
f"Research agent call at tab_index {tab_index} failed, skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
tool_call_response_msg = ChatMessageSimple(
|
||||
message=report,
|
||||
token_count=token_counter(report),
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id=current_tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(tool_call_response_msg)
|
||||
current_tool_call = research_agent_calls[tab_index]
|
||||
tool_call_info = ToolCallInfo(
|
||||
parent_tool_call_id=None,
|
||||
turn_index=orchestrator_start_turn_index
|
||||
+ cycle
|
||||
+ reasoning_cycles,
|
||||
tab_index=tab_index,
|
||||
tool_name=current_tool_call.tool_name,
|
||||
tool_call_id=current_tool_call.tool_call_id,
|
||||
tool_id=get_tool_by_name(
|
||||
tool_name=RESEARCH_AGENT_TOOL_NAME,
|
||||
db_session=db_session,
|
||||
).id,
|
||||
reasoning_tokens=llm_step_result.reasoning
|
||||
or most_recent_reasoning,
|
||||
tool_call_arguments=current_tool_call.tool_args,
|
||||
tool_call_response=report,
|
||||
search_docs=None, # Intermediate docs are not saved/shown
|
||||
generated_images=None,
|
||||
)
|
||||
state_container.add_tool_call(tool_call_info)
|
||||
|
||||
# If it reached this point, it did not call reasoning, so here we wipe it to not save it to multiple turns
|
||||
most_recent_reasoning = None
|
||||
tool_call_message = current_tool_call.to_msg_str()
|
||||
tool_call_token_count = token_counter(tool_call_message)
|
||||
|
||||
tool_call_msg = ChatMessageSimple(
|
||||
message=tool_call_message,
|
||||
token_count=tool_call_token_count,
|
||||
message_type=MessageType.TOOL_CALL,
|
||||
tool_call_id=current_tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(tool_call_msg)
|
||||
|
||||
tool_call_response_msg = ChatMessageSimple(
|
||||
message=report,
|
||||
token_count=token_counter(report),
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id=current_tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
)
|
||||
simple_chat_history.append(tool_call_response_msg)
|
||||
|
||||
# If it reached this point, it did not call reasoning, so here we wipe it to not save it to multiple turns
|
||||
most_recent_reasoning = None
|
||||
|
||||
emitter.emit(
|
||||
Packet(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
GENERATE_PLAN_TOOL_NAME = "generate_plan"
|
||||
|
||||
RESEARCH_AGENT_DB_NAME = "ResearchAgent"
|
||||
RESEARCH_AGENT_IN_CODE_ID = "ResearchAgent"
|
||||
RESEARCH_AGENT_TOOL_NAME = "research_agent"
|
||||
RESEARCH_AGENT_TASK_KEY = "task"
|
||||
|
||||
|
||||
@@ -3,6 +3,9 @@ import json
|
||||
import httpx
|
||||
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_experts_stores_representations,
|
||||
)
|
||||
from onyx.context.search.enums import QueryType
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
@@ -44,6 +47,7 @@ from onyx.document_index.opensearch.search import (
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import Document
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
|
||||
@@ -58,50 +62,36 @@ def _convert_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
blurb=chunk.blurb,
|
||||
content=chunk.content,
|
||||
source_links=json.loads(chunk.source_links) if chunk.source_links else None,
|
||||
image_file_id=chunk.image_file_name,
|
||||
# TODO(andrei) Yuhong says he doesn't think we need that anymore. Used
|
||||
# if a section needed to be split into diff chunks. A section is a part
|
||||
# of a doc that a link will take you to. But don't chunks have their own
|
||||
# links? Look at this in a followup.
|
||||
image_file_id=chunk.image_file_id,
|
||||
# Deprecated. Fill in some reasonable default.
|
||||
section_continuation=False,
|
||||
document_id=chunk.document_id,
|
||||
source_type=DocumentSource(chunk.source_type),
|
||||
semantic_identifier=chunk.semantic_identifier,
|
||||
title=chunk.title,
|
||||
# TODO(andrei): Same comment as in
|
||||
# _convert_onyx_chunk_to_opensearch_document. Yuhong thinks OpenSearch
|
||||
# has some thing out of the box for this. Just need to look at it in a
|
||||
# followup.
|
||||
boost=1,
|
||||
# TODO(andrei): Do in a followup.
|
||||
boost=chunk.global_boost,
|
||||
# TODO(andrei): Do in a followup. We should be able to get this from
|
||||
# OpenSearch.
|
||||
recency_bias=1.0,
|
||||
# TODO(andrei): This is how good the match is, we need this, key insight
|
||||
# is we can order chunks by this. Should not be hard to plumb this from
|
||||
# a search result, do that in a followup.
|
||||
score=None,
|
||||
hidden=chunk.hidden,
|
||||
# TODO(andrei): Don't worry about these for now.
|
||||
# is_relevant
|
||||
# relevance_explanation
|
||||
# metadata
|
||||
# TODO(andrei): Same comment as in
|
||||
# _convert_onyx_chunk_to_opensearch_document.
|
||||
metadata={},
|
||||
metadata=json.loads(chunk.metadata),
|
||||
# TODO(andrei): The vector DB needs to supply this. I vaguely know
|
||||
# OpenSearch can from the documentation I've seen till now, look at this
|
||||
# in a followup.
|
||||
match_highlights=[],
|
||||
# TODO(andrei) This content is not queried on, it is only used to clean
|
||||
# appended content to chunks. Consider storing a chunk content index
|
||||
# instead of a full string when working on chunk content augmentation.
|
||||
doc_summary="",
|
||||
# TODO(andrei) Consider storing a chunk content index instead of a full
|
||||
# string when working on chunk content augmentation.
|
||||
doc_summary=chunk.doc_summary,
|
||||
# TODO(andrei) Same thing as contx ret above, LLM gens context for each
|
||||
# chunk.
|
||||
chunk_context="",
|
||||
chunk_context=chunk.chunk_context,
|
||||
updated_at=chunk.last_updated,
|
||||
# primary_owners TODO(andrei)
|
||||
# secondary_owners TODO(andrei)
|
||||
# large_chunk_reference_ids TODO(andrei): Don't worry about this one.
|
||||
primary_owners=chunk.primary_owners,
|
||||
secondary_owners=chunk.secondary_owners,
|
||||
# TODO(andrei): This is the suffix appended to the end of the chunk
|
||||
# content to assist querying. There are better ways we can do this, for
|
||||
# ex. keeping an index of where to string split from.
|
||||
@@ -126,44 +116,31 @@ def _convert_onyx_chunk_to_opensearch_document(
|
||||
title_vector=chunk.title_embedding,
|
||||
content=chunk.content,
|
||||
content_vector=chunk.embeddings.full_embedding,
|
||||
# TODO(andrei): We should know this. Reason to have this is convenience,
|
||||
# but it could also change when you change your embedding model, maybe
|
||||
# we can remove it, Yuhong to look at this. Hardcoded to some nonsense
|
||||
# value for now.
|
||||
num_tokens=0,
|
||||
source_type=chunk.source_document.source.value,
|
||||
# TODO(andrei): This is just represented a bit differently in
|
||||
# DocumentBase than how we expect it in the schema currently. Look at
|
||||
# this closer in a followup. Always defaults to None for now.
|
||||
# metadata=chunk.source_document.metadata,
|
||||
metadata=json.dumps(chunk.source_document.metadata),
|
||||
last_updated=chunk.source_document.doc_updated_at,
|
||||
# TODO(andrei): Don't currently see an easy way of porting this, and
|
||||
# besides some connectors genuinely don't have this data. Look at this
|
||||
# closer in a followup. Always defaults to None for now.
|
||||
# created_at=None,
|
||||
public=chunk.access.is_public,
|
||||
# TODO(andrei): Implement ACL in a followup, currently none of the
|
||||
# methods in OpenSearchDocumentIndex support it anyway. Always defaults
|
||||
# to None for now.
|
||||
# access_control_list=chunk.access.to_acl(),
|
||||
# TODO(andrei): This doesn't work bc global_boost is float, presumably
|
||||
# between 0.0 and inf (check this) and chunk.boost is an int from -inf
|
||||
# to +inf. Look at how the scaling compares between these in a followup.
|
||||
# Always defaults to 1.0 for now.
|
||||
# global_boost=chunk.boost,
|
||||
access_control_list=list(chunk.access.to_acl()),
|
||||
global_boost=chunk.boost,
|
||||
semantic_identifier=chunk.source_document.semantic_identifier,
|
||||
# TODO(andrei): Ask Chris more about this later. Always defaults to None
|
||||
# for now.
|
||||
# image_file_name=None,
|
||||
image_file_id=chunk.image_file_id,
|
||||
source_links=json.dumps(chunk.source_links) if chunk.source_links else None,
|
||||
blurb=chunk.blurb,
|
||||
doc_summary=chunk.doc_summary,
|
||||
chunk_context=chunk.chunk_context,
|
||||
document_sets=list(chunk.document_sets) if chunk.document_sets else None,
|
||||
project_ids=list(chunk.user_project) if chunk.user_project else None,
|
||||
primary_owners=get_experts_stores_representations(
|
||||
chunk.source_document.primary_owners
|
||||
),
|
||||
secondary_owners=get_experts_stores_representations(
|
||||
chunk.source_document.secondary_owners
|
||||
),
|
||||
# TODO(andrei): Consider not even getting this from
|
||||
# DocMetadataAwareIndexChunk and instead using OpenSearchDocumentIndex's
|
||||
# instance variable. One source of truth -> less chance of a very bad
|
||||
# bug in prod.
|
||||
tenant_id=chunk.tenant_id,
|
||||
tenant_id=TenantState(tenant_id=chunk.tenant_id, multitenant=MULTI_TENANT),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -4,30 +4,35 @@ from typing import Any
|
||||
from typing import Self
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import field_serializer
|
||||
from pydantic import field_validator
|
||||
from pydantic import model_serializer
|
||||
from pydantic import model_validator
|
||||
from pydantic import SerializerFunctionWrapHandler
|
||||
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.constants import EF_CONSTRUCTION
|
||||
from onyx.document_index.opensearch.constants import EF_SEARCH
|
||||
from onyx.document_index.opensearch.constants import M
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
TITLE_FIELD_NAME = "title"
|
||||
TITLE_VECTOR_FIELD_NAME = "title_vector"
|
||||
CONTENT_FIELD_NAME = "content"
|
||||
CONTENT_VECTOR_FIELD_NAME = "content_vector"
|
||||
NUM_TOKENS_FIELD_NAME = "num_tokens"
|
||||
SOURCE_TYPE_FIELD_NAME = "source_type"
|
||||
METADATA_FIELD_NAME = "metadata"
|
||||
LAST_UPDATED_FIELD_NAME = "last_updated"
|
||||
CREATED_AT_FIELD_NAME = "created_at"
|
||||
PUBLIC_FIELD_NAME = "public"
|
||||
ACCESS_CONTROL_LIST_FIELD_NAME = "access_control_list"
|
||||
HIDDEN_FIELD_NAME = "hidden"
|
||||
GLOBAL_BOOST_FIELD_NAME = "global_boost"
|
||||
SEMANTIC_IDENTIFIER_FIELD_NAME = "semantic_identifier"
|
||||
IMAGE_FILE_NAME_FIELD_NAME = "image_file_name"
|
||||
IMAGE_FILE_ID_FIELD_NAME = "image_file_id"
|
||||
SOURCE_LINKS_FIELD_NAME = "source_links"
|
||||
DOCUMENT_SETS_FIELD_NAME = "document_sets"
|
||||
PROJECT_IDS_FIELD_NAME = "project_ids"
|
||||
@@ -36,6 +41,10 @@ CHUNK_INDEX_FIELD_NAME = "chunk_index"
|
||||
MAX_CHUNK_SIZE_FIELD_NAME = "max_chunk_size"
|
||||
TENANT_ID_FIELD_NAME = "tenant_id"
|
||||
BLURB_FIELD_NAME = "blurb"
|
||||
DOC_SUMMARY_FIELD_NAME = "doc_summary"
|
||||
CHUNK_CONTEXT_FIELD_NAME = "chunk_context"
|
||||
PRIMARY_OWNERS_FIELD_NAME = "primary_owners"
|
||||
SECONDARY_OWNERS_FIELD_NAME = "secondary_owners"
|
||||
|
||||
|
||||
def get_opensearch_doc_chunk_id(
|
||||
@@ -52,12 +61,27 @@ def get_opensearch_doc_chunk_id(
|
||||
return f"{document_id}__{max_chunk_size}__{chunk_index}"
|
||||
|
||||
|
||||
def set_or_convert_timezone_to_utc(value: datetime) -> datetime:
|
||||
if value.tzinfo is None:
|
||||
# astimezone will raise if value does not have a timezone set.
|
||||
value = value.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
# Does appropriate time conversion if value was set in a different
|
||||
# timezone.
|
||||
value = value.astimezone(timezone.utc)
|
||||
return value
|
||||
|
||||
|
||||
class DocumentChunk(BaseModel):
|
||||
"""
|
||||
Represents a chunk of a document in the OpenSearch index.
|
||||
|
||||
The names of these fields are based on the OpenSearch schema. Changes to the
|
||||
schema require changes here. See get_document_schema.
|
||||
|
||||
WARNING: Relies on MULTI_TENANT which is global state. Also uses
|
||||
get_current_tenant_id. Generally relying on global state is bad, in this
|
||||
case we accept it because of the importance of validating tenant logic.
|
||||
"""
|
||||
|
||||
model_config = {"frozen": True}
|
||||
@@ -75,41 +99,44 @@ class DocumentChunk(BaseModel):
|
||||
title_vector: list[float] | None = None
|
||||
content: str
|
||||
content_vector: list[float]
|
||||
# The actual number of tokens in the chunk.
|
||||
num_tokens: int
|
||||
|
||||
source_type: str
|
||||
# Application logic should store these strings the format key:::value.
|
||||
metadata: list[str] | None = None
|
||||
# Contains a string representation of a dict which maps string key to either
|
||||
# string value or list of string values.
|
||||
# TODO(andrei): When we augment content with metadata this can just be an
|
||||
# index pointer, and when we support metadata list that will just be a list
|
||||
# of strings.
|
||||
metadata: str
|
||||
# If it exists, time zone should always be UTC.
|
||||
last_updated: datetime | None = None
|
||||
created_at: datetime | None = None
|
||||
|
||||
public: bool
|
||||
access_control_list: list[str] | None = None
|
||||
access_control_list: list[str]
|
||||
# Defaults to False, currently gets written during update not index.
|
||||
hidden: bool = False
|
||||
|
||||
global_boost: float = 1.0
|
||||
global_boost: int
|
||||
|
||||
semantic_identifier: str
|
||||
image_file_name: str | None = None
|
||||
image_file_id: str | None = None
|
||||
# Contains a string representation of a dict which maps offset into the raw
|
||||
# chunk text to the link corresponding to that point.
|
||||
source_links: str | None = None
|
||||
blurb: str
|
||||
doc_summary: str
|
||||
chunk_context: str
|
||||
|
||||
document_sets: list[str] | None = None
|
||||
# User projects.
|
||||
project_ids: list[int] | None = None
|
||||
primary_owners: list[str] | None = None
|
||||
secondary_owners: list[str] | None = None
|
||||
|
||||
tenant_id: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_num_tokens_fits_within_max_chunk_size(self) -> Self:
|
||||
if self.num_tokens > self.max_chunk_size:
|
||||
raise ValueError(
|
||||
"Bug: Num tokens must be less than or equal to max chunk size."
|
||||
)
|
||||
return self
|
||||
tenant_id: TenantState = Field(
|
||||
default_factory=lambda: TenantState(
|
||||
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
|
||||
)
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_title_and_title_vector_are_consistent(self) -> Self:
|
||||
@@ -120,25 +147,116 @@ class DocumentChunk(BaseModel):
|
||||
raise ValueError("Bug: Title must not be None if title vector is not None.")
|
||||
return self
|
||||
|
||||
@field_serializer("last_updated", "created_at", mode="plain")
|
||||
@model_serializer(mode="wrap")
|
||||
def serialize_model(
|
||||
self, handler: SerializerFunctionWrapHandler
|
||||
) -> dict[str, object]:
|
||||
"""Invokes pydantic's serialization logic, then excludes Nones.
|
||||
|
||||
We do this because .model_dump(exclude_none=True) does not work after
|
||||
@field_serializer logic, so for some field serializers which return None
|
||||
and which we would like to exclude from the final dump, they would be
|
||||
included without this.
|
||||
|
||||
Args:
|
||||
handler: Callable from pydantic which takes the instance of the
|
||||
model as an argument and performs standard serialization.
|
||||
|
||||
Returns:
|
||||
The return of handler but with None items excluded.
|
||||
"""
|
||||
serialized: dict[str, object] = handler(self)
|
||||
serialized_exclude_none = {k: v for k, v in serialized.items() if v is not None}
|
||||
return serialized_exclude_none
|
||||
|
||||
@field_serializer("last_updated", mode="wrap")
|
||||
def serialize_datetime_fields_to_epoch_millis(
|
||||
self, value: datetime | None
|
||||
self, value: datetime | None, handler: SerializerFunctionWrapHandler
|
||||
) -> int | None:
|
||||
"""
|
||||
Serializes datetime fields to milliseconds since the Unix epoch.
|
||||
|
||||
If there is no datetime, returns None.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if value.tzinfo is None:
|
||||
# astimezone will raise if value does not have a timezone set.
|
||||
value = value.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
# Does appropriate time conversion if value was set in a different
|
||||
# timezone.
|
||||
value = value.astimezone(timezone.utc)
|
||||
value = set_or_convert_timezone_to_utc(value)
|
||||
# timestamp returns a float in seconds so convert to millis.
|
||||
return int(value.timestamp() * 1000)
|
||||
|
||||
@field_validator("last_updated", mode="before")
|
||||
@classmethod
|
||||
def parse_epoch_millis_to_datetime(cls, value: Any) -> datetime | None:
|
||||
"""Parses milliseconds since the Unix epoch to a datetime object.
|
||||
|
||||
If the input is None, returns None.
|
||||
|
||||
The datetime returned will be in UTC.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, datetime):
|
||||
value = set_or_convert_timezone_to_utc(value)
|
||||
return value
|
||||
if not isinstance(value, int):
|
||||
raise ValueError(
|
||||
f"Bug: Expected an int for the last_updated property from OpenSearch, got {type(value)} instead."
|
||||
)
|
||||
return datetime.fromtimestamp(value / 1000, tz=timezone.utc)
|
||||
|
||||
@field_serializer("tenant_id", mode="wrap")
|
||||
def serialize_tenant_state(
|
||||
self, value: TenantState, handler: SerializerFunctionWrapHandler
|
||||
) -> str | None:
|
||||
"""
|
||||
Serializes tenant_state to the tenant str if multitenant, or None if
|
||||
not.
|
||||
|
||||
The idea is that in single tenant mode, the schema does not have a
|
||||
tenant_id field, so we don't want to supply it in our serialized
|
||||
DocumentChunk. This assumes the final serialized model excludes None
|
||||
fields, which serialize_model should enforce.
|
||||
"""
|
||||
if not value.multitenant:
|
||||
return None
|
||||
else:
|
||||
return value.tenant_id
|
||||
|
||||
@field_validator("tenant_id", mode="before")
|
||||
@classmethod
|
||||
def parse_tenant_id(cls, value: Any) -> TenantState:
|
||||
"""
|
||||
Generates a TenantState from OpenSearch's tenant_id if it exists, or
|
||||
generates a default state if it does not (implies we are in single
|
||||
tenant mode).
|
||||
"""
|
||||
if value is None:
|
||||
if MULTI_TENANT:
|
||||
raise ValueError(
|
||||
"Bug: No tenant_id was supplied but multi-tenant mode is enabled."
|
||||
)
|
||||
return TenantState(
|
||||
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
|
||||
)
|
||||
elif isinstance(value, TenantState):
|
||||
if MULTI_TENANT != value.multitenant:
|
||||
raise ValueError(
|
||||
f"Bug: An existing TenantState object was supplied to the DocumentChunk model but its multi-tenant mode "
|
||||
f"({value.multitenant}) does not match the program's current global tenancy state."
|
||||
)
|
||||
return value
|
||||
elif not isinstance(value, str):
|
||||
raise ValueError(
|
||||
f"Bug: Expected a str for the tenant_id property from OpenSearch, got {type(value)} instead."
|
||||
)
|
||||
else:
|
||||
if not MULTI_TENANT:
|
||||
raise ValueError(
|
||||
"Bug: Got a non-null str for the tenant_id property from OpenSearch but multi-tenant mode is not enabled. "
|
||||
"This is unexpected because in single-tenant mode we don't expect to see a tenant_id."
|
||||
)
|
||||
return TenantState(tenant_id=value, multitenant=MULTI_TENANT)
|
||||
|
||||
|
||||
class DocumentSchema:
|
||||
"""
|
||||
@@ -176,13 +294,19 @@ class DocumentSchema:
|
||||
OpenSearch client. The structure of this dictionary is
|
||||
determined by OpenSearch documentation.
|
||||
"""
|
||||
schema = {
|
||||
schema: dict[str, Any] = {
|
||||
# By default OpenSearch allows dynamically adding new properties
|
||||
# based on indexed documents. This is awful and we disable it here.
|
||||
# An exception will be raised if you try to index a new doc which
|
||||
# contains unexpected fields.
|
||||
"dynamic": "strict",
|
||||
"properties": {
|
||||
TITLE_FIELD_NAME: {
|
||||
"type": "text",
|
||||
"fields": {
|
||||
# Subfield accessed as title.keyword. Not indexed for
|
||||
# values longer than 256 chars.
|
||||
# TODO(andrei): Ask Yuhong do we want this?
|
||||
"keyword": {"type": "keyword", "ignore_above": 256}
|
||||
},
|
||||
},
|
||||
@@ -200,6 +324,8 @@ class DocumentSchema:
|
||||
"parameters": {"ef_construction": EF_CONSTRUCTION, "m": M},
|
||||
},
|
||||
},
|
||||
# TODO(andrei): This is a tensor in Vespa. Also look at feature
|
||||
# parity for these other method fields.
|
||||
CONTENT_VECTOR_FIELD_NAME: {
|
||||
"type": "knn_vector",
|
||||
"dimension": vector_dimension,
|
||||
@@ -210,14 +336,10 @@ class DocumentSchema:
|
||||
"parameters": {"ef_construction": EF_CONSTRUCTION, "m": M},
|
||||
},
|
||||
},
|
||||
# See TODO in _convert_onyx_chunk_to_opensearch_document. I
|
||||
# don't want to actually add this to the schema until we know
|
||||
# for sure we need it. If we decide we don't I will remove this.
|
||||
# # Number of tokens in the chunk's content.
|
||||
# NUM_TOKENS_FIELD_NAME: {"type": "integer", "store": True},
|
||||
SOURCE_TYPE_FIELD_NAME: {"type": "keyword"},
|
||||
# Application logic should store in the format key:::value.
|
||||
METADATA_FIELD_NAME: {"type": "keyword"},
|
||||
# TODO(andrei): Check if Vespa stores seconds, we may wanna do
|
||||
# seconds here not millis.
|
||||
LAST_UPDATED_FIELD_NAME: {
|
||||
"type": "date",
|
||||
"format": "epoch_millis",
|
||||
@@ -225,16 +347,6 @@ class DocumentSchema:
|
||||
# would make sense to sort by date.
|
||||
"doc_values": True,
|
||||
},
|
||||
# See TODO in _convert_onyx_chunk_to_opensearch_document. I
|
||||
# don't want to actually add this to the schema until we know
|
||||
# for sure we need it. If we decide we don't I will remove this.
|
||||
# CREATED_AT_FIELD_NAME: {
|
||||
# "type": "date",
|
||||
# "format": "epoch_millis",
|
||||
# # For some reason date defaults to False, even though it
|
||||
# # would make sense to sort by date.
|
||||
# "doc_values": True,
|
||||
# },
|
||||
# Access control fields.
|
||||
# Whether the doc is public. Could have fallen under access
|
||||
# control list but is such a broad and critical filter that it
|
||||
@@ -247,7 +359,7 @@ class DocumentSchema:
|
||||
# all other search filters; up to search implementations to
|
||||
# guarantee this.
|
||||
HIDDEN_FIELD_NAME: {"type": "boolean"},
|
||||
GLOBAL_BOOST_FIELD_NAME: {"type": "float"},
|
||||
GLOBAL_BOOST_FIELD_NAME: {"type": "integer"},
|
||||
# This field is only used for displaying a useful name for the
|
||||
# doc in the UI and is not used for searching. Disabling these
|
||||
# features to increase perf.
|
||||
@@ -258,7 +370,7 @@ class DocumentSchema:
|
||||
"store": False,
|
||||
},
|
||||
# Same as above; used to display an image along with the doc.
|
||||
IMAGE_FILE_NAME_FIELD_NAME: {
|
||||
IMAGE_FILE_ID_FIELD_NAME: {
|
||||
"type": "keyword",
|
||||
"index": False,
|
||||
"doc_values": False,
|
||||
@@ -278,15 +390,36 @@ class DocumentSchema:
|
||||
"doc_values": False,
|
||||
"store": False,
|
||||
},
|
||||
# Same as above.
|
||||
# TODO(andrei): If we want to search on this this needs to be
|
||||
# changed.
|
||||
DOC_SUMMARY_FIELD_NAME: {
|
||||
"type": "keyword",
|
||||
"index": False,
|
||||
"doc_values": False,
|
||||
"store": False,
|
||||
},
|
||||
# Same as above.
|
||||
# TODO(andrei): If we want to search on this this needs to be
|
||||
# changed.
|
||||
CHUNK_CONTEXT_FIELD_NAME: {
|
||||
"type": "keyword",
|
||||
"index": False,
|
||||
"doc_values": False,
|
||||
"store": False,
|
||||
},
|
||||
# Product-specific fields.
|
||||
DOCUMENT_SETS_FIELD_NAME: {"type": "keyword"},
|
||||
PROJECT_IDS_FIELD_NAME: {"type": "integer"},
|
||||
PRIMARY_OWNERS_FIELD_NAME: {"type": "keyword"},
|
||||
SECONDARY_OWNERS_FIELD_NAME: {"type": "keyword"},
|
||||
# OpenSearch metadata fields.
|
||||
DOCUMENT_ID_FIELD_NAME: {"type": "keyword"},
|
||||
CHUNK_INDEX_FIELD_NAME: {"type": "integer"},
|
||||
# The maximum number of tokens this chunk's content can hold.
|
||||
# TODO(andrei): Can we generalize this to embedding type?
|
||||
MAX_CHUNK_SIZE_FIELD_NAME: {"type": "integer"},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
if multitenant:
|
||||
|
||||
@@ -24,7 +24,7 @@ from onyx.document_index.opensearch.schema import TITLE_VECTOR_FIELD_NAME
|
||||
# TODO(andrei): Turn all magic dictionaries to pydantic models.
|
||||
|
||||
MIN_MAX_NORMALIZATION_PIPELINE_NAME = "normalization_pipeline_min_max"
|
||||
MIN_MAX_NORMALIZATION_PIPELINE_CONFIG = {
|
||||
MIN_MAX_NORMALIZATION_PIPELINE_CONFIG: dict[str, Any] = {
|
||||
"description": "Normalization for keyword and vector scores using min-max",
|
||||
"phase_results_processors": [
|
||||
{
|
||||
@@ -49,7 +49,7 @@ MIN_MAX_NORMALIZATION_PIPELINE_CONFIG = {
|
||||
}
|
||||
|
||||
ZSCORE_NORMALIZATION_PIPELINE_NAME = "normalization_pipeline_zscore"
|
||||
ZSCORE_NORMALIZATION_PIPELINE_CONFIG = {
|
||||
ZSCORE_NORMALIZATION_PIPELINE_CONFIG: dict[str, Any] = {
|
||||
"description": "Normalization for keyword and vector scores using z-score",
|
||||
"phase_results_processors": [
|
||||
{
|
||||
@@ -140,7 +140,7 @@ class DocumentQuery:
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
|
||||
]
|
||||
|
||||
if tenant_state.tenant_id is not None:
|
||||
if tenant_state.multitenant:
|
||||
# TODO(andrei): Fix tenant stuff.
|
||||
filter_clauses.append(
|
||||
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
|
||||
@@ -199,7 +199,7 @@ class DocumentQuery:
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
|
||||
]
|
||||
|
||||
if tenant_state.tenant_id is not None:
|
||||
if tenant_state.multitenant:
|
||||
filter_clauses.append(
|
||||
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
|
||||
)
|
||||
@@ -316,6 +316,7 @@ class DocumentQuery:
|
||||
{
|
||||
"multi_match": {
|
||||
"query": query_text,
|
||||
# TODO(andrei): Ask Yuhong do we want this?
|
||||
"fields": [f"{TITLE_FIELD_NAME}^2", f"{TITLE_FIELD_NAME}.keyword"],
|
||||
"type": "best_fields",
|
||||
}
|
||||
@@ -340,7 +341,7 @@ class DocumentQuery:
|
||||
{"term": {PUBLIC_FIELD_NAME: {"value": True}}},
|
||||
{"term": {HIDDEN_FIELD_NAME: {"value": False}}},
|
||||
]
|
||||
if tenant_state.tenant_id is not None:
|
||||
if tenant_state.multitenant:
|
||||
hybrid_search_filters.append(
|
||||
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@ from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from unstructured_client.models import operations # type: ignore
|
||||
from unstructured_client.models import operations
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -55,19 +55,19 @@ def _sdk_partition_request(
|
||||
|
||||
def unstructured_to_text(file: IO[Any], file_name: str) -> str:
|
||||
from unstructured.staging.base import dict_to_elements
|
||||
from unstructured_client import UnstructuredClient # type: ignore
|
||||
from unstructured_client import UnstructuredClient
|
||||
|
||||
logger.debug(f"Starting to read file: {file_name}")
|
||||
req = _sdk_partition_request(file, file_name, strategy="fast")
|
||||
|
||||
unstructured_client = UnstructuredClient(api_key_auth=get_unstructured_api_key())
|
||||
|
||||
response = unstructured_client.general.partition(req)
|
||||
elements = dict_to_elements(response.elements)
|
||||
response = unstructured_client.general.partition(request=req)
|
||||
|
||||
if response.status_code != 200:
|
||||
err = f"Received unexpected status code {response.status_code} from Unstructured API."
|
||||
logger.error(err)
|
||||
raise ValueError(err)
|
||||
|
||||
elements = dict_to_elements(response.elements or [])
|
||||
return "\n\n".join(str(el) for el in elements)
|
||||
|
||||
@@ -40,6 +40,7 @@ class BaseChunk(BaseModel):
|
||||
source_links: dict[int, str] | None
|
||||
image_file_id: str | None
|
||||
# True if this Chunk's start is not at the start of a Section
|
||||
# TODO(andrei): This is deprecated as of the OpenSearch migration. Remove.
|
||||
section_continuation: bool
|
||||
|
||||
|
||||
|
||||
@@ -369,6 +369,8 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
# New output item added
|
||||
output_item = parsed_chunk.get("item", {})
|
||||
if output_item.get("type") == "function_call":
|
||||
# Track that we've received tool calls via streaming
|
||||
self._has_streamed_tool_calls = True
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=ChatCompletionToolCallChunk(
|
||||
@@ -394,6 +396,8 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
content_part: Optional[str] = parsed_chunk.get("delta", None)
|
||||
if content_part:
|
||||
# Track that we've received tool calls via streaming
|
||||
self._has_streamed_tool_calls = True
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=ChatCompletionToolCallChunk(
|
||||
@@ -491,22 +495,72 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
|
||||
elif event_type == "response.completed":
|
||||
# Final event signaling all output items (including parallel tool calls) are done
|
||||
# Check if we already received tool calls via streaming events
|
||||
# There is an issue where OpenAI (not via Azure) will give back the tool calls streamed out as tokens
|
||||
# But on Azure, it's only given out all at once. OpenAI also happens to give back the tool calls in the
|
||||
# response.completed event so we need to throw it out here or there are duplicate tool calls.
|
||||
has_streamed_tool_calls = getattr(self, "_has_streamed_tool_calls", False)
|
||||
|
||||
response_data = parsed_chunk.get("response", {})
|
||||
# Determine finish reason based on response content
|
||||
finish_reason = "stop"
|
||||
if response_data.get("output"):
|
||||
for item in response_data["output"]:
|
||||
if isinstance(item, dict) and item.get("type") == "function_call":
|
||||
finish_reason = "tool_calls"
|
||||
break
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason=finish_reason,
|
||||
usage=None,
|
||||
output_items = response_data.get("output", [])
|
||||
|
||||
# Check if there are function_call items in the output
|
||||
has_function_calls = any(
|
||||
isinstance(item, dict) and item.get("type") == "function_call"
|
||||
for item in output_items
|
||||
)
|
||||
|
||||
if has_function_calls and not has_streamed_tool_calls:
|
||||
# Azure's Responses API returns all tool calls in response.completed
|
||||
# without streaming them incrementally. Extract them here.
|
||||
from litellm.types.utils import (
|
||||
Delta,
|
||||
ModelResponseStream,
|
||||
StreamingChoices,
|
||||
)
|
||||
|
||||
tool_calls = []
|
||||
for idx, item in enumerate(output_items):
|
||||
if isinstance(item, dict) and item.get("type") == "function_call":
|
||||
tool_calls.append(
|
||||
ChatCompletionToolCallChunk(
|
||||
id=item.get("call_id"),
|
||||
index=idx,
|
||||
type="function",
|
||||
function=ChatCompletionToolCallFunctionChunk(
|
||||
name=item.get("name"),
|
||||
arguments=item.get("arguments", ""),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return ModelResponseStream(
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
index=0,
|
||||
delta=Delta(tool_calls=tool_calls),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
]
|
||||
)
|
||||
elif has_function_calls:
|
||||
# Tool calls were already streamed, just signal completion
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason="tool_calls",
|
||||
usage=None,
|
||||
)
|
||||
else:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason="stop",
|
||||
usage=None,
|
||||
)
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
@@ -631,6 +685,40 @@ def _patch_openai_responses_transform_response() -> None:
|
||||
LiteLLMResponsesTransformationHandler.transform_response = _patched_transform_response # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _patch_azure_responses_should_fake_stream() -> None:
|
||||
"""
|
||||
Patches AzureOpenAIResponsesAPIConfig.should_fake_stream to always return False.
|
||||
|
||||
By default, LiteLLM uses "fake streaming" (MockResponsesAPIStreamingIterator) for models
|
||||
not in its database. This causes Azure custom model deployments to buffer the entire
|
||||
response before yielding, resulting in poor time-to-first-token.
|
||||
|
||||
Azure's Responses API supports native streaming, so we override this to always use
|
||||
real streaming (SyncResponsesAPIStreamingIterator).
|
||||
"""
|
||||
from litellm.llms.azure.responses.transformation import (
|
||||
AzureOpenAIResponsesAPIConfig,
|
||||
)
|
||||
|
||||
if (
|
||||
getattr(AzureOpenAIResponsesAPIConfig.should_fake_stream, "__name__", "")
|
||||
== "_patched_should_fake_stream"
|
||||
):
|
||||
return
|
||||
|
||||
def _patched_should_fake_stream(
|
||||
self: Any,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
# Azure Responses API supports native streaming - never fake it
|
||||
return False
|
||||
|
||||
_patched_should_fake_stream.__name__ = "_patched_should_fake_stream"
|
||||
AzureOpenAIResponsesAPIConfig.should_fake_stream = _patched_should_fake_stream # type: ignore[method-assign]
|
||||
|
||||
|
||||
def apply_monkey_patches() -> None:
|
||||
"""
|
||||
Apply all necessary monkey patches to LiteLLM for compatibility.
|
||||
@@ -640,12 +728,13 @@ def apply_monkey_patches() -> None:
|
||||
- Patching OllamaChatCompletionResponseIterator.chunk_parser for streaming content
|
||||
- Patching OpenAiResponsesToChatCompletionStreamIterator.chunk_parser for OpenAI Responses API
|
||||
- Patching LiteLLMResponsesTransformationHandler.transform_response for non-streaming responses
|
||||
- Patching LiteLLMResponsesTransformationHandler._convert_content_str_to_input_text for tool content types
|
||||
- Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming
|
||||
"""
|
||||
_patch_ollama_transform_request()
|
||||
_patch_ollama_chunk_parser()
|
||||
_patch_openai_responses_chunk_parser()
|
||||
_patch_openai_responses_transform_response()
|
||||
_patch_azure_responses_should_fake_stream()
|
||||
|
||||
|
||||
def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[str]]:
|
||||
|
||||
@@ -63,7 +63,7 @@ def process_with_prompt_cache(
|
||||
return suffix, None
|
||||
|
||||
# Get provider adapter
|
||||
provider_adapter = get_provider_adapter(llm_config.model_provider)
|
||||
provider_adapter = get_provider_adapter(llm_config)
|
||||
|
||||
# If provider doesn't support caching, combine and return unchanged
|
||||
if not provider_adapter.supports_caching():
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
"""Factory for creating provider-specific prompt cache adapters."""
|
||||
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.prompt_cache.providers.anthropic import AnthropicPromptCacheProvider
|
||||
from onyx.llm.prompt_cache.providers.base import PromptCacheProvider
|
||||
from onyx.llm.prompt_cache.providers.noop import NoOpPromptCacheProvider
|
||||
from onyx.llm.prompt_cache.providers.openai import OpenAIPromptCacheProvider
|
||||
from onyx.llm.prompt_cache.providers.vertex import VertexAIPromptCacheProvider
|
||||
|
||||
ANTHROPIC_BEDROCK_TAG = "anthropic."
|
||||
|
||||
def get_provider_adapter(provider: str) -> PromptCacheProvider:
|
||||
|
||||
def get_provider_adapter(llm_config: LLMConfig) -> PromptCacheProvider:
|
||||
"""Get the appropriate prompt cache provider adapter for a given provider.
|
||||
|
||||
Args:
|
||||
@@ -17,11 +20,14 @@ def get_provider_adapter(provider: str) -> PromptCacheProvider:
|
||||
Returns:
|
||||
PromptCacheProvider instance for the given provider
|
||||
"""
|
||||
if provider == LlmProviderNames.OPENAI:
|
||||
if llm_config.model_provider == LlmProviderNames.OPENAI:
|
||||
return OpenAIPromptCacheProvider()
|
||||
elif provider in [LlmProviderNames.ANTHROPIC, LlmProviderNames.BEDROCK]:
|
||||
elif llm_config.model_provider == LlmProviderNames.ANTHROPIC or (
|
||||
llm_config.model_provider == LlmProviderNames.BEDROCK
|
||||
and ANTHROPIC_BEDROCK_TAG in llm_config.model_name
|
||||
):
|
||||
return AnthropicPromptCacheProvider()
|
||||
elif provider == LlmProviderNames.VERTEX_AI:
|
||||
elif llm_config.model_provider == LlmProviderNames.VERTEX_AI:
|
||||
return VertexAIPromptCacheProvider()
|
||||
else:
|
||||
# Default to no-op for providers without caching support
|
||||
|
||||
@@ -1,30 +1,39 @@
|
||||
from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS
|
||||
|
||||
SLACK_QUERY_EXPANSION_PROMPT = f"""
|
||||
Rewrite the user's query and, if helpful, split it into at most {MAX_SLACK_QUERY_EXPANSIONS} \
|
||||
keyword-only queries, so that Slack's keyword search yields the best matches.
|
||||
Rewrite the user's query into at most {MAX_SLACK_QUERY_EXPANSIONS} keyword-only queries for Slack's keyword search.
|
||||
|
||||
Keep in mind the Slack's search behavior:
|
||||
- Pure keyword AND search (no semantics).
|
||||
- Word order matters.
|
||||
- More words = fewer matches, so keep each query concise.
|
||||
- IMPORTANT: Prefer simple 1-2 word queries over longer multi-word queries.
|
||||
Slack search behavior:
|
||||
- Pure keyword AND search (no semantics)
|
||||
- More words = fewer matches, so keep queries concise (1-3 words)
|
||||
|
||||
Critical: Extract ONLY keywords that would actually appear in Slack message content.
|
||||
ALWAYS include:
|
||||
- Person names (e.g., "Sarah Chen", "Mike Johnson") - people search for messages from/about specific people
|
||||
- Project/product names, technical terms, proper nouns
|
||||
- Actual content words: "performance", "bug", "deployment", "API", "error"
|
||||
|
||||
DO NOT include:
|
||||
- Meta-words: "topics", "conversations", "discussed", "summary", "messages", "big", "main", "talking"
|
||||
- Temporal: "today", "yesterday", "week", "month", "recent", "past", "last"
|
||||
- Channels/Users: "general", "eng-general", "engineering", "@username"
|
||||
|
||||
DO include:
|
||||
- Actual content: "performance", "bug", "deployment", "API", "database", "error", "feature"
|
||||
- Meta-words: "topics", "conversations", "discussed", "summary", "messages"
|
||||
- Temporal: "today", "yesterday", "week", "month", "recent", "last"
|
||||
- Channel names: "general", "eng-general", "random"
|
||||
|
||||
Examples:
|
||||
|
||||
Query: "what are the big topics in eng-general this week?"
|
||||
Output:
|
||||
|
||||
Query: "messages with Sarah about the deployment"
|
||||
Output:
|
||||
Sarah deployment
|
||||
Sarah
|
||||
deployment
|
||||
|
||||
Query: "what did Mike say about the budget?"
|
||||
Output:
|
||||
Mike budget
|
||||
Mike
|
||||
budget
|
||||
|
||||
Query: "performance issues in eng-general"
|
||||
Output:
|
||||
performance issues
|
||||
@@ -41,7 +50,7 @@ Now process this query:
|
||||
|
||||
{{query}}
|
||||
|
||||
Output:
|
||||
Output (keywords only, one per line, NO explanations or commentary):
|
||||
"""
|
||||
|
||||
SLACK_DATE_EXTRACTION_PROMPT = """
|
||||
|
||||
@@ -109,6 +109,7 @@ class TenantRedis(redis.Redis):
|
||||
"unlock",
|
||||
"get",
|
||||
"set",
|
||||
"setex",
|
||||
"delete",
|
||||
"exists",
|
||||
"incrby",
|
||||
|
||||
@@ -697,7 +697,7 @@ def save_user_credentials(
|
||||
# TODO: fix and/or type correctly w/base model
|
||||
config_data = MCPConnectionData(
|
||||
headers=auth_template.config.get("headers", {}),
|
||||
header_substitutions=auth_template.config.get(HEADER_SUBSTITUTIONS, {}),
|
||||
header_substitutions=request.credentials,
|
||||
)
|
||||
for oauth_field_key in MCPOAuthKeys:
|
||||
field_key: Literal["client_info", "tokens", "metadata"] = (
|
||||
|
||||
@@ -34,7 +34,7 @@ from onyx.db.persona import mark_persona_as_not_deleted
|
||||
from onyx.db.persona import update_persona_is_default
|
||||
from onyx.db.persona import update_persona_label
|
||||
from onyx.db.persona import update_persona_public_status
|
||||
from onyx.db.persona import update_persona_shared_users
|
||||
from onyx.db.persona import update_persona_shared
|
||||
from onyx.db.persona import update_persona_visibility
|
||||
from onyx.db.persona import update_personas_display_priority
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
@@ -366,7 +366,9 @@ def delete_label(
|
||||
|
||||
|
||||
class PersonaShareRequest(BaseModel):
|
||||
user_ids: list[UUID]
|
||||
user_ids: list[UUID] | None = None
|
||||
group_ids: list[int] | None = None
|
||||
is_public: bool | None = None
|
||||
|
||||
|
||||
# We notify each user when a user is shared with them
|
||||
@@ -377,11 +379,13 @@ def share_persona(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_persona_shared_users(
|
||||
update_persona_shared(
|
||||
persona_id=persona_id,
|
||||
user_ids=persona_share_request.user_ids,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
user_ids=persona_share_request.user_ids,
|
||||
group_ids=persona_share_request.group_ids,
|
||||
is_public=persona_share_request.is_public,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -87,6 +87,11 @@ def parse_mdx_to_release_note_entries(mdx_content: str) -> list[ReleaseNoteEntry
|
||||
entries = [
|
||||
entry for entry in all_entries if is_version_gte(entry.version, __version__)
|
||||
]
|
||||
elif "nightly" in __version__:
|
||||
# Just show the latest entry for nightly versions
|
||||
entries = sorted(
|
||||
all_entries, key=lambda x: parse_version_tuple(x.version), reverse=True
|
||||
)[:1]
|
||||
else:
|
||||
# If not recognized version
|
||||
# likely `development` and we should show all entries
|
||||
|
||||
@@ -410,26 +410,20 @@ def list_llm_provider_basics(
|
||||
|
||||
all_providers = fetch_existing_llm_providers(db_session)
|
||||
user_group_ids = fetch_user_group_ids(db_session, user) if user else set()
|
||||
is_admin = user and user.role == UserRole.ADMIN
|
||||
is_admin = user is not None and user.role == UserRole.ADMIN
|
||||
|
||||
accessible_providers = []
|
||||
|
||||
for provider in all_providers:
|
||||
# Include all public providers
|
||||
if provider.is_public:
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
continue
|
||||
|
||||
# Include restricted providers user has access to via groups
|
||||
if is_admin:
|
||||
# Admins see all providers
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
elif provider.groups:
|
||||
# User must be in at least one of the provider's groups
|
||||
if user_group_ids.intersection({g.id for g in provider.groups}):
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
elif not provider.personas:
|
||||
# No restrictions = accessible
|
||||
# Use centralized access control logic with persona=None since we're
|
||||
# listing providers without a specific persona context. This correctly:
|
||||
# - Includes all public providers
|
||||
# - Includes providers user can access via group membership
|
||||
# - Excludes persona-only restricted providers (requires specific persona)
|
||||
# - Excludes non-public providers with no restrictions (admin-only)
|
||||
if can_user_access_llm_provider(
|
||||
provider, user_group_ids, persona=None, is_admin=is_admin
|
||||
):
|
||||
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
|
||||
|
||||
end_time = datetime.now(timezone.utc)
|
||||
|
||||
@@ -4,10 +4,13 @@ from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import InternetContentProvider
|
||||
from onyx.db.models import InternetSearchProvider
|
||||
from onyx.db.models import User
|
||||
from onyx.db.web_search import deactivate_web_content_provider
|
||||
from onyx.db.web_search import deactivate_web_search_provider
|
||||
@@ -94,6 +97,28 @@ def upsert_search_provider_endpoint(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Sync Exa key of search engine to content provider
|
||||
if (
|
||||
request.provider_type == WebSearchProviderType.EXA
|
||||
and request.api_key_changed
|
||||
and request.api_key
|
||||
):
|
||||
stmt = (
|
||||
insert(InternetContentProvider)
|
||||
.values(
|
||||
name="Exa",
|
||||
provider_type=WebContentProviderType.EXA.value,
|
||||
api_key=request.api_key,
|
||||
is_active=False,
|
||||
)
|
||||
.on_conflict_do_update(
|
||||
index_elements=["name"],
|
||||
set_={"api_key": request.api_key},
|
||||
)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
db_session.flush()
|
||||
|
||||
db_session.commit()
|
||||
return WebSearchProviderView(
|
||||
id=provider.id,
|
||||
@@ -245,6 +270,28 @@ def upsert_content_provider_endpoint(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Sync Exa key of content provider to search provider
|
||||
if (
|
||||
request.provider_type == WebContentProviderType.EXA
|
||||
and request.api_key_changed
|
||||
and request.api_key
|
||||
):
|
||||
stmt = (
|
||||
insert(InternetSearchProvider)
|
||||
.values(
|
||||
name="Exa",
|
||||
provider_type=WebSearchProviderType.EXA.value,
|
||||
api_key=request.api_key,
|
||||
is_active=False,
|
||||
)
|
||||
.on_conflict_do_update(
|
||||
index_elements=["name"],
|
||||
set_={"api_key": request.api_key},
|
||||
)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
db_session.flush()
|
||||
|
||||
db_session.commit()
|
||||
return WebContentProviderView(
|
||||
id=provider.id,
|
||||
|
||||
@@ -11,7 +11,7 @@ from onyx.db.chat import get_db_search_doc_by_id
|
||||
from onyx.db.chat import translate_db_search_doc_to_saved_search_doc
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.tools import get_tool_by_id
|
||||
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_DB_NAME
|
||||
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_IN_CODE_ID
|
||||
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_TASK_KEY
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
@@ -401,7 +401,7 @@ def translate_assistant_message_to_packets(
|
||||
# Here we do a try because some tools may get deleted before the session is reloaded.
|
||||
try:
|
||||
tool = get_tool_by_id(tool_call.tool_id, db_session)
|
||||
if tool.in_code_tool_id == RESEARCH_AGENT_DB_NAME:
|
||||
if tool.in_code_tool_id == RESEARCH_AGENT_IN_CODE_ID:
|
||||
research_agent_count += 1
|
||||
|
||||
# Handle different tool types
|
||||
@@ -457,7 +457,7 @@ def translate_assistant_message_to_packets(
|
||||
)
|
||||
)
|
||||
|
||||
elif tool.in_code_tool_id == RESEARCH_AGENT_DB_NAME:
|
||||
elif tool.in_code_tool_id == RESEARCH_AGENT_IN_CODE_ID:
|
||||
# Not ideal but not a huge issue if the research task is lost.
|
||||
research_task = cast(
|
||||
str,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
|
||||
from exa_py import Exa
|
||||
@@ -19,6 +20,21 @@ from onyx.utils.retry_wrapper import retry_builder
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _extract_site_operators(query: str) -> tuple[str, list[str]]:
|
||||
"""Extract site: operators and return cleaned query + full domains.
|
||||
|
||||
Returns (cleaned_query, full_domains) where full_domains contains the full
|
||||
values after site: (e.g., ["reddit.com/r/leagueoflegends"]).
|
||||
"""
|
||||
full_domains = re.findall(r"site:\s*([^\s]+)", query, re.IGNORECASE)
|
||||
cleaned_query = re.sub(r"site:\s*\S+\s*", "", query, flags=re.IGNORECASE).strip()
|
||||
|
||||
if not cleaned_query and full_domains:
|
||||
cleaned_query = full_domains[0]
|
||||
|
||||
return cleaned_query, full_domains
|
||||
|
||||
|
||||
class ExaClient(WebSearchProvider, WebContentProvider):
|
||||
def __init__(self, api_key: str, num_results: int = 10) -> None:
|
||||
self.exa = Exa(api_key=api_key)
|
||||
@@ -28,8 +44,9 @@ class ExaClient(WebSearchProvider, WebContentProvider):
|
||||
def supports_site_filter(self) -> bool:
|
||||
return False
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def search(self, query: str) -> list[WebSearchResult]:
|
||||
def _search_exa(
|
||||
self, query: str, include_domains: list[str] | None = None
|
||||
) -> list[WebSearchResult]:
|
||||
response = self.exa.search_and_contents(
|
||||
query,
|
||||
type="auto",
|
||||
@@ -38,6 +55,7 @@ class ExaClient(WebSearchProvider, WebContentProvider):
|
||||
highlights_per_url=1,
|
||||
),
|
||||
num_results=self._num_results,
|
||||
include_domains=include_domains,
|
||||
)
|
||||
|
||||
results: list[WebSearchResult] = []
|
||||
@@ -60,6 +78,21 @@ class ExaClient(WebSearchProvider, WebContentProvider):
|
||||
|
||||
return results
|
||||
|
||||
@retry_builder(tries=3, delay=1, backoff=2)
|
||||
def search(self, query: str) -> list[WebSearchResult]:
|
||||
cleaned_query, full_domains = _extract_site_operators(query)
|
||||
|
||||
if full_domains:
|
||||
# Try with include_domains using base domains (e.g., ["reddit.com"])
|
||||
base_domains = [d.split("/")[0].removeprefix("www.") for d in full_domains]
|
||||
results = self._search_exa(cleaned_query, include_domains=base_domains)
|
||||
if results:
|
||||
return results
|
||||
|
||||
# Fallback: add full domains as keywords
|
||||
query_with_domains = f"{cleaned_query} {' '.join(full_domains)}".strip()
|
||||
return self._search_exa(query_with_domains)
|
||||
|
||||
def test_connection(self) -> dict[str, str]:
|
||||
try:
|
||||
test_results = self.search("test")
|
||||
@@ -113,6 +146,7 @@ class ExaClient(WebSearchProvider, WebContentProvider):
|
||||
if result.published_date
|
||||
else None
|
||||
),
|
||||
scrape_successful=bool(full_content),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -98,6 +98,9 @@ def build_content_provider_from_config(
|
||||
timeout_seconds=config.timeout_seconds,
|
||||
)
|
||||
|
||||
if provider_type == WebContentProviderType.EXA:
|
||||
return ExaClient(api_key=api_key)
|
||||
|
||||
|
||||
def get_default_provider() -> WebSearchProvider | None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
|
||||
@@ -265,13 +265,22 @@ class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
|
||||
)
|
||||
|
||||
# Format for LLM
|
||||
docs_str, citation_mapping = convert_inference_sections_to_llm_string(
|
||||
top_sections=inference_sections,
|
||||
citation_start=override_kwargs.starting_citation_num,
|
||||
limit=None, # Already truncated
|
||||
include_source_type=False,
|
||||
include_link=True,
|
||||
)
|
||||
if not all_search_results:
|
||||
docs_str = json.dumps(
|
||||
{
|
||||
"results": [],
|
||||
"message": "The web search completed but returned no results for any of the queries. Do not search again.",
|
||||
}
|
||||
)
|
||||
citation_mapping: dict[int, str] = {}
|
||||
else:
|
||||
docs_str, citation_mapping = convert_inference_sections_to_llm_string(
|
||||
top_sections=inference_sections,
|
||||
citation_start=override_kwargs.starting_citation_num,
|
||||
limit=None, # Already truncated
|
||||
include_source_type=False,
|
||||
include_link=True,
|
||||
)
|
||||
|
||||
return ToolResponse(
|
||||
rich_response=SearchDocsResponse(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[project]
|
||||
name = "onyx-backend"
|
||||
version = "0.0.0"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"onyx[backend,dev,ee]",
|
||||
]
|
||||
|
||||
@@ -5,7 +5,9 @@ aioboto3==15.1.0
|
||||
aiobotocore==2.24.0
|
||||
# via aioboto3
|
||||
aiofiles==25.1.0
|
||||
# via aioboto3
|
||||
# via
|
||||
# aioboto3
|
||||
# unstructured-client
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.13.3
|
||||
@@ -115,7 +117,6 @@ certifi==2025.11.12
|
||||
# requests
|
||||
# sentry-sdk
|
||||
# trafilatura
|
||||
# unstructured-client
|
||||
cffi==2.0.0
|
||||
# via
|
||||
# argon2-cffi-bindings
|
||||
@@ -123,9 +124,7 @@ cffi==2.0.0
|
||||
# pynacl
|
||||
# zstandard
|
||||
chardet==5.2.0
|
||||
# via
|
||||
# onyx
|
||||
# unstructured
|
||||
# via onyx
|
||||
charset-normalizer==3.4.4
|
||||
# via
|
||||
# htmldate
|
||||
@@ -133,7 +132,7 @@ charset-normalizer==3.4.4
|
||||
# pdfminer-six
|
||||
# requests
|
||||
# trafilatura
|
||||
# unstructured-client
|
||||
# unstructured
|
||||
chevron==0.14.0
|
||||
# via braintrust
|
||||
chonkie==1.0.10
|
||||
@@ -149,6 +148,7 @@ click==8.3.1
|
||||
# litellm
|
||||
# magika
|
||||
# nltk
|
||||
# python-oxmsg
|
||||
# typer
|
||||
# uvicorn
|
||||
# zulip
|
||||
@@ -185,6 +185,7 @@ cryptography==46.0.3
|
||||
# pyjwt
|
||||
# secretstorage
|
||||
# sendgrid
|
||||
# unstructured-client
|
||||
cyclopts==4.2.4
|
||||
# via fastmcp
|
||||
dask==2023.8.1
|
||||
@@ -192,17 +193,13 @@ dask==2023.8.1
|
||||
# distributed
|
||||
# onyx
|
||||
dataclasses-json==0.6.7
|
||||
# via
|
||||
# unstructured
|
||||
# unstructured-client
|
||||
# via unstructured
|
||||
dateparser==1.2.2
|
||||
# via htmldate
|
||||
ddtrace==3.10.0
|
||||
# via onyx
|
||||
decorator==5.2.1
|
||||
# via retry
|
||||
deepdiff==8.6.1
|
||||
# via unstructured-client
|
||||
defusedxml==0.7.1
|
||||
# via
|
||||
# jira
|
||||
@@ -354,7 +351,7 @@ greenlet==3.2.4
|
||||
# sqlalchemy
|
||||
grpc-google-iam-v1==0.14.3
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.67.1
|
||||
grpcio==1.67.1 ; python_full_version < '3.14'
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
@@ -362,7 +359,17 @@ grpcio==1.67.1
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# litellm
|
||||
grpcio-status==1.67.1
|
||||
grpcio==1.76.0 ; python_full_version >= '3.14'
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# litellm
|
||||
grpcio-status==1.67.1 ; python_full_version < '3.14'
|
||||
# via google-api-core
|
||||
grpcio-status==1.76.0 ; python_full_version >= '3.14'
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via
|
||||
@@ -374,12 +381,15 @@ hf-xet==1.2.0 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or
|
||||
# via huggingface-hub
|
||||
hpack==4.1.0
|
||||
# via h2
|
||||
html5lib==1.1
|
||||
# via unstructured
|
||||
htmldate==1.9.1
|
||||
# via trafilatura
|
||||
httpcore==1.0.9
|
||||
# via
|
||||
# httpx
|
||||
# onyx
|
||||
# unstructured-client
|
||||
httplib2==0.31.0
|
||||
# via
|
||||
# google-api-python-client
|
||||
@@ -420,7 +430,6 @@ idna==3.11
|
||||
# email-validator
|
||||
# httpx
|
||||
# requests
|
||||
# unstructured-client
|
||||
# yarl
|
||||
importlib-metadata==8.7.0
|
||||
# via
|
||||
@@ -466,8 +475,6 @@ joblib==1.5.2
|
||||
# via nltk
|
||||
jsonpatch==1.33
|
||||
# via langchain-core
|
||||
jsonpath-python==1.0.6
|
||||
# via unstructured-client
|
||||
jsonpointer==3.0.0
|
||||
# via jsonpatch
|
||||
jsonref==1.1.0
|
||||
@@ -509,6 +516,8 @@ langsmith==0.3.45
|
||||
# langchain-core
|
||||
lazy-imports==1.0.1
|
||||
# via onyx
|
||||
legacy-cgi==2.6.4 ; python_full_version >= '3.13'
|
||||
# via ddtrace
|
||||
litellm==1.80.11
|
||||
# via onyx
|
||||
locket==1.0.0
|
||||
@@ -555,9 +564,7 @@ markupsafe==3.0.3
|
||||
# mako
|
||||
# werkzeug
|
||||
marshmallow==3.26.2
|
||||
# via
|
||||
# dataclasses-json
|
||||
# unstructured-client
|
||||
# via dataclasses-json
|
||||
matrix-client==0.3.2
|
||||
# via zulip
|
||||
mcp==1.25.0
|
||||
@@ -598,16 +605,13 @@ mypy-extensions==1.0.0
|
||||
# via
|
||||
# mypy
|
||||
# typing-inspect
|
||||
# unstructured-client
|
||||
nest-asyncio==1.6.0
|
||||
# via
|
||||
# onyx
|
||||
# unstructured-client
|
||||
# via onyx
|
||||
nltk==3.9.1
|
||||
# via
|
||||
# onyx
|
||||
# unstructured
|
||||
numpy==1.26.4
|
||||
numpy==2.4.1
|
||||
# via
|
||||
# magika
|
||||
# onnxruntime
|
||||
@@ -623,7 +627,9 @@ oauthlib==3.2.2
|
||||
office365-rest-python-client==2.5.9
|
||||
# via onyx
|
||||
olefile==0.47
|
||||
# via msoffcrypto-tool
|
||||
# via
|
||||
# msoffcrypto-tool
|
||||
# python-oxmsg
|
||||
onnxruntime==1.20.1
|
||||
# via magika
|
||||
openai==2.14.0
|
||||
@@ -678,8 +684,6 @@ opentelemetry-semantic-conventions==0.60b1
|
||||
# via
|
||||
# opentelemetry-instrumentation
|
||||
# opentelemetry-sdk
|
||||
orderly-set==5.5.0
|
||||
# via deepdiff
|
||||
orjson==3.11.4 ; platform_python_implementation != 'PyPy'
|
||||
# via langsmith
|
||||
packaging==24.2
|
||||
@@ -700,7 +704,6 @@ packaging==24.2
|
||||
# opentelemetry-instrumentation
|
||||
# pytest
|
||||
# pywikibot
|
||||
# unstructured-client
|
||||
pandas==2.2.3
|
||||
# via markitdown
|
||||
parameterized==0.9.0
|
||||
@@ -748,7 +751,19 @@ proto-plus==1.26.1
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==5.29.5
|
||||
protobuf==5.29.5 ; python_full_version < '3.14'
|
||||
# via
|
||||
# ddtrace
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# onnxruntime
|
||||
# opentelemetry-proto
|
||||
# proto-plus
|
||||
protobuf==6.33.4 ; python_full_version >= '3.14'
|
||||
# via
|
||||
# ddtrace
|
||||
# google-api-core
|
||||
@@ -810,6 +825,7 @@ pydantic==2.11.7
|
||||
# openapi-pydantic
|
||||
# pyairtable
|
||||
# pydantic-settings
|
||||
# unstructured-client
|
||||
pydantic-core==2.33.2
|
||||
# via pydantic
|
||||
pydantic-settings==2.12.0
|
||||
@@ -835,7 +851,7 @@ pynacl==1.6.2
|
||||
# via pygithub
|
||||
pyparsing==3.2.5
|
||||
# via httplib2
|
||||
pypdf==6.1.3
|
||||
pypdf==6.6.0
|
||||
# via
|
||||
# onyx
|
||||
# unstructured-client
|
||||
@@ -867,7 +883,6 @@ python-dateutil==2.8.2
|
||||
# onyx
|
||||
# opensearch-py
|
||||
# pandas
|
||||
# unstructured-client
|
||||
python-docx==1.1.2
|
||||
# via onyx
|
||||
python-dotenv==1.1.1
|
||||
@@ -894,6 +909,8 @@ python-multipart==0.0.20
|
||||
# fastapi-users
|
||||
# mcp
|
||||
# onyx
|
||||
python-oxmsg==0.0.2
|
||||
# via unstructured
|
||||
python-pptx==0.6.23
|
||||
# via
|
||||
# markitdown
|
||||
@@ -985,7 +1002,6 @@ requests==2.32.5
|
||||
# stripe
|
||||
# tiktoken
|
||||
# unstructured
|
||||
# unstructured-client
|
||||
# voyageai
|
||||
# zeep
|
||||
# zulip
|
||||
@@ -1045,12 +1061,12 @@ six==1.17.0
|
||||
# atlassian-python-api
|
||||
# dropbox
|
||||
# google-auth-httplib2
|
||||
# html5lib
|
||||
# hubspot-api-client
|
||||
# langdetect
|
||||
# markdownify
|
||||
# python-dateutil
|
||||
# stone
|
||||
# unstructured-client
|
||||
slack-sdk==3.20.2
|
||||
# via onyx
|
||||
smmap==5.0.2
|
||||
@@ -1089,8 +1105,6 @@ supervisor==4.3.0
|
||||
# via onyx
|
||||
sympy==1.13.1
|
||||
# via onnxruntime
|
||||
tabulate==0.9.0
|
||||
# via unstructured
|
||||
tblib==3.2.2
|
||||
# via distributed
|
||||
tenacity==9.1.2
|
||||
@@ -1158,6 +1172,7 @@ typing-extensions==4.15.0
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# jira
|
||||
# langchain-core
|
||||
@@ -1178,6 +1193,7 @@ typing-extensions==4.15.0
|
||||
# pyee
|
||||
# pygithub
|
||||
# python-docx
|
||||
# python-oxmsg
|
||||
# referencing
|
||||
# simple-salesforce
|
||||
# sqlalchemy
|
||||
@@ -1187,12 +1203,9 @@ typing-extensions==4.15.0
|
||||
# typing-inspect
|
||||
# typing-inspection
|
||||
# unstructured
|
||||
# unstructured-client
|
||||
# zulip
|
||||
typing-inspect==0.9.0
|
||||
# via
|
||||
# dataclasses-json
|
||||
# unstructured-client
|
||||
# via dataclasses-json
|
||||
typing-inspection==0.4.2
|
||||
# via
|
||||
# mcp
|
||||
@@ -1205,9 +1218,9 @@ tzdata==2025.2
|
||||
# tzlocal
|
||||
tzlocal==5.3.1
|
||||
# via dateparser
|
||||
unstructured==0.15.1
|
||||
unstructured==0.18.27
|
||||
# via onyx
|
||||
unstructured-client==0.25.4
|
||||
unstructured-client==0.42.6
|
||||
# via
|
||||
# onyx
|
||||
# unstructured
|
||||
@@ -1229,7 +1242,6 @@ urllib3==2.6.3
|
||||
# sentry-sdk
|
||||
# trafilatura
|
||||
# types-requests
|
||||
# unstructured-client
|
||||
uvicorn==0.35.0
|
||||
# via
|
||||
# fastmcp
|
||||
@@ -1244,6 +1256,8 @@ voyageai==0.2.3
|
||||
# via onyx
|
||||
wcwidth==0.2.14
|
||||
# via prompt-toolkit
|
||||
webencodings==0.5.1
|
||||
# via html5lib
|
||||
websockets==15.0.1
|
||||
# via
|
||||
# fastmcp
|
||||
|
||||
@@ -175,7 +175,7 @@ greenlet==3.2.4 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or
|
||||
# via sqlalchemy
|
||||
grpc-google-iam-v1==0.14.3
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.67.1
|
||||
grpcio==1.67.1 ; python_full_version < '3.14'
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
@@ -183,7 +183,17 @@ grpcio==1.67.1
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# litellm
|
||||
grpcio-status==1.67.1
|
||||
grpcio==1.76.0 ; python_full_version >= '3.14'
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# litellm
|
||||
grpcio-status==1.67.1 ; python_full_version < '3.14'
|
||||
# via google-api-core
|
||||
grpcio-status==1.76.0 ; python_full_version >= '3.14'
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via
|
||||
@@ -278,14 +288,14 @@ nest-asyncio==1.6.0
|
||||
# via ipykernel
|
||||
nodeenv==1.9.1
|
||||
# via pre-commit
|
||||
numpy==1.26.4
|
||||
numpy==2.4.1
|
||||
# via
|
||||
# contourpy
|
||||
# matplotlib
|
||||
# pandas-stubs
|
||||
# shapely
|
||||
# voyageai
|
||||
onyx-devtools==0.2.0
|
||||
onyx-devtools==0.6.2
|
||||
# via onyx
|
||||
openai==2.14.0
|
||||
# via
|
||||
@@ -347,7 +357,16 @@ proto-plus==1.26.1
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==5.29.5
|
||||
protobuf==5.29.5 ; python_full_version < '3.14'
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# proto-plus
|
||||
protobuf==6.33.4 ; python_full_version >= '3.14'
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
@@ -546,6 +565,7 @@ typing-extensions==4.15.0
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# ipython
|
||||
# mypy
|
||||
|
||||
@@ -132,7 +132,7 @@ googleapis-common-protos==1.72.0
|
||||
# grpcio-status
|
||||
grpc-google-iam-v1==0.14.3
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.67.1
|
||||
grpcio==1.67.1 ; python_full_version < '3.14'
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
@@ -140,7 +140,17 @@ grpcio==1.67.1
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# litellm
|
||||
grpcio-status==1.67.1
|
||||
grpcio==1.76.0 ; python_full_version >= '3.14'
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# litellm
|
||||
grpcio-status==1.67.1 ; python_full_version < '3.14'
|
||||
# via google-api-core
|
||||
grpcio-status==1.76.0 ; python_full_version >= '3.14'
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via
|
||||
@@ -192,7 +202,7 @@ multidict==6.7.0
|
||||
# aiobotocore
|
||||
# aiohttp
|
||||
# yarl
|
||||
numpy==1.26.4
|
||||
numpy==2.4.1
|
||||
# via
|
||||
# shapely
|
||||
# voyageai
|
||||
@@ -224,7 +234,16 @@ proto-plus==1.26.1
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==5.29.5
|
||||
protobuf==5.29.5 ; python_full_version < '3.14'
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# proto-plus
|
||||
protobuf==6.33.4 ; python_full_version >= '3.14'
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
@@ -329,6 +348,7 @@ typing-extensions==4.15.0
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# openai
|
||||
# pydantic
|
||||
|
||||
@@ -157,7 +157,7 @@ googleapis-common-protos==1.72.0
|
||||
# grpcio-status
|
||||
grpc-google-iam-v1==0.14.3
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.67.1
|
||||
grpcio==1.67.1 ; python_full_version < '3.14'
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
@@ -165,7 +165,17 @@ grpcio==1.67.1
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# litellm
|
||||
grpcio-status==1.67.1
|
||||
grpcio==1.76.0 ; python_full_version >= '3.14'
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# litellm
|
||||
grpcio-status==1.67.1 ; python_full_version < '3.14'
|
||||
# via google-api-core
|
||||
grpcio-status==1.76.0 ; python_full_version >= '3.14'
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via
|
||||
@@ -229,7 +239,7 @@ multidict==6.7.0
|
||||
# yarl
|
||||
networkx==3.5
|
||||
# via torch
|
||||
numpy==1.26.4
|
||||
numpy==2.4.1
|
||||
# via
|
||||
# accelerate
|
||||
# onyx
|
||||
@@ -306,7 +316,16 @@ proto-plus==1.26.1
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==5.29.5
|
||||
protobuf==5.29.5 ; python_full_version < '3.14'
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# proto-plus
|
||||
protobuf==6.33.4 ; python_full_version >= '3.14'
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
@@ -450,6 +469,7 @@ typing-extensions==4.15.0
|
||||
# fastapi
|
||||
# google-cloud-aiplatform
|
||||
# google-genai
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# openai
|
||||
# pydantic
|
||||
|
||||
@@ -31,3 +31,4 @@ class WebSearchProviderType(str, Enum):
|
||||
class WebContentProviderType(str, Enum):
|
||||
ONYX_WEB_CRAWLER = "onyx_web_crawler"
|
||||
FIRECRAWL = "firecrawl"
|
||||
EXA = "exa"
|
||||
|
||||
@@ -0,0 +1,281 @@
|
||||
"""
|
||||
External dependency unit tests for user file processing queue protections.
|
||||
|
||||
Verifies that the three mechanisms added to check_user_file_processing work
|
||||
correctly:
|
||||
|
||||
1. Queue depth backpressure – when the broker queue exceeds
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH, no new tasks are enqueued.
|
||||
|
||||
2. Per-file Redis guard key – if the guard key for a file already exists in
|
||||
Redis, that file is skipped even though it is still in PROCESSING status.
|
||||
|
||||
3. Task expiry – every send_task call carries expires=
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES so that stale queued tasks are
|
||||
discarded by workers automatically.
|
||||
|
||||
Also verifies that process_single_user_file clears the guard key the moment
|
||||
it is picked up by a worker.
|
||||
|
||||
Uses real Redis (DB 0 via get_redis_client) and real PostgreSQL for UserFile
|
||||
rows. The Celery app is provided as a MagicMock injected via a PropertyMock
|
||||
on the task class so no real broker is needed.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import PropertyMock
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_user_file_lock_key,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_user_file_queued_key,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
check_user_file_processing,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
process_single_user_file,
|
||||
)
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PATCH_QUEUE_LEN = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_queue_length"
|
||||
)
|
||||
|
||||
|
||||
def _create_processing_user_file(db_session: Session, user_id: object) -> UserFile:
|
||||
"""Insert a UserFile in PROCESSING status and return it."""
|
||||
uf = UserFile(
|
||||
id=uuid4(),
|
||||
user_id=user_id,
|
||||
file_id=f"test_file_{uuid4().hex[:8]}",
|
||||
name=f"test_{uuid4().hex[:8]}.txt",
|
||||
file_type="text/plain",
|
||||
status=UserFileStatus.PROCESSING,
|
||||
)
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
db_session.refresh(uf)
|
||||
return uf
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, None]:
|
||||
"""Patch the ``app`` property on *task*'s class so that ``self.app``
|
||||
inside the task function returns *mock_app*.
|
||||
|
||||
With ``bind=True``, ``task.run`` is a bound method whose ``__self__`` is
|
||||
the actual task instance. We patch ``app`` on that instance's class
|
||||
(a unique Celery-generated Task subclass) so the mock is scoped to this
|
||||
task only.
|
||||
"""
|
||||
task_instance = task.run.__self__
|
||||
with patch.object(
|
||||
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQueueDepthBackpressure:
|
||||
"""Protection 1: skip all enqueuing when the broker queue is too deep."""
|
||||
|
||||
def test_no_tasks_enqueued_when_queue_over_limit(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""When the queue depth exceeds the limit the beat cycle is skipped."""
|
||||
user = create_test_user(db_session, "bp_user")
|
||||
_create_processing_user_file(db_session, user.id)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
with (
|
||||
_patch_task_app(check_user_file_processing, mock_app),
|
||||
patch(
|
||||
_PATCH_QUEUE_LEN, return_value=USER_FILE_PROCESSING_MAX_QUEUE_DEPTH + 1
|
||||
),
|
||||
):
|
||||
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
mock_app.send_task.assert_not_called()
|
||||
|
||||
|
||||
class TestPerFileGuardKey:
|
||||
"""Protection 2: per-file Redis guard key prevents duplicate enqueue."""
|
||||
|
||||
def test_guarded_file_not_re_enqueued(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file whose guard key is already set in Redis is skipped."""
|
||||
user = create_test_user(db_session, "guard_user")
|
||||
uf = _create_processing_user_file(db_session, user.id)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_queued_key(uf.id)
|
||||
redis_client.setex(guard_key, CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, 1)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
try:
|
||||
with (
|
||||
_patch_task_app(check_user_file_processing, mock_app),
|
||||
patch(_PATCH_QUEUE_LEN, return_value=0),
|
||||
):
|
||||
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
# send_task must not have been called with this specific file's ID
|
||||
for call in mock_app.send_task.call_args_list:
|
||||
kwargs = call.kwargs.get("kwargs", {})
|
||||
assert kwargs.get("user_file_id") != str(
|
||||
uf.id
|
||||
), f"File {uf.id} should have been skipped because its guard key exists"
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
def test_guard_key_exists_in_redis_after_enqueue(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""After a file is enqueued its guard key is present in Redis with a TTL."""
|
||||
user = create_test_user(db_session, "guard_set_user")
|
||||
uf = _create_processing_user_file(db_session, user.id)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_queued_key(uf.id)
|
||||
redis_client.delete(guard_key) # clean slate
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
try:
|
||||
with (
|
||||
_patch_task_app(check_user_file_processing, mock_app),
|
||||
patch(_PATCH_QUEUE_LEN, return_value=0),
|
||||
):
|
||||
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
assert redis_client.exists(
|
||||
guard_key
|
||||
), "Guard key should be set in Redis after enqueue"
|
||||
ttl = int(redis_client.ttl(guard_key)) # type: ignore[arg-type]
|
||||
assert 0 < ttl <= CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, (
|
||||
f"Guard key TTL {ttl}s is outside the expected range "
|
||||
f"(0, {CELERY_USER_FILE_PROCESSING_TASK_EXPIRES}]"
|
||||
)
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
|
||||
class TestTaskExpiry:
|
||||
"""Protection 3: every send_task call includes an expires value."""
|
||||
|
||||
def test_send_task_called_with_expires(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""send_task is called with the correct queue, task name, and expires."""
|
||||
user = create_test_user(db_session, "expires_user")
|
||||
uf = _create_processing_user_file(db_session, user.id)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_queued_key(uf.id)
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
try:
|
||||
with (
|
||||
_patch_task_app(check_user_file_processing, mock_app),
|
||||
patch(_PATCH_QUEUE_LEN, return_value=0),
|
||||
):
|
||||
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
# At least one task should have been submitted (for our file)
|
||||
assert (
|
||||
mock_app.send_task.call_count >= 1
|
||||
), "Expected at least one task to be submitted"
|
||||
|
||||
# Every submitted task must carry expires
|
||||
for call in mock_app.send_task.call_args_list:
|
||||
assert call.args[0] == OnyxCeleryTask.PROCESS_SINGLE_USER_FILE
|
||||
assert call.kwargs.get("queue") == OnyxCeleryQueues.USER_FILE_PROCESSING
|
||||
assert (
|
||||
call.kwargs.get("expires")
|
||||
== CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
), (
|
||||
"Task must be submitted with the correct expires value to prevent "
|
||||
"stale task accumulation"
|
||||
)
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
|
||||
class TestWorkerClearsGuardKey:
|
||||
"""process_single_user_file removes the guard key when it picks up a task."""
|
||||
|
||||
def test_guard_key_deleted_on_pickup(
|
||||
self,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""The guard key is deleted before the worker does any real work.
|
||||
|
||||
We simulate an already-locked file so process_single_user_file returns
|
||||
early – but crucially, after the guard key deletion.
|
||||
"""
|
||||
user_file_id = str(uuid4())
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_queued_key(user_file_id)
|
||||
|
||||
# Simulate the guard key set when the beat enqueued the task
|
||||
redis_client.setex(guard_key, CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, 1)
|
||||
assert redis_client.exists(guard_key), "Guard key must exist before pickup"
|
||||
|
||||
# Hold the per-file processing lock so the worker exits early without
|
||||
# touching the database or file store.
|
||||
lock_key = _user_file_lock_key(user_file_id)
|
||||
processing_lock = redis_client.lock(lock_key, timeout=10)
|
||||
acquired = processing_lock.acquire(blocking=False)
|
||||
assert acquired, "Should be able to acquire the processing lock for this test"
|
||||
|
||||
try:
|
||||
process_single_user_file.run(
|
||||
user_file_id=user_file_id,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
)
|
||||
finally:
|
||||
if processing_lock.owned():
|
||||
processing_lock.release()
|
||||
|
||||
assert not redis_client.exists(
|
||||
guard_key
|
||||
), "Guard key should be deleted when the worker picks up the task"
|
||||
@@ -4,9 +4,11 @@ These tests assume OpenSearch is running and test all implemented methods
|
||||
using real schemas, pipelines, and search queries from the codebase.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -21,18 +23,31 @@ from onyx.document_index.opensearch.search import (
|
||||
MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
from onyx.document_index.opensearch.search import MIN_MAX_NORMALIZATION_PIPELINE_NAME
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
|
||||
def _patch_global_tenant_state(monkeypatch: pytest.MonkeyPatch, state: bool) -> None:
|
||||
"""Patches MULTI_TENANT wherever necessary for this test file.
|
||||
|
||||
Args:
|
||||
monkeypatch: The test instance's monkeypatch instance, used for
|
||||
patching.
|
||||
state: The intended state of MULTI_TENANT.
|
||||
"""
|
||||
monkeypatch.setattr("shared_configs.configs.MULTI_TENANT", state)
|
||||
monkeypatch.setattr("onyx.document_index.opensearch.schema.MULTI_TENANT", state)
|
||||
|
||||
|
||||
def _create_test_document_chunk(
|
||||
document_id: str = "test-doc-1",
|
||||
chunk_index: int = 0,
|
||||
content: str = "Test content",
|
||||
document_id: str,
|
||||
chunk_index: int,
|
||||
content: str,
|
||||
tenant_state: TenantState,
|
||||
content_vector: list[float] | None = None,
|
||||
title: str | None = None,
|
||||
title_vector: list[float] | None = None,
|
||||
public: bool = True,
|
||||
hidden: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> DocumentChunk:
|
||||
if content_vector is None:
|
||||
# Generate dummy vector - 128 dimensions for fast testing.
|
||||
@@ -42,31 +57,51 @@ def _create_test_document_chunk(
|
||||
if title is not None and title_vector is None:
|
||||
title_vector = [0.2] * 128
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
# We only store millisecond precision, so to make sure asserts work in this
|
||||
# test file manually lose some precision from datetime.now().
|
||||
now = now.replace(microsecond=(now.microsecond // 1000) * 1000)
|
||||
|
||||
return DocumentChunk(
|
||||
document_id=document_id,
|
||||
chunk_index=chunk_index,
|
||||
content=content,
|
||||
content_vector=content_vector,
|
||||
title=title,
|
||||
title_vector=title_vector,
|
||||
# This is not how tokenization necessarily works, this is just for quick
|
||||
# testing.
|
||||
num_tokens=len(content.split()),
|
||||
content=content,
|
||||
content_vector=content_vector,
|
||||
source_type="test_source",
|
||||
metadata=json.dumps({}),
|
||||
last_updated=now,
|
||||
public=public,
|
||||
access_control_list=[],
|
||||
hidden=hidden,
|
||||
**kwargs,
|
||||
global_boost=0,
|
||||
semantic_identifier="Test semantic identifier",
|
||||
image_file_id=None,
|
||||
source_links=None,
|
||||
blurb="Test blurb",
|
||||
doc_summary="Test doc summary",
|
||||
chunk_context="Test chunk context",
|
||||
document_sets=None,
|
||||
project_ids=None,
|
||||
primary_owners=None,
|
||||
secondary_owners=None,
|
||||
tenant_id=tenant_state,
|
||||
)
|
||||
|
||||
|
||||
def _generate_test_vector(base_value: float = 0.1, dimension: int = 128) -> list[float]:
|
||||
"""Generate a test vector with slight variations."""
|
||||
return [base_value + (i * 0.001) for i in range(dimension)]
|
||||
"""Generates a test vector with slight variations.
|
||||
|
||||
We round to eliminate floating point precision errors when comparing chunks
|
||||
for equality.
|
||||
"""
|
||||
return [round(base_value + (i * 0.001), 5) for i in range(dimension)]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def opensearch_available() -> None:
|
||||
"""Verify OpenSearch is running, skip all tests if not."""
|
||||
"""Verifies OpenSearch is running, skips all tests if not."""
|
||||
client = OpenSearchClient(index_name="test_ping")
|
||||
try:
|
||||
if not client.ping():
|
||||
@@ -228,11 +263,15 @@ class TestOpenSearchClient:
|
||||
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME
|
||||
)
|
||||
|
||||
def test_index_document(self, test_client: OpenSearchClient) -> None:
|
||||
def test_index_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests indexing a document."""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, False)
|
||||
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
@@ -241,17 +280,22 @@ class TestOpenSearchClient:
|
||||
document_id="test-doc-1",
|
||||
chunk_index=0,
|
||||
content="Test content for indexing",
|
||||
tenant_state=tenant_state,
|
||||
)
|
||||
|
||||
# Under test and postcondition.
|
||||
# Should not raise.
|
||||
test_client.index_document(document=doc)
|
||||
|
||||
def test_index_duplicate_document(self, test_client: OpenSearchClient) -> None:
|
||||
def test_index_duplicate_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests indexing a duplicate document raises an error."""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, False)
|
||||
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
@@ -260,6 +304,7 @@ class TestOpenSearchClient:
|
||||
document_id="test-doc-duplicate",
|
||||
chunk_index=0,
|
||||
content="Duplicate test",
|
||||
tenant_state=tenant_state,
|
||||
)
|
||||
|
||||
# Index once - should succeed.
|
||||
@@ -270,11 +315,15 @@ class TestOpenSearchClient:
|
||||
with pytest.raises(Exception, match="already exists"):
|
||||
test_client.index_document(document=doc)
|
||||
|
||||
def test_get_document(self, test_client: OpenSearchClient) -> None:
|
||||
def test_get_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests getting a document."""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, False)
|
||||
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
@@ -283,6 +332,7 @@ class TestOpenSearchClient:
|
||||
document_id="test-doc-get",
|
||||
chunk_index=0,
|
||||
content="Content to retrieve",
|
||||
tenant_state=tenant_state,
|
||||
)
|
||||
test_client.index_document(document=original_doc)
|
||||
|
||||
@@ -297,11 +347,14 @@ class TestOpenSearchClient:
|
||||
# Postcondition.
|
||||
assert retrieved_doc == original_doc
|
||||
|
||||
def test_get_nonexistent_document(self, test_client: OpenSearchClient) -> None:
|
||||
def test_get_nonexistent_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests getting a nonexistent document raises an error."""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, False)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
vector_dimension=128, multitenant=False
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
@@ -312,11 +365,15 @@ class TestOpenSearchClient:
|
||||
document_chunk_id="test_source__nonexistent__512__0"
|
||||
)
|
||||
|
||||
def test_delete_existing_document(self, test_client: OpenSearchClient) -> None:
|
||||
def test_delete_existing_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests deleting an existing document returns True."""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, False)
|
||||
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
@@ -325,6 +382,7 @@ class TestOpenSearchClient:
|
||||
document_id="test-doc-delete",
|
||||
chunk_index=0,
|
||||
content="Content to delete",
|
||||
tenant_state=tenant_state,
|
||||
)
|
||||
test_client.index_document(document=doc)
|
||||
|
||||
@@ -342,11 +400,15 @@ class TestOpenSearchClient:
|
||||
with pytest.raises(Exception, match="404"):
|
||||
test_client.get_document(document_chunk_id=doc_chunk_id)
|
||||
|
||||
def test_delete_nonexistent_document(self, test_client: OpenSearchClient) -> None:
|
||||
def test_delete_nonexistent_document(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests deleting a nonexistent document returns False."""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, False)
|
||||
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
@@ -359,11 +421,15 @@ class TestOpenSearchClient:
|
||||
# Postcondition.
|
||||
assert result is False
|
||||
|
||||
def test_delete_by_query(self, test_client: OpenSearchClient) -> None:
|
||||
def test_delete_by_query(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests deleting documents by query."""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, False)
|
||||
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
@@ -374,7 +440,7 @@ class TestOpenSearchClient:
|
||||
document_id="delete-me",
|
||||
chunk_index=i,
|
||||
content=f"Delete this {i}",
|
||||
tenant_id="tenant-x",
|
||||
tenant_state=tenant_state,
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
@@ -383,7 +449,7 @@ class TestOpenSearchClient:
|
||||
document_id="keep-me",
|
||||
chunk_index=0,
|
||||
content="Keep this",
|
||||
tenant_id="tenant-x",
|
||||
tenant_state=tenant_state,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -393,7 +459,7 @@ class TestOpenSearchClient:
|
||||
|
||||
query_body = DocumentQuery.delete_from_document_id_query(
|
||||
document_id="delete-me",
|
||||
tenant_state=TenantState(tenant_id="tenant-x", multitenant=True),
|
||||
tenant_state=tenant_state,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
@@ -406,7 +472,7 @@ class TestOpenSearchClient:
|
||||
test_client.refresh_index()
|
||||
search_query = DocumentQuery.get_from_document_id_query(
|
||||
document_id="delete-me",
|
||||
tenant_state=TenantState(tenant_id="tenant-x", multitenant=True),
|
||||
tenant_state=tenant_state,
|
||||
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -418,7 +484,7 @@ class TestOpenSearchClient:
|
||||
# Verify other documents still exist.
|
||||
keep_query = DocumentQuery.get_from_document_id_query(
|
||||
document_id="keep-me",
|
||||
tenant_state=TenantState(tenant_id="tenant-x", multitenant=True),
|
||||
tenant_state=tenant_state,
|
||||
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -432,37 +498,44 @@ class TestOpenSearchClient:
|
||||
with pytest.raises(NotImplementedError):
|
||||
test_client.update_document()
|
||||
|
||||
def test_search_basic(self, test_client: OpenSearchClient) -> None:
|
||||
def test_search_basic(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests basic search functionality."""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, False)
|
||||
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=False
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index multiple documents with different content and vectors.
|
||||
docs = [
|
||||
_create_test_document_chunk(
|
||||
docs = {
|
||||
"search-doc-1": _create_test_document_chunk(
|
||||
document_id="search-doc-1",
|
||||
chunk_index=0,
|
||||
content="Python programming language tutorial",
|
||||
content_vector=_generate_test_vector(0.1),
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
_create_test_document_chunk(
|
||||
"search-doc-2": _create_test_document_chunk(
|
||||
document_id="search-doc-2",
|
||||
chunk_index=0,
|
||||
content="How to make cheese",
|
||||
content_vector=_generate_test_vector(0.2),
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
_create_test_document_chunk(
|
||||
"search-doc-3": _create_test_document_chunk(
|
||||
document_id="search-doc-3",
|
||||
chunk_index=0,
|
||||
content="C++ for newborns",
|
||||
content_vector=_generate_test_vector(0.15),
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
]
|
||||
for doc in docs:
|
||||
}
|
||||
for doc in docs.values():
|
||||
test_client.index_document(document=doc)
|
||||
|
||||
# Refresh index to make documents searchable.
|
||||
@@ -476,47 +549,57 @@ class TestOpenSearchClient:
|
||||
query_vector=query_vector,
|
||||
num_candidates=10,
|
||||
num_hits=5,
|
||||
tenant_state=TenantState(tenant_id="", multitenant=False),
|
||||
tenant_state=tenant_state,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
results = test_client.search(body=search_body, search_pipeline_id=None)
|
||||
|
||||
# Postcondition.
|
||||
assert len(results) > 0
|
||||
assert len(results) == 3
|
||||
# Assert that all the chunks above are present.
|
||||
assert all(
|
||||
chunk.document_id in ["search-doc-1", "search-doc-2", "search-doc-3"]
|
||||
for chunk in results
|
||||
)
|
||||
# Make sure the chunk contents are preserved.
|
||||
for chunk in results:
|
||||
assert chunk == docs[chunk.document_id]
|
||||
|
||||
def test_search_with_pipeline(
|
||||
self, test_client: OpenSearchClient, search_pipeline: None
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
search_pipeline: None,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Tests search with a normalization pipeline."""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, False)
|
||||
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=False
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index documents.
|
||||
docs = [
|
||||
_create_test_document_chunk(
|
||||
docs = {
|
||||
"pipeline-doc-1": _create_test_document_chunk(
|
||||
document_id="pipeline-doc-1",
|
||||
chunk_index=0,
|
||||
content="Machine learning algorithms for single-celled organisms",
|
||||
content_vector=_generate_test_vector(0.3),
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
_create_test_document_chunk(
|
||||
"pipeline-doc-2": _create_test_document_chunk(
|
||||
document_id="pipeline-doc-2",
|
||||
chunk_index=0,
|
||||
content="Deep learning shallow neural networks",
|
||||
content_vector=_generate_test_vector(0.35),
|
||||
tenant_state=tenant_state,
|
||||
),
|
||||
]
|
||||
for doc in docs:
|
||||
}
|
||||
for doc in docs.values():
|
||||
test_client.index_document(document=doc)
|
||||
|
||||
# Refresh index to make documents searchable
|
||||
@@ -530,7 +613,7 @@ class TestOpenSearchClient:
|
||||
query_vector=query_vector,
|
||||
num_candidates=10,
|
||||
num_hits=5,
|
||||
tenant_state=TenantState(tenant_id="", multitenant=False),
|
||||
tenant_state=tenant_state,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
@@ -539,18 +622,25 @@ class TestOpenSearchClient:
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
assert len(results) > 0
|
||||
assert len(results) == 2
|
||||
# Assert that all the chunks above are present.
|
||||
assert all(
|
||||
chunk.document_id in ["pipeline-doc-1", "pipeline-doc-2"]
|
||||
for chunk in results
|
||||
)
|
||||
# Make sure the chunk contents are preserved.
|
||||
for chunk in results:
|
||||
assert chunk == docs[chunk.document_id]
|
||||
|
||||
def test_search_empty_index(self, test_client: OpenSearchClient) -> None:
|
||||
def test_search_empty_index(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests search on an empty index returns an empty list."""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, False)
|
||||
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=False
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
@@ -564,7 +654,7 @@ class TestOpenSearchClient:
|
||||
query_vector=query_vector,
|
||||
num_candidates=10,
|
||||
num_hits=5,
|
||||
tenant_state=TenantState(tenant_id="", multitenant=False),
|
||||
tenant_state=tenant_state,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
@@ -573,43 +663,60 @@ class TestOpenSearchClient:
|
||||
# Postcondition.
|
||||
assert len(results) == 0
|
||||
|
||||
def test_search_filters(self, test_client: OpenSearchClient) -> None:
|
||||
"""Tests search filters for public/hidden documents."""
|
||||
def test_search_filters(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""
|
||||
Tests search filters for public/hidden documents and tenant isolation.
|
||||
"""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, True)
|
||||
tenant_x = TenantState(tenant_id="tenant-x", multitenant=True)
|
||||
tenant_y = TenantState(tenant_id="tenant-y", multitenant=True)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
vector_dimension=128, multitenant=tenant_x.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index documents with different public/hidden states.
|
||||
docs = [
|
||||
_create_test_document_chunk(
|
||||
# Index documents with different public/hidden and tenant states.
|
||||
docs = {
|
||||
"public-doc-1": _create_test_document_chunk(
|
||||
document_id="public-doc-1",
|
||||
chunk_index=0,
|
||||
content="Public document content",
|
||||
public=True,
|
||||
hidden=False,
|
||||
tenant_id="tenant-x",
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
_create_test_document_chunk(
|
||||
"hidden-doc-1": _create_test_document_chunk(
|
||||
document_id="hidden-doc-1",
|
||||
chunk_index=0,
|
||||
content="Hidden document content, spooky",
|
||||
public=True,
|
||||
hidden=True,
|
||||
tenant_id="tenant-x",
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
_create_test_document_chunk(
|
||||
"private-doc-1": _create_test_document_chunk(
|
||||
document_id="private-doc-1",
|
||||
chunk_index=0,
|
||||
content="Private document content, btw my SSN is 123-45-6789",
|
||||
public=False,
|
||||
hidden=False,
|
||||
tenant_id="tenant-x",
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
]
|
||||
for doc in docs:
|
||||
"should-not-exist-from-tenant-x-pov": _create_test_document_chunk(
|
||||
document_id="should-not-exist-from-tenant-x-pov",
|
||||
chunk_index=0,
|
||||
content="This is an entirely different tenant, x should never see this",
|
||||
# Make this as permissive as possible to exercise tenant
|
||||
# isolation.
|
||||
public=True,
|
||||
hidden=False,
|
||||
tenant_state=tenant_y,
|
||||
),
|
||||
}
|
||||
for doc in docs.values():
|
||||
test_client.index_document(document=doc)
|
||||
|
||||
# Refresh index to make documents searchable.
|
||||
@@ -625,7 +732,7 @@ class TestOpenSearchClient:
|
||||
query_vector=query_vector,
|
||||
num_candidates=10,
|
||||
num_hits=5,
|
||||
tenant_state=TenantState(tenant_id="tenant-x", multitenant=True),
|
||||
tenant_state=tenant_x,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
@@ -635,19 +742,24 @@ class TestOpenSearchClient:
|
||||
# Should only get the public, non-hidden document.
|
||||
assert len(results) == 1
|
||||
assert results[0].document_id == "public-doc-1"
|
||||
assert results[0].public is True
|
||||
assert results[0].hidden is False
|
||||
# Make sure the chunk contents are preserved.
|
||||
assert results[0] == docs["public-doc-1"]
|
||||
|
||||
def test_search_with_pipeline_and_filters_returns_chunks_with_related_content_first(
|
||||
self, test_client: OpenSearchClient, search_pipeline: None
|
||||
self,
|
||||
test_client: OpenSearchClient,
|
||||
search_pipeline: None,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""
|
||||
Tests search with a normalization pipeline and filters returns chunks
|
||||
with related content first.
|
||||
"""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, True)
|
||||
tenant_x = TenantState(tenant_id="tenant-x", multitenant=True)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
vector_dimension=128, multitenant=tenant_x.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
@@ -664,7 +776,7 @@ class TestOpenSearchClient:
|
||||
), # Very close to query vector.
|
||||
public=True,
|
||||
hidden=False,
|
||||
tenant_id="tenant-x",
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
_create_test_document_chunk(
|
||||
document_id="somewhat-relevant-1",
|
||||
@@ -673,7 +785,7 @@ class TestOpenSearchClient:
|
||||
content_vector=_generate_test_vector(0.5), # Far from query vector.
|
||||
public=True,
|
||||
hidden=False,
|
||||
tenant_id="tenant-x",
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
_create_test_document_chunk(
|
||||
document_id="not-very-relevant-1",
|
||||
@@ -684,7 +796,7 @@ class TestOpenSearchClient:
|
||||
), # Very far from query vector.
|
||||
public=True,
|
||||
hidden=False,
|
||||
tenant_id="tenant-x",
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
# These should be filtered out by public/hidden filters.
|
||||
_create_test_document_chunk(
|
||||
@@ -694,7 +806,7 @@ class TestOpenSearchClient:
|
||||
content_vector=_generate_test_vector(0.05), # Very close but hidden.
|
||||
public=True,
|
||||
hidden=True,
|
||||
tenant_id="tenant-x",
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
_create_test_document_chunk(
|
||||
document_id="private-but-relevant-1",
|
||||
@@ -703,7 +815,7 @@ class TestOpenSearchClient:
|
||||
content_vector=_generate_test_vector(0.08), # Very close but private.
|
||||
public=False,
|
||||
hidden=False,
|
||||
tenant_id="tenant-x",
|
||||
tenant_state=tenant_x,
|
||||
),
|
||||
]
|
||||
for doc in docs:
|
||||
@@ -720,7 +832,7 @@ class TestOpenSearchClient:
|
||||
query_vector=query_vector,
|
||||
num_candidates=10,
|
||||
num_hits=5,
|
||||
tenant_state=TenantState(tenant_id="tenant-x", multitenant=True),
|
||||
tenant_state=tenant_x,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
@@ -742,72 +854,18 @@ class TestOpenSearchClient:
|
||||
# Most relevant document should be first due to normalization pipeline.
|
||||
assert results[0].document_id == "highly-relevant-1"
|
||||
|
||||
def test_search_for_ids_basic(self, test_client: OpenSearchClient) -> None:
|
||||
"""Tests search_for_ids method returns correct chunk IDs."""
|
||||
# Precondition.
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=False
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
|
||||
# Index chunks for two different documents.
|
||||
doc1_chunks = [
|
||||
_create_test_document_chunk(
|
||||
document_id="doc-1", chunk_index=i, content=f"Doc 1 Chunk {i}"
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
doc2_chunks = [
|
||||
_create_test_document_chunk(
|
||||
document_id="doc-2", chunk_index=i, content=f"Doc 2 Chunk {i}"
|
||||
)
|
||||
for i in range(2)
|
||||
]
|
||||
|
||||
for chunk in doc1_chunks + doc2_chunks:
|
||||
test_client.index_document(document=chunk)
|
||||
|
||||
test_client.refresh_index()
|
||||
|
||||
# Build query for doc-1.
|
||||
query_body = DocumentQuery.get_from_document_id_query(
|
||||
document_id="doc-1",
|
||||
tenant_state=TenantState(tenant_id="", multitenant=False),
|
||||
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
get_full_document=False,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
chunk_ids = test_client.search_for_document_ids(body=query_body)
|
||||
|
||||
# Postcondition.
|
||||
# Should get 3 IDs for doc-1.
|
||||
assert len(chunk_ids) == 3
|
||||
|
||||
# Verify IDs match expected chunk IDs.
|
||||
expected_ids = {
|
||||
get_opensearch_doc_chunk_id(
|
||||
document_id=chunk.document_id,
|
||||
chunk_index=chunk.chunk_index,
|
||||
max_chunk_size=chunk.max_chunk_size,
|
||||
)
|
||||
for chunk in doc1_chunks
|
||||
}
|
||||
assert set(chunk_ids) == expected_ids
|
||||
|
||||
def test_delete_by_query_multitenant_isolation(
|
||||
self, test_client: OpenSearchClient
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""
|
||||
Tests delete_by_query respects tenant boundaries in multi-tenant mode.
|
||||
"""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, True)
|
||||
tenant_x = TenantState(tenant_id="tenant-x", multitenant=True)
|
||||
tenant_y = TenantState(tenant_id="tenant-y", multitenant=True)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=True
|
||||
vector_dimension=128, multitenant=tenant_x.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
@@ -815,103 +873,88 @@ class TestOpenSearchClient:
|
||||
# Index chunks for different doc IDs for different tenants.
|
||||
# NOTE: Since get_opensearch_doc_chunk_id doesn't include tenant_id yet,
|
||||
# we use different document IDs to avoid ID conflicts.
|
||||
tenant_a_chunks = [
|
||||
tenant_x_chunks = [
|
||||
_create_test_document_chunk(
|
||||
document_id="doc-tenant-a",
|
||||
document_id="doc-tenant-x",
|
||||
chunk_index=i,
|
||||
content=f"Tenant A Chunk {i}",
|
||||
tenant_id="tenant-a",
|
||||
tenant_state=tenant_x,
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
tenant_b_chunks = [
|
||||
tenant_y_chunks = [
|
||||
_create_test_document_chunk(
|
||||
document_id="doc-tenant-b",
|
||||
document_id="doc-tenant-y",
|
||||
chunk_index=i,
|
||||
content=f"Tenant B Chunk {i}",
|
||||
tenant_id="tenant-b",
|
||||
tenant_state=tenant_y,
|
||||
)
|
||||
for i in range(2)
|
||||
]
|
||||
|
||||
for chunk in tenant_a_chunks + tenant_b_chunks:
|
||||
for chunk in tenant_x_chunks + tenant_y_chunks:
|
||||
test_client.index_document(document=chunk)
|
||||
test_client.refresh_index()
|
||||
|
||||
# Build deletion query for tenant-a only.
|
||||
query_body = DocumentQuery.get_from_document_id_query(
|
||||
document_id="doc-tenant-a",
|
||||
tenant_state=TenantState(tenant_id="tenant-a", multitenant=True),
|
||||
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
get_full_document=False,
|
||||
# Build deletion query for tenant-x only.
|
||||
query_body = DocumentQuery.delete_from_document_id_query(
|
||||
document_id="doc-tenant-x",
|
||||
tenant_state=tenant_x,
|
||||
)
|
||||
|
||||
chunk_ids = test_client.search_for_document_ids(body=query_body)
|
||||
|
||||
assert len(chunk_ids) == 3
|
||||
expected_ids = {
|
||||
get_opensearch_doc_chunk_id(
|
||||
document_id=chunk.document_id,
|
||||
chunk_index=chunk.chunk_index,
|
||||
max_chunk_size=chunk.max_chunk_size,
|
||||
)
|
||||
for chunk in tenant_a_chunks
|
||||
}
|
||||
assert set(chunk_ids) == expected_ids
|
||||
|
||||
# Under test.
|
||||
# Delete tenant-a chunks.
|
||||
for chunk_id in chunk_ids:
|
||||
result = test_client.delete_document(chunk_id)
|
||||
assert result is True
|
||||
# Delete tenant-x chunks using delete_by_query.
|
||||
num_deleted = test_client.delete_by_query(query_body=query_body)
|
||||
|
||||
# Postcondition.
|
||||
# Verify tenant-a chunks are deleted.
|
||||
assert num_deleted == 3
|
||||
|
||||
# Verify tenant-x chunks are deleted.
|
||||
test_client.refresh_index()
|
||||
verify_query_a = DocumentQuery.get_from_document_id_query(
|
||||
document_id="doc-tenant-a",
|
||||
tenant_state=TenantState(tenant_id="tenant-a", multitenant=True),
|
||||
verify_query_x = DocumentQuery.get_from_document_id_query(
|
||||
document_id="doc-tenant-x",
|
||||
tenant_state=tenant_x,
|
||||
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
get_full_document=False,
|
||||
)
|
||||
remaining_a_ids = test_client.search_for_document_ids(body=verify_query_a)
|
||||
remaining_a_ids = test_client.search_for_document_ids(body=verify_query_x)
|
||||
assert len(remaining_a_ids) == 0
|
||||
|
||||
# Verify tenant-b chunks still exist.
|
||||
verify_query_b = DocumentQuery.get_from_document_id_query(
|
||||
document_id="doc-tenant-b",
|
||||
tenant_state=TenantState(tenant_id="tenant-b", multitenant=True),
|
||||
# Verify tenant-y chunks still exist.
|
||||
verify_query_y = DocumentQuery.get_from_document_id_query(
|
||||
document_id="doc-tenant-y",
|
||||
tenant_state=tenant_y,
|
||||
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
get_full_document=False,
|
||||
)
|
||||
remaining_b_ids = test_client.search_for_document_ids(body=verify_query_b)
|
||||
assert len(remaining_b_ids) == 2
|
||||
expected_b_ids = {
|
||||
remaining_y_ids = test_client.search_for_document_ids(body=verify_query_y)
|
||||
assert len(remaining_y_ids) == 2
|
||||
expected_y_ids = {
|
||||
get_opensearch_doc_chunk_id(
|
||||
document_id=chunk.document_id,
|
||||
chunk_index=chunk.chunk_index,
|
||||
max_chunk_size=chunk.max_chunk_size,
|
||||
)
|
||||
for chunk in tenant_b_chunks
|
||||
for chunk in tenant_y_chunks
|
||||
}
|
||||
assert set(remaining_b_ids) == expected_b_ids
|
||||
assert set(remaining_y_ids) == expected_y_ids
|
||||
|
||||
def test_delete_by_query_nonexistent_document(
|
||||
self, test_client: OpenSearchClient
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""
|
||||
Tests delete_by_query for non-existent document returns 0 deleted.
|
||||
"""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, False)
|
||||
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=False
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
@@ -919,26 +962,26 @@ class TestOpenSearchClient:
|
||||
# Don't index any documents.
|
||||
|
||||
# Build deletion query.
|
||||
query_body = DocumentQuery.get_from_document_id_query(
|
||||
query_body = DocumentQuery.delete_from_document_id_query(
|
||||
document_id="nonexistent-doc",
|
||||
tenant_state=TenantState(tenant_id="", multitenant=False),
|
||||
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
get_full_document=False,
|
||||
tenant_state=tenant_state,
|
||||
)
|
||||
|
||||
# Under test.
|
||||
chunk_ids = test_client.search_for_document_ids(body=query_body)
|
||||
num_deleted = test_client.delete_by_query(query_body=query_body)
|
||||
|
||||
# Postcondition.
|
||||
assert len(chunk_ids) == 0
|
||||
assert num_deleted == 0
|
||||
|
||||
def test_search_for_document_ids(self, test_client: OpenSearchClient) -> None:
|
||||
def test_search_for_document_ids(
|
||||
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Tests search_for_document_ids method returns correct chunk IDs."""
|
||||
# Precondition.
|
||||
_patch_global_tenant_state(monkeypatch, False)
|
||||
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
|
||||
mappings = DocumentSchema.get_document_schema(
|
||||
vector_dimension=128, multitenant=False
|
||||
vector_dimension=128, multitenant=tenant_state.multitenant
|
||||
)
|
||||
settings = DocumentSchema.get_index_settings()
|
||||
test_client.create_index(mappings=mappings, settings=settings)
|
||||
@@ -946,13 +989,19 @@ class TestOpenSearchClient:
|
||||
# Index chunks for two different documents.
|
||||
doc1_chunks = [
|
||||
_create_test_document_chunk(
|
||||
document_id="doc-1", chunk_index=i, content=f"Doc 1 Chunk {i}"
|
||||
document_id="doc-1",
|
||||
chunk_index=i,
|
||||
content=f"Doc 1 Chunk {i}",
|
||||
tenant_state=tenant_state,
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
doc2_chunks = [
|
||||
_create_test_document_chunk(
|
||||
document_id="doc-2", chunk_index=i, content=f"Doc 2 Chunk {i}"
|
||||
document_id="doc-2",
|
||||
chunk_index=i,
|
||||
content=f"Doc 2 Chunk {i}",
|
||||
tenant_state=tenant_state,
|
||||
)
|
||||
for i in range(2)
|
||||
]
|
||||
@@ -964,7 +1013,7 @@ class TestOpenSearchClient:
|
||||
# Build query for doc-1.
|
||||
query_body = DocumentQuery.get_from_document_id_query(
|
||||
document_id="doc-1",
|
||||
tenant_state=TenantState(tenant_id="", multitenant=False),
|
||||
tenant_state=tenant_state,
|
||||
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
|
||||
@@ -41,6 +41,12 @@ API_KEY_RECORDS: Dict[str, Dict[str, Any]] = {
|
||||
},
|
||||
}
|
||||
|
||||
# These are inferrable from the file anyways, no need to obfuscate.
|
||||
# use them to test your auth with this server
|
||||
#
|
||||
# mcp_live-kid_alice_001-S3cr3tAlice
|
||||
# mcp_live-kid_bob_001-S3cr3tBob
|
||||
|
||||
|
||||
# ---- verifier ---------------------------------------------------------------
|
||||
class ApiKeyVerifier(TokenVerifier):
|
||||
|
||||
@@ -309,6 +309,63 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
|
||||
assert fallback_llm.config.model_name == default_provider.default_model_name
|
||||
|
||||
|
||||
def test_list_llm_provider_basics_excludes_non_public_unrestricted(
|
||||
users: tuple[DATestUser, DATestUser],
|
||||
) -> None:
|
||||
"""Test that the /llm/provider endpoint correctly excludes non-public providers
|
||||
with no group/persona restrictions.
|
||||
|
||||
This tests the fix for the bug where non-public providers with no restrictions
|
||||
were incorrectly shown to all users instead of being admin-only.
|
||||
"""
|
||||
admin_user, basic_user = users
|
||||
|
||||
# Create a public provider (should be visible to all)
|
||||
public_provider = LLMProviderManager.create(
|
||||
name="public-provider",
|
||||
is_public=True,
|
||||
set_as_default=True,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Create a non-public provider with no restrictions (should be admin-only)
|
||||
non_public_provider = LLMProviderManager.create(
|
||||
name="non-public-unrestricted",
|
||||
is_public=False,
|
||||
groups=[],
|
||||
personas=[],
|
||||
set_as_default=False,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Non-admin user calls the /llm/provider endpoint
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/llm/provider",
|
||||
headers=basic_user.headers,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
providers = response.json()
|
||||
provider_names = [p["name"] for p in providers]
|
||||
|
||||
# Public provider should be visible
|
||||
assert public_provider.name in provider_names
|
||||
|
||||
# Non-public provider with no restrictions should NOT be visible to non-admin
|
||||
assert non_public_provider.name not in provider_names
|
||||
|
||||
# Admin user should see both providers
|
||||
admin_response = requests.get(
|
||||
f"{API_SERVER_URL}/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert admin_response.status_code == 200
|
||||
admin_providers = admin_response.json()
|
||||
admin_provider_names = [p["name"] for p in admin_providers]
|
||||
|
||||
assert public_provider.name in admin_provider_names
|
||||
assert non_public_provider.name in admin_provider_names
|
||||
|
||||
|
||||
def test_provider_delete_clears_persona_references(reset: None) -> None:
|
||||
"""Test that deleting a provider automatically clears persona references."""
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -61,13 +61,13 @@ def test_cold_startup_default_assistant() -> None:
|
||||
|
||||
# Verify all three main tools are attached
|
||||
assert (
|
||||
"SearchTool" in tool_names
|
||||
"internal_search" in tool_names
|
||||
), "Default assistant should have SearchTool attached"
|
||||
assert (
|
||||
"ImageGenerationTool" in tool_names
|
||||
"generate_image" in tool_names
|
||||
), "Default assistant should have ImageGenerationTool attached"
|
||||
assert (
|
||||
"WebSearchTool" in tool_names
|
||||
"web_search" in tool_names
|
||||
), "Default assistant should have WebSearchTool attached"
|
||||
|
||||
# Also verify by display names for clarity
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import text
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
@@ -5,6 +6,53 @@ from tests.integration.common_utils.reset import downgrade_postgres
|
||||
from tests.integration.common_utils.reset import upgrade_postgres
|
||||
|
||||
|
||||
class ToolSeedingExpectedResult(BaseModel):
|
||||
name: str
|
||||
display_name: str
|
||||
in_code_tool_id: str
|
||||
user_id: str | None
|
||||
|
||||
|
||||
EXPECTED_TOOLS = {
|
||||
"SearchTool": ToolSeedingExpectedResult(
|
||||
name="internal_search",
|
||||
display_name="Internal Search",
|
||||
in_code_tool_id="SearchTool",
|
||||
user_id=None,
|
||||
),
|
||||
"ImageGenerationTool": ToolSeedingExpectedResult(
|
||||
name="generate_image",
|
||||
display_name="Image Generation",
|
||||
in_code_tool_id="ImageGenerationTool",
|
||||
user_id=None,
|
||||
),
|
||||
"WebSearchTool": ToolSeedingExpectedResult(
|
||||
name="web_search",
|
||||
display_name="Web Search",
|
||||
in_code_tool_id="WebSearchTool",
|
||||
user_id=None,
|
||||
),
|
||||
"KnowledgeGraphTool": ToolSeedingExpectedResult(
|
||||
name="run_kg_search",
|
||||
display_name="Knowledge Graph Search",
|
||||
in_code_tool_id="KnowledgeGraphTool",
|
||||
user_id=None,
|
||||
),
|
||||
"PythonTool": ToolSeedingExpectedResult(
|
||||
name="python",
|
||||
display_name="Code Interpreter",
|
||||
in_code_tool_id="PythonTool",
|
||||
user_id=None,
|
||||
),
|
||||
"ResearchAgent": ToolSeedingExpectedResult(
|
||||
name="research_agent",
|
||||
display_name="Research Agent",
|
||||
in_code_tool_id="ResearchAgent",
|
||||
user_id=None,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def test_tool_seeding_migration() -> None:
|
||||
"""Test that migration from base to head correctly seeds builtin tools."""
|
||||
# Start from base and upgrade to just before tool seeding
|
||||
@@ -49,56 +97,33 @@ def test_tool_seeding_migration() -> None:
|
||||
len(tools) == 8
|
||||
), f"Should have created exactly 8 builtin tools, got {len(tools)}"
|
||||
|
||||
def validate_tool(expected: ToolSeedingExpectedResult) -> None:
|
||||
tool = next((t for t in tools if t[1] == expected.name), None)
|
||||
assert tool is not None, f"{expected.name} should exist"
|
||||
assert (
|
||||
tool[2] == expected.display_name
|
||||
), f"{expected.name} display name should be '{expected.display_name}'"
|
||||
assert (
|
||||
tool[4] == expected.in_code_tool_id
|
||||
), f"{expected.name} in_code_tool_id should be '{expected.in_code_tool_id}'"
|
||||
assert (
|
||||
tool[5] is None
|
||||
), f"{expected.name} should not have a user_id (builtin)"
|
||||
|
||||
# Check SearchTool
|
||||
search_tool = next((t for t in tools if t[1] == "SearchTool"), None)
|
||||
assert search_tool is not None, "SearchTool should exist"
|
||||
assert (
|
||||
search_tool[2] == "Internal Search"
|
||||
), "SearchTool display name should be 'Internal Search'"
|
||||
assert search_tool[5] is None, "SearchTool should not have a user_id (builtin)"
|
||||
validate_tool(EXPECTED_TOOLS["SearchTool"])
|
||||
|
||||
# Check ImageGenerationTool
|
||||
img_tool = next((t for t in tools if t[1] == "ImageGenerationTool"), None)
|
||||
assert img_tool is not None, "ImageGenerationTool should exist"
|
||||
assert (
|
||||
img_tool[2] == "Image Generation"
|
||||
), "ImageGenerationTool display name should be 'Image Generation'"
|
||||
assert (
|
||||
img_tool[5] is None
|
||||
), "ImageGenerationTool should not have a user_id (builtin)"
|
||||
validate_tool(EXPECTED_TOOLS["ImageGenerationTool"])
|
||||
|
||||
# Check WebSearchTool
|
||||
web_tool = next((t for t in tools if t[1] == "WebSearchTool"), None)
|
||||
assert web_tool is not None, "WebSearchTool should exist"
|
||||
assert (
|
||||
web_tool[2] == "Web Search"
|
||||
), "WebSearchTool display name should be 'Web Search'"
|
||||
assert web_tool[5] is None, "WebSearchTool should not have a user_id (builtin)"
|
||||
validate_tool(EXPECTED_TOOLS["WebSearchTool"])
|
||||
|
||||
# Check KnowledgeGraphTool
|
||||
kg_tool = next((t for t in tools if t[1] == "KnowledgeGraphTool"), None)
|
||||
assert kg_tool is not None, "KnowledgeGraphTool should exist"
|
||||
assert (
|
||||
kg_tool[2] == "Knowledge Graph Search"
|
||||
), "KnowledgeGraphTool display name should be 'Knowledge Graph Search'"
|
||||
assert (
|
||||
kg_tool[5] is None
|
||||
), "KnowledgeGraphTool should not have a user_id (builtin)"
|
||||
validate_tool(EXPECTED_TOOLS["KnowledgeGraphTool"])
|
||||
|
||||
# Check PythonTool
|
||||
python_tool = next((t for t in tools if t[1] == "PythonTool"), None)
|
||||
assert python_tool is not None, "PythonTool should exist"
|
||||
assert (
|
||||
python_tool[2] == "Code Interpreter"
|
||||
), "PythonTool display name should be 'Code Interpreter'"
|
||||
assert python_tool[5] is None, "PythonTool should not have a user_id (builtin)"
|
||||
validate_tool(EXPECTED_TOOLS["PythonTool"])
|
||||
|
||||
# Check ResearchAgent (Deep Research as a tool)
|
||||
research_agent = next((t for t in tools if t[1] == "ResearchAgent"), None)
|
||||
assert research_agent is not None, "ResearchAgent should exist"
|
||||
assert (
|
||||
research_agent[2] == "Research Agent"
|
||||
), "ResearchAgent display name should be 'Research Agent'"
|
||||
assert (
|
||||
research_agent[5] is None
|
||||
), "ResearchAgent should not have a user_id (builtin)"
|
||||
validate_tool(EXPECTED_TOOLS["ResearchAgent"])
|
||||
|
||||
@@ -38,11 +38,11 @@ def test_unified_assistant(reset: None, admin_user: DATestUser) -> None:
|
||||
# Verify tools
|
||||
tools = unified_assistant.tools
|
||||
tool_names = [tool.name for tool in tools]
|
||||
assert "SearchTool" in tool_names, "SearchTool not found in unified assistant"
|
||||
assert "internal_search" in tool_names, "SearchTool not found in unified assistant"
|
||||
assert (
|
||||
"ImageGenerationTool" in tool_names
|
||||
"generate_image" in tool_names
|
||||
), "ImageGenerationTool not found in unified assistant"
|
||||
assert "WebSearchTool" in tool_names, "WebSearchTool not found in unified assistant"
|
||||
assert "web_search" in tool_names, "WebSearchTool not found in unified assistant"
|
||||
|
||||
# Verify no starter messages
|
||||
starter_messages = unified_assistant.starter_messages or []
|
||||
|
||||
@@ -270,7 +270,7 @@ def test_web_search_endpoints_with_exa(
|
||||
provider_id = _activate_exa_provider(admin_user)
|
||||
assert isinstance(provider_id, int)
|
||||
|
||||
search_request = {"queries": ["latest ai research news"], "max_results": 3}
|
||||
search_request = {"queries": ["wikipedia python programming"], "max_results": 3}
|
||||
|
||||
lite_response = requests.post(
|
||||
f"{API_SERVER_URL}/web-search/search-lite",
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
# This file exposes service ports for development and testing purposes
|
||||
#
|
||||
# Usage:
|
||||
# docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d
|
||||
# docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d --wait
|
||||
#
|
||||
# Or set COMPOSE_FILE environment variable:
|
||||
# export COMPOSE_FILE=docker-compose.yml:docker-compose.dev.yml
|
||||
# docker compose up -d
|
||||
# docker compose up -d --wait
|
||||
|
||||
services:
|
||||
api_server:
|
||||
|
||||
@@ -58,7 +58,7 @@ services:
|
||||
- minio
|
||||
restart: unless-stopped
|
||||
# DEV: To expose ports, either:
|
||||
# 1. Use docker-compose.dev.yml: docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d
|
||||
# 1. Use docker-compose.dev.yml: docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d --wait
|
||||
# 2. Uncomment the ports below
|
||||
# ports:
|
||||
# - "8080:8080"
|
||||
@@ -83,7 +83,13 @@ services:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8080/health')"]
|
||||
test:
|
||||
[
|
||||
"CMD",
|
||||
"python",
|
||||
"-c",
|
||||
"import urllib.request; urllib.request.urlopen('http://localhost:8080/health')",
|
||||
]
|
||||
interval: 30s
|
||||
timeout: 20s
|
||||
retries: 3
|
||||
@@ -299,7 +305,7 @@ services:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||
# DEV: To expose ports, either:
|
||||
# 1. Use docker-compose.dev.yml: docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d
|
||||
# 1. Use docker-compose.dev.yml: docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d --wait
|
||||
# 2. Uncomment the ports below
|
||||
# ports:
|
||||
# - "5432:5432"
|
||||
@@ -321,7 +327,7 @@ services:
|
||||
environment:
|
||||
- VESPA_SKIP_UPGRADE_CHECK=${VESPA_SKIP_UPGRADE_CHECK:-true}
|
||||
# DEV: To expose ports, either:
|
||||
# 1. Use docker-compose.dev.yml: docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d
|
||||
# 1. Use docker-compose.dev.yml: docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d --wait
|
||||
# 2. Uncomment the ports below
|
||||
# ports:
|
||||
# - "19071:19071"
|
||||
@@ -378,7 +384,7 @@ services:
|
||||
image: redis:7.4-alpine
|
||||
restart: unless-stopped
|
||||
# DEV: To expose ports, either:
|
||||
# 1. Use docker-compose.dev.yml: docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d
|
||||
# 1. Use docker-compose.dev.yml: docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d --wait
|
||||
# 2. Uncomment the ports below
|
||||
# ports:
|
||||
# - "6379:6379"
|
||||
@@ -396,7 +402,7 @@ services:
|
||||
image: minio/minio:RELEASE.2025-07-23T15-54-02Z-cpuv1
|
||||
restart: unless-stopped
|
||||
# DEV: To expose ports, either:
|
||||
# 1. Use docker-compose.dev.yml: docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d
|
||||
# 1. Use docker-compose.dev.yml: docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d --wait
|
||||
# 2. Uncomment the ports below
|
||||
# ports:
|
||||
# - "9004:9000"
|
||||
|
||||
@@ -21,9 +21,9 @@ use tauri::{
|
||||
webview::PageLoadPayload, AppHandle, Manager, Webview, WebviewUrl, WebviewWindowBuilder,
|
||||
};
|
||||
use tauri_plugin_global_shortcut::{Code, GlobalShortcutExt, Modifiers, Shortcut};
|
||||
use url::Url;
|
||||
#[cfg(target_os = "macos")]
|
||||
use tokio::time::sleep;
|
||||
use url::Url;
|
||||
#[cfg(target_os = "macos")]
|
||||
use window_vibrancy::{apply_vibrancy, NSVisualEffectMaterial};
|
||||
|
||||
@@ -76,39 +76,25 @@ fn get_config_path() -> Option<PathBuf> {
|
||||
}
|
||||
|
||||
/// Load config from file, or create default if it doesn't exist
|
||||
fn load_config() -> AppConfig {
|
||||
fn load_config() -> (AppConfig, bool) {
|
||||
let config_path = match get_config_path() {
|
||||
Some(path) => path,
|
||||
None => {
|
||||
eprintln!("Could not determine config directory, using defaults");
|
||||
return AppConfig::default();
|
||||
return (AppConfig::default(), false);
|
||||
}
|
||||
};
|
||||
|
||||
if config_path.exists() {
|
||||
match fs::read_to_string(&config_path) {
|
||||
Ok(contents) => match serde_json::from_str(&contents) {
|
||||
Ok(config) => {
|
||||
return config;
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Failed to parse config: {}, using defaults", e);
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
eprintln!("Failed to read config: {}, using defaults", e);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Create default config file
|
||||
if let Err(e) = save_config(&AppConfig::default()) {
|
||||
eprintln!("Failed to create default config: {}", e);
|
||||
} else {
|
||||
println!("Created default config at {:?}", config_path);
|
||||
}
|
||||
if !config_path.exists() {
|
||||
return (AppConfig::default(), false);
|
||||
}
|
||||
|
||||
AppConfig::default()
|
||||
match fs::read_to_string(&config_path) {
|
||||
Ok(contents) => match serde_json::from_str(&contents) {
|
||||
Ok(config) => (config, true),
|
||||
Err(_) => (AppConfig::default(), false),
|
||||
},
|
||||
Err(_) => (AppConfig::default(), false),
|
||||
}
|
||||
}
|
||||
|
||||
/// Save config to file
|
||||
@@ -128,7 +114,11 @@ fn save_config(config: &AppConfig) -> Result<(), String> {
|
||||
}
|
||||
|
||||
// Global config state
|
||||
struct ConfigState(RwLock<AppConfig>);
|
||||
struct ConfigState {
|
||||
config: RwLock<AppConfig>,
|
||||
config_initialized: RwLock<bool>,
|
||||
app_base_url: RwLock<Option<Url>>,
|
||||
}
|
||||
|
||||
fn focus_main_window(app: &AppHandle) {
|
||||
if let Some(window) = app.get_webview_window("main") {
|
||||
@@ -142,7 +132,7 @@ fn focus_main_window(app: &AppHandle) {
|
||||
|
||||
fn trigger_new_chat(app: &AppHandle) {
|
||||
let state = app.state::<ConfigState>();
|
||||
let server_url = state.0.read().unwrap().server_url.clone();
|
||||
let server_url = state.config.read().unwrap().server_url.clone();
|
||||
|
||||
if let Some(window) = app.get_webview_window("main") {
|
||||
let url = format!("{}/chat", server_url);
|
||||
@@ -152,7 +142,7 @@ fn trigger_new_chat(app: &AppHandle) {
|
||||
|
||||
fn trigger_new_window(app: &AppHandle) {
|
||||
let state = app.state::<ConfigState>();
|
||||
let server_url = state.0.read().unwrap().server_url.clone();
|
||||
let server_url = state.config.read().unwrap().server_url.clone();
|
||||
let handle = app.clone();
|
||||
|
||||
tauri::async_runtime::spawn(async move {
|
||||
@@ -206,6 +196,30 @@ fn open_docs() {
|
||||
}
|
||||
}
|
||||
|
||||
fn open_settings(app: &AppHandle) {
|
||||
// Navigate main window to the settings page (index.html) with settings flag
|
||||
let state = app.state::<ConfigState>();
|
||||
let settings_url = state
|
||||
.app_base_url
|
||||
.read()
|
||||
.unwrap()
|
||||
.as_ref()
|
||||
.cloned()
|
||||
.and_then(|mut url| {
|
||||
url.set_query(None);
|
||||
url.set_fragment(Some("settings"));
|
||||
url.set_path("/");
|
||||
Some(url)
|
||||
})
|
||||
.or_else(|| Url::parse("tauri://localhost/#settings").ok());
|
||||
|
||||
if let Some(window) = app.get_webview_window("main") {
|
||||
if let Some(url) = settings_url {
|
||||
let _ = window.navigate(url);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tauri Commands
|
||||
// ============================================================================
|
||||
@@ -213,7 +227,27 @@ fn open_docs() {
|
||||
/// Get the current server URL
|
||||
#[tauri::command]
|
||||
fn get_server_url(state: tauri::State<ConfigState>) -> String {
|
||||
state.0.read().unwrap().server_url.clone()
|
||||
state.config.read().unwrap().server_url.clone()
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct BootstrapState {
|
||||
server_url: String,
|
||||
config_exists: bool,
|
||||
}
|
||||
|
||||
/// Get the server URL plus whether a config file exists
|
||||
#[tauri::command]
|
||||
fn get_bootstrap_state(state: tauri::State<ConfigState>) -> BootstrapState {
|
||||
let server_url = state.config.read().unwrap().server_url.clone();
|
||||
let config_initialized = *state.config_initialized.read().unwrap();
|
||||
let config_exists = config_initialized
|
||||
&& get_config_path().map(|path| path.exists()).unwrap_or(false);
|
||||
|
||||
BootstrapState {
|
||||
server_url,
|
||||
config_exists,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set a new server URL and save to config
|
||||
@@ -224,9 +258,10 @@ fn set_server_url(state: tauri::State<ConfigState>, url: String) -> Result<Strin
|
||||
return Err("URL must start with http:// or https://".to_string());
|
||||
}
|
||||
|
||||
let mut config = state.0.write().unwrap();
|
||||
let mut config = state.config.write().unwrap();
|
||||
config.server_url = url.trim_end_matches('/').to_string();
|
||||
save_config(&config)?;
|
||||
*state.config_initialized.write().unwrap() = true;
|
||||
|
||||
Ok(config.server_url.clone())
|
||||
}
|
||||
@@ -315,7 +350,7 @@ fn open_config_directory() -> Result<(), String> {
|
||||
/// Navigate to a specific path on the configured server
|
||||
#[tauri::command]
|
||||
fn navigate_to(window: tauri::WebviewWindow, state: tauri::State<ConfigState>, path: &str) {
|
||||
let base_url = state.0.read().unwrap().server_url.clone();
|
||||
let base_url = state.config.read().unwrap().server_url.clone();
|
||||
let url = format!("{}{}", base_url, path);
|
||||
let _ = window.eval(&format!("window.location.href = '{}'", url));
|
||||
}
|
||||
@@ -341,7 +376,7 @@ fn go_forward(window: tauri::WebviewWindow) {
|
||||
/// Open a new window
|
||||
#[tauri::command]
|
||||
async fn new_window(app: AppHandle, state: tauri::State<'_, ConfigState>) -> Result<(), String> {
|
||||
let server_url = state.0.read().unwrap().server_url.clone();
|
||||
let server_url = state.config.read().unwrap().server_url.clone();
|
||||
let window_label = format!("onyx-{}", uuid::Uuid::new_v4());
|
||||
|
||||
let builder = WebviewWindowBuilder::new(
|
||||
@@ -385,9 +420,10 @@ async fn new_window(app: AppHandle, state: tauri::State<'_, ConfigState>) -> Res
|
||||
/// Reset config to defaults
|
||||
#[tauri::command]
|
||||
fn reset_config(state: tauri::State<ConfigState>) -> Result<(), String> {
|
||||
let mut config = state.0.write().unwrap();
|
||||
let mut config = state.config.write().unwrap();
|
||||
*config = AppConfig::default();
|
||||
save_config(&config)?;
|
||||
*state.config_initialized.write().unwrap() = true;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -423,7 +459,7 @@ fn setup_shortcuts(app: &AppHandle) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let forward = Shortcut::new(Some(Modifiers::SUPER), Code::BracketRight);
|
||||
let new_window_shortcut = Shortcut::new(Some(Modifiers::SUPER | Modifiers::SHIFT), Code::KeyN);
|
||||
let show_app = Shortcut::new(Some(Modifiers::SUPER | Modifiers::SHIFT), Code::Space);
|
||||
let open_settings = Shortcut::new(Some(Modifiers::SUPER), Code::Comma);
|
||||
let open_settings_shortcut = Shortcut::new(Some(Modifiers::SUPER), Code::Comma);
|
||||
|
||||
let app_handle = app.clone();
|
||||
|
||||
@@ -435,7 +471,7 @@ fn setup_shortcuts(app: &AppHandle) -> Result<(), Box<dyn std::error::Error>> {
|
||||
forward,
|
||||
new_window_shortcut,
|
||||
show_app,
|
||||
open_settings,
|
||||
open_settings_shortcut,
|
||||
];
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
@@ -446,7 +482,7 @@ fn setup_shortcuts(app: &AppHandle) -> Result<(), Box<dyn std::error::Error>> {
|
||||
forward,
|
||||
new_window_shortcut,
|
||||
show_app,
|
||||
open_settings,
|
||||
open_settings_shortcut,
|
||||
];
|
||||
|
||||
app.global_shortcut().on_shortcuts(
|
||||
@@ -463,9 +499,8 @@ fn setup_shortcuts(app: &AppHandle) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = window.eval("window.history.back()");
|
||||
} else if shortcut == &forward {
|
||||
let _ = window.eval("window.history.forward()");
|
||||
} else if shortcut == &open_settings {
|
||||
// Open config file for editing
|
||||
let _ = open_config_file();
|
||||
} else if shortcut == &open_settings_shortcut {
|
||||
open_settings(&app_handle);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -495,6 +530,7 @@ fn setup_app_menu(app: &AppHandle) -> tauri::Result<()> {
|
||||
true,
|
||||
Some("CmdOrCtrl+Shift+N"),
|
||||
)?;
|
||||
let settings_item = MenuItem::with_id(app, "open_settings", "Settings...", true, Some("CmdOrCtrl+Comma"))?;
|
||||
let docs_item = MenuItem::with_id(app, "open_docs", "Onyx Documentation", true, None::<&str>)?;
|
||||
|
||||
if let Some(file_menu) = menu
|
||||
@@ -503,12 +539,13 @@ fn setup_app_menu(app: &AppHandle) -> tauri::Result<()> {
|
||||
.filter_map(|item| item.as_submenu().cloned())
|
||||
.find(|submenu| submenu.text().ok().as_deref() == Some("File"))
|
||||
{
|
||||
file_menu.insert_items(&[&new_chat_item, &new_window_item], 0)?;
|
||||
file_menu.insert_items(&[&new_chat_item, &new_window_item, &settings_item], 0)?;
|
||||
} else {
|
||||
let file_menu = SubmenuBuilder::new(app, "File")
|
||||
.items(&[
|
||||
&new_chat_item,
|
||||
&new_window_item,
|
||||
&settings_item,
|
||||
&PredefinedMenuItem::close_window(app, None)?,
|
||||
])
|
||||
.build()?;
|
||||
@@ -625,22 +662,20 @@ fn setup_tray_icon(app: &AppHandle) -> tauri::Result<()> {
|
||||
|
||||
fn main() {
|
||||
// Load config at startup
|
||||
let config = load_config();
|
||||
let server_url = config.server_url.clone();
|
||||
|
||||
println!("Starting Onyx Desktop");
|
||||
println!("Server URL: {}", server_url);
|
||||
if let Some(path) = get_config_path() {
|
||||
println!("Config file: {:?}", path);
|
||||
}
|
||||
let (config, config_initialized) = load_config();
|
||||
|
||||
tauri::Builder::default()
|
||||
.plugin(tauri_plugin_shell::init())
|
||||
.plugin(tauri_plugin_global_shortcut::Builder::new().build())
|
||||
.plugin(tauri_plugin_window_state::Builder::default().build())
|
||||
.manage(ConfigState(RwLock::new(config)))
|
||||
.manage(ConfigState {
|
||||
config: RwLock::new(config),
|
||||
config_initialized: RwLock::new(config_initialized),
|
||||
app_base_url: RwLock::new(None),
|
||||
})
|
||||
.invoke_handler(tauri::generate_handler![
|
||||
get_server_url,
|
||||
get_bootstrap_state,
|
||||
set_server_url,
|
||||
get_config_path_cmd,
|
||||
open_config_file,
|
||||
@@ -657,6 +692,7 @@ fn main() {
|
||||
"open_docs" => open_docs(),
|
||||
"new_chat" => trigger_new_chat(app),
|
||||
"new_window" => trigger_new_window(app),
|
||||
"open_settings" => open_settings(app),
|
||||
_ => {}
|
||||
})
|
||||
.setup(move |app| {
|
||||
@@ -675,7 +711,7 @@ fn main() {
|
||||
eprintln!("Failed to setup tray icon: {}", e);
|
||||
}
|
||||
|
||||
// Update main window URL to configured server and inject title bar
|
||||
// Setup main window with vibrancy effect
|
||||
if let Some(window) = app.get_webview_window("main") {
|
||||
// Apply vibrancy effect for translucent glass look
|
||||
#[cfg(target_os = "macos")]
|
||||
@@ -683,14 +719,12 @@ fn main() {
|
||||
let _ = apply_vibrancy(&window, NSVisualEffectMaterial::Sidebar, None, None);
|
||||
}
|
||||
|
||||
if let Ok(target) = Url::parse(&server_url) {
|
||||
if let Ok(current) = window.url() {
|
||||
if current != target {
|
||||
let _ = window.navigate(target);
|
||||
}
|
||||
} else {
|
||||
let _ = window.navigate(target);
|
||||
}
|
||||
if let Ok(url) = window.url() {
|
||||
let mut base_url = url;
|
||||
base_url.set_query(None);
|
||||
base_url.set_fragment(None);
|
||||
base_url.set_path("/");
|
||||
*app.state::<ConfigState>().app_base_url.write().unwrap() = Some(base_url);
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
{
|
||||
"title": "Onyx",
|
||||
"label": "main",
|
||||
"url": "https://cloud.onyx.app",
|
||||
"url": "index.html",
|
||||
"width": 1200,
|
||||
"height": 800,
|
||||
"minWidth": 800,
|
||||
@@ -52,7 +52,7 @@
|
||||
"entitlements": null,
|
||||
"exceptionDomain": "cloud.onyx.app",
|
||||
"minimumSystemVersion": "10.15",
|
||||
"signingIdentity": "-",
|
||||
"signingIdentity": null,
|
||||
"dmg": {
|
||||
"windowSize": {
|
||||
"width": 660,
|
||||
|
||||
@@ -4,28 +4,43 @@
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Onyx</title>
|
||||
<link
|
||||
href="https://fonts.googleapis.com/css2?family=Hanken+Grotesk:wght@400;500;600;700&display=swap"
|
||||
rel="stylesheet"
|
||||
/>
|
||||
<style>
|
||||
:root {
|
||||
--background-900: #f5f5f5;
|
||||
--background-800: #ffffff;
|
||||
--text-light-05: rgba(0, 0, 0, 0.95);
|
||||
--text-light-03: rgba(0, 0, 0, 0.6);
|
||||
--white-10: rgba(0, 0, 0, 0.1);
|
||||
--white-15: rgba(0, 0, 0, 0.15);
|
||||
--white-20: rgba(0, 0, 0, 0.2);
|
||||
--white-30: rgba(0, 0, 0, 0.3);
|
||||
--font-hanken-grotesk: "Hanken Grotesk", -apple-system,
|
||||
BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
}
|
||||
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto,
|
||||
Oxygen, Ubuntu, sans-serif;
|
||||
background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%);
|
||||
color: #fff;
|
||||
font-family: var(--font-hanken-grotesk);
|
||||
background: linear-gradient(135deg, #f5f5f5 0%, #ffffff 100%);
|
||||
min-height: 100vh;
|
||||
color: var(--text-light-05);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 20px;
|
||||
-webkit-user-select: none;
|
||||
user-select: none;
|
||||
}
|
||||
|
||||
/* Draggable titlebar area for macOS */
|
||||
.titlebar {
|
||||
position: fixed;
|
||||
top: 0;
|
||||
@@ -33,198 +48,451 @@
|
||||
right: 0;
|
||||
height: 28px;
|
||||
-webkit-app-region: drag;
|
||||
z-index: 10000;
|
||||
}
|
||||
|
||||
.container {
|
||||
text-align: center;
|
||||
padding: 2rem;
|
||||
.settings-container {
|
||||
max-width: 500px;
|
||||
width: 100%;
|
||||
opacity: 0;
|
||||
transform: translateY(8px);
|
||||
pointer-events: none;
|
||||
transition:
|
||||
opacity 0.18s ease,
|
||||
transform 0.18s ease;
|
||||
}
|
||||
|
||||
.logo {
|
||||
width: 80px;
|
||||
height: 80px;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
border-radius: 20px;
|
||||
margin: 0 auto 1.5rem;
|
||||
body.show-settings .settings-container {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
pointer-events: auto;
|
||||
}
|
||||
|
||||
.settings-panel {
|
||||
background: linear-gradient(
|
||||
to bottom,
|
||||
rgba(255, 255, 255, 0.95),
|
||||
rgba(245, 245, 245, 0.95)
|
||||
);
|
||||
backdrop-filter: blur(24px);
|
||||
border-radius: 16px;
|
||||
border: 1px solid var(--white-10);
|
||||
overflow: hidden;
|
||||
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
.settings-header {
|
||||
padding: 24px;
|
||||
border-bottom: 1px solid var(--white-10);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.settings-icon {
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
border-radius: 12px;
|
||||
background: white;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-size: 2.5rem;
|
||||
font-weight: bold;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
h1 {
|
||||
font-size: 2rem;
|
||||
margin-bottom: 0.5rem;
|
||||
.settings-icon svg {
|
||||
width: 24px;
|
||||
height: 24px;
|
||||
color: #000;
|
||||
}
|
||||
|
||||
.settings-title {
|
||||
font-size: 20px;
|
||||
font-weight: 600;
|
||||
color: var(--text-light-05);
|
||||
}
|
||||
|
||||
p {
|
||||
color: #a0a0a0;
|
||||
margin-bottom: 2rem;
|
||||
.settings-content {
|
||||
padding: 24px;
|
||||
}
|
||||
|
||||
.loading {
|
||||
.settings-section {
|
||||
margin-bottom: 32px;
|
||||
}
|
||||
|
||||
.settings-section:last-child {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.section-title {
|
||||
font-size: 11px;
|
||||
font-weight: 600;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
color: var(--text-light-03);
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
|
||||
.settings-group {
|
||||
background: rgba(0, 0, 0, 0.03);
|
||||
border-radius: 16px;
|
||||
padding: 4px;
|
||||
}
|
||||
|
||||
.setting-row {
|
||||
display: flex;
|
||||
gap: 0.5rem;
|
||||
justify-content: center;
|
||||
margin-bottom: 2rem;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 12px;
|
||||
}
|
||||
|
||||
.loading span {
|
||||
width: 10px;
|
||||
height: 10px;
|
||||
background: #667eea;
|
||||
border-radius: 50%;
|
||||
animation: bounce 1.4s ease-in-out infinite;
|
||||
.setting-row-content {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 4px;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.loading span:nth-child(1) {
|
||||
animation-delay: 0s;
|
||||
}
|
||||
.loading span:nth-child(2) {
|
||||
animation-delay: 0.2s;
|
||||
}
|
||||
.loading span:nth-child(3) {
|
||||
animation-delay: 0.4s;
|
||||
.setting-label {
|
||||
font-size: 14px;
|
||||
font-weight: 400;
|
||||
color: var(--text-light-05);
|
||||
}
|
||||
|
||||
@keyframes bounce {
|
||||
0%,
|
||||
80%,
|
||||
100% {
|
||||
transform: scale(0.8);
|
||||
opacity: 0.5;
|
||||
}
|
||||
40% {
|
||||
transform: scale(1.2);
|
||||
opacity: 1;
|
||||
}
|
||||
.setting-description {
|
||||
font-size: 12px;
|
||||
color: var(--text-light-03);
|
||||
}
|
||||
|
||||
.btn {
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 0.75rem 2rem;
|
||||
.setting-divider {
|
||||
height: 1px;
|
||||
background: var(--white-10);
|
||||
margin: 0 4px;
|
||||
}
|
||||
|
||||
.input-field {
|
||||
width: 100%;
|
||||
padding: 10px 12px;
|
||||
border: 1px solid var(--white-10);
|
||||
border-radius: 8px;
|
||||
font-size: 1rem;
|
||||
cursor: pointer;
|
||||
transition:
|
||||
transform 0.2s,
|
||||
box-shadow 0.2s;
|
||||
font-size: 14px;
|
||||
background: rgba(0, 0, 0, 0.05);
|
||||
color: var(--text-light-05);
|
||||
font-family: var(--font-hanken-grotesk);
|
||||
transition: all 0.2s;
|
||||
-webkit-app-region: no-drag;
|
||||
}
|
||||
|
||||
.btn:hover {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 4px 20px rgba(102, 126, 234, 0.4);
|
||||
.input-field:focus {
|
||||
outline: none;
|
||||
border-color: var(--white-30);
|
||||
background: rgba(0, 0, 0, 0.08);
|
||||
box-shadow: 0 0 0 2px rgba(0, 0, 0, 0.05);
|
||||
}
|
||||
|
||||
.shortcuts {
|
||||
margin-top: 3rem;
|
||||
padding: 1.5rem;
|
||||
background: rgba(255, 255, 255, 0.05);
|
||||
border-radius: 12px;
|
||||
text-align: left;
|
||||
.input-field::placeholder {
|
||||
color: var(--text-light-03);
|
||||
}
|
||||
|
||||
.shortcuts h3 {
|
||||
font-size: 0.875rem;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
color: #a0a0a0;
|
||||
margin-bottom: 1rem;
|
||||
.input-field.error {
|
||||
border-color: #ef4444;
|
||||
}
|
||||
|
||||
.shortcut {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
padding: 0.5rem 0;
|
||||
border-bottom: 1px solid rgba(255, 255, 255, 0.1);
|
||||
.error-message {
|
||||
color: #ef4444;
|
||||
font-size: 12px;
|
||||
margin-top: 4px;
|
||||
padding-left: 12px;
|
||||
display: none;
|
||||
}
|
||||
|
||||
.shortcut:last-child {
|
||||
border-bottom: none;
|
||||
.error-message.visible {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.shortcut-key {
|
||||
font-family:
|
||||
SF Mono,
|
||||
Monaco,
|
||||
monospace;
|
||||
background: rgba(255, 255, 255, 0.1);
|
||||
padding: 0.25rem 0.5rem;
|
||||
.toggle-switch {
|
||||
position: relative;
|
||||
display: inline-block;
|
||||
width: 44px;
|
||||
height: 24px;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.toggle-switch input {
|
||||
opacity: 0;
|
||||
width: 0;
|
||||
height: 0;
|
||||
}
|
||||
|
||||
.toggle-slider {
|
||||
position: absolute;
|
||||
cursor: pointer;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background-color: rgba(0, 0, 0, 0.15);
|
||||
transition: 0.3s;
|
||||
border-radius: 24px;
|
||||
}
|
||||
|
||||
.toggle-slider:before {
|
||||
position: absolute;
|
||||
content: "";
|
||||
height: 18px;
|
||||
width: 18px;
|
||||
left: 3px;
|
||||
bottom: 3px;
|
||||
background-color: white;
|
||||
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.2);
|
||||
transition: 0.3s;
|
||||
border-radius: 50%;
|
||||
}
|
||||
|
||||
input:checked + .toggle-slider {
|
||||
background-color: rgba(0, 0, 0, 0.3);
|
||||
}
|
||||
|
||||
input:checked + .toggle-slider:before {
|
||||
transform: translateX(20px);
|
||||
}
|
||||
|
||||
.button {
|
||||
padding: 12px 24px;
|
||||
border-radius: 8px;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
font-size: 14px;
|
||||
font-weight: 600;
|
||||
transition: all 0.2s;
|
||||
font-family: var(--font-hanken-grotesk);
|
||||
width: 100%;
|
||||
margin-top: 24px;
|
||||
-webkit-app-region: no-drag;
|
||||
}
|
||||
|
||||
.button.primary {
|
||||
background: #286df8;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.button.primary:hover {
|
||||
background: #1e5cd6;
|
||||
box-shadow: 0 4px 12px rgba(40, 109, 248, 0.3);
|
||||
}
|
||||
|
||||
.button.primary:disabled {
|
||||
opacity: 0.5;
|
||||
cursor: not-allowed;
|
||||
box-shadow: none;
|
||||
}
|
||||
|
||||
kbd {
|
||||
background: rgba(0, 0, 0, 0.1);
|
||||
border: 1px solid var(--white-10);
|
||||
border-radius: 4px;
|
||||
font-size: 0.75rem;
|
||||
padding: 2px 6px;
|
||||
font-family: monospace;
|
||||
font-weight: 500;
|
||||
color: var(--text-light-05);
|
||||
font-size: 11px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="titlebar"></div>
|
||||
|
||||
<div class="container">
|
||||
<div class="logo">O</div>
|
||||
<h1>Onyx</h1>
|
||||
<p>Connecting to Onyx Cloud...</p>
|
||||
<div class="settings-container">
|
||||
<div class="settings-panel">
|
||||
<div class="settings-header">
|
||||
<div class="settings-icon">
|
||||
<svg
|
||||
viewBox="0 0 56 56"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
fill="currentColor"
|
||||
>
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
clip-rule="evenodd"
|
||||
d="M28 0 10.869 7.77 28 15.539l17.131-7.77L28 0Zm0 40.461-17.131 7.77L28 56l17.131-7.77L28 40.461Zm20.231-29.592L56 28.001l-7.769 17.131L40.462 28l7.769-17.131ZM15.538 28 7.77 10.869 0 28l7.769 17.131L15.538 28Z"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
<h1 class="settings-title">Settings</h1>
|
||||
</div>
|
||||
|
||||
<div class="loading">
|
||||
<span></span>
|
||||
<span></span>
|
||||
<span></span>
|
||||
</div>
|
||||
<div class="settings-content">
|
||||
<section class="settings-section">
|
||||
<div class="section-title">GENERAL</div>
|
||||
<div class="settings-group">
|
||||
<div class="setting-row">
|
||||
<div class="setting-row-content">
|
||||
<label class="setting-label" for="onyxDomain"
|
||||
>Root Domain</label
|
||||
>
|
||||
<div class="setting-description">
|
||||
The root URL for your Onyx instance
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="setting-divider"></div>
|
||||
<div class="setting-row" style="padding: 12px">
|
||||
<input
|
||||
type="text"
|
||||
id="onyxDomain"
|
||||
class="input-field"
|
||||
placeholder="https://cloud.onyx.app"
|
||||
autocomplete="off"
|
||||
autocorrect="off"
|
||||
autocapitalize="off"
|
||||
spellcheck="false"
|
||||
/>
|
||||
</div>
|
||||
<div class="error-message" id="errorMessage">
|
||||
Please enter a valid URL starting with http:// or https://
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<button
|
||||
class="btn"
|
||||
onclick="window.location.href='https://cloud.onyx.app'"
|
||||
>
|
||||
Open Onyx Cloud
|
||||
</button>
|
||||
|
||||
<p style="margin-top: 1.5rem; font-size: 0.875rem; color: #666">
|
||||
Self-hosted? Press
|
||||
<span
|
||||
class="shortcut-key"
|
||||
style="display: inline; padding: 0.15rem 0.4rem"
|
||||
>⌘ ,</span
|
||||
>
|
||||
to configure your server URL.
|
||||
</p>
|
||||
|
||||
<div class="shortcuts">
|
||||
<h3>Keyboard Shortcuts</h3>
|
||||
<div class="shortcut">
|
||||
<span>New Chat</span>
|
||||
<span class="shortcut-key">⌘ N</span>
|
||||
</div>
|
||||
<div class="shortcut">
|
||||
<span>New Window</span>
|
||||
<span class="shortcut-key">⌘ ⇧ N</span>
|
||||
</div>
|
||||
<div class="shortcut">
|
||||
<span>Reload</span>
|
||||
<span class="shortcut-key">⌘ R</span>
|
||||
</div>
|
||||
<div class="shortcut">
|
||||
<span>Back</span>
|
||||
<span class="shortcut-key">⌘ [</span>
|
||||
</div>
|
||||
<div class="shortcut">
|
||||
<span>Forward</span>
|
||||
<span class="shortcut-key">⌘ ]</span>
|
||||
</div>
|
||||
<div class="shortcut">
|
||||
<span>Settings / Config</span>
|
||||
<span class="shortcut-key">⌘ ,</span>
|
||||
<button class="button primary" id="saveBtn">Save & Connect</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
// Auto-redirect to Onyx Cloud after a short delay
|
||||
setTimeout(() => {
|
||||
window.location.href = "https://cloud.onyx.app";
|
||||
}, 1500);
|
||||
// Import Tauri API
|
||||
const { invoke } = window.__TAURI__.core;
|
||||
|
||||
// Configuration
|
||||
const DEFAULT_DOMAIN = "https://cloud.onyx.app";
|
||||
let currentServerUrl = "";
|
||||
|
||||
// DOM elements
|
||||
const domainInput = document.getElementById("onyxDomain");
|
||||
const errorMessage = document.getElementById("errorMessage");
|
||||
const saveBtn = document.getElementById("saveBtn");
|
||||
|
||||
function showSettings() {
|
||||
document.body.classList.add("show-settings");
|
||||
}
|
||||
|
||||
// Initialize the app
|
||||
async function init() {
|
||||
try {
|
||||
const bootstrap = await invoke("get_bootstrap_state");
|
||||
currentServerUrl = bootstrap.server_url;
|
||||
|
||||
// Set the input value
|
||||
domainInput.value = currentServerUrl || DEFAULT_DOMAIN;
|
||||
|
||||
// Check if user came here explicitly (via Settings menu/shortcut)
|
||||
const urlParams = new URLSearchParams(window.location.search);
|
||||
const isExplicitSettings =
|
||||
window.location.hash === "#settings" ||
|
||||
urlParams.get("settings") === "true";
|
||||
|
||||
// If user explicitly opened settings, show modal
|
||||
if (isExplicitSettings) {
|
||||
// Modal is already visible, user can edit and save
|
||||
showSettings();
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, check if this is first launch
|
||||
// First launch = config doesn't exist
|
||||
if (!bootstrap.config_exists || !currentServerUrl) {
|
||||
// First launch - show modal, require user to configure
|
||||
showSettings();
|
||||
return;
|
||||
}
|
||||
|
||||
// Not first launch and not explicit settings
|
||||
// Auto-redirect to configured domain
|
||||
window.location.href = currentServerUrl;
|
||||
} catch (error) {
|
||||
// On error, default to cloud
|
||||
domainInput.value = DEFAULT_DOMAIN;
|
||||
showSettings();
|
||||
}
|
||||
}
|
||||
|
||||
// Validate URL
|
||||
function validateUrl(url) {
|
||||
const trimmedUrl = url.trim();
|
||||
if (!trimmedUrl) {
|
||||
return { valid: false, error: "URL cannot be empty" };
|
||||
}
|
||||
if (
|
||||
!trimmedUrl.startsWith("http://") &&
|
||||
!trimmedUrl.startsWith("https://")
|
||||
) {
|
||||
return {
|
||||
valid: false,
|
||||
error: "URL must start with http:// or https://",
|
||||
};
|
||||
}
|
||||
try {
|
||||
new URL(trimmedUrl);
|
||||
return { valid: true, url: trimmedUrl };
|
||||
} catch {
|
||||
return { valid: false, error: "Please enter a valid URL" };
|
||||
}
|
||||
}
|
||||
|
||||
// Show error
|
||||
function showError(message) {
|
||||
domainInput.classList.add("error");
|
||||
errorMessage.textContent = message;
|
||||
errorMessage.classList.add("visible");
|
||||
}
|
||||
|
||||
// Clear error
|
||||
function clearError() {
|
||||
domainInput.classList.remove("error");
|
||||
errorMessage.classList.remove("visible");
|
||||
}
|
||||
|
||||
// Save configuration
|
||||
async function saveConfiguration() {
|
||||
clearError();
|
||||
|
||||
const validation = validateUrl(domainInput.value);
|
||||
if (!validation.valid) {
|
||||
showError(validation.error);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
saveBtn.disabled = true;
|
||||
saveBtn.textContent = "Saving...";
|
||||
|
||||
// Call Tauri command to save the URL
|
||||
await invoke("set_server_url", { url: validation.url });
|
||||
|
||||
// Success - redirect to the new URL (login page)
|
||||
window.location.href = validation.url;
|
||||
} catch (error) {
|
||||
showError(error || "Failed to save configuration");
|
||||
saveBtn.disabled = false;
|
||||
saveBtn.textContent = "Save & Connect";
|
||||
}
|
||||
}
|
||||
|
||||
// Event listeners
|
||||
domainInput.addEventListener("input", clearError);
|
||||
domainInput.addEventListener("keypress", (e) => {
|
||||
if (e.key === "Enter") {
|
||||
saveConfiguration();
|
||||
}
|
||||
});
|
||||
saveBtn.addEventListener("click", saveConfiguration);
|
||||
|
||||
// Initialize when DOM is ready
|
||||
if (document.readyState === "loading") {
|
||||
document.addEventListener("DOMContentLoaded", init);
|
||||
} else {
|
||||
init();
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
// This script injects a draggable title bar that matches Onyx design system
|
||||
|
||||
(function () {
|
||||
console.log("[Onyx Desktop] Title bar script loaded");
|
||||
|
||||
const TITLEBAR_ID = "onyx-desktop-titlebar";
|
||||
const TITLEBAR_HEIGHT = 36;
|
||||
const STYLE_ID = "onyx-desktop-titlebar-style";
|
||||
@@ -31,12 +29,7 @@
|
||||
try {
|
||||
await invoke("start_drag_window");
|
||||
return;
|
||||
} catch (err) {
|
||||
console.error(
|
||||
"[Onyx Desktop] Failed to start dragging via invoke:",
|
||||
err,
|
||||
);
|
||||
}
|
||||
} catch (err) {}
|
||||
}
|
||||
|
||||
const appWindow =
|
||||
@@ -46,14 +39,7 @@
|
||||
if (appWindow?.startDragging) {
|
||||
try {
|
||||
await appWindow.startDragging();
|
||||
} catch (err) {
|
||||
console.error(
|
||||
"[Onyx Desktop] Failed to start dragging via appWindow:",
|
||||
err,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
console.error("[Onyx Desktop] No Tauri drag API available.");
|
||||
} catch (err) {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -177,7 +163,6 @@
|
||||
|
||||
function mountTitleBar() {
|
||||
if (!document.body) {
|
||||
console.error("[Onyx Desktop] document.body not found");
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -193,7 +178,6 @@
|
||||
const titleBar = buildTitleBar();
|
||||
document.body.insertBefore(titleBar, document.body.firstChild);
|
||||
injectStyles();
|
||||
console.log("[Onyx Desktop] Title bar injected");
|
||||
}
|
||||
|
||||
function syncViewportHeight() {
|
||||
|
||||
@@ -5,8 +5,7 @@ build-backend = "setuptools.build_meta"
|
||||
[project]
|
||||
name = "onyx"
|
||||
version = "0.0.0"
|
||||
# TODO(jamison): Upgrade dependencies until they're compatible with python >3.13.
|
||||
requires-python = ">=3.11,<3.13"
|
||||
requires-python = ">=3.11"
|
||||
# Shared dependencies between backend and model_server
|
||||
dependencies = [
|
||||
"aioboto3==15.1.0",
|
||||
@@ -91,7 +90,7 @@ backend = [
|
||||
"python-dateutil==2.8.2",
|
||||
"python-gitlab==5.6.0",
|
||||
"python-pptx==0.6.23",
|
||||
"pypdf==6.1.3",
|
||||
"pypdf==6.6.0",
|
||||
"pytest-mock==3.12.0",
|
||||
"pytest-playwright==0.7.0",
|
||||
"python-docx==1.1.2",
|
||||
@@ -111,8 +110,8 @@ backend = [
|
||||
"tiktoken==0.7.0",
|
||||
"timeago==1.0.16",
|
||||
"types-openpyxl==3.0.4.7",
|
||||
"unstructured==0.15.1",
|
||||
"unstructured-client==0.25.4",
|
||||
"unstructured==0.18.27",
|
||||
"unstructured-client==0.42.6",
|
||||
"zulip==0.8.2",
|
||||
"hubspot-api-client==11.1.0",
|
||||
"asana==5.0.8",
|
||||
@@ -143,7 +142,7 @@ dev = [
|
||||
"matplotlib==3.10.8",
|
||||
"mypy-extensions==1.0.0",
|
||||
"mypy==1.13.0",
|
||||
"onyx-devtools==0.2.0",
|
||||
"onyx-devtools==0.6.2",
|
||||
"openapi-generator-cli==7.17.0",
|
||||
"pandas-stubs==2.2.3.241009",
|
||||
"pre-commit==3.2.2",
|
||||
@@ -181,7 +180,7 @@ ee = [
|
||||
model_server = [
|
||||
"accelerate==1.6.0",
|
||||
"einops==0.8.1",
|
||||
"numpy==1.26.4",
|
||||
"numpy==2.4.1",
|
||||
"safetensors==0.5.3",
|
||||
"sentence-transformers==4.0.2",
|
||||
"torch==2.6.0",
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 548 B After Width: | Height: | Size: 581 B |
@@ -25,7 +25,7 @@ export default function OnyxApiKeyForm({
|
||||
|
||||
return (
|
||||
<Modal open onOpenChange={onClose}>
|
||||
<Modal.Content tall>
|
||||
<Modal.Content width="sm" height="lg">
|
||||
<Modal.Header
|
||||
icon={SvgKey}
|
||||
title={isUpdate ? "Update API Key" : "Create a new API Key"}
|
||||
|
||||
@@ -105,7 +105,7 @@ function Main() {
|
||||
{popup}
|
||||
|
||||
<Modal open={!!fullApiKey}>
|
||||
<Modal.Content small>
|
||||
<Modal.Content width="sm" height="sm">
|
||||
<Modal.Header
|
||||
title="New API Key"
|
||||
icon={SvgKey}
|
||||
|
||||
@@ -10,10 +10,7 @@ import {
|
||||
} from "@/lib/types";
|
||||
import BackButton from "@/refresh-components/buttons/BackButton";
|
||||
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
|
||||
import {
|
||||
FetchAssistantsResponse,
|
||||
fetchAssistantsSS,
|
||||
} from "@/lib/assistants/fetchAssistantsSS";
|
||||
import { FetchAssistantsResponse, fetchAssistantsSS } from "@/lib/agentsSS";
|
||||
import { getStandardAnswerCategoriesIfEE } from "@/components/standardAnswers/getStandardAnswerCategoriesIfEE";
|
||||
|
||||
async function EditslackChannelConfigPage(props: {
|
||||
|
||||
@@ -4,7 +4,7 @@ import { fetchSS } from "@/lib/utilsSS";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { DocumentSetSummary, ValidSources } from "@/lib/types";
|
||||
import BackButton from "@/refresh-components/buttons/BackButton";
|
||||
import { fetchAssistantsSS } from "@/lib/assistants/fetchAssistantsSS";
|
||||
import { fetchAssistantsSS } from "@/lib/agentsSS";
|
||||
import { getStandardAnswerCategoriesIfEE } from "@/components/standardAnswers/getStandardAnswerCategoriesIfEE";
|
||||
import { redirect } from "next/navigation";
|
||||
import { SourceIcon } from "@/components/SourceIcon";
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import { Form, Formik, FormikProps } from "formik";
|
||||
import { SelectorFormField, TextFormField } from "@/components/Field";
|
||||
@@ -28,13 +30,7 @@ import { DisplayModels } from "./components/DisplayModels";
|
||||
import { fetchBedrockModels } from "../utils";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import {
|
||||
Tabs,
|
||||
TabsList,
|
||||
TabsTrigger,
|
||||
TabsContent,
|
||||
} from "@/refresh-components/tabs/tabs";
|
||||
import { cn } from "@/lib/utils";
|
||||
import Tabs from "@/refresh-components/Tabs";
|
||||
|
||||
export const BEDROCK_PROVIDER_NAME = "bedrock";
|
||||
const BEDROCK_DISPLAY_NAME = "AWS Bedrock";
|
||||
@@ -161,33 +157,25 @@ function BedrockFormInternals({
|
||||
onValueChange={(value) =>
|
||||
formikProps.setFieldValue(FIELD_BEDROCK_AUTH_METHOD, value)
|
||||
}
|
||||
className="mt-2"
|
||||
>
|
||||
<TabsList>
|
||||
<TabsTrigger value={AUTH_METHOD_IAM}>IAM Role</TabsTrigger>
|
||||
<TabsTrigger value={AUTH_METHOD_ACCESS_KEY}>Access Key</TabsTrigger>
|
||||
<TabsTrigger value={AUTH_METHOD_LONG_TERM_API_KEY}>
|
||||
<Tabs.List>
|
||||
<Tabs.Trigger value={AUTH_METHOD_IAM}>IAM Role</Tabs.Trigger>
|
||||
<Tabs.Trigger value={AUTH_METHOD_ACCESS_KEY}>
|
||||
Access Key
|
||||
</Tabs.Trigger>
|
||||
<Tabs.Trigger value={AUTH_METHOD_LONG_TERM_API_KEY}>
|
||||
Long-term API Key
|
||||
</TabsTrigger>
|
||||
</TabsList>
|
||||
</Tabs.Trigger>
|
||||
</Tabs.List>
|
||||
|
||||
<TabsContent
|
||||
value={AUTH_METHOD_IAM}
|
||||
className="data-[state=active]:animate-fade-in-scale"
|
||||
>
|
||||
<Tabs.Content value={AUTH_METHOD_IAM}>
|
||||
<Text as="p" text03>
|
||||
Uses the IAM role attached to your AWS environment. Recommended
|
||||
for EC2, ECS, Lambda, or other AWS services.
|
||||
</Text>
|
||||
</TabsContent>
|
||||
</Tabs.Content>
|
||||
|
||||
<TabsContent
|
||||
value={AUTH_METHOD_ACCESS_KEY}
|
||||
className={cn(
|
||||
"data-[state=active]:animate-fade-in-scale",
|
||||
"mt-4 ml-2"
|
||||
)}
|
||||
>
|
||||
<Tabs.Content value={AUTH_METHOD_ACCESS_KEY}>
|
||||
<div className="flex flex-col gap-4">
|
||||
<TextFormField
|
||||
name={FIELD_AWS_ACCESS_KEY_ID}
|
||||
@@ -200,15 +188,9 @@ function BedrockFormInternals({
|
||||
placeholder="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
|
||||
/>
|
||||
</div>
|
||||
</TabsContent>
|
||||
</Tabs.Content>
|
||||
|
||||
<TabsContent
|
||||
value={AUTH_METHOD_LONG_TERM_API_KEY}
|
||||
className={cn(
|
||||
"data-[state=active]:animate-fade-in-scale",
|
||||
"mt-4 ml-2"
|
||||
)}
|
||||
>
|
||||
<Tabs.Content value={AUTH_METHOD_LONG_TERM_API_KEY}>
|
||||
<div className="flex flex-col gap-4">
|
||||
<PasswordInputTypeInField
|
||||
name={FIELD_AWS_BEARER_TOKEN_BEDROCK}
|
||||
@@ -216,7 +198,7 @@ function BedrockFormInternals({
|
||||
placeholder="Your long-term API key"
|
||||
/>
|
||||
</div>
|
||||
</TabsContent>
|
||||
</Tabs.Content>
|
||||
</Tabs>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"use client";
|
||||
|
||||
import { useState, ReactNode } from "react";
|
||||
import useSWR, { useSWRConfig, KeyedMutator } from "swr";
|
||||
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
|
||||
import {
|
||||
LLMProviderView,
|
||||
ModelConfiguration,
|
||||
WellKnownLLMProviderDescriptor,
|
||||
} from "../../interfaces";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
@@ -114,7 +115,7 @@ export function ProviderFormEntrypointWrapper({
|
||||
|
||||
{formIsVisible && (
|
||||
<Modal open onOpenChange={onClose}>
|
||||
<Modal.Content medium>
|
||||
<Modal.Content>
|
||||
<Modal.Header
|
||||
icon={SvgSettings}
|
||||
title={`Setup ${providerName}`}
|
||||
@@ -196,7 +197,7 @@ export function ProviderFormEntrypointWrapper({
|
||||
|
||||
{formIsVisible && (
|
||||
<Modal open onOpenChange={onClose}>
|
||||
<Modal.Content medium>
|
||||
<Modal.Content>
|
||||
<Modal.Header
|
||||
icon={SvgSettings}
|
||||
title={`${existingLlmProvider ? "Configure" : "Setup"} ${
|
||||
|
||||
@@ -130,7 +130,7 @@ export default function UpgradingPage({
|
||||
{popup}
|
||||
{isCancelling && (
|
||||
<Modal open onOpenChange={() => setIsCancelling(false)}>
|
||||
<Modal.Content small>
|
||||
<Modal.Content width="sm" height="sm">
|
||||
<Modal.Header
|
||||
icon={SvgX}
|
||||
title="Cancel Embedding Model Switch"
|
||||
|
||||
@@ -81,7 +81,7 @@ export const WebProviderSetupModal = memo(
|
||||
|
||||
return (
|
||||
<Modal open={isOpen} onOpenChange={(open) => !open && onClose()}>
|
||||
<Modal.Content mini preventAccidentalClose>
|
||||
<Modal.Content width="sm" preventAccidentalClose>
|
||||
<Modal.Header
|
||||
icon={LogoArrangement}
|
||||
title={`Set up ${providerLabel}`}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
export type WebContentProviderType =
|
||||
| "firecrawl"
|
||||
| "onyx_web_crawler"
|
||||
| "exa"
|
||||
| (string & {});
|
||||
|
||||
export const CONTENT_PROVIDERS_URL = "/api/admin/web-search/content-providers";
|
||||
@@ -23,6 +24,13 @@ export const CONTENT_PROVIDER_DETAILS: Record<
|
||||
"Connect Firecrawl to fetch and summarize page content from search results.",
|
||||
logoSrc: "/firecrawl.svg",
|
||||
},
|
||||
exa: {
|
||||
label: "Exa",
|
||||
subtitle: "Exa.ai",
|
||||
description:
|
||||
"Use Exa to fetch and summarize page content from search results.",
|
||||
logoSrc: "/Exa.svg",
|
||||
},
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -64,6 +72,7 @@ const CONTENT_PROVIDER_CAPABILITIES: Record<
|
||||
base_url: ["base_url", "api_base_url"],
|
||||
},
|
||||
},
|
||||
// exa uses default capabilities
|
||||
};
|
||||
|
||||
const DEFAULT_CONTENT_PROVIDER_CAPABILITIES: ContentProviderCapabilities = {
|
||||
|
||||
@@ -136,6 +136,17 @@ export default function Page() {
|
||||
|
||||
const isLoading = isLoadingSearchProviders || isLoadingContentProviders;
|
||||
|
||||
// Exa shares API key between search and content providers
|
||||
const exaSearchProvider = searchProviders.find(
|
||||
(p) => p.provider_type === "exa"
|
||||
);
|
||||
const exaContentProvider = contentProviders.find(
|
||||
(p) => p.provider_type === "exa"
|
||||
);
|
||||
const hasSharedExaKey =
|
||||
(exaSearchProvider?.has_api_key || exaContentProvider?.has_api_key) ??
|
||||
false;
|
||||
|
||||
// Modal form state is owned by reducers
|
||||
|
||||
const openSearchModal = (
|
||||
@@ -145,12 +156,18 @@ export default function Page() {
|
||||
const requiresApiKey = searchProviderRequiresApiKey(providerType);
|
||||
const hasStoredKey = provider?.has_api_key ?? false;
|
||||
|
||||
// For Exa search provider, check if we can use the shared Exa key
|
||||
const isExa = providerType === "exa";
|
||||
const canUseSharedExaKey = isExa && hasSharedExaKey && !hasStoredKey;
|
||||
|
||||
dispatchSearchModal({
|
||||
type: "OPEN",
|
||||
providerType,
|
||||
existingProviderId: provider?.id ?? null,
|
||||
initialApiKeyValue:
|
||||
requiresApiKey && hasStoredKey ? MASKED_API_KEY_PLACEHOLDER : "",
|
||||
requiresApiKey && (hasStoredKey || canUseSharedExaKey)
|
||||
? MASKED_API_KEY_PLACEHOLDER
|
||||
: "",
|
||||
initialConfigValue: getSingleConfigFieldValueForForm(
|
||||
providerType,
|
||||
provider
|
||||
@@ -165,11 +182,16 @@ export default function Page() {
|
||||
const hasStoredKey = provider?.has_api_key ?? false;
|
||||
const defaultFirecrawlBaseUrl = "https://api.firecrawl.dev/v1/scrape";
|
||||
|
||||
// For Exa content provider, check if we can use the shared Exa key
|
||||
const isExa = providerType === "exa";
|
||||
const canUseSharedExaKey = isExa && hasSharedExaKey && !hasStoredKey;
|
||||
|
||||
dispatchContentModal({
|
||||
type: "OPEN",
|
||||
providerType,
|
||||
existingProviderId: provider?.id ?? null,
|
||||
initialApiKeyValue: hasStoredKey ? MASKED_API_KEY_PLACEHOLDER : "",
|
||||
initialApiKeyValue:
|
||||
hasStoredKey || canUseSharedExaKey ? MASKED_API_KEY_PLACEHOLDER : "",
|
||||
initialConfigValue:
|
||||
providerType === "firecrawl"
|
||||
? getSingleContentConfigFieldValueForForm(
|
||||
@@ -339,6 +361,17 @@ export default function Page() {
|
||||
} satisfies WebContentProviderView;
|
||||
}
|
||||
|
||||
if (providerType === "exa") {
|
||||
return {
|
||||
id: -3,
|
||||
name: "Exa",
|
||||
provider_type: "exa",
|
||||
is_active: false,
|
||||
config: null,
|
||||
has_api_key: hasSharedExaKey,
|
||||
} satisfies WebContentProviderView;
|
||||
}
|
||||
|
||||
return null;
|
||||
}).filter(Boolean) as WebContentProviderView[];
|
||||
|
||||
@@ -347,7 +380,7 @@ export default function Page() {
|
||||
);
|
||||
|
||||
return [...ordered, ...additional];
|
||||
}, [contentProviders]);
|
||||
}, [contentProviders, hasSharedExaKey]);
|
||||
|
||||
const currentContentProviderType =
|
||||
getCurrentContentProviderType(contentProviders);
|
||||
@@ -468,7 +501,12 @@ export default function Page() {
|
||||
onClose: () => {
|
||||
dispatchSearchModal({ type: "CLOSE" });
|
||||
},
|
||||
mutate: mutateSearchProviders,
|
||||
mutate: async () => {
|
||||
await mutateSearchProviders();
|
||||
if (selectedProviderType === "exa") {
|
||||
await mutateContentProviders();
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -678,6 +716,23 @@ export default function Page() {
|
||||
selectedContentProviderType
|
||||
: "";
|
||||
|
||||
if (selectedContentProviderType === "exa") {
|
||||
return (
|
||||
<>
|
||||
Paste your{" "}
|
||||
<a
|
||||
href="https://dashboard.exa.ai/api-keys"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="underline"
|
||||
>
|
||||
API key
|
||||
</a>{" "}
|
||||
from Exa to enable crawling.
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
return selectedContentProviderType === "firecrawl" ? (
|
||||
<>
|
||||
Paste your <span className="underline">API key</span> from Firecrawl to
|
||||
@@ -730,6 +785,10 @@ export default function Page() {
|
||||
dispatchContentModal({ type: "SET_PHASE", phase: "saving" });
|
||||
dispatchContentModal({ type: "CLEAR_MESSAGE" });
|
||||
|
||||
const apiKeyChangedForContentProvider =
|
||||
contentModal.apiKeyValue !== MASKED_API_KEY_PLACEHOLDER &&
|
||||
contentProviderValues.apiKey.length > 0;
|
||||
|
||||
await connectProviderFlow({
|
||||
category: "content",
|
||||
providerType: selectedContentProviderType,
|
||||
@@ -740,9 +799,7 @@ export default function Page() {
|
||||
CONTENT_PROVIDER_DETAILS[selectedContentProviderType]?.label ??
|
||||
selectedContentProviderType,
|
||||
providerRequiresApiKey: true,
|
||||
apiKeyChangedForProvider:
|
||||
contentModal.apiKeyValue !== MASKED_API_KEY_PLACEHOLDER &&
|
||||
contentProviderValues.apiKey.length > 0,
|
||||
apiKeyChangedForProvider: apiKeyChangedForContentProvider,
|
||||
apiKey: contentProviderValues.apiKey,
|
||||
config,
|
||||
configChanged,
|
||||
@@ -759,7 +816,12 @@ export default function Page() {
|
||||
onClose: () => {
|
||||
dispatchContentModal({ type: "CLOSE" });
|
||||
},
|
||||
mutate: mutateContentProviders,
|
||||
mutate: async () => {
|
||||
await mutateContentProviders();
|
||||
if (selectedContentProviderType === "exa") {
|
||||
await mutateSearchProviders();
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -1052,7 +1114,8 @@ export default function Page() {
|
||||
|
||||
const canActivate =
|
||||
providerId > 0 ||
|
||||
provider.provider_type === "onyx_web_crawler";
|
||||
provider.provider_type === "onyx_web_crawler" ||
|
||||
isConfigured;
|
||||
|
||||
return {
|
||||
label: "Set as Default",
|
||||
|
||||
@@ -125,7 +125,7 @@ export default function IndexAttemptErrorsModal({
|
||||
|
||||
return (
|
||||
<Modal open onOpenChange={onClose}>
|
||||
<Modal.Content large>
|
||||
<Modal.Content width="lg" height="full">
|
||||
<Modal.Header
|
||||
icon={SvgAlertTriangle}
|
||||
title="Indexing Errors"
|
||||
|
||||
@@ -353,7 +353,7 @@ export default function InlineFileManagement({
|
||||
|
||||
{/* Confirmation Modal */}
|
||||
<Modal open={showSaveConfirm} onOpenChange={setShowSaveConfirm}>
|
||||
<Modal.Content mini>
|
||||
<Modal.Content width="sm">
|
||||
<Modal.Header
|
||||
icon={SvgFolderPlus}
|
||||
title="Confirm File Changes"
|
||||
|
||||
@@ -128,7 +128,7 @@ export default function ReIndexModal({
|
||||
|
||||
return (
|
||||
<Modal open onOpenChange={hide}>
|
||||
<Modal.Content small>
|
||||
<Modal.Content width="sm" height="sm">
|
||||
<Modal.Header icon={SvgRefreshCw} title="Run Indexing" onClose={hide} />
|
||||
<Modal.Body>
|
||||
<Text as="p">
|
||||
|
||||
@@ -584,7 +584,7 @@ export default function AddConnector({
|
||||
open
|
||||
onOpenChange={() => setCreateCredentialFormToggle(false)}
|
||||
>
|
||||
<Modal.Content medium>
|
||||
<Modal.Content>
|
||||
<Modal.Header
|
||||
icon={SvgKey}
|
||||
title={`Create a ${getSourceDisplayName(
|
||||
|
||||
@@ -9,12 +9,7 @@ import FileInput from "./ConnectorInput/FileInput";
|
||||
import { ConfigurableSources } from "@/lib/types";
|
||||
import { Credential } from "@/lib/connectors/credentials";
|
||||
import CollapsibleSection from "@/app/admin/assistants/CollapsibleSection";
|
||||
import {
|
||||
Tabs,
|
||||
TabsContent,
|
||||
TabsList,
|
||||
TabsTrigger,
|
||||
} from "@/components/ui/fully_wrapped_tabs";
|
||||
import Tabs from "@/refresh-components/Tabs";
|
||||
import { useFormikContext } from "formik";
|
||||
|
||||
// Define a general type for form values
|
||||
@@ -60,7 +55,6 @@ const TabsField: FC<TabsFieldProps> = ({
|
||||
) : (
|
||||
<Tabs
|
||||
defaultValue={tabField.defaultTab || tabField.tabs[0]?.value}
|
||||
className="w-full"
|
||||
onValueChange={(newTab) => {
|
||||
// Clear values from other tabs but preserve defaults
|
||||
tabField.tabs.forEach((tab) => {
|
||||
@@ -75,15 +69,15 @@ const TabsField: FC<TabsFieldProps> = ({
|
||||
});
|
||||
}}
|
||||
>
|
||||
<TabsList>
|
||||
<Tabs.List>
|
||||
{tabField.tabs.map((tab) => (
|
||||
<TabsTrigger key={tab.value} value={tab.value}>
|
||||
<Tabs.Trigger key={tab.value} value={tab.value}>
|
||||
{tab.label}
|
||||
</TabsTrigger>
|
||||
</Tabs.Trigger>
|
||||
))}
|
||||
</TabsList>
|
||||
</Tabs.List>
|
||||
{tabField.tabs.map((tab) => (
|
||||
<TabsContent key={tab.value} value={tab.value} className="">
|
||||
<Tabs.Content key={tab.value} value={tab.value}>
|
||||
{tab.fields.map((subField, index, array) => {
|
||||
// Check visibility condition first
|
||||
if (
|
||||
@@ -112,7 +106,7 @@ const TabsField: FC<TabsFieldProps> = ({
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</TabsContent>
|
||||
</Tabs.Content>
|
||||
))}
|
||||
</Tabs>
|
||||
)}
|
||||
|
||||
@@ -323,7 +323,7 @@ const RerankingDetailsForm = forwardRef<
|
||||
open
|
||||
onOpenChange={() => setShowGpuWarningModalModel(null)}
|
||||
>
|
||||
<Modal.Content small>
|
||||
<Modal.Content width="sm" height="sm">
|
||||
<Modal.Header
|
||||
icon={SvgAlertTriangle}
|
||||
title="GPU Not Enabled"
|
||||
@@ -358,7 +358,7 @@ const RerankingDetailsForm = forwardRef<
|
||||
setShowLiteLLMConfigurationModal(false);
|
||||
}}
|
||||
>
|
||||
<Modal.Content medium>
|
||||
<Modal.Content>
|
||||
<Modal.Header
|
||||
icon={SvgKey}
|
||||
title="API Key Configuration"
|
||||
@@ -462,7 +462,7 @@ const RerankingDetailsForm = forwardRef<
|
||||
setIsApiKeyModalOpen(false);
|
||||
}}
|
||||
>
|
||||
<Modal.Content medium>
|
||||
<Modal.Content>
|
||||
<Modal.Header
|
||||
icon={SvgKey}
|
||||
title="API Key Configuration"
|
||||
|
||||
@@ -14,7 +14,7 @@ export default function AlreadyPickedModal({
|
||||
}: AlreadyPickedModalProps) {
|
||||
return (
|
||||
<Modal open onOpenChange={onClose}>
|
||||
<Modal.Content small>
|
||||
<Modal.Content width="sm" height="sm">
|
||||
<Modal.Header
|
||||
icon={SvgCheck}
|
||||
title={`${model.model_name} already chosen`}
|
||||
|
||||
@@ -21,7 +21,7 @@ export default function DeleteCredentialsModal({
|
||||
}: DeleteCredentialsModalProps) {
|
||||
return (
|
||||
<Modal open onOpenChange={onCancel}>
|
||||
<Modal.Content small>
|
||||
<Modal.Content width="sm" height="sm">
|
||||
<Modal.Header
|
||||
icon={SvgTrash}
|
||||
title={`Delete ${getFormattedProviderName(
|
||||
|
||||
@@ -13,7 +13,7 @@ export default function InstantSwitchConfirmModal({
|
||||
}: InstantSwitchConfirmModalProps) {
|
||||
return (
|
||||
<Modal open onOpenChange={onClose}>
|
||||
<Modal.Content small>
|
||||
<Modal.Content width="sm" height="sm">
|
||||
<Modal.Header
|
||||
icon={SvgAlertTriangle}
|
||||
title="Are you sure you want to do an instant switch?"
|
||||
|
||||
@@ -20,7 +20,7 @@ export default function ModelSelectionConfirmationModal({
|
||||
}: ModelSelectionConfirmationModalProps) {
|
||||
return (
|
||||
<Modal open onOpenChange={onCancel}>
|
||||
<Modal.Content tall>
|
||||
<Modal.Content width="sm" height="lg">
|
||||
<Modal.Header
|
||||
icon={SvgServer}
|
||||
title="Update Embedding Model"
|
||||
|
||||
@@ -186,7 +186,7 @@ export default function ProviderCreationModal({
|
||||
|
||||
return (
|
||||
<Modal open onOpenChange={onCancel}>
|
||||
<Modal.Content small>
|
||||
<Modal.Content width="sm" height="sm">
|
||||
<Modal.Header
|
||||
icon={SvgSettings}
|
||||
title={`Configure ${getFormattedProviderName(
|
||||
|
||||
@@ -17,7 +17,7 @@ export default function SelectModelModal({
|
||||
}: SelectModelModalProps) {
|
||||
return (
|
||||
<Modal open onOpenChange={onCancel}>
|
||||
<Modal.Content small>
|
||||
<Modal.Content width="sm" height="sm">
|
||||
<Modal.Header
|
||||
icon={SvgServer}
|
||||
title={`Select ${model.model_name}`}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user