Compare commits

...

63 Commits

Author SHA1 Message Date
Justin Tahara
f1c30974f5 fix(celery): Guardrail for User File Processing (#8633) 2026-03-01 09:22:43 -08:00
Jamison Lahman
81bf07fb15 chore(devtools): upgrade ods: v0.6.1->v0.6.2 (#8773) 2026-02-26 16:20:13 -08:00
Jamison Lahman
b565bf8291 chore(mypy): fix mypy cache issues switching between HEAD and release (#7732) 2026-01-27 15:52:40 -08:00
Jamison Lahman
b4da99cbdd fix(citations): enable citation sidebar w/ web_search-only assistants (#7888) 2026-01-27 13:36:44 -08:00
Justin Tahara
f910feea0f fix(llm): Hide private models from Agent Creation (#7873) 2026-01-27 12:20:56 -08:00
Justin Tahara
e3af8c6c8a feat(desktop): Domain Configuration (#7655) 2026-01-26 16:42:58 -08:00
Justin Tahara
d6e46ed792 feat(desktop): Properly Sign Mac App (#7608) 2026-01-26 16:42:47 -08:00
Jamison Lahman
4ce1f4ecdd chore(desktop): make artifact filename version-agnostic (#7679) 2026-01-26 16:24:06 -08:00
Jamison Lahman
a4678884d7 chore(deployments): fix region (#7640) 2026-01-26 16:24:06 -08:00
Jamison Lahman
c861ba68f1 chore(deployments): fetch secrets from AWS (#7584) 2026-01-26 16:24:06 -08:00
Raunak Bhagat
b1d0e0bb0b Fix actions-steps collapsing/opening issue 2026-01-25 12:49:32 -08:00
Raunak Bhagat
0d78bf52e3 Stop header from collapsing over and over again 2026-01-25 12:49:32 -08:00
Yuhong Sun
bd743282e6 fix: LiteLLM Azure models don't stream (#7761) 2026-01-25 12:47:48 -08:00
Raunak Bhagat
d44d1d92b3 2.9 fixes (#7756) 2026-01-24 17:36:20 -08:00
Raunak Bhagat
4cedcfee59 Fix notifications popover some more 2026-01-24 17:30:45 -08:00
Raunak Bhagat
90a721a76e Fix line-items 2026-01-24 17:30:45 -08:00
Raunak Bhagat
3ccd99e931 Fix notifications 2026-01-24 17:30:45 -08:00
Raunak Bhagat
9076bf603f Fix actions popover 2026-01-24 17:30:45 -08:00
Nikolas Garza
8c6e0a70c3 fix(chat): prevent streaming text from appearing in bursts after citations (#7745) 2026-01-24 16:58:12 -08:00
Yuhong Sun
bebe9555d4 fix: Azure OpenAI Tool Calls (#7727) 2026-01-24 16:55:27 -08:00
Nikolas Garza
c530722c9f fix(tests): use crawler-friendly search query in Exa integration test (#7746) 2026-01-24 16:53:40 -08:00
Jamison Lahman
68380b4ddb chore(fe): align assistant icon with chat bar (#7537) 2026-01-24 16:34:57 -08:00
Jamison Lahman
b3380746ab fix(fe): chat header is sticky and transparent (#7487) 2026-01-24 16:34:57 -08:00
Nikolas Garza
56be114c87 fix(fe): show scroll-down button when user scrolls up during streaming (#7562) 2026-01-24 16:34:57 -08:00
Nikolas Garza
54f467da5c fix: improve scroll behavior (#7364) 2026-01-24 16:34:57 -08:00
Nikolas Garza
8726b112fe fix(slack): Extract person names and filter garbage in query expansion (#7632) 2026-01-23 22:59:23 -08:00
Raunak Bhagat
92181d07b2 fix: Fix scrollability issues for modals (#7718) 2026-01-23 22:05:53 -08:00
Raunak Bhagat
3a73f7fab2 fix: Fix layout issues with AgentEditorPage (#7730) 2026-01-23 20:29:21 -08:00
Raunak Bhagat
7dabaca7cd fix: Add back agent sharing (#7731) 2026-01-23 19:13:36 -08:00
Raunak Bhagat
dec4748825 Close modal on success only 2026-01-23 17:39:52 -08:00
Raunak Bhagat
072836cd86 Cherry-pick agent-deletion 2026-01-23 17:39:52 -08:00
Evan Lohn
2705b5fb0e Revert "fix: modal header in index attempt errors (#7601)"
This reverts commit f945ab6b05.
2026-01-23 15:02:41 -08:00
Evan Lohn
37dcde4226 fix: prevent updates from overwriting perm syncing (#7384) 2026-01-23 14:52:44 -08:00
Evan Lohn
a765b5f622 fix(mcp): per-user auth (#7400) 2026-01-23 14:51:56 -08:00
Evan Lohn
5e093368d1 fix: bedrock non-anthropic prompt caching (#7435) 2026-01-23 14:50:13 -08:00
Evan Lohn
f945ab6b05 fix: modal header in index attempt errors (#7601) 2026-01-23 14:48:29 -08:00
Justin Tahara
11b7a22404 fix(ui): Coda Logo (#7656) 2026-01-23 14:45:29 -08:00
Justin Tahara
8e34f944cc fix(ui): First Connector Result (#7657) 2026-01-23 14:45:18 -08:00
Jamison Lahman
32606dc752 revert: "feat: Enable triple click on content in the chat" (#7393) to release v2.9 (#7710) 2026-01-23 14:21:22 -08:00
Jamison Lahman
1f6c4b40bf fix(fe): inline code text wraps (#7574) to release v2.9 (#7707) 2026-01-23 13:40:28 -08:00
Nikolas Garza
1943f1c745 feat(billing): add annual pricing support to subscription checkout (#7506) 2026-01-23 10:40:16 -08:00
Jamison Lahman
82460729a6 fix(db): ensure migrations are atomic (#7474) to release v2.9 (#7648) 2026-01-21 14:58:04 -08:00
Wenxi
c445e6a8c0 fix: delete old notifications first in migration (#7454) 2026-01-20 08:31:00 -08:00
SubashMohan
8d30a03d7f fix(chat): prevent adding chat sessions to recents that belong to a project (#7377) 2026-01-13 17:57:29 +00:00
Raunak Bhagat
277428f579 refactor: consolidate tabs components into single Tabs.tsx (#7370) 2026-01-13 03:51:48 +00:00
acaprau
9f8c0d4237 feat(opensearch): Even more feature parity, more strict tenant ID checks, OpenSearch client test improvements (#7372)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-01-13 03:39:02 +00:00
Jessica Singh
9ccbb6a04b feat(web search): exa crawler (#7326) 2026-01-13 01:42:16 +00:00
Danelegend
58a943f782 fix(tools): Tool name should align with what llm knows (#7352) 2026-01-13 01:04:20 +00:00
roshan
9021c607f2 chore(dr): finer grained tracing for clarification step, research plan step, and orchestration step (#7374)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-01-12 23:58:27 +00:00
Jamison Lahman
c03b0d80fd chore(deps): remove requires-python < 3.13 (#7367) 2026-01-12 23:21:02 +00:00
acaprau
fcf0b316a4 feat(opensearch): More feature parity (#7286) 2026-01-12 23:01:55 +00:00
Jamison Lahman
157f672b4b chore(deps): upgrade numpy, unstructured, unstructured-client (#7369) 2026-01-12 22:58:11 +00:00
dependabot[bot]
51b9484b96 chore(deps): bump actions/upload-artifact from 5.0.0 to 6.0.0 (#6964)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-01-12 21:53:48 +00:00
Danelegend
0c8f55c049 fix(tools): persist enabled tools in ui (#7347) 2026-01-12 21:47:29 +00:00
dependabot[bot]
c7be2571d1 chore(deps): bump tauri-apps/tauri-action from 0.6.0 to 0.6.1 (#7371)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-01-12 13:48:46 -08:00
dependabot[bot]
4948b6cca9 chore(deps): bump actions/stale from 10.1.0 to 10.1.1 (#6965)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-01-12 13:12:24 -08:00
Jamison Lahman
638ea5f316 chore(deps): fix uv-lock hook (#7368) 2026-01-12 12:52:17 -08:00
dependabot[bot]
6e3268ca75 chore(deps): bump pypdf from 6.1.3 to 6.6.0 (#7319)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-01-12 20:36:47 +00:00
Wenxi
d8921df60c fix: onboarding modal styling (#7363) 2026-01-12 20:29:23 +00:00
Yuhong Sun
693d9f5f69 fix: Editing First Message (#7366) 2026-01-12 19:45:01 +00:00
Jamison Lahman
02e17871cc chore(devtools): recommend starting dev dockers with --wait (#7365) 2026-01-12 19:13:00 +00:00
Wenxi
209cfd00b0 fix: only show latest release notification for nightly versions (#7362) 2026-01-12 11:10:28 -08:00
Jessica Singh
cd36baa484 fix(web search): removing site: operator from exa query (#7248) 2026-01-12 18:22:18 +00:00
202 changed files with 8429 additions and 4006 deletions

View File

@@ -8,7 +8,9 @@ on:
# Set restrictive default permissions for all jobs. Jobs that need more permissions
# should explicitly declare them.
permissions: {}
permissions:
# Required for OIDC authentication with AWS
id-token: write # zizmor: ignore[excessive-permissions]
env:
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
@@ -150,16 +152,30 @@ jobs:
if: always() && needs.check-version-tag.result == 'failure' && github.event_name != 'workflow_dispatch'
runs-on: ubuntu-slim
timeout-minutes: 10
environment: release
steps:
- name: Checkout
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
MONITOR_DEPLOYMENTS_WEBHOOK, deploy/monitor-deployments-webhook
parse-json-secrets: true
- name: Send Slack notification
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
webhook-url: ${{ env.MONITOR_DEPLOYMENTS_WEBHOOK }}
failed-jobs: "• check-version-tag"
title: "🚨 Version Tag Check Failed"
ref-name: ${{ github.ref_name }}
@@ -168,6 +184,7 @@ jobs:
needs: determine-builds
if: needs.determine-builds.outputs.build-desktop == 'true'
permissions:
id-token: write
contents: write
actions: read
strategy:
@@ -185,12 +202,33 @@ jobs:
runs-on: ${{ matrix.platform }}
timeout-minutes: 90
environment: release
steps:
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6.0.1
with:
# NOTE: persist-credentials is needed for tauri-action to create GitHub releases.
persist-credentials: true # zizmor: ignore[artipacked]
- name: Configure AWS credentials
if: startsWith(matrix.platform, 'macos-')
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
if: startsWith(matrix.platform, 'macos-')
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
APPLE_ID, deploy/apple-id
APPLE_PASSWORD, deploy/apple-password
APPLE_CERTIFICATE, deploy/apple-certificate
APPLE_CERTIFICATE_PASSWORD, deploy/apple-certificate-password
KEYCHAIN_PASSWORD, deploy/keychain-password
APPLE_TEAM_ID, deploy/apple-team-id
parse-json-secrets: true
- name: install dependencies (ubuntu only)
if: startsWith(matrix.platform, 'ubuntu-')
run: |
@@ -285,15 +323,40 @@ jobs:
Write-Host "Versions set to: $VERSION"
- uses: tauri-apps/tauri-action@19b93bb55601e3e373a93cfb6eb4242e45f5af20 # ratchet:tauri-apps/tauri-action@action-v0.6.0
- name: Import Apple Developer Certificate
if: startsWith(matrix.platform, 'macos-')
run: |
echo $APPLE_CERTIFICATE | base64 --decode > certificate.p12
security create-keychain -p "$KEYCHAIN_PASSWORD" build.keychain
security default-keychain -s build.keychain
security unlock-keychain -p "$KEYCHAIN_PASSWORD" build.keychain
security set-keychain-settings -t 3600 -u build.keychain
security import certificate.p12 -k build.keychain -P "$APPLE_CERTIFICATE_PASSWORD" -T /usr/bin/codesign
security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k "$KEYCHAIN_PASSWORD" build.keychain
security find-identity -v -p codesigning build.keychain
- name: Verify Certificate
if: startsWith(matrix.platform, 'macos-')
run: |
CERT_INFO=$(security find-identity -v -p codesigning build.keychain | grep -E "(Developer ID Application|Apple Distribution|Apple Development)" | head -n 1)
CERT_ID=$(echo "$CERT_INFO" | awk -F'"' '{print $2}')
echo "CERT_ID=$CERT_ID" >> $GITHUB_ENV
echo "Certificate imported."
- uses: tauri-apps/tauri-action@73fb865345c54760d875b94642314f8c0c894afa # ratchet:tauri-apps/tauri-action@action-v0.6.1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
APPLE_ID: ${{ env.APPLE_ID }}
APPLE_PASSWORD: ${{ env.APPLE_PASSWORD }}
APPLE_SIGNING_IDENTITY: ${{ env.CERT_ID }}
APPLE_TEAM_ID: ${{ env.APPLE_TEAM_ID }}
with:
tagName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
releaseName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
releaseBody: "See the assets to download this version and install."
releaseDraft: true
prerelease: false
assetNamePattern: "[name]_[arch][ext]"
args: ${{ matrix.args }}
build-web-amd64:
@@ -305,6 +368,7 @@ jobs:
- run-id=${{ github.run_id }}-web-amd64
- extras=ecr-cache
timeout-minutes: 90
environment: release
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -317,6 +381,20 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
@@ -331,8 +409,8 @@ jobs:
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push AMD64
id: build
@@ -363,6 +441,7 @@ jobs:
- run-id=${{ github.run_id }}-web-arm64
- extras=ecr-cache
timeout-minutes: 90
environment: release
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -375,6 +454,20 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
@@ -389,8 +482,8 @@ jobs:
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push ARM64
id: build
@@ -423,19 +516,34 @@ jobs:
- run-id=${{ github.run_id }}-merge-web
- extras=ecr-cache
timeout-minutes: 90
environment: release
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Docker meta
id: meta
@@ -471,6 +579,7 @@ jobs:
- run-id=${{ github.run_id }}-web-cloud-amd64
- extras=ecr-cache
timeout-minutes: 90
environment: release
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -483,6 +592,20 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
@@ -497,8 +620,8 @@ jobs:
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push AMD64
id: build
@@ -537,6 +660,7 @@ jobs:
- run-id=${{ github.run_id }}-web-cloud-arm64
- extras=ecr-cache
timeout-minutes: 90
environment: release
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -549,6 +673,20 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
@@ -563,8 +701,8 @@ jobs:
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push ARM64
id: build
@@ -605,19 +743,34 @@ jobs:
- run-id=${{ github.run_id }}-merge-web-cloud
- extras=ecr-cache
timeout-minutes: 90
environment: release
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Docker meta
id: meta
@@ -650,6 +803,7 @@ jobs:
- run-id=${{ github.run_id }}-backend-amd64
- extras=ecr-cache
timeout-minutes: 90
environment: release
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -662,6 +816,20 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
@@ -676,8 +844,8 @@ jobs:
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push AMD64
id: build
@@ -707,6 +875,7 @@ jobs:
- run-id=${{ github.run_id }}-backend-arm64
- extras=ecr-cache
timeout-minutes: 90
environment: release
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -719,6 +888,20 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
@@ -733,8 +916,8 @@ jobs:
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push ARM64
id: build
@@ -766,19 +949,34 @@ jobs:
- run-id=${{ github.run_id }}-merge-backend
- extras=ecr-cache
timeout-minutes: 90
environment: release
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Docker meta
id: meta
@@ -815,6 +1013,7 @@ jobs:
- volume=40gb
- extras=ecr-cache
timeout-minutes: 90
environment: release
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -827,6 +1026,20 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
@@ -843,8 +1056,8 @@ jobs:
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push AMD64
id: build
@@ -879,6 +1092,7 @@ jobs:
- volume=40gb
- extras=ecr-cache
timeout-minutes: 90
environment: release
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
@@ -891,6 +1105,20 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Docker meta
id: meta
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
@@ -907,8 +1135,8 @@ jobs:
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push ARM64
id: build
@@ -944,19 +1172,34 @@ jobs:
- run-id=${{ github.run_id }}-merge-model-server
- extras=ecr-cache
timeout-minutes: 90
environment: release
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Docker meta
id: meta
@@ -994,11 +1237,26 @@ jobs:
- run-id=${{ github.run_id }}-trivy-scan-web
- extras=ecr-cache
timeout-minutes: 90
environment: release
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Run Trivy vulnerability scanner
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
with:
@@ -1014,8 +1272,8 @@ jobs:
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
image \
--skip-version-check \
@@ -1034,11 +1292,26 @@ jobs:
- run-id=${{ github.run_id }}-trivy-scan-web-cloud
- extras=ecr-cache
timeout-minutes: 90
environment: release
env:
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Run Trivy vulnerability scanner
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
with:
@@ -1054,8 +1327,8 @@ jobs:
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
image \
--skip-version-check \
@@ -1074,6 +1347,7 @@ jobs:
- run-id=${{ github.run_id }}-trivy-scan-backend
- extras=ecr-cache
timeout-minutes: 90
environment: release
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
steps:
@@ -1084,6 +1358,20 @@ jobs:
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Run Trivy vulnerability scanner
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
with:
@@ -1100,8 +1388,8 @@ jobs:
-v ${{ github.workspace }}/backend/.trivyignore:/tmp/.trivyignore:ro \
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
image \
--skip-version-check \
@@ -1121,11 +1409,26 @@ jobs:
- run-id=${{ github.run_id }}-trivy-scan-model-server
- extras=ecr-cache
timeout-minutes: 90
environment: release
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Run Trivy vulnerability scanner
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
with:
@@ -1141,8 +1444,8 @@ jobs:
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
image \
--skip-version-check \
@@ -1170,12 +1473,26 @@ jobs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 90
environment: release
steps:
- name: Checkout
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
with:
secret-ids: |
MONITOR_DEPLOYMENTS_WEBHOOK, deploy/monitor-deployments-webhook
parse-json-secrets: true
- name: Determine failed jobs
id: failed-jobs
shell: bash
@@ -1241,7 +1558,7 @@ jobs:
- name: Send Slack notification
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
webhook-url: ${{ env.MONITOR_DEPLOYMENTS_WEBHOOK }}
failed-jobs: ${{ steps.failed-jobs.outputs.jobs }}
title: "🚨 Deployment Workflow Failed"
ref-name: ${{ github.ref_name }}

View File

@@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # ratchet:actions/stale@v10
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # ratchet:actions/stale@v10
with:
stale-issue-message: 'This issue is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
stale-pr-message: 'This PR is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'

View File

@@ -172,7 +172,7 @@ jobs:
- name: Upload Docker logs
if: failure()
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v5
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-logs-${{ matrix.test-dir }}
path: docker-logs/

View File

@@ -439,7 +439,7 @@ jobs:
- name: Upload logs
if: always()
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-all-logs-${{ matrix.test-dir.name }}
path: ${{ github.workspace }}/docker-compose.log
@@ -568,7 +568,7 @@ jobs:
- name: Upload logs (multi-tenant)
if: always()
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-all-logs-multitenant
path: ${{ github.workspace }}/docker-compose-multitenant.log

View File

@@ -44,7 +44,7 @@ jobs:
- name: Upload coverage reports
if: always()
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: jest-coverage-${{ github.run_id }}
path: ./web/coverage

View File

@@ -424,7 +424,7 @@ jobs:
- name: Upload logs
if: always()
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-all-logs-${{ matrix.test-dir.name }}
path: ${{ github.workspace }}/docker-compose.log

View File

@@ -435,7 +435,7 @@ jobs:
fi
npx playwright test --project ${PROJECT}
- uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
if: always()
with:
# Includes test results and trace.zip files
@@ -455,7 +455,7 @@ jobs:
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
path: ${{ github.workspace }}/docker-compose.log

View File

@@ -50,8 +50,9 @@ jobs:
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
with:
path: backend/.mypy_cache
key: mypy-${{ runner.os }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
restore-keys: |
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
mypy-${{ runner.os }}-
- name: Run MyPy

View File

@@ -144,7 +144,7 @@ jobs:
- name: Upload logs
if: always()
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-all-logs
path: ${{ github.workspace }}/docker-compose.log

1
.gitignore vendored
View File

@@ -21,6 +21,7 @@ backend/tests/regression/search_quality/*.json
backend/onyx/evals/data/
backend/onyx/evals/one_off/*.json
*.log
*.csv
# secret files
.env

View File

@@ -11,7 +11,6 @@ repos:
- id: uv-sync
args: ["--locked", "--all-extras"]
- id: uv-lock
files: ^pyproject\.toml$
- id: uv-export
name: uv-export default.txt
args:

View File

@@ -225,7 +225,6 @@ def do_run_migrations(
) -> None:
if create_schema:
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
connection.execute(text("COMMIT"))
connection.execute(text(f'SET search_path TO "{schema_name}"'))
@@ -309,6 +308,7 @@ async def run_async_migrations() -> None:
schema_name=schema,
create_schema=create_schema,
)
await connection.commit()
except Exception as e:
logger.error(f"Error migrating schema {schema}: {e}")
if not continue_on_error:
@@ -346,6 +346,7 @@ async def run_async_migrations() -> None:
schema_name=schema,
create_schema=create_schema,
)
await connection.commit()
except Exception as e:
logger.error(f"Error migrating schema {schema}: {e}")
if not continue_on_error:

View File

@@ -85,103 +85,122 @@ class UserRow(NamedTuple):
def upgrade() -> None:
conn = op.get_bind()
# Start transaction
conn.execute(sa.text("BEGIN"))
# Step 1: Create or update the unified assistant (ID 0)
search_assistant = conn.execute(
sa.text("SELECT * FROM persona WHERE id = 0")
).fetchone()
try:
# Step 1: Create or update the unified assistant (ID 0)
search_assistant = conn.execute(
sa.text("SELECT * FROM persona WHERE id = 0")
).fetchone()
if search_assistant:
# Update existing Search assistant to be the unified assistant
conn.execute(
sa.text(
"""
UPDATE persona
SET name = :name,
description = :description,
system_prompt = :system_prompt,
num_chunks = :num_chunks,
is_default_persona = true,
is_visible = true,
deleted = false,
display_priority = :display_priority,
llm_filter_extraction = :llm_filter_extraction,
llm_relevance_filter = :llm_relevance_filter,
recency_bias = :recency_bias,
chunks_above = :chunks_above,
chunks_below = :chunks_below,
datetime_aware = :datetime_aware,
starter_messages = null
WHERE id = 0
"""
),
INSERT_DICT,
)
else:
# Create new unified assistant with ID 0
conn.execute(
sa.text(
"""
INSERT INTO persona (
id, name, description, system_prompt, num_chunks,
is_default_persona, is_visible, deleted, display_priority,
llm_filter_extraction, llm_relevance_filter, recency_bias,
chunks_above, chunks_below, datetime_aware, starter_messages,
builtin_persona
) VALUES (
0, :name, :description, :system_prompt, :num_chunks,
true, true, false, :display_priority, :llm_filter_extraction,
:llm_relevance_filter, :recency_bias, :chunks_above, :chunks_below,
:datetime_aware, null, true
)
"""
),
INSERT_DICT,
)
# Step 2: Mark ALL builtin assistants as deleted (except the unified assistant ID 0)
if search_assistant:
# Update existing Search assistant to be the unified assistant
conn.execute(
sa.text(
"""
UPDATE persona
SET deleted = true, is_visible = false, is_default_persona = false
WHERE builtin_persona = true AND id != 0
SET name = :name,
description = :description,
system_prompt = :system_prompt,
num_chunks = :num_chunks,
is_default_persona = true,
is_visible = true,
deleted = false,
display_priority = :display_priority,
llm_filter_extraction = :llm_filter_extraction,
llm_relevance_filter = :llm_relevance_filter,
recency_bias = :recency_bias,
chunks_above = :chunks_above,
chunks_below = :chunks_below,
datetime_aware = :datetime_aware,
starter_messages = null
WHERE id = 0
"""
)
),
INSERT_DICT,
)
else:
# Create new unified assistant with ID 0
conn.execute(
sa.text(
"""
INSERT INTO persona (
id, name, description, system_prompt, num_chunks,
is_default_persona, is_visible, deleted, display_priority,
llm_filter_extraction, llm_relevance_filter, recency_bias,
chunks_above, chunks_below, datetime_aware, starter_messages,
builtin_persona
) VALUES (
0, :name, :description, :system_prompt, :num_chunks,
true, true, false, :display_priority, :llm_filter_extraction,
:llm_relevance_filter, :recency_bias, :chunks_above, :chunks_below,
:datetime_aware, null, true
)
"""
),
INSERT_DICT,
)
# Step 3: Add all built-in tools to the unified assistant
# First, get the tool IDs for SearchTool, ImageGenerationTool, and WebSearchTool
search_tool = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'SearchTool'")
).fetchone()
# Step 2: Mark ALL builtin assistants as deleted (except the unified assistant ID 0)
conn.execute(
sa.text(
"""
UPDATE persona
SET deleted = true, is_visible = false, is_default_persona = false
WHERE builtin_persona = true AND id != 0
"""
)
)
if not search_tool:
raise ValueError(
"SearchTool not found in database. Ensure tools migration has run first."
)
# Step 3: Add all built-in tools to the unified assistant
# First, get the tool IDs for SearchTool, ImageGenerationTool, and WebSearchTool
search_tool = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'SearchTool'")
).fetchone()
image_gen_tool = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'ImageGenerationTool'")
).fetchone()
if not search_tool:
raise ValueError(
"SearchTool not found in database. Ensure tools migration has run first."
)
if not image_gen_tool:
raise ValueError(
"ImageGenerationTool not found in database. Ensure tools migration has run first."
)
image_gen_tool = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'ImageGenerationTool'")
).fetchone()
# WebSearchTool is optional - may not be configured
web_search_tool = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'WebSearchTool'")
).fetchone()
if not image_gen_tool:
raise ValueError(
"ImageGenerationTool not found in database. Ensure tools migration has run first."
)
# Clear existing tool associations for persona 0
conn.execute(sa.text("DELETE FROM persona__tool WHERE persona_id = 0"))
# WebSearchTool is optional - may not be configured
web_search_tool = conn.execute(
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'WebSearchTool'")
).fetchone()
# Add tools to the unified assistant
# Clear existing tool associations for persona 0
conn.execute(sa.text("DELETE FROM persona__tool WHERE persona_id = 0"))
# Add tools to the unified assistant
conn.execute(
sa.text(
"""
INSERT INTO persona__tool (persona_id, tool_id)
VALUES (0, :tool_id)
ON CONFLICT DO NOTHING
"""
),
{"tool_id": search_tool[0]},
)
conn.execute(
sa.text(
"""
INSERT INTO persona__tool (persona_id, tool_id)
VALUES (0, :tool_id)
ON CONFLICT DO NOTHING
"""
),
{"tool_id": image_gen_tool[0]},
)
if web_search_tool:
conn.execute(
sa.text(
"""
@@ -190,191 +209,148 @@ def upgrade() -> None:
ON CONFLICT DO NOTHING
"""
),
{"tool_id": search_tool[0]},
{"tool_id": web_search_tool[0]},
)
conn.execute(
sa.text(
"""
INSERT INTO persona__tool (persona_id, tool_id)
VALUES (0, :tool_id)
ON CONFLICT DO NOTHING
# Step 4: Migrate existing chat sessions from all builtin assistants to unified assistant
conn.execute(
sa.text(
"""
),
{"tool_id": image_gen_tool[0]},
UPDATE chat_session
SET persona_id = 0
WHERE persona_id IN (
SELECT id FROM persona WHERE builtin_persona = true AND id != 0
)
"""
)
)
if web_search_tool:
# Step 5: Migrate user preferences - remove references to all builtin assistants
# First, get all builtin assistant IDs (except 0)
builtin_assistants_result = conn.execute(
sa.text(
"""
SELECT id FROM persona
WHERE builtin_persona = true AND id != 0
"""
)
).fetchall()
builtin_assistant_ids = [row[0] for row in builtin_assistants_result]
# Get all users with preferences
users_result = conn.execute(
sa.text(
"""
SELECT id, chosen_assistants, visible_assistants,
hidden_assistants, pinned_assistants
FROM "user"
"""
)
).fetchall()
for user_row in users_result:
user = UserRow(*user_row)
user_id: UUID = user.id
updates: dict[str, Any] = {}
# Remove all builtin assistants from chosen_assistants
if user.chosen_assistants:
new_chosen: list[int] = [
assistant_id
for assistant_id in user.chosen_assistants
if assistant_id not in builtin_assistant_ids
]
if new_chosen != user.chosen_assistants:
updates["chosen_assistants"] = json.dumps(new_chosen)
# Remove all builtin assistants from visible_assistants
if user.visible_assistants:
new_visible: list[int] = [
assistant_id
for assistant_id in user.visible_assistants
if assistant_id not in builtin_assistant_ids
]
if new_visible != user.visible_assistants:
updates["visible_assistants"] = json.dumps(new_visible)
# Add all builtin assistants to hidden_assistants
if user.hidden_assistants:
new_hidden: list[int] = list(user.hidden_assistants)
for old_id in builtin_assistant_ids:
if old_id not in new_hidden:
new_hidden.append(old_id)
if new_hidden != user.hidden_assistants:
updates["hidden_assistants"] = json.dumps(new_hidden)
else:
updates["hidden_assistants"] = json.dumps(builtin_assistant_ids)
# Remove all builtin assistants from pinned_assistants
if user.pinned_assistants:
new_pinned: list[int] = [
assistant_id
for assistant_id in user.pinned_assistants
if assistant_id not in builtin_assistant_ids
]
if new_pinned != user.pinned_assistants:
updates["pinned_assistants"] = json.dumps(new_pinned)
# Apply updates if any
if updates:
set_clause = ", ".join([f"{k} = :{k}" for k in updates.keys()])
updates["user_id"] = str(user_id) # Convert UUID to string for SQL
conn.execute(
sa.text(
"""
INSERT INTO persona__tool (persona_id, tool_id)
VALUES (0, :tool_id)
ON CONFLICT DO NOTHING
"""
),
{"tool_id": web_search_tool[0]},
sa.text(f'UPDATE "user" SET {set_clause} WHERE id = :user_id'),
updates,
)
# Step 4: Migrate existing chat sessions from all builtin assistants to unified assistant
conn.execute(
sa.text(
"""
UPDATE chat_session
SET persona_id = 0
WHERE persona_id IN (
SELECT id FROM persona WHERE builtin_persona = true AND id != 0
)
"""
)
)
# Step 5: Migrate user preferences - remove references to all builtin assistants
# First, get all builtin assistant IDs (except 0)
builtin_assistants_result = conn.execute(
sa.text(
"""
SELECT id FROM persona
WHERE builtin_persona = true AND id != 0
"""
)
).fetchall()
builtin_assistant_ids = [row[0] for row in builtin_assistants_result]
# Get all users with preferences
users_result = conn.execute(
sa.text(
"""
SELECT id, chosen_assistants, visible_assistants,
hidden_assistants, pinned_assistants
FROM "user"
"""
)
).fetchall()
for user_row in users_result:
user = UserRow(*user_row)
user_id: UUID = user.id
updates: dict[str, Any] = {}
# Remove all builtin assistants from chosen_assistants
if user.chosen_assistants:
new_chosen: list[int] = [
assistant_id
for assistant_id in user.chosen_assistants
if assistant_id not in builtin_assistant_ids
]
if new_chosen != user.chosen_assistants:
updates["chosen_assistants"] = json.dumps(new_chosen)
# Remove all builtin assistants from visible_assistants
if user.visible_assistants:
new_visible: list[int] = [
assistant_id
for assistant_id in user.visible_assistants
if assistant_id not in builtin_assistant_ids
]
if new_visible != user.visible_assistants:
updates["visible_assistants"] = json.dumps(new_visible)
# Add all builtin assistants to hidden_assistants
if user.hidden_assistants:
new_hidden: list[int] = list(user.hidden_assistants)
for old_id in builtin_assistant_ids:
if old_id not in new_hidden:
new_hidden.append(old_id)
if new_hidden != user.hidden_assistants:
updates["hidden_assistants"] = json.dumps(new_hidden)
else:
updates["hidden_assistants"] = json.dumps(builtin_assistant_ids)
# Remove all builtin assistants from pinned_assistants
if user.pinned_assistants:
new_pinned: list[int] = [
assistant_id
for assistant_id in user.pinned_assistants
if assistant_id not in builtin_assistant_ids
]
if new_pinned != user.pinned_assistants:
updates["pinned_assistants"] = json.dumps(new_pinned)
# Apply updates if any
if updates:
set_clause = ", ".join([f"{k} = :{k}" for k in updates.keys()])
updates["user_id"] = str(user_id) # Convert UUID to string for SQL
conn.execute(
sa.text(f'UPDATE "user" SET {set_clause} WHERE id = :user_id'),
updates,
)
# Commit transaction
conn.execute(sa.text("COMMIT"))
except Exception as e:
# Rollback on error
conn.execute(sa.text("ROLLBACK"))
raise e
def downgrade() -> None:
conn = op.get_bind()
# Start transaction
conn.execute(sa.text("BEGIN"))
try:
# Only restore General (ID -1) and Art (ID -3) assistants
# Step 1: Keep Search assistant (ID 0) as default but restore original state
conn.execute(
sa.text(
"""
UPDATE persona
SET is_default_persona = true,
is_visible = true,
deleted = false
WHERE id = 0
# Only restore General (ID -1) and Art (ID -3) assistants
# Step 1: Keep Search assistant (ID 0) as default but restore original state
conn.execute(
sa.text(
"""
)
UPDATE persona
SET is_default_persona = true,
is_visible = true,
deleted = false
WHERE id = 0
"""
)
)
# Step 2: Restore General assistant (ID -1)
conn.execute(
sa.text(
"""
UPDATE persona
SET deleted = false,
is_visible = true,
is_default_persona = true
WHERE id = :general_assistant_id
# Step 2: Restore General assistant (ID -1)
conn.execute(
sa.text(
"""
),
{"general_assistant_id": GENERAL_ASSISTANT_ID},
)
UPDATE persona
SET deleted = false,
is_visible = true,
is_default_persona = true
WHERE id = :general_assistant_id
"""
),
{"general_assistant_id": GENERAL_ASSISTANT_ID},
)
# Step 3: Restore Art assistant (ID -3)
conn.execute(
sa.text(
"""
UPDATE persona
SET deleted = false,
is_visible = true,
is_default_persona = true
WHERE id = :art_assistant_id
# Step 3: Restore Art assistant (ID -3)
conn.execute(
sa.text(
"""
),
{"art_assistant_id": ART_ASSISTANT_ID},
)
UPDATE persona
SET deleted = false,
is_visible = true,
is_default_persona = true
WHERE id = :art_assistant_id
"""
),
{"art_assistant_id": ART_ASSISTANT_ID},
)
# Note: We don't restore the original tool associations, names, or descriptions
# as those would require more complex logic to determine original state.
# We also cannot restore original chat session persona_ids as we don't
# have the original mappings.
# Other builtin assistants remain deleted as per the requirement.
# Commit transaction
conn.execute(sa.text("COMMIT"))
except Exception as e:
# Rollback on error
conn.execute(sa.text("ROLLBACK"))
raise e
# Note: We don't restore the original tool associations, names, or descriptions
# as those would require more complex logic to determine original state.
# We also cannot restore original chat session persona_ids as we don't
# have the original mappings.
# Other builtin assistants remain deleted as per the requirement.

View File

@@ -24,6 +24,9 @@ def upgrade() -> None:
# in unique constraints, but we want NULL == NULL for deduplication).
# The '{}' represents an empty JSONB object as the NULL replacement.
# Clean up legacy notifications first
op.execute("DELETE FROM notification WHERE title = 'New Notification'")
op.execute(
"""
CREATE UNIQUE INDEX IF NOT EXISTS ix_notification_user_type_data
@@ -40,9 +43,6 @@ def upgrade() -> None:
"""
)
# Clean up legacy 'reindex' notifications that are no longer needed
op.execute("DELETE FROM notification WHERE title = 'New Notification'")
def downgrade() -> None:
op.execute("DROP INDEX IF EXISTS ix_notification_user_type_data")

View File

@@ -42,20 +42,13 @@ TOOL_DESCRIPTIONS = {
def upgrade() -> None:
conn = op.get_bind()
conn.execute(sa.text("BEGIN"))
try:
for tool_id, description in TOOL_DESCRIPTIONS.items():
conn.execute(
sa.text(
"UPDATE tool SET description = :description WHERE in_code_tool_id = :tool_id"
),
{"description": description, "tool_id": tool_id},
)
conn.execute(sa.text("COMMIT"))
except Exception as e:
conn.execute(sa.text("ROLLBACK"))
raise e
for tool_id, description in TOOL_DESCRIPTIONS.items():
conn.execute(
sa.text(
"UPDATE tool SET description = :description WHERE in_code_tool_id = :tool_id"
),
{"description": description, "tool_id": tool_id},
)
def downgrade() -> None:

View File

@@ -7,7 +7,6 @@ Create Date: 2025-12-18 16:00:00.000000
"""
from alembic import op
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_DB_NAME
import sqlalchemy as sa
@@ -19,7 +18,7 @@ depends_on = None
DEEP_RESEARCH_TOOL = {
"name": RESEARCH_AGENT_DB_NAME,
"name": "ResearchAgent",
"display_name": "Research Agent",
"description": "The Research Agent is a sub-agent that conducts research on a specific topic.",
"in_code_tool_id": "ResearchAgent",

View File

@@ -70,80 +70,66 @@ BUILT_IN_TOOLS = [
def upgrade() -> None:
conn = op.get_bind()
# Start transaction
conn.execute(sa.text("BEGIN"))
# Get existing tools to check what already exists
existing_tools = conn.execute(
sa.text("SELECT in_code_tool_id FROM tool WHERE in_code_tool_id IS NOT NULL")
).fetchall()
existing_tool_ids = {row[0] for row in existing_tools}
try:
# Get existing tools to check what already exists
existing_tools = conn.execute(
sa.text(
"SELECT in_code_tool_id FROM tool WHERE in_code_tool_id IS NOT NULL"
# Insert or update built-in tools
for tool in BUILT_IN_TOOLS:
in_code_id = tool["in_code_tool_id"]
# Handle historical rename: InternetSearchTool -> WebSearchTool
if (
in_code_id == "WebSearchTool"
and "WebSearchTool" not in existing_tool_ids
and "InternetSearchTool" in existing_tool_ids
):
# Rename the existing InternetSearchTool row in place and update fields
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :name,
display_name = :display_name,
description = :description,
in_code_tool_id = :in_code_tool_id
WHERE in_code_tool_id = 'InternetSearchTool'
"""
),
tool,
)
).fetchall()
existing_tool_ids = {row[0] for row in existing_tools}
# Keep the local view of existing ids in sync to avoid duplicate insert
existing_tool_ids.discard("InternetSearchTool")
existing_tool_ids.add("WebSearchTool")
continue
# Insert or update built-in tools
for tool in BUILT_IN_TOOLS:
in_code_id = tool["in_code_tool_id"]
# Handle historical rename: InternetSearchTool -> WebSearchTool
if (
in_code_id == "WebSearchTool"
and "WebSearchTool" not in existing_tool_ids
and "InternetSearchTool" in existing_tool_ids
):
# Rename the existing InternetSearchTool row in place and update fields
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :name,
display_name = :display_name,
description = :description,
in_code_tool_id = :in_code_tool_id
WHERE in_code_tool_id = 'InternetSearchTool'
"""
),
tool,
)
# Keep the local view of existing ids in sync to avoid duplicate insert
existing_tool_ids.discard("InternetSearchTool")
existing_tool_ids.add("WebSearchTool")
continue
if in_code_id in existing_tool_ids:
# Update existing tool
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :name,
display_name = :display_name,
description = :description
WHERE in_code_tool_id = :in_code_tool_id
"""
),
tool,
)
else:
# Insert new tool
conn.execute(
sa.text(
"""
INSERT INTO tool (name, display_name, description, in_code_tool_id)
VALUES (:name, :display_name, :description, :in_code_tool_id)
"""
),
tool,
)
# Commit transaction
conn.execute(sa.text("COMMIT"))
except Exception as e:
# Rollback on error
conn.execute(sa.text("ROLLBACK"))
raise e
if in_code_id in existing_tool_ids:
# Update existing tool
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :name,
display_name = :display_name,
description = :description
WHERE in_code_tool_id = :in_code_tool_id
"""
),
tool,
)
else:
# Insert new tool
conn.execute(
sa.text(
"""
INSERT INTO tool (name, display_name, description, in_code_tool_id)
VALUES (:name, :display_name, :description, :in_code_tool_id)
"""
),
tool,
)
def downgrade() -> None:

View File

@@ -0,0 +1,64 @@
"""sync_exa_api_key_to_content_provider
Revision ID: d1b637d7050a
Revises: d25168c2beee
Create Date: 2026-01-09 15:54:15.646249
"""
from alembic import op
from sqlalchemy import text
# revision identifiers, used by Alembic.
revision = "d1b637d7050a"
down_revision = "d25168c2beee"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Exa uses a shared API key between search and content providers.
# For existing Exa search providers with API keys, create the corresponding
# content provider if it doesn't exist yet.
connection = op.get_bind()
# Check if Exa search provider exists with an API key
result = connection.execute(
text(
"""
SELECT api_key FROM internet_search_provider
WHERE provider_type = 'exa' AND api_key IS NOT NULL
LIMIT 1
"""
)
)
row = result.fetchone()
if row:
api_key = row[0]
# Create Exa content provider with the shared key
connection.execute(
text(
"""
INSERT INTO internet_content_provider
(name, provider_type, api_key, is_active)
VALUES ('Exa', 'exa', :api_key, false)
ON CONFLICT (name) DO NOTHING
"""
),
{"api_key": api_key},
)
def downgrade() -> None:
# Remove the Exa content provider that was created by this migration
connection = op.get_bind()
connection.execute(
text(
"""
DELETE FROM internet_content_provider
WHERE provider_type = 'exa'
"""
)
)

View File

@@ -0,0 +1,86 @@
"""tool_name_consistency
Revision ID: d25168c2beee
Revises: 8405ca81cc83
Create Date: 2026-01-11 17:54:40.135777
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "d25168c2beee"
down_revision = "8405ca81cc83"
branch_labels = None
depends_on = None
# Currently the seeded tools have the in_code_tool_id == name
CURRENT_TOOL_NAME_MAPPING = [
"SearchTool",
"WebSearchTool",
"ImageGenerationTool",
"PythonTool",
"OpenURLTool",
"KnowledgeGraphTool",
"ResearchAgent",
]
# Mapping of in_code_tool_id -> name
# These are the expected names that we want in the database
EXPECTED_TOOL_NAME_MAPPING = {
"SearchTool": "internal_search",
"WebSearchTool": "web_search",
"ImageGenerationTool": "generate_image",
"PythonTool": "python",
"OpenURLTool": "open_url",
"KnowledgeGraphTool": "run_kg_search",
"ResearchAgent": "research_agent",
}
def upgrade() -> None:
conn = op.get_bind()
# Mapping of in_code_tool_id to the NAME constant from each tool class
# These match the .name property of each tool implementation
tool_name_mapping = EXPECTED_TOOL_NAME_MAPPING
# Update the name column for each tool based on its in_code_tool_id
for in_code_tool_id, expected_name in tool_name_mapping.items():
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :expected_name
WHERE in_code_tool_id = :in_code_tool_id
"""
),
{
"expected_name": expected_name,
"in_code_tool_id": in_code_tool_id,
},
)
def downgrade() -> None:
conn = op.get_bind()
# Reverse the migration by setting name back to in_code_tool_id
# This matches the original pattern where name was the class name
for in_code_tool_id in CURRENT_TOOL_NAME_MAPPING:
conn.execute(
sa.text(
"""
UPDATE tool
SET name = :current_name
WHERE in_code_tool_id = :in_code_tool_id
"""
),
{
"current_name": in_code_tool_id,
"in_code_tool_id": in_code_tool_id,
},
)

View File

@@ -109,7 +109,6 @@ CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS = float(
STRIPE_SECRET_KEY = os.environ.get("STRIPE_SECRET_KEY")
STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")
# JWT Public Key URL
JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)

View File

@@ -3,30 +3,42 @@ from uuid import UUID
from sqlalchemy.orm import Session
from onyx.configs.constants import NotificationType
from onyx.db.models import Persona
from onyx.db.models import Persona__User
from onyx.db.models import Persona__UserGroup
from onyx.db.notification import create_notification
from onyx.server.features.persona.models import PersonaSharedNotificationData
def make_persona_private(
def update_persona_access(
persona_id: int,
creator_user_id: UUID | None,
user_ids: list[UUID] | None,
group_ids: list[int] | None,
db_session: Session,
is_public: bool | None = None,
user_ids: list[UUID] | None = None,
group_ids: list[int] | None = None,
) -> None:
"""NOTE(rkuo): This function batches all updates into a single commit. If we don't
dedupe the inputs, the commit will exception."""
"""Updates the access settings for a persona including public status, user shares,
and group shares.
db_session.query(Persona__User).filter(
Persona__User.persona_id == persona_id
).delete(synchronize_session="fetch")
db_session.query(Persona__UserGroup).filter(
Persona__UserGroup.persona_id == persona_id
).delete(synchronize_session="fetch")
NOTE: This function batches all updates. If we don't dedupe the inputs,
the commit will exception.
NOTE: Callers are responsible for committing."""
if is_public is not None:
persona = db_session.query(Persona).filter(Persona.id == persona_id).first()
if persona:
persona.is_public = is_public
# NOTE: For user-ids and group-ids, `None` means "leave unchanged", `[]` means "clear all shares",
# and a non-empty list means "replace with these shares".
if user_ids is not None:
db_session.query(Persona__User).filter(
Persona__User.persona_id == persona_id
).delete(synchronize_session="fetch")
if user_ids:
user_ids_set = set(user_ids)
for user_id in user_ids_set:
db_session.add(Persona__User(persona_id=persona_id, user_id=user_id))
@@ -41,11 +53,13 @@ def make_persona_private(
).model_dump(),
)
if group_ids:
if group_ids is not None:
db_session.query(Persona__UserGroup).filter(
Persona__UserGroup.persona_id == persona_id
).delete(synchronize_session="fetch")
group_ids_set = set(group_ids)
for group_id in group_ids_set:
db_session.add(
Persona__UserGroup(persona_id=persona_id, user_group_id=group_id)
)
db_session.commit()

View File

@@ -1,9 +1,9 @@
from typing import cast
from typing import Literal
import requests
import stripe
from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import BillingInformation
@@ -16,15 +16,21 @@ stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
def fetch_stripe_checkout_session(tenant_id: str) -> str:
def fetch_stripe_checkout_session(
tenant_id: str,
billing_period: Literal["monthly", "annual"] = "monthly",
) -> str:
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = f"{CONTROL_PLANE_API_BASE_URL}/create-checkout-session"
params = {"tenant_id": tenant_id}
response = requests.post(url, headers=headers, params=params)
payload = {
"tenant_id": tenant_id,
"billing_period": billing_period,
}
response = requests.post(url, headers=headers, json=payload)
response.raise_for_status()
return response.json()["sessionId"]
@@ -72,22 +78,24 @@ def fetch_billing_information(
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:
"""
Send a request to the control service to register the number of users for a tenant.
Update the number of seats for a tenant's subscription.
Preserves the existing price (monthly, annual, or grandfathered).
"""
if not STRIPE_PRICE_ID:
raise Exception("STRIPE_PRICE_ID is not set")
response = fetch_tenant_stripe_information(tenant_id)
stripe_subscription_id = cast(str, response.get("stripe_subscription_id"))
subscription = stripe.Subscription.retrieve(stripe_subscription_id)
subscription_item = subscription["items"]["data"][0]
# Use existing price to preserve the customer's current plan
current_price_id = subscription_item.price.id
updated_subscription = stripe.Subscription.modify(
stripe_subscription_id,
items=[
{
"id": subscription["items"]["data"][0].id,
"price": STRIPE_PRICE_ID,
"id": subscription_item.id,
"price": current_price_id,
"quantity": number_of_users,
}
],

View File

@@ -10,6 +10,7 @@ from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import CreateSubscriptionSessionRequest
from ee.onyx.server.tenants.models import ProductGatingFullSyncRequest
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
@@ -104,15 +105,18 @@ async def create_customer_portal_session(
@router.post("/create-subscription-session")
async def create_subscription_session(
request: CreateSubscriptionSessionRequest | None = None,
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
session_id = fetch_stripe_checkout_session(tenant_id)
billing_period = request.billing_period if request else "monthly"
session_id = fetch_stripe_checkout_session(tenant_id, billing_period)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create resubscription session")
logger.exception("Failed to create subscription session")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -1,4 +1,5 @@
from datetime import datetime
from typing import Literal
from pydantic import BaseModel
@@ -73,6 +74,12 @@ class SubscriptionSessionResponse(BaseModel):
sessionId: str
class CreateSubscriptionSessionRequest(BaseModel):
"""Request to create a subscription checkout session."""
billing_period: Literal["monthly", "annual"] = "monthly"
class TenantByDomainResponse(BaseModel):
tenant_id: str
number_of_users: int

View File

@@ -105,6 +105,8 @@ class DocExternalAccess:
)
# TODO(andrei): First refactor this into a pydantic model, then get rid of
# duplicate fields.
@dataclass(frozen=True, init=False)
class DocumentAccess(ExternalAccess):
# User emails for Onyx users, None indicates admin

View File

@@ -12,6 +12,7 @@ from retry import retry
from sqlalchemy import select
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.configs.app_configs import MANAGED_VESPA
@@ -19,12 +20,14 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
from onyx.connectors.file.connector import LocalFileConnector
from onyx.connectors.models import Document
from onyx.db.engine.sql_engine import get_session_with_current_tenant
@@ -53,6 +56,17 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
def _user_file_queued_key(user_file_id: str | UUID) -> str:
"""Key that exists while a process_single_user_file task is sitting in the queue.
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
before enqueuing and the worker deletes it as its first action. This prevents
the beat from adding duplicate tasks for files that already have a live task
in flight.
"""
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
@@ -116,7 +130,24 @@ def _get_document_chunk_count(
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
Uses direct Redis locks to avoid overlapping runs.
Three mechanisms prevent queue runaway:
1. **Queue depth backpressure** if the broker queue already has more than
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
entirely. Workers are clearly behind; adding more tasks would only make
the backlog worse.
2. **Per-file queued guard** before enqueuing a task we set a short-lived
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
already exists the file already has a live task in the queue, so we skip
it. The worker deletes the key the moment it picks up the task so the
next beat cycle can re-enqueue if the file is still PROCESSING.
3. **Task expiry** every enqueued task carries an `expires` value equal to
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
the queue after that deadline, Celery discards it without touching the DB.
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
Redis restart), stale tasks evict themselves rather than piling up forever.
"""
task_logger.info("check_user_file_processing - Starting")
@@ -131,7 +162,21 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
return None
enqueued = 0
skipped_guard = 0
try:
# --- Protection 1: queue depth backpressure ---
r_celery = self.app.broker_connection().channel().client # type: ignore
queue_len = celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
)
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
task_logger.warning(
f"check_user_file_processing - Queue depth {queue_len} exceeds "
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
f"tenant={tenant_id}"
)
return None
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
@@ -144,12 +189,35 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
)
for user_file_id in user_file_ids:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
# --- Protection 2: per-file queued guard ---
queued_key = _user_file_queued_key(user_file_id)
guard_set = redis_client.set(
queued_key,
1,
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
nx=True,
)
if not guard_set:
skipped_guard += 1
continue
# --- Protection 3: task expiry ---
# If task submission fails, clear the guard immediately so the
# next beat cycle can retry enqueuing this file.
try:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={
"user_file_id": str(user_file_id),
"tenant_id": tenant_id,
},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
)
except Exception:
redis_client.delete(queued_key)
raise
enqueued += 1
finally:
@@ -157,7 +225,8 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
lock.release()
task_logger.info(
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
f"tasks for tenant={tenant_id}"
)
return None
@@ -172,6 +241,12 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
start = time.monotonic()
redis_client = get_redis_client(tenant_id=tenant_id)
# Clear the "queued" guard set by the beat generator so that the next beat
# cycle can re-enqueue this file if it is still in PROCESSING state after
# this task completes or fails.
redis_client.delete(_user_file_queued_key(user_file_id))
file_lock: RedisLock = redis_client.lock(
_user_file_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,

View File

@@ -401,7 +401,10 @@ def handle_stream_message_objects(
if new_msg_req.parent_message_id == AUTO_PLACE_AFTER_LATEST_MESSAGE:
# Auto-place after the latest message in the chain
parent_message = chat_history[-1] if chat_history else root_message
elif new_msg_req.parent_message_id is None:
elif (
new_msg_req.parent_message_id is None
or new_msg_req.parent_message_id == root_message.id
):
# None = regeneration from root
parent_message = root_message
# Truncate history since we're starting from root

View File

@@ -149,6 +149,17 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
# How long a queued user-file task is valid before workers discard it.
# Should be longer than the beat interval (20 s) but short enough to prevent
# indefinite queue growth. Workers drop tasks older than this without touching
# the DB, so a shorter value = faster drain of stale duplicates.
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
# Maximum number of tasks allowed in the user-file-processing queue before the
# beat generator stops adding more. Prevents unbounded queue growth when workers
# fall behind.
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
@@ -419,6 +430,9 @@ class OnyxRedisLocks:
# User file processing
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
# Prevents the beat from re-enqueuing the same file while a task is already queued.
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"

View File

@@ -97,10 +97,17 @@ def basic_expert_info_representation(info: BasicExpertInfo) -> str | None:
def get_experts_stores_representations(
experts: list[BasicExpertInfo] | None,
) -> list[str] | None:
"""Gets string representations of experts supplied.
If an expert cannot be represented as a string, it is omitted from the
result.
"""
if not experts:
return None
reps = [basic_expert_info_representation(owner) for owner in experts]
reps: list[str | None] = [
basic_expert_info_representation(owner) for owner in experts
]
return [owner for owner in reps if owner is not None]

View File

@@ -566,6 +566,23 @@ def extract_content_words_from_recency_query(
return content_words_filtered[:MAX_CONTENT_WORDS]
def _is_valid_keyword_query(line: str) -> bool:
"""Check if a line looks like a valid keyword query vs explanatory text.
Returns False for lines that appear to be LLM explanations rather than keywords.
"""
# Reject lines that start with parentheses (explanatory notes)
if line.startswith("("):
return False
# Reject lines that are too long (likely sentences, not keywords)
# Keywords should be short - reject if > 50 chars or > 6 words
if len(line) > 50 or len(line.split()) > 6:
return False
return True
def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
"""Use LLM to expand query into multiple search variations.
@@ -586,10 +603,18 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
response_clean = _parse_llm_code_block_response(response)
# Split into lines and filter out empty lines
rephrased_queries = [
raw_queries = [
line.strip() for line in response_clean.split("\n") if line.strip()
]
# Filter out lines that look like explanatory text rather than keywords
rephrased_queries = [q for q in raw_queries if _is_valid_keyword_query(q)]
# Log if we filtered out garbage
if len(raw_queries) != len(rephrased_queries):
filtered_out = set(raw_queries) - set(rephrased_queries)
logger.warning(f"Filtered out non-keyword LLM responses: {filtered_out}")
# If no queries generated, use empty query
if not rephrased_queries:
logger.debug("No content keywords extracted from query expansion")

View File

@@ -444,6 +444,8 @@ def upsert_documents(
logger.info("No documents to upsert. Skipping.")
return
includes_permissions = any(doc.external_access for doc in seen_documents.values())
insert_stmt = insert(DbDocument).values(
[
model_to_dict(
@@ -479,21 +481,38 @@ def upsert_documents(
]
)
update_set = {
"from_ingestion_api": insert_stmt.excluded.from_ingestion_api,
"boost": insert_stmt.excluded.boost,
"hidden": insert_stmt.excluded.hidden,
"semantic_id": insert_stmt.excluded.semantic_id,
"link": insert_stmt.excluded.link,
"primary_owners": insert_stmt.excluded.primary_owners,
"secondary_owners": insert_stmt.excluded.secondary_owners,
"doc_metadata": insert_stmt.excluded.doc_metadata,
}
if includes_permissions:
# Use COALESCE to preserve existing permissions when new values are NULL.
# This prevents subsequent indexing runs (which don't fetch permissions)
# from overwriting permissions set by permission sync jobs.
update_set.update(
{
"external_user_emails": func.coalesce(
insert_stmt.excluded.external_user_emails,
DbDocument.external_user_emails,
),
"external_user_group_ids": func.coalesce(
insert_stmt.excluded.external_user_group_ids,
DbDocument.external_user_group_ids,
),
"is_public": func.coalesce(
insert_stmt.excluded.is_public,
DbDocument.is_public,
),
}
)
on_conflict_stmt = insert_stmt.on_conflict_do_update(
index_elements=["id"], # Conflict target
set_={
"from_ingestion_api": insert_stmt.excluded.from_ingestion_api,
"boost": insert_stmt.excluded.boost,
"hidden": insert_stmt.excluded.hidden,
"semantic_id": insert_stmt.excluded.semantic_id,
"link": insert_stmt.excluded.link,
"primary_owners": insert_stmt.excluded.primary_owners,
"secondary_owners": insert_stmt.excluded.secondary_owners,
"external_user_emails": insert_stmt.excluded.external_user_emails,
"external_user_group_ids": insert_stmt.excluded.external_user_group_ids,
"is_public": insert_stmt.excluded.is_public,
"doc_metadata": insert_stmt.excluded.doc_metadata,
},
index_elements=["id"], set_=update_set # Conflict target
)
db_session.execute(on_conflict_stmt)
db_session.commit()

View File

@@ -2616,6 +2616,7 @@ class Tool(Base):
__tablename__ = "tool"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
# The name of the tool that the LLM will see
name: Mapped[str] = mapped_column(String, nullable=False)
description: Mapped[str] = mapped_column(Text, nullable=True)
# ID of the tool in the codebase, only applies for in-code tools.

View File

@@ -187,13 +187,25 @@ def _get_persona_by_name(
return result
def make_persona_private(
def update_persona_access(
persona_id: int,
creator_user_id: UUID | None,
user_ids: list[UUID] | None,
group_ids: list[int] | None,
db_session: Session,
is_public: bool | None = None,
user_ids: list[UUID] | None = None,
group_ids: list[int] | None = None,
) -> None:
"""Updates the access settings for a persona including public status and user shares.
NOTE: Callers are responsible for committing."""
if is_public is not None:
persona = db_session.query(Persona).filter(Persona.id == persona_id).first()
if persona:
persona.is_public = is_public
# NOTE: For user-ids and group-ids, `None` means "leave unchanged", `[]` means "clear all shares",
# and a non-empty list means "replace with these shares".
if user_ids is not None:
db_session.query(Persona__User).filter(
Persona__User.persona_id == persona_id
@@ -212,11 +224,15 @@ def make_persona_private(
).model_dump(),
)
db_session.commit()
# MIT doesn't support group-based sharing, so we allow clearing (no-op since
# there shouldn't be any) but raise an error if trying to add actual groups.
if group_ids is not None:
db_session.query(Persona__UserGroup).filter(
Persona__UserGroup.persona_id == persona_id
).delete(synchronize_session="fetch")
# May cause error if someone switches down to MIT from EE
if group_ids:
raise NotImplementedError("Onyx MIT does not support private Personas")
if group_ids:
raise NotImplementedError("Onyx MIT does not support group-based sharing")
def create_update_persona(
@@ -282,20 +298,21 @@ def create_update_persona(
llm_filter_extraction=create_persona_request.llm_filter_extraction,
is_default_persona=create_persona_request.is_default_persona,
user_file_ids=converted_user_file_ids,
commit=False,
)
versioned_make_persona_private = fetch_versioned_implementation(
"onyx.db.persona", "make_persona_private"
versioned_update_persona_access = fetch_versioned_implementation(
"onyx.db.persona", "update_persona_access"
)
# Privatize Persona
versioned_make_persona_private(
versioned_update_persona_access(
persona_id=persona.id,
creator_user_id=user.id if user else None,
db_session=db_session,
user_ids=create_persona_request.users,
group_ids=create_persona_request.groups,
db_session=db_session,
)
db_session.commit()
except ValueError as e:
logger.exception("Failed to create persona")
@@ -304,11 +321,13 @@ def create_update_persona(
return FullPersonaSnapshot.from_model(persona)
def update_persona_shared_users(
def update_persona_shared(
persona_id: int,
user_ids: list[UUID],
user: User | None,
db_session: Session,
user_ids: list[UUID] | None = None,
group_ids: list[int] | None = None,
is_public: bool | None = None,
) -> None:
"""Simplified version of `create_update_persona` which only touches the
accessibility rather than any of the logic (e.g. prompt, connected data sources,
@@ -317,22 +336,25 @@ def update_persona_shared_users(
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
)
if persona.is_public:
raise HTTPException(status_code=400, detail="Cannot share public persona")
if user and user.role != UserRole.ADMIN and persona.user_id != user.id:
raise HTTPException(
status_code=403, detail="You don't have permission to modify this persona"
)
versioned_make_persona_private = fetch_versioned_implementation(
"onyx.db.persona", "make_persona_private"
versioned_update_persona_access = fetch_versioned_implementation(
"onyx.db.persona", "update_persona_access"
)
# Privatize Persona
versioned_make_persona_private(
versioned_update_persona_access(
persona_id=persona_id,
creator_user_id=user.id if user else None,
user_ids=user_ids,
group_ids=None,
db_session=db_session,
is_public=is_public,
user_ids=user_ids,
group_ids=group_ids,
)
db_session.commit()
def update_persona_public_status(
persona_id: int,

View File

@@ -113,7 +113,6 @@ def upsert_web_search_provider(
if activate:
set_active_web_search_provider(provider_id=provider.id, db_session=db_session)
db_session.commit()
db_session.refresh(provider)
return provider
@@ -269,7 +268,6 @@ def upsert_web_content_provider(
if activate:
set_active_web_content_provider(provider_id=provider.id, db_session=db_session)
db_session.commit()
db_session.refresh(provider)
return provider

View File

@@ -21,7 +21,6 @@ from onyx.configs.constants import MessageType
from onyx.db.tools import get_tool_by_name
from onyx.deep_research.dr_mock_tools import get_clarification_tool_definitions
from onyx.deep_research.dr_mock_tools import get_orchestrator_tools
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_DB_NAME
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_TOOL_NAME
from onyx.deep_research.dr_mock_tools import THINK_TOOL_RESPONSE_MESSAGE
from onyx.deep_research.dr_mock_tools import THINK_TOOL_RESPONSE_TOKEN_COUNT
@@ -220,35 +219,90 @@ def run_deep_research_llm_loop(
else ""
)
if not skip_clarification:
clarification_prompt = CLARIFICATION_PROMPT.format(
current_datetime=get_current_llm_day_time(full_sentence=False),
internal_search_clarification_guidance=internal_search_clarification_guidance,
)
with function_span("clarification_step") as span:
clarification_prompt = CLARIFICATION_PROMPT.format(
current_datetime=get_current_llm_day_time(full_sentence=False),
internal_search_clarification_guidance=internal_search_clarification_guidance,
)
system_prompt = ChatMessageSimple(
message=clarification_prompt,
token_count=300, # Skips the exact token count but has enough leeway
message_type=MessageType.SYSTEM,
)
truncated_message_history = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
project_files=None,
available_tokens=available_tokens,
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
)
llm_step_result, _ = run_llm_step(
emitter=emitter,
history=truncated_message_history,
tool_definitions=get_clarification_tool_definitions(),
tool_choice=ToolChoiceOptions.AUTO,
llm=llm,
placement=Placement(turn_index=0),
# No citations in this step, it should just pass through all
# tokens directly so initialized as an empty citation processor
citation_processor=None,
state_container=state_container,
final_documents=None,
user_identity=user_identity,
is_deep_research=True,
)
if not llm_step_result.tool_calls:
# Mark this turn as a clarification question
state_container.set_is_clarification(True)
span.span_data.output = "clarification_required"
emitter.emit(
Packet(
placement=Placement(turn_index=0),
obj=OverallStop(type="stop"),
)
)
# If a clarification is asked, we need to end this turn and wait on user input
return
#########################################################
# RESEARCH PLAN STEP
#########################################################
with function_span("research_plan_step") as span:
system_prompt = ChatMessageSimple(
message=clarification_prompt,
token_count=300, # Skips the exact token count but has enough leeway
message=RESEARCH_PLAN_PROMPT.format(
current_datetime=get_current_llm_day_time(full_sentence=False)
),
token_count=300,
message_type=MessageType.SYSTEM,
)
reminder_message = ChatMessageSimple(
message=RESEARCH_PLAN_REMINDER,
token_count=100,
message_type=MessageType.USER,
)
truncated_message_history = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
simple_chat_history=simple_chat_history + [reminder_message],
reminder_message=None,
project_files=None,
available_tokens=available_tokens,
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT + 1,
)
llm_step_result, _ = run_llm_step(
emitter=emitter,
research_plan_generator = run_llm_step_pkt_generator(
history=truncated_message_history,
tool_definitions=get_clarification_tool_definitions(),
tool_choice=ToolChoiceOptions.AUTO,
tool_definitions=[],
tool_choice=ToolChoiceOptions.NONE,
llm=llm,
placement=Placement(turn_index=0),
# No citations in this step, it should just pass through all
# tokens directly so initialized as an empty citation processor
citation_processor=None,
state_container=state_container,
final_documents=None,
@@ -256,301 +310,177 @@ def run_deep_research_llm_loop(
is_deep_research=True,
)
if not llm_step_result.tool_calls:
# Mark this turn as a clarification question
state_container.set_is_clarification(True)
emitter.emit(
Packet(
placement=Placement(turn_index=0), obj=OverallStop(type="stop")
)
)
# If a clarification is asked, we need to end this turn and wait on user input
return
#########################################################
# RESEARCH PLAN STEP
#########################################################
system_prompt = ChatMessageSimple(
message=RESEARCH_PLAN_PROMPT.format(
current_datetime=get_current_llm_day_time(full_sentence=False)
),
token_count=300,
message_type=MessageType.SYSTEM,
)
reminder_message = ChatMessageSimple(
message=RESEARCH_PLAN_REMINDER,
token_count=100,
message_type=MessageType.USER,
)
truncated_message_history = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=None,
simple_chat_history=simple_chat_history + [reminder_message],
reminder_message=None,
project_files=None,
available_tokens=available_tokens,
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT + 1,
)
research_plan_generator = run_llm_step_pkt_generator(
history=truncated_message_history,
tool_definitions=[],
tool_choice=ToolChoiceOptions.NONE,
llm=llm,
placement=Placement(turn_index=0),
citation_processor=None,
state_container=state_container,
final_documents=None,
user_identity=user_identity,
is_deep_research=True,
)
while True:
try:
packet = next(research_plan_generator)
# Translate AgentResponseStart/Delta packets to DeepResearchPlanStart/Delta
# The LLM response from this prompt is the research plan
if isinstance(packet.obj, AgentResponseStart):
while True:
try:
packet = next(research_plan_generator)
# Translate AgentResponseStart/Delta packets to DeepResearchPlanStart/Delta
# The LLM response from this prompt is the research plan
if isinstance(packet.obj, AgentResponseStart):
emitter.emit(
Packet(
placement=packet.placement,
obj=DeepResearchPlanStart(),
)
)
elif isinstance(packet.obj, AgentResponseDelta):
emitter.emit(
Packet(
placement=packet.placement,
obj=DeepResearchPlanDelta(content=packet.obj.content),
)
)
else:
# Pass through other packet types (e.g., ReasoningStart, ReasoningDelta, etc.)
emitter.emit(packet)
except StopIteration as e:
llm_step_result, reasoned = e.value
emitter.emit(
Packet(
placement=packet.placement,
obj=DeepResearchPlanStart(),
# Marks the last turn end which should be the plan generation
placement=Placement(
turn_index=1 if reasoned else 0,
),
obj=SectionEnd(),
)
)
elif isinstance(packet.obj, AgentResponseDelta):
emitter.emit(
Packet(
placement=packet.placement,
obj=DeepResearchPlanDelta(content=packet.obj.content),
)
)
else:
# Pass through other packet types (e.g., ReasoningStart, ReasoningDelta, etc.)
emitter.emit(packet)
except StopIteration as e:
llm_step_result, reasoned = e.value
emitter.emit(
Packet(
# Marks the last turn end which should be the plan generation
placement=Placement(
turn_index=1 if reasoned else 0,
),
obj=SectionEnd(),
)
)
if reasoned:
orchestrator_start_turn_index += 1
break
llm_step_result = cast(LlmStepResult, llm_step_result)
if reasoned:
orchestrator_start_turn_index += 1
break
llm_step_result = cast(LlmStepResult, llm_step_result)
research_plan = llm_step_result.answer
research_plan = llm_step_result.answer
span.span_data.output = research_plan if research_plan else None
#########################################################
# RESEARCH EXECUTION STEP
#########################################################
is_reasoning_model = model_is_reasoning_model(
llm.config.model_name, llm.config.model_provider
)
with function_span("research_execution_step") as span:
is_reasoning_model = model_is_reasoning_model(
llm.config.model_name, llm.config.model_provider
)
max_orchestrator_cycles = (
MAX_ORCHESTRATOR_CYCLES
if not is_reasoning_model
else MAX_ORCHESTRATOR_CYCLES_REASONING
)
max_orchestrator_cycles = (
MAX_ORCHESTRATOR_CYCLES
if not is_reasoning_model
else MAX_ORCHESTRATOR_CYCLES_REASONING
)
orchestrator_prompt_template = (
ORCHESTRATOR_PROMPT
if not is_reasoning_model
else ORCHESTRATOR_PROMPT_REASONING
)
orchestrator_prompt_template = (
ORCHESTRATOR_PROMPT
if not is_reasoning_model
else ORCHESTRATOR_PROMPT_REASONING
)
internal_search_research_task_guidance = (
INTERNAL_SEARCH_RESEARCH_TASK_GUIDANCE
if include_internal_search_tunings
else ""
)
token_count_prompt = orchestrator_prompt_template.format(
current_datetime=get_current_llm_day_time(full_sentence=False),
current_cycle_count=1,
max_cycles=max_orchestrator_cycles,
research_plan=research_plan,
internal_search_research_task_guidance=internal_search_research_task_guidance,
)
orchestration_tokens = token_counter(token_count_prompt)
reasoning_cycles = 0
most_recent_reasoning: str | None = None
citation_mapping: CitationMapping = {}
final_turn_index: int = (
orchestrator_start_turn_index # Track the final turn_index for stop packet
)
for cycle in range(max_orchestrator_cycles):
if cycle == max_orchestrator_cycles - 1:
# If it's the last cycle, forcibly generate the final report
report_turn_index = (
orchestrator_start_turn_index + cycle + reasoning_cycles
)
report_reasoned = generate_final_report(
history=simple_chat_history,
llm=llm,
token_counter=token_counter,
state_container=state_container,
emitter=emitter,
turn_index=report_turn_index,
citation_mapping=citation_mapping,
user_identity=user_identity,
)
# Update final_turn_index: base + 1 for the report itself + 1 if reasoning occurred
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
break
research_agent_calls: list[ToolCallKickoff] = []
orchestrator_prompt = orchestrator_prompt_template.format(
internal_search_research_task_guidance = (
INTERNAL_SEARCH_RESEARCH_TASK_GUIDANCE
if include_internal_search_tunings
else ""
)
token_count_prompt = orchestrator_prompt_template.format(
current_datetime=get_current_llm_day_time(full_sentence=False),
current_cycle_count=cycle,
current_cycle_count=1,
max_cycles=max_orchestrator_cycles,
research_plan=research_plan,
internal_search_research_task_guidance=internal_search_research_task_guidance,
)
orchestration_tokens = token_counter(token_count_prompt)
system_prompt = ChatMessageSimple(
message=orchestrator_prompt,
token_count=orchestration_tokens,
message_type=MessageType.SYSTEM,
reasoning_cycles = 0
most_recent_reasoning: str | None = None
citation_mapping: CitationMapping = {}
final_turn_index: int = (
orchestrator_start_turn_index # Track the final turn_index for stop packet
)
for cycle in range(max_orchestrator_cycles):
if cycle == max_orchestrator_cycles - 1:
# If it's the last cycle, forcibly generate the final report
report_turn_index = (
orchestrator_start_turn_index + cycle + reasoning_cycles
)
report_reasoned = generate_final_report(
history=simple_chat_history,
llm=llm,
token_counter=token_counter,
state_container=state_container,
emitter=emitter,
turn_index=report_turn_index,
citation_mapping=citation_mapping,
user_identity=user_identity,
)
# Update final_turn_index: base + 1 for the report itself + 1 if reasoning occurred
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
break
truncated_message_history = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
project_files=None,
available_tokens=available_tokens,
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
)
research_agent_calls: list[ToolCallKickoff] = []
# Use think tool processor for non-reasoning models to convert
# think_tool calls to reasoning content
custom_processor = (
create_think_tool_token_processor() if not is_reasoning_model else None
)
llm_step_result, has_reasoned = run_llm_step(
emitter=emitter,
history=truncated_message_history,
tool_definitions=get_orchestrator_tools(
include_think_tool=not is_reasoning_model
),
tool_choice=ToolChoiceOptions.REQUIRED,
llm=llm,
placement=Placement(
turn_index=orchestrator_start_turn_index + cycle + reasoning_cycles
),
# No citations in this step, it should just pass through all
# tokens directly so initialized as an empty citation processor
citation_processor=DynamicCitationProcessor(),
state_container=state_container,
final_documents=None,
user_identity=user_identity,
custom_token_processor=custom_processor,
is_deep_research=True,
)
if has_reasoned:
reasoning_cycles += 1
tool_calls = llm_step_result.tool_calls or []
if not tool_calls and cycle == 0:
raise RuntimeError(
"Deep Research failed to generate any research tasks for the agents."
orchestrator_prompt = orchestrator_prompt_template.format(
current_datetime=get_current_llm_day_time(full_sentence=False),
current_cycle_count=cycle,
max_cycles=max_orchestrator_cycles,
research_plan=research_plan,
internal_search_research_task_guidance=internal_search_research_task_guidance,
)
if not tool_calls:
# Basically hope that this is an infrequent occurence and hopefully multiple research
# cycles have already ran
logger.warning("No tool calls found, this should not happen.")
report_turn_index = (
orchestrator_start_turn_index + cycle + reasoning_cycles
system_prompt = ChatMessageSimple(
message=orchestrator_prompt,
token_count=orchestration_tokens,
message_type=MessageType.SYSTEM,
)
report_reasoned = generate_final_report(
history=simple_chat_history,
llm=llm,
token_counter=token_counter,
state_container=state_container,
truncated_message_history = construct_message_history(
system_prompt=system_prompt,
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
project_files=None,
available_tokens=available_tokens,
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
)
# Use think tool processor for non-reasoning models to convert
# think_tool calls to reasoning content
custom_processor = (
create_think_tool_token_processor()
if not is_reasoning_model
else None
)
llm_step_result, has_reasoned = run_llm_step(
emitter=emitter,
turn_index=report_turn_index,
citation_mapping=citation_mapping,
user_identity=user_identity,
)
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
break
special_tool_calls = check_special_tool_calls(tool_calls=tool_calls)
if special_tool_calls.generate_report_tool_call:
report_turn_index = (
special_tool_calls.generate_report_tool_call.placement.turn_index
)
report_reasoned = generate_final_report(
history=simple_chat_history,
history=truncated_message_history,
tool_definitions=get_orchestrator_tools(
include_think_tool=not is_reasoning_model
),
tool_choice=ToolChoiceOptions.REQUIRED,
llm=llm,
token_counter=token_counter,
placement=Placement(
turn_index=orchestrator_start_turn_index
+ cycle
+ reasoning_cycles
),
# No citations in this step, it should just pass through all
# tokens directly so initialized as an empty citation processor
citation_processor=DynamicCitationProcessor(),
state_container=state_container,
emitter=emitter,
turn_index=report_turn_index,
citation_mapping=citation_mapping,
final_documents=None,
user_identity=user_identity,
saved_reasoning=most_recent_reasoning,
custom_token_processor=custom_processor,
is_deep_research=True,
)
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
break
elif special_tool_calls.think_tool_call:
think_tool_call = special_tool_calls.think_tool_call
# Only process the THINK_TOOL and skip all other tool calls
# This will not actually get saved to the db as a tool call but we'll attach it to the tool(s) called after
# it as if it were just a reasoning model doing it. In the chat history, because it happens in 2 steps,
# we will show it as a separate message.
# NOTE: This does not need to increment the reasoning cycles because the custom token processor causes
# the LLM step to handle this
with function_span("think_tool") as span:
span.span_data.input = str(think_tool_call.tool_args)
most_recent_reasoning = state_container.reasoning_tokens
tool_call_message = think_tool_call.to_msg_str()
if has_reasoned:
reasoning_cycles += 1
think_tool_msg = ChatMessageSimple(
message=tool_call_message,
token_count=token_counter(tool_call_message),
message_type=MessageType.TOOL_CALL,
tool_call_id=think_tool_call.tool_call_id,
image_files=None,
tool_calls = llm_step_result.tool_calls or []
if not tool_calls and cycle == 0:
raise RuntimeError(
"Deep Research failed to generate any research tasks for the agents."
)
simple_chat_history.append(think_tool_msg)
think_tool_response_msg = ChatMessageSimple(
message=THINK_TOOL_RESPONSE_MESSAGE,
token_count=THINK_TOOL_RESPONSE_TOKEN_COUNT,
message_type=MessageType.TOOL_CALL_RESPONSE,
tool_call_id=think_tool_call.tool_call_id,
image_files=None,
)
simple_chat_history.append(think_tool_response_msg)
span.span_data.output = THINK_TOOL_RESPONSE_MESSAGE
continue
else:
for tool_call in tool_calls:
if tool_call.tool_name != RESEARCH_AGENT_TOOL_NAME:
logger.warning(f"Unexpected tool call: {tool_call.tool_name}")
continue
research_agent_calls.append(tool_call)
if not research_agent_calls:
logger.warning(
"No research agent tool calls found, this should not happen."
)
if not tool_calls:
# Basically hope that this is an infrequent occurence and hopefully multiple research
# cycles have already ran
logger.warning("No tool calls found, this should not happen.")
report_turn_index = (
orchestrator_start_turn_index + cycle + reasoning_cycles
)
@@ -567,91 +497,177 @@ def run_deep_research_llm_loop(
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
break
if len(research_agent_calls) > 1:
emitter.emit(
Packet(
placement=Placement(
turn_index=research_agent_calls[0].placement.turn_index
),
obj=TopLevelBranching(
num_parallel_branches=len(research_agent_calls)
),
special_tool_calls = check_special_tool_calls(tool_calls=tool_calls)
if special_tool_calls.generate_report_tool_call:
report_turn_index = (
special_tool_calls.generate_report_tool_call.placement.turn_index
)
report_reasoned = generate_final_report(
history=simple_chat_history,
llm=llm,
token_counter=token_counter,
state_container=state_container,
emitter=emitter,
turn_index=report_turn_index,
citation_mapping=citation_mapping,
user_identity=user_identity,
saved_reasoning=most_recent_reasoning,
)
final_turn_index = report_turn_index + (1 if report_reasoned else 0)
break
elif special_tool_calls.think_tool_call:
think_tool_call = special_tool_calls.think_tool_call
# Only process the THINK_TOOL and skip all other tool calls
# This will not actually get saved to the db as a tool call but we'll attach it to the tool(s) called after
# it as if it were just a reasoning model doing it. In the chat history, because it happens in 2 steps,
# we will show it as a separate message.
# NOTE: This does not need to increment the reasoning cycles because the custom token processor causes
# the LLM step to handle this
with function_span("think_tool") as span:
span.span_data.input = str(think_tool_call.tool_args)
most_recent_reasoning = state_container.reasoning_tokens
tool_call_message = think_tool_call.to_msg_str()
think_tool_msg = ChatMessageSimple(
message=tool_call_message,
token_count=token_counter(tool_call_message),
message_type=MessageType.TOOL_CALL,
tool_call_id=think_tool_call.tool_call_id,
image_files=None,
)
)
simple_chat_history.append(think_tool_msg)
research_results = run_research_agent_calls(
# The tool calls here contain the placement information
research_agent_calls=research_agent_calls,
parent_tool_call_ids=[
tool_call.tool_call_id for tool_call in tool_calls
],
tools=allowed_tools,
emitter=emitter,
state_container=state_container,
llm=llm,
is_reasoning_model=is_reasoning_model,
token_counter=token_counter,
citation_mapping=citation_mapping,
user_identity=user_identity,
)
citation_mapping = research_results.citation_mapping
for tab_index, report in enumerate(
research_results.intermediate_reports
):
if report is None:
# The LLM will not see that this research was even attempted, it may try
# something similar again but this is not bad.
logger.error(
f"Research agent call at tab_index {tab_index} failed, skipping"
think_tool_response_msg = ChatMessageSimple(
message=THINK_TOOL_RESPONSE_MESSAGE,
token_count=THINK_TOOL_RESPONSE_TOKEN_COUNT,
message_type=MessageType.TOOL_CALL_RESPONSE,
tool_call_id=think_tool_call.tool_call_id,
image_files=None,
)
continue
simple_chat_history.append(think_tool_response_msg)
span.span_data.output = THINK_TOOL_RESPONSE_MESSAGE
continue
else:
for tool_call in tool_calls:
if tool_call.tool_name != RESEARCH_AGENT_TOOL_NAME:
logger.warning(
f"Unexpected tool call: {tool_call.tool_name}"
)
continue
current_tool_call = research_agent_calls[tab_index]
tool_call_info = ToolCallInfo(
parent_tool_call_id=None,
turn_index=orchestrator_start_turn_index
+ cycle
+ reasoning_cycles,
tab_index=tab_index,
tool_name=current_tool_call.tool_name,
tool_call_id=current_tool_call.tool_call_id,
tool_id=get_tool_by_name(
tool_name=RESEARCH_AGENT_DB_NAME, db_session=db_session
).id,
reasoning_tokens=llm_step_result.reasoning
or most_recent_reasoning,
tool_call_arguments=current_tool_call.tool_args,
tool_call_response=report,
search_docs=None, # Intermediate docs are not saved/shown
generated_images=None,
research_agent_calls.append(tool_call)
if not research_agent_calls:
logger.warning(
"No research agent tool calls found, this should not happen."
)
report_turn_index = (
orchestrator_start_turn_index + cycle + reasoning_cycles
)
report_reasoned = generate_final_report(
history=simple_chat_history,
llm=llm,
token_counter=token_counter,
state_container=state_container,
emitter=emitter,
turn_index=report_turn_index,
citation_mapping=citation_mapping,
user_identity=user_identity,
)
final_turn_index = report_turn_index + (
1 if report_reasoned else 0
)
break
if len(research_agent_calls) > 1:
emitter.emit(
Packet(
placement=Placement(
turn_index=research_agent_calls[
0
].placement.turn_index
),
obj=TopLevelBranching(
num_parallel_branches=len(research_agent_calls)
),
)
)
research_results = run_research_agent_calls(
# The tool calls here contain the placement information
research_agent_calls=research_agent_calls,
parent_tool_call_ids=[
tool_call.tool_call_id for tool_call in tool_calls
],
tools=allowed_tools,
emitter=emitter,
state_container=state_container,
llm=llm,
is_reasoning_model=is_reasoning_model,
token_counter=token_counter,
citation_mapping=citation_mapping,
user_identity=user_identity,
)
state_container.add_tool_call(tool_call_info)
tool_call_message = current_tool_call.to_msg_str()
tool_call_token_count = token_counter(tool_call_message)
citation_mapping = research_results.citation_mapping
tool_call_msg = ChatMessageSimple(
message=tool_call_message,
token_count=tool_call_token_count,
message_type=MessageType.TOOL_CALL,
tool_call_id=current_tool_call.tool_call_id,
image_files=None,
)
simple_chat_history.append(tool_call_msg)
for tab_index, report in enumerate(
research_results.intermediate_reports
):
if report is None:
# The LLM will not see that this research was even attempted, it may try
# something similar again but this is not bad.
logger.error(
f"Research agent call at tab_index {tab_index} failed, skipping"
)
continue
tool_call_response_msg = ChatMessageSimple(
message=report,
token_count=token_counter(report),
message_type=MessageType.TOOL_CALL_RESPONSE,
tool_call_id=current_tool_call.tool_call_id,
image_files=None,
)
simple_chat_history.append(tool_call_response_msg)
current_tool_call = research_agent_calls[tab_index]
tool_call_info = ToolCallInfo(
parent_tool_call_id=None,
turn_index=orchestrator_start_turn_index
+ cycle
+ reasoning_cycles,
tab_index=tab_index,
tool_name=current_tool_call.tool_name,
tool_call_id=current_tool_call.tool_call_id,
tool_id=get_tool_by_name(
tool_name=RESEARCH_AGENT_TOOL_NAME,
db_session=db_session,
).id,
reasoning_tokens=llm_step_result.reasoning
or most_recent_reasoning,
tool_call_arguments=current_tool_call.tool_args,
tool_call_response=report,
search_docs=None, # Intermediate docs are not saved/shown
generated_images=None,
)
state_container.add_tool_call(tool_call_info)
# If it reached this point, it did not call reasoning, so here we wipe it to not save it to multiple turns
most_recent_reasoning = None
tool_call_message = current_tool_call.to_msg_str()
tool_call_token_count = token_counter(tool_call_message)
tool_call_msg = ChatMessageSimple(
message=tool_call_message,
token_count=tool_call_token_count,
message_type=MessageType.TOOL_CALL,
tool_call_id=current_tool_call.tool_call_id,
image_files=None,
)
simple_chat_history.append(tool_call_msg)
tool_call_response_msg = ChatMessageSimple(
message=report,
token_count=token_counter(report),
message_type=MessageType.TOOL_CALL_RESPONSE,
tool_call_id=current_tool_call.tool_call_id,
image_files=None,
)
simple_chat_history.append(tool_call_response_msg)
# If it reached this point, it did not call reasoning, so here we wipe it to not save it to multiple turns
most_recent_reasoning = None
emitter.emit(
Packet(

View File

@@ -1,6 +1,6 @@
GENERATE_PLAN_TOOL_NAME = "generate_plan"
RESEARCH_AGENT_DB_NAME = "ResearchAgent"
RESEARCH_AGENT_IN_CODE_ID = "ResearchAgent"
RESEARCH_AGENT_TOOL_NAME = "research_agent"
RESEARCH_AGENT_TASK_KEY = "task"

View File

@@ -3,6 +3,9 @@ import json
import httpx
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
get_experts_stores_representations,
)
from onyx.context.search.enums import QueryType
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunk
@@ -44,6 +47,7 @@ from onyx.document_index.opensearch.search import (
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import Document
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.model_server_models import Embedding
@@ -58,50 +62,36 @@ def _convert_opensearch_chunk_to_inference_chunk_uncleaned(
blurb=chunk.blurb,
content=chunk.content,
source_links=json.loads(chunk.source_links) if chunk.source_links else None,
image_file_id=chunk.image_file_name,
# TODO(andrei) Yuhong says he doesn't think we need that anymore. Used
# if a section needed to be split into diff chunks. A section is a part
# of a doc that a link will take you to. But don't chunks have their own
# links? Look at this in a followup.
image_file_id=chunk.image_file_id,
# Deprecated. Fill in some reasonable default.
section_continuation=False,
document_id=chunk.document_id,
source_type=DocumentSource(chunk.source_type),
semantic_identifier=chunk.semantic_identifier,
title=chunk.title,
# TODO(andrei): Same comment as in
# _convert_onyx_chunk_to_opensearch_document. Yuhong thinks OpenSearch
# has some thing out of the box for this. Just need to look at it in a
# followup.
boost=1,
# TODO(andrei): Do in a followup.
boost=chunk.global_boost,
# TODO(andrei): Do in a followup. We should be able to get this from
# OpenSearch.
recency_bias=1.0,
# TODO(andrei): This is how good the match is, we need this, key insight
# is we can order chunks by this. Should not be hard to plumb this from
# a search result, do that in a followup.
score=None,
hidden=chunk.hidden,
# TODO(andrei): Don't worry about these for now.
# is_relevant
# relevance_explanation
# metadata
# TODO(andrei): Same comment as in
# _convert_onyx_chunk_to_opensearch_document.
metadata={},
metadata=json.loads(chunk.metadata),
# TODO(andrei): The vector DB needs to supply this. I vaguely know
# OpenSearch can from the documentation I've seen till now, look at this
# in a followup.
match_highlights=[],
# TODO(andrei) This content is not queried on, it is only used to clean
# appended content to chunks. Consider storing a chunk content index
# instead of a full string when working on chunk content augmentation.
doc_summary="",
# TODO(andrei) Consider storing a chunk content index instead of a full
# string when working on chunk content augmentation.
doc_summary=chunk.doc_summary,
# TODO(andrei) Same thing as contx ret above, LLM gens context for each
# chunk.
chunk_context="",
chunk_context=chunk.chunk_context,
updated_at=chunk.last_updated,
# primary_owners TODO(andrei)
# secondary_owners TODO(andrei)
# large_chunk_reference_ids TODO(andrei): Don't worry about this one.
primary_owners=chunk.primary_owners,
secondary_owners=chunk.secondary_owners,
# TODO(andrei): This is the suffix appended to the end of the chunk
# content to assist querying. There are better ways we can do this, for
# ex. keeping an index of where to string split from.
@@ -126,44 +116,31 @@ def _convert_onyx_chunk_to_opensearch_document(
title_vector=chunk.title_embedding,
content=chunk.content,
content_vector=chunk.embeddings.full_embedding,
# TODO(andrei): We should know this. Reason to have this is convenience,
# but it could also change when you change your embedding model, maybe
# we can remove it, Yuhong to look at this. Hardcoded to some nonsense
# value for now.
num_tokens=0,
source_type=chunk.source_document.source.value,
# TODO(andrei): This is just represented a bit differently in
# DocumentBase than how we expect it in the schema currently. Look at
# this closer in a followup. Always defaults to None for now.
# metadata=chunk.source_document.metadata,
metadata=json.dumps(chunk.source_document.metadata),
last_updated=chunk.source_document.doc_updated_at,
# TODO(andrei): Don't currently see an easy way of porting this, and
# besides some connectors genuinely don't have this data. Look at this
# closer in a followup. Always defaults to None for now.
# created_at=None,
public=chunk.access.is_public,
# TODO(andrei): Implement ACL in a followup, currently none of the
# methods in OpenSearchDocumentIndex support it anyway. Always defaults
# to None for now.
# access_control_list=chunk.access.to_acl(),
# TODO(andrei): This doesn't work bc global_boost is float, presumably
# between 0.0 and inf (check this) and chunk.boost is an int from -inf
# to +inf. Look at how the scaling compares between these in a followup.
# Always defaults to 1.0 for now.
# global_boost=chunk.boost,
access_control_list=list(chunk.access.to_acl()),
global_boost=chunk.boost,
semantic_identifier=chunk.source_document.semantic_identifier,
# TODO(andrei): Ask Chris more about this later. Always defaults to None
# for now.
# image_file_name=None,
image_file_id=chunk.image_file_id,
source_links=json.dumps(chunk.source_links) if chunk.source_links else None,
blurb=chunk.blurb,
doc_summary=chunk.doc_summary,
chunk_context=chunk.chunk_context,
document_sets=list(chunk.document_sets) if chunk.document_sets else None,
project_ids=list(chunk.user_project) if chunk.user_project else None,
primary_owners=get_experts_stores_representations(
chunk.source_document.primary_owners
),
secondary_owners=get_experts_stores_representations(
chunk.source_document.secondary_owners
),
# TODO(andrei): Consider not even getting this from
# DocMetadataAwareIndexChunk and instead using OpenSearchDocumentIndex's
# instance variable. One source of truth -> less chance of a very bad
# bug in prod.
tenant_id=chunk.tenant_id,
tenant_id=TenantState(tenant_id=chunk.tenant_id, multitenant=MULTI_TENANT),
)

View File

@@ -4,30 +4,35 @@ from typing import Any
from typing import Self
from pydantic import BaseModel
from pydantic import Field
from pydantic import field_serializer
from pydantic import field_validator
from pydantic import model_serializer
from pydantic import model_validator
from pydantic import SerializerFunctionWrapHandler
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
from onyx.document_index.opensearch.constants import EF_CONSTRUCTION
from onyx.document_index.opensearch.constants import EF_SEARCH
from onyx.document_index.opensearch.constants import M
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
TITLE_FIELD_NAME = "title"
TITLE_VECTOR_FIELD_NAME = "title_vector"
CONTENT_FIELD_NAME = "content"
CONTENT_VECTOR_FIELD_NAME = "content_vector"
NUM_TOKENS_FIELD_NAME = "num_tokens"
SOURCE_TYPE_FIELD_NAME = "source_type"
METADATA_FIELD_NAME = "metadata"
LAST_UPDATED_FIELD_NAME = "last_updated"
CREATED_AT_FIELD_NAME = "created_at"
PUBLIC_FIELD_NAME = "public"
ACCESS_CONTROL_LIST_FIELD_NAME = "access_control_list"
HIDDEN_FIELD_NAME = "hidden"
GLOBAL_BOOST_FIELD_NAME = "global_boost"
SEMANTIC_IDENTIFIER_FIELD_NAME = "semantic_identifier"
IMAGE_FILE_NAME_FIELD_NAME = "image_file_name"
IMAGE_FILE_ID_FIELD_NAME = "image_file_id"
SOURCE_LINKS_FIELD_NAME = "source_links"
DOCUMENT_SETS_FIELD_NAME = "document_sets"
PROJECT_IDS_FIELD_NAME = "project_ids"
@@ -36,6 +41,10 @@ CHUNK_INDEX_FIELD_NAME = "chunk_index"
MAX_CHUNK_SIZE_FIELD_NAME = "max_chunk_size"
TENANT_ID_FIELD_NAME = "tenant_id"
BLURB_FIELD_NAME = "blurb"
DOC_SUMMARY_FIELD_NAME = "doc_summary"
CHUNK_CONTEXT_FIELD_NAME = "chunk_context"
PRIMARY_OWNERS_FIELD_NAME = "primary_owners"
SECONDARY_OWNERS_FIELD_NAME = "secondary_owners"
def get_opensearch_doc_chunk_id(
@@ -52,12 +61,27 @@ def get_opensearch_doc_chunk_id(
return f"{document_id}__{max_chunk_size}__{chunk_index}"
def set_or_convert_timezone_to_utc(value: datetime) -> datetime:
if value.tzinfo is None:
# astimezone will raise if value does not have a timezone set.
value = value.replace(tzinfo=timezone.utc)
else:
# Does appropriate time conversion if value was set in a different
# timezone.
value = value.astimezone(timezone.utc)
return value
class DocumentChunk(BaseModel):
"""
Represents a chunk of a document in the OpenSearch index.
The names of these fields are based on the OpenSearch schema. Changes to the
schema require changes here. See get_document_schema.
WARNING: Relies on MULTI_TENANT which is global state. Also uses
get_current_tenant_id. Generally relying on global state is bad, in this
case we accept it because of the importance of validating tenant logic.
"""
model_config = {"frozen": True}
@@ -75,41 +99,44 @@ class DocumentChunk(BaseModel):
title_vector: list[float] | None = None
content: str
content_vector: list[float]
# The actual number of tokens in the chunk.
num_tokens: int
source_type: str
# Application logic should store these strings the format key:::value.
metadata: list[str] | None = None
# Contains a string representation of a dict which maps string key to either
# string value or list of string values.
# TODO(andrei): When we augment content with metadata this can just be an
# index pointer, and when we support metadata list that will just be a list
# of strings.
metadata: str
# If it exists, time zone should always be UTC.
last_updated: datetime | None = None
created_at: datetime | None = None
public: bool
access_control_list: list[str] | None = None
access_control_list: list[str]
# Defaults to False, currently gets written during update not index.
hidden: bool = False
global_boost: float = 1.0
global_boost: int
semantic_identifier: str
image_file_name: str | None = None
image_file_id: str | None = None
# Contains a string representation of a dict which maps offset into the raw
# chunk text to the link corresponding to that point.
source_links: str | None = None
blurb: str
doc_summary: str
chunk_context: str
document_sets: list[str] | None = None
# User projects.
project_ids: list[int] | None = None
primary_owners: list[str] | None = None
secondary_owners: list[str] | None = None
tenant_id: str | None = None
@model_validator(mode="after")
def check_num_tokens_fits_within_max_chunk_size(self) -> Self:
if self.num_tokens > self.max_chunk_size:
raise ValueError(
"Bug: Num tokens must be less than or equal to max chunk size."
)
return self
tenant_id: TenantState = Field(
default_factory=lambda: TenantState(
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
)
)
@model_validator(mode="after")
def check_title_and_title_vector_are_consistent(self) -> Self:
@@ -120,25 +147,116 @@ class DocumentChunk(BaseModel):
raise ValueError("Bug: Title must not be None if title vector is not None.")
return self
@field_serializer("last_updated", "created_at", mode="plain")
@model_serializer(mode="wrap")
def serialize_model(
self, handler: SerializerFunctionWrapHandler
) -> dict[str, object]:
"""Invokes pydantic's serialization logic, then excludes Nones.
We do this because .model_dump(exclude_none=True) does not work after
@field_serializer logic, so for some field serializers which return None
and which we would like to exclude from the final dump, they would be
included without this.
Args:
handler: Callable from pydantic which takes the instance of the
model as an argument and performs standard serialization.
Returns:
The return of handler but with None items excluded.
"""
serialized: dict[str, object] = handler(self)
serialized_exclude_none = {k: v for k, v in serialized.items() if v is not None}
return serialized_exclude_none
@field_serializer("last_updated", mode="wrap")
def serialize_datetime_fields_to_epoch_millis(
self, value: datetime | None
self, value: datetime | None, handler: SerializerFunctionWrapHandler
) -> int | None:
"""
Serializes datetime fields to milliseconds since the Unix epoch.
If there is no datetime, returns None.
"""
if value is None:
return None
if value.tzinfo is None:
# astimezone will raise if value does not have a timezone set.
value = value.replace(tzinfo=timezone.utc)
else:
# Does appropriate time conversion if value was set in a different
# timezone.
value = value.astimezone(timezone.utc)
value = set_or_convert_timezone_to_utc(value)
# timestamp returns a float in seconds so convert to millis.
return int(value.timestamp() * 1000)
@field_validator("last_updated", mode="before")
@classmethod
def parse_epoch_millis_to_datetime(cls, value: Any) -> datetime | None:
"""Parses milliseconds since the Unix epoch to a datetime object.
If the input is None, returns None.
The datetime returned will be in UTC.
"""
if value is None:
return None
if isinstance(value, datetime):
value = set_or_convert_timezone_to_utc(value)
return value
if not isinstance(value, int):
raise ValueError(
f"Bug: Expected an int for the last_updated property from OpenSearch, got {type(value)} instead."
)
return datetime.fromtimestamp(value / 1000, tz=timezone.utc)
@field_serializer("tenant_id", mode="wrap")
def serialize_tenant_state(
self, value: TenantState, handler: SerializerFunctionWrapHandler
) -> str | None:
"""
Serializes tenant_state to the tenant str if multitenant, or None if
not.
The idea is that in single tenant mode, the schema does not have a
tenant_id field, so we don't want to supply it in our serialized
DocumentChunk. This assumes the final serialized model excludes None
fields, which serialize_model should enforce.
"""
if not value.multitenant:
return None
else:
return value.tenant_id
@field_validator("tenant_id", mode="before")
@classmethod
def parse_tenant_id(cls, value: Any) -> TenantState:
"""
Generates a TenantState from OpenSearch's tenant_id if it exists, or
generates a default state if it does not (implies we are in single
tenant mode).
"""
if value is None:
if MULTI_TENANT:
raise ValueError(
"Bug: No tenant_id was supplied but multi-tenant mode is enabled."
)
return TenantState(
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
)
elif isinstance(value, TenantState):
if MULTI_TENANT != value.multitenant:
raise ValueError(
f"Bug: An existing TenantState object was supplied to the DocumentChunk model but its multi-tenant mode "
f"({value.multitenant}) does not match the program's current global tenancy state."
)
return value
elif not isinstance(value, str):
raise ValueError(
f"Bug: Expected a str for the tenant_id property from OpenSearch, got {type(value)} instead."
)
else:
if not MULTI_TENANT:
raise ValueError(
"Bug: Got a non-null str for the tenant_id property from OpenSearch but multi-tenant mode is not enabled. "
"This is unexpected because in single-tenant mode we don't expect to see a tenant_id."
)
return TenantState(tenant_id=value, multitenant=MULTI_TENANT)
class DocumentSchema:
"""
@@ -176,13 +294,19 @@ class DocumentSchema:
OpenSearch client. The structure of this dictionary is
determined by OpenSearch documentation.
"""
schema = {
schema: dict[str, Any] = {
# By default OpenSearch allows dynamically adding new properties
# based on indexed documents. This is awful and we disable it here.
# An exception will be raised if you try to index a new doc which
# contains unexpected fields.
"dynamic": "strict",
"properties": {
TITLE_FIELD_NAME: {
"type": "text",
"fields": {
# Subfield accessed as title.keyword. Not indexed for
# values longer than 256 chars.
# TODO(andrei): Ask Yuhong do we want this?
"keyword": {"type": "keyword", "ignore_above": 256}
},
},
@@ -200,6 +324,8 @@ class DocumentSchema:
"parameters": {"ef_construction": EF_CONSTRUCTION, "m": M},
},
},
# TODO(andrei): This is a tensor in Vespa. Also look at feature
# parity for these other method fields.
CONTENT_VECTOR_FIELD_NAME: {
"type": "knn_vector",
"dimension": vector_dimension,
@@ -210,14 +336,10 @@ class DocumentSchema:
"parameters": {"ef_construction": EF_CONSTRUCTION, "m": M},
},
},
# See TODO in _convert_onyx_chunk_to_opensearch_document. I
# don't want to actually add this to the schema until we know
# for sure we need it. If we decide we don't I will remove this.
# # Number of tokens in the chunk's content.
# NUM_TOKENS_FIELD_NAME: {"type": "integer", "store": True},
SOURCE_TYPE_FIELD_NAME: {"type": "keyword"},
# Application logic should store in the format key:::value.
METADATA_FIELD_NAME: {"type": "keyword"},
# TODO(andrei): Check if Vespa stores seconds, we may wanna do
# seconds here not millis.
LAST_UPDATED_FIELD_NAME: {
"type": "date",
"format": "epoch_millis",
@@ -225,16 +347,6 @@ class DocumentSchema:
# would make sense to sort by date.
"doc_values": True,
},
# See TODO in _convert_onyx_chunk_to_opensearch_document. I
# don't want to actually add this to the schema until we know
# for sure we need it. If we decide we don't I will remove this.
# CREATED_AT_FIELD_NAME: {
# "type": "date",
# "format": "epoch_millis",
# # For some reason date defaults to False, even though it
# # would make sense to sort by date.
# "doc_values": True,
# },
# Access control fields.
# Whether the doc is public. Could have fallen under access
# control list but is such a broad and critical filter that it
@@ -247,7 +359,7 @@ class DocumentSchema:
# all other search filters; up to search implementations to
# guarantee this.
HIDDEN_FIELD_NAME: {"type": "boolean"},
GLOBAL_BOOST_FIELD_NAME: {"type": "float"},
GLOBAL_BOOST_FIELD_NAME: {"type": "integer"},
# This field is only used for displaying a useful name for the
# doc in the UI and is not used for searching. Disabling these
# features to increase perf.
@@ -258,7 +370,7 @@ class DocumentSchema:
"store": False,
},
# Same as above; used to display an image along with the doc.
IMAGE_FILE_NAME_FIELD_NAME: {
IMAGE_FILE_ID_FIELD_NAME: {
"type": "keyword",
"index": False,
"doc_values": False,
@@ -278,15 +390,36 @@ class DocumentSchema:
"doc_values": False,
"store": False,
},
# Same as above.
# TODO(andrei): If we want to search on this this needs to be
# changed.
DOC_SUMMARY_FIELD_NAME: {
"type": "keyword",
"index": False,
"doc_values": False,
"store": False,
},
# Same as above.
# TODO(andrei): If we want to search on this this needs to be
# changed.
CHUNK_CONTEXT_FIELD_NAME: {
"type": "keyword",
"index": False,
"doc_values": False,
"store": False,
},
# Product-specific fields.
DOCUMENT_SETS_FIELD_NAME: {"type": "keyword"},
PROJECT_IDS_FIELD_NAME: {"type": "integer"},
PRIMARY_OWNERS_FIELD_NAME: {"type": "keyword"},
SECONDARY_OWNERS_FIELD_NAME: {"type": "keyword"},
# OpenSearch metadata fields.
DOCUMENT_ID_FIELD_NAME: {"type": "keyword"},
CHUNK_INDEX_FIELD_NAME: {"type": "integer"},
# The maximum number of tokens this chunk's content can hold.
# TODO(andrei): Can we generalize this to embedding type?
MAX_CHUNK_SIZE_FIELD_NAME: {"type": "integer"},
}
},
}
if multitenant:

View File

@@ -24,7 +24,7 @@ from onyx.document_index.opensearch.schema import TITLE_VECTOR_FIELD_NAME
# TODO(andrei): Turn all magic dictionaries to pydantic models.
MIN_MAX_NORMALIZATION_PIPELINE_NAME = "normalization_pipeline_min_max"
MIN_MAX_NORMALIZATION_PIPELINE_CONFIG = {
MIN_MAX_NORMALIZATION_PIPELINE_CONFIG: dict[str, Any] = {
"description": "Normalization for keyword and vector scores using min-max",
"phase_results_processors": [
{
@@ -49,7 +49,7 @@ MIN_MAX_NORMALIZATION_PIPELINE_CONFIG = {
}
ZSCORE_NORMALIZATION_PIPELINE_NAME = "normalization_pipeline_zscore"
ZSCORE_NORMALIZATION_PIPELINE_CONFIG = {
ZSCORE_NORMALIZATION_PIPELINE_CONFIG: dict[str, Any] = {
"description": "Normalization for keyword and vector scores using z-score",
"phase_results_processors": [
{
@@ -140,7 +140,7 @@ class DocumentQuery:
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
]
if tenant_state.tenant_id is not None:
if tenant_state.multitenant:
# TODO(andrei): Fix tenant stuff.
filter_clauses.append(
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
@@ -199,7 +199,7 @@ class DocumentQuery:
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
]
if tenant_state.tenant_id is not None:
if tenant_state.multitenant:
filter_clauses.append(
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
)
@@ -316,6 +316,7 @@ class DocumentQuery:
{
"multi_match": {
"query": query_text,
# TODO(andrei): Ask Yuhong do we want this?
"fields": [f"{TITLE_FIELD_NAME}^2", f"{TITLE_FIELD_NAME}.keyword"],
"type": "best_fields",
}
@@ -340,7 +341,7 @@ class DocumentQuery:
{"term": {PUBLIC_FIELD_NAME: {"value": True}}},
{"term": {HIDDEN_FIELD_NAME: {"value": False}}},
]
if tenant_state.tenant_id is not None:
if tenant_state.multitenant:
hybrid_search_filters.append(
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
)

View File

@@ -9,7 +9,7 @@ from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.utils.logger import setup_logger
if TYPE_CHECKING:
from unstructured_client.models import operations # type: ignore
from unstructured_client.models import operations
logger = setup_logger()
@@ -55,19 +55,19 @@ def _sdk_partition_request(
def unstructured_to_text(file: IO[Any], file_name: str) -> str:
from unstructured.staging.base import dict_to_elements
from unstructured_client import UnstructuredClient # type: ignore
from unstructured_client import UnstructuredClient
logger.debug(f"Starting to read file: {file_name}")
req = _sdk_partition_request(file, file_name, strategy="fast")
unstructured_client = UnstructuredClient(api_key_auth=get_unstructured_api_key())
response = unstructured_client.general.partition(req)
elements = dict_to_elements(response.elements)
response = unstructured_client.general.partition(request=req)
if response.status_code != 200:
err = f"Received unexpected status code {response.status_code} from Unstructured API."
logger.error(err)
raise ValueError(err)
elements = dict_to_elements(response.elements or [])
return "\n\n".join(str(el) for el in elements)

View File

@@ -40,6 +40,7 @@ class BaseChunk(BaseModel):
source_links: dict[int, str] | None
image_file_id: str | None
# True if this Chunk's start is not at the start of a Section
# TODO(andrei): This is deprecated as of the OpenSearch migration. Remove.
section_continuation: bool

View File

@@ -369,6 +369,8 @@ def _patch_openai_responses_chunk_parser() -> None:
# New output item added
output_item = parsed_chunk.get("item", {})
if output_item.get("type") == "function_call":
# Track that we've received tool calls via streaming
self._has_streamed_tool_calls = True
return GenericStreamingChunk(
text="",
tool_use=ChatCompletionToolCallChunk(
@@ -394,6 +396,8 @@ def _patch_openai_responses_chunk_parser() -> None:
elif event_type == "response.function_call_arguments.delta":
content_part: Optional[str] = parsed_chunk.get("delta", None)
if content_part:
# Track that we've received tool calls via streaming
self._has_streamed_tool_calls = True
return GenericStreamingChunk(
text="",
tool_use=ChatCompletionToolCallChunk(
@@ -491,22 +495,72 @@ def _patch_openai_responses_chunk_parser() -> None:
elif event_type == "response.completed":
# Final event signaling all output items (including parallel tool calls) are done
# Check if we already received tool calls via streaming events
# There is an issue where OpenAI (not via Azure) will give back the tool calls streamed out as tokens
# But on Azure, it's only given out all at once. OpenAI also happens to give back the tool calls in the
# response.completed event so we need to throw it out here or there are duplicate tool calls.
has_streamed_tool_calls = getattr(self, "_has_streamed_tool_calls", False)
response_data = parsed_chunk.get("response", {})
# Determine finish reason based on response content
finish_reason = "stop"
if response_data.get("output"):
for item in response_data["output"]:
if isinstance(item, dict) and item.get("type") == "function_call":
finish_reason = "tool_calls"
break
return GenericStreamingChunk(
text="",
tool_use=None,
is_finished=True,
finish_reason=finish_reason,
usage=None,
output_items = response_data.get("output", [])
# Check if there are function_call items in the output
has_function_calls = any(
isinstance(item, dict) and item.get("type") == "function_call"
for item in output_items
)
if has_function_calls and not has_streamed_tool_calls:
# Azure's Responses API returns all tool calls in response.completed
# without streaming them incrementally. Extract them here.
from litellm.types.utils import (
Delta,
ModelResponseStream,
StreamingChoices,
)
tool_calls = []
for idx, item in enumerate(output_items):
if isinstance(item, dict) and item.get("type") == "function_call":
tool_calls.append(
ChatCompletionToolCallChunk(
id=item.get("call_id"),
index=idx,
type="function",
function=ChatCompletionToolCallFunctionChunk(
name=item.get("name"),
arguments=item.get("arguments", ""),
),
)
)
return ModelResponseStream(
choices=[
StreamingChoices(
index=0,
delta=Delta(tool_calls=tool_calls),
finish_reason="tool_calls",
)
]
)
elif has_function_calls:
# Tool calls were already streamed, just signal completion
return GenericStreamingChunk(
text="",
tool_use=None,
is_finished=True,
finish_reason="tool_calls",
usage=None,
)
else:
return GenericStreamingChunk(
text="",
tool_use=None,
is_finished=True,
finish_reason="stop",
usage=None,
)
else:
pass
@@ -631,6 +685,40 @@ def _patch_openai_responses_transform_response() -> None:
LiteLLMResponsesTransformationHandler.transform_response = _patched_transform_response # type: ignore[method-assign]
def _patch_azure_responses_should_fake_stream() -> None:
"""
Patches AzureOpenAIResponsesAPIConfig.should_fake_stream to always return False.
By default, LiteLLM uses "fake streaming" (MockResponsesAPIStreamingIterator) for models
not in its database. This causes Azure custom model deployments to buffer the entire
response before yielding, resulting in poor time-to-first-token.
Azure's Responses API supports native streaming, so we override this to always use
real streaming (SyncResponsesAPIStreamingIterator).
"""
from litellm.llms.azure.responses.transformation import (
AzureOpenAIResponsesAPIConfig,
)
if (
getattr(AzureOpenAIResponsesAPIConfig.should_fake_stream, "__name__", "")
== "_patched_should_fake_stream"
):
return
def _patched_should_fake_stream(
self: Any,
model: Optional[str],
stream: Optional[bool],
custom_llm_provider: Optional[str] = None,
) -> bool:
# Azure Responses API supports native streaming - never fake it
return False
_patched_should_fake_stream.__name__ = "_patched_should_fake_stream"
AzureOpenAIResponsesAPIConfig.should_fake_stream = _patched_should_fake_stream # type: ignore[method-assign]
def apply_monkey_patches() -> None:
"""
Apply all necessary monkey patches to LiteLLM for compatibility.
@@ -640,12 +728,13 @@ def apply_monkey_patches() -> None:
- Patching OllamaChatCompletionResponseIterator.chunk_parser for streaming content
- Patching OpenAiResponsesToChatCompletionStreamIterator.chunk_parser for OpenAI Responses API
- Patching LiteLLMResponsesTransformationHandler.transform_response for non-streaming responses
- Patching LiteLLMResponsesTransformationHandler._convert_content_str_to_input_text for tool content types
- Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming
"""
_patch_ollama_transform_request()
_patch_ollama_chunk_parser()
_patch_openai_responses_chunk_parser()
_patch_openai_responses_transform_response()
_patch_azure_responses_should_fake_stream()
def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[str]]:

View File

@@ -63,7 +63,7 @@ def process_with_prompt_cache(
return suffix, None
# Get provider adapter
provider_adapter = get_provider_adapter(llm_config.model_provider)
provider_adapter = get_provider_adapter(llm_config)
# If provider doesn't support caching, combine and return unchanged
if not provider_adapter.supports_caching():

View File

@@ -1,14 +1,17 @@
"""Factory for creating provider-specific prompt cache adapters."""
from onyx.llm.constants import LlmProviderNames
from onyx.llm.interfaces import LLMConfig
from onyx.llm.prompt_cache.providers.anthropic import AnthropicPromptCacheProvider
from onyx.llm.prompt_cache.providers.base import PromptCacheProvider
from onyx.llm.prompt_cache.providers.noop import NoOpPromptCacheProvider
from onyx.llm.prompt_cache.providers.openai import OpenAIPromptCacheProvider
from onyx.llm.prompt_cache.providers.vertex import VertexAIPromptCacheProvider
ANTHROPIC_BEDROCK_TAG = "anthropic."
def get_provider_adapter(provider: str) -> PromptCacheProvider:
def get_provider_adapter(llm_config: LLMConfig) -> PromptCacheProvider:
"""Get the appropriate prompt cache provider adapter for a given provider.
Args:
@@ -17,11 +20,14 @@ def get_provider_adapter(provider: str) -> PromptCacheProvider:
Returns:
PromptCacheProvider instance for the given provider
"""
if provider == LlmProviderNames.OPENAI:
if llm_config.model_provider == LlmProviderNames.OPENAI:
return OpenAIPromptCacheProvider()
elif provider in [LlmProviderNames.ANTHROPIC, LlmProviderNames.BEDROCK]:
elif llm_config.model_provider == LlmProviderNames.ANTHROPIC or (
llm_config.model_provider == LlmProviderNames.BEDROCK
and ANTHROPIC_BEDROCK_TAG in llm_config.model_name
):
return AnthropicPromptCacheProvider()
elif provider == LlmProviderNames.VERTEX_AI:
elif llm_config.model_provider == LlmProviderNames.VERTEX_AI:
return VertexAIPromptCacheProvider()
else:
# Default to no-op for providers without caching support

View File

@@ -1,30 +1,39 @@
from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS
SLACK_QUERY_EXPANSION_PROMPT = f"""
Rewrite the user's query and, if helpful, split it into at most {MAX_SLACK_QUERY_EXPANSIONS} \
keyword-only queries, so that Slack's keyword search yields the best matches.
Rewrite the user's query into at most {MAX_SLACK_QUERY_EXPANSIONS} keyword-only queries for Slack's keyword search.
Keep in mind the Slack's search behavior:
- Pure keyword AND search (no semantics).
- Word order matters.
- More words = fewer matches, so keep each query concise.
- IMPORTANT: Prefer simple 1-2 word queries over longer multi-word queries.
Slack search behavior:
- Pure keyword AND search (no semantics)
- More words = fewer matches, so keep queries concise (1-3 words)
Critical: Extract ONLY keywords that would actually appear in Slack message content.
ALWAYS include:
- Person names (e.g., "Sarah Chen", "Mike Johnson") - people search for messages from/about specific people
- Project/product names, technical terms, proper nouns
- Actual content words: "performance", "bug", "deployment", "API", "error"
DO NOT include:
- Meta-words: "topics", "conversations", "discussed", "summary", "messages", "big", "main", "talking"
- Temporal: "today", "yesterday", "week", "month", "recent", "past", "last"
- Channels/Users: "general", "eng-general", "engineering", "@username"
DO include:
- Actual content: "performance", "bug", "deployment", "API", "database", "error", "feature"
- Meta-words: "topics", "conversations", "discussed", "summary", "messages"
- Temporal: "today", "yesterday", "week", "month", "recent", "last"
- Channel names: "general", "eng-general", "random"
Examples:
Query: "what are the big topics in eng-general this week?"
Output:
Query: "messages with Sarah about the deployment"
Output:
Sarah deployment
Sarah
deployment
Query: "what did Mike say about the budget?"
Output:
Mike budget
Mike
budget
Query: "performance issues in eng-general"
Output:
performance issues
@@ -41,7 +50,7 @@ Now process this query:
{{query}}
Output:
Output (keywords only, one per line, NO explanations or commentary):
"""
SLACK_DATE_EXTRACTION_PROMPT = """

View File

@@ -109,6 +109,7 @@ class TenantRedis(redis.Redis):
"unlock",
"get",
"set",
"setex",
"delete",
"exists",
"incrby",

View File

@@ -697,7 +697,7 @@ def save_user_credentials(
# TODO: fix and/or type correctly w/base model
config_data = MCPConnectionData(
headers=auth_template.config.get("headers", {}),
header_substitutions=auth_template.config.get(HEADER_SUBSTITUTIONS, {}),
header_substitutions=request.credentials,
)
for oauth_field_key in MCPOAuthKeys:
field_key: Literal["client_info", "tokens", "metadata"] = (

View File

@@ -34,7 +34,7 @@ from onyx.db.persona import mark_persona_as_not_deleted
from onyx.db.persona import update_persona_is_default
from onyx.db.persona import update_persona_label
from onyx.db.persona import update_persona_public_status
from onyx.db.persona import update_persona_shared_users
from onyx.db.persona import update_persona_shared
from onyx.db.persona import update_persona_visibility
from onyx.db.persona import update_personas_display_priority
from onyx.file_store.file_store import get_default_file_store
@@ -366,7 +366,9 @@ def delete_label(
class PersonaShareRequest(BaseModel):
user_ids: list[UUID]
user_ids: list[UUID] | None = None
group_ids: list[int] | None = None
is_public: bool | None = None
# We notify each user when a user is shared with them
@@ -377,11 +379,13 @@ def share_persona(
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
update_persona_shared_users(
update_persona_shared(
persona_id=persona_id,
user_ids=persona_share_request.user_ids,
user=user,
db_session=db_session,
user_ids=persona_share_request.user_ids,
group_ids=persona_share_request.group_ids,
is_public=persona_share_request.is_public,
)

View File

@@ -87,6 +87,11 @@ def parse_mdx_to_release_note_entries(mdx_content: str) -> list[ReleaseNoteEntry
entries = [
entry for entry in all_entries if is_version_gte(entry.version, __version__)
]
elif "nightly" in __version__:
# Just show the latest entry for nightly versions
entries = sorted(
all_entries, key=lambda x: parse_version_tuple(x.version), reverse=True
)[:1]
else:
# If not recognized version
# likely `development` and we should show all entries

View File

@@ -410,26 +410,20 @@ def list_llm_provider_basics(
all_providers = fetch_existing_llm_providers(db_session)
user_group_ids = fetch_user_group_ids(db_session, user) if user else set()
is_admin = user and user.role == UserRole.ADMIN
is_admin = user is not None and user.role == UserRole.ADMIN
accessible_providers = []
for provider in all_providers:
# Include all public providers
if provider.is_public:
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
continue
# Include restricted providers user has access to via groups
if is_admin:
# Admins see all providers
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
elif provider.groups:
# User must be in at least one of the provider's groups
if user_group_ids.intersection({g.id for g in provider.groups}):
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
elif not provider.personas:
# No restrictions = accessible
# Use centralized access control logic with persona=None since we're
# listing providers without a specific persona context. This correctly:
# - Includes all public providers
# - Includes providers user can access via group membership
# - Excludes persona-only restricted providers (requires specific persona)
# - Excludes non-public providers with no restrictions (admin-only)
if can_user_access_llm_provider(
provider, user_group_ids, persona=None, is_admin=is_admin
):
accessible_providers.append(LLMProviderDescriptor.from_model(provider))
end_time = datetime.now(timezone.utc)

View File

@@ -4,10 +4,13 @@ from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import InternetContentProvider
from onyx.db.models import InternetSearchProvider
from onyx.db.models import User
from onyx.db.web_search import deactivate_web_content_provider
from onyx.db.web_search import deactivate_web_search_provider
@@ -94,6 +97,28 @@ def upsert_search_provider_endpoint(
db_session=db_session,
)
# Sync Exa key of search engine to content provider
if (
request.provider_type == WebSearchProviderType.EXA
and request.api_key_changed
and request.api_key
):
stmt = (
insert(InternetContentProvider)
.values(
name="Exa",
provider_type=WebContentProviderType.EXA.value,
api_key=request.api_key,
is_active=False,
)
.on_conflict_do_update(
index_elements=["name"],
set_={"api_key": request.api_key},
)
)
db_session.execute(stmt)
db_session.flush()
db_session.commit()
return WebSearchProviderView(
id=provider.id,
@@ -245,6 +270,28 @@ def upsert_content_provider_endpoint(
db_session=db_session,
)
# Sync Exa key of content provider to search provider
if (
request.provider_type == WebContentProviderType.EXA
and request.api_key_changed
and request.api_key
):
stmt = (
insert(InternetSearchProvider)
.values(
name="Exa",
provider_type=WebSearchProviderType.EXA.value,
api_key=request.api_key,
is_active=False,
)
.on_conflict_do_update(
index_elements=["name"],
set_={"api_key": request.api_key},
)
)
db_session.execute(stmt)
db_session.flush()
db_session.commit()
return WebContentProviderView(
id=provider.id,

View File

@@ -11,7 +11,7 @@ from onyx.db.chat import get_db_search_doc_by_id
from onyx.db.chat import translate_db_search_doc_to_saved_search_doc
from onyx.db.models import ChatMessage
from onyx.db.tools import get_tool_by_id
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_DB_NAME
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_IN_CODE_ID
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_TASK_KEY
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
@@ -401,7 +401,7 @@ def translate_assistant_message_to_packets(
# Here we do a try because some tools may get deleted before the session is reloaded.
try:
tool = get_tool_by_id(tool_call.tool_id, db_session)
if tool.in_code_tool_id == RESEARCH_AGENT_DB_NAME:
if tool.in_code_tool_id == RESEARCH_AGENT_IN_CODE_ID:
research_agent_count += 1
# Handle different tool types
@@ -457,7 +457,7 @@ def translate_assistant_message_to_packets(
)
)
elif tool.in_code_tool_id == RESEARCH_AGENT_DB_NAME:
elif tool.in_code_tool_id == RESEARCH_AGENT_IN_CODE_ID:
# Not ideal but not a huge issue if the research task is lost.
research_task = cast(
str,

View File

@@ -1,3 +1,4 @@
import re
from collections.abc import Sequence
from exa_py import Exa
@@ -19,6 +20,21 @@ from onyx.utils.retry_wrapper import retry_builder
logger = setup_logger()
def _extract_site_operators(query: str) -> tuple[str, list[str]]:
"""Extract site: operators and return cleaned query + full domains.
Returns (cleaned_query, full_domains) where full_domains contains the full
values after site: (e.g., ["reddit.com/r/leagueoflegends"]).
"""
full_domains = re.findall(r"site:\s*([^\s]+)", query, re.IGNORECASE)
cleaned_query = re.sub(r"site:\s*\S+\s*", "", query, flags=re.IGNORECASE).strip()
if not cleaned_query and full_domains:
cleaned_query = full_domains[0]
return cleaned_query, full_domains
class ExaClient(WebSearchProvider, WebContentProvider):
def __init__(self, api_key: str, num_results: int = 10) -> None:
self.exa = Exa(api_key=api_key)
@@ -28,8 +44,9 @@ class ExaClient(WebSearchProvider, WebContentProvider):
def supports_site_filter(self) -> bool:
return False
@retry_builder(tries=3, delay=1, backoff=2)
def search(self, query: str) -> list[WebSearchResult]:
def _search_exa(
self, query: str, include_domains: list[str] | None = None
) -> list[WebSearchResult]:
response = self.exa.search_and_contents(
query,
type="auto",
@@ -38,6 +55,7 @@ class ExaClient(WebSearchProvider, WebContentProvider):
highlights_per_url=1,
),
num_results=self._num_results,
include_domains=include_domains,
)
results: list[WebSearchResult] = []
@@ -60,6 +78,21 @@ class ExaClient(WebSearchProvider, WebContentProvider):
return results
@retry_builder(tries=3, delay=1, backoff=2)
def search(self, query: str) -> list[WebSearchResult]:
cleaned_query, full_domains = _extract_site_operators(query)
if full_domains:
# Try with include_domains using base domains (e.g., ["reddit.com"])
base_domains = [d.split("/")[0].removeprefix("www.") for d in full_domains]
results = self._search_exa(cleaned_query, include_domains=base_domains)
if results:
return results
# Fallback: add full domains as keywords
query_with_domains = f"{cleaned_query} {' '.join(full_domains)}".strip()
return self._search_exa(query_with_domains)
def test_connection(self) -> dict[str, str]:
try:
test_results = self.search("test")
@@ -113,6 +146,7 @@ class ExaClient(WebSearchProvider, WebContentProvider):
if result.published_date
else None
),
scrape_successful=bool(full_content),
)
)

View File

@@ -98,6 +98,9 @@ def build_content_provider_from_config(
timeout_seconds=config.timeout_seconds,
)
if provider_type == WebContentProviderType.EXA:
return ExaClient(api_key=api_key)
def get_default_provider() -> WebSearchProvider | None:
with get_session_with_current_tenant() as db_session:

View File

@@ -265,13 +265,22 @@ class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
)
# Format for LLM
docs_str, citation_mapping = convert_inference_sections_to_llm_string(
top_sections=inference_sections,
citation_start=override_kwargs.starting_citation_num,
limit=None, # Already truncated
include_source_type=False,
include_link=True,
)
if not all_search_results:
docs_str = json.dumps(
{
"results": [],
"message": "The web search completed but returned no results for any of the queries. Do not search again.",
}
)
citation_mapping: dict[int, str] = {}
else:
docs_str, citation_mapping = convert_inference_sections_to_llm_string(
top_sections=inference_sections,
citation_start=override_kwargs.starting_citation_num,
limit=None, # Already truncated
include_source_type=False,
include_link=True,
)
return ToolResponse(
rich_response=SearchDocsResponse(

View File

@@ -1,7 +1,7 @@
[project]
name = "onyx-backend"
version = "0.0.0"
requires-python = ">=3.11,<3.13"
requires-python = ">=3.11"
dependencies = [
"onyx[backend,dev,ee]",
]

View File

@@ -5,7 +5,9 @@ aioboto3==15.1.0
aiobotocore==2.24.0
# via aioboto3
aiofiles==25.1.0
# via aioboto3
# via
# aioboto3
# unstructured-client
aiohappyeyeballs==2.6.1
# via aiohttp
aiohttp==3.13.3
@@ -115,7 +117,6 @@ certifi==2025.11.12
# requests
# sentry-sdk
# trafilatura
# unstructured-client
cffi==2.0.0
# via
# argon2-cffi-bindings
@@ -123,9 +124,7 @@ cffi==2.0.0
# pynacl
# zstandard
chardet==5.2.0
# via
# onyx
# unstructured
# via onyx
charset-normalizer==3.4.4
# via
# htmldate
@@ -133,7 +132,7 @@ charset-normalizer==3.4.4
# pdfminer-six
# requests
# trafilatura
# unstructured-client
# unstructured
chevron==0.14.0
# via braintrust
chonkie==1.0.10
@@ -149,6 +148,7 @@ click==8.3.1
# litellm
# magika
# nltk
# python-oxmsg
# typer
# uvicorn
# zulip
@@ -185,6 +185,7 @@ cryptography==46.0.3
# pyjwt
# secretstorage
# sendgrid
# unstructured-client
cyclopts==4.2.4
# via fastmcp
dask==2023.8.1
@@ -192,17 +193,13 @@ dask==2023.8.1
# distributed
# onyx
dataclasses-json==0.6.7
# via
# unstructured
# unstructured-client
# via unstructured
dateparser==1.2.2
# via htmldate
ddtrace==3.10.0
# via onyx
decorator==5.2.1
# via retry
deepdiff==8.6.1
# via unstructured-client
defusedxml==0.7.1
# via
# jira
@@ -354,7 +351,7 @@ greenlet==3.2.4
# sqlalchemy
grpc-google-iam-v1==0.14.3
# via google-cloud-resource-manager
grpcio==1.67.1
grpcio==1.67.1 ; python_full_version < '3.14'
# via
# google-api-core
# google-cloud-resource-manager
@@ -362,7 +359,17 @@ grpcio==1.67.1
# grpc-google-iam-v1
# grpcio-status
# litellm
grpcio-status==1.67.1
grpcio==1.76.0 ; python_full_version >= '3.14'
# via
# google-api-core
# google-cloud-resource-manager
# googleapis-common-protos
# grpc-google-iam-v1
# grpcio-status
# litellm
grpcio-status==1.67.1 ; python_full_version < '3.14'
# via google-api-core
grpcio-status==1.76.0 ; python_full_version >= '3.14'
# via google-api-core
h11==0.16.0
# via
@@ -374,12 +381,15 @@ hf-xet==1.2.0 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or
# via huggingface-hub
hpack==4.1.0
# via h2
html5lib==1.1
# via unstructured
htmldate==1.9.1
# via trafilatura
httpcore==1.0.9
# via
# httpx
# onyx
# unstructured-client
httplib2==0.31.0
# via
# google-api-python-client
@@ -420,7 +430,6 @@ idna==3.11
# email-validator
# httpx
# requests
# unstructured-client
# yarl
importlib-metadata==8.7.0
# via
@@ -466,8 +475,6 @@ joblib==1.5.2
# via nltk
jsonpatch==1.33
# via langchain-core
jsonpath-python==1.0.6
# via unstructured-client
jsonpointer==3.0.0
# via jsonpatch
jsonref==1.1.0
@@ -509,6 +516,8 @@ langsmith==0.3.45
# langchain-core
lazy-imports==1.0.1
# via onyx
legacy-cgi==2.6.4 ; python_full_version >= '3.13'
# via ddtrace
litellm==1.80.11
# via onyx
locket==1.0.0
@@ -555,9 +564,7 @@ markupsafe==3.0.3
# mako
# werkzeug
marshmallow==3.26.2
# via
# dataclasses-json
# unstructured-client
# via dataclasses-json
matrix-client==0.3.2
# via zulip
mcp==1.25.0
@@ -598,16 +605,13 @@ mypy-extensions==1.0.0
# via
# mypy
# typing-inspect
# unstructured-client
nest-asyncio==1.6.0
# via
# onyx
# unstructured-client
# via onyx
nltk==3.9.1
# via
# onyx
# unstructured
numpy==1.26.4
numpy==2.4.1
# via
# magika
# onnxruntime
@@ -623,7 +627,9 @@ oauthlib==3.2.2
office365-rest-python-client==2.5.9
# via onyx
olefile==0.47
# via msoffcrypto-tool
# via
# msoffcrypto-tool
# python-oxmsg
onnxruntime==1.20.1
# via magika
openai==2.14.0
@@ -678,8 +684,6 @@ opentelemetry-semantic-conventions==0.60b1
# via
# opentelemetry-instrumentation
# opentelemetry-sdk
orderly-set==5.5.0
# via deepdiff
orjson==3.11.4 ; platform_python_implementation != 'PyPy'
# via langsmith
packaging==24.2
@@ -700,7 +704,6 @@ packaging==24.2
# opentelemetry-instrumentation
# pytest
# pywikibot
# unstructured-client
pandas==2.2.3
# via markitdown
parameterized==0.9.0
@@ -748,7 +751,19 @@ proto-plus==1.26.1
# google-api-core
# google-cloud-aiplatform
# google-cloud-resource-manager
protobuf==5.29.5
protobuf==5.29.5 ; python_full_version < '3.14'
# via
# ddtrace
# google-api-core
# google-cloud-aiplatform
# google-cloud-resource-manager
# googleapis-common-protos
# grpc-google-iam-v1
# grpcio-status
# onnxruntime
# opentelemetry-proto
# proto-plus
protobuf==6.33.4 ; python_full_version >= '3.14'
# via
# ddtrace
# google-api-core
@@ -810,6 +825,7 @@ pydantic==2.11.7
# openapi-pydantic
# pyairtable
# pydantic-settings
# unstructured-client
pydantic-core==2.33.2
# via pydantic
pydantic-settings==2.12.0
@@ -835,7 +851,7 @@ pynacl==1.6.2
# via pygithub
pyparsing==3.2.5
# via httplib2
pypdf==6.1.3
pypdf==6.6.0
# via
# onyx
# unstructured-client
@@ -867,7 +883,6 @@ python-dateutil==2.8.2
# onyx
# opensearch-py
# pandas
# unstructured-client
python-docx==1.1.2
# via onyx
python-dotenv==1.1.1
@@ -894,6 +909,8 @@ python-multipart==0.0.20
# fastapi-users
# mcp
# onyx
python-oxmsg==0.0.2
# via unstructured
python-pptx==0.6.23
# via
# markitdown
@@ -985,7 +1002,6 @@ requests==2.32.5
# stripe
# tiktoken
# unstructured
# unstructured-client
# voyageai
# zeep
# zulip
@@ -1045,12 +1061,12 @@ six==1.17.0
# atlassian-python-api
# dropbox
# google-auth-httplib2
# html5lib
# hubspot-api-client
# langdetect
# markdownify
# python-dateutil
# stone
# unstructured-client
slack-sdk==3.20.2
# via onyx
smmap==5.0.2
@@ -1089,8 +1105,6 @@ supervisor==4.3.0
# via onyx
sympy==1.13.1
# via onnxruntime
tabulate==0.9.0
# via unstructured
tblib==3.2.2
# via distributed
tenacity==9.1.2
@@ -1158,6 +1172,7 @@ typing-extensions==4.15.0
# fastapi
# google-cloud-aiplatform
# google-genai
# grpcio
# huggingface-hub
# jira
# langchain-core
@@ -1178,6 +1193,7 @@ typing-extensions==4.15.0
# pyee
# pygithub
# python-docx
# python-oxmsg
# referencing
# simple-salesforce
# sqlalchemy
@@ -1187,12 +1203,9 @@ typing-extensions==4.15.0
# typing-inspect
# typing-inspection
# unstructured
# unstructured-client
# zulip
typing-inspect==0.9.0
# via
# dataclasses-json
# unstructured-client
# via dataclasses-json
typing-inspection==0.4.2
# via
# mcp
@@ -1205,9 +1218,9 @@ tzdata==2025.2
# tzlocal
tzlocal==5.3.1
# via dateparser
unstructured==0.15.1
unstructured==0.18.27
# via onyx
unstructured-client==0.25.4
unstructured-client==0.42.6
# via
# onyx
# unstructured
@@ -1229,7 +1242,6 @@ urllib3==2.6.3
# sentry-sdk
# trafilatura
# types-requests
# unstructured-client
uvicorn==0.35.0
# via
# fastmcp
@@ -1244,6 +1256,8 @@ voyageai==0.2.3
# via onyx
wcwidth==0.2.14
# via prompt-toolkit
webencodings==0.5.1
# via html5lib
websockets==15.0.1
# via
# fastmcp

View File

@@ -175,7 +175,7 @@ greenlet==3.2.4 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or
# via sqlalchemy
grpc-google-iam-v1==0.14.3
# via google-cloud-resource-manager
grpcio==1.67.1
grpcio==1.67.1 ; python_full_version < '3.14'
# via
# google-api-core
# google-cloud-resource-manager
@@ -183,7 +183,17 @@ grpcio==1.67.1
# grpc-google-iam-v1
# grpcio-status
# litellm
grpcio-status==1.67.1
grpcio==1.76.0 ; python_full_version >= '3.14'
# via
# google-api-core
# google-cloud-resource-manager
# googleapis-common-protos
# grpc-google-iam-v1
# grpcio-status
# litellm
grpcio-status==1.67.1 ; python_full_version < '3.14'
# via google-api-core
grpcio-status==1.76.0 ; python_full_version >= '3.14'
# via google-api-core
h11==0.16.0
# via
@@ -278,14 +288,14 @@ nest-asyncio==1.6.0
# via ipykernel
nodeenv==1.9.1
# via pre-commit
numpy==1.26.4
numpy==2.4.1
# via
# contourpy
# matplotlib
# pandas-stubs
# shapely
# voyageai
onyx-devtools==0.2.0
onyx-devtools==0.6.2
# via onyx
openai==2.14.0
# via
@@ -347,7 +357,16 @@ proto-plus==1.26.1
# google-api-core
# google-cloud-aiplatform
# google-cloud-resource-manager
protobuf==5.29.5
protobuf==5.29.5 ; python_full_version < '3.14'
# via
# google-api-core
# google-cloud-aiplatform
# google-cloud-resource-manager
# googleapis-common-protos
# grpc-google-iam-v1
# grpcio-status
# proto-plus
protobuf==6.33.4 ; python_full_version >= '3.14'
# via
# google-api-core
# google-cloud-aiplatform
@@ -546,6 +565,7 @@ typing-extensions==4.15.0
# fastapi
# google-cloud-aiplatform
# google-genai
# grpcio
# huggingface-hub
# ipython
# mypy

View File

@@ -132,7 +132,7 @@ googleapis-common-protos==1.72.0
# grpcio-status
grpc-google-iam-v1==0.14.3
# via google-cloud-resource-manager
grpcio==1.67.1
grpcio==1.67.1 ; python_full_version < '3.14'
# via
# google-api-core
# google-cloud-resource-manager
@@ -140,7 +140,17 @@ grpcio==1.67.1
# grpc-google-iam-v1
# grpcio-status
# litellm
grpcio-status==1.67.1
grpcio==1.76.0 ; python_full_version >= '3.14'
# via
# google-api-core
# google-cloud-resource-manager
# googleapis-common-protos
# grpc-google-iam-v1
# grpcio-status
# litellm
grpcio-status==1.67.1 ; python_full_version < '3.14'
# via google-api-core
grpcio-status==1.76.0 ; python_full_version >= '3.14'
# via google-api-core
h11==0.16.0
# via
@@ -192,7 +202,7 @@ multidict==6.7.0
# aiobotocore
# aiohttp
# yarl
numpy==1.26.4
numpy==2.4.1
# via
# shapely
# voyageai
@@ -224,7 +234,16 @@ proto-plus==1.26.1
# google-api-core
# google-cloud-aiplatform
# google-cloud-resource-manager
protobuf==5.29.5
protobuf==5.29.5 ; python_full_version < '3.14'
# via
# google-api-core
# google-cloud-aiplatform
# google-cloud-resource-manager
# googleapis-common-protos
# grpc-google-iam-v1
# grpcio-status
# proto-plus
protobuf==6.33.4 ; python_full_version >= '3.14'
# via
# google-api-core
# google-cloud-aiplatform
@@ -329,6 +348,7 @@ typing-extensions==4.15.0
# fastapi
# google-cloud-aiplatform
# google-genai
# grpcio
# huggingface-hub
# openai
# pydantic

View File

@@ -157,7 +157,7 @@ googleapis-common-protos==1.72.0
# grpcio-status
grpc-google-iam-v1==0.14.3
# via google-cloud-resource-manager
grpcio==1.67.1
grpcio==1.67.1 ; python_full_version < '3.14'
# via
# google-api-core
# google-cloud-resource-manager
@@ -165,7 +165,17 @@ grpcio==1.67.1
# grpc-google-iam-v1
# grpcio-status
# litellm
grpcio-status==1.67.1
grpcio==1.76.0 ; python_full_version >= '3.14'
# via
# google-api-core
# google-cloud-resource-manager
# googleapis-common-protos
# grpc-google-iam-v1
# grpcio-status
# litellm
grpcio-status==1.67.1 ; python_full_version < '3.14'
# via google-api-core
grpcio-status==1.76.0 ; python_full_version >= '3.14'
# via google-api-core
h11==0.16.0
# via
@@ -229,7 +239,7 @@ multidict==6.7.0
# yarl
networkx==3.5
# via torch
numpy==1.26.4
numpy==2.4.1
# via
# accelerate
# onyx
@@ -306,7 +316,16 @@ proto-plus==1.26.1
# google-api-core
# google-cloud-aiplatform
# google-cloud-resource-manager
protobuf==5.29.5
protobuf==5.29.5 ; python_full_version < '3.14'
# via
# google-api-core
# google-cloud-aiplatform
# google-cloud-resource-manager
# googleapis-common-protos
# grpc-google-iam-v1
# grpcio-status
# proto-plus
protobuf==6.33.4 ; python_full_version >= '3.14'
# via
# google-api-core
# google-cloud-aiplatform
@@ -450,6 +469,7 @@ typing-extensions==4.15.0
# fastapi
# google-cloud-aiplatform
# google-genai
# grpcio
# huggingface-hub
# openai
# pydantic

View File

@@ -31,3 +31,4 @@ class WebSearchProviderType(str, Enum):
class WebContentProviderType(str, Enum):
ONYX_WEB_CRAWLER = "onyx_web_crawler"
FIRECRAWL = "firecrawl"
EXA = "exa"

View File

@@ -0,0 +1,281 @@
"""
External dependency unit tests for user file processing queue protections.
Verifies that the three mechanisms added to check_user_file_processing work
correctly:
1. Queue depth backpressure when the broker queue exceeds
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH, no new tasks are enqueued.
2. Per-file Redis guard key if the guard key for a file already exists in
Redis, that file is skipped even though it is still in PROCESSING status.
3. Task expiry every send_task call carries expires=
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES so that stale queued tasks are
discarded by workers automatically.
Also verifies that process_single_user_file clears the guard key the moment
it is picked up by a worker.
Uses real Redis (DB 0 via get_redis_client) and real PostgreSQL for UserFile
rows. The Celery app is provided as a MagicMock injected via a PropertyMock
on the task class so no real broker is needed.
"""
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from unittest.mock import PropertyMock
from uuid import uuid4
from sqlalchemy.orm import Session
from onyx.background.celery.tasks.user_file_processing.tasks import (
_user_file_lock_key,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
_user_file_queued_key,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
check_user_file_processing,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
process_single_user_file,
)
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
from onyx.db.enums import UserFileStatus
from onyx.db.models import UserFile
from onyx.redis.redis_pool import get_redis_client
from tests.external_dependency_unit.conftest import create_test_user
from tests.external_dependency_unit.constants import TEST_TENANT_ID
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_PATCH_QUEUE_LEN = (
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_queue_length"
)
def _create_processing_user_file(db_session: Session, user_id: object) -> UserFile:
"""Insert a UserFile in PROCESSING status and return it."""
uf = UserFile(
id=uuid4(),
user_id=user_id,
file_id=f"test_file_{uuid4().hex[:8]}",
name=f"test_{uuid4().hex[:8]}.txt",
file_type="text/plain",
status=UserFileStatus.PROCESSING,
)
db_session.add(uf)
db_session.commit()
db_session.refresh(uf)
return uf
@contextmanager
def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, None]:
"""Patch the ``app`` property on *task*'s class so that ``self.app``
inside the task function returns *mock_app*.
With ``bind=True``, ``task.run`` is a bound method whose ``__self__`` is
the actual task instance. We patch ``app`` on that instance's class
(a unique Celery-generated Task subclass) so the mock is scoped to this
task only.
"""
task_instance = task.run.__self__
with patch.object(
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
):
yield
# ---------------------------------------------------------------------------
# Test classes
# ---------------------------------------------------------------------------
class TestQueueDepthBackpressure:
"""Protection 1: skip all enqueuing when the broker queue is too deep."""
def test_no_tasks_enqueued_when_queue_over_limit(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""When the queue depth exceeds the limit the beat cycle is skipped."""
user = create_test_user(db_session, "bp_user")
_create_processing_user_file(db_session, user.id)
mock_app = MagicMock()
with (
_patch_task_app(check_user_file_processing, mock_app),
patch(
_PATCH_QUEUE_LEN, return_value=USER_FILE_PROCESSING_MAX_QUEUE_DEPTH + 1
),
):
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
mock_app.send_task.assert_not_called()
class TestPerFileGuardKey:
"""Protection 2: per-file Redis guard key prevents duplicate enqueue."""
def test_guarded_file_not_re_enqueued(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""A file whose guard key is already set in Redis is skipped."""
user = create_test_user(db_session, "guard_user")
uf = _create_processing_user_file(db_session, user.id)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_queued_key(uf.id)
redis_client.setex(guard_key, CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, 1)
mock_app = MagicMock()
try:
with (
_patch_task_app(check_user_file_processing, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=0),
):
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
# send_task must not have been called with this specific file's ID
for call in mock_app.send_task.call_args_list:
kwargs = call.kwargs.get("kwargs", {})
assert kwargs.get("user_file_id") != str(
uf.id
), f"File {uf.id} should have been skipped because its guard key exists"
finally:
redis_client.delete(guard_key)
def test_guard_key_exists_in_redis_after_enqueue(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""After a file is enqueued its guard key is present in Redis with a TTL."""
user = create_test_user(db_session, "guard_set_user")
uf = _create_processing_user_file(db_session, user.id)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_queued_key(uf.id)
redis_client.delete(guard_key) # clean slate
mock_app = MagicMock()
try:
with (
_patch_task_app(check_user_file_processing, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=0),
):
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
assert redis_client.exists(
guard_key
), "Guard key should be set in Redis after enqueue"
ttl = int(redis_client.ttl(guard_key)) # type: ignore[arg-type]
assert 0 < ttl <= CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, (
f"Guard key TTL {ttl}s is outside the expected range "
f"(0, {CELERY_USER_FILE_PROCESSING_TASK_EXPIRES}]"
)
finally:
redis_client.delete(guard_key)
class TestTaskExpiry:
"""Protection 3: every send_task call includes an expires value."""
def test_send_task_called_with_expires(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""send_task is called with the correct queue, task name, and expires."""
user = create_test_user(db_session, "expires_user")
uf = _create_processing_user_file(db_session, user.id)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_queued_key(uf.id)
redis_client.delete(guard_key)
mock_app = MagicMock()
try:
with (
_patch_task_app(check_user_file_processing, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=0),
):
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
# At least one task should have been submitted (for our file)
assert (
mock_app.send_task.call_count >= 1
), "Expected at least one task to be submitted"
# Every submitted task must carry expires
for call in mock_app.send_task.call_args_list:
assert call.args[0] == OnyxCeleryTask.PROCESS_SINGLE_USER_FILE
assert call.kwargs.get("queue") == OnyxCeleryQueues.USER_FILE_PROCESSING
assert (
call.kwargs.get("expires")
== CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
), (
"Task must be submitted with the correct expires value to prevent "
"stale task accumulation"
)
finally:
redis_client.delete(guard_key)
class TestWorkerClearsGuardKey:
"""process_single_user_file removes the guard key when it picks up a task."""
def test_guard_key_deleted_on_pickup(
self,
tenant_context: None, # noqa: ARG002
) -> None:
"""The guard key is deleted before the worker does any real work.
We simulate an already-locked file so process_single_user_file returns
early but crucially, after the guard key deletion.
"""
user_file_id = str(uuid4())
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_queued_key(user_file_id)
# Simulate the guard key set when the beat enqueued the task
redis_client.setex(guard_key, CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, 1)
assert redis_client.exists(guard_key), "Guard key must exist before pickup"
# Hold the per-file processing lock so the worker exits early without
# touching the database or file store.
lock_key = _user_file_lock_key(user_file_id)
processing_lock = redis_client.lock(lock_key, timeout=10)
acquired = processing_lock.acquire(blocking=False)
assert acquired, "Should be able to acquire the processing lock for this test"
try:
process_single_user_file.run(
user_file_id=user_file_id,
tenant_id=TEST_TENANT_ID,
)
finally:
if processing_lock.owned():
processing_lock.release()
assert not redis_client.exists(
guard_key
), "Guard key should be deleted when the worker picks up the task"

View File

@@ -4,9 +4,11 @@ These tests assume OpenSearch is running and test all implemented methods
using real schemas, pipelines, and search queries from the codebase.
"""
import json
import uuid
from collections.abc import Generator
from typing import Any
from datetime import datetime
from datetime import timezone
import pytest
@@ -21,18 +23,31 @@ from onyx.document_index.opensearch.search import (
MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
)
from onyx.document_index.opensearch.search import MIN_MAX_NORMALIZATION_PIPELINE_NAME
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
def _patch_global_tenant_state(monkeypatch: pytest.MonkeyPatch, state: bool) -> None:
"""Patches MULTI_TENANT wherever necessary for this test file.
Args:
monkeypatch: The test instance's monkeypatch instance, used for
patching.
state: The intended state of MULTI_TENANT.
"""
monkeypatch.setattr("shared_configs.configs.MULTI_TENANT", state)
monkeypatch.setattr("onyx.document_index.opensearch.schema.MULTI_TENANT", state)
def _create_test_document_chunk(
document_id: str = "test-doc-1",
chunk_index: int = 0,
content: str = "Test content",
document_id: str,
chunk_index: int,
content: str,
tenant_state: TenantState,
content_vector: list[float] | None = None,
title: str | None = None,
title_vector: list[float] | None = None,
public: bool = True,
hidden: bool = False,
**kwargs: Any,
) -> DocumentChunk:
if content_vector is None:
# Generate dummy vector - 128 dimensions for fast testing.
@@ -42,31 +57,51 @@ def _create_test_document_chunk(
if title is not None and title_vector is None:
title_vector = [0.2] * 128
now = datetime.now(timezone.utc)
# We only store millisecond precision, so to make sure asserts work in this
# test file manually lose some precision from datetime.now().
now = now.replace(microsecond=(now.microsecond // 1000) * 1000)
return DocumentChunk(
document_id=document_id,
chunk_index=chunk_index,
content=content,
content_vector=content_vector,
title=title,
title_vector=title_vector,
# This is not how tokenization necessarily works, this is just for quick
# testing.
num_tokens=len(content.split()),
content=content,
content_vector=content_vector,
source_type="test_source",
metadata=json.dumps({}),
last_updated=now,
public=public,
access_control_list=[],
hidden=hidden,
**kwargs,
global_boost=0,
semantic_identifier="Test semantic identifier",
image_file_id=None,
source_links=None,
blurb="Test blurb",
doc_summary="Test doc summary",
chunk_context="Test chunk context",
document_sets=None,
project_ids=None,
primary_owners=None,
secondary_owners=None,
tenant_id=tenant_state,
)
def _generate_test_vector(base_value: float = 0.1, dimension: int = 128) -> list[float]:
"""Generate a test vector with slight variations."""
return [base_value + (i * 0.001) for i in range(dimension)]
"""Generates a test vector with slight variations.
We round to eliminate floating point precision errors when comparing chunks
for equality.
"""
return [round(base_value + (i * 0.001), 5) for i in range(dimension)]
@pytest.fixture(scope="module")
def opensearch_available() -> None:
"""Verify OpenSearch is running, skip all tests if not."""
"""Verifies OpenSearch is running, skips all tests if not."""
client = OpenSearchClient(index_name="test_ping")
try:
if not client.ping():
@@ -228,11 +263,15 @@ class TestOpenSearchClient:
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME
)
def test_index_document(self, test_client: OpenSearchClient) -> None:
def test_index_document(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests indexing a document."""
# Precondition.
_patch_global_tenant_state(monkeypatch, False)
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=mappings, settings=settings)
@@ -241,17 +280,22 @@ class TestOpenSearchClient:
document_id="test-doc-1",
chunk_index=0,
content="Test content for indexing",
tenant_state=tenant_state,
)
# Under test and postcondition.
# Should not raise.
test_client.index_document(document=doc)
def test_index_duplicate_document(self, test_client: OpenSearchClient) -> None:
def test_index_duplicate_document(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests indexing a duplicate document raises an error."""
# Precondition.
_patch_global_tenant_state(monkeypatch, False)
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=mappings, settings=settings)
@@ -260,6 +304,7 @@ class TestOpenSearchClient:
document_id="test-doc-duplicate",
chunk_index=0,
content="Duplicate test",
tenant_state=tenant_state,
)
# Index once - should succeed.
@@ -270,11 +315,15 @@ class TestOpenSearchClient:
with pytest.raises(Exception, match="already exists"):
test_client.index_document(document=doc)
def test_get_document(self, test_client: OpenSearchClient) -> None:
def test_get_document(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests getting a document."""
# Precondition.
_patch_global_tenant_state(monkeypatch, False)
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=mappings, settings=settings)
@@ -283,6 +332,7 @@ class TestOpenSearchClient:
document_id="test-doc-get",
chunk_index=0,
content="Content to retrieve",
tenant_state=tenant_state,
)
test_client.index_document(document=original_doc)
@@ -297,11 +347,14 @@ class TestOpenSearchClient:
# Postcondition.
assert retrieved_doc == original_doc
def test_get_nonexistent_document(self, test_client: OpenSearchClient) -> None:
def test_get_nonexistent_document(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests getting a nonexistent document raises an error."""
# Precondition.
_patch_global_tenant_state(monkeypatch, False)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
vector_dimension=128, multitenant=False
)
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=mappings, settings=settings)
@@ -312,11 +365,15 @@ class TestOpenSearchClient:
document_chunk_id="test_source__nonexistent__512__0"
)
def test_delete_existing_document(self, test_client: OpenSearchClient) -> None:
def test_delete_existing_document(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests deleting an existing document returns True."""
# Precondition.
_patch_global_tenant_state(monkeypatch, False)
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=mappings, settings=settings)
@@ -325,6 +382,7 @@ class TestOpenSearchClient:
document_id="test-doc-delete",
chunk_index=0,
content="Content to delete",
tenant_state=tenant_state,
)
test_client.index_document(document=doc)
@@ -342,11 +400,15 @@ class TestOpenSearchClient:
with pytest.raises(Exception, match="404"):
test_client.get_document(document_chunk_id=doc_chunk_id)
def test_delete_nonexistent_document(self, test_client: OpenSearchClient) -> None:
def test_delete_nonexistent_document(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests deleting a nonexistent document returns False."""
# Precondition.
_patch_global_tenant_state(monkeypatch, False)
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=mappings, settings=settings)
@@ -359,11 +421,15 @@ class TestOpenSearchClient:
# Postcondition.
assert result is False
def test_delete_by_query(self, test_client: OpenSearchClient) -> None:
def test_delete_by_query(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests deleting documents by query."""
# Precondition.
_patch_global_tenant_state(monkeypatch, False)
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=mappings, settings=settings)
@@ -374,7 +440,7 @@ class TestOpenSearchClient:
document_id="delete-me",
chunk_index=i,
content=f"Delete this {i}",
tenant_id="tenant-x",
tenant_state=tenant_state,
)
for i in range(3)
]
@@ -383,7 +449,7 @@ class TestOpenSearchClient:
document_id="keep-me",
chunk_index=0,
content="Keep this",
tenant_id="tenant-x",
tenant_state=tenant_state,
)
]
@@ -393,7 +459,7 @@ class TestOpenSearchClient:
query_body = DocumentQuery.delete_from_document_id_query(
document_id="delete-me",
tenant_state=TenantState(tenant_id="tenant-x", multitenant=True),
tenant_state=tenant_state,
)
# Under test.
@@ -406,7 +472,7 @@ class TestOpenSearchClient:
test_client.refresh_index()
search_query = DocumentQuery.get_from_document_id_query(
document_id="delete-me",
tenant_state=TenantState(tenant_id="tenant-x", multitenant=True),
tenant_state=tenant_state,
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
min_chunk_index=None,
max_chunk_index=None,
@@ -418,7 +484,7 @@ class TestOpenSearchClient:
# Verify other documents still exist.
keep_query = DocumentQuery.get_from_document_id_query(
document_id="keep-me",
tenant_state=TenantState(tenant_id="tenant-x", multitenant=True),
tenant_state=tenant_state,
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
min_chunk_index=None,
max_chunk_index=None,
@@ -432,37 +498,44 @@ class TestOpenSearchClient:
with pytest.raises(NotImplementedError):
test_client.update_document()
def test_search_basic(self, test_client: OpenSearchClient) -> None:
def test_search_basic(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests basic search functionality."""
# Precondition.
_patch_global_tenant_state(monkeypatch, False)
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=False
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=mappings, settings=settings)
# Index multiple documents with different content and vectors.
docs = [
_create_test_document_chunk(
docs = {
"search-doc-1": _create_test_document_chunk(
document_id="search-doc-1",
chunk_index=0,
content="Python programming language tutorial",
content_vector=_generate_test_vector(0.1),
tenant_state=tenant_state,
),
_create_test_document_chunk(
"search-doc-2": _create_test_document_chunk(
document_id="search-doc-2",
chunk_index=0,
content="How to make cheese",
content_vector=_generate_test_vector(0.2),
tenant_state=tenant_state,
),
_create_test_document_chunk(
"search-doc-3": _create_test_document_chunk(
document_id="search-doc-3",
chunk_index=0,
content="C++ for newborns",
content_vector=_generate_test_vector(0.15),
tenant_state=tenant_state,
),
]
for doc in docs:
}
for doc in docs.values():
test_client.index_document(document=doc)
# Refresh index to make documents searchable.
@@ -476,47 +549,57 @@ class TestOpenSearchClient:
query_vector=query_vector,
num_candidates=10,
num_hits=5,
tenant_state=TenantState(tenant_id="", multitenant=False),
tenant_state=tenant_state,
)
# Under test.
results = test_client.search(body=search_body, search_pipeline_id=None)
# Postcondition.
assert len(results) > 0
assert len(results) == 3
# Assert that all the chunks above are present.
assert all(
chunk.document_id in ["search-doc-1", "search-doc-2", "search-doc-3"]
for chunk in results
)
# Make sure the chunk contents are preserved.
for chunk in results:
assert chunk == docs[chunk.document_id]
def test_search_with_pipeline(
self, test_client: OpenSearchClient, search_pipeline: None
self,
test_client: OpenSearchClient,
search_pipeline: None,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Tests search with a normalization pipeline."""
# Precondition.
_patch_global_tenant_state(monkeypatch, False)
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=False
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=mappings, settings=settings)
# Index documents.
docs = [
_create_test_document_chunk(
docs = {
"pipeline-doc-1": _create_test_document_chunk(
document_id="pipeline-doc-1",
chunk_index=0,
content="Machine learning algorithms for single-celled organisms",
content_vector=_generate_test_vector(0.3),
tenant_state=tenant_state,
),
_create_test_document_chunk(
"pipeline-doc-2": _create_test_document_chunk(
document_id="pipeline-doc-2",
chunk_index=0,
content="Deep learning shallow neural networks",
content_vector=_generate_test_vector(0.35),
tenant_state=tenant_state,
),
]
for doc in docs:
}
for doc in docs.values():
test_client.index_document(document=doc)
# Refresh index to make documents searchable
@@ -530,7 +613,7 @@ class TestOpenSearchClient:
query_vector=query_vector,
num_candidates=10,
num_hits=5,
tenant_state=TenantState(tenant_id="", multitenant=False),
tenant_state=tenant_state,
)
# Under test.
@@ -539,18 +622,25 @@ class TestOpenSearchClient:
)
# Postcondition.
assert len(results) > 0
assert len(results) == 2
# Assert that all the chunks above are present.
assert all(
chunk.document_id in ["pipeline-doc-1", "pipeline-doc-2"]
for chunk in results
)
# Make sure the chunk contents are preserved.
for chunk in results:
assert chunk == docs[chunk.document_id]
def test_search_empty_index(self, test_client: OpenSearchClient) -> None:
def test_search_empty_index(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests search on an empty index returns an empty list."""
# Precondition.
_patch_global_tenant_state(monkeypatch, False)
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=False
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=mappings, settings=settings)
@@ -564,7 +654,7 @@ class TestOpenSearchClient:
query_vector=query_vector,
num_candidates=10,
num_hits=5,
tenant_state=TenantState(tenant_id="", multitenant=False),
tenant_state=tenant_state,
)
# Under test.
@@ -573,43 +663,60 @@ class TestOpenSearchClient:
# Postcondition.
assert len(results) == 0
def test_search_filters(self, test_client: OpenSearchClient) -> None:
"""Tests search filters for public/hidden documents."""
def test_search_filters(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""
Tests search filters for public/hidden documents and tenant isolation.
"""
# Precondition.
_patch_global_tenant_state(monkeypatch, True)
tenant_x = TenantState(tenant_id="tenant-x", multitenant=True)
tenant_y = TenantState(tenant_id="tenant-y", multitenant=True)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
vector_dimension=128, multitenant=tenant_x.multitenant
)
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=mappings, settings=settings)
# Index documents with different public/hidden states.
docs = [
_create_test_document_chunk(
# Index documents with different public/hidden and tenant states.
docs = {
"public-doc-1": _create_test_document_chunk(
document_id="public-doc-1",
chunk_index=0,
content="Public document content",
public=True,
hidden=False,
tenant_id="tenant-x",
tenant_state=tenant_x,
),
_create_test_document_chunk(
"hidden-doc-1": _create_test_document_chunk(
document_id="hidden-doc-1",
chunk_index=0,
content="Hidden document content, spooky",
public=True,
hidden=True,
tenant_id="tenant-x",
tenant_state=tenant_x,
),
_create_test_document_chunk(
"private-doc-1": _create_test_document_chunk(
document_id="private-doc-1",
chunk_index=0,
content="Private document content, btw my SSN is 123-45-6789",
public=False,
hidden=False,
tenant_id="tenant-x",
tenant_state=tenant_x,
),
]
for doc in docs:
"should-not-exist-from-tenant-x-pov": _create_test_document_chunk(
document_id="should-not-exist-from-tenant-x-pov",
chunk_index=0,
content="This is an entirely different tenant, x should never see this",
# Make this as permissive as possible to exercise tenant
# isolation.
public=True,
hidden=False,
tenant_state=tenant_y,
),
}
for doc in docs.values():
test_client.index_document(document=doc)
# Refresh index to make documents searchable.
@@ -625,7 +732,7 @@ class TestOpenSearchClient:
query_vector=query_vector,
num_candidates=10,
num_hits=5,
tenant_state=TenantState(tenant_id="tenant-x", multitenant=True),
tenant_state=tenant_x,
)
# Under test.
@@ -635,19 +742,24 @@ class TestOpenSearchClient:
# Should only get the public, non-hidden document.
assert len(results) == 1
assert results[0].document_id == "public-doc-1"
assert results[0].public is True
assert results[0].hidden is False
# Make sure the chunk contents are preserved.
assert results[0] == docs["public-doc-1"]
def test_search_with_pipeline_and_filters_returns_chunks_with_related_content_first(
self, test_client: OpenSearchClient, search_pipeline: None
self,
test_client: OpenSearchClient,
search_pipeline: None,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""
Tests search with a normalization pipeline and filters returns chunks
with related content first.
"""
# Precondition.
_patch_global_tenant_state(monkeypatch, True)
tenant_x = TenantState(tenant_id="tenant-x", multitenant=True)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
vector_dimension=128, multitenant=tenant_x.multitenant
)
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=mappings, settings=settings)
@@ -664,7 +776,7 @@ class TestOpenSearchClient:
), # Very close to query vector.
public=True,
hidden=False,
tenant_id="tenant-x",
tenant_state=tenant_x,
),
_create_test_document_chunk(
document_id="somewhat-relevant-1",
@@ -673,7 +785,7 @@ class TestOpenSearchClient:
content_vector=_generate_test_vector(0.5), # Far from query vector.
public=True,
hidden=False,
tenant_id="tenant-x",
tenant_state=tenant_x,
),
_create_test_document_chunk(
document_id="not-very-relevant-1",
@@ -684,7 +796,7 @@ class TestOpenSearchClient:
), # Very far from query vector.
public=True,
hidden=False,
tenant_id="tenant-x",
tenant_state=tenant_x,
),
# These should be filtered out by public/hidden filters.
_create_test_document_chunk(
@@ -694,7 +806,7 @@ class TestOpenSearchClient:
content_vector=_generate_test_vector(0.05), # Very close but hidden.
public=True,
hidden=True,
tenant_id="tenant-x",
tenant_state=tenant_x,
),
_create_test_document_chunk(
document_id="private-but-relevant-1",
@@ -703,7 +815,7 @@ class TestOpenSearchClient:
content_vector=_generate_test_vector(0.08), # Very close but private.
public=False,
hidden=False,
tenant_id="tenant-x",
tenant_state=tenant_x,
),
]
for doc in docs:
@@ -720,7 +832,7 @@ class TestOpenSearchClient:
query_vector=query_vector,
num_candidates=10,
num_hits=5,
tenant_state=TenantState(tenant_id="tenant-x", multitenant=True),
tenant_state=tenant_x,
)
# Under test.
@@ -742,72 +854,18 @@ class TestOpenSearchClient:
# Most relevant document should be first due to normalization pipeline.
assert results[0].document_id == "highly-relevant-1"
def test_search_for_ids_basic(self, test_client: OpenSearchClient) -> None:
"""Tests search_for_ids method returns correct chunk IDs."""
# Precondition.
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=False
)
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=mappings, settings=settings)
# Index chunks for two different documents.
doc1_chunks = [
_create_test_document_chunk(
document_id="doc-1", chunk_index=i, content=f"Doc 1 Chunk {i}"
)
for i in range(3)
]
doc2_chunks = [
_create_test_document_chunk(
document_id="doc-2", chunk_index=i, content=f"Doc 2 Chunk {i}"
)
for i in range(2)
]
for chunk in doc1_chunks + doc2_chunks:
test_client.index_document(document=chunk)
test_client.refresh_index()
# Build query for doc-1.
query_body = DocumentQuery.get_from_document_id_query(
document_id="doc-1",
tenant_state=TenantState(tenant_id="", multitenant=False),
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
min_chunk_index=None,
max_chunk_index=None,
get_full_document=False,
)
# Under test.
chunk_ids = test_client.search_for_document_ids(body=query_body)
# Postcondition.
# Should get 3 IDs for doc-1.
assert len(chunk_ids) == 3
# Verify IDs match expected chunk IDs.
expected_ids = {
get_opensearch_doc_chunk_id(
document_id=chunk.document_id,
chunk_index=chunk.chunk_index,
max_chunk_size=chunk.max_chunk_size,
)
for chunk in doc1_chunks
}
assert set(chunk_ids) == expected_ids
def test_delete_by_query_multitenant_isolation(
self, test_client: OpenSearchClient
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""
Tests delete_by_query respects tenant boundaries in multi-tenant mode.
"""
# Precondition.
_patch_global_tenant_state(monkeypatch, True)
tenant_x = TenantState(tenant_id="tenant-x", multitenant=True)
tenant_y = TenantState(tenant_id="tenant-y", multitenant=True)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=True
vector_dimension=128, multitenant=tenant_x.multitenant
)
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=mappings, settings=settings)
@@ -815,103 +873,88 @@ class TestOpenSearchClient:
# Index chunks for different doc IDs for different tenants.
# NOTE: Since get_opensearch_doc_chunk_id doesn't include tenant_id yet,
# we use different document IDs to avoid ID conflicts.
tenant_a_chunks = [
tenant_x_chunks = [
_create_test_document_chunk(
document_id="doc-tenant-a",
document_id="doc-tenant-x",
chunk_index=i,
content=f"Tenant A Chunk {i}",
tenant_id="tenant-a",
tenant_state=tenant_x,
)
for i in range(3)
]
tenant_b_chunks = [
tenant_y_chunks = [
_create_test_document_chunk(
document_id="doc-tenant-b",
document_id="doc-tenant-y",
chunk_index=i,
content=f"Tenant B Chunk {i}",
tenant_id="tenant-b",
tenant_state=tenant_y,
)
for i in range(2)
]
for chunk in tenant_a_chunks + tenant_b_chunks:
for chunk in tenant_x_chunks + tenant_y_chunks:
test_client.index_document(document=chunk)
test_client.refresh_index()
# Build deletion query for tenant-a only.
query_body = DocumentQuery.get_from_document_id_query(
document_id="doc-tenant-a",
tenant_state=TenantState(tenant_id="tenant-a", multitenant=True),
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
min_chunk_index=None,
max_chunk_index=None,
get_full_document=False,
# Build deletion query for tenant-x only.
query_body = DocumentQuery.delete_from_document_id_query(
document_id="doc-tenant-x",
tenant_state=tenant_x,
)
chunk_ids = test_client.search_for_document_ids(body=query_body)
assert len(chunk_ids) == 3
expected_ids = {
get_opensearch_doc_chunk_id(
document_id=chunk.document_id,
chunk_index=chunk.chunk_index,
max_chunk_size=chunk.max_chunk_size,
)
for chunk in tenant_a_chunks
}
assert set(chunk_ids) == expected_ids
# Under test.
# Delete tenant-a chunks.
for chunk_id in chunk_ids:
result = test_client.delete_document(chunk_id)
assert result is True
# Delete tenant-x chunks using delete_by_query.
num_deleted = test_client.delete_by_query(query_body=query_body)
# Postcondition.
# Verify tenant-a chunks are deleted.
assert num_deleted == 3
# Verify tenant-x chunks are deleted.
test_client.refresh_index()
verify_query_a = DocumentQuery.get_from_document_id_query(
document_id="doc-tenant-a",
tenant_state=TenantState(tenant_id="tenant-a", multitenant=True),
verify_query_x = DocumentQuery.get_from_document_id_query(
document_id="doc-tenant-x",
tenant_state=tenant_x,
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
min_chunk_index=None,
max_chunk_index=None,
get_full_document=False,
)
remaining_a_ids = test_client.search_for_document_ids(body=verify_query_a)
remaining_a_ids = test_client.search_for_document_ids(body=verify_query_x)
assert len(remaining_a_ids) == 0
# Verify tenant-b chunks still exist.
verify_query_b = DocumentQuery.get_from_document_id_query(
document_id="doc-tenant-b",
tenant_state=TenantState(tenant_id="tenant-b", multitenant=True),
# Verify tenant-y chunks still exist.
verify_query_y = DocumentQuery.get_from_document_id_query(
document_id="doc-tenant-y",
tenant_state=tenant_y,
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
min_chunk_index=None,
max_chunk_index=None,
get_full_document=False,
)
remaining_b_ids = test_client.search_for_document_ids(body=verify_query_b)
assert len(remaining_b_ids) == 2
expected_b_ids = {
remaining_y_ids = test_client.search_for_document_ids(body=verify_query_y)
assert len(remaining_y_ids) == 2
expected_y_ids = {
get_opensearch_doc_chunk_id(
document_id=chunk.document_id,
chunk_index=chunk.chunk_index,
max_chunk_size=chunk.max_chunk_size,
)
for chunk in tenant_b_chunks
for chunk in tenant_y_chunks
}
assert set(remaining_b_ids) == expected_b_ids
assert set(remaining_y_ids) == expected_y_ids
def test_delete_by_query_nonexistent_document(
self, test_client: OpenSearchClient
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""
Tests delete_by_query for non-existent document returns 0 deleted.
"""
# Precondition.
_patch_global_tenant_state(monkeypatch, False)
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=False
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=mappings, settings=settings)
@@ -919,26 +962,26 @@ class TestOpenSearchClient:
# Don't index any documents.
# Build deletion query.
query_body = DocumentQuery.get_from_document_id_query(
query_body = DocumentQuery.delete_from_document_id_query(
document_id="nonexistent-doc",
tenant_state=TenantState(tenant_id="", multitenant=False),
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
min_chunk_index=None,
max_chunk_index=None,
get_full_document=False,
tenant_state=tenant_state,
)
# Under test.
chunk_ids = test_client.search_for_document_ids(body=query_body)
num_deleted = test_client.delete_by_query(query_body=query_body)
# Postcondition.
assert len(chunk_ids) == 0
assert num_deleted == 0
def test_search_for_document_ids(self, test_client: OpenSearchClient) -> None:
def test_search_for_document_ids(
self, test_client: OpenSearchClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Tests search_for_document_ids method returns correct chunk IDs."""
# Precondition.
_patch_global_tenant_state(monkeypatch, False)
tenant_state = TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False)
mappings = DocumentSchema.get_document_schema(
vector_dimension=128, multitenant=False
vector_dimension=128, multitenant=tenant_state.multitenant
)
settings = DocumentSchema.get_index_settings()
test_client.create_index(mappings=mappings, settings=settings)
@@ -946,13 +989,19 @@ class TestOpenSearchClient:
# Index chunks for two different documents.
doc1_chunks = [
_create_test_document_chunk(
document_id="doc-1", chunk_index=i, content=f"Doc 1 Chunk {i}"
document_id="doc-1",
chunk_index=i,
content=f"Doc 1 Chunk {i}",
tenant_state=tenant_state,
)
for i in range(3)
]
doc2_chunks = [
_create_test_document_chunk(
document_id="doc-2", chunk_index=i, content=f"Doc 2 Chunk {i}"
document_id="doc-2",
chunk_index=i,
content=f"Doc 2 Chunk {i}",
tenant_state=tenant_state,
)
for i in range(2)
]
@@ -964,7 +1013,7 @@ class TestOpenSearchClient:
# Build query for doc-1.
query_body = DocumentQuery.get_from_document_id_query(
document_id="doc-1",
tenant_state=TenantState(tenant_id="", multitenant=False),
tenant_state=tenant_state,
max_chunk_size=DEFAULT_MAX_CHUNK_SIZE,
min_chunk_index=None,
max_chunk_index=None,

View File

@@ -41,6 +41,12 @@ API_KEY_RECORDS: Dict[str, Dict[str, Any]] = {
},
}
# These are inferrable from the file anyways, no need to obfuscate.
# use them to test your auth with this server
#
# mcp_live-kid_alice_001-S3cr3tAlice
# mcp_live-kid_bob_001-S3cr3tBob
# ---- verifier ---------------------------------------------------------------
class ApiKeyVerifier(TokenVerifier):

View File

@@ -309,6 +309,63 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
assert fallback_llm.config.model_name == default_provider.default_model_name
def test_list_llm_provider_basics_excludes_non_public_unrestricted(
users: tuple[DATestUser, DATestUser],
) -> None:
"""Test that the /llm/provider endpoint correctly excludes non-public providers
with no group/persona restrictions.
This tests the fix for the bug where non-public providers with no restrictions
were incorrectly shown to all users instead of being admin-only.
"""
admin_user, basic_user = users
# Create a public provider (should be visible to all)
public_provider = LLMProviderManager.create(
name="public-provider",
is_public=True,
set_as_default=True,
user_performing_action=admin_user,
)
# Create a non-public provider with no restrictions (should be admin-only)
non_public_provider = LLMProviderManager.create(
name="non-public-unrestricted",
is_public=False,
groups=[],
personas=[],
set_as_default=False,
user_performing_action=admin_user,
)
# Non-admin user calls the /llm/provider endpoint
response = requests.get(
f"{API_SERVER_URL}/llm/provider",
headers=basic_user.headers,
)
assert response.status_code == 200
providers = response.json()
provider_names = [p["name"] for p in providers]
# Public provider should be visible
assert public_provider.name in provider_names
# Non-public provider with no restrictions should NOT be visible to non-admin
assert non_public_provider.name not in provider_names
# Admin user should see both providers
admin_response = requests.get(
f"{API_SERVER_URL}/llm/provider",
headers=admin_user.headers,
)
assert admin_response.status_code == 200
admin_providers = admin_response.json()
admin_provider_names = [p["name"] for p in admin_providers]
assert public_provider.name in admin_provider_names
assert non_public_provider.name in admin_provider_names
def test_provider_delete_clears_persona_references(reset: None) -> None:
"""Test that deleting a provider automatically clears persona references."""
admin_user = UserManager.create(name="admin_user")

View File

@@ -61,13 +61,13 @@ def test_cold_startup_default_assistant() -> None:
# Verify all three main tools are attached
assert (
"SearchTool" in tool_names
"internal_search" in tool_names
), "Default assistant should have SearchTool attached"
assert (
"ImageGenerationTool" in tool_names
"generate_image" in tool_names
), "Default assistant should have ImageGenerationTool attached"
assert (
"WebSearchTool" in tool_names
"web_search" in tool_names
), "Default assistant should have WebSearchTool attached"
# Also verify by display names for clarity

View File

@@ -1,3 +1,4 @@
from pydantic import BaseModel
from sqlalchemy import text
from onyx.db.engine.sql_engine import get_session_with_current_tenant
@@ -5,6 +6,53 @@ from tests.integration.common_utils.reset import downgrade_postgres
from tests.integration.common_utils.reset import upgrade_postgres
class ToolSeedingExpectedResult(BaseModel):
name: str
display_name: str
in_code_tool_id: str
user_id: str | None
EXPECTED_TOOLS = {
"SearchTool": ToolSeedingExpectedResult(
name="internal_search",
display_name="Internal Search",
in_code_tool_id="SearchTool",
user_id=None,
),
"ImageGenerationTool": ToolSeedingExpectedResult(
name="generate_image",
display_name="Image Generation",
in_code_tool_id="ImageGenerationTool",
user_id=None,
),
"WebSearchTool": ToolSeedingExpectedResult(
name="web_search",
display_name="Web Search",
in_code_tool_id="WebSearchTool",
user_id=None,
),
"KnowledgeGraphTool": ToolSeedingExpectedResult(
name="run_kg_search",
display_name="Knowledge Graph Search",
in_code_tool_id="KnowledgeGraphTool",
user_id=None,
),
"PythonTool": ToolSeedingExpectedResult(
name="python",
display_name="Code Interpreter",
in_code_tool_id="PythonTool",
user_id=None,
),
"ResearchAgent": ToolSeedingExpectedResult(
name="research_agent",
display_name="Research Agent",
in_code_tool_id="ResearchAgent",
user_id=None,
),
}
def test_tool_seeding_migration() -> None:
"""Test that migration from base to head correctly seeds builtin tools."""
# Start from base and upgrade to just before tool seeding
@@ -49,56 +97,33 @@ def test_tool_seeding_migration() -> None:
len(tools) == 8
), f"Should have created exactly 8 builtin tools, got {len(tools)}"
def validate_tool(expected: ToolSeedingExpectedResult) -> None:
tool = next((t for t in tools if t[1] == expected.name), None)
assert tool is not None, f"{expected.name} should exist"
assert (
tool[2] == expected.display_name
), f"{expected.name} display name should be '{expected.display_name}'"
assert (
tool[4] == expected.in_code_tool_id
), f"{expected.name} in_code_tool_id should be '{expected.in_code_tool_id}'"
assert (
tool[5] is None
), f"{expected.name} should not have a user_id (builtin)"
# Check SearchTool
search_tool = next((t for t in tools if t[1] == "SearchTool"), None)
assert search_tool is not None, "SearchTool should exist"
assert (
search_tool[2] == "Internal Search"
), "SearchTool display name should be 'Internal Search'"
assert search_tool[5] is None, "SearchTool should not have a user_id (builtin)"
validate_tool(EXPECTED_TOOLS["SearchTool"])
# Check ImageGenerationTool
img_tool = next((t for t in tools if t[1] == "ImageGenerationTool"), None)
assert img_tool is not None, "ImageGenerationTool should exist"
assert (
img_tool[2] == "Image Generation"
), "ImageGenerationTool display name should be 'Image Generation'"
assert (
img_tool[5] is None
), "ImageGenerationTool should not have a user_id (builtin)"
validate_tool(EXPECTED_TOOLS["ImageGenerationTool"])
# Check WebSearchTool
web_tool = next((t for t in tools if t[1] == "WebSearchTool"), None)
assert web_tool is not None, "WebSearchTool should exist"
assert (
web_tool[2] == "Web Search"
), "WebSearchTool display name should be 'Web Search'"
assert web_tool[5] is None, "WebSearchTool should not have a user_id (builtin)"
validate_tool(EXPECTED_TOOLS["WebSearchTool"])
# Check KnowledgeGraphTool
kg_tool = next((t for t in tools if t[1] == "KnowledgeGraphTool"), None)
assert kg_tool is not None, "KnowledgeGraphTool should exist"
assert (
kg_tool[2] == "Knowledge Graph Search"
), "KnowledgeGraphTool display name should be 'Knowledge Graph Search'"
assert (
kg_tool[5] is None
), "KnowledgeGraphTool should not have a user_id (builtin)"
validate_tool(EXPECTED_TOOLS["KnowledgeGraphTool"])
# Check PythonTool
python_tool = next((t for t in tools if t[1] == "PythonTool"), None)
assert python_tool is not None, "PythonTool should exist"
assert (
python_tool[2] == "Code Interpreter"
), "PythonTool display name should be 'Code Interpreter'"
assert python_tool[5] is None, "PythonTool should not have a user_id (builtin)"
validate_tool(EXPECTED_TOOLS["PythonTool"])
# Check ResearchAgent (Deep Research as a tool)
research_agent = next((t for t in tools if t[1] == "ResearchAgent"), None)
assert research_agent is not None, "ResearchAgent should exist"
assert (
research_agent[2] == "Research Agent"
), "ResearchAgent display name should be 'Research Agent'"
assert (
research_agent[5] is None
), "ResearchAgent should not have a user_id (builtin)"
validate_tool(EXPECTED_TOOLS["ResearchAgent"])

View File

@@ -38,11 +38,11 @@ def test_unified_assistant(reset: None, admin_user: DATestUser) -> None:
# Verify tools
tools = unified_assistant.tools
tool_names = [tool.name for tool in tools]
assert "SearchTool" in tool_names, "SearchTool not found in unified assistant"
assert "internal_search" in tool_names, "SearchTool not found in unified assistant"
assert (
"ImageGenerationTool" in tool_names
"generate_image" in tool_names
), "ImageGenerationTool not found in unified assistant"
assert "WebSearchTool" in tool_names, "WebSearchTool not found in unified assistant"
assert "web_search" in tool_names, "WebSearchTool not found in unified assistant"
# Verify no starter messages
starter_messages = unified_assistant.starter_messages or []

View File

@@ -270,7 +270,7 @@ def test_web_search_endpoints_with_exa(
provider_id = _activate_exa_provider(admin_user)
assert isinstance(provider_id, int)
search_request = {"queries": ["latest ai research news"], "max_results": 3}
search_request = {"queries": ["wikipedia python programming"], "max_results": 3}
lite_response = requests.post(
f"{API_SERVER_URL}/web-search/search-lite",

View File

@@ -2,11 +2,11 @@
# This file exposes service ports for development and testing purposes
#
# Usage:
# docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d
# docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d --wait
#
# Or set COMPOSE_FILE environment variable:
# export COMPOSE_FILE=docker-compose.yml:docker-compose.dev.yml
# docker compose up -d
# docker compose up -d --wait
services:
api_server:

View File

@@ -58,7 +58,7 @@ services:
- minio
restart: unless-stopped
# DEV: To expose ports, either:
# 1. Use docker-compose.dev.yml: docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d
# 1. Use docker-compose.dev.yml: docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d --wait
# 2. Uncomment the ports below
# ports:
# - "8080:8080"
@@ -83,7 +83,13 @@ services:
max-size: "50m"
max-file: "6"
healthcheck:
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8080/health')"]
test:
[
"CMD",
"python",
"-c",
"import urllib.request; urllib.request.urlopen('http://localhost:8080/health')",
]
interval: 30s
timeout: 20s
retries: 3
@@ -299,7 +305,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
# DEV: To expose ports, either:
# 1. Use docker-compose.dev.yml: docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d
# 1. Use docker-compose.dev.yml: docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d --wait
# 2. Uncomment the ports below
# ports:
# - "5432:5432"
@@ -321,7 +327,7 @@ services:
environment:
- VESPA_SKIP_UPGRADE_CHECK=${VESPA_SKIP_UPGRADE_CHECK:-true}
# DEV: To expose ports, either:
# 1. Use docker-compose.dev.yml: docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d
# 1. Use docker-compose.dev.yml: docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d --wait
# 2. Uncomment the ports below
# ports:
# - "19071:19071"
@@ -378,7 +384,7 @@ services:
image: redis:7.4-alpine
restart: unless-stopped
# DEV: To expose ports, either:
# 1. Use docker-compose.dev.yml: docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d
# 1. Use docker-compose.dev.yml: docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d --wait
# 2. Uncomment the ports below
# ports:
# - "6379:6379"
@@ -396,7 +402,7 @@ services:
image: minio/minio:RELEASE.2025-07-23T15-54-02Z-cpuv1
restart: unless-stopped
# DEV: To expose ports, either:
# 1. Use docker-compose.dev.yml: docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d
# 1. Use docker-compose.dev.yml: docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d --wait
# 2. Uncomment the ports below
# ports:
# - "9004:9000"

View File

@@ -21,9 +21,9 @@ use tauri::{
webview::PageLoadPayload, AppHandle, Manager, Webview, WebviewUrl, WebviewWindowBuilder,
};
use tauri_plugin_global_shortcut::{Code, GlobalShortcutExt, Modifiers, Shortcut};
use url::Url;
#[cfg(target_os = "macos")]
use tokio::time::sleep;
use url::Url;
#[cfg(target_os = "macos")]
use window_vibrancy::{apply_vibrancy, NSVisualEffectMaterial};
@@ -76,39 +76,25 @@ fn get_config_path() -> Option<PathBuf> {
}
/// Load config from file, or create default if it doesn't exist
fn load_config() -> AppConfig {
fn load_config() -> (AppConfig, bool) {
let config_path = match get_config_path() {
Some(path) => path,
None => {
eprintln!("Could not determine config directory, using defaults");
return AppConfig::default();
return (AppConfig::default(), false);
}
};
if config_path.exists() {
match fs::read_to_string(&config_path) {
Ok(contents) => match serde_json::from_str(&contents) {
Ok(config) => {
return config;
}
Err(e) => {
eprintln!("Failed to parse config: {}, using defaults", e);
}
},
Err(e) => {
eprintln!("Failed to read config: {}, using defaults", e);
}
}
} else {
// Create default config file
if let Err(e) = save_config(&AppConfig::default()) {
eprintln!("Failed to create default config: {}", e);
} else {
println!("Created default config at {:?}", config_path);
}
if !config_path.exists() {
return (AppConfig::default(), false);
}
AppConfig::default()
match fs::read_to_string(&config_path) {
Ok(contents) => match serde_json::from_str(&contents) {
Ok(config) => (config, true),
Err(_) => (AppConfig::default(), false),
},
Err(_) => (AppConfig::default(), false),
}
}
/// Save config to file
@@ -128,7 +114,11 @@ fn save_config(config: &AppConfig) -> Result<(), String> {
}
// Global config state
struct ConfigState(RwLock<AppConfig>);
struct ConfigState {
config: RwLock<AppConfig>,
config_initialized: RwLock<bool>,
app_base_url: RwLock<Option<Url>>,
}
fn focus_main_window(app: &AppHandle) {
if let Some(window) = app.get_webview_window("main") {
@@ -142,7 +132,7 @@ fn focus_main_window(app: &AppHandle) {
fn trigger_new_chat(app: &AppHandle) {
let state = app.state::<ConfigState>();
let server_url = state.0.read().unwrap().server_url.clone();
let server_url = state.config.read().unwrap().server_url.clone();
if let Some(window) = app.get_webview_window("main") {
let url = format!("{}/chat", server_url);
@@ -152,7 +142,7 @@ fn trigger_new_chat(app: &AppHandle) {
fn trigger_new_window(app: &AppHandle) {
let state = app.state::<ConfigState>();
let server_url = state.0.read().unwrap().server_url.clone();
let server_url = state.config.read().unwrap().server_url.clone();
let handle = app.clone();
tauri::async_runtime::spawn(async move {
@@ -206,6 +196,30 @@ fn open_docs() {
}
}
fn open_settings(app: &AppHandle) {
// Navigate main window to the settings page (index.html) with settings flag
let state = app.state::<ConfigState>();
let settings_url = state
.app_base_url
.read()
.unwrap()
.as_ref()
.cloned()
.and_then(|mut url| {
url.set_query(None);
url.set_fragment(Some("settings"));
url.set_path("/");
Some(url)
})
.or_else(|| Url::parse("tauri://localhost/#settings").ok());
if let Some(window) = app.get_webview_window("main") {
if let Some(url) = settings_url {
let _ = window.navigate(url);
}
}
}
// ============================================================================
// Tauri Commands
// ============================================================================
@@ -213,7 +227,27 @@ fn open_docs() {
/// Get the current server URL
#[tauri::command]
fn get_server_url(state: tauri::State<ConfigState>) -> String {
state.0.read().unwrap().server_url.clone()
state.config.read().unwrap().server_url.clone()
}
#[derive(Serialize)]
struct BootstrapState {
server_url: String,
config_exists: bool,
}
/// Get the server URL plus whether a config file exists
#[tauri::command]
fn get_bootstrap_state(state: tauri::State<ConfigState>) -> BootstrapState {
let server_url = state.config.read().unwrap().server_url.clone();
let config_initialized = *state.config_initialized.read().unwrap();
let config_exists = config_initialized
&& get_config_path().map(|path| path.exists()).unwrap_or(false);
BootstrapState {
server_url,
config_exists,
}
}
/// Set a new server URL and save to config
@@ -224,9 +258,10 @@ fn set_server_url(state: tauri::State<ConfigState>, url: String) -> Result<Strin
return Err("URL must start with http:// or https://".to_string());
}
let mut config = state.0.write().unwrap();
let mut config = state.config.write().unwrap();
config.server_url = url.trim_end_matches('/').to_string();
save_config(&config)?;
*state.config_initialized.write().unwrap() = true;
Ok(config.server_url.clone())
}
@@ -315,7 +350,7 @@ fn open_config_directory() -> Result<(), String> {
/// Navigate to a specific path on the configured server
#[tauri::command]
fn navigate_to(window: tauri::WebviewWindow, state: tauri::State<ConfigState>, path: &str) {
let base_url = state.0.read().unwrap().server_url.clone();
let base_url = state.config.read().unwrap().server_url.clone();
let url = format!("{}{}", base_url, path);
let _ = window.eval(&format!("window.location.href = '{}'", url));
}
@@ -341,7 +376,7 @@ fn go_forward(window: tauri::WebviewWindow) {
/// Open a new window
#[tauri::command]
async fn new_window(app: AppHandle, state: tauri::State<'_, ConfigState>) -> Result<(), String> {
let server_url = state.0.read().unwrap().server_url.clone();
let server_url = state.config.read().unwrap().server_url.clone();
let window_label = format!("onyx-{}", uuid::Uuid::new_v4());
let builder = WebviewWindowBuilder::new(
@@ -385,9 +420,10 @@ async fn new_window(app: AppHandle, state: tauri::State<'_, ConfigState>) -> Res
/// Reset config to defaults
#[tauri::command]
fn reset_config(state: tauri::State<ConfigState>) -> Result<(), String> {
let mut config = state.0.write().unwrap();
let mut config = state.config.write().unwrap();
*config = AppConfig::default();
save_config(&config)?;
*state.config_initialized.write().unwrap() = true;
Ok(())
}
@@ -423,7 +459,7 @@ fn setup_shortcuts(app: &AppHandle) -> Result<(), Box<dyn std::error::Error>> {
let forward = Shortcut::new(Some(Modifiers::SUPER), Code::BracketRight);
let new_window_shortcut = Shortcut::new(Some(Modifiers::SUPER | Modifiers::SHIFT), Code::KeyN);
let show_app = Shortcut::new(Some(Modifiers::SUPER | Modifiers::SHIFT), Code::Space);
let open_settings = Shortcut::new(Some(Modifiers::SUPER), Code::Comma);
let open_settings_shortcut = Shortcut::new(Some(Modifiers::SUPER), Code::Comma);
let app_handle = app.clone();
@@ -435,7 +471,7 @@ fn setup_shortcuts(app: &AppHandle) -> Result<(), Box<dyn std::error::Error>> {
forward,
new_window_shortcut,
show_app,
open_settings,
open_settings_shortcut,
];
#[cfg(not(target_os = "macos"))]
@@ -446,7 +482,7 @@ fn setup_shortcuts(app: &AppHandle) -> Result<(), Box<dyn std::error::Error>> {
forward,
new_window_shortcut,
show_app,
open_settings,
open_settings_shortcut,
];
app.global_shortcut().on_shortcuts(
@@ -463,9 +499,8 @@ fn setup_shortcuts(app: &AppHandle) -> Result<(), Box<dyn std::error::Error>> {
let _ = window.eval("window.history.back()");
} else if shortcut == &forward {
let _ = window.eval("window.history.forward()");
} else if shortcut == &open_settings {
// Open config file for editing
let _ = open_config_file();
} else if shortcut == &open_settings_shortcut {
open_settings(&app_handle);
}
}
@@ -495,6 +530,7 @@ fn setup_app_menu(app: &AppHandle) -> tauri::Result<()> {
true,
Some("CmdOrCtrl+Shift+N"),
)?;
let settings_item = MenuItem::with_id(app, "open_settings", "Settings...", true, Some("CmdOrCtrl+Comma"))?;
let docs_item = MenuItem::with_id(app, "open_docs", "Onyx Documentation", true, None::<&str>)?;
if let Some(file_menu) = menu
@@ -503,12 +539,13 @@ fn setup_app_menu(app: &AppHandle) -> tauri::Result<()> {
.filter_map(|item| item.as_submenu().cloned())
.find(|submenu| submenu.text().ok().as_deref() == Some("File"))
{
file_menu.insert_items(&[&new_chat_item, &new_window_item], 0)?;
file_menu.insert_items(&[&new_chat_item, &new_window_item, &settings_item], 0)?;
} else {
let file_menu = SubmenuBuilder::new(app, "File")
.items(&[
&new_chat_item,
&new_window_item,
&settings_item,
&PredefinedMenuItem::close_window(app, None)?,
])
.build()?;
@@ -625,22 +662,20 @@ fn setup_tray_icon(app: &AppHandle) -> tauri::Result<()> {
fn main() {
// Load config at startup
let config = load_config();
let server_url = config.server_url.clone();
println!("Starting Onyx Desktop");
println!("Server URL: {}", server_url);
if let Some(path) = get_config_path() {
println!("Config file: {:?}", path);
}
let (config, config_initialized) = load_config();
tauri::Builder::default()
.plugin(tauri_plugin_shell::init())
.plugin(tauri_plugin_global_shortcut::Builder::new().build())
.plugin(tauri_plugin_window_state::Builder::default().build())
.manage(ConfigState(RwLock::new(config)))
.manage(ConfigState {
config: RwLock::new(config),
config_initialized: RwLock::new(config_initialized),
app_base_url: RwLock::new(None),
})
.invoke_handler(tauri::generate_handler![
get_server_url,
get_bootstrap_state,
set_server_url,
get_config_path_cmd,
open_config_file,
@@ -657,6 +692,7 @@ fn main() {
"open_docs" => open_docs(),
"new_chat" => trigger_new_chat(app),
"new_window" => trigger_new_window(app),
"open_settings" => open_settings(app),
_ => {}
})
.setup(move |app| {
@@ -675,7 +711,7 @@ fn main() {
eprintln!("Failed to setup tray icon: {}", e);
}
// Update main window URL to configured server and inject title bar
// Setup main window with vibrancy effect
if let Some(window) = app.get_webview_window("main") {
// Apply vibrancy effect for translucent glass look
#[cfg(target_os = "macos")]
@@ -683,14 +719,12 @@ fn main() {
let _ = apply_vibrancy(&window, NSVisualEffectMaterial::Sidebar, None, None);
}
if let Ok(target) = Url::parse(&server_url) {
if let Ok(current) = window.url() {
if current != target {
let _ = window.navigate(target);
}
} else {
let _ = window.navigate(target);
}
if let Ok(url) = window.url() {
let mut base_url = url;
base_url.set_query(None);
base_url.set_fragment(None);
base_url.set_path("/");
*app.state::<ConfigState>().app_base_url.write().unwrap() = Some(base_url);
}
#[cfg(target_os = "macos")]

View File

@@ -14,7 +14,7 @@
{
"title": "Onyx",
"label": "main",
"url": "https://cloud.onyx.app",
"url": "index.html",
"width": 1200,
"height": 800,
"minWidth": 800,
@@ -52,7 +52,7 @@
"entitlements": null,
"exceptionDomain": "cloud.onyx.app",
"minimumSystemVersion": "10.15",
"signingIdentity": "-",
"signingIdentity": null,
"dmg": {
"windowSize": {
"width": 660,

View File

@@ -4,28 +4,43 @@
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Onyx</title>
<link
href="https://fonts.googleapis.com/css2?family=Hanken+Grotesk:wght@400;500;600;700&display=swap"
rel="stylesheet"
/>
<style>
:root {
--background-900: #f5f5f5;
--background-800: #ffffff;
--text-light-05: rgba(0, 0, 0, 0.95);
--text-light-03: rgba(0, 0, 0, 0.6);
--white-10: rgba(0, 0, 0, 0.1);
--white-15: rgba(0, 0, 0, 0.15);
--white-20: rgba(0, 0, 0, 0.2);
--white-30: rgba(0, 0, 0, 0.3);
--font-hanken-grotesk: "Hanken Grotesk", -apple-system,
BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
}
* {
box-sizing: border-box;
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto,
Oxygen, Ubuntu, sans-serif;
background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%);
color: #fff;
font-family: var(--font-hanken-grotesk);
background: linear-gradient(135deg, #f5f5f5 0%, #ffffff 100%);
min-height: 100vh;
color: var(--text-light-05);
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
padding: 20px;
-webkit-user-select: none;
user-select: none;
}
/* Draggable titlebar area for macOS */
.titlebar {
position: fixed;
top: 0;
@@ -33,198 +48,451 @@
right: 0;
height: 28px;
-webkit-app-region: drag;
z-index: 10000;
}
.container {
text-align: center;
padding: 2rem;
.settings-container {
max-width: 500px;
width: 100%;
opacity: 0;
transform: translateY(8px);
pointer-events: none;
transition:
opacity 0.18s ease,
transform 0.18s ease;
}
.logo {
width: 80px;
height: 80px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border-radius: 20px;
margin: 0 auto 1.5rem;
body.show-settings .settings-container {
opacity: 1;
transform: translateY(0);
pointer-events: auto;
}
.settings-panel {
background: linear-gradient(
to bottom,
rgba(255, 255, 255, 0.95),
rgba(245, 245, 245, 0.95)
);
backdrop-filter: blur(24px);
border-radius: 16px;
border: 1px solid var(--white-10);
overflow: hidden;
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1);
}
.settings-header {
padding: 24px;
border-bottom: 1px solid var(--white-10);
display: flex;
align-items: center;
gap: 12px;
}
.settings-icon {
width: 40px;
height: 40px;
border-radius: 12px;
background: white;
display: flex;
align-items: center;
justify-content: center;
font-size: 2.5rem;
font-weight: bold;
overflow: hidden;
}
h1 {
font-size: 2rem;
margin-bottom: 0.5rem;
.settings-icon svg {
width: 24px;
height: 24px;
color: #000;
}
.settings-title {
font-size: 20px;
font-weight: 600;
color: var(--text-light-05);
}
p {
color: #a0a0a0;
margin-bottom: 2rem;
.settings-content {
padding: 24px;
}
.loading {
.settings-section {
margin-bottom: 32px;
}
.settings-section:last-child {
margin-bottom: 0;
}
.section-title {
font-size: 11px;
font-weight: 600;
text-transform: uppercase;
letter-spacing: 0.05em;
color: var(--text-light-03);
margin-bottom: 12px;
}
.settings-group {
background: rgba(0, 0, 0, 0.03);
border-radius: 16px;
padding: 4px;
}
.setting-row {
display: flex;
gap: 0.5rem;
justify-content: center;
margin-bottom: 2rem;
justify-content: space-between;
align-items: center;
padding: 12px;
}
.loading span {
width: 10px;
height: 10px;
background: #667eea;
border-radius: 50%;
animation: bounce 1.4s ease-in-out infinite;
.setting-row-content {
display: flex;
flex-direction: column;
gap: 4px;
flex: 1;
}
.loading span:nth-child(1) {
animation-delay: 0s;
}
.loading span:nth-child(2) {
animation-delay: 0.2s;
}
.loading span:nth-child(3) {
animation-delay: 0.4s;
.setting-label {
font-size: 14px;
font-weight: 400;
color: var(--text-light-05);
}
@keyframes bounce {
0%,
80%,
100% {
transform: scale(0.8);
opacity: 0.5;
}
40% {
transform: scale(1.2);
opacity: 1;
}
.setting-description {
font-size: 12px;
color: var(--text-light-03);
}
.btn {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
padding: 0.75rem 2rem;
.setting-divider {
height: 1px;
background: var(--white-10);
margin: 0 4px;
}
.input-field {
width: 100%;
padding: 10px 12px;
border: 1px solid var(--white-10);
border-radius: 8px;
font-size: 1rem;
cursor: pointer;
transition:
transform 0.2s,
box-shadow 0.2s;
font-size: 14px;
background: rgba(0, 0, 0, 0.05);
color: var(--text-light-05);
font-family: var(--font-hanken-grotesk);
transition: all 0.2s;
-webkit-app-region: no-drag;
}
.btn:hover {
transform: translateY(-2px);
box-shadow: 0 4px 20px rgba(102, 126, 234, 0.4);
.input-field:focus {
outline: none;
border-color: var(--white-30);
background: rgba(0, 0, 0, 0.08);
box-shadow: 0 0 0 2px rgba(0, 0, 0, 0.05);
}
.shortcuts {
margin-top: 3rem;
padding: 1.5rem;
background: rgba(255, 255, 255, 0.05);
border-radius: 12px;
text-align: left;
.input-field::placeholder {
color: var(--text-light-03);
}
.shortcuts h3 {
font-size: 0.875rem;
text-transform: uppercase;
letter-spacing: 0.05em;
color: #a0a0a0;
margin-bottom: 1rem;
.input-field.error {
border-color: #ef4444;
}
.shortcut {
display: flex;
justify-content: space-between;
padding: 0.5rem 0;
border-bottom: 1px solid rgba(255, 255, 255, 0.1);
.error-message {
color: #ef4444;
font-size: 12px;
margin-top: 4px;
padding-left: 12px;
display: none;
}
.shortcut:last-child {
border-bottom: none;
.error-message.visible {
display: block;
}
.shortcut-key {
font-family:
SF Mono,
Monaco,
monospace;
background: rgba(255, 255, 255, 0.1);
padding: 0.25rem 0.5rem;
.toggle-switch {
position: relative;
display: inline-block;
width: 44px;
height: 24px;
flex-shrink: 0;
}
.toggle-switch input {
opacity: 0;
width: 0;
height: 0;
}
.toggle-slider {
position: absolute;
cursor: pointer;
top: 0;
left: 0;
right: 0;
bottom: 0;
background-color: rgba(0, 0, 0, 0.15);
transition: 0.3s;
border-radius: 24px;
}
.toggle-slider:before {
position: absolute;
content: "";
height: 18px;
width: 18px;
left: 3px;
bottom: 3px;
background-color: white;
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.2);
transition: 0.3s;
border-radius: 50%;
}
input:checked + .toggle-slider {
background-color: rgba(0, 0, 0, 0.3);
}
input:checked + .toggle-slider:before {
transform: translateX(20px);
}
.button {
padding: 12px 24px;
border-radius: 8px;
border: none;
cursor: pointer;
font-size: 14px;
font-weight: 600;
transition: all 0.2s;
font-family: var(--font-hanken-grotesk);
width: 100%;
margin-top: 24px;
-webkit-app-region: no-drag;
}
.button.primary {
background: #286df8;
color: white;
}
.button.primary:hover {
background: #1e5cd6;
box-shadow: 0 4px 12px rgba(40, 109, 248, 0.3);
}
.button.primary:disabled {
opacity: 0.5;
cursor: not-allowed;
box-shadow: none;
}
kbd {
background: rgba(0, 0, 0, 0.1);
border: 1px solid var(--white-10);
border-radius: 4px;
font-size: 0.75rem;
padding: 2px 6px;
font-family: monospace;
font-weight: 500;
color: var(--text-light-05);
font-size: 11px;
}
</style>
</head>
<body>
<div class="titlebar"></div>
<div class="container">
<div class="logo">O</div>
<h1>Onyx</h1>
<p>Connecting to Onyx Cloud...</p>
<div class="settings-container">
<div class="settings-panel">
<div class="settings-header">
<div class="settings-icon">
<svg
viewBox="0 0 56 56"
xmlns="http://www.w3.org/2000/svg"
fill="currentColor"
>
<path
fill-rule="evenodd"
clip-rule="evenodd"
d="M28 0 10.869 7.77 28 15.539l17.131-7.77L28 0Zm0 40.461-17.131 7.77L28 56l17.131-7.77L28 40.461Zm20.231-29.592L56 28.001l-7.769 17.131L40.462 28l7.769-17.131ZM15.538 28 7.77 10.869 0 28l7.769 17.131L15.538 28Z"
/>
</svg>
</div>
<h1 class="settings-title">Settings</h1>
</div>
<div class="loading">
<span></span>
<span></span>
<span></span>
</div>
<div class="settings-content">
<section class="settings-section">
<div class="section-title">GENERAL</div>
<div class="settings-group">
<div class="setting-row">
<div class="setting-row-content">
<label class="setting-label" for="onyxDomain"
>Root Domain</label
>
<div class="setting-description">
The root URL for your Onyx instance
</div>
</div>
</div>
<div class="setting-divider"></div>
<div class="setting-row" style="padding: 12px">
<input
type="text"
id="onyxDomain"
class="input-field"
placeholder="https://cloud.onyx.app"
autocomplete="off"
autocorrect="off"
autocapitalize="off"
spellcheck="false"
/>
</div>
<div class="error-message" id="errorMessage">
Please enter a valid URL starting with http:// or https://
</div>
</div>
</section>
<button
class="btn"
onclick="window.location.href='https://cloud.onyx.app'"
>
Open Onyx Cloud
</button>
<p style="margin-top: 1.5rem; font-size: 0.875rem; color: #666">
Self-hosted? Press
<span
class="shortcut-key"
style="display: inline; padding: 0.15rem 0.4rem"
>⌘ ,</span
>
to configure your server URL.
</p>
<div class="shortcuts">
<h3>Keyboard Shortcuts</h3>
<div class="shortcut">
<span>New Chat</span>
<span class="shortcut-key">⌘ N</span>
</div>
<div class="shortcut">
<span>New Window</span>
<span class="shortcut-key">⌘ ⇧ N</span>
</div>
<div class="shortcut">
<span>Reload</span>
<span class="shortcut-key">⌘ R</span>
</div>
<div class="shortcut">
<span>Back</span>
<span class="shortcut-key">⌘ [</span>
</div>
<div class="shortcut">
<span>Forward</span>
<span class="shortcut-key">⌘ ]</span>
</div>
<div class="shortcut">
<span>Settings / Config</span>
<span class="shortcut-key">⌘ ,</span>
<button class="button primary" id="saveBtn">Save & Connect</button>
</div>
</div>
</div>
<script>
// Auto-redirect to Onyx Cloud after a short delay
setTimeout(() => {
window.location.href = "https://cloud.onyx.app";
}, 1500);
// Import Tauri API
const { invoke } = window.__TAURI__.core;
// Configuration
const DEFAULT_DOMAIN = "https://cloud.onyx.app";
let currentServerUrl = "";
// DOM elements
const domainInput = document.getElementById("onyxDomain");
const errorMessage = document.getElementById("errorMessage");
const saveBtn = document.getElementById("saveBtn");
function showSettings() {
document.body.classList.add("show-settings");
}
// Initialize the app
async function init() {
try {
const bootstrap = await invoke("get_bootstrap_state");
currentServerUrl = bootstrap.server_url;
// Set the input value
domainInput.value = currentServerUrl || DEFAULT_DOMAIN;
// Check if user came here explicitly (via Settings menu/shortcut)
const urlParams = new URLSearchParams(window.location.search);
const isExplicitSettings =
window.location.hash === "#settings" ||
urlParams.get("settings") === "true";
// If user explicitly opened settings, show modal
if (isExplicitSettings) {
// Modal is already visible, user can edit and save
showSettings();
return;
}
// Otherwise, check if this is first launch
// First launch = config doesn't exist
if (!bootstrap.config_exists || !currentServerUrl) {
// First launch - show modal, require user to configure
showSettings();
return;
}
// Not first launch and not explicit settings
// Auto-redirect to configured domain
window.location.href = currentServerUrl;
} catch (error) {
// On error, default to cloud
domainInput.value = DEFAULT_DOMAIN;
showSettings();
}
}
// Validate URL
function validateUrl(url) {
const trimmedUrl = url.trim();
if (!trimmedUrl) {
return { valid: false, error: "URL cannot be empty" };
}
if (
!trimmedUrl.startsWith("http://") &&
!trimmedUrl.startsWith("https://")
) {
return {
valid: false,
error: "URL must start with http:// or https://",
};
}
try {
new URL(trimmedUrl);
return { valid: true, url: trimmedUrl };
} catch {
return { valid: false, error: "Please enter a valid URL" };
}
}
// Show error
function showError(message) {
domainInput.classList.add("error");
errorMessage.textContent = message;
errorMessage.classList.add("visible");
}
// Clear error
function clearError() {
domainInput.classList.remove("error");
errorMessage.classList.remove("visible");
}
// Save configuration
async function saveConfiguration() {
clearError();
const validation = validateUrl(domainInput.value);
if (!validation.valid) {
showError(validation.error);
return;
}
try {
saveBtn.disabled = true;
saveBtn.textContent = "Saving...";
// Call Tauri command to save the URL
await invoke("set_server_url", { url: validation.url });
// Success - redirect to the new URL (login page)
window.location.href = validation.url;
} catch (error) {
showError(error || "Failed to save configuration");
saveBtn.disabled = false;
saveBtn.textContent = "Save & Connect";
}
}
// Event listeners
domainInput.addEventListener("input", clearError);
domainInput.addEventListener("keypress", (e) => {
if (e.key === "Enter") {
saveConfiguration();
}
});
saveBtn.addEventListener("click", saveConfiguration);
// Initialize when DOM is ready
if (document.readyState === "loading") {
document.addEventListener("DOMContentLoaded", init);
} else {
init();
}
</script>
</body>
</html>

View File

@@ -2,8 +2,6 @@
// This script injects a draggable title bar that matches Onyx design system
(function () {
console.log("[Onyx Desktop] Title bar script loaded");
const TITLEBAR_ID = "onyx-desktop-titlebar";
const TITLEBAR_HEIGHT = 36;
const STYLE_ID = "onyx-desktop-titlebar-style";
@@ -31,12 +29,7 @@
try {
await invoke("start_drag_window");
return;
} catch (err) {
console.error(
"[Onyx Desktop] Failed to start dragging via invoke:",
err,
);
}
} catch (err) {}
}
const appWindow =
@@ -46,14 +39,7 @@
if (appWindow?.startDragging) {
try {
await appWindow.startDragging();
} catch (err) {
console.error(
"[Onyx Desktop] Failed to start dragging via appWindow:",
err,
);
}
} else {
console.error("[Onyx Desktop] No Tauri drag API available.");
} catch (err) {}
}
}
@@ -177,7 +163,6 @@
function mountTitleBar() {
if (!document.body) {
console.error("[Onyx Desktop] document.body not found");
return;
}
@@ -193,7 +178,6 @@
const titleBar = buildTitleBar();
document.body.insertBefore(titleBar, document.body.firstChild);
injectStyles();
console.log("[Onyx Desktop] Title bar injected");
}
function syncViewportHeight() {

View File

@@ -5,8 +5,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "onyx"
version = "0.0.0"
# TODO(jamison): Upgrade dependencies until they're compatible with python >3.13.
requires-python = ">=3.11,<3.13"
requires-python = ">=3.11"
# Shared dependencies between backend and model_server
dependencies = [
"aioboto3==15.1.0",
@@ -91,7 +90,7 @@ backend = [
"python-dateutil==2.8.2",
"python-gitlab==5.6.0",
"python-pptx==0.6.23",
"pypdf==6.1.3",
"pypdf==6.6.0",
"pytest-mock==3.12.0",
"pytest-playwright==0.7.0",
"python-docx==1.1.2",
@@ -111,8 +110,8 @@ backend = [
"tiktoken==0.7.0",
"timeago==1.0.16",
"types-openpyxl==3.0.4.7",
"unstructured==0.15.1",
"unstructured-client==0.25.4",
"unstructured==0.18.27",
"unstructured-client==0.42.6",
"zulip==0.8.2",
"hubspot-api-client==11.1.0",
"asana==5.0.8",
@@ -143,7 +142,7 @@ dev = [
"matplotlib==3.10.8",
"mypy-extensions==1.0.0",
"mypy==1.13.0",
"onyx-devtools==0.2.0",
"onyx-devtools==0.6.2",
"openapi-generator-cli==7.17.0",
"pandas-stubs==2.2.3.241009",
"pre-commit==3.2.2",
@@ -181,7 +180,7 @@ ee = [
model_server = [
"accelerate==1.6.0",
"einops==0.8.1",
"numpy==1.26.4",
"numpy==2.4.1",
"safetensors==0.5.3",
"sentence-transformers==4.0.2",
"torch==2.6.0",

1785
uv.lock generated

File diff suppressed because it is too large Load Diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 548 B

After

Width:  |  Height:  |  Size: 581 B

View File

@@ -25,7 +25,7 @@ export default function OnyxApiKeyForm({
return (
<Modal open onOpenChange={onClose}>
<Modal.Content tall>
<Modal.Content width="sm" height="lg">
<Modal.Header
icon={SvgKey}
title={isUpdate ? "Update API Key" : "Create a new API Key"}

View File

@@ -105,7 +105,7 @@ function Main() {
{popup}
<Modal open={!!fullApiKey}>
<Modal.Content small>
<Modal.Content width="sm" height="sm">
<Modal.Header
title="New API Key"
icon={SvgKey}

View File

@@ -10,10 +10,7 @@ import {
} from "@/lib/types";
import BackButton from "@/refresh-components/buttons/BackButton";
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
import {
FetchAssistantsResponse,
fetchAssistantsSS,
} from "@/lib/assistants/fetchAssistantsSS";
import { FetchAssistantsResponse, fetchAssistantsSS } from "@/lib/agentsSS";
import { getStandardAnswerCategoriesIfEE } from "@/components/standardAnswers/getStandardAnswerCategoriesIfEE";
async function EditslackChannelConfigPage(props: {

View File

@@ -4,7 +4,7 @@ import { fetchSS } from "@/lib/utilsSS";
import { ErrorCallout } from "@/components/ErrorCallout";
import { DocumentSetSummary, ValidSources } from "@/lib/types";
import BackButton from "@/refresh-components/buttons/BackButton";
import { fetchAssistantsSS } from "@/lib/assistants/fetchAssistantsSS";
import { fetchAssistantsSS } from "@/lib/agentsSS";
import { getStandardAnswerCategoriesIfEE } from "@/components/standardAnswers/getStandardAnswerCategoriesIfEE";
import { redirect } from "next/navigation";
import { SourceIcon } from "@/components/SourceIcon";

View File

@@ -1,3 +1,5 @@
"use client";
import { useState, useEffect } from "react";
import { Form, Formik, FormikProps } from "formik";
import { SelectorFormField, TextFormField } from "@/components/Field";
@@ -28,13 +30,7 @@ import { DisplayModels } from "./components/DisplayModels";
import { fetchBedrockModels } from "../utils";
import Separator from "@/refresh-components/Separator";
import Text from "@/refresh-components/texts/Text";
import {
Tabs,
TabsList,
TabsTrigger,
TabsContent,
} from "@/refresh-components/tabs/tabs";
import { cn } from "@/lib/utils";
import Tabs from "@/refresh-components/Tabs";
export const BEDROCK_PROVIDER_NAME = "bedrock";
const BEDROCK_DISPLAY_NAME = "AWS Bedrock";
@@ -161,33 +157,25 @@ function BedrockFormInternals({
onValueChange={(value) =>
formikProps.setFieldValue(FIELD_BEDROCK_AUTH_METHOD, value)
}
className="mt-2"
>
<TabsList>
<TabsTrigger value={AUTH_METHOD_IAM}>IAM Role</TabsTrigger>
<TabsTrigger value={AUTH_METHOD_ACCESS_KEY}>Access Key</TabsTrigger>
<TabsTrigger value={AUTH_METHOD_LONG_TERM_API_KEY}>
<Tabs.List>
<Tabs.Trigger value={AUTH_METHOD_IAM}>IAM Role</Tabs.Trigger>
<Tabs.Trigger value={AUTH_METHOD_ACCESS_KEY}>
Access Key
</Tabs.Trigger>
<Tabs.Trigger value={AUTH_METHOD_LONG_TERM_API_KEY}>
Long-term API Key
</TabsTrigger>
</TabsList>
</Tabs.Trigger>
</Tabs.List>
<TabsContent
value={AUTH_METHOD_IAM}
className="data-[state=active]:animate-fade-in-scale"
>
<Tabs.Content value={AUTH_METHOD_IAM}>
<Text as="p" text03>
Uses the IAM role attached to your AWS environment. Recommended
for EC2, ECS, Lambda, or other AWS services.
</Text>
</TabsContent>
</Tabs.Content>
<TabsContent
value={AUTH_METHOD_ACCESS_KEY}
className={cn(
"data-[state=active]:animate-fade-in-scale",
"mt-4 ml-2"
)}
>
<Tabs.Content value={AUTH_METHOD_ACCESS_KEY}>
<div className="flex flex-col gap-4">
<TextFormField
name={FIELD_AWS_ACCESS_KEY_ID}
@@ -200,15 +188,9 @@ function BedrockFormInternals({
placeholder="wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"
/>
</div>
</TabsContent>
</Tabs.Content>
<TabsContent
value={AUTH_METHOD_LONG_TERM_API_KEY}
className={cn(
"data-[state=active]:animate-fade-in-scale",
"mt-4 ml-2"
)}
>
<Tabs.Content value={AUTH_METHOD_LONG_TERM_API_KEY}>
<div className="flex flex-col gap-4">
<PasswordInputTypeInField
name={FIELD_AWS_BEARER_TOKEN_BEDROCK}
@@ -216,7 +198,7 @@ function BedrockFormInternals({
placeholder="Your long-term API key"
/>
</div>
</TabsContent>
</Tabs.Content>
</Tabs>
</div>

View File

@@ -1,9 +1,10 @@
"use client";
import { useState, ReactNode } from "react";
import useSWR, { useSWRConfig, KeyedMutator } from "swr";
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
import {
LLMProviderView,
ModelConfiguration,
WellKnownLLMProviderDescriptor,
} from "../../interfaces";
import { errorHandlingFetcher } from "@/lib/fetcher";
@@ -114,7 +115,7 @@ export function ProviderFormEntrypointWrapper({
{formIsVisible && (
<Modal open onOpenChange={onClose}>
<Modal.Content medium>
<Modal.Content>
<Modal.Header
icon={SvgSettings}
title={`Setup ${providerName}`}
@@ -196,7 +197,7 @@ export function ProviderFormEntrypointWrapper({
{formIsVisible && (
<Modal open onOpenChange={onClose}>
<Modal.Content medium>
<Modal.Content>
<Modal.Header
icon={SvgSettings}
title={`${existingLlmProvider ? "Configure" : "Setup"} ${

View File

@@ -130,7 +130,7 @@ export default function UpgradingPage({
{popup}
{isCancelling && (
<Modal open onOpenChange={() => setIsCancelling(false)}>
<Modal.Content small>
<Modal.Content width="sm" height="sm">
<Modal.Header
icon={SvgX}
title="Cancel Embedding Model Switch"

View File

@@ -81,7 +81,7 @@ export const WebProviderSetupModal = memo(
return (
<Modal open={isOpen} onOpenChange={(open) => !open && onClose()}>
<Modal.Content mini preventAccidentalClose>
<Modal.Content width="sm" preventAccidentalClose>
<Modal.Header
icon={LogoArrangement}
title={`Set up ${providerLabel}`}

View File

@@ -1,6 +1,7 @@
export type WebContentProviderType =
| "firecrawl"
| "onyx_web_crawler"
| "exa"
| (string & {});
export const CONTENT_PROVIDERS_URL = "/api/admin/web-search/content-providers";
@@ -23,6 +24,13 @@ export const CONTENT_PROVIDER_DETAILS: Record<
"Connect Firecrawl to fetch and summarize page content from search results.",
logoSrc: "/firecrawl.svg",
},
exa: {
label: "Exa",
subtitle: "Exa.ai",
description:
"Use Exa to fetch and summarize page content from search results.",
logoSrc: "/Exa.svg",
},
};
/**
@@ -64,6 +72,7 @@ const CONTENT_PROVIDER_CAPABILITIES: Record<
base_url: ["base_url", "api_base_url"],
},
},
// exa uses default capabilities
};
const DEFAULT_CONTENT_PROVIDER_CAPABILITIES: ContentProviderCapabilities = {

View File

@@ -136,6 +136,17 @@ export default function Page() {
const isLoading = isLoadingSearchProviders || isLoadingContentProviders;
// Exa shares API key between search and content providers
const exaSearchProvider = searchProviders.find(
(p) => p.provider_type === "exa"
);
const exaContentProvider = contentProviders.find(
(p) => p.provider_type === "exa"
);
const hasSharedExaKey =
(exaSearchProvider?.has_api_key || exaContentProvider?.has_api_key) ??
false;
// Modal form state is owned by reducers
const openSearchModal = (
@@ -145,12 +156,18 @@ export default function Page() {
const requiresApiKey = searchProviderRequiresApiKey(providerType);
const hasStoredKey = provider?.has_api_key ?? false;
// For Exa search provider, check if we can use the shared Exa key
const isExa = providerType === "exa";
const canUseSharedExaKey = isExa && hasSharedExaKey && !hasStoredKey;
dispatchSearchModal({
type: "OPEN",
providerType,
existingProviderId: provider?.id ?? null,
initialApiKeyValue:
requiresApiKey && hasStoredKey ? MASKED_API_KEY_PLACEHOLDER : "",
requiresApiKey && (hasStoredKey || canUseSharedExaKey)
? MASKED_API_KEY_PLACEHOLDER
: "",
initialConfigValue: getSingleConfigFieldValueForForm(
providerType,
provider
@@ -165,11 +182,16 @@ export default function Page() {
const hasStoredKey = provider?.has_api_key ?? false;
const defaultFirecrawlBaseUrl = "https://api.firecrawl.dev/v1/scrape";
// For Exa content provider, check if we can use the shared Exa key
const isExa = providerType === "exa";
const canUseSharedExaKey = isExa && hasSharedExaKey && !hasStoredKey;
dispatchContentModal({
type: "OPEN",
providerType,
existingProviderId: provider?.id ?? null,
initialApiKeyValue: hasStoredKey ? MASKED_API_KEY_PLACEHOLDER : "",
initialApiKeyValue:
hasStoredKey || canUseSharedExaKey ? MASKED_API_KEY_PLACEHOLDER : "",
initialConfigValue:
providerType === "firecrawl"
? getSingleContentConfigFieldValueForForm(
@@ -339,6 +361,17 @@ export default function Page() {
} satisfies WebContentProviderView;
}
if (providerType === "exa") {
return {
id: -3,
name: "Exa",
provider_type: "exa",
is_active: false,
config: null,
has_api_key: hasSharedExaKey,
} satisfies WebContentProviderView;
}
return null;
}).filter(Boolean) as WebContentProviderView[];
@@ -347,7 +380,7 @@ export default function Page() {
);
return [...ordered, ...additional];
}, [contentProviders]);
}, [contentProviders, hasSharedExaKey]);
const currentContentProviderType =
getCurrentContentProviderType(contentProviders);
@@ -468,7 +501,12 @@ export default function Page() {
onClose: () => {
dispatchSearchModal({ type: "CLOSE" });
},
mutate: mutateSearchProviders,
mutate: async () => {
await mutateSearchProviders();
if (selectedProviderType === "exa") {
await mutateContentProviders();
}
},
});
};
@@ -678,6 +716,23 @@ export default function Page() {
selectedContentProviderType
: "";
if (selectedContentProviderType === "exa") {
return (
<>
Paste your{" "}
<a
href="https://dashboard.exa.ai/api-keys"
target="_blank"
rel="noopener noreferrer"
className="underline"
>
API key
</a>{" "}
from Exa to enable crawling.
</>
);
}
return selectedContentProviderType === "firecrawl" ? (
<>
Paste your <span className="underline">API key</span> from Firecrawl to
@@ -730,6 +785,10 @@ export default function Page() {
dispatchContentModal({ type: "SET_PHASE", phase: "saving" });
dispatchContentModal({ type: "CLEAR_MESSAGE" });
const apiKeyChangedForContentProvider =
contentModal.apiKeyValue !== MASKED_API_KEY_PLACEHOLDER &&
contentProviderValues.apiKey.length > 0;
await connectProviderFlow({
category: "content",
providerType: selectedContentProviderType,
@@ -740,9 +799,7 @@ export default function Page() {
CONTENT_PROVIDER_DETAILS[selectedContentProviderType]?.label ??
selectedContentProviderType,
providerRequiresApiKey: true,
apiKeyChangedForProvider:
contentModal.apiKeyValue !== MASKED_API_KEY_PLACEHOLDER &&
contentProviderValues.apiKey.length > 0,
apiKeyChangedForProvider: apiKeyChangedForContentProvider,
apiKey: contentProviderValues.apiKey,
config,
configChanged,
@@ -759,7 +816,12 @@ export default function Page() {
onClose: () => {
dispatchContentModal({ type: "CLOSE" });
},
mutate: mutateContentProviders,
mutate: async () => {
await mutateContentProviders();
if (selectedContentProviderType === "exa") {
await mutateSearchProviders();
}
},
});
};
@@ -1052,7 +1114,8 @@ export default function Page() {
const canActivate =
providerId > 0 ||
provider.provider_type === "onyx_web_crawler";
provider.provider_type === "onyx_web_crawler" ||
isConfigured;
return {
label: "Set as Default",

View File

@@ -125,7 +125,7 @@ export default function IndexAttemptErrorsModal({
return (
<Modal open onOpenChange={onClose}>
<Modal.Content large>
<Modal.Content width="lg" height="full">
<Modal.Header
icon={SvgAlertTriangle}
title="Indexing Errors"

View File

@@ -353,7 +353,7 @@ export default function InlineFileManagement({
{/* Confirmation Modal */}
<Modal open={showSaveConfirm} onOpenChange={setShowSaveConfirm}>
<Modal.Content mini>
<Modal.Content width="sm">
<Modal.Header
icon={SvgFolderPlus}
title="Confirm File Changes"

View File

@@ -128,7 +128,7 @@ export default function ReIndexModal({
return (
<Modal open onOpenChange={hide}>
<Modal.Content small>
<Modal.Content width="sm" height="sm">
<Modal.Header icon={SvgRefreshCw} title="Run Indexing" onClose={hide} />
<Modal.Body>
<Text as="p">

View File

@@ -584,7 +584,7 @@ export default function AddConnector({
open
onOpenChange={() => setCreateCredentialFormToggle(false)}
>
<Modal.Content medium>
<Modal.Content>
<Modal.Header
icon={SvgKey}
title={`Create a ${getSourceDisplayName(

View File

@@ -9,12 +9,7 @@ import FileInput from "./ConnectorInput/FileInput";
import { ConfigurableSources } from "@/lib/types";
import { Credential } from "@/lib/connectors/credentials";
import CollapsibleSection from "@/app/admin/assistants/CollapsibleSection";
import {
Tabs,
TabsContent,
TabsList,
TabsTrigger,
} from "@/components/ui/fully_wrapped_tabs";
import Tabs from "@/refresh-components/Tabs";
import { useFormikContext } from "formik";
// Define a general type for form values
@@ -60,7 +55,6 @@ const TabsField: FC<TabsFieldProps> = ({
) : (
<Tabs
defaultValue={tabField.defaultTab || tabField.tabs[0]?.value}
className="w-full"
onValueChange={(newTab) => {
// Clear values from other tabs but preserve defaults
tabField.tabs.forEach((tab) => {
@@ -75,15 +69,15 @@ const TabsField: FC<TabsFieldProps> = ({
});
}}
>
<TabsList>
<Tabs.List>
{tabField.tabs.map((tab) => (
<TabsTrigger key={tab.value} value={tab.value}>
<Tabs.Trigger key={tab.value} value={tab.value}>
{tab.label}
</TabsTrigger>
</Tabs.Trigger>
))}
</TabsList>
</Tabs.List>
{tabField.tabs.map((tab) => (
<TabsContent key={tab.value} value={tab.value} className="">
<Tabs.Content key={tab.value} value={tab.value}>
{tab.fields.map((subField, index, array) => {
// Check visibility condition first
if (
@@ -112,7 +106,7 @@ const TabsField: FC<TabsFieldProps> = ({
</div>
);
})}
</TabsContent>
</Tabs.Content>
))}
</Tabs>
)}

View File

@@ -323,7 +323,7 @@ const RerankingDetailsForm = forwardRef<
open
onOpenChange={() => setShowGpuWarningModalModel(null)}
>
<Modal.Content small>
<Modal.Content width="sm" height="sm">
<Modal.Header
icon={SvgAlertTriangle}
title="GPU Not Enabled"
@@ -358,7 +358,7 @@ const RerankingDetailsForm = forwardRef<
setShowLiteLLMConfigurationModal(false);
}}
>
<Modal.Content medium>
<Modal.Content>
<Modal.Header
icon={SvgKey}
title="API Key Configuration"
@@ -462,7 +462,7 @@ const RerankingDetailsForm = forwardRef<
setIsApiKeyModalOpen(false);
}}
>
<Modal.Content medium>
<Modal.Content>
<Modal.Header
icon={SvgKey}
title="API Key Configuration"

View File

@@ -14,7 +14,7 @@ export default function AlreadyPickedModal({
}: AlreadyPickedModalProps) {
return (
<Modal open onOpenChange={onClose}>
<Modal.Content small>
<Modal.Content width="sm" height="sm">
<Modal.Header
icon={SvgCheck}
title={`${model.model_name} already chosen`}

View File

@@ -21,7 +21,7 @@ export default function DeleteCredentialsModal({
}: DeleteCredentialsModalProps) {
return (
<Modal open onOpenChange={onCancel}>
<Modal.Content small>
<Modal.Content width="sm" height="sm">
<Modal.Header
icon={SvgTrash}
title={`Delete ${getFormattedProviderName(

View File

@@ -13,7 +13,7 @@ export default function InstantSwitchConfirmModal({
}: InstantSwitchConfirmModalProps) {
return (
<Modal open onOpenChange={onClose}>
<Modal.Content small>
<Modal.Content width="sm" height="sm">
<Modal.Header
icon={SvgAlertTriangle}
title="Are you sure you want to do an instant switch?"

View File

@@ -20,7 +20,7 @@ export default function ModelSelectionConfirmationModal({
}: ModelSelectionConfirmationModalProps) {
return (
<Modal open onOpenChange={onCancel}>
<Modal.Content tall>
<Modal.Content width="sm" height="lg">
<Modal.Header
icon={SvgServer}
title="Update Embedding Model"

View File

@@ -186,7 +186,7 @@ export default function ProviderCreationModal({
return (
<Modal open onOpenChange={onCancel}>
<Modal.Content small>
<Modal.Content width="sm" height="sm">
<Modal.Header
icon={SvgSettings}
title={`Configure ${getFormattedProviderName(

View File

@@ -17,7 +17,7 @@ export default function SelectModelModal({
}: SelectModelModalProps) {
return (
<Modal open onOpenChange={onCancel}>
<Modal.Content small>
<Modal.Content width="sm" height="sm">
<Modal.Header
icon={SvgServer}
title={`Select ${model.model_name}`}

Some files were not shown because too many files have changed in this diff Show More