mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-26 01:52:45 +00:00
Compare commits
86 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2129e77bdf | ||
|
|
40dec09e35 | ||
|
|
b30737b6b2 | ||
|
|
caf8b85ec2 | ||
|
|
1d13580b63 | ||
|
|
00390c53e0 | ||
|
|
66656df9e6 | ||
|
|
51d26d7e4c | ||
|
|
198ac8ccbc | ||
|
|
ee6d33f484 | ||
|
|
7bcb72d055 | ||
|
|
876894e097 | ||
|
|
7215f56b25 | ||
|
|
0fd1c34014 | ||
|
|
9e24b41b7b | ||
|
|
ab3853578b | ||
|
|
7db969d36a | ||
|
|
6cdeb71656 | ||
|
|
2c4b2c68b4 | ||
|
|
5301ee7cef | ||
|
|
f8e6716875 | ||
|
|
755c65fd8a | ||
|
|
90cf5f49e3 | ||
|
|
d4068c2b07 | ||
|
|
fd6fa43fe1 | ||
|
|
8d5013bf01 | ||
|
|
dabd7c6263 | ||
|
|
c8c0389675 | ||
|
|
9cfcfb12e1 | ||
|
|
786a0c2bd0 | ||
|
|
0cd8d3402b | ||
|
|
3fa397b24d | ||
|
|
e0a97230b8 | ||
|
|
7f1272117a | ||
|
|
79302f19be | ||
|
|
4a91e644d4 | ||
|
|
ca0318f16e | ||
|
|
be8e0b3a98 | ||
|
|
49c4814c70 | ||
|
|
2f945613a2 | ||
|
|
e9242ca3a8 | ||
|
|
a150de761a | ||
|
|
0e792ca6c9 | ||
|
|
6be467a4ac | ||
|
|
dd91bfcfe6 | ||
|
|
8a72291781 | ||
|
|
b2d71da4eb | ||
|
|
6e2f851c62 | ||
|
|
be078edcb4 | ||
|
|
194c54aca3 | ||
|
|
9fa7221e24 | ||
|
|
3a5c7ef8ee | ||
|
|
84458aa0bf | ||
|
|
de57bfa35f | ||
|
|
386f8f31ed | ||
|
|
376f04caea | ||
|
|
4b0a3c2b04 | ||
|
|
1bd9f9d9a6 | ||
|
|
4ac10abaea | ||
|
|
a66a283af4 | ||
|
|
bf5da04166 | ||
|
|
693487f855 | ||
|
|
d02a76d7d1 | ||
|
|
28e05c6e90 | ||
|
|
a18f546921 | ||
|
|
e98dea149e | ||
|
|
027c165794 | ||
|
|
14ebe912c8 | ||
|
|
a63b906789 | ||
|
|
92a68a3c22 | ||
|
|
95db4ed9c7 | ||
|
|
5134d60d48 | ||
|
|
651a54470d | ||
|
|
269d243b67 | ||
|
|
0286dd7da9 | ||
|
|
f3a0710d69 | ||
|
|
037c2aee3a | ||
|
|
9b2f3d234d | ||
|
|
7646399cd4 | ||
|
|
d913b93d10 | ||
|
|
8a0ce4c294 | ||
|
|
862c140763 | ||
|
|
47487f1940 | ||
|
|
e3471df940 | ||
|
|
fb33c815b3 | ||
|
|
5c6594be73 |
387
.github/workflows/deployment.yml
vendored
387
.github/workflows/deployment.yml
vendored
@@ -8,9 +8,7 @@ on:
|
||||
|
||||
# Set restrictive default permissions for all jobs. Jobs that need more permissions
|
||||
# should explicitly declare them.
|
||||
permissions:
|
||||
# Required for OIDC authentication with AWS
|
||||
id-token: write # zizmor: ignore[excessive-permissions]
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }}
|
||||
@@ -152,30 +150,16 @@ jobs:
|
||||
if: always() && needs.check-version-tag.result == 'failure' && github.event_name != 'workflow_dispatch'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
environment: release
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
MONITOR_DEPLOYMENTS_WEBHOOK, deploy/monitor-deployments-webhook
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ env.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
failed-jobs: "• check-version-tag"
|
||||
title: "🚨 Version Tag Check Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
@@ -184,7 +168,6 @@ jobs:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-desktop == 'true'
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: write
|
||||
actions: read
|
||||
strategy:
|
||||
@@ -202,33 +185,12 @@ jobs:
|
||||
|
||||
runs-on: ${{ matrix.platform }}
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6.0.1
|
||||
with:
|
||||
# NOTE: persist-credentials is needed for tauri-action to create GitHub releases.
|
||||
persist-credentials: true # zizmor: ignore[artipacked]
|
||||
|
||||
- name: Configure AWS credentials
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
APPLE_ID, deploy/apple-id
|
||||
APPLE_PASSWORD, deploy/apple-password
|
||||
APPLE_CERTIFICATE, deploy/apple-certificate
|
||||
APPLE_CERTIFICATE_PASSWORD, deploy/apple-certificate-password
|
||||
KEYCHAIN_PASSWORD, deploy/keychain-password
|
||||
APPLE_TEAM_ID, deploy/apple-team-id
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: install dependencies (ubuntu only)
|
||||
if: startsWith(matrix.platform, 'ubuntu-')
|
||||
run: |
|
||||
@@ -323,40 +285,15 @@ jobs:
|
||||
|
||||
Write-Host "Versions set to: $VERSION"
|
||||
|
||||
- name: Import Apple Developer Certificate
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
run: |
|
||||
echo $APPLE_CERTIFICATE | base64 --decode > certificate.p12
|
||||
security create-keychain -p "$KEYCHAIN_PASSWORD" build.keychain
|
||||
security default-keychain -s build.keychain
|
||||
security unlock-keychain -p "$KEYCHAIN_PASSWORD" build.keychain
|
||||
security set-keychain-settings -t 3600 -u build.keychain
|
||||
security import certificate.p12 -k build.keychain -P "$APPLE_CERTIFICATE_PASSWORD" -T /usr/bin/codesign
|
||||
security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k "$KEYCHAIN_PASSWORD" build.keychain
|
||||
security find-identity -v -p codesigning build.keychain
|
||||
|
||||
- name: Verify Certificate
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
run: |
|
||||
CERT_INFO=$(security find-identity -v -p codesigning build.keychain | grep -E "(Developer ID Application|Apple Distribution|Apple Development)" | head -n 1)
|
||||
CERT_ID=$(echo "$CERT_INFO" | awk -F'"' '{print $2}')
|
||||
echo "CERT_ID=$CERT_ID" >> $GITHUB_ENV
|
||||
echo "Certificate imported."
|
||||
|
||||
- uses: tauri-apps/tauri-action@73fb865345c54760d875b94642314f8c0c894afa # ratchet:tauri-apps/tauri-action@action-v0.6.1
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
APPLE_ID: ${{ env.APPLE_ID }}
|
||||
APPLE_PASSWORD: ${{ env.APPLE_PASSWORD }}
|
||||
APPLE_SIGNING_IDENTITY: ${{ env.CERT_ID }}
|
||||
APPLE_TEAM_ID: ${{ env.APPLE_TEAM_ID }}
|
||||
with:
|
||||
tagName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseBody: "See the assets to download this version and install."
|
||||
releaseDraft: true
|
||||
prerelease: false
|
||||
assetNamePattern: "[name]_[arch][ext]"
|
||||
args: ${{ matrix.args }}
|
||||
|
||||
build-web-amd64:
|
||||
@@ -368,7 +305,6 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-web-amd64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -381,20 +317,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -409,8 +331,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
@@ -441,7 +363,6 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-web-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -454,20 +375,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -482,8 +389,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
@@ -516,34 +423,19 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-merge-web
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -579,7 +471,6 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-web-cloud-amd64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -592,20 +483,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -620,8 +497,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
@@ -660,7 +537,6 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-web-cloud-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -673,20 +549,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -701,8 +563,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
@@ -743,34 +605,19 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-merge-web-cloud
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -803,7 +650,6 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-backend-amd64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -816,20 +662,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -844,8 +676,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
@@ -875,7 +707,6 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-backend-arm64
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -888,20 +719,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -916,8 +733,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
@@ -949,34 +766,19 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-merge-backend
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -1013,7 +815,6 @@ jobs:
|
||||
- volume=40gb
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -1026,20 +827,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -1056,8 +843,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
@@ -1092,7 +879,6 @@ jobs:
|
||||
- volume=40gb
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
@@ -1105,20 +891,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
@@ -1135,8 +907,8 @@ jobs:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
@@ -1172,34 +944,19 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-merge-model-server
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
@@ -1237,26 +994,11 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-trivy-scan-web
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
@@ -1272,8 +1014,8 @@ jobs:
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
@@ -1292,26 +1034,11 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-trivy-scan-web-cloud
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
@@ -1327,8 +1054,8 @@ jobs:
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
@@ -1347,7 +1074,6 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-trivy-scan-backend
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }}
|
||||
steps:
|
||||
@@ -1358,20 +1084,6 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
@@ -1388,8 +1100,8 @@ jobs:
|
||||
-v ${{ github.workspace }}/backend/.trivyignore:/tmp/.trivyignore:ro \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
@@ -1409,26 +1121,11 @@ jobs:
|
||||
- run-id=${{ github.run_id }}-trivy-scan-model-server
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
@@ -1444,8 +1141,8 @@ jobs:
|
||||
docker run --rm -v $HOME/.cache/trivy:/root/.cache/trivy \
|
||||
-e TRIVY_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-db:2" \
|
||||
-e TRIVY_JAVA_DB_REPOSITORY="public.ecr.aws/aquasecurity/trivy-java-db:1" \
|
||||
-e TRIVY_USERNAME="${{ env.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ env.DOCKER_TOKEN }}" \
|
||||
-e TRIVY_USERNAME="${{ secrets.DOCKER_USERNAME }}" \
|
||||
-e TRIVY_PASSWORD="${{ secrets.DOCKER_TOKEN }}" \
|
||||
aquasec/trivy@sha256:a22415a38938a56c379387a8163fcb0ce38b10ace73e593475d3658d578b2436 \
|
||||
image \
|
||||
--skip-version-check \
|
||||
@@ -1473,26 +1170,12 @@ jobs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 90
|
||||
environment: release
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
MONITOR_DEPLOYMENTS_WEBHOOK, deploy/monitor-deployments-webhook
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Determine failed jobs
|
||||
id: failed-jobs
|
||||
shell: bash
|
||||
@@ -1558,7 +1241,7 @@ jobs:
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ env.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
failed-jobs: ${{ steps.failed-jobs.outputs.jobs }}
|
||||
title: "🚨 Deployment Workflow Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
|
||||
3
.github/workflows/pr-python-checks.yml
vendored
3
.github/workflows/pr-python-checks.yml
vendored
@@ -50,9 +50,8 @@ jobs:
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: backend/.mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
key: mypy-${{ runner.os }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
restore-keys: |
|
||||
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
|
||||
mypy-${{ runner.os }}-
|
||||
|
||||
- name: Run MyPy
|
||||
|
||||
138
.github/workflows/pr-python-model-tests.yml
vendored
138
.github/workflows/pr-python-model-tests.yml
vendored
@@ -5,11 +5,6 @@ on:
|
||||
# This cron expression runs the job daily at 16:00 UTC (9am PT)
|
||||
- cron: "0 16 * * *"
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
branch:
|
||||
description: 'Branch to run the workflow on'
|
||||
required: false
|
||||
default: 'main'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -31,7 +26,11 @@ env:
|
||||
jobs:
|
||||
model-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}-model-check"]
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- "run-id=${{ github.run_id }}-model-check"
|
||||
- "extras=ecr-cache"
|
||||
timeout-minutes: 45
|
||||
|
||||
env:
|
||||
@@ -43,104 +42,83 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Python and Install Dependencies
|
||||
uses: ./.github/actions/setup-python-and-install-dependencies
|
||||
with:
|
||||
requirements: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435
|
||||
|
||||
# We don't need to build the Web Docker image since it's not yet used
|
||||
# in the integration tests. We have a separate action to verify that it builds
|
||||
# successfully.
|
||||
- name: Pull Model Server Docker image
|
||||
run: |
|
||||
docker pull onyxdotapp/onyx-model-server:latest
|
||||
docker tag onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:test
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # ratchet:actions/setup-python@v6
|
||||
- name: Build and load
|
||||
uses: docker/bake-action@5be5f02ff8819ecd3092ea6b2e6261c31774f2b4 # ratchet:docker/bake-action@v6
|
||||
env:
|
||||
TAG: model-server-${{ github.run_id }}
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
|
||||
- name: Install Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
load: true
|
||||
targets: model-server
|
||||
set: |
|
||||
model-server.cache-from=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }}
|
||||
model-server.cache-from=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
model-server.cache-from=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
|
||||
model-server.cache-from=type=registry,ref=onyxdotapp/onyx-model-server:latest
|
||||
model-server.cache-to=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max
|
||||
model-server.cache-to=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
model-server.cache-to=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
|
||||
|
||||
- name: Start Docker containers
|
||||
id: start_docker
|
||||
env:
|
||||
IMAGE_TAG: model-server-${{ github.run_id }}
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
AUTH_TYPE=basic \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.model-server-test.yml up -d indexing_model_server
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Service did not become ready in 5 minutes."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Use curl with error handling to ignore specific exit code 56
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:9000/api/health || echo "curl_error")
|
||||
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "Service is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
|
||||
else
|
||||
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
|
||||
sleep 5
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
docker compose \
|
||||
-f docker-compose.yml \
|
||||
-f docker-compose.dev.yml \
|
||||
up -d --wait \
|
||||
inference_model_server
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/llm
|
||||
py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/embedding
|
||||
|
||||
- name: Alert on Failure
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
env:
|
||||
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
|
||||
REPO: ${{ github.repository }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
curl -X POST \
|
||||
-H 'Content-type: application/json' \
|
||||
--data "{\"text\":\"Scheduled Model Tests failed! Check the run at: https://github.com/${REPO}/actions/runs/${RUN_ID}\"}" \
|
||||
$SLACK_WEBHOOK
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
|
||||
failed-jobs: model-check
|
||||
title: "🚨 Scheduled Model Tests failed!"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
|
||||
- name: Dump all-container logs (optional)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.model-server-test.yml logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
|
||||
@@ -146,6 +146,22 @@ repos:
|
||||
pass_filenames: false
|
||||
files: \.tf$
|
||||
|
||||
- id: npm-install
|
||||
name: npm install
|
||||
description: "Automatically run 'npm install' after a checkout, pull or rebase"
|
||||
language: system
|
||||
entry: bash -c 'cd web && npm install --no-save'
|
||||
pass_filenames: false
|
||||
files: ^web/package(-lock)?\.json$
|
||||
stages: [post-checkout, post-merge, post-rewrite]
|
||||
- id: npm-install-check
|
||||
name: npm install --package-lock-only
|
||||
description: "Check the 'web/package-lock.json' is updated"
|
||||
language: system
|
||||
entry: bash -c 'cd web && npm install --package-lock-only'
|
||||
pass_filenames: false
|
||||
files: ^web/package(-lock)?\.json$
|
||||
|
||||
# Uses tsgo (TypeScript's native Go compiler) for ~10x faster type checking.
|
||||
# This is a preview package - if it breaks:
|
||||
# 1. Try updating: cd web && npm update @typescript/native-preview
|
||||
|
||||
6
.vscode/env_template.txt
vendored
6
.vscode/env_template.txt
vendored
@@ -17,12 +17,6 @@ LOG_ONYX_MODEL_INTERACTIONS=True
|
||||
LOG_LEVEL=debug
|
||||
|
||||
|
||||
# This passes top N results to LLM an additional time for reranking prior to
|
||||
# answer generation.
|
||||
# This step is quite heavy on token usage so we disable it for dev generally.
|
||||
DISABLE_LLM_DOC_RELEVANCE=False
|
||||
|
||||
|
||||
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically).
|
||||
OAUTH_CLIENT_ID=<REPLACE THIS>
|
||||
OAUTH_CLIENT_SECRET=<REPLACE THIS>
|
||||
|
||||
259
CONTRIBUTING.md
259
CONTRIBUTING.md
@@ -1,262 +1,31 @@
|
||||
<!-- ONYX_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/CONTRIBUTING.md"} -->
|
||||
|
||||
# Contributing to Onyx
|
||||
|
||||
Hey there! We are so excited that you're interested in Onyx.
|
||||
|
||||
As an open source project in a rapidly changing space, we welcome all contributions.
|
||||
|
||||
## 💃 Guidelines
|
||||
## Contribution Opportunities
|
||||
The [GitHub Issues](https://github.com/onyx-dot-app/onyx/issues) page is a great place to look for and share contribution ideas.
|
||||
|
||||
### Contribution Opportunities
|
||||
If you have your own feature that you would like to build please create an issue and community members can provide feedback and
|
||||
thumb it up if they feel a common need.
|
||||
|
||||
The [GitHub Issues](https://github.com/onyx-dot-app/onyx/issues) page is a great place to start for contribution ideas.
|
||||
|
||||
To ensure that your contribution is aligned with the project's direction, please reach out to any maintainer on the Onyx team
|
||||
via [Discord](https://discord.gg/4NA5SbzrWb) or [email](mailto:hello@onyx.app).
|
||||
## Contributing Code
|
||||
Please reference the documents in contributing_guides folder to ensure that the code base is kept to a high standard.
|
||||
1. dev_setup.md (start here): gives you a guide to setting up a local development environment.
|
||||
2. contribution_process.md: how to ensure you are building valuable features that will get reviewed and merged.
|
||||
3. best_practices.md: before asking for reviews, ensure your changes meet the repo code quality standards.
|
||||
|
||||
Issues that have been explicitly approved by the maintainers (aligned with the direction of the project)
|
||||
will be marked with the `approved by maintainers` label.
|
||||
Issues marked `good first issue` are an especially great place to start.
|
||||
|
||||
**Connectors** to other tools are another great place to contribute. For details on how, refer to this
|
||||
[README.md](https://github.com/onyx-dot-app/onyx/blob/main/backend/onyx/connectors/README.md).
|
||||
|
||||
If you have a new/different contribution in mind, we'd love to hear about it!
|
||||
Your input is vital to making sure that Onyx moves in the right direction.
|
||||
Before starting on implementation, please raise a GitHub issue.
|
||||
|
||||
Also, always feel free to message the founders (Chris Weaver / Yuhong Sun) on
|
||||
[Discord](https://discord.gg/4NA5SbzrWb) directly about anything at all.
|
||||
|
||||
### Contributing Code
|
||||
|
||||
To contribute to this project, please follow the
|
||||
To contribute, please follow the
|
||||
["fork and pull request"](https://docs.github.com/en/get-started/quickstart/contributing-to-projects) workflow.
|
||||
When opening a pull request, mention related issues and feel free to tag relevant maintainers.
|
||||
|
||||
Before creating a pull request please make sure that the new changes conform to the formatting and linting requirements.
|
||||
See the [Formatting and Linting](#formatting-and-linting) section for how to run these checks locally.
|
||||
|
||||
### Getting Help 🙋
|
||||
## Getting Help 🙋
|
||||
We have support channels and generally interesting discussions on our [Discord](https://discord.gg/4NA5SbzrWb).
|
||||
|
||||
Our goal is to make contributing as easy as possible. If you run into any issues please don't hesitate to reach out.
|
||||
That way we can help future contributors and users can avoid the same issue.
|
||||
See you there!
|
||||
|
||||
We also have support channels and generally interesting discussions on our
|
||||
[Discord](https://discord.gg/4NA5SbzrWb).
|
||||
|
||||
We would love to see you there!
|
||||
|
||||
## Get Started 🚀
|
||||
|
||||
Onyx being a fully functional app, relies on some external software, specifically:
|
||||
|
||||
- [Postgres](https://www.postgresql.org/) (Relational DB)
|
||||
- [Vespa](https://vespa.ai/) (Vector DB/Search Engine)
|
||||
- [Redis](https://redis.io/) (Cache)
|
||||
- [MinIO](https://min.io/) (File Store)
|
||||
- [Nginx](https://nginx.org/) (Not needed for development flows generally)
|
||||
|
||||
> **Note:**
|
||||
> This guide provides instructions to build and run Onyx locally from source with Docker containers providing the above external software. We believe this combination is easier for
|
||||
> development purposes. If you prefer to use pre-built container images, we provide instructions on running the full Onyx stack within Docker below.
|
||||
|
||||
### Local Set Up
|
||||
|
||||
Be sure to use Python version 3.11. For instructions on installing Python 3.11 on macOS, refer to the [CONTRIBUTING_MACOS.md](./CONTRIBUTING_MACOS.md) readme.
|
||||
|
||||
If using a lower version, modifications will have to be made to the code.
|
||||
If using a higher version, sometimes some libraries will not be available (i.e. we had problems with Tensorflow in the past with higher versions of python).
|
||||
|
||||
#### Backend: Python requirements
|
||||
|
||||
Currently, we use [uv](https://docs.astral.sh/uv/) and recommend creating a [virtual environment](https://docs.astral.sh/uv/pip/environments/#using-a-virtual-environment).
|
||||
|
||||
For convenience here's a command for it:
|
||||
|
||||
```bash
|
||||
uv venv .venv --python 3.11
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
_For Windows, activate the virtual environment using Command Prompt:_
|
||||
|
||||
```bash
|
||||
.venv\Scripts\activate
|
||||
```
|
||||
|
||||
If using PowerShell, the command slightly differs:
|
||||
|
||||
```powershell
|
||||
.venv\Scripts\Activate.ps1
|
||||
```
|
||||
|
||||
Install the required python dependencies:
|
||||
|
||||
```bash
|
||||
uv sync --all-extras
|
||||
```
|
||||
|
||||
Install Playwright for Python (headless browser required by the Web Connector):
|
||||
|
||||
```bash
|
||||
uv run playwright install
|
||||
```
|
||||
|
||||
#### Frontend: Node dependencies
|
||||
|
||||
Onyx uses Node v22.20.0. We highly recommend you use [Node Version Manager (nvm)](https://github.com/nvm-sh/nvm)
|
||||
to manage your Node installations. Once installed, you can run
|
||||
|
||||
```bash
|
||||
nvm install 22 && nvm use 22
|
||||
node -v # verify your active version
|
||||
```
|
||||
|
||||
Navigate to `onyx/web` and run:
|
||||
|
||||
```bash
|
||||
npm i
|
||||
```
|
||||
|
||||
## Formatting and Linting
|
||||
|
||||
### Backend
|
||||
|
||||
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
|
||||
|
||||
Then run:
|
||||
|
||||
```bash
|
||||
uv run pre-commit install
|
||||
```
|
||||
|
||||
Additionally, we use `mypy` for static type checking.
|
||||
Onyx is fully type-annotated, and we want to keep it that way!
|
||||
To run the mypy checks manually, run `uv run mypy .` from the `onyx/backend` directory.
|
||||
|
||||
### Web
|
||||
|
||||
We use `prettier` for formatting. The desired version will be installed via a `npm i` from the `onyx/web` directory.
|
||||
To run the formatter, use `npx prettier --write .` from the `onyx/web` directory.
|
||||
|
||||
Pre-commit will also run prettier automatically on files you've recently touched. If re-formatted, your commit will fail.
|
||||
Re-stage your changes and commit again.
|
||||
|
||||
# Running the application for development
|
||||
|
||||
## Developing using VSCode Debugger (recommended)
|
||||
|
||||
**We highly recommend using VSCode debugger for development.**
|
||||
See [CONTRIBUTING_VSCODE.md](./CONTRIBUTING_VSCODE.md) for more details.
|
||||
|
||||
Otherwise, you can follow the instructions below to run the application for development.
|
||||
|
||||
## Manually running the application for development
|
||||
### Docker containers for external software
|
||||
|
||||
You will need Docker installed to run these containers.
|
||||
|
||||
First navigate to `onyx/deployment/docker_compose`, then start up Postgres/Vespa/Redis/MinIO with:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d index relational_db cache minio
|
||||
```
|
||||
|
||||
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)
|
||||
|
||||
### Running Onyx locally
|
||||
|
||||
To start the frontend, navigate to `onyx/web` and run:
|
||||
|
||||
```bash
|
||||
npm run dev
|
||||
```
|
||||
|
||||
Next, start the model server which runs the local NLP models.
|
||||
Navigate to `onyx/backend` and run:
|
||||
|
||||
```bash
|
||||
uvicorn model_server.main:app --reload --port 9000
|
||||
```
|
||||
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
|
||||
```bash
|
||||
powershell -Command "uvicorn model_server.main:app --reload --port 9000"
|
||||
```
|
||||
|
||||
The first time running Onyx, you will need to run the DB migrations for Postgres.
|
||||
After the first time, this is no longer required unless the DB models change.
|
||||
|
||||
Navigate to `onyx/backend` and with the venv active, run:
|
||||
|
||||
```bash
|
||||
alembic upgrade head
|
||||
```
|
||||
|
||||
Next, start the task queue which orchestrates the background jobs.
|
||||
Jobs that take more time are run async from the API server.
|
||||
|
||||
Still in `onyx/backend`, run:
|
||||
|
||||
```bash
|
||||
python ./scripts/dev_run_background_jobs.py
|
||||
```
|
||||
|
||||
To run the backend API server, navigate back to `onyx/backend` and run:
|
||||
|
||||
```bash
|
||||
AUTH_TYPE=disabled uvicorn onyx.main:app --reload --port 8080
|
||||
```
|
||||
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
|
||||
```bash
|
||||
powershell -Command "
|
||||
$env:AUTH_TYPE='disabled'
|
||||
uvicorn onyx.main:app --reload --port 8080
|
||||
"
|
||||
```
|
||||
|
||||
> **Note:**
|
||||
> If you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services.
|
||||
|
||||
#### Wrapping up
|
||||
|
||||
You should now have 4 servers running:
|
||||
|
||||
- Web server
|
||||
- Backend API
|
||||
- Model server
|
||||
- Background jobs
|
||||
|
||||
Now, visit `http://localhost:3000` in your browser. You should see the Onyx onboarding wizard where you can connect your external LLM provider to Onyx.
|
||||
|
||||
You've successfully set up a local Onyx instance! 🏁
|
||||
|
||||
#### Running the Onyx application in a container
|
||||
|
||||
You can run the full Onyx application stack from pre-built images including all external software dependencies.
|
||||
|
||||
Navigate to `onyx/deployment/docker_compose` and run:
|
||||
|
||||
```bash
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
After Docker pulls and starts these containers, navigate to `http://localhost:3000` to use Onyx.
|
||||
|
||||
If you want to make changes to Onyx and run those changes in Docker, you can also build a local version of the Onyx container images that incorporates your changes like so:
|
||||
|
||||
```bash
|
||||
docker compose up -d --build
|
||||
```
|
||||
|
||||
|
||||
### Release Process
|
||||
|
||||
## Release Process
|
||||
Onyx loosely follows the SemVer versioning standard.
|
||||
Major changes are released with a "minor" version bump. Currently we use patch release versions to indicate small feature changes.
|
||||
A set of Docker containers will be pushed automatically to DockerHub with every tag.
|
||||
|
||||
@@ -42,9 +42,7 @@ RUN apt-get update && \
|
||||
pkg-config \
|
||||
gcc \
|
||||
nano \
|
||||
vim \
|
||||
libjemalloc2 \
|
||||
&& \
|
||||
vim && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
apt-get clean
|
||||
|
||||
@@ -132,13 +130,6 @@ ENV PYTHONPATH=/app
|
||||
ARG ONYX_VERSION=0.0.0-dev
|
||||
ENV ONYX_VERSION=${ONYX_VERSION}
|
||||
|
||||
# Use jemalloc instead of glibc malloc to reduce memory fragmentation
|
||||
# in long-running Python processes (API server, Celery workers).
|
||||
# The soname is architecture-independent; the dynamic linker resolves
|
||||
# the correct path from standard library directories.
|
||||
# Placed after all RUN steps so build-time processes are unaffected.
|
||||
ENV LD_PRELOAD=libjemalloc.so.2
|
||||
|
||||
# Default command which does nothing
|
||||
# This container is used by api server and background which specify their own CMD
|
||||
CMD ["tail", "-f", "/dev/null"]
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
"""add_search_query_table
|
||||
|
||||
Revision ID: 73e9983e5091
|
||||
Revises: d1b637d7050a
|
||||
Create Date: 2026-01-14 14:16:52.837489
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "73e9983e5091"
|
||||
down_revision = "d1b637d7050a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"search_query",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("user.id"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("query", sa.String(), nullable=False),
|
||||
sa.Column("query_expansions", postgresql.ARRAY(sa.String()), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
)
|
||||
|
||||
op.create_index("ix_search_query_user_id", "search_query", ["user_id"])
|
||||
op.create_index("ix_search_query_created_at", "search_query", ["created_at"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_search_query_created_at", table_name="search_query")
|
||||
op.drop_index("ix_search_query_user_id", table_name="search_query")
|
||||
op.drop_table("search_query")
|
||||
@@ -10,8 +10,7 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.enums import RecencyBiasSetting, SearchType
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "776b3bbe9092"
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
"""Add SET NULL cascade to chat_session.persona_id foreign key
|
||||
|
||||
Revision ID: ac9c7b76419b
|
||||
Revises: 73e9983e5091
|
||||
Create Date: 2026-01-17 18:10:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ac9c7b76419b"
|
||||
down_revision = "73e9983e5091"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop the existing foreign key constraint (no cascade behavior)
|
||||
op.drop_constraint("fk_chat_session_persona_id", "chat_session", type_="foreignkey")
|
||||
# Recreate with SET NULL on delete, so deleting a persona sets
|
||||
# chat_session.persona_id to NULL instead of blocking the delete
|
||||
op.create_foreign_key(
|
||||
"fk_chat_session_persona_id",
|
||||
"chat_session",
|
||||
"persona",
|
||||
["persona_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert to original constraint without cascade behavior
|
||||
op.drop_constraint("fk_chat_session_persona_id", "chat_session", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"fk_chat_session_persona_id",
|
||||
"chat_session",
|
||||
"persona",
|
||||
["persona_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -109,6 +109,7 @@ CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS = float(
|
||||
|
||||
|
||||
STRIPE_SECRET_KEY = os.environ.get("STRIPE_SECRET_KEY")
|
||||
STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")
|
||||
|
||||
# JWT Public Key URL
|
||||
JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
|
||||
|
||||
64
backend/ee/onyx/db/search.py
Normal file
64
backend/ee/onyx/db/search.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.time_utils import get_db_current_time
|
||||
from onyx.db.models import SearchQuery
|
||||
|
||||
|
||||
def create_search_query(
|
||||
db_session: Session,
|
||||
user_id: UUID,
|
||||
query: str,
|
||||
query_expansions: list[str] | None = None,
|
||||
) -> SearchQuery:
|
||||
"""Create and persist a `SearchQuery` row.
|
||||
|
||||
Notes:
|
||||
- `SearchQuery.id` is a UUID PK without a server-side default, so we generate it.
|
||||
- `created_at` is filled by the DB (server_default=now()).
|
||||
"""
|
||||
search_query = SearchQuery(
|
||||
id=uuid.uuid4(),
|
||||
user_id=user_id,
|
||||
query=query,
|
||||
query_expansions=query_expansions,
|
||||
)
|
||||
db_session.add(search_query)
|
||||
db_session.commit()
|
||||
db_session.refresh(search_query)
|
||||
return search_query
|
||||
|
||||
|
||||
def fetch_search_queries_for_user(
|
||||
db_session: Session,
|
||||
user_id: UUID,
|
||||
filter_days: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[SearchQuery]:
|
||||
"""Fetch `SearchQuery` rows for a user.
|
||||
|
||||
Args:
|
||||
user_id: User UUID.
|
||||
filter_days: Optional time filter. If provided, only rows created within
|
||||
the last `filter_days` days are returned.
|
||||
limit: Optional max number of rows to return.
|
||||
"""
|
||||
if filter_days is not None and filter_days <= 0:
|
||||
raise ValueError("filter_days must be > 0")
|
||||
|
||||
stmt = select(SearchQuery).where(SearchQuery.user_id == user_id)
|
||||
|
||||
if filter_days is not None and filter_days > 0:
|
||||
cutoff = get_db_current_time(db_session) - timedelta(days=filter_days)
|
||||
stmt = stmt.where(SearchQuery.created_at >= cutoff)
|
||||
|
||||
stmt = stmt.order_by(SearchQuery.created_at.desc())
|
||||
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
@@ -20,12 +20,10 @@ from ee.onyx.server.middleware.tenant_tracking import (
|
||||
add_api_server_tenant_id_middleware,
|
||||
)
|
||||
from ee.onyx.server.oauth.api import router as ee_oauth_router
|
||||
from ee.onyx.server.query_and_chat.chat_backend import (
|
||||
router as chat_router,
|
||||
)
|
||||
from ee.onyx.server.query_and_chat.query_backend import (
|
||||
basic_router as ee_query_router,
|
||||
)
|
||||
from ee.onyx.server.query_and_chat.search_backend import router as search_router
|
||||
from ee.onyx.server.query_history.api import router as query_history_router
|
||||
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
|
||||
from ee.onyx.server.seeding import seed_db
|
||||
@@ -124,7 +122,7 @@ def get_application() -> FastAPI:
|
||||
# EE only backend APIs
|
||||
include_router_with_global_prefix_prepended(application, query_router)
|
||||
include_router_with_global_prefix_prepended(application, ee_query_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
include_router_with_global_prefix_prepended(application, search_router)
|
||||
include_router_with_global_prefix_prepended(application, standard_answer_router)
|
||||
include_router_with_global_prefix_prepended(application, ee_oauth_router)
|
||||
include_router_with_global_prefix_prepended(application, ee_document_cc_pair_router)
|
||||
|
||||
0
backend/ee/onyx/prompts/__init__.py
Normal file
0
backend/ee/onyx/prompts/__init__.py
Normal file
27
backend/ee/onyx/prompts/query_expansion.py
Normal file
27
backend/ee/onyx/prompts/query_expansion.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# Single message is likely most reliable and generally better for this task
|
||||
# No final reminders at the end since the user query is expected to be short
|
||||
# If it is not short, it should go into the chat flow so we do not need to account for this.
|
||||
KEYWORD_EXPANSION_PROMPT = """
|
||||
Generate a set of keyword-only queries to help find relevant documents for the provided query. \
|
||||
These queries will be passed to a bm25-based keyword search engine. \
|
||||
Provide a single query per line (where each query consists of one or more keywords). \
|
||||
The queries must be purely keywords and not contain any filler natural language. \
|
||||
The each query should have as few keywords as necessary to represent the user's search intent. \
|
||||
If there are no useful expansions, simply return the original query with no additional keyword queries. \
|
||||
CRITICAL: Do not include any additional formatting, comments, or anything aside from the keyword queries.
|
||||
|
||||
The user query is:
|
||||
{user_query}
|
||||
""".strip()
|
||||
|
||||
|
||||
QUERY_TYPE_PROMPT = """
|
||||
Determine if the provided query is better suited for a keyword search or a semantic search.
|
||||
Respond with "keyword" or "semantic" literally and nothing else.
|
||||
Do not provide any additional text or reasoning to your response.
|
||||
|
||||
CRITICAL: It must only be 1 single word - EITHER "keyword" or "semantic".
|
||||
|
||||
The user query is:
|
||||
{user_query}
|
||||
""".strip()
|
||||
42
backend/ee/onyx/prompts/search_flow_classification.py
Normal file
42
backend/ee/onyx/prompts/search_flow_classification.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# ruff: noqa: E501, W605 start
|
||||
SEARCH_CLASS = "search"
|
||||
CHAT_CLASS = "chat"
|
||||
|
||||
# Will note that with many larger LLMs the latency on running this prompt via third party APIs is as high as 2 seconds which is too slow for many
|
||||
# use cases.
|
||||
SEARCH_CHAT_PROMPT = f"""
|
||||
Determine if the following query is better suited for a search UI or a chat UI. Respond with "{SEARCH_CLASS}" or "{CHAT_CLASS}" literally and nothing else. \
|
||||
Do not provide any additional text or reasoning to your response. CRITICAL, IT MUST ONLY BE 1 SINGLE WORD - EITHER "{SEARCH_CLASS}" or "{CHAT_CLASS}".
|
||||
|
||||
# Classification Guidelines:
|
||||
## {SEARCH_CLASS}
|
||||
- If the query consists entirely of keywords or query doesn't require any answer from the AI
|
||||
- If the query is a short statement that seems like a search query rather than a question
|
||||
- If the query feels nonsensical or is a short phrase that possibly describes a document or information that could be found in a internal document
|
||||
|
||||
### Examples of {SEARCH_CLASS} queries:
|
||||
- Find me the document that goes over the onboarding process for a new hire
|
||||
- Pull requests since last week
|
||||
- Sales Runbook AMEA Region
|
||||
- Procurement process
|
||||
- Retrieve the PRD for project X
|
||||
|
||||
## {CHAT_CLASS}
|
||||
- If the query is asking a question that requires an answer rather than a document
|
||||
- If the query is asking for a solution, suggestion, or general help
|
||||
- If the query is seeking information that is on the web and likely not in a company internal document
|
||||
- If the query should be answered without any context from additional documents or searches
|
||||
|
||||
### Examples of {CHAT_CLASS} queries:
|
||||
- What led us to win the deal with company X? (seeking answer)
|
||||
- Google Drive not sync-ing files to my computer (seeking solution)
|
||||
- Review my email: <whatever the email is> (general help)
|
||||
- Write me a script to... (general help)
|
||||
- Cheap flights Europe to Tokyo (information likely found on the web, not internal)
|
||||
|
||||
# User Query:
|
||||
{{user_query}}
|
||||
|
||||
REMEMBER TO ONLY RESPOND WITH "{SEARCH_CLASS}" OR "{CHAT_CLASS}" AND NOTHING ELSE.
|
||||
""".strip()
|
||||
# ruff: noqa: E501, W605 end
|
||||
270
backend/ee/onyx/search/process_search_query.py
Normal file
270
backend/ee/onyx/search/process_search_query.py
Normal file
@@ -0,0 +1,270 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.search import create_search_query
|
||||
from ee.onyx.secondary_llm_flows.query_expansion import expand_keywords
|
||||
from ee.onyx.server.query_and_chat.models import SearchDocWithContent
|
||||
from ee.onyx.server.query_and_chat.models import SearchFullResponse
|
||||
from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest
|
||||
from ee.onyx.server.query_and_chat.streaming_models import LLMSelectedDocsPacket
|
||||
from ee.onyx.server.query_and_chat.streaming_models import SearchDocsPacket
|
||||
from ee.onyx.server.query_and_chat.streaming_models import SearchErrorPacket
|
||||
from ee.onyx.server.query_and_chat.streaming_models import SearchQueriesPacket
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import ChunkSearchRequest
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.pipeline import merge_individual_chunks
|
||||
from onyx.context.search.pipeline import search_pipeline
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.factory import get_current_primary_default_document_index
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.secondary_llm_flows.document_filter import select_sections_for_expansion
|
||||
from onyx.tools.tool_implementations.search.search_utils import (
|
||||
weighted_reciprocal_rank_fusion,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# This is just a heuristic that also happens to work well for the UI/UX
|
||||
# Users would not find it useful to see a huge list of suggested docs
|
||||
# but more than 1 is also likely good as many questions may target more than 1 doc.
|
||||
TARGET_NUM_SECTIONS_FOR_LLM_SELECTION = 3
|
||||
|
||||
|
||||
def _run_single_search(
|
||||
query: str,
|
||||
filters: BaseFilters | None,
|
||||
document_index: DocumentIndex,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Execute a single search query and return chunks."""
|
||||
chunk_search_request = ChunkSearchRequest(
|
||||
query=query,
|
||||
user_selected_filters=filters,
|
||||
)
|
||||
|
||||
return search_pipeline(
|
||||
chunk_search_request=chunk_search_request,
|
||||
document_index=document_index,
|
||||
user=user,
|
||||
persona=None, # No persona for direct search
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def stream_search_query(
|
||||
request: SendSearchQueryRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> Generator[
|
||||
SearchQueriesPacket | SearchDocsPacket | LLMSelectedDocsPacket | SearchErrorPacket,
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
Core search function that yields streaming packets.
|
||||
Used by both streaming and non-streaming endpoints.
|
||||
"""
|
||||
# Get document index
|
||||
document_index = get_current_primary_default_document_index(db_session)
|
||||
|
||||
# Determine queries to execute
|
||||
original_query = request.search_query
|
||||
keyword_expansions: list[str] = []
|
||||
|
||||
if request.run_query_expansion:
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
keyword_expansions = expand_keywords(
|
||||
user_query=original_query,
|
||||
llm=llm,
|
||||
)
|
||||
if keyword_expansions:
|
||||
logger.debug(
|
||||
f"Query expansion generated {len(keyword_expansions)} keyword queries"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Query expansion failed: {e}; using original query only.")
|
||||
keyword_expansions = []
|
||||
|
||||
# Build list of all executed queries for tracking
|
||||
all_executed_queries = [original_query] + keyword_expansions
|
||||
|
||||
# TODO remove this check, user should not be None
|
||||
if user is not None:
|
||||
create_search_query(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
query=request.search_query,
|
||||
query_expansions=keyword_expansions if keyword_expansions else None,
|
||||
)
|
||||
|
||||
# Execute search(es)
|
||||
if not keyword_expansions:
|
||||
# Single query (original only) - no threading needed
|
||||
chunks = _run_single_search(
|
||||
query=original_query,
|
||||
filters=request.filters,
|
||||
document_index=document_index,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
else:
|
||||
# Multiple queries - run in parallel and merge with RRF
|
||||
# First query is the original (semantic), rest are keyword expansions
|
||||
search_functions = [
|
||||
(
|
||||
_run_single_search,
|
||||
(query, request.filters, document_index, user, db_session),
|
||||
)
|
||||
for query in all_executed_queries
|
||||
]
|
||||
|
||||
# Run all searches in parallel
|
||||
all_search_results: list[list[InferenceChunk]] = (
|
||||
run_functions_tuples_in_parallel(
|
||||
search_functions,
|
||||
allow_failures=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Separate original query results from keyword expansion results
|
||||
# Note that in rare cases, the original query may have failed and so we may be
|
||||
# just overweighting one set of keyword results, should be not a big deal though.
|
||||
original_result = all_search_results[0] if all_search_results else []
|
||||
keyword_results = all_search_results[1:] if len(all_search_results) > 1 else []
|
||||
|
||||
# Build valid results and weights
|
||||
# Original query (semantic): weight 2.0
|
||||
# Keyword expansions: weight 1.0 each
|
||||
valid_results: list[list[InferenceChunk]] = []
|
||||
weights: list[float] = []
|
||||
|
||||
if original_result:
|
||||
valid_results.append(original_result)
|
||||
weights.append(2.0)
|
||||
|
||||
for keyword_result in keyword_results:
|
||||
if keyword_result:
|
||||
valid_results.append(keyword_result)
|
||||
weights.append(1.0)
|
||||
|
||||
if not valid_results:
|
||||
logger.warning("All parallel searches returned empty results")
|
||||
chunks = []
|
||||
else:
|
||||
chunks = weighted_reciprocal_rank_fusion(
|
||||
ranked_results=valid_results,
|
||||
weights=weights,
|
||||
id_extractor=lambda chunk: f"{chunk.document_id}_{chunk.chunk_id}",
|
||||
)
|
||||
|
||||
# Merge chunks into sections
|
||||
sections = merge_individual_chunks(chunks)
|
||||
|
||||
# Apply LLM document selection if requested
|
||||
# num_docs_fed_to_llm_selection specifies how many sections to feed to the LLM for selection
|
||||
# The LLM will always try to select TARGET_NUM_SECTIONS_FOR_LLM_SELECTION sections from those fed to it
|
||||
# llm_selected_doc_ids will be:
|
||||
# - None if LLM selection was not requested or failed
|
||||
# - Empty list if LLM selection ran but selected nothing
|
||||
# - List of doc IDs if LLM selection succeeded
|
||||
run_llm_selection = (
|
||||
request.num_docs_fed_to_llm_selection is not None
|
||||
and request.num_docs_fed_to_llm_selection >= 1
|
||||
)
|
||||
llm_selected_doc_ids: list[str] | None = None
|
||||
llm_selection_failed = False
|
||||
if run_llm_selection and sections:
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
sections_to_evaluate = sections[: request.num_docs_fed_to_llm_selection]
|
||||
selected_sections, _ = select_sections_for_expansion(
|
||||
sections=sections_to_evaluate,
|
||||
user_query=original_query,
|
||||
llm=llm,
|
||||
max_sections=TARGET_NUM_SECTIONS_FOR_LLM_SELECTION,
|
||||
try_to_fill_to_max=True,
|
||||
)
|
||||
# Extract unique document IDs from selected sections (may be empty)
|
||||
llm_selected_doc_ids = list(
|
||||
dict.fromkeys(
|
||||
section.center_chunk.document_id for section in selected_sections
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
f"LLM document selection evaluated {len(sections_to_evaluate)} sections, "
|
||||
f"selected {len(selected_sections)} sections with doc IDs: {llm_selected_doc_ids}"
|
||||
)
|
||||
except Exception as e:
|
||||
# Allowing a blanket exception here as this step is not critical and the rest of the results are still valid
|
||||
logger.warning(f"LLM document selection failed: {e}")
|
||||
llm_selection_failed = True
|
||||
elif run_llm_selection and not sections:
|
||||
# LLM selection requested but no sections to evaluate
|
||||
llm_selected_doc_ids = []
|
||||
|
||||
# Convert to SearchDocWithContent list, optionally including content
|
||||
search_docs = SearchDocWithContent.from_inference_sections(
|
||||
sections,
|
||||
include_content=request.include_content,
|
||||
is_internet=False,
|
||||
)
|
||||
|
||||
# Yield queries packet
|
||||
yield SearchQueriesPacket(all_executed_queries=all_executed_queries)
|
||||
|
||||
# Yield docs packet
|
||||
yield SearchDocsPacket(search_docs=search_docs)
|
||||
|
||||
# Yield LLM selected docs packet if LLM selection was requested
|
||||
# - llm_selected_doc_ids is None if selection failed
|
||||
# - llm_selected_doc_ids is empty list if no docs were selected
|
||||
# - llm_selected_doc_ids is list of IDs if docs were selected
|
||||
if run_llm_selection:
|
||||
yield LLMSelectedDocsPacket(
|
||||
llm_selected_doc_ids=None if llm_selection_failed else llm_selected_doc_ids
|
||||
)
|
||||
|
||||
|
||||
def gather_search_stream(
|
||||
packets: Generator[
|
||||
SearchQueriesPacket
|
||||
| SearchDocsPacket
|
||||
| LLMSelectedDocsPacket
|
||||
| SearchErrorPacket,
|
||||
None,
|
||||
None,
|
||||
],
|
||||
) -> SearchFullResponse:
|
||||
"""
|
||||
Aggregate all streaming packets into SearchFullResponse.
|
||||
"""
|
||||
all_executed_queries: list[str] = []
|
||||
search_docs: list[SearchDocWithContent] = []
|
||||
llm_selected_doc_ids: list[str] | None = None
|
||||
error: str | None = None
|
||||
|
||||
for packet in packets:
|
||||
if isinstance(packet, SearchQueriesPacket):
|
||||
all_executed_queries = packet.all_executed_queries
|
||||
elif isinstance(packet, SearchDocsPacket):
|
||||
search_docs = packet.search_docs
|
||||
elif isinstance(packet, LLMSelectedDocsPacket):
|
||||
llm_selected_doc_ids = packet.llm_selected_doc_ids
|
||||
elif isinstance(packet, SearchErrorPacket):
|
||||
error = packet.error
|
||||
|
||||
return SearchFullResponse(
|
||||
all_executed_queries=all_executed_queries,
|
||||
search_docs=search_docs,
|
||||
doc_selection_reasoning=None,
|
||||
llm_selected_doc_ids=llm_selected_doc_ids,
|
||||
error=error,
|
||||
)
|
||||
0
backend/ee/onyx/secondary_llm_flows/__init__.py
Normal file
0
backend/ee/onyx/secondary_llm_flows/__init__.py
Normal file
92
backend/ee/onyx/secondary_llm_flows/query_expansion.py
Normal file
92
backend/ee/onyx/secondary_llm_flows/query_expansion.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import re
|
||||
|
||||
from ee.onyx.prompts.query_expansion import KEYWORD_EXPANSION_PROMPT
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import LanguageModelInput
|
||||
from onyx.llm.models import ReasoningEffort
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.llm.utils import llm_response_to_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Pattern to remove common LLM artifacts: brackets, quotes, list markers, etc.
|
||||
CLEANUP_PATTERN = re.compile(r'[\[\]"\'`]')
|
||||
|
||||
|
||||
def _clean_keyword_line(line: str) -> str:
|
||||
"""Clean a keyword line by removing common LLM artifacts.
|
||||
|
||||
Removes brackets, quotes, and other characters that LLMs may accidentally
|
||||
include in their output.
|
||||
"""
|
||||
# Remove common artifacts
|
||||
cleaned = CLEANUP_PATTERN.sub("", line)
|
||||
# Remove leading list markers like "1.", "2.", "-", "*"
|
||||
cleaned = re.sub(r"^\s*(?:\d+[\.\)]\s*|[-*]\s*)", "", cleaned)
|
||||
return cleaned.strip()
|
||||
|
||||
|
||||
def expand_keywords(
|
||||
user_query: str,
|
||||
llm: LLM,
|
||||
) -> list[str]:
|
||||
"""Expand a user query into multiple keyword-only queries for BM25 search.
|
||||
|
||||
Uses an LLM to generate keyword-based search queries that capture different
|
||||
aspects of the user's search intent. Returns only the expanded queries,
|
||||
not the original query.
|
||||
|
||||
Args:
|
||||
user_query: The original search query from the user
|
||||
llm: Language model to use for keyword expansion
|
||||
|
||||
Returns:
|
||||
List of expanded keyword queries (excluding the original query).
|
||||
Returns empty list if expansion fails or produces no useful expansions.
|
||||
"""
|
||||
messages: LanguageModelInput = [
|
||||
UserMessage(content=KEYWORD_EXPANSION_PROMPT.format(user_query=user_query))
|
||||
]
|
||||
|
||||
try:
|
||||
response = llm.invoke(
|
||||
prompt=messages,
|
||||
reasoning_effort=ReasoningEffort.OFF,
|
||||
# Limit output - we only expect a few short keyword queries
|
||||
max_tokens=150,
|
||||
)
|
||||
|
||||
content = llm_response_to_string(response).strip()
|
||||
|
||||
if not content:
|
||||
logger.warning("Keyword expansion returned empty response.")
|
||||
return []
|
||||
|
||||
# Parse response - each line is a separate keyword query
|
||||
# Clean each line to remove LLM artifacts and drop empty lines
|
||||
parsed_queries = []
|
||||
for line in content.strip().split("\n"):
|
||||
cleaned = _clean_keyword_line(line)
|
||||
if cleaned:
|
||||
parsed_queries.append(cleaned)
|
||||
|
||||
if not parsed_queries:
|
||||
logger.warning("Keyword expansion parsing returned no queries.")
|
||||
return []
|
||||
|
||||
# Filter out duplicates and queries that match the original
|
||||
expanded_queries: list[str] = []
|
||||
seen_lower: set[str] = {user_query.lower()}
|
||||
for query in parsed_queries:
|
||||
query_lower = query.lower()
|
||||
if query_lower not in seen_lower:
|
||||
seen_lower.add(query_lower)
|
||||
expanded_queries.append(query)
|
||||
|
||||
logger.debug(f"Keyword expansion generated {len(expanded_queries)} queries")
|
||||
return expanded_queries
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Keyword expansion failed: {e}")
|
||||
return []
|
||||
@@ -0,0 +1,50 @@
|
||||
from ee.onyx.prompts.search_flow_classification import CHAT_CLASS
|
||||
from ee.onyx.prompts.search_flow_classification import SEARCH_CHAT_PROMPT
|
||||
from ee.onyx.prompts.search_flow_classification import SEARCH_CLASS
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import LanguageModelInput
|
||||
from onyx.llm.models import ReasoningEffort
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.llm.utils import llm_response_to_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def classify_is_search_flow(
|
||||
query: str,
|
||||
llm: LLM,
|
||||
) -> bool:
|
||||
messages: LanguageModelInput = [
|
||||
UserMessage(content=SEARCH_CHAT_PROMPT.format(user_query=query))
|
||||
]
|
||||
response = llm.invoke(
|
||||
prompt=messages,
|
||||
reasoning_effort=ReasoningEffort.OFF,
|
||||
# Nothing can happen in the UI until this call finishes so we need to be aggressive with the timeout
|
||||
timeout_override=2,
|
||||
# Well more than necessary but just to ensure completion and in case it succeeds with classifying but
|
||||
# ends up rambling
|
||||
max_tokens=20,
|
||||
)
|
||||
|
||||
content = llm_response_to_string(response).strip().lower()
|
||||
if not content:
|
||||
logger.warning(
|
||||
"Search flow classification returned empty response; defaulting to chat flow."
|
||||
)
|
||||
return False
|
||||
|
||||
# Prefer chat if both appear.
|
||||
if CHAT_CLASS in content:
|
||||
return False
|
||||
if SEARCH_CLASS in content:
|
||||
return True
|
||||
|
||||
logger.warning(
|
||||
"Search flow classification returned unexpected response; defaulting to chat flow. Response=%r",
|
||||
content,
|
||||
)
|
||||
return False
|
||||
@@ -19,9 +19,9 @@ from ee.onyx.db.analytics import fetch_query_analytics
|
||||
from ee.onyx.db.analytics import user_can_view_assistant_stats
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.server.utils import PUBLIC_API_TAGS
|
||||
|
||||
router = APIRouter(prefix="/analytics", tags=PUBLIC_API_TAGS)
|
||||
|
||||
|
||||
@@ -1,217 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.query_and_chat.models import BasicCreateChatMessageRequest
|
||||
from ee.onyx.server.query_and_chat.models import (
|
||||
BasicCreateChatMessageWithHistoryRequest,
|
||||
)
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.chat_utils import create_chat_history_chain
|
||||
from onyx.chat.models import ChatBasicResponse
|
||||
from onyx.chat.process_message import gather_stream
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import OptionalSearchSetting
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/chat")
|
||||
|
||||
|
||||
@router.post("/send-message-simple-api")
|
||||
def handle_simplified_chat_message(
|
||||
chat_message_req: BasicCreateChatMessageRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatBasicResponse:
|
||||
"""This is a Non-Streaming version that only gives back a minimal set of information"""
|
||||
logger.notice(f"Received new simple api chat message: {chat_message_req.message}")
|
||||
|
||||
if not chat_message_req.message:
|
||||
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
|
||||
|
||||
# Handle chat session creation if chat_session_id is not provided
|
||||
if chat_message_req.chat_session_id is None:
|
||||
if chat_message_req.persona_id is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Either chat_session_id or persona_id must be provided",
|
||||
)
|
||||
|
||||
# Create a new chat session with the provided persona_id
|
||||
try:
|
||||
new_chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="", # Leave empty for simple API
|
||||
user_id=user.id if user else None,
|
||||
persona_id=chat_message_req.persona_id,
|
||||
)
|
||||
chat_session_id = new_chat_session.id
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise HTTPException(status_code=400, detail="Invalid Persona provided.")
|
||||
else:
|
||||
chat_session_id = chat_message_req.chat_session_id
|
||||
|
||||
try:
|
||||
parent_message = create_chat_history_chain(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)[-1]
|
||||
except Exception:
|
||||
parent_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
|
||||
if (
|
||||
chat_message_req.retrieval_options is None
|
||||
and chat_message_req.search_doc_ids is None
|
||||
):
|
||||
retrieval_options: RetrievalDetails | None = RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=False,
|
||||
)
|
||||
else:
|
||||
retrieval_options = chat_message_req.retrieval_options
|
||||
|
||||
full_chat_msg_info = CreateChatMessageRequest(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=parent_message.id,
|
||||
message=chat_message_req.message,
|
||||
file_descriptors=[],
|
||||
search_doc_ids=chat_message_req.search_doc_ids,
|
||||
retrieval_options=retrieval_options,
|
||||
# Simple API does not support reranking, hide complexity from user
|
||||
rerank_settings=None,
|
||||
query_override=chat_message_req.query_override,
|
||||
# Currently only applies to search flow not chat
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
full_doc=chat_message_req.full_doc,
|
||||
structured_response_format=chat_message_req.structured_response_format,
|
||||
origin=MessageOrigin.API,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=full_chat_msg_info,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return gather_stream(packets)
|
||||
|
||||
|
||||
@router.post("/send-message-simple-with-history")
|
||||
def handle_send_message_simple_with_history(
|
||||
req: BasicCreateChatMessageWithHistoryRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatBasicResponse:
|
||||
"""This is a Non-Streaming version that only gives back a minimal set of information.
|
||||
takes in chat history maintained by the caller
|
||||
and does query rephrasing similar to answer-with-quote"""
|
||||
|
||||
if len(req.messages) == 0:
|
||||
raise HTTPException(status_code=400, detail="Messages cannot be zero length")
|
||||
|
||||
# This is a sanity check to make sure the chat history is valid
|
||||
# It must start with a user message and alternate beteen user and assistant
|
||||
expected_role = MessageType.USER
|
||||
for msg in req.messages:
|
||||
if not msg.message:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="One or more chat messages were empty"
|
||||
)
|
||||
|
||||
if msg.role != expected_role:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Message roles must start and end with MessageType.USER and alternate in-between.",
|
||||
)
|
||||
if expected_role == MessageType.USER:
|
||||
expected_role = MessageType.ASSISTANT
|
||||
else:
|
||||
expected_role = MessageType.USER
|
||||
|
||||
query = req.messages[-1].message
|
||||
msg_history = req.messages[:-1]
|
||||
|
||||
logger.notice(f"Received new simple with history chat message: {query}")
|
||||
|
||||
user_id = user.id if user is not None else None
|
||||
chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="handle_send_message_simple_with_history",
|
||||
user_id=user_id,
|
||||
persona_id=req.persona_id,
|
||||
)
|
||||
|
||||
llm = get_llm_for_persona(persona=chat_session.persona, user=user)
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
provider_type=llm.config.model_provider,
|
||||
)
|
||||
|
||||
# Every chat Session begins with an empty root message
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
)
|
||||
|
||||
chat_message = root_message
|
||||
for msg in msg_history:
|
||||
chat_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=chat_message,
|
||||
message=msg.message,
|
||||
token_count=len(llm_tokenizer.encode(msg.message)),
|
||||
message_type=msg.role,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
if req.retrieval_options is None and req.search_doc_ids is None:
|
||||
retrieval_options: RetrievalDetails | None = RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=False,
|
||||
)
|
||||
else:
|
||||
retrieval_options = req.retrieval_options
|
||||
|
||||
full_chat_msg_info = CreateChatMessageRequest(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message_id=chat_message.id,
|
||||
message=query,
|
||||
file_descriptors=[],
|
||||
search_doc_ids=req.search_doc_ids,
|
||||
retrieval_options=retrieval_options,
|
||||
# Simple API does not support reranking, hide complexity from user
|
||||
rerank_settings=None,
|
||||
query_override=None,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
full_doc=req.full_doc,
|
||||
structured_response_format=req.structured_response_format,
|
||||
origin=MessageOrigin.API,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=full_chat_msg_info,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return gather_stream(packets)
|
||||
@@ -1,18 +1,12 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import BasicChunkRequest
|
||||
from onyx.context.search.models import ChunkContext
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.server.manage.models import StandardAnswer
|
||||
|
||||
|
||||
@@ -25,119 +19,88 @@ class StandardAnswerResponse(BaseModel):
|
||||
standard_answers: list[StandardAnswer] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DocumentSearchRequest(BasicChunkRequest):
|
||||
user_selected_filters: BaseFilters | None = None
|
||||
class SearchFlowClassificationRequest(BaseModel):
|
||||
user_query: str
|
||||
|
||||
|
||||
class DocumentSearchResponse(BaseModel):
|
||||
top_documents: list[InferenceChunk]
|
||||
class SearchFlowClassificationResponse(BaseModel):
|
||||
is_search_flow: bool
|
||||
|
||||
|
||||
class BasicCreateChatMessageRequest(ChunkContext):
|
||||
"""If a chat_session_id is not provided, a persona_id must be provided to automatically create a new chat session
|
||||
Note, for simplicity this option only allows for a single linear chain of messages
|
||||
"""
|
||||
class SendSearchQueryRequest(BaseModel):
|
||||
search_query: str
|
||||
filters: BaseFilters | None = None
|
||||
num_docs_fed_to_llm_selection: int | None = None
|
||||
run_query_expansion: bool = False
|
||||
|
||||
chat_session_id: UUID | None = None
|
||||
# Optional persona_id to create a new chat session if chat_session_id is not provided
|
||||
persona_id: int | None = None
|
||||
# New message contents
|
||||
message: str
|
||||
# Defaults to using retrieval with no additional filters
|
||||
retrieval_options: RetrievalDetails | None = None
|
||||
# Allows the caller to specify the exact search query they want to use
|
||||
# will disable Query Rewording if specified
|
||||
query_override: str | None = None
|
||||
# If search_doc_ids provided, then retrieval options are unused
|
||||
search_doc_ids: list[int] | None = None
|
||||
# only works if using an OpenAI model. See the following for more details:
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_chat_session_or_persona(self) -> "BasicCreateChatMessageRequest":
|
||||
if self.chat_session_id is None and self.persona_id is None:
|
||||
raise ValueError("Either chat_session_id or persona_id must be provided")
|
||||
return self
|
||||
include_content: bool = False
|
||||
stream: bool = False
|
||||
|
||||
|
||||
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
# Last element is the new query. All previous elements are historical context
|
||||
messages: list[ThreadMessage]
|
||||
persona_id: int
|
||||
retrieval_options: RetrievalDetails | None = None
|
||||
query_override: str | None = None
|
||||
skip_rerank: bool | None = None
|
||||
# If search_doc_ids provided, then retrieval options are unused
|
||||
search_doc_ids: list[int] | None = None
|
||||
# only works if using an OpenAI model. See the following for more details:
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
class SearchDocWithContent(SearchDoc):
|
||||
# Allows None because this is determined by a flag but the object used in code
|
||||
# of the search path uses this type
|
||||
content: str | None
|
||||
|
||||
@classmethod
|
||||
def from_inference_sections(
|
||||
cls,
|
||||
sections: Sequence[InferenceSection],
|
||||
include_content: bool = False,
|
||||
is_internet: bool = False,
|
||||
) -> list["SearchDocWithContent"]:
|
||||
"""Convert InferenceSections to SearchDocWithContent objects.
|
||||
|
||||
class SimpleDoc(BaseModel):
|
||||
id: str
|
||||
semantic_identifier: str
|
||||
link: str | None
|
||||
blurb: str
|
||||
match_highlights: list[str]
|
||||
source_type: DocumentSource
|
||||
metadata: dict | None
|
||||
Args:
|
||||
sections: Sequence of InferenceSection objects
|
||||
include_content: If True, populate content field with combined_content
|
||||
is_internet: Whether these are internet search results
|
||||
|
||||
|
||||
class AgentSubQuestion(BaseModel):
|
||||
sub_question: str
|
||||
document_ids: list[str]
|
||||
|
||||
|
||||
class AgentAnswer(BaseModel):
|
||||
answer: str
|
||||
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
|
||||
|
||||
|
||||
class AgentSubQuery(BaseModel):
|
||||
sub_query: str
|
||||
query_id: int
|
||||
|
||||
@staticmethod
|
||||
def make_dict_by_level_and_question_index(
|
||||
original_dict: dict[tuple[int, int, int], "AgentSubQuery"],
|
||||
) -> dict[int, dict[int, list["AgentSubQuery"]]]:
|
||||
"""Takes a dict of tuple(level, question num, query_id) to sub queries.
|
||||
|
||||
returns a dict of level to dict[question num to list of query_id's]
|
||||
Ordering is asc for readability.
|
||||
Returns:
|
||||
List of SearchDocWithContent with optional content
|
||||
"""
|
||||
# In this function, when we sort int | None, we deliberately push None to the end
|
||||
if not sections:
|
||||
return []
|
||||
|
||||
# map entries to the level_question_dict
|
||||
level_question_dict: dict[int, dict[int, list["AgentSubQuery"]]] = {}
|
||||
for k1, obj in original_dict.items():
|
||||
level = k1[0]
|
||||
question = k1[1]
|
||||
|
||||
if level not in level_question_dict:
|
||||
level_question_dict[level] = {}
|
||||
|
||||
if question not in level_question_dict[level]:
|
||||
level_question_dict[level][question] = []
|
||||
|
||||
level_question_dict[level][question].append(obj)
|
||||
|
||||
# sort each query_id list and question_index
|
||||
for key1, obj1 in level_question_dict.items():
|
||||
for key2, value2 in obj1.items():
|
||||
# sort the query_id list of each question_index
|
||||
level_question_dict[key1][key2] = sorted(
|
||||
value2, key=lambda o: o.query_id
|
||||
)
|
||||
# sort the question_index dict of level
|
||||
level_question_dict[key1] = OrderedDict(
|
||||
sorted(level_question_dict[key1].items(), key=lambda x: (x is None, x))
|
||||
return [
|
||||
cls(
|
||||
document_id=(chunk := section.center_chunk).document_id,
|
||||
chunk_ind=chunk.chunk_id,
|
||||
semantic_identifier=chunk.semantic_identifier or "Unknown",
|
||||
link=chunk.source_links[0] if chunk.source_links else None,
|
||||
blurb=chunk.blurb,
|
||||
source_type=chunk.source_type,
|
||||
boost=chunk.boost,
|
||||
hidden=chunk.hidden,
|
||||
metadata=chunk.metadata,
|
||||
score=chunk.score,
|
||||
match_highlights=chunk.match_highlights,
|
||||
updated_at=chunk.updated_at,
|
||||
primary_owners=chunk.primary_owners,
|
||||
secondary_owners=chunk.secondary_owners,
|
||||
is_internet=is_internet,
|
||||
content=section.combined_content if include_content else None,
|
||||
)
|
||||
for section in sections
|
||||
]
|
||||
|
||||
# sort the top dict of levels
|
||||
sorted_dict = OrderedDict(
|
||||
sorted(level_question_dict.items(), key=lambda x: (x is None, x))
|
||||
)
|
||||
return sorted_dict
|
||||
|
||||
class SearchFullResponse(BaseModel):
|
||||
all_executed_queries: list[str]
|
||||
search_docs: list[SearchDocWithContent]
|
||||
# Reasoning tokens output by the LLM for the document selection
|
||||
doc_selection_reasoning: str | None = None
|
||||
# This a list of document ids that are in the search_docs list
|
||||
llm_selected_doc_ids: list[str] | None = None
|
||||
# Error message if the search failed partway through
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class SearchQueryResponse(BaseModel):
|
||||
query: str
|
||||
query_expansions: list[str] | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class SearchHistoryResponse(BaseModel):
|
||||
search_queries: list[SearchQueryResponse]
|
||||
|
||||
170
backend/ee/onyx/server/query_and_chat/search_backend.py
Normal file
170
backend/ee/onyx/server/query_and_chat/search_backend.py
Normal file
@@ -0,0 +1,170 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.search import fetch_search_queries_for_user
|
||||
from ee.onyx.search.process_search_query import gather_search_stream
|
||||
from ee.onyx.search.process_search_query import stream_search_query
|
||||
from ee.onyx.secondary_llm_flows.search_flow_classification import (
|
||||
classify_is_search_flow,
|
||||
)
|
||||
from ee.onyx.server.query_and_chat.models import SearchFlowClassificationRequest
|
||||
from ee.onyx.server.query_and_chat.models import SearchFlowClassificationResponse
|
||||
from ee.onyx.server.query_and_chat.models import SearchFullResponse
|
||||
from ee.onyx.server.query_and_chat.models import SearchHistoryResponse
|
||||
from ee.onyx.server.query_and_chat.models import SearchQueryResponse
|
||||
from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest
|
||||
from ee.onyx.server.query_and_chat.streaming_models import SearchErrorPacket
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.server.usage_limits import check_llm_cost_limit_for_provider
|
||||
from onyx.server.utils import get_json_line
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/search")
|
||||
|
||||
|
||||
@router.post("/search-flow-classification")
|
||||
def search_flow_classification(
|
||||
request: SearchFlowClassificationRequest,
|
||||
# This is added just to ensure this endpoint isn't spammed by non-authorized users since there's an LLM call underneath it
|
||||
_: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SearchFlowClassificationResponse:
|
||||
query = request.user_query
|
||||
# This is a heuristic that if the user is typing a lot of text, it's unlikely they're looking for some specific document
|
||||
# Most likely something needs to be done with the text included so we'll just classify it as a chat flow
|
||||
if len(query) > 200:
|
||||
return SearchFlowClassificationResponse(is_search_flow=False)
|
||||
|
||||
llm = get_default_llm()
|
||||
|
||||
check_llm_cost_limit_for_provider(
|
||||
db_session=db_session,
|
||||
tenant_id=get_current_tenant_id(),
|
||||
llm_provider_api_key=llm.config.api_key,
|
||||
)
|
||||
|
||||
try:
|
||||
is_search_flow = classify_is_search_flow(query=query, llm=llm)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Search flow classification failed; defaulting to chat flow",
|
||||
exc_info=e,
|
||||
)
|
||||
is_search_flow = False
|
||||
|
||||
return SearchFlowClassificationResponse(is_search_flow=is_search_flow)
|
||||
|
||||
|
||||
@router.post("/send-search-message", response_model=None)
|
||||
def handle_send_search_message(
|
||||
request: SendSearchQueryRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse | SearchFullResponse:
|
||||
"""
|
||||
Execute a search query with optional streaming.
|
||||
|
||||
When stream=True: Returns StreamingResponse with SSE
|
||||
When stream=False: Returns SearchFullResponse
|
||||
"""
|
||||
logger.debug(f"Received search query: {request.search_query}")
|
||||
|
||||
# Non-streaming path
|
||||
if not request.stream:
|
||||
try:
|
||||
packets = stream_search_query(request, user, db_session)
|
||||
return gather_search_stream(packets)
|
||||
except NotImplementedError as e:
|
||||
return SearchFullResponse(
|
||||
all_executed_queries=[],
|
||||
search_docs=[],
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
# Streaming path
|
||||
def stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
with get_session_with_current_tenant() as streaming_db_session:
|
||||
for packet in stream_search_query(request, user, streaming_db_session):
|
||||
yield get_json_line(packet.model_dump())
|
||||
except NotImplementedError as e:
|
||||
yield get_json_line(SearchErrorPacket(error=str(e)).model_dump())
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error in search streaming")
|
||||
yield get_json_line(SearchErrorPacket(error=str(e)).model_dump())
|
||||
|
||||
return StreamingResponse(stream_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.get("/search-history")
|
||||
def get_search_history(
|
||||
limit: int = 100,
|
||||
filter_days: int | None = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SearchHistoryResponse:
|
||||
"""
|
||||
Fetch past search queries for the authenticated user.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of queries to return (default 100)
|
||||
filter_days: Only return queries from the last N days (optional)
|
||||
|
||||
Returns:
|
||||
SearchHistoryResponse with list of search queries, ordered by most recent first.
|
||||
"""
|
||||
# Validate limit
|
||||
if limit <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="limit must be greater than 0",
|
||||
)
|
||||
if limit > 1000:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="limit must be at most 1000",
|
||||
)
|
||||
|
||||
# Validate filter_days
|
||||
if filter_days is not None and filter_days <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="filter_days must be greater than 0",
|
||||
)
|
||||
|
||||
# TODO(yuhong) remove this
|
||||
if user is None:
|
||||
# Return empty list for unauthenticated users
|
||||
return SearchHistoryResponse(search_queries=[])
|
||||
|
||||
search_queries = fetch_search_queries_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
filter_days=filter_days,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return SearchHistoryResponse(
|
||||
search_queries=[
|
||||
SearchQueryResponse(
|
||||
query=sq.query,
|
||||
query_expansions=sq.query_expansions,
|
||||
created_at=sq.created_at,
|
||||
)
|
||||
for sq in search_queries
|
||||
]
|
||||
)
|
||||
35
backend/ee/onyx/server/query_and_chat/streaming_models.py
Normal file
35
backend/ee/onyx/server/query_and_chat/streaming_models.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from ee.onyx.server.query_and_chat.models import SearchDocWithContent
|
||||
|
||||
|
||||
class SearchQueriesPacket(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
type: Literal["search_queries"] = "search_queries"
|
||||
all_executed_queries: list[str]
|
||||
|
||||
|
||||
class SearchDocsPacket(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
type: Literal["search_docs"] = "search_docs"
|
||||
search_docs: list[SearchDocWithContent]
|
||||
|
||||
|
||||
class SearchErrorPacket(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
type: Literal["search_error"] = "search_error"
|
||||
error: str
|
||||
|
||||
|
||||
class LLMSelectedDocsPacket(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
type: Literal["llm_selected_docs"] = "llm_selected_docs"
|
||||
# None if LLM selection failed, empty list if no docs selected, list of IDs otherwise
|
||||
llm_selected_doc_ids: list[str] | None
|
||||
@@ -32,6 +32,7 @@ from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.configs.constants import SessionType
|
||||
@@ -48,7 +49,6 @@ from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from onyx.server.query_and_chat.models import ChatSessionDetails
|
||||
from onyx.server.query_and_chat.models import ChatSessionsResponse
|
||||
from onyx.server.utils import PUBLIC_API_TAGS
|
||||
from onyx.utils.threadpool_concurrency import parallel_yield
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
"""Tenant-specific usage limit overrides from the control plane (EE version)."""
|
||||
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.server.tenant_usage_limits import TenantUsageLimitOverrides
|
||||
from onyx.server.usage_limits import NO_LIMIT
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -12,9 +16,12 @@ logger = setup_logger()
|
||||
|
||||
# In-memory storage for tenant overrides (populated at startup)
|
||||
_tenant_usage_limit_overrides: dict[str, TenantUsageLimitOverrides] | None = None
|
||||
_last_fetch_time: float = 0.0
|
||||
_FETCH_INTERVAL = 60 * 60 * 24 # 24 hours
|
||||
_ERROR_FETCH_INTERVAL = 30 * 60 # 30 minutes (if the last fetch failed)
|
||||
|
||||
|
||||
def fetch_usage_limit_overrides() -> dict[str, TenantUsageLimitOverrides]:
|
||||
def fetch_usage_limit_overrides() -> dict[str, TenantUsageLimitOverrides] | None:
|
||||
"""
|
||||
Fetch tenant-specific usage limit overrides from the control plane.
|
||||
|
||||
@@ -45,33 +52,52 @@ def fetch_usage_limit_overrides() -> dict[str, TenantUsageLimitOverrides]:
|
||||
f"Failed to parse usage limit overrides for tenant {tenant_id}: {e}"
|
||||
)
|
||||
|
||||
return result
|
||||
return (
|
||||
result or None
|
||||
) # if empty dictionary, something went wrong and we shouldn't enforce limits
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.warning(f"Failed to fetch usage limit overrides from control plane: {e}")
|
||||
return {}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing usage limit overrides: {e}")
|
||||
return {}
|
||||
return None
|
||||
|
||||
|
||||
def load_usage_limit_overrides() -> dict[str, TenantUsageLimitOverrides]:
|
||||
def load_usage_limit_overrides() -> None:
|
||||
"""
|
||||
Load tenant usage limit overrides from the control plane.
|
||||
|
||||
Called at server startup to populate the in-memory cache.
|
||||
"""
|
||||
global _tenant_usage_limit_overrides
|
||||
global _last_fetch_time
|
||||
|
||||
logger.info("Loading tenant usage limit overrides from control plane...")
|
||||
overrides = fetch_usage_limit_overrides()
|
||||
_tenant_usage_limit_overrides = overrides
|
||||
|
||||
_last_fetch_time = time.time()
|
||||
|
||||
# use the new result if it exists, otherwise use the old result
|
||||
# (prevents us from updating to a failed fetch result)
|
||||
_tenant_usage_limit_overrides = overrides or _tenant_usage_limit_overrides
|
||||
|
||||
if overrides:
|
||||
logger.info(f"Loaded usage limit overrides for {len(overrides)} tenants")
|
||||
else:
|
||||
logger.info("No tenant-specific usage limit overrides found")
|
||||
return overrides
|
||||
|
||||
|
||||
def unlimited(tenant_id: str) -> TenantUsageLimitOverrides:
|
||||
return TenantUsageLimitOverrides(
|
||||
tenant_id=tenant_id,
|
||||
llm_cost_cents_trial=NO_LIMIT,
|
||||
llm_cost_cents_paid=NO_LIMIT,
|
||||
chunks_indexed_trial=NO_LIMIT,
|
||||
chunks_indexed_paid=NO_LIMIT,
|
||||
api_calls_trial=NO_LIMIT,
|
||||
api_calls_paid=NO_LIMIT,
|
||||
non_streaming_calls_trial=NO_LIMIT,
|
||||
non_streaming_calls_paid=NO_LIMIT,
|
||||
)
|
||||
|
||||
|
||||
def get_tenant_usage_limit_overrides(
|
||||
@@ -86,7 +112,22 @@ def get_tenant_usage_limit_overrides(
|
||||
Returns:
|
||||
TenantUsageLimitOverrides if the tenant has overrides, None otherwise.
|
||||
"""
|
||||
|
||||
if DEV_MODE: # in dev mode, we return unlimited limits for all tenants
|
||||
return unlimited(tenant_id)
|
||||
|
||||
global _tenant_usage_limit_overrides
|
||||
if _tenant_usage_limit_overrides is None:
|
||||
_tenant_usage_limit_overrides = load_usage_limit_overrides()
|
||||
time_since = time.time() - _last_fetch_time
|
||||
if (
|
||||
_tenant_usage_limit_overrides is None and time_since > _ERROR_FETCH_INTERVAL
|
||||
) or (time_since > _FETCH_INTERVAL):
|
||||
logger.debug(
|
||||
f"Last fetch time: {_last_fetch_time}, time since last fetch: {time_since}"
|
||||
)
|
||||
|
||||
load_usage_limit_overrides()
|
||||
|
||||
# If we have failed to fetch from the control plane or we're in dev mode, don't usage limit anyone.
|
||||
if _tenant_usage_limit_overrides is None or DEV_MODE:
|
||||
return unlimited(tenant_id)
|
||||
return _tenant_usage_limit_overrides.get(tenant_id)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
import requests
|
||||
import stripe
|
||||
|
||||
from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
@@ -16,21 +16,15 @@ stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def fetch_stripe_checkout_session(
|
||||
tenant_id: str,
|
||||
billing_period: Literal["monthly", "annual"] = "monthly",
|
||||
) -> str:
|
||||
def fetch_stripe_checkout_session(tenant_id: str) -> str:
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/create-checkout-session"
|
||||
payload = {
|
||||
"tenant_id": tenant_id,
|
||||
"billing_period": billing_period,
|
||||
}
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
params = {"tenant_id": tenant_id}
|
||||
response = requests.post(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
return response.json()["sessionId"]
|
||||
|
||||
@@ -78,24 +72,22 @@ def fetch_billing_information(
|
||||
|
||||
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:
|
||||
"""
|
||||
Update the number of seats for a tenant's subscription.
|
||||
Preserves the existing price (monthly, annual, or grandfathered).
|
||||
Send a request to the control service to register the number of users for a tenant.
|
||||
"""
|
||||
|
||||
if not STRIPE_PRICE_ID:
|
||||
raise Exception("STRIPE_PRICE_ID is not set")
|
||||
|
||||
response = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_subscription_id = cast(str, response.get("stripe_subscription_id"))
|
||||
|
||||
subscription = stripe.Subscription.retrieve(stripe_subscription_id)
|
||||
subscription_item = subscription["items"]["data"][0]
|
||||
|
||||
# Use existing price to preserve the customer's current plan
|
||||
current_price_id = subscription_item.price.id
|
||||
|
||||
updated_subscription = stripe.Subscription.modify(
|
||||
stripe_subscription_id,
|
||||
items=[
|
||||
{
|
||||
"id": subscription_item.id,
|
||||
"price": current_price_id,
|
||||
"id": subscription["items"]["data"][0].id,
|
||||
"price": STRIPE_PRICE_ID,
|
||||
"quantity": number_of_users,
|
||||
}
|
||||
],
|
||||
|
||||
@@ -10,7 +10,6 @@ from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
|
||||
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import CreateSubscriptionSessionRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingFullSyncRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingResponse
|
||||
@@ -105,18 +104,15 @@ async def create_customer_portal_session(
|
||||
|
||||
@router.post("/create-subscription-session")
|
||||
async def create_subscription_session(
|
||||
request: CreateSubscriptionSessionRequest | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> SubscriptionSessionResponse:
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="Tenant ID not found")
|
||||
|
||||
billing_period = request.billing_period if request else "monthly"
|
||||
session_id = fetch_stripe_checkout_session(tenant_id, billing_period)
|
||||
session_id = fetch_stripe_checkout_session(tenant_id)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create subscription session")
|
||||
logger.exception("Failed to create resubscription session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -74,12 +73,6 @@ class SubscriptionSessionResponse(BaseModel):
|
||||
sessionId: str
|
||||
|
||||
|
||||
class CreateSubscriptionSessionRequest(BaseModel):
|
||||
"""Request to create a subscription checkout session."""
|
||||
|
||||
billing_period: Literal["monthly", "annual"] = "monthly"
|
||||
|
||||
|
||||
class TenantByDomainResponse(BaseModel):
|
||||
tenant_id: str
|
||||
number_of_users: int
|
||||
|
||||
@@ -9,6 +9,7 @@ from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits_for_user
|
||||
from ee.onyx.db.token_limit import insert_user_group_token_rate_limit
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.token_limit import fetch_all_user_token_rate_limits
|
||||
@@ -16,7 +17,6 @@ from onyx.db.token_limit import insert_user_token_rate_limit
|
||||
from onyx.server.query_and_chat.token_limit import any_rate_limit_exists
|
||||
from onyx.server.token_rate_limits.models import TokenRateLimitArgs
|
||||
from onyx.server.token_rate_limits.models import TokenRateLimitDisplay
|
||||
from onyx.server.utils import PUBLIC_API_TAGS
|
||||
|
||||
router = APIRouter(prefix="/admin/token-rate-limits", tags=PUBLIC_API_TAGS)
|
||||
|
||||
|
||||
@@ -18,10 +18,10 @@ from ee.onyx.server.user_group.models import UserGroupCreate
|
||||
from ee.onyx.server.user_group.models import UserGroupUpdate
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.server.utils import PUBLIC_API_TAGS
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -517,6 +517,7 @@ def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
Raises WorkerShutdown if the timeout is reached."""
|
||||
|
||||
if ENABLE_OPENSEARCH_FOR_ONYX:
|
||||
# TODO(andrei): Do some similar liveness checking for OpenSearch.
|
||||
return
|
||||
|
||||
if not wait_for_vespa_with_timeout():
|
||||
|
||||
@@ -12,7 +12,6 @@ from retry import retry
|
||||
from sqlalchemy import select
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
@@ -20,14 +19,12 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
@@ -56,17 +53,6 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_queued_key(user_file_id: str | UUID) -> str:
|
||||
"""Key that exists while a process_single_user_file task is sitting in the queue.
|
||||
|
||||
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
before enqueuing and the worker deletes it as its first action. This prevents
|
||||
the beat from adding duplicate tasks for files that already have a live task
|
||||
in flight.
|
||||
"""
|
||||
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
@@ -130,24 +116,7 @@ def _get_document_chunk_count(
|
||||
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
|
||||
|
||||
Three mechanisms prevent queue runaway:
|
||||
|
||||
1. **Queue depth backpressure** – if the broker queue already has more than
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
|
||||
entirely. Workers are clearly behind; adding more tasks would only make
|
||||
the backlog worse.
|
||||
|
||||
2. **Per-file queued guard** – before enqueuing a task we set a short-lived
|
||||
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
|
||||
already exists the file already has a live task in the queue, so we skip
|
||||
it. The worker deletes the key the moment it picks up the task so the
|
||||
next beat cycle can re-enqueue if the file is still PROCESSING.
|
||||
|
||||
3. **Task expiry** – every enqueued task carries an `expires` value equal to
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
|
||||
the queue after that deadline, Celery discards it without touching the DB.
|
||||
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
|
||||
Redis restart), stale tasks evict themselves rather than piling up forever.
|
||||
Uses direct Redis locks to avoid overlapping runs.
|
||||
"""
|
||||
task_logger.info("check_user_file_processing - Starting")
|
||||
|
||||
@@ -162,21 +131,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
enqueued = 0
|
||||
skipped_guard = 0
|
||||
try:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
queue_len = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
|
||||
)
|
||||
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
f"check_user_file_processing - Queue depth {queue_len} exceeds "
|
||||
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file_ids = (
|
||||
db_session.execute(
|
||||
@@ -189,35 +144,12 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
)
|
||||
|
||||
for user_file_id in user_file_ids:
|
||||
# --- Protection 2: per-file queued guard ---
|
||||
queued_key = _user_file_queued_key(user_file_id)
|
||||
guard_set = redis_client.set(
|
||||
queued_key,
|
||||
1,
|
||||
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
nx=True,
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
if not guard_set:
|
||||
skipped_guard += 1
|
||||
continue
|
||||
|
||||
# --- Protection 3: task expiry ---
|
||||
# If task submission fails, clear the guard immediately so the
|
||||
# next beat cycle can retry enqueuing this file.
|
||||
try:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={
|
||||
"user_file_id": str(user_file_id),
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
)
|
||||
except Exception:
|
||||
redis_client.delete(queued_key)
|
||||
raise
|
||||
enqueued += 1
|
||||
|
||||
finally:
|
||||
@@ -225,8 +157,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
|
||||
f"tasks for tenant={tenant_id}"
|
||||
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -241,12 +172,6 @@ def process_single_user_file(self: Task, *, user_file_id: str, tenant_id: str) -
|
||||
start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Clear the "queued" guard set by the beat generator so that the next beat
|
||||
# cycle can re-enqueue this file if it is still in PROCESSING state after
|
||||
# this task completes or fails.
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
|
||||
@@ -18,12 +18,10 @@ from onyx.background.celery.tasks.kg_processing.kg_indexing import (
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import get_chat_messages_by_session
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
@@ -48,14 +46,10 @@ from onyx.kg.models import KGException
|
||||
from onyx.kg.setup.kg_default_entity_definitions import (
|
||||
populate_missing_default_entity_types__commit,
|
||||
)
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.prompts.chat_prompts import ADDITIONAL_CONTEXT_PROMPT
|
||||
from onyx.prompts.chat_prompts import TOOL_CALL_RESPONSE_CROSS_MESSAGE
|
||||
from onyx.prompts.tool_prompts import TOOL_CALL_FAILURE_PROMPT
|
||||
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import (
|
||||
@@ -104,91 +98,6 @@ def create_chat_session_from_request(
|
||||
)
|
||||
|
||||
|
||||
def prepare_chat_message_request(
|
||||
message_text: str,
|
||||
user: User | None,
|
||||
persona_id: int | None,
|
||||
# Does the question need to have a persona override
|
||||
persona_override_config: PersonaOverrideConfig | None,
|
||||
message_ts_to_respond_to: str | None,
|
||||
retrieval_details: RetrievalDetails | None,
|
||||
rerank_settings: RerankingDetails | None,
|
||||
db_session: Session,
|
||||
skip_gen_ai_answer_generation: bool = False,
|
||||
llm_override: LLMOverride | None = None,
|
||||
allowed_tool_ids: list[int] | None = None,
|
||||
forced_tool_ids: list[int] | None = None,
|
||||
origin: MessageOrigin | None = None,
|
||||
) -> CreateChatMessageRequest:
|
||||
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
|
||||
new_chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description=None,
|
||||
user_id=user.id if user else None,
|
||||
# If using an override, this id will be ignored later on
|
||||
persona_id=persona_id or DEFAULT_PERSONA_ID,
|
||||
onyxbot_flow=True,
|
||||
slack_thread_id=message_ts_to_respond_to,
|
||||
)
|
||||
|
||||
return CreateChatMessageRequest(
|
||||
chat_session_id=new_chat_session.id,
|
||||
parent_message_id=None, # It's a standalone chat session each time
|
||||
message=message_text,
|
||||
file_descriptors=[], # Currently SlackBot/answer api do not support files in the context
|
||||
# Can always override the persona for the single query, if it's a normal persona
|
||||
# then it will be treated the same
|
||||
persona_override_config=persona_override_config,
|
||||
search_doc_ids=None,
|
||||
retrieval_options=retrieval_details,
|
||||
rerank_settings=rerank_settings,
|
||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||
llm_override=llm_override,
|
||||
allowed_tool_ids=allowed_tool_ids,
|
||||
forced_tool_ids=forced_tool_ids,
|
||||
origin=origin or MessageOrigin.UNKNOWN,
|
||||
)
|
||||
|
||||
|
||||
def combine_message_thread(
|
||||
messages: list[ThreadMessage],
|
||||
max_tokens: int | None,
|
||||
llm_tokenizer: BaseTokenizer,
|
||||
) -> str:
|
||||
"""Used to create a single combined message context from threads"""
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
message_strs: list[str] = []
|
||||
total_token_count = 0
|
||||
|
||||
for message in reversed(messages):
|
||||
if message.role == MessageType.USER:
|
||||
role_str = message.role.value.upper()
|
||||
if message.sender:
|
||||
role_str += " " + message.sender
|
||||
else:
|
||||
# Since other messages might have the user identifying information
|
||||
# better to use Unknown for symmetry
|
||||
role_str += " Unknown"
|
||||
else:
|
||||
role_str = message.role.value.upper()
|
||||
|
||||
msg_str = f"{role_str}:\n{message.message}"
|
||||
message_token_count = len(llm_tokenizer.encode(msg_str))
|
||||
|
||||
if (
|
||||
max_tokens is not None
|
||||
and total_token_count + message_token_count > max_tokens
|
||||
):
|
||||
break
|
||||
|
||||
message_strs.insert(0, msg_str)
|
||||
total_token_count += message_token_count
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
|
||||
def create_chat_history_chain(
|
||||
chat_session_id: UUID,
|
||||
db_session: Session,
|
||||
@@ -250,31 +159,6 @@ def create_chat_history_chain(
|
||||
return mainline_messages
|
||||
|
||||
|
||||
def combine_message_chain(
|
||||
messages: list[ChatMessage],
|
||||
token_limit: int,
|
||||
msg_limit: int | None = None,
|
||||
) -> str:
|
||||
"""Used for secondary LLM flows that require the chat history,"""
|
||||
message_strs: list[str] = []
|
||||
total_token_count = 0
|
||||
|
||||
if msg_limit is not None:
|
||||
messages = messages[-msg_limit:]
|
||||
|
||||
for message in cast(list[ChatMessage], reversed(messages)):
|
||||
message_token_count = message.token_count
|
||||
|
||||
if total_token_count + message_token_count > token_limit:
|
||||
break
|
||||
|
||||
role = message.message_type.value.upper()
|
||||
message_strs.insert(0, f"{role}:\n{message.message}")
|
||||
total_token_count += message_token_count
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
|
||||
def reorganize_citations(
|
||||
answer: str, citations: list[CitationInfo]
|
||||
) -> tuple[str, list[CitationInfo]]:
|
||||
@@ -415,7 +299,7 @@ def create_temporary_persona(
|
||||
num_chunks=persona_config.num_chunks,
|
||||
llm_relevance_filter=persona_config.llm_relevance_filter,
|
||||
llm_filter_extraction=persona_config.llm_filter_extraction,
|
||||
recency_bias=persona_config.recency_bias,
|
||||
recency_bias=RecencyBiasSetting.BASE_DECAY,
|
||||
llm_model_provider_override=persona_config.llm_model_provider_override,
|
||||
llm_model_version_override=persona_config.llm_model_version_override,
|
||||
)
|
||||
@@ -585,6 +469,71 @@ def load_all_chat_files(
|
||||
return files
|
||||
|
||||
|
||||
def convert_chat_history_basic(
|
||||
chat_history: list[ChatMessage],
|
||||
token_counter: Callable[[str], int],
|
||||
max_individual_message_tokens: int | None = None,
|
||||
max_total_tokens: int | None = None,
|
||||
) -> list[ChatMessageSimple]:
|
||||
"""Convert ChatMessage history to ChatMessageSimple format with no tool calls or files included.
|
||||
|
||||
Args:
|
||||
chat_history: List of ChatMessage objects to convert
|
||||
token_counter: Function to count tokens in a message string
|
||||
max_individual_message_tokens: If set, messages exceeding this number of tokens are dropped.
|
||||
If None, no messages are dropped based on individual token count.
|
||||
max_total_tokens: If set, maximum number of tokens allowed for the entire history.
|
||||
If None, the history is not trimmed based on total token count.
|
||||
|
||||
Returns:
|
||||
List of ChatMessageSimple objects
|
||||
"""
|
||||
# Defensive: treat a non-positive total budget as "no history".
|
||||
if max_total_tokens is not None and max_total_tokens <= 0:
|
||||
return []
|
||||
|
||||
# Convert only the core USER/ASSISTANT messages; omit files and tool calls.
|
||||
converted: list[ChatMessageSimple] = []
|
||||
for chat_message in chat_history:
|
||||
if chat_message.message_type not in (MessageType.USER, MessageType.ASSISTANT):
|
||||
continue
|
||||
|
||||
message = chat_message.message or ""
|
||||
token_count = getattr(chat_message, "token_count", None)
|
||||
if token_count is None:
|
||||
token_count = token_counter(message)
|
||||
|
||||
# Drop any single message that would dominate the context window.
|
||||
if (
|
||||
max_individual_message_tokens is not None
|
||||
and token_count > max_individual_message_tokens
|
||||
):
|
||||
continue
|
||||
|
||||
converted.append(
|
||||
ChatMessageSimple(
|
||||
message=message,
|
||||
token_count=token_count,
|
||||
message_type=chat_message.message_type,
|
||||
image_files=None,
|
||||
)
|
||||
)
|
||||
|
||||
if max_total_tokens is None:
|
||||
return converted
|
||||
|
||||
# Enforce a max total budget by keeping a contiguous suffix of the conversation.
|
||||
trimmed_reversed: list[ChatMessageSimple] = []
|
||||
total_tokens = 0
|
||||
for msg in reversed(converted):
|
||||
if total_tokens + msg.token_count > max_total_tokens:
|
||||
break
|
||||
trimmed_reversed.append(msg)
|
||||
total_tokens += msg.token_count
|
||||
|
||||
return list(reversed(trimmed_reversed))
|
||||
|
||||
|
||||
def convert_chat_history(
|
||||
chat_history: list[ChatMessage],
|
||||
files: list[ChatLoadedFile],
|
||||
|
||||
@@ -4,14 +4,15 @@ Dynamic Citation Processor for LLM Responses
|
||||
This module provides a citation processor that can:
|
||||
- Accept citation number to SearchDoc mappings dynamically
|
||||
- Process token streams from LLMs to extract citations
|
||||
- Optionally replace citation markers with formatted markdown links
|
||||
- Emit CitationInfo objects for detected citations (when replacing)
|
||||
- Track all seen citations regardless of replacement mode
|
||||
- Handle citations in three modes: REMOVE, KEEP_MARKERS, or HYPERLINK
|
||||
- Emit CitationInfo objects for detected citations (in HYPERLINK mode)
|
||||
- Track all seen citations regardless of mode
|
||||
- Maintain a list of cited documents in order of first citation
|
||||
"""
|
||||
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from typing import TypeAlias
|
||||
|
||||
from onyx.configs.chat_configs import STOP_STREAM_PAT
|
||||
@@ -23,6 +24,29 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class CitationMode(Enum):
|
||||
"""Defines how citations should be handled in the output.
|
||||
|
||||
REMOVE: Citations are completely removed from output text.
|
||||
No CitationInfo objects are emitted.
|
||||
Use case: When you need to remove citations from the output if they are not shared with the user
|
||||
(e.g. in discord bot, public slack bot).
|
||||
|
||||
KEEP_MARKERS: Original citation markers like [1], [2] are preserved unchanged.
|
||||
No CitationInfo objects are emitted.
|
||||
Use case: When you need to track citations in research agent and later process
|
||||
them with collapse_citations() to renumber.
|
||||
|
||||
HYPERLINK: Citations are replaced with markdown links like [[1]](url).
|
||||
CitationInfo objects are emitted for UI tracking.
|
||||
Use case: Final reports shown to users with clickable links.
|
||||
"""
|
||||
|
||||
REMOVE = "remove"
|
||||
KEEP_MARKERS = "keep_markers"
|
||||
HYPERLINK = "hyperlink"
|
||||
|
||||
|
||||
CitationMapping: TypeAlias = dict[int, SearchDoc]
|
||||
|
||||
|
||||
@@ -48,29 +72,37 @@ class DynamicCitationProcessor:
|
||||
|
||||
This processor is designed for multi-turn conversations where the citation
|
||||
number to document mapping is provided externally. It processes streaming
|
||||
tokens from an LLM, detects citations (e.g., [1], [2,3], [[4]]), and based
|
||||
on the `replace_citation_tokens` setting:
|
||||
tokens from an LLM, detects citations (e.g., [1], [2,3], [[4]]), and handles
|
||||
them according to the configured CitationMode:
|
||||
|
||||
When replace_citation_tokens=True (default):
|
||||
CitationMode.HYPERLINK (default):
|
||||
1. Replaces citation markers with formatted markdown links (e.g., [[1]](url))
|
||||
2. Emits CitationInfo objects for tracking
|
||||
3. Maintains the order in which documents were first cited
|
||||
Use case: Final reports shown to users with clickable links.
|
||||
|
||||
When replace_citation_tokens=False:
|
||||
1. Preserves original citation markers in the output text
|
||||
CitationMode.KEEP_MARKERS:
|
||||
1. Preserves original citation markers like [1], [2] unchanged
|
||||
2. Does NOT emit CitationInfo objects
|
||||
3. Still tracks all seen citations via get_seen_citations()
|
||||
Use case: When citations need later processing (e.g., renumbering).
|
||||
|
||||
CitationMode.REMOVE:
|
||||
1. Removes citation markers entirely from the output text
|
||||
2. Does NOT emit CitationInfo objects
|
||||
3. Still tracks all seen citations via get_seen_citations()
|
||||
Use case: Research agent intermediate reports.
|
||||
|
||||
Features:
|
||||
- Accepts citation number → SearchDoc mapping via update_citation_mapping()
|
||||
- Configurable citation replacement behavior at initialization
|
||||
- Always tracks seen citations regardless of replacement mode
|
||||
- Configurable citation mode at initialization
|
||||
- Always tracks seen citations regardless of mode
|
||||
- Holds back tokens that might be partial citations
|
||||
- Maintains list of cited SearchDocs in order of first citation
|
||||
- Handles unicode bracket variants (【】, [])
|
||||
- Skips citation processing inside code blocks
|
||||
|
||||
Example (with citation replacement - default):
|
||||
Example (HYPERLINK mode - default):
|
||||
processor = DynamicCitationProcessor()
|
||||
|
||||
# Set up citation mapping
|
||||
@@ -87,8 +119,8 @@ class DynamicCitationProcessor:
|
||||
# Get cited documents at the end
|
||||
cited_docs = processor.get_cited_documents()
|
||||
|
||||
Example (without citation replacement):
|
||||
processor = DynamicCitationProcessor(replace_citation_tokens=False)
|
||||
Example (KEEP_MARKERS mode):
|
||||
processor = DynamicCitationProcessor(citation_mode=CitationMode.KEEP_MARKERS)
|
||||
processor.update_citation_mapping({1: search_doc1, 2: search_doc2})
|
||||
|
||||
# Process tokens from LLM
|
||||
@@ -99,26 +131,42 @@ class DynamicCitationProcessor:
|
||||
|
||||
# Get all seen citations after processing
|
||||
seen_citations = processor.get_seen_citations() # {1: search_doc1, ...}
|
||||
|
||||
Example (REMOVE mode):
|
||||
processor = DynamicCitationProcessor(citation_mode=CitationMode.REMOVE)
|
||||
processor.update_citation_mapping({1: search_doc1, 2: search_doc2})
|
||||
|
||||
# Process tokens - citations are removed but tracked
|
||||
for token in llm_stream:
|
||||
for result in processor.process_token(token):
|
||||
print(result) # Text without any citation markers
|
||||
|
||||
# Citations are still tracked
|
||||
seen_citations = processor.get_seen_citations()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
replace_citation_tokens: bool = True,
|
||||
citation_mode: CitationMode = CitationMode.HYPERLINK,
|
||||
stop_stream: str | None = STOP_STREAM_PAT,
|
||||
):
|
||||
"""
|
||||
Initialize the citation processor.
|
||||
|
||||
Args:
|
||||
replace_citation_tokens: If True (default), citations like [1] are replaced
|
||||
with formatted markdown links like [[1]](url) and CitationInfo objects
|
||||
are emitted. If False, original citation text is preserved in output
|
||||
and no CitationInfo objects are emitted. Regardless of this setting,
|
||||
all seen citations are tracked and available via get_seen_citations().
|
||||
citation_mode: How to handle citations in the output. One of:
|
||||
- CitationMode.HYPERLINK (default): Replace [1] with [[1]](url)
|
||||
and emit CitationInfo objects.
|
||||
- CitationMode.KEEP_MARKERS: Keep original [1] markers unchanged,
|
||||
no CitationInfo objects emitted.
|
||||
- CitationMode.REMOVE: Remove citations entirely from output,
|
||||
no CitationInfo objects emitted.
|
||||
All modes track seen citations via get_seen_citations().
|
||||
stop_stream: Optional stop token pattern to halt processing early.
|
||||
When this pattern is detected in the token stream, processing stops.
|
||||
Defaults to STOP_STREAM_PAT from chat configs.
|
||||
"""
|
||||
|
||||
# Citation mapping from citation number to SearchDoc
|
||||
self.citation_to_doc: CitationMapping = {}
|
||||
self.seen_citations: CitationMapping = {} # citation num -> SearchDoc
|
||||
@@ -128,7 +176,7 @@ class DynamicCitationProcessor:
|
||||
self.curr_segment = "" # tokens held for citation processing
|
||||
self.hold = "" # tokens held for stop token processing
|
||||
self.stop_stream = stop_stream
|
||||
self.replace_citation_tokens = replace_citation_tokens
|
||||
self.citation_mode = citation_mode
|
||||
|
||||
# Citation tracking
|
||||
self.cited_documents_in_order: list[SearchDoc] = (
|
||||
@@ -199,19 +247,21 @@ class DynamicCitationProcessor:
|
||||
5. Handles stop tokens
|
||||
6. Always tracks seen citations in self.seen_citations
|
||||
|
||||
Behavior depends on the `replace_citation_tokens` setting from __init__:
|
||||
- If True: Citations are replaced with [[n]](url) format and CitationInfo
|
||||
Behavior depends on the `citation_mode` setting from __init__:
|
||||
- HYPERLINK: Citations are replaced with [[n]](url) format and CitationInfo
|
||||
objects are yielded before each formatted citation
|
||||
- If False: Original citation text (e.g., [1]) is preserved in output
|
||||
and no CitationInfo objects are yielded
|
||||
- KEEP_MARKERS: Original citation markers like [1] are preserved unchanged,
|
||||
no CitationInfo objects are yielded
|
||||
- REMOVE: Citations are removed entirely from output,
|
||||
no CitationInfo objects are yielded
|
||||
|
||||
Args:
|
||||
token: The next token from the LLM stream, or None to signal end of stream.
|
||||
Pass None to flush any remaining buffered text at end of stream.
|
||||
|
||||
Yields:
|
||||
str: Text chunks to display. Citation format depends on replace_citation_tokens.
|
||||
CitationInfo: Citation metadata (only when replace_citation_tokens=True)
|
||||
str: Text chunks to display. Citation format depends on citation_mode.
|
||||
CitationInfo: Citation metadata (only when citation_mode=HYPERLINK)
|
||||
"""
|
||||
# None -> end of stream, flush remaining segment
|
||||
if token is None:
|
||||
@@ -299,17 +349,17 @@ class DynamicCitationProcessor:
|
||||
if self.non_citation_count > 5:
|
||||
self.recent_cited_documents.clear()
|
||||
|
||||
# Yield text before citation FIRST (preserve order)
|
||||
if intermatch_str:
|
||||
yield intermatch_str
|
||||
|
||||
# Process the citation (returns formatted citation text and CitationInfo objects)
|
||||
# Always tracks seen citations regardless of strip_citations flag
|
||||
# Always tracks seen citations regardless of citation_mode
|
||||
citation_text, citation_info_list = self._process_citation(
|
||||
match, has_leading_space, self.replace_citation_tokens
|
||||
match, has_leading_space
|
||||
)
|
||||
|
||||
if self.replace_citation_tokens:
|
||||
if self.citation_mode == CitationMode.HYPERLINK:
|
||||
# HYPERLINK mode: Replace citations with markdown links [[n]](url)
|
||||
# Yield text before citation FIRST (preserve order)
|
||||
if intermatch_str:
|
||||
yield intermatch_str
|
||||
# Yield CitationInfo objects BEFORE the citation text
|
||||
# This allows the frontend to receive citation metadata before the token
|
||||
# that contains [[n]](link), enabling immediate rendering
|
||||
@@ -318,10 +368,34 @@ class DynamicCitationProcessor:
|
||||
# Then yield the formatted citation text
|
||||
if citation_text:
|
||||
yield citation_text
|
||||
else:
|
||||
# When not stripping, yield the original citation text unchanged
|
||||
|
||||
elif self.citation_mode == CitationMode.KEEP_MARKERS:
|
||||
# KEEP_MARKERS mode: Preserve original citation markers unchanged
|
||||
# Yield text before citation
|
||||
if intermatch_str:
|
||||
yield intermatch_str
|
||||
# Yield the original citation marker as-is
|
||||
yield match.group()
|
||||
|
||||
else: # CitationMode.REMOVE
|
||||
# REMOVE mode: Remove citations entirely from output
|
||||
# This strips citation markers like [1], [2], 【1】 from the output text
|
||||
# When removing citations, we need to handle spacing to avoid issues like:
|
||||
# - "text [1] more" -> "text more" (double space)
|
||||
# - "text [1]." -> "text ." (space before punctuation)
|
||||
if intermatch_str:
|
||||
remaining_text = self.curr_segment[match_span[1] :]
|
||||
# Strip trailing space from intermatch if:
|
||||
# 1. Remaining text starts with space (avoids double space)
|
||||
# 2. Remaining text starts with punctuation (avoids space before punctuation)
|
||||
if intermatch_str[-1].isspace() and remaining_text:
|
||||
first_char = remaining_text[0]
|
||||
# Check if next char is space or common punctuation
|
||||
if first_char.isspace() or first_char in ".,;:!?)]}":
|
||||
intermatch_str = intermatch_str.rstrip()
|
||||
if intermatch_str:
|
||||
yield intermatch_str
|
||||
|
||||
self.non_citation_count = 0
|
||||
|
||||
# Leftover text could be part of next citation
|
||||
@@ -338,7 +412,7 @@ class DynamicCitationProcessor:
|
||||
yield result
|
||||
|
||||
def _process_citation(
|
||||
self, match: re.Match, has_leading_space: bool, replace_tokens: bool = True
|
||||
self, match: re.Match, has_leading_space: bool
|
||||
) -> tuple[str, list[CitationInfo]]:
|
||||
"""
|
||||
Process a single citation match and return formatted citation text and citation info objects.
|
||||
@@ -349,31 +423,28 @@ class DynamicCitationProcessor:
|
||||
This method always:
|
||||
1. Extracts citation numbers from the match
|
||||
2. Looks up the corresponding SearchDoc from the mapping
|
||||
3. Tracks seen citations in self.seen_citations (regardless of replace_tokens)
|
||||
3. Tracks seen citations in self.seen_citations (regardless of citation_mode)
|
||||
|
||||
When replace_tokens=True (controlled by self.replace_citation_tokens):
|
||||
When citation_mode is HYPERLINK:
|
||||
4. Creates formatted citation text as [[n]](url)
|
||||
5. Creates CitationInfo objects for new citations
|
||||
6. Handles deduplication of recently cited documents
|
||||
|
||||
When replace_tokens=False:
|
||||
4. Returns empty string and empty list (caller yields original match text)
|
||||
When citation_mode is REMOVE or KEEP_MARKERS:
|
||||
4. Returns empty string and empty list (caller handles output based on mode)
|
||||
|
||||
Args:
|
||||
match: Regex match object containing the citation pattern
|
||||
has_leading_space: Whether the text immediately before this citation
|
||||
ends with whitespace. Used to determine if a leading space should
|
||||
be added to the formatted output.
|
||||
replace_tokens: If True, return formatted text and CitationInfo objects.
|
||||
If False, only track seen citations and return empty results.
|
||||
This is passed from self.replace_citation_tokens by the caller.
|
||||
|
||||
Returns:
|
||||
Tuple of (formatted_citation_text, citation_info_list):
|
||||
- formatted_citation_text: Markdown-formatted citation text like
|
||||
"[[1]](https://example.com)" or empty string if replace_tokens=False
|
||||
"[[1]](https://example.com)" or empty string if not in HYPERLINK mode
|
||||
- citation_info_list: List of CitationInfo objects for newly cited
|
||||
documents, or empty list if replace_tokens=False
|
||||
documents, or empty list if not in HYPERLINK mode
|
||||
"""
|
||||
citation_str: str = match.group() # e.g., '[1]', '[1, 2, 3]', '[[1]]', '【1】'
|
||||
formatted = (
|
||||
@@ -411,11 +482,11 @@ class DynamicCitationProcessor:
|
||||
doc_id = search_doc.document_id
|
||||
link = search_doc.link or ""
|
||||
|
||||
# Always track seen citations regardless of replace_tokens setting
|
||||
# Always track seen citations regardless of citation_mode setting
|
||||
self.seen_citations[num] = search_doc
|
||||
|
||||
# When not replacing citation tokens, skip the rest of the processing
|
||||
if not replace_tokens:
|
||||
# Only generate formatted citations and CitationInfo in HYPERLINK mode
|
||||
if self.citation_mode != CitationMode.HYPERLINK:
|
||||
continue
|
||||
|
||||
# Format the citation text as [[n]](link)
|
||||
@@ -450,14 +521,14 @@ class DynamicCitationProcessor:
|
||||
"""
|
||||
Get the list of cited SearchDoc objects in the order they were first cited.
|
||||
|
||||
Note: This list is only populated when `replace_citation_tokens=True`.
|
||||
When `replace_citation_tokens=False`, this will return an empty list.
|
||||
Note: This list is only populated when `citation_mode=HYPERLINK`.
|
||||
When using REMOVE or KEEP_MARKERS mode, this will return an empty list.
|
||||
Use get_seen_citations() instead if you need to track citations without
|
||||
replacing them.
|
||||
emitting CitationInfo objects.
|
||||
|
||||
Returns:
|
||||
List of SearchDoc objects in the order they were first cited.
|
||||
Empty list if replace_citation_tokens=False.
|
||||
Empty list if citation_mode is not HYPERLINK.
|
||||
"""
|
||||
return self.cited_documents_in_order
|
||||
|
||||
@@ -465,14 +536,14 @@ class DynamicCitationProcessor:
|
||||
"""
|
||||
Get the list of cited document IDs in the order they were first cited.
|
||||
|
||||
Note: This list is only populated when `replace_citation_tokens=True`.
|
||||
When `replace_citation_tokens=False`, this will return an empty list.
|
||||
Note: This list is only populated when `citation_mode=HYPERLINK`.
|
||||
When using REMOVE or KEEP_MARKERS mode, this will return an empty list.
|
||||
Use get_seen_citations() instead if you need to track citations without
|
||||
replacing them.
|
||||
emitting CitationInfo objects.
|
||||
|
||||
Returns:
|
||||
List of document IDs (strings) in the order they were first cited.
|
||||
Empty list if replace_citation_tokens=False.
|
||||
Empty list if citation_mode is not HYPERLINK.
|
||||
"""
|
||||
return [doc.document_id for doc in self.cited_documents_in_order]
|
||||
|
||||
@@ -481,12 +552,12 @@ class DynamicCitationProcessor:
|
||||
Get all seen citations as a mapping from citation number to SearchDoc.
|
||||
|
||||
This returns all citations that have been encountered during processing,
|
||||
regardless of the `replace_citation_tokens` setting. Citations are tracked
|
||||
regardless of the `citation_mode` setting. Citations are tracked
|
||||
whenever they are parsed, making this useful for cases where you need to
|
||||
know which citations appeared in the text without replacing them.
|
||||
know which citations appeared in the text without emitting CitationInfo objects.
|
||||
|
||||
This is particularly useful when `replace_citation_tokens=False`, as
|
||||
get_cited_documents() will be empty in that case, but get_seen_citations()
|
||||
This is particularly useful when using REMOVE or KEEP_MARKERS mode, as
|
||||
get_cited_documents() will be empty in those cases, but get_seen_citations()
|
||||
will still contain all the citations that were found.
|
||||
|
||||
Returns:
|
||||
@@ -501,13 +572,13 @@ class DynamicCitationProcessor:
|
||||
"""
|
||||
Get the number of unique documents that have been cited.
|
||||
|
||||
Note: This count is only updated when `replace_citation_tokens=True`.
|
||||
When `replace_citation_tokens=False`, this will always return 0.
|
||||
Note: This count is only updated when `citation_mode=HYPERLINK`.
|
||||
When using REMOVE or KEEP_MARKERS mode, this will always return 0.
|
||||
Use len(get_seen_citations()) instead if you need to count citations
|
||||
without replacing them.
|
||||
without emitting CitationInfo objects.
|
||||
|
||||
Returns:
|
||||
Number of unique documents cited. 0 if replace_citation_tokens=False.
|
||||
Number of unique documents cited. 0 if citation_mode is not HYPERLINK.
|
||||
"""
|
||||
return len(self.cited_document_ids)
|
||||
|
||||
@@ -519,9 +590,9 @@ class DynamicCitationProcessor:
|
||||
CitationInfo objects for the same document when it's cited multiple times
|
||||
in close succession. This method clears that tracker.
|
||||
|
||||
This is primarily useful when `replace_citation_tokens=True` to allow
|
||||
This is primarily useful when `citation_mode=HYPERLINK` to allow
|
||||
previously cited documents to emit CitationInfo objects again. Has no
|
||||
effect when `replace_citation_tokens=False`.
|
||||
effect when using REMOVE or KEEP_MARKERS mode.
|
||||
|
||||
The recent citation tracker is also automatically cleared when more than
|
||||
5 non-citation characters are processed between citations.
|
||||
|
||||
@@ -5,6 +5,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_utils import create_tool_call_failure_messages
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.citation_processor import CitationMode
|
||||
from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
from onyx.chat.citation_utils import update_citation_processor_from_tool_response
|
||||
from onyx.chat.emitter import Emitter
|
||||
@@ -297,6 +298,7 @@ def run_llm_loop(
|
||||
forced_tool_id: int | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
chat_session_id: str | None = None,
|
||||
include_citations: bool = True,
|
||||
) -> None:
|
||||
with trace(
|
||||
"run_llm_loop",
|
||||
@@ -314,7 +316,13 @@ def run_llm_loop(
|
||||
initialize_litellm()
|
||||
|
||||
# Initialize citation processor for handling citations dynamically
|
||||
citation_processor = DynamicCitationProcessor()
|
||||
# When include_citations is True, use HYPERLINK mode to format citations as [[1]](url)
|
||||
# When include_citations is False, use REMOVE mode to strip citations from output
|
||||
citation_processor = DynamicCitationProcessor(
|
||||
citation_mode=(
|
||||
CitationMode.HYPERLINK if include_citations else CitationMode.REMOVE
|
||||
)
|
||||
)
|
||||
|
||||
# Add project file citation mappings if project files are present
|
||||
project_citation_mapping: CitationMapping = {}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Mapping
|
||||
@@ -136,12 +137,11 @@ def _format_message_history_for_logging(
|
||||
|
||||
separator = "================================================"
|
||||
|
||||
# Handle string input
|
||||
if isinstance(message_history, str):
|
||||
formatted_lines.append("Message [string]:")
|
||||
formatted_lines.append(separator)
|
||||
formatted_lines.append(f"{message_history}")
|
||||
return "\n".join(formatted_lines)
|
||||
# Handle single ChatCompletionMessage - wrap in list for uniform processing
|
||||
if isinstance(
|
||||
message_history, (SystemMessage, UserMessage, AssistantMessage, ToolMessage)
|
||||
):
|
||||
message_history = [message_history]
|
||||
|
||||
# Handle sequence of messages
|
||||
for i, msg in enumerate(message_history):
|
||||
@@ -211,7 +211,8 @@ def _update_tool_call_with_delta(
|
||||
|
||||
if index not in tool_calls_in_progress:
|
||||
tool_calls_in_progress[index] = {
|
||||
"id": None,
|
||||
# Fallback ID in case the provider never sends one via deltas.
|
||||
"id": f"fallback_{uuid.uuid4().hex}",
|
||||
"name": None,
|
||||
"arguments": "",
|
||||
}
|
||||
@@ -581,6 +582,18 @@ def run_llm_step_pkt_generator(
|
||||
}
|
||||
# Note: LLM cost tracking is now handled in multi_llm.py
|
||||
delta = packet.choice.delta
|
||||
|
||||
# Weird behavior from some model providers, just log and ignore for now
|
||||
if (
|
||||
delta.content is None
|
||||
and delta.reasoning_content is None
|
||||
and delta.tool_calls is None
|
||||
):
|
||||
logger.warning(
|
||||
f"LLM packet is empty (no contents, reasoning or tool calls). Skipping: {packet}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not first_action_recorded and _delta_has_action(delta):
|
||||
span_generation.span_data.time_to_first_action_seconds = (
|
||||
time.monotonic() - stream_start_time
|
||||
@@ -840,14 +853,14 @@ def run_llm_step_pkt_generator(
|
||||
logger.debug(f"Accumulated reasoning: {accumulated_reasoning}")
|
||||
logger.debug(f"Accumulated answer: {accumulated_answer}")
|
||||
|
||||
if tool_calls:
|
||||
tool_calls_str = "\n".join(
|
||||
f" - {tc.tool_name}: {json.dumps(tc.tool_args, indent=4)}"
|
||||
for tc in tool_calls
|
||||
)
|
||||
logger.debug(f"Tool calls:\n{tool_calls_str}")
|
||||
else:
|
||||
logger.debug("Tool calls: []")
|
||||
if tool_calls:
|
||||
tool_calls_str = "\n".join(
|
||||
f" - {tc.tool_name}: {json.dumps(tc.tool_args, indent=4)}"
|
||||
for tc in tool_calls
|
||||
)
|
||||
logger.debug(f"Tool calls:\n{tool_calls_str}")
|
||||
else:
|
||||
logger.debug("Tool calls: []")
|
||||
|
||||
return (
|
||||
LlmStepResult(
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
@@ -8,10 +7,7 @@ from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.enums import QueryFlow
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
@@ -24,25 +20,6 @@ from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.tool_implementations.custom.base_tool_types import ToolResultType
|
||||
|
||||
|
||||
# First chunk of info for streaming QA
|
||||
class QADocsResponse(BaseModel):
|
||||
top_documents: list[SearchDoc]
|
||||
rephrased_query: str | None = None
|
||||
predicted_flow: QueryFlow | None
|
||||
predicted_search: SearchType | None
|
||||
applied_source_filters: list[DocumentSource] | None
|
||||
applied_time_cutoff: datetime | None
|
||||
recency_bias_multiplier: float
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
initial_dict["applied_time_cutoff"] = (
|
||||
self.applied_time_cutoff.isoformat() if self.applied_time_cutoff else None
|
||||
)
|
||||
|
||||
return initial_dict
|
||||
|
||||
|
||||
class StreamStopReason(Enum):
|
||||
CONTEXT_LENGTH = "context_length"
|
||||
CANCELLED = "cancelled"
|
||||
@@ -70,22 +47,11 @@ class UserKnowledgeFilePacket(BaseModel):
|
||||
user_files: list[FileDescriptor]
|
||||
|
||||
|
||||
class LLMRelevanceFilterResponse(BaseModel):
|
||||
llm_selected_doc_indices: list[int]
|
||||
|
||||
|
||||
class RelevanceAnalysis(BaseModel):
|
||||
relevant: bool
|
||||
content: str | None = None
|
||||
|
||||
|
||||
class SectionRelevancePiece(RelevanceAnalysis):
|
||||
"""LLM analysis mapped to an Inference Section"""
|
||||
|
||||
document_id: str
|
||||
chunk_id: int # ID of the center chunk for a given inference section
|
||||
|
||||
|
||||
class DocumentRelevance(BaseModel):
|
||||
"""Contains all relevance information for a given search"""
|
||||
|
||||
@@ -116,12 +82,6 @@ class OnyxAnswer(BaseModel):
|
||||
answer: str | None
|
||||
|
||||
|
||||
class ThreadMessage(BaseModel):
|
||||
message: str
|
||||
sender: str | None = None
|
||||
role: MessageType = MessageType.USER
|
||||
|
||||
|
||||
class FileChatDisplay(BaseModel):
|
||||
file_ids: list[str]
|
||||
|
||||
@@ -158,7 +118,6 @@ class PersonaOverrideConfig(BaseModel):
|
||||
num_chunks: float | None = None
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
|
||||
|
||||
@@ -38,9 +38,10 @@ from onyx.chat.save_chat import save_chat_turn
|
||||
from onyx.chat.stop_signal_checker import is_connected as check_stop_signal
|
||||
from onyx.chat.stop_signal_checker import reset_cancel_status
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import CitationDocInfo
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
@@ -50,6 +51,7 @@ from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.memory import get_memories
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.projects import get_project_token_count
|
||||
from onyx.db.projects import get_user_files_from_project
|
||||
@@ -67,6 +69,7 @@ from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.models import OptionalSearchSetting
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
@@ -93,6 +96,22 @@ logger = setup_logger()
|
||||
ERROR_TYPE_CANCELLED = "cancelled"
|
||||
|
||||
|
||||
def _should_enable_slack_search(
|
||||
persona: Persona,
|
||||
filters: BaseFilters | None,
|
||||
) -> bool:
|
||||
"""Determine if Slack search should be enabled.
|
||||
|
||||
Returns True if:
|
||||
- Source type filter exists and includes Slack, OR
|
||||
- Default persona with no source type filter
|
||||
"""
|
||||
source_types = filters.source_type if filters else None
|
||||
return (source_types is not None and DocumentSource.SLACK in source_types) or (
|
||||
persona.id == DEFAULT_PERSONA_ID and source_types is None
|
||||
)
|
||||
|
||||
|
||||
def _extract_project_file_texts_and_images(
|
||||
project_id: int | None,
|
||||
user_id: UUID | None,
|
||||
@@ -281,6 +300,7 @@ def handle_stream_message_objects(
|
||||
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
mcp_headers: dict[str, str] | None = None,
|
||||
bypass_acl: bool = False,
|
||||
# Additional context that should be included in the chat history, for example:
|
||||
# Slack threads where the conversation cannot be represented by a chain of User/Assistant
|
||||
@@ -504,11 +524,15 @@ def handle_stream_message_objects(
|
||||
),
|
||||
bypass_acl=bypass_acl,
|
||||
slack_context=slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
persona, new_msg_req.internal_search_filters
|
||||
),
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=chat_session.id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
additional_headers=custom_tool_additional_headers,
|
||||
mcp_headers=mcp_headers,
|
||||
),
|
||||
allowed_tool_ids=new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=project_search_config.search_usage,
|
||||
@@ -629,6 +653,7 @@ def handle_stream_message_objects(
|
||||
forced_tool_id=forced_tool_id,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
include_citations=new_msg_req.include_citations,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
@@ -790,6 +815,7 @@ def stream_chat_message_objects(
|
||||
parent_message_id=new_msg_req.parent_message_id,
|
||||
chat_session_id=new_msg_req.chat_session_id,
|
||||
origin=new_msg_req.origin,
|
||||
include_citations=new_msg_req.include_citations,
|
||||
)
|
||||
return handle_stream_message_objects(
|
||||
new_msg_req=translated_new_msg_req,
|
||||
|
||||
@@ -22,6 +22,14 @@ APP_PORT = 8080
|
||||
# prefix from requests directed towards the API server. In these cases, set this to `/api`
|
||||
APP_API_PREFIX = os.environ.get("API_PREFIX", "")
|
||||
|
||||
# Certain services need to make HTTP requests to the API server, such as the MCP server and Discord bot
|
||||
API_SERVER_PROTOCOL = os.environ.get("API_SERVER_PROTOCOL", "http")
|
||||
API_SERVER_HOST = os.environ.get("API_SERVER_HOST", "127.0.0.1")
|
||||
# This override allows self-hosting the MCP server with Onyx Cloud backend.
|
||||
API_SERVER_URL_OVERRIDE_FOR_HTTP_REQUESTS = os.environ.get(
|
||||
"API_SERVER_URL_OVERRIDE_FOR_HTTP_REQUESTS"
|
||||
)
|
||||
|
||||
# Whether to send user metadata (user_id/email and session_id) to the LLM provider.
|
||||
# Disabled by default.
|
||||
SEND_USER_METADATA_TO_LLM_PROVIDER = (
|
||||
@@ -850,6 +858,7 @@ AZURE_IMAGE_DEPLOYMENT_NAME = os.environ.get(
|
||||
|
||||
# configurable image model
|
||||
IMAGE_MODEL_NAME = os.environ.get("IMAGE_MODEL_NAME", "gpt-image-1")
|
||||
IMAGE_MODEL_PROVIDER = os.environ.get("IMAGE_MODEL_PROVIDER", "openai")
|
||||
|
||||
# Use managed Vespa (Vespa Cloud). If set, must also set VESPA_CLOUD_URL, VESPA_CLOUD_CERT_PATH and VESPA_CLOUD_KEY_PATH
|
||||
MANAGED_VESPA = os.environ.get("MANAGED_VESPA", "").lower() == "true"
|
||||
|
||||
@@ -12,9 +12,6 @@ NUM_POSTPROCESSED_RESULTS = 20
|
||||
# May be less depending on model
|
||||
MAX_CHUNKS_FED_TO_CHAT = int(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 25)
|
||||
|
||||
# Maximum percentage of the context window to fill with selected sections
|
||||
SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE = 0.8
|
||||
|
||||
# 1 / (1 + DOC_TIME_DECAY * doc-age-in-years), set to 0 to have no decay
|
||||
# Capped in Vespa at 0.5
|
||||
DOC_TIME_DECAY = float(
|
||||
@@ -27,11 +24,6 @@ FAVOR_RECENT_DECAY_MULTIPLIER = 2.0
|
||||
# Currently only applies to search flow not chat
|
||||
CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 1)
|
||||
CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 1)
|
||||
DISABLE_LLM_QUERY_REPHRASE = (
|
||||
os.environ.get("DISABLE_LLM_QUERY_REPHRASE", "").lower() == "true"
|
||||
)
|
||||
# 1 edit per 20 characters, currently unused due to fuzzy match being too slow
|
||||
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
|
||||
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
|
||||
# Weighting factor between Vector and Keyword Search, 1 for completely vector search
|
||||
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.5)))
|
||||
@@ -46,34 +38,6 @@ TITLE_CONTENT_RATIO = max(
|
||||
0, min(1, float(os.environ.get("TITLE_CONTENT_RATIO") or 0.10))
|
||||
)
|
||||
|
||||
# A list of languages passed to the LLM to rephase the query
|
||||
# For example "English,French,Spanish", be sure to use the "," separator
|
||||
# TODO these are not used, should probably reintroduce these
|
||||
MULTILINGUAL_QUERY_EXPANSION = os.environ.get("MULTILINGUAL_QUERY_EXPANSION") or None
|
||||
LANGUAGE_HINT = "\n" + (
|
||||
os.environ.get("LANGUAGE_HINT")
|
||||
or "IMPORTANT: Respond in the same language as my query!"
|
||||
)
|
||||
LANGUAGE_CHAT_NAMING_HINT = (
|
||||
os.environ.get("LANGUAGE_CHAT_NAMING_HINT")
|
||||
or "The name of the conversation must be in the same language as the user query."
|
||||
)
|
||||
|
||||
# Number of prompts each persona should have
|
||||
NUM_PERSONA_PROMPTS = 4
|
||||
NUM_PERSONA_PROMPT_GENERATION_CHUNKS = 5
|
||||
|
||||
# Agentic search takes significantly more tokens and therefore has much higher cost.
|
||||
# This configuration allows users to get a search-only experience with instant results
|
||||
# and no involvement from the LLM.
|
||||
# Additionally, some LLM providers have strict rate limits which may prohibit
|
||||
# sending many API requests at once (as is done in agentic search).
|
||||
# Whether the LLM should evaluate all of the document chunks passed in for usefulness
|
||||
# in relation to the user query
|
||||
DISABLE_LLM_DOC_RELEVANCE = (
|
||||
os.environ.get("DISABLE_LLM_DOC_RELEVANCE", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Stops streaming answers back to the UI if this pattern is seen:
|
||||
STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None
|
||||
|
||||
@@ -86,9 +50,6 @@ HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "").lower() == "true"
|
||||
NUM_INTERNET_SEARCH_RESULTS = int(os.environ.get("NUM_INTERNET_SEARCH_RESULTS") or 10)
|
||||
NUM_INTERNET_SEARCH_CHUNKS = int(os.environ.get("NUM_INTERNET_SEARCH_CHUNKS") or 50)
|
||||
|
||||
# Enable in-house model for detecting connector-based filtering in queries
|
||||
ENABLE_CONNECTOR_CLASSIFIER = os.environ.get("ENABLE_CONNECTOR_CLASSIFIER", False)
|
||||
|
||||
VESPA_SEARCHER_THREADS = int(os.environ.get("VESPA_SEARCHER_THREADS") or 2)
|
||||
|
||||
# Whether or not to use the semantic & keyword search expansions for Basic Search
|
||||
@@ -96,5 +57,3 @@ USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH = (
|
||||
os.environ.get("USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH", "false").lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
USE_DIV_CON_AGENT = os.environ.get("USE_DIV_CON_AGENT", "false").lower() == "true"
|
||||
|
||||
@@ -23,6 +23,9 @@ PUBLIC_DOC_PAT = "PUBLIC"
|
||||
ID_SEPARATOR = ":;:"
|
||||
DEFAULT_BOOST = 0
|
||||
|
||||
# Tag for endpoints that should be included in the public API documentation
|
||||
PUBLIC_API_TAGS: list[str | Enum] = ["public"]
|
||||
|
||||
# Cookies
|
||||
FASTAPI_USERS_AUTH_COOKIE_NAME = (
|
||||
"fastapiusersauth" # Currently a constant, but logic allows for configuration
|
||||
@@ -149,17 +152,6 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
|
||||
|
||||
# How long a queued user-file task is valid before workers discard it.
|
||||
# Should be longer than the beat interval (20 s) but short enough to prevent
|
||||
# indefinite queue growth. Workers drop tasks older than this without touching
|
||||
# the DB, so a shorter value = faster drain of stale duplicates.
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
|
||||
|
||||
# Maximum number of tasks allowed in the user-file-processing queue before the
|
||||
# beat generator stops adding more. Prevents unbounded queue growth when workers
|
||||
# fall behind.
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
|
||||
|
||||
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
|
||||
@@ -430,9 +422,6 @@ class OnyxRedisLocks:
|
||||
# User file processing
|
||||
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
|
||||
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
|
||||
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
|
||||
# Prevents the beat from re-enqueuing the same file while a task is already queued.
|
||||
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
|
||||
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
|
||||
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
|
||||
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
|
||||
|
||||
@@ -4,8 +4,6 @@ import os
|
||||
# Onyx Slack Bot Configs
|
||||
#####
|
||||
ONYX_BOT_NUM_RETRIES = int(os.environ.get("ONYX_BOT_NUM_RETRIES", "5"))
|
||||
# How much of the available input context can be used for thread context
|
||||
MAX_THREAD_CONTEXT_PERCENTAGE = 512 * 2 / 3072
|
||||
# Number of docs to display in "Reference Documents"
|
||||
ONYX_BOT_NUM_DOCS_TO_DISPLAY = int(os.environ.get("ONYX_BOT_NUM_DOCS_TO_DISPLAY", "5"))
|
||||
# If the LLM fails to answer, Onyx can still show the "Reference Documents"
|
||||
@@ -47,10 +45,6 @@ ONYX_BOT_MAX_WAIT_TIME = int(os.environ.get("ONYX_BOT_MAX_WAIT_TIME") or 180)
|
||||
# Time (in minutes) after which a Slack message is sent to the user to remind him to give feedback.
|
||||
# Set to 0 to disable it (default)
|
||||
ONYX_BOT_FEEDBACK_REMINDER = int(os.environ.get("ONYX_BOT_FEEDBACK_REMINDER") or 0)
|
||||
# Set to True to rephrase the Slack users messages
|
||||
ONYX_BOT_REPHRASE_MESSAGE = (
|
||||
os.environ.get("ONYX_BOT_REPHRASE_MESSAGE", "").lower() == "true"
|
||||
)
|
||||
|
||||
# ONYX_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD is the number of
|
||||
# responses OnyxBot can send in a given time period.
|
||||
|
||||
@@ -161,6 +161,8 @@ class DocumentBase(BaseModel):
|
||||
sections: list[TextSection | ImageSection]
|
||||
source: DocumentSource | None = None
|
||||
semantic_identifier: str # displayed in the UI as the main identifier for the doc
|
||||
# TODO(andrei): Ideally we could improve this to where each value is just a
|
||||
# list of strings.
|
||||
metadata: dict[str, str | list[str]]
|
||||
|
||||
# UTC time
|
||||
@@ -202,13 +204,7 @@ class DocumentBase(BaseModel):
|
||||
if not self.metadata:
|
||||
return None
|
||||
# Combined string for the key/value for easy filtering
|
||||
attributes: list[str] = []
|
||||
for k, v in self.metadata.items():
|
||||
if isinstance(v, list):
|
||||
attributes.extend([k + INDEX_SEPARATOR + vi for vi in v])
|
||||
else:
|
||||
attributes.append(k + INDEX_SEPARATOR + v)
|
||||
return attributes
|
||||
return convert_metadata_dict_to_list_of_strings(self.metadata)
|
||||
|
||||
def __sizeof__(self) -> int:
|
||||
size = sys.getsizeof(self.id)
|
||||
@@ -240,6 +236,66 @@ class DocumentBase(BaseModel):
|
||||
return " ".join([section.text for section in self.sections if section.text])
|
||||
|
||||
|
||||
def convert_metadata_dict_to_list_of_strings(
|
||||
metadata: dict[str, str | list[str]],
|
||||
) -> list[str]:
|
||||
"""Converts a metadata dict to a list of strings.
|
||||
|
||||
Each string is a key-value pair separated by the INDEX_SEPARATOR. If a key
|
||||
points to a list of values, each value generates a unique pair.
|
||||
|
||||
Args:
|
||||
metadata: The metadata dict to convert where values can be either a
|
||||
string or a list of strings.
|
||||
|
||||
Returns:
|
||||
A list of strings where each string is a key-value pair separated by the
|
||||
INDEX_SEPARATOR.
|
||||
"""
|
||||
attributes: list[str] = []
|
||||
for k, v in metadata.items():
|
||||
if isinstance(v, list):
|
||||
attributes.extend([k + INDEX_SEPARATOR + vi for vi in v])
|
||||
else:
|
||||
attributes.append(k + INDEX_SEPARATOR + v)
|
||||
return attributes
|
||||
|
||||
|
||||
def convert_metadata_list_of_strings_to_dict(
|
||||
metadata_list: list[str],
|
||||
) -> dict[str, str | list[str]]:
|
||||
"""
|
||||
Converts a list of strings to a metadata dict. The inverse of
|
||||
convert_metadata_dict_to_list_of_strings.
|
||||
|
||||
Assumes the input strings are formatted as in the output of
|
||||
convert_metadata_dict_to_list_of_strings.
|
||||
|
||||
The schema of the output metadata dict is suboptimal yet bound to legacy
|
||||
code. Ideally each key would just point to a list of strings, where each
|
||||
list might contain just one element.
|
||||
|
||||
Args:
|
||||
metadata_list: The list of strings to convert to a metadata dict.
|
||||
|
||||
Returns:
|
||||
A metadata dict where values can be either a string or a list of
|
||||
strings.
|
||||
"""
|
||||
metadata: dict[str, str | list[str]] = {}
|
||||
for item in metadata_list:
|
||||
key, value = item.split(INDEX_SEPARATOR, 1)
|
||||
if key in metadata:
|
||||
# We have already seen this key therefore it must point to a list.
|
||||
if isinstance(metadata[key], list):
|
||||
cast(list[str], metadata[key]).append(value)
|
||||
else:
|
||||
metadata[key] = [cast(str, metadata[key]), value]
|
||||
else:
|
||||
metadata[key] = value
|
||||
return metadata
|
||||
|
||||
|
||||
class Document(DocumentBase):
|
||||
"""Used for Onyx ingestion api, the ID is required"""
|
||||
|
||||
|
||||
@@ -13,13 +13,6 @@ class RecencyBiasSetting(str, Enum):
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class OptionalSearchSetting(str, Enum):
|
||||
ALWAYS = "always"
|
||||
NEVER = "never"
|
||||
# Determine whether to run search based on history and latest query
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class QueryType(str, Enum):
|
||||
"""
|
||||
The type of first-pass query to use for hybrid search.
|
||||
@@ -36,15 +29,3 @@ class SearchType(str, Enum):
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
INTERNET = "internet"
|
||||
|
||||
|
||||
class LLMEvaluationType(str, Enum):
|
||||
AGENTIC = "agentic" # applies agentic evaluation
|
||||
BASIC = "basic" # applies boolean evaluation
|
||||
SKIP = "skip" # skips evaluation
|
||||
UNSPECIFIED = "unspecified" # reverts to default
|
||||
|
||||
|
||||
class QueryFlow(str, Enum):
|
||||
SEARCH = "search"
|
||||
QUESTION_ANSWER = "question-answer"
|
||||
|
||||
@@ -31,7 +31,6 @@ from onyx.context.search.federated.slack_search_utils import is_recency_query
|
||||
from onyx.context.search.federated.slack_search_utils import should_include_message
|
||||
from onyx.context.search.models import ChunkIndexRequest
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import SearchQuery
|
||||
from onyx.db.document import DocumentSource
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.document_index_utils import (
|
||||
@@ -425,7 +424,6 @@ class SlackQueryResult(BaseModel):
|
||||
|
||||
def query_slack(
|
||||
query_string: str,
|
||||
original_query: SearchQuery,
|
||||
access_token: str,
|
||||
limit: int | None = None,
|
||||
allowed_private_channel: str | None = None,
|
||||
@@ -456,7 +454,7 @@ def query_slack(
|
||||
logger.info(f"Final query to slack: {final_query}")
|
||||
|
||||
# Detect if query asks for most recent results
|
||||
sort_by_time = is_recency_query(original_query.query)
|
||||
sort_by_time = is_recency_query(query_string)
|
||||
|
||||
slack_client = WebClient(token=access_token)
|
||||
try:
|
||||
@@ -536,8 +534,7 @@ def query_slack(
|
||||
)
|
||||
document_id = f"{channel_id}_{message_id}"
|
||||
|
||||
# compute recency bias (parallels vespa calculation) and metadata
|
||||
decay_factor = DOC_TIME_DECAY * original_query.recency_bias_multiplier
|
||||
decay_factor = DOC_TIME_DECAY
|
||||
doc_time = datetime.fromtimestamp(float(message_id))
|
||||
doc_age_years = (datetime.now() - doc_time).total_seconds() / (
|
||||
365 * 24 * 60 * 60
|
||||
@@ -1002,7 +999,6 @@ def slack_retrieval(
|
||||
query_slack,
|
||||
(
|
||||
query_string,
|
||||
query,
|
||||
access_token,
|
||||
query_limit,
|
||||
allowed_private_channel,
|
||||
@@ -1045,7 +1041,6 @@ def slack_retrieval(
|
||||
query_slack,
|
||||
(
|
||||
query_string,
|
||||
query,
|
||||
access_token,
|
||||
query_limit,
|
||||
allowed_private_channel,
|
||||
@@ -1225,7 +1220,6 @@ def slack_retrieval(
|
||||
source_type=DocumentSource.SLACK,
|
||||
title=chunk.title_prefix,
|
||||
boost=0,
|
||||
recency_bias=docid_to_message[document_id].recency_bias,
|
||||
score=convert_slack_score(docid_to_message[document_id].slack_score),
|
||||
hidden=False,
|
||||
is_relevant=None,
|
||||
|
||||
@@ -13,6 +13,7 @@ from onyx.context.search.federated.models import ChannelMetadata
|
||||
from onyx.context.search.models import ChunkIndexRequest
|
||||
from onyx.federated_connectors.slack.models import SlackEntities
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.llm.utils import llm_response_to_string
|
||||
from onyx.onyxbot.slack.models import ChannelType
|
||||
from onyx.prompts.federated_search import SLACK_DATE_EXTRACTION_PROMPT
|
||||
@@ -190,7 +191,7 @@ def extract_date_range_from_query(
|
||||
|
||||
try:
|
||||
prompt = SLACK_DATE_EXTRACTION_PROMPT.format(query=query)
|
||||
response = llm_response_to_string(llm.invoke(prompt))
|
||||
response = llm_response_to_string(llm.invoke(UserMessage(content=prompt)))
|
||||
|
||||
response_clean = _parse_llm_code_block_response(response)
|
||||
|
||||
@@ -566,23 +567,6 @@ def extract_content_words_from_recency_query(
|
||||
return content_words_filtered[:MAX_CONTENT_WORDS]
|
||||
|
||||
|
||||
def _is_valid_keyword_query(line: str) -> bool:
|
||||
"""Check if a line looks like a valid keyword query vs explanatory text.
|
||||
|
||||
Returns False for lines that appear to be LLM explanations rather than keywords.
|
||||
"""
|
||||
# Reject lines that start with parentheses (explanatory notes)
|
||||
if line.startswith("("):
|
||||
return False
|
||||
|
||||
# Reject lines that are too long (likely sentences, not keywords)
|
||||
# Keywords should be short - reject if > 50 chars or > 6 words
|
||||
if len(line) > 50 or len(line.split()) > 6:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
"""Use LLM to expand query into multiple search variations.
|
||||
|
||||
@@ -593,8 +577,10 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
Returns:
|
||||
List of rephrased query strings (up to MAX_SLACK_QUERY_EXPANSIONS)
|
||||
"""
|
||||
prompt = SLACK_QUERY_EXPANSION_PROMPT.format(
|
||||
query=query_text, max_queries=MAX_SLACK_QUERY_EXPANSIONS
|
||||
prompt = UserMessage(
|
||||
content=SLACK_QUERY_EXPANSION_PROMPT.format(
|
||||
query=query_text, max_queries=MAX_SLACK_QUERY_EXPANSIONS
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -603,18 +589,10 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
response_clean = _parse_llm_code_block_response(response)
|
||||
|
||||
# Split into lines and filter out empty lines
|
||||
raw_queries = [
|
||||
rephrased_queries = [
|
||||
line.strip() for line in response_clean.split("\n") if line.strip()
|
||||
]
|
||||
|
||||
# Filter out lines that look like explanatory text rather than keywords
|
||||
rephrased_queries = [q for q in raw_queries if _is_valid_keyword_query(q)]
|
||||
|
||||
# Log if we filtered out garbage
|
||||
if len(raw_queries) != len(rephrased_queries):
|
||||
filtered_out = set(raw_queries) - set(rephrased_queries)
|
||||
logger.warning(f"Filtered out non-keyword LLM responses: {filtered_out}")
|
||||
|
||||
# If no queries generated, use empty query
|
||||
if not rephrased_queries:
|
||||
logger.debug("No content keywords extracted from query expansion")
|
||||
|
||||
@@ -5,27 +5,15 @@ from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.indexing.models import BaseChunk
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from onyx.tools.tool_implementations.web_search.models import WEB_SEARCH_PREFIX
|
||||
from shared_configs.enums import RerankerProvider
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
|
||||
MAX_METRICS_CONTENT = (
|
||||
200 # Just need enough characters to identify where in the doc the chunk is
|
||||
)
|
||||
|
||||
|
||||
class QueryExpansions(BaseModel):
|
||||
@@ -38,6 +26,7 @@ class QueryExpansionType(Enum):
|
||||
SEMANTIC = "semantic"
|
||||
|
||||
|
||||
# TODO clean up this stuff, reranking is no longer used
|
||||
class RerankingDetails(BaseModel):
|
||||
# If model is None (or num_rerank is 0), then reranking is turned off
|
||||
rerank_model_name: str | None
|
||||
@@ -131,13 +120,6 @@ class IndexFilters(BaseFilters, UserFileFilters):
|
||||
tenant_id: str | None = None
|
||||
|
||||
|
||||
class ChunkMetric(BaseModel):
|
||||
document_id: str
|
||||
chunk_content_start: str
|
||||
first_link: str | None
|
||||
score: float
|
||||
|
||||
|
||||
class ChunkContext(BaseModel):
|
||||
# If not specified (None), picked up from Persona settings if there is space
|
||||
# if specified (even if 0), it always uses the specified number of chunks above and below
|
||||
@@ -192,94 +174,18 @@ class ContextExpansionType(str, Enum):
|
||||
FULL_DOCUMENT = "full_document"
|
||||
|
||||
|
||||
class SearchRequest(ChunkContext):
|
||||
query: str
|
||||
|
||||
expanded_queries: QueryExpansions | None = None
|
||||
original_query: str | None = None
|
||||
|
||||
search_type: SearchType = SearchType.SEMANTIC
|
||||
|
||||
human_selected_filters: BaseFilters | None = None
|
||||
user_file_filters: UserFileFilters | None = None
|
||||
enable_auto_detect_filters: bool | None = None
|
||||
persona: Persona | None = None
|
||||
|
||||
# if None, no offset / limit
|
||||
offset: int | None = None
|
||||
limit: int | None = None
|
||||
|
||||
multilingual_expansion: list[str] | None = None
|
||||
recency_bias_multiplier: float = 1.0
|
||||
hybrid_alpha: float | None = None
|
||||
rerank_settings: RerankingDetails | None = None
|
||||
evaluation_type: LLMEvaluationType = LLMEvaluationType.UNSPECIFIED
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
precomputed_query_embedding: Embedding | None = None
|
||||
precomputed_is_keyword: bool | None = None
|
||||
precomputed_keywords: list[str] | None = None
|
||||
|
||||
|
||||
class SearchQuery(ChunkContext):
|
||||
query: str
|
||||
processed_keywords: list[str]
|
||||
search_type: SearchType
|
||||
evaluation_type: LLMEvaluationType
|
||||
filters: IndexFilters
|
||||
|
||||
# by this point, the chunks_above and chunks_below must be set
|
||||
chunks_above: int
|
||||
chunks_below: int
|
||||
|
||||
rerank_settings: RerankingDetails | None
|
||||
hybrid_alpha: float
|
||||
recency_bias_multiplier: float
|
||||
|
||||
# Only used if LLM evaluation type is not skip, None to use default settings
|
||||
max_llm_filter_sections: int
|
||||
|
||||
num_hits: int = NUM_RETURNED_HITS
|
||||
offset: int = 0
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
precomputed_query_embedding: Embedding | None = None
|
||||
|
||||
expanded_queries: QueryExpansions | None = None
|
||||
original_query: str | None
|
||||
|
||||
|
||||
class RetrievalDetails(ChunkContext):
|
||||
# Use LLM to determine whether to do a retrieval or only rely on existing history
|
||||
# If the Persona is configured to not run search (0 chunks), this is bypassed
|
||||
# If no Prompt is configured, the only search results are shown, this is bypassed
|
||||
run_search: OptionalSearchSetting = OptionalSearchSetting.AUTO
|
||||
# Is this a real-time/streaming call or a question where Onyx can take more time?
|
||||
# Used to determine reranking flow
|
||||
real_time: bool = True
|
||||
# The following have defaults in the Persona settings which can be overridden via
|
||||
# the query, if None, then use Persona settings
|
||||
filters: BaseFilters | None = None
|
||||
enable_auto_detect_filters: bool | None = None
|
||||
# if None, no offset / limit
|
||||
offset: int | None = None
|
||||
limit: int | None = None
|
||||
|
||||
# If this is set, only the highest matching chunk (or merged chunks) is returned
|
||||
dedupe_docs: bool = False
|
||||
|
||||
|
||||
class InferenceChunk(BaseChunk):
|
||||
document_id: str
|
||||
source_type: DocumentSource
|
||||
semantic_identifier: str
|
||||
title: str | None # Separate from Semantic Identifier though often same
|
||||
boost: int
|
||||
recency_bias: float
|
||||
score: float | None
|
||||
hidden: bool
|
||||
is_relevant: bool | None = None
|
||||
relevance_explanation: str | None = None
|
||||
# TODO(andrei): Ideally we could improve this to where each value is just a
|
||||
# list of strings.
|
||||
metadata: dict[str, str | list[str]]
|
||||
# Matched sections in the chunk. Uses Vespa syntax e.g. <hi>TEXT</hi>
|
||||
# to specify that a set of words should be highlighted. For example:
|
||||
@@ -534,15 +440,3 @@ class SavedSearchDocWithContent(SavedSearchDoc):
|
||||
section in addition to the match_highlights."""
|
||||
|
||||
content: str
|
||||
|
||||
|
||||
class RetrievalMetricsContainer(BaseModel):
|
||||
search_type: SearchType
|
||||
metrics: list[ChunkMetric] # This contains the scores for retrieval as well
|
||||
|
||||
|
||||
class RerankMetricsContainer(BaseModel):
|
||||
"""The score held by this is the un-boosted, averaged score of the ensemble cross-encoders"""
|
||||
|
||||
metrics: list[ChunkMetric]
|
||||
raw_similarity_scores: list[float]
|
||||
|
||||
@@ -1,272 +0,0 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.chat_configs import BASE_RECENCY_DECAY
|
||||
from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
|
||||
from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
|
||||
from onyx.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
|
||||
from onyx.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER
|
||||
from onyx.configs.chat_configs import HYBRID_ALPHA
|
||||
from onyx.configs.chat_configs import HYBRID_ALPHA_KEYWORD
|
||||
from onyx.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import SearchQuery
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.context.search.preprocessing.access_filters import (
|
||||
build_access_filters_for_user,
|
||||
)
|
||||
from onyx.context.search.utils import (
|
||||
remove_stop_words_and_punctuation,
|
||||
)
|
||||
from onyx.db.models import User
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.natural_language_processing.search_nlp_models import QueryAnalysisModel
|
||||
from onyx.secondary_llm_flows.source_filter import extract_source_filter
|
||||
from onyx.secondary_llm_flows.time_filter import extract_time_filter
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import FunctionCall
|
||||
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def query_analysis(query: str) -> tuple[bool, list[str]]:
|
||||
analysis_model = QueryAnalysisModel()
|
||||
return analysis_model.predict(query)
|
||||
|
||||
|
||||
# TODO: This is unused code.
|
||||
@log_function_time(print_only=True)
|
||||
def retrieval_preprocessing(
|
||||
search_request: SearchRequest,
|
||||
user: User | None,
|
||||
llm: LLM,
|
||||
skip_query_analysis: bool,
|
||||
db_session: Session,
|
||||
favor_recent_decay_multiplier: float = FAVOR_RECENT_DECAY_MULTIPLIER,
|
||||
base_recency_decay: float = BASE_RECENCY_DECAY,
|
||||
bypass_acl: bool = False,
|
||||
) -> SearchQuery:
|
||||
"""Logic is as follows:
|
||||
Any global disables apply first
|
||||
Then any filters or settings as part of the query are used
|
||||
Then defaults to Persona settings if not specified by the query
|
||||
"""
|
||||
query = search_request.query
|
||||
limit = search_request.limit
|
||||
offset = search_request.offset
|
||||
persona = search_request.persona
|
||||
|
||||
preset_filters = search_request.human_selected_filters or BaseFilters()
|
||||
if persona and persona.document_sets and preset_filters.document_set is None:
|
||||
preset_filters.document_set = [
|
||||
document_set.name for document_set in persona.document_sets
|
||||
]
|
||||
|
||||
time_filter = preset_filters.time_cutoff
|
||||
if time_filter is None and persona:
|
||||
time_filter = persona.search_start_date
|
||||
|
||||
source_filter = preset_filters.source_type
|
||||
|
||||
auto_detect_time_filter = True
|
||||
auto_detect_source_filter = True
|
||||
if not search_request.enable_auto_detect_filters:
|
||||
logger.debug("Retrieval details disables auto detect filters")
|
||||
auto_detect_time_filter = False
|
||||
auto_detect_source_filter = False
|
||||
elif persona and persona.llm_filter_extraction is False:
|
||||
logger.debug("Persona disables auto detect filters")
|
||||
auto_detect_time_filter = False
|
||||
auto_detect_source_filter = False
|
||||
else:
|
||||
logger.debug("Auto detect filters enabled")
|
||||
|
||||
if (
|
||||
time_filter is not None
|
||||
and persona
|
||||
and persona.recency_bias != RecencyBiasSetting.AUTO
|
||||
):
|
||||
auto_detect_time_filter = False
|
||||
logger.debug("Not extract time filter - already provided")
|
||||
if source_filter is not None:
|
||||
logger.debug("Not extract source filter - already provided")
|
||||
auto_detect_source_filter = False
|
||||
|
||||
# Based on the query figure out if we should apply any hard time filters /
|
||||
# if we should bias more recent docs even more strongly
|
||||
run_time_filters = (
|
||||
FunctionCall(extract_time_filter, (query, llm), {})
|
||||
if auto_detect_time_filter
|
||||
else None
|
||||
)
|
||||
|
||||
# Based on the query, figure out if we should apply any source filters
|
||||
run_source_filters = (
|
||||
FunctionCall(extract_source_filter, (query, llm, db_session), {})
|
||||
if auto_detect_source_filter
|
||||
else None
|
||||
)
|
||||
|
||||
# Sometimes this is pre-computed in parallel with other heavy tasks to improve
|
||||
# latency, and in that case we don't need to run the model again
|
||||
run_query_analysis = (
|
||||
None
|
||||
if (skip_query_analysis or search_request.precomputed_is_keyword is not None)
|
||||
else FunctionCall(query_analysis, (query,), {})
|
||||
)
|
||||
|
||||
functions_to_run = [
|
||||
filter_fn
|
||||
for filter_fn in [
|
||||
run_time_filters,
|
||||
run_source_filters,
|
||||
run_query_analysis,
|
||||
]
|
||||
if filter_fn
|
||||
]
|
||||
parallel_results = run_functions_in_parallel(functions_to_run)
|
||||
|
||||
predicted_time_cutoff, predicted_favor_recent = (
|
||||
parallel_results[run_time_filters.result_id]
|
||||
if run_time_filters
|
||||
else (None, None)
|
||||
)
|
||||
predicted_source_filters = (
|
||||
parallel_results[run_source_filters.result_id] if run_source_filters else None
|
||||
)
|
||||
|
||||
# The extracted keywords right now are not very reliable, not using for now
|
||||
# Can maybe use for highlighting
|
||||
is_keyword, _extracted_keywords = False, None
|
||||
if search_request.precomputed_is_keyword is not None:
|
||||
is_keyword = search_request.precomputed_is_keyword
|
||||
_extracted_keywords = search_request.precomputed_keywords
|
||||
elif run_query_analysis:
|
||||
is_keyword, _extracted_keywords = parallel_results[run_query_analysis.result_id]
|
||||
|
||||
all_query_terms = query.split()
|
||||
processed_keywords = (
|
||||
remove_stop_words_and_punctuation(all_query_terms)
|
||||
# If the user is using a different language, don't edit the query or remove english stopwords
|
||||
if not search_request.multilingual_expansion
|
||||
else all_query_terms
|
||||
)
|
||||
|
||||
user_acl_filters = (
|
||||
None if bypass_acl else build_access_filters_for_user(user, db_session)
|
||||
)
|
||||
user_file_filters = search_request.user_file_filters
|
||||
user_file_ids = (user_file_filters.user_file_ids or []) if user_file_filters else []
|
||||
if persona and persona.user_files:
|
||||
user_file_ids = list(
|
||||
set(user_file_ids) | set([file.id for file in persona.user_files])
|
||||
)
|
||||
|
||||
final_filters = IndexFilters(
|
||||
user_file_ids=user_file_ids,
|
||||
project_id=user_file_filters.project_id if user_file_filters else None,
|
||||
source_type=preset_filters.source_type or predicted_source_filters,
|
||||
document_set=preset_filters.document_set,
|
||||
time_cutoff=time_filter or predicted_time_cutoff,
|
||||
tags=preset_filters.tags, # Tags are never auto-extracted
|
||||
access_control_list=user_acl_filters,
|
||||
tenant_id=get_current_tenant_id() if MULTI_TENANT else None,
|
||||
# kg_entities=preset_filters.kg_entities,
|
||||
# kg_relationships=preset_filters.kg_relationships,
|
||||
# kg_terms=preset_filters.kg_terms,
|
||||
# kg_sources=preset_filters.kg_sources,
|
||||
# kg_chunk_id_zero_only=preset_filters.kg_chunk_id_zero_only,
|
||||
)
|
||||
|
||||
llm_evaluation_type = LLMEvaluationType.BASIC
|
||||
if search_request.evaluation_type is not LLMEvaluationType.UNSPECIFIED:
|
||||
llm_evaluation_type = search_request.evaluation_type
|
||||
|
||||
elif persona:
|
||||
llm_evaluation_type = (
|
||||
LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
)
|
||||
|
||||
if DISABLE_LLM_DOC_RELEVANCE:
|
||||
if llm_evaluation_type:
|
||||
logger.info(
|
||||
"LLM chunk filtering would have run but has been globally disabled"
|
||||
)
|
||||
llm_evaluation_type = LLMEvaluationType.SKIP
|
||||
|
||||
rerank_settings = search_request.rerank_settings
|
||||
# If not explicitly specified by the query, use the current settings
|
||||
if rerank_settings is None:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
# For non-streaming flows, the rerank settings are applied at the search_request level
|
||||
if not search_settings.disable_rerank_for_streaming:
|
||||
rerank_settings = RerankingDetails.from_db_model(search_settings)
|
||||
|
||||
# Decays at 1 / (1 + (multiplier * num years))
|
||||
if persona and persona.recency_bias == RecencyBiasSetting.NO_DECAY:
|
||||
recency_bias_multiplier = 0.0
|
||||
elif persona and persona.recency_bias == RecencyBiasSetting.BASE_DECAY:
|
||||
recency_bias_multiplier = base_recency_decay
|
||||
elif persona and persona.recency_bias == RecencyBiasSetting.FAVOR_RECENT:
|
||||
recency_bias_multiplier = base_recency_decay * favor_recent_decay_multiplier
|
||||
else:
|
||||
if predicted_favor_recent:
|
||||
recency_bias_multiplier = base_recency_decay * favor_recent_decay_multiplier
|
||||
else:
|
||||
recency_bias_multiplier = base_recency_decay
|
||||
|
||||
hybrid_alpha = HYBRID_ALPHA_KEYWORD if is_keyword else HYBRID_ALPHA
|
||||
if search_request.hybrid_alpha:
|
||||
hybrid_alpha = search_request.hybrid_alpha
|
||||
|
||||
# Search request overrides anything else as it's explicitly set by the request
|
||||
# If not explicitly specified, use the persona settings if they exist
|
||||
# Otherwise, use the global defaults
|
||||
chunks_above = (
|
||||
search_request.chunks_above
|
||||
if search_request.chunks_above is not None
|
||||
else (persona.chunks_above if persona else CONTEXT_CHUNKS_ABOVE)
|
||||
)
|
||||
chunks_below = (
|
||||
search_request.chunks_below
|
||||
if search_request.chunks_below is not None
|
||||
else (persona.chunks_below if persona else CONTEXT_CHUNKS_BELOW)
|
||||
)
|
||||
|
||||
return SearchQuery(
|
||||
query=query,
|
||||
original_query=search_request.original_query,
|
||||
processed_keywords=processed_keywords,
|
||||
search_type=SearchType.KEYWORD if is_keyword else SearchType.SEMANTIC,
|
||||
evaluation_type=llm_evaluation_type,
|
||||
filters=final_filters,
|
||||
hybrid_alpha=hybrid_alpha,
|
||||
recency_bias_multiplier=recency_bias_multiplier,
|
||||
num_hits=limit if limit is not None else NUM_RETURNED_HITS,
|
||||
offset=offset or 0,
|
||||
rerank_settings=rerank_settings,
|
||||
# Should match the LLM filtering to the same as the reranked, it's understood as this is the number of results
|
||||
# the user wants to do heavier processing on, so do the same for the LLM if reranking is on
|
||||
# if no reranking settings are set, then use the global default
|
||||
max_llm_filter_sections=(
|
||||
rerank_settings.num_rerank if rerank_settings else NUM_POSTPROCESSED_RESULTS
|
||||
),
|
||||
chunks_above=chunks_above,
|
||||
chunks_below=chunks_below,
|
||||
full_doc=search_request.full_doc,
|
||||
precomputed_query_embedding=search_request.precomputed_query_embedding,
|
||||
expanded_queries=search_request.expanded_queries,
|
||||
)
|
||||
@@ -1,42 +1,24 @@
|
||||
import string
|
||||
from collections.abc import Callable
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.chat_configs import HYBRID_ALPHA
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import ChunkIndexRequest
|
||||
from onyx.context.search.models import ChunkMetric
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import MAX_METRICS_CONTENT
|
||||
from onyx.context.search.models import QueryExpansionType
|
||||
from onyx.context.search.models import RetrievalMetricsContainer
|
||||
from onyx.context.search.models import SearchQuery
|
||||
from onyx.context.search.preprocessing.preprocessing import HYBRID_ALPHA
|
||||
from onyx.context.search.preprocessing.preprocessing import HYBRID_ALPHA_KEYWORD
|
||||
from onyx.context.search.utils import get_query_embedding
|
||||
from onyx.context.search.utils import get_query_embeddings
|
||||
from onyx.context.search.utils import inference_section_from_chunks
|
||||
from onyx.db.search_settings import get_multilingual_expansion
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
from onyx.document_index.vespa.shared_utils.utils import (
|
||||
replace_invalid_doc_id_characters,
|
||||
)
|
||||
from onyx.federated_connectors.federated_retrieval import (
|
||||
get_federated_retrieval_functions,
|
||||
)
|
||||
from onyx.secondary_llm_flows.query_expansion import multilingual_query_expansion
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import TimeoutThread
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -80,19 +62,6 @@ def download_nltk_data() -> None:
|
||||
logger.error(f"Failed to download {resource_name}. Error: {e}")
|
||||
|
||||
|
||||
def lemmatize_text(keywords: list[str]) -> list[str]:
|
||||
raise NotImplementedError("Lemmatization should not be used currently")
|
||||
# try:
|
||||
# query = " ".join(keywords)
|
||||
# lemmatizer = WordNetLemmatizer()
|
||||
# word_tokens = word_tokenize(query)
|
||||
# lemmatized_words = [lemmatizer.lemmatize(word) for word in word_tokens]
|
||||
# combined_keywords = list(set(keywords + lemmatized_words))
|
||||
# return combined_keywords
|
||||
# except Exception:
|
||||
# return keywords
|
||||
|
||||
|
||||
def combine_retrieval_results(
|
||||
chunk_sets: list[list[InferenceChunk]],
|
||||
) -> list[InferenceChunk]:
|
||||
@@ -117,313 +86,6 @@ def combine_retrieval_results(
|
||||
return sorted_chunks
|
||||
|
||||
|
||||
# TODO: This is unused code.
|
||||
@log_function_time(print_only=True)
|
||||
def doc_index_retrieval(
|
||||
query: SearchQuery,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session,
|
||||
) -> list[InferenceChunk]:
|
||||
"""
|
||||
This function performs the search to retrieve the chunks,
|
||||
extracts chunks from the large chunks, persists the scores
|
||||
from the large chunks to the referenced chunks,
|
||||
dedupes the chunks, and cleans the chunks.
|
||||
"""
|
||||
query_embedding = query.precomputed_query_embedding or get_query_embedding(
|
||||
query.query, db_session
|
||||
)
|
||||
|
||||
keyword_embeddings_thread: TimeoutThread[list[Embedding]] | None = None
|
||||
semantic_embeddings_thread: TimeoutThread[list[Embedding]] | None = None
|
||||
top_base_chunks_standard_ranking_thread: (
|
||||
TimeoutThread[list[InferenceChunk]] | None
|
||||
) = None
|
||||
|
||||
top_semantic_chunks_thread: TimeoutThread[list[InferenceChunk]] | None = None
|
||||
|
||||
keyword_embeddings: list[Embedding] | None = None
|
||||
semantic_embeddings: list[Embedding] | None = None
|
||||
|
||||
top_semantic_chunks: list[InferenceChunk] | None = None
|
||||
|
||||
# original retrieveal method
|
||||
top_base_chunks_standard_ranking_thread = run_in_background(
|
||||
document_index.hybrid_retrieval,
|
||||
query.query,
|
||||
query_embedding,
|
||||
query.processed_keywords,
|
||||
query.filters,
|
||||
query.hybrid_alpha,
|
||||
query.recency_bias_multiplier,
|
||||
query.num_hits,
|
||||
QueryExpansionType.SEMANTIC,
|
||||
query.offset,
|
||||
)
|
||||
|
||||
if (
|
||||
query.expanded_queries
|
||||
and query.expanded_queries.keywords_expansions
|
||||
and query.expanded_queries.semantic_expansions
|
||||
):
|
||||
|
||||
keyword_embeddings_thread = run_in_background(
|
||||
get_query_embeddings,
|
||||
query.expanded_queries.keywords_expansions,
|
||||
db_session,
|
||||
)
|
||||
|
||||
if query.search_type == SearchType.SEMANTIC:
|
||||
semantic_embeddings_thread = run_in_background(
|
||||
get_query_embeddings,
|
||||
query.expanded_queries.semantic_expansions,
|
||||
db_session,
|
||||
)
|
||||
|
||||
keyword_embeddings = wait_on_background(keyword_embeddings_thread)
|
||||
if query.search_type == SearchType.SEMANTIC:
|
||||
assert semantic_embeddings_thread is not None
|
||||
semantic_embeddings = wait_on_background(semantic_embeddings_thread)
|
||||
|
||||
# Use original query embedding for keyword retrieval embedding
|
||||
keyword_embeddings = [query_embedding]
|
||||
|
||||
# Note: we generally prepped earlier for multiple expansions, but for now we only use one.
|
||||
top_keyword_chunks_thread = run_in_background(
|
||||
document_index.hybrid_retrieval,
|
||||
query.expanded_queries.keywords_expansions[0],
|
||||
keyword_embeddings[0],
|
||||
query.processed_keywords,
|
||||
query.filters,
|
||||
HYBRID_ALPHA_KEYWORD,
|
||||
query.recency_bias_multiplier,
|
||||
query.num_hits,
|
||||
QueryExpansionType.KEYWORD,
|
||||
query.offset,
|
||||
)
|
||||
|
||||
if query.search_type == SearchType.SEMANTIC:
|
||||
assert semantic_embeddings is not None
|
||||
|
||||
top_semantic_chunks_thread = run_in_background(
|
||||
document_index.hybrid_retrieval,
|
||||
query.expanded_queries.semantic_expansions[0],
|
||||
semantic_embeddings[0],
|
||||
query.processed_keywords,
|
||||
query.filters,
|
||||
HYBRID_ALPHA,
|
||||
query.recency_bias_multiplier,
|
||||
query.num_hits,
|
||||
QueryExpansionType.SEMANTIC,
|
||||
query.offset,
|
||||
)
|
||||
|
||||
top_base_chunks_standard_ranking = wait_on_background(
|
||||
top_base_chunks_standard_ranking_thread
|
||||
)
|
||||
|
||||
top_keyword_chunks = wait_on_background(top_keyword_chunks_thread)
|
||||
|
||||
if query.search_type == SearchType.SEMANTIC:
|
||||
assert top_semantic_chunks_thread is not None
|
||||
top_semantic_chunks = wait_on_background(top_semantic_chunks_thread)
|
||||
|
||||
all_top_chunks = top_base_chunks_standard_ranking + top_keyword_chunks
|
||||
|
||||
# use all three retrieval methods to retrieve top chunks
|
||||
|
||||
if query.search_type == SearchType.SEMANTIC and top_semantic_chunks is not None:
|
||||
|
||||
all_top_chunks += top_semantic_chunks
|
||||
|
||||
top_chunks = _dedupe_chunks(all_top_chunks)
|
||||
|
||||
else:
|
||||
|
||||
top_base_chunks_standard_ranking = wait_on_background(
|
||||
top_base_chunks_standard_ranking_thread
|
||||
)
|
||||
|
||||
top_chunks = _dedupe_chunks(top_base_chunks_standard_ranking)
|
||||
|
||||
logger.info(f"Overall number of top initial retrieval chunks: {len(top_chunks)}")
|
||||
|
||||
retrieval_requests: list[VespaChunkRequest] = []
|
||||
normal_chunks: list[InferenceChunk] = []
|
||||
referenced_chunk_scores: dict[tuple[str, int], float] = {}
|
||||
for chunk in top_chunks:
|
||||
if chunk.large_chunk_reference_ids:
|
||||
retrieval_requests.append(
|
||||
VespaChunkRequest(
|
||||
document_id=replace_invalid_doc_id_characters(chunk.document_id),
|
||||
min_chunk_ind=chunk.large_chunk_reference_ids[0],
|
||||
max_chunk_ind=chunk.large_chunk_reference_ids[-1],
|
||||
)
|
||||
)
|
||||
# for each referenced chunk, persist the
|
||||
# highest score to the referenced chunk
|
||||
for chunk_id in chunk.large_chunk_reference_ids:
|
||||
key = (chunk.document_id, chunk_id)
|
||||
referenced_chunk_scores[key] = max(
|
||||
referenced_chunk_scores.get(key, 0), chunk.score or 0
|
||||
)
|
||||
else:
|
||||
normal_chunks.append(chunk)
|
||||
|
||||
# If there are no large chunks, just return the normal chunks
|
||||
if not retrieval_requests:
|
||||
return normal_chunks
|
||||
|
||||
# Retrieve and return the referenced normal chunks from the large chunks
|
||||
retrieved_inference_chunks = document_index.id_based_retrieval(
|
||||
chunk_requests=retrieval_requests,
|
||||
filters=query.filters,
|
||||
batch_retrieval=True,
|
||||
)
|
||||
|
||||
# Apply the scores from the large chunks to the chunks referenced
|
||||
# by each large chunk
|
||||
for chunk in retrieved_inference_chunks:
|
||||
if (chunk.document_id, chunk.chunk_id) in referenced_chunk_scores:
|
||||
chunk.score = referenced_chunk_scores[(chunk.document_id, chunk.chunk_id)]
|
||||
referenced_chunk_scores.pop((chunk.document_id, chunk.chunk_id))
|
||||
else:
|
||||
logger.error(
|
||||
f"Chunk {chunk.document_id} {chunk.chunk_id} not found in referenced chunk scores"
|
||||
)
|
||||
|
||||
# Log any chunks that were not found in the retrieved chunks
|
||||
for reference in referenced_chunk_scores.keys():
|
||||
logger.error(f"Chunk {reference} not found in retrieved chunks")
|
||||
|
||||
unique_chunks: dict[tuple[str, int], InferenceChunk] = {
|
||||
(chunk.document_id, chunk.chunk_id): chunk for chunk in normal_chunks
|
||||
}
|
||||
|
||||
# persist the highest score of each deduped chunk
|
||||
for chunk in retrieved_inference_chunks:
|
||||
key = (chunk.document_id, chunk.chunk_id)
|
||||
# For duplicates, keep the highest score
|
||||
if key not in unique_chunks or (chunk.score or 0) > (
|
||||
unique_chunks[key].score or 0
|
||||
):
|
||||
unique_chunks[key] = chunk
|
||||
|
||||
# Deduplicate the chunks
|
||||
deduped_chunks = list(unique_chunks.values())
|
||||
deduped_chunks.sort(key=lambda chunk: chunk.score or 0, reverse=True)
|
||||
return deduped_chunks
|
||||
|
||||
|
||||
def _simplify_text(text: str) -> str:
|
||||
return "".join(
|
||||
char for char in text if char not in string.punctuation and not char.isspace()
|
||||
).lower()
|
||||
|
||||
|
||||
# TODO delete this
|
||||
def retrieve_chunks(
|
||||
query: SearchQuery,
|
||||
user_id: UUID | None,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session,
|
||||
retrieval_metrics_callback: (
|
||||
Callable[[RetrievalMetricsContainer], None] | None
|
||||
) = None,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Returns a list of the best chunks from an initial keyword/semantic/ hybrid search."""
|
||||
|
||||
multilingual_expansion = get_multilingual_expansion(db_session)
|
||||
run_queries: list[tuple[Callable, tuple]] = []
|
||||
|
||||
source_filters = (
|
||||
set(query.filters.source_type) if query.filters.source_type else None
|
||||
)
|
||||
|
||||
# Federated retrieval
|
||||
federated_retrieval_infos = get_federated_retrieval_functions(
|
||||
db_session,
|
||||
user_id,
|
||||
list(query.filters.source_type) if query.filters.source_type else None,
|
||||
query.filters.document_set,
|
||||
user_file_ids=query.filters.user_file_ids,
|
||||
)
|
||||
federated_sources = set(
|
||||
federated_retrieval_info.source.to_non_federated_source()
|
||||
for federated_retrieval_info in federated_retrieval_infos
|
||||
)
|
||||
for federated_retrieval_info in federated_retrieval_infos:
|
||||
run_queries.append((federated_retrieval_info.retrieval_function, (query,)))
|
||||
|
||||
# Normal retrieval
|
||||
normal_search_enabled = (source_filters is None) or (
|
||||
len(set(source_filters) - federated_sources) > 0
|
||||
)
|
||||
if normal_search_enabled and (
|
||||
not multilingual_expansion or "\n" in query.query or "\r" in query.query
|
||||
):
|
||||
# Don't do query expansion on complex queries, rephrasings likely would not work well
|
||||
run_queries.append((doc_index_retrieval, (query, document_index, db_session)))
|
||||
elif normal_search_enabled:
|
||||
simplified_queries = set()
|
||||
|
||||
# Currently only uses query expansion on multilingual use cases
|
||||
query_rephrases = multilingual_query_expansion(
|
||||
query.query, multilingual_expansion
|
||||
)
|
||||
# Just to be extra sure, add the original query.
|
||||
query_rephrases.append(query.query)
|
||||
for rephrase in set(query_rephrases):
|
||||
# Sometimes the model rephrases the query in the same language with minor changes
|
||||
# Avoid doing an extra search with the minor changes as this biases the results
|
||||
simplified_rephrase = _simplify_text(rephrase)
|
||||
if simplified_rephrase in simplified_queries:
|
||||
continue
|
||||
simplified_queries.add(simplified_rephrase)
|
||||
|
||||
q_copy = query.model_copy(
|
||||
update={
|
||||
"query": rephrase,
|
||||
# need to recompute for each rephrase
|
||||
# note that `SearchQuery` is a frozen model, so we can't update
|
||||
# it below
|
||||
"precomputed_query_embedding": None,
|
||||
},
|
||||
deep=True,
|
||||
)
|
||||
run_queries.append(
|
||||
(doc_index_retrieval, (q_copy, document_index, db_session))
|
||||
)
|
||||
|
||||
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
|
||||
top_chunks = combine_retrieval_results(parallel_search_results)
|
||||
|
||||
if not top_chunks:
|
||||
logger.warning(
|
||||
f"Hybrid ({query.search_type.value.capitalize()}) search returned no results "
|
||||
f"with filters: {query.filters}"
|
||||
)
|
||||
return []
|
||||
|
||||
if retrieval_metrics_callback is not None:
|
||||
chunk_metrics = [
|
||||
ChunkMetric(
|
||||
document_id=chunk.document_id,
|
||||
chunk_content_start=chunk.content[:MAX_METRICS_CONTENT],
|
||||
first_link=chunk.source_links[0] if chunk.source_links else None,
|
||||
score=chunk.score if chunk.score is not None else 0,
|
||||
)
|
||||
for chunk in top_chunks
|
||||
]
|
||||
retrieval_metrics_callback(
|
||||
RetrievalMetricsContainer(
|
||||
search_type=query.search_type, metrics=chunk_metrics
|
||||
)
|
||||
)
|
||||
|
||||
return top_chunks
|
||||
|
||||
|
||||
def _embed_and_search(
|
||||
query_request: ChunkIndexRequest,
|
||||
document_index: DocumentIndex,
|
||||
|
||||
@@ -1,16 +1,12 @@
|
||||
import string
|
||||
from collections.abc import Sequence
|
||||
from typing import TypeVar
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import SectionRelevancePiece
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.context.search.models import SavedSearchDocWithContent
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.db.models import SearchDoc as DBSearchDoc
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -41,66 +37,6 @@ TSection = TypeVar(
|
||||
)
|
||||
|
||||
|
||||
def dedupe_documents(items: list[T]) -> tuple[list[T], list[int]]:
|
||||
seen_ids = set()
|
||||
deduped_items = []
|
||||
dropped_indices = []
|
||||
for index, item in enumerate(items):
|
||||
if isinstance(item, InferenceSection):
|
||||
document_id = item.center_chunk.document_id
|
||||
else:
|
||||
document_id = item.document_id
|
||||
|
||||
if document_id not in seen_ids:
|
||||
seen_ids.add(document_id)
|
||||
deduped_items.append(item)
|
||||
else:
|
||||
dropped_indices.append(index)
|
||||
return deduped_items, dropped_indices
|
||||
|
||||
|
||||
def relevant_sections_to_indices(
|
||||
relevance_sections: list[SectionRelevancePiece] | None, items: list[TSection]
|
||||
) -> list[int]:
|
||||
if not relevance_sections:
|
||||
return []
|
||||
|
||||
relevant_set = {
|
||||
(chunk.document_id, chunk.chunk_id)
|
||||
for chunk in relevance_sections
|
||||
if chunk.relevant
|
||||
}
|
||||
|
||||
return [
|
||||
index
|
||||
for index, item in enumerate(items)
|
||||
if (
|
||||
(
|
||||
isinstance(item, InferenceSection)
|
||||
and (item.center_chunk.document_id, item.center_chunk.chunk_id)
|
||||
in relevant_set
|
||||
)
|
||||
or (
|
||||
not isinstance(item, (InferenceSection))
|
||||
and (item.document_id, item.chunk_ind) in relevant_set
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def drop_llm_indices(
|
||||
llm_indices: list[int],
|
||||
search_docs: Sequence[DBSearchDoc | SavedSearchDoc],
|
||||
dropped_indices: list[int],
|
||||
) -> list[int]:
|
||||
llm_bools = [i in llm_indices for i in range(len(search_docs))]
|
||||
if dropped_indices:
|
||||
llm_bools = [
|
||||
val for ind, val in enumerate(llm_bools) if ind not in dropped_indices
|
||||
]
|
||||
return [i for i, val in enumerate(llm_bools) if val]
|
||||
|
||||
|
||||
def inference_section_from_chunks(
|
||||
center_chunk: InferenceChunk,
|
||||
chunks: list[InferenceChunk],
|
||||
@@ -128,26 +64,6 @@ def inference_section_from_single_chunk(
|
||||
)
|
||||
|
||||
|
||||
def remove_stop_words_and_punctuation(keywords: list[str]) -> list[str]:
|
||||
from nltk.corpus import stopwords # type:ignore
|
||||
from nltk.tokenize import word_tokenize # type:ignore
|
||||
|
||||
try:
|
||||
# Re-tokenize using the NLTK tokenizer for better matching
|
||||
query = " ".join(keywords)
|
||||
stop_words = set(stopwords.words("english"))
|
||||
word_tokens = word_tokenize(query)
|
||||
text_trimmed = [
|
||||
word
|
||||
for word in word_tokens
|
||||
if (word.casefold() not in stop_words and word not in string.punctuation)
|
||||
]
|
||||
return text_trimmed or word_tokens
|
||||
except Exception as e:
|
||||
logger.warning(f"Error removing stop words and punctuation: {e}")
|
||||
return keywords
|
||||
|
||||
|
||||
def get_query_embeddings(queries: list[str], db_session: Session) -> list[Embedding]:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
|
||||
@@ -91,59 +91,6 @@ def get_chat_sessions_by_slack_thread_id(
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def get_valid_messages_from_query_sessions(
|
||||
chat_session_ids: list[UUID],
|
||||
db_session: Session,
|
||||
) -> dict[UUID, str]:
|
||||
user_message_subquery = (
|
||||
select(
|
||||
ChatMessage.chat_session_id, func.min(ChatMessage.id).label("user_msg_id")
|
||||
)
|
||||
.where(
|
||||
ChatMessage.chat_session_id.in_(chat_session_ids),
|
||||
ChatMessage.message_type == MessageType.USER,
|
||||
)
|
||||
.group_by(ChatMessage.chat_session_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
assistant_message_subquery = (
|
||||
select(
|
||||
ChatMessage.chat_session_id,
|
||||
func.min(ChatMessage.id).label("assistant_msg_id"),
|
||||
)
|
||||
.where(
|
||||
ChatMessage.chat_session_id.in_(chat_session_ids),
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
)
|
||||
.group_by(ChatMessage.chat_session_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
query = (
|
||||
select(ChatMessage.chat_session_id, ChatMessage.message)
|
||||
.join(
|
||||
user_message_subquery,
|
||||
ChatMessage.chat_session_id == user_message_subquery.c.chat_session_id,
|
||||
)
|
||||
.join(
|
||||
assistant_message_subquery,
|
||||
ChatMessage.chat_session_id == assistant_message_subquery.c.chat_session_id,
|
||||
)
|
||||
.join(
|
||||
ChatMessage__SearchDoc,
|
||||
ChatMessage__SearchDoc.chat_message_id
|
||||
== assistant_message_subquery.c.assistant_msg_id,
|
||||
)
|
||||
.where(ChatMessage.id == user_message_subquery.c.user_msg_id)
|
||||
)
|
||||
|
||||
first_messages = db_session.execute(query).all()
|
||||
logger.info(f"Retrieved {len(first_messages)} first messages with documents")
|
||||
|
||||
return {row.chat_session_id: row.message for row in first_messages}
|
||||
|
||||
|
||||
# Retrieves chat sessions by user
|
||||
# Chat sessions do not include onyxbot flows
|
||||
def get_chat_sessions_by_user(
|
||||
@@ -510,21 +457,6 @@ def add_chats_to_session_from_slack_thread(
|
||||
)
|
||||
|
||||
|
||||
def get_search_docs_for_chat_message(
|
||||
chat_message_id: int, db_session: Session
|
||||
) -> list[DBSearchDoc]:
|
||||
stmt = (
|
||||
select(DBSearchDoc)
|
||||
.join(
|
||||
ChatMessage__SearchDoc,
|
||||
ChatMessage__SearchDoc.search_doc_id == DBSearchDoc.id,
|
||||
)
|
||||
.where(ChatMessage__SearchDoc.chat_message_id == chat_message_id)
|
||||
)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def add_search_docs_to_chat_message(
|
||||
chat_message_id: int, search_doc_ids: list[int], db_session: Session
|
||||
) -> None:
|
||||
|
||||
@@ -83,7 +83,6 @@ from onyx.utils.special_types import JSON_ro
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.override_models import PromptOverride
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.kg.models import KGStage
|
||||
from onyx.server.features.mcp.models import MCPConnectionData
|
||||
from onyx.utils.encryption import decrypt_bytes_to_string
|
||||
@@ -91,6 +90,8 @@ from onyx.utils.encryption import encrypt_string_to_bytes
|
||||
from onyx.utils.headers import HeaderItemDict
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import RerankerProvider
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -2044,7 +2045,7 @@ class ChatSession(Base):
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
persona_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("persona.id"), nullable=True
|
||||
ForeignKey("persona.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
# This chat created by OnyxBot
|
||||
@@ -2332,6 +2333,23 @@ class SearchDoc(Base):
|
||||
)
|
||||
|
||||
|
||||
class SearchQuery(Base):
|
||||
# This table contains search queries for the Search UI. There are no followups and less is stored because the reply
|
||||
# functionality is simply to rerun the search query again as things may have changed and this is more common for search.
|
||||
__tablename__ = "search_query"
|
||||
id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True), primary_key=True, default=uuid4
|
||||
)
|
||||
user_id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), ForeignKey("user.id"))
|
||||
query: Mapped[str] = mapped_column(String)
|
||||
query_expansions: Mapped[list[str] | None] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True
|
||||
)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
Feedback, Logging, Metrics Tables
|
||||
"""
|
||||
|
||||
101
backend/onyx/document_index/chunk_content_enrichment.py
Normal file
101
backend/onyx/document_index/chunk_content_enrichment.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from onyx.configs.app_configs import BLURB_SIZE
|
||||
from onyx.configs.constants import RETURN_SEPARATOR
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceChunkUncleaned
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
|
||||
|
||||
def generate_enriched_content_for_chunk(chunk: DocMetadataAwareIndexChunk) -> str:
|
||||
return f"{chunk.title_prefix}{chunk.doc_summary}{chunk.content}{chunk.chunk_context}{chunk.metadata_suffix_keyword}"
|
||||
|
||||
|
||||
def cleanup_content_for_chunks(
|
||||
chunks: list[InferenceChunkUncleaned],
|
||||
) -> list[InferenceChunk]:
|
||||
"""
|
||||
Removes indexing-time content additions from chunks. Inverse of
|
||||
generate_enriched_content_for_chunk.
|
||||
|
||||
During indexing, chunks are augmented with additional text to improve search
|
||||
quality:
|
||||
- Title prepended to content (for better keyword/semantic matching)
|
||||
- Metadata suffix appended to content
|
||||
- Contextual RAG: doc_summary (beginning) and chunk_context (end)
|
||||
|
||||
This function strips these additions before returning chunks to users,
|
||||
restoring the original document content. Cleaning is applied in sequence:
|
||||
1. Title removal:
|
||||
- Full match: Strips exact title from beginning
|
||||
- Partial match: If content starts with title[:BLURB_SIZE], splits on
|
||||
RETURN_SEPARATOR to remove title section
|
||||
2. Metadata suffix removal:
|
||||
- Strips metadata_suffix from end, plus trailing RETURN_SEPARATOR
|
||||
3. Contextual RAG removal:
|
||||
- Strips doc_summary from beginning (if present)
|
||||
- Strips chunk_context from end (if present)
|
||||
|
||||
TODO(andrei): This entire function is not that fantastic, clean it up during
|
||||
QA before rolling out OpenSearch.
|
||||
|
||||
Args:
|
||||
chunks: Chunks as retrieved from the document index with indexing
|
||||
augmentations intact.
|
||||
|
||||
Returns:
|
||||
Clean InferenceChunk objects with augmentations removed, containing only
|
||||
the original document content that should be shown to users.
|
||||
"""
|
||||
|
||||
def _remove_title(chunk: InferenceChunkUncleaned) -> str:
|
||||
# TODO(andrei): This was ported over from
|
||||
# backend/onyx/document_index/vespa/vespa_document_index.py but I don't
|
||||
# think this logic is correct. In Vespa at least we set the title field
|
||||
# from the output of get_title_for_document_index, which is not
|
||||
# necessarily the same data that is prepended to the content; that comes
|
||||
# from title_prefix.
|
||||
# This was added in
|
||||
# https://github.com/onyx-dot-app/onyx/commit/e90c66c1b61c5b7da949652d703f7c906863e6e4#diff-2a2a29d5929de75cdaea77867a397934d9f8b785ce40a861c0d704033e3663ab,
|
||||
# see postprocessing.py. At that time the content enrichment logic was
|
||||
# also added in that commit, see
|
||||
# https://github.com/onyx-dot-app/onyx/commit/e90c66c1b61c5b7da949652d703f7c906863e6e4#diff-d807718aa263a15c1d991a4ab063c360c8419eaad210b4ba70e1e9f47d2aa6d2R77
|
||||
# chunker.py.
|
||||
if not chunk.title or not chunk.content:
|
||||
return chunk.content
|
||||
|
||||
if chunk.content.startswith(chunk.title):
|
||||
return chunk.content[len(chunk.title) :].lstrip()
|
||||
|
||||
# BLURB SIZE is by token instead of char but each token is at least 1 char
|
||||
# If this prefix matches the content, it's assumed the title was prepended
|
||||
if chunk.content.startswith(chunk.title[:BLURB_SIZE]):
|
||||
return (
|
||||
chunk.content.split(RETURN_SEPARATOR, 1)[-1]
|
||||
if RETURN_SEPARATOR in chunk.content
|
||||
else chunk.content
|
||||
)
|
||||
return chunk.content
|
||||
|
||||
def _remove_metadata_suffix(chunk: InferenceChunkUncleaned) -> str:
|
||||
if not chunk.metadata_suffix:
|
||||
return chunk.content
|
||||
return chunk.content.removesuffix(chunk.metadata_suffix).rstrip(
|
||||
RETURN_SEPARATOR
|
||||
)
|
||||
|
||||
def _remove_contextual_rag(chunk: InferenceChunkUncleaned) -> str:
|
||||
# remove document summary
|
||||
if chunk.doc_summary and chunk.content.startswith(chunk.doc_summary):
|
||||
chunk.content = chunk.content[len(chunk.doc_summary) :].lstrip()
|
||||
# remove chunk context
|
||||
if chunk.chunk_context and chunk.content.endswith(chunk.chunk_context):
|
||||
chunk.content = chunk.content[
|
||||
: len(chunk.content) - len(chunk.chunk_context)
|
||||
].rstrip()
|
||||
return chunk.content
|
||||
|
||||
for chunk in chunks:
|
||||
chunk.content = _remove_title(chunk)
|
||||
chunk.content = _remove_metadata_suffix(chunk)
|
||||
chunk.content = _remove_contextual_rag(chunk)
|
||||
|
||||
return [chunk.to_inference_chunk() for chunk in chunks]
|
||||
@@ -167,9 +167,9 @@ class IndexRetrievalFilters(BaseModel):
|
||||
|
||||
class SchemaVerifiable(abc.ABC):
|
||||
"""
|
||||
Class must implement document index schema verification. For example, verify that all of the
|
||||
necessary attributes for indexing, querying, filtering, and fields to return from search are
|
||||
all valid in the schema.
|
||||
Class must implement document index schema verification. For example, verify
|
||||
that all of the necessary attributes for indexing, querying, filtering, and
|
||||
fields to return from search are all valid in the schema.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -179,13 +179,18 @@ class SchemaVerifiable(abc.ABC):
|
||||
embedding_precision: EmbeddingPrecision,
|
||||
) -> None:
|
||||
"""
|
||||
Verify that the document index exists and is consistent with the expectations in the code. For certain search
|
||||
engines, the schema needs to be created before indexing can happen. This call should create the schema if it
|
||||
does not exist.
|
||||
Verifies that the document index exists and is consistent with the
|
||||
expectations in the code.
|
||||
|
||||
Parameters:
|
||||
- embedding_dim: Vector dimensionality for the vector similarity part of the search
|
||||
- embedding_precision: Precision of the vector similarity part of the search
|
||||
For certain search engines, the schema needs to be created before
|
||||
indexing can happen. This call should create the schema if it does not
|
||||
exist.
|
||||
|
||||
Args:
|
||||
embedding_dim: Vector dimensionality for the vector similarity part
|
||||
of the search.
|
||||
embedding_precision: Precision of the values of the vectors for the
|
||||
similarity part of the search.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -238,8 +243,8 @@ class Deletable(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def delete(
|
||||
self,
|
||||
# TODO(andrei): Fine for now but this can probably be a batch operation that
|
||||
# takes in a list of IDs.
|
||||
# TODO(andrei): Fine for now but this can probably be a batch operation
|
||||
# that takes in a list of IDs.
|
||||
document_id: str,
|
||||
chunk_count: int | None = None,
|
||||
# TODO(andrei): Shouldn't this also have some acl filtering at minimum?
|
||||
@@ -283,10 +288,7 @@ class Updatable(abc.ABC):
|
||||
self,
|
||||
update_requests: list[MetadataUpdateRequest],
|
||||
) -> None:
|
||||
"""
|
||||
Updates some set of chunks. The document and fields to update are specified in the update
|
||||
requests. Each update request in the list applies its changes to a list of document ids.
|
||||
None values mean that the field does not need an update.
|
||||
"""Updates some set of chunks.
|
||||
|
||||
The document and fields to update are specified in the update requests.
|
||||
Each update request in the list applies its changes to a list of
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TypeVar
|
||||
|
||||
from opensearchpy import OpenSearch
|
||||
from opensearchpy.exceptions import TransportError
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.app_configs import OPENSEARCH_ADMIN_PASSWORD
|
||||
from onyx.configs.app_configs import OPENSEARCH_ADMIN_USERNAME
|
||||
@@ -17,10 +20,36 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger(__name__)
|
||||
# Set the logging level to WARNING to ignore INFO and DEBUG logs from
|
||||
# opensearch. By default it emits INFO-level logs for every request.
|
||||
# TODO(andrei): I don't think this is working as intended, I still see spam in
|
||||
# logs. The module name is probably wrong or opensearchpy initializes a logger
|
||||
# dynamically along with an instance of a client class. Look at the constructor
|
||||
# for OpenSearch.
|
||||
opensearch_logger = logging.getLogger("opensearchpy")
|
||||
opensearch_logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
SchemaDocumentModel = TypeVar("SchemaDocumentModel")
|
||||
|
||||
|
||||
class SearchHit(BaseModel, Generic[SchemaDocumentModel]):
|
||||
"""Represents a hit from OpenSearch in response to a query.
|
||||
|
||||
Templated on the specific document model as defined by a schema.
|
||||
"""
|
||||
|
||||
model_config = {"frozen": True}
|
||||
|
||||
# The document chunk source retrieved from OpenSearch.
|
||||
document_chunk: SchemaDocumentModel
|
||||
# The match score for the document chunk as calculated by OpenSearch. Only
|
||||
# relevant for "fuzzy searches"; this will be None for direct queries where
|
||||
# score is not relevant like direct retrieval on ID.
|
||||
score: float | None = None
|
||||
# Maps schema property name to a list of highlighted snippets with match
|
||||
# terms wrapped in tags (e.g. "something <hi>keyword</hi> other thing").
|
||||
match_highlights: dict[str, list[str]] = {}
|
||||
|
||||
|
||||
class OpenSearchClient:
|
||||
"""Client for interacting with OpenSearch.
|
||||
|
||||
@@ -230,9 +259,9 @@ class OpenSearchClient:
|
||||
)
|
||||
result_string: str = result.get("result", "")
|
||||
match result_string:
|
||||
# Sanity check.
|
||||
case "created":
|
||||
return
|
||||
# Sanity check.
|
||||
case "updated":
|
||||
raise RuntimeError(
|
||||
f'The OpenSearch client returned result "updated" for indexing document chunk "{document_chunk_id}". '
|
||||
@@ -307,9 +336,49 @@ class OpenSearchClient:
|
||||
|
||||
return num_deleted
|
||||
|
||||
def update_document(self) -> None:
|
||||
# TODO(andrei): Implement this.
|
||||
raise NotImplementedError("Not implemented.")
|
||||
def update_document(
|
||||
self, document_chunk_id: str, properties_to_update: dict[str, Any]
|
||||
) -> None:
|
||||
"""Updates a document's properties.
|
||||
|
||||
Args:
|
||||
document_chunk_id: The OpenSearch ID of the document chunk to
|
||||
update.
|
||||
properties_to_update: The properties of the document to update. Each
|
||||
property should exist in the schema.
|
||||
|
||||
Raises:
|
||||
Exception: There was an error updating the document.
|
||||
"""
|
||||
update_body: dict[str, Any] = {"doc": properties_to_update}
|
||||
result = self._client.update(
|
||||
index=self._index_name,
|
||||
id=document_chunk_id,
|
||||
body=update_body,
|
||||
_source=False,
|
||||
)
|
||||
result_id = result.get("_id", "")
|
||||
# Sanity check.
|
||||
if result_id != document_chunk_id:
|
||||
raise RuntimeError(
|
||||
f'Upon trying to update a document, OpenSearch responded with ID "{result_id}" '
|
||||
f'instead of "{document_chunk_id}" which is the ID it was given.'
|
||||
)
|
||||
result_string: str = result.get("result", "")
|
||||
match result_string:
|
||||
# Sanity check.
|
||||
case "updated":
|
||||
return
|
||||
case "noop":
|
||||
logger.warning(
|
||||
f'OpenSearch reported a no-op when trying to update document with ID "{document_chunk_id}".'
|
||||
)
|
||||
return
|
||||
case _:
|
||||
raise RuntimeError(
|
||||
f'The OpenSearch client returned result "{result_string}" for updating document chunk "{document_chunk_id}". '
|
||||
"This is unexpected."
|
||||
)
|
||||
|
||||
def get_document(self, document_chunk_id: str) -> DocumentChunk:
|
||||
"""Gets a document.
|
||||
@@ -378,12 +447,13 @@ class OpenSearchClient:
|
||||
|
||||
def search(
|
||||
self, body: dict[str, Any], search_pipeline_id: str | None
|
||||
) -> list[DocumentChunk]:
|
||||
) -> list[SearchHit[DocumentChunk]]:
|
||||
"""Searches the index.
|
||||
|
||||
TODO(andrei): Ideally we could check that every field in the body is
|
||||
present in the index, to avoid a class of runtime bugs that could easily
|
||||
be caught during development.
|
||||
be caught during development. Or change the function signature to accept
|
||||
a predefined pydantic model of allowed fields.
|
||||
|
||||
Args:
|
||||
body: The body of the search request. See the OpenSearch
|
||||
@@ -395,7 +465,7 @@ class OpenSearchClient:
|
||||
Exception: There was an error searching the index.
|
||||
|
||||
Returns:
|
||||
List of document chunks that match the search request.
|
||||
List of search hits that match the search request.
|
||||
"""
|
||||
result: dict[str, Any]
|
||||
if search_pipeline_id:
|
||||
@@ -407,15 +477,22 @@ class OpenSearchClient:
|
||||
|
||||
hits = self._get_hits_from_search_result(result)
|
||||
|
||||
result_chunks: list[DocumentChunk] = []
|
||||
search_hits: list[SearchHit[DocumentChunk]] = []
|
||||
for hit in hits:
|
||||
document_chunk_source: dict[str, Any] | None = hit.get("_source")
|
||||
if not document_chunk_source:
|
||||
raise RuntimeError(
|
||||
f"Document chunk with ID \"{hit.get('_id', '')}\" has no data."
|
||||
)
|
||||
result_chunks.append(DocumentChunk.model_validate(document_chunk_source))
|
||||
return result_chunks
|
||||
document_chunk_score = hit.get("_score", None)
|
||||
match_highlights: dict[str, list[str]] = hit.get("highlight", {})
|
||||
search_hit = SearchHit[DocumentChunk](
|
||||
document_chunk=DocumentChunk.model_validate(document_chunk_source),
|
||||
score=document_chunk_score,
|
||||
match_highlights=match_highlights,
|
||||
)
|
||||
search_hits.append(search_hit)
|
||||
return search_hits
|
||||
|
||||
def search_for_document_ids(self, body: dict[str, Any]) -> list[str]:
|
||||
"""Searches the index and returns only document chunk IDs.
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -6,6 +7,7 @@ from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_experts_stores_representations,
|
||||
)
|
||||
from onyx.connectors.models import convert_metadata_list_of_strings_to_dict
|
||||
from onyx.context.search.enums import QueryType
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
@@ -13,6 +15,10 @@ from onyx.context.search.models import InferenceChunkUncleaned
|
||||
from onyx.context.search.models import QueryExpansionType
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.db.models import DocumentSource
|
||||
from onyx.document_index.chunk_content_enrichment import cleanup_content_for_chunks
|
||||
from onyx.document_index.chunk_content_enrichment import (
|
||||
generate_enriched_content_for_chunk,
|
||||
)
|
||||
from onyx.document_index.interfaces import DocumentIndex as OldDocumentIndex
|
||||
from onyx.document_index.interfaces import (
|
||||
DocumentInsertionRecord as OldDocumentInsertionRecord,
|
||||
@@ -29,8 +35,16 @@ from onyx.document_index.interfaces_new import IndexingMetadata
|
||||
from onyx.document_index.interfaces_new import MetadataUpdateRequest
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
from onyx.document_index.opensearch.client import SearchHit
|
||||
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import DOCUMENT_SETS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import DocumentChunk
|
||||
from onyx.document_index.opensearch.schema import DocumentSchema
|
||||
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
|
||||
from onyx.document_index.opensearch.schema import GLOBAL_BOOST_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import HIDDEN_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.search import DocumentQuery
|
||||
from onyx.document_index.opensearch.search import (
|
||||
MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
|
||||
@@ -54,14 +68,40 @@ from shared_configs.model_server_models import Embedding
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
def _convert_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
def _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
chunk: DocumentChunk,
|
||||
score: float | None,
|
||||
highlights: dict[str, list[str]],
|
||||
) -> InferenceChunkUncleaned:
|
||||
"""
|
||||
Generates an inference chunk from an OpenSearch document chunk, its score,
|
||||
and its match highlights.
|
||||
|
||||
Args:
|
||||
chunk: The document chunk returned by OpenSearch.
|
||||
score: The document chunk match score as calculated by OpenSearch. Only
|
||||
relevant for searches like hybrid search. It is acceptable for this
|
||||
value to be None for results from other queries like ID-based
|
||||
retrieval as a match score makes no sense in those contexts.
|
||||
highlights: Maps schema property name to a list of highlighted snippets
|
||||
with match terms wrapped in tags (e.g. "something <hi>keyword</hi>
|
||||
other thing").
|
||||
|
||||
Returns:
|
||||
An Onyx inference chunk representation.
|
||||
"""
|
||||
return InferenceChunkUncleaned(
|
||||
chunk_id=chunk.chunk_index,
|
||||
blurb=chunk.blurb,
|
||||
# Includes extra content prepended/appended during indexing.
|
||||
content=chunk.content,
|
||||
source_links=json.loads(chunk.source_links) if chunk.source_links else None,
|
||||
# When we read a string and turn it into a dict the keys will be
|
||||
# strings, but in this case they need to be ints.
|
||||
source_links=(
|
||||
{int(k): v for k, v in json.loads(chunk.source_links).items()}
|
||||
if chunk.source_links
|
||||
else None
|
||||
),
|
||||
image_file_id=chunk.image_file_id,
|
||||
# Deprecated. Fill in some reasonable default.
|
||||
section_continuation=False,
|
||||
@@ -70,42 +110,30 @@ def _convert_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
semantic_identifier=chunk.semantic_identifier,
|
||||
title=chunk.title,
|
||||
boost=chunk.global_boost,
|
||||
# TODO(andrei): Do in a followup. We should be able to get this from
|
||||
# OpenSearch.
|
||||
recency_bias=1.0,
|
||||
# TODO(andrei): This is how good the match is, we need this, key insight
|
||||
# is we can order chunks by this. Should not be hard to plumb this from
|
||||
# a search result, do that in a followup.
|
||||
score=None,
|
||||
score=score,
|
||||
hidden=chunk.hidden,
|
||||
metadata=json.loads(chunk.metadata),
|
||||
# TODO(andrei): The vector DB needs to supply this. I vaguely know
|
||||
# OpenSearch can from the documentation I've seen till now, look at this
|
||||
# in a followup.
|
||||
match_highlights=[],
|
||||
metadata=(
|
||||
convert_metadata_list_of_strings_to_dict(chunk.metadata_list)
|
||||
if chunk.metadata_list
|
||||
else {}
|
||||
),
|
||||
# Extract highlighted snippets from the content field, if available. In
|
||||
# the future we may want to match on other fields too, currently we only
|
||||
# use the content field.
|
||||
match_highlights=highlights.get(CONTENT_FIELD_NAME, []),
|
||||
# TODO(andrei) Consider storing a chunk content index instead of a full
|
||||
# string when working on chunk content augmentation.
|
||||
doc_summary=chunk.doc_summary,
|
||||
# TODO(andrei) Same thing as contx ret above, LLM gens context for each
|
||||
# chunk.
|
||||
# TODO(andrei) Same thing as above.
|
||||
chunk_context=chunk.chunk_context,
|
||||
updated_at=chunk.last_updated,
|
||||
primary_owners=chunk.primary_owners,
|
||||
secondary_owners=chunk.secondary_owners,
|
||||
# TODO(andrei): This is the suffix appended to the end of the chunk
|
||||
# content to assist querying. There are better ways we can do this, for
|
||||
# ex. keeping an index of where to string split from.
|
||||
metadata_suffix=None,
|
||||
# TODO(andrei) Same thing as chunk_context above.
|
||||
metadata_suffix=chunk.metadata_suffix,
|
||||
)
|
||||
|
||||
|
||||
def _convert_inference_chunk_uncleaned_to_inference_chunk(
|
||||
inference_chunk_uncleaned: InferenceChunkUncleaned,
|
||||
) -> InferenceChunk:
|
||||
# TODO(andrei): Implement this.
|
||||
return inference_chunk_uncleaned.to_inference_chunk()
|
||||
|
||||
|
||||
def _convert_onyx_chunk_to_opensearch_document(
|
||||
chunk: DocMetadataAwareIndexChunk,
|
||||
) -> DocumentChunk:
|
||||
@@ -114,22 +142,35 @@ def _convert_onyx_chunk_to_opensearch_document(
|
||||
chunk_index=chunk.chunk_id,
|
||||
title=chunk.source_document.title,
|
||||
title_vector=chunk.title_embedding,
|
||||
content=chunk.content,
|
||||
content=generate_enriched_content_for_chunk(chunk),
|
||||
content_vector=chunk.embeddings.full_embedding,
|
||||
source_type=chunk.source_document.source.value,
|
||||
metadata=json.dumps(chunk.source_document.metadata),
|
||||
metadata_list=chunk.source_document.get_metadata_str_attributes(),
|
||||
metadata_suffix=chunk.metadata_suffix_keyword,
|
||||
last_updated=chunk.source_document.doc_updated_at,
|
||||
public=chunk.access.is_public,
|
||||
# TODO(andrei): When going over ACL look very carefully at
|
||||
# access_control_list. Notice DocumentAccess::to_acl prepends every
|
||||
# string with a type.
|
||||
access_control_list=list(chunk.access.to_acl()),
|
||||
global_boost=chunk.boost,
|
||||
semantic_identifier=chunk.source_document.semantic_identifier,
|
||||
image_file_id=chunk.image_file_id,
|
||||
# Small optimization, if this list is empty we can supply None to
|
||||
# OpenSearch and it will not store any data at all for this field, which
|
||||
# is different from supplying an empty list.
|
||||
source_links=json.dumps(chunk.source_links) if chunk.source_links else None,
|
||||
blurb=chunk.blurb,
|
||||
doc_summary=chunk.doc_summary,
|
||||
chunk_context=chunk.chunk_context,
|
||||
# Small optimization, if this list is empty we can supply None to
|
||||
# OpenSearch and it will not store any data at all for this field, which
|
||||
# is different from supplying an empty list.
|
||||
document_sets=list(chunk.document_sets) if chunk.document_sets else None,
|
||||
project_ids=list(chunk.user_project) if chunk.user_project else None,
|
||||
# Small optimization, if this list is empty we can supply None to
|
||||
# OpenSearch and it will not store any data at all for this field, which
|
||||
# is different from supplying an empty list.
|
||||
user_projects=chunk.user_project or None,
|
||||
primary_owners=get_experts_stores_representations(
|
||||
chunk.source_document.primary_owners
|
||||
),
|
||||
@@ -144,23 +185,6 @@ def _convert_onyx_chunk_to_opensearch_document(
|
||||
)
|
||||
|
||||
|
||||
def _enrich_chunk_info() -> None: # pyright: ignore[reportUnusedFunction]
|
||||
# TODO(andrei): Implement this. Until then, we do not enrich chunk content
|
||||
# with title, etc.
|
||||
raise NotImplementedError(
|
||||
"[ANDREI]: Enrich chunk info is not implemented for OpenSearch."
|
||||
)
|
||||
|
||||
|
||||
def _clean_chunk_info() -> None: # pyright: ignore[reportUnusedFunction]
|
||||
# Analogous to _cleanup_chunks in vespa_document_index.py.
|
||||
# TODO(andrei): Implement this. Until then, we do not enrich chunk content
|
||||
# with title, etc.
|
||||
raise NotImplementedError(
|
||||
"[ANDREI]: Clean chunk info is not implemented for OpenSearch."
|
||||
)
|
||||
|
||||
|
||||
class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
"""
|
||||
Wrapper for OpenSearch to adapt the new DocumentIndex interface with
|
||||
@@ -186,6 +210,10 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
index_name=index_name,
|
||||
secondary_index_name=secondary_index_name,
|
||||
)
|
||||
if multitenant:
|
||||
raise ValueError(
|
||||
"Bug: OpenSearch is not yet ready for multitenant environments but something tried to use it."
|
||||
)
|
||||
self._real_index = OpenSearchDocumentIndex(
|
||||
index_name=index_name,
|
||||
# TODO(andrei): Sus. Do not plug this into production until all
|
||||
@@ -452,7 +480,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
opensearch_document_chunk = _convert_onyx_chunk_to_opensearch_document(
|
||||
chunk
|
||||
)
|
||||
# TODO(andrei): Enrich chunk content here.
|
||||
# TODO(andrei): After our client supports batch indexing, use that
|
||||
# here.
|
||||
self._os_client.index_document(opensearch_document_chunk)
|
||||
@@ -494,15 +521,80 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
self,
|
||||
update_requests: list[MetadataUpdateRequest],
|
||||
) -> None:
|
||||
logger.info("[ANDREI]: Updating documents...")
|
||||
# TODO(andrei): This needs to be implemented. I explicitly do not raise
|
||||
# here despite this not being implemented because indexing calls this
|
||||
# method so it is very hard to test other methods of this class if this
|
||||
# raises.
|
||||
"""Updates some set of chunks.
|
||||
|
||||
NOTE: Requires document chunk count be known; will raise if it is not.
|
||||
NOTE: Each update request must have some field to update; if not it is
|
||||
assumed there is a bug in the caller and this will raise.
|
||||
|
||||
TODO(andrei): Consider exploring a batch API for OpenSearch for this
|
||||
operation.
|
||||
|
||||
Args:
|
||||
update_requests: A list of update requests, each containing a list
|
||||
of document IDs and the fields to update. The field updates
|
||||
apply to all of the specified documents in each update request.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Failed to update some or all of the chunks for the
|
||||
specified documents.
|
||||
"""
|
||||
for update_request in update_requests:
|
||||
properties_to_update: dict[str, Any] = dict()
|
||||
# TODO(andrei): Nit but consider if we can use DocumentChunk
|
||||
# here so we don't have to think about passing in the
|
||||
# appropriate types into this dict.
|
||||
if update_request.access is not None:
|
||||
properties_to_update[ACCESS_CONTROL_LIST_FIELD_NAME] = list(
|
||||
update_request.access.to_acl()
|
||||
)
|
||||
if update_request.document_sets is not None:
|
||||
properties_to_update[DOCUMENT_SETS_FIELD_NAME] = list(
|
||||
update_request.document_sets
|
||||
)
|
||||
if update_request.boost is not None:
|
||||
properties_to_update[GLOBAL_BOOST_FIELD_NAME] = int(
|
||||
update_request.boost
|
||||
)
|
||||
if update_request.hidden is not None:
|
||||
properties_to_update[HIDDEN_FIELD_NAME] = update_request.hidden
|
||||
if update_request.project_ids is not None:
|
||||
properties_to_update[USER_PROJECTS_FIELD_NAME] = list(
|
||||
update_request.project_ids
|
||||
)
|
||||
|
||||
for doc_id in update_request.document_ids:
|
||||
if not properties_to_update:
|
||||
raise ValueError(
|
||||
f"Bug: Tried to update document {doc_id} with no updated fields or user fields."
|
||||
)
|
||||
|
||||
doc_chunk_count = update_request.doc_id_to_chunk_cnt.get(doc_id, -1)
|
||||
if doc_chunk_count < 0:
|
||||
raise ValueError(
|
||||
f"Tried to update document {doc_id} but its chunk count is not known. Older versions of the "
|
||||
"application used to permit this but is not a supported state for a document when using OpenSearch."
|
||||
)
|
||||
if doc_chunk_count == 0:
|
||||
raise ValueError(
|
||||
f"Bug: Tried to update document {doc_id} but its chunk count was 0."
|
||||
)
|
||||
|
||||
for chunk_index in range(doc_chunk_count):
|
||||
document_chunk_id = get_opensearch_doc_chunk_id(
|
||||
document_id=doc_id, chunk_index=chunk_index
|
||||
)
|
||||
self._os_client.update_document(
|
||||
document_chunk_id=document_chunk_id,
|
||||
properties_to_update=properties_to_update,
|
||||
)
|
||||
|
||||
def id_based_retrieval(
|
||||
self,
|
||||
chunk_requests: list[DocumentSectionRequest],
|
||||
# TODO(andrei): When going over ACL look very carefully at
|
||||
# access_control_list. Notice DocumentAccess::to_acl prepends every
|
||||
# string with a type.
|
||||
filters: IndexFilters,
|
||||
# TODO(andrei): Remove this from the new interface at some point; we
|
||||
# should not be exposing this.
|
||||
@@ -514,7 +606,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
"""
|
||||
results: list[InferenceChunk] = []
|
||||
for chunk_request in chunk_requests:
|
||||
document_chunks: list[DocumentChunk] = []
|
||||
search_hits: list[SearchHit[DocumentChunk]] = []
|
||||
query_body = DocumentQuery.get_from_document_id_query(
|
||||
document_id=chunk_request.document_id,
|
||||
tenant_state=self._tenant_state,
|
||||
@@ -522,22 +614,20 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
min_chunk_index=chunk_request.min_chunk_ind,
|
||||
max_chunk_index=chunk_request.max_chunk_ind,
|
||||
)
|
||||
document_chunks = self._os_client.search(
|
||||
search_hits = self._os_client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
)
|
||||
inference_chunks_uncleaned = [
|
||||
_convert_opensearch_chunk_to_inference_chunk_uncleaned(document_chunk)
|
||||
for document_chunk in document_chunks
|
||||
]
|
||||
inference_chunks = [
|
||||
_convert_inference_chunk_uncleaned_to_inference_chunk(
|
||||
inference_chunk_uncleaned
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
search_hit.document_chunk, None, {}
|
||||
)
|
||||
for inference_chunk_uncleaned in inference_chunks_uncleaned
|
||||
for search_hit in search_hits
|
||||
]
|
||||
inference_chunks: list[InferenceChunk] = cleanup_content_for_chunks(
|
||||
inference_chunks_uncleaned
|
||||
)
|
||||
results.extend(inference_chunks)
|
||||
# TODO(andrei): Clean chunk content here.
|
||||
return results
|
||||
|
||||
def hybrid_retrieval(
|
||||
@@ -546,6 +636,9 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
query_embedding: Embedding,
|
||||
final_keywords: list[str] | None,
|
||||
query_type: QueryType,
|
||||
# TODO(andrei): When going over ACL look very carefully at
|
||||
# access_control_list. Notice DocumentAccess::to_acl prepends every
|
||||
# string with a type.
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
offset: int = 0,
|
||||
@@ -557,25 +650,27 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
num_hits=num_to_retrieve,
|
||||
tenant_state=self._tenant_state,
|
||||
)
|
||||
document_chunks = self._os_client.search(
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._os_client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
)
|
||||
# TODO(andrei): Clean chunk content here.
|
||||
inference_chunks_uncleaned = [
|
||||
_convert_opensearch_chunk_to_inference_chunk_uncleaned(document_chunk)
|
||||
for document_chunk in document_chunks
|
||||
]
|
||||
inference_chunks = [
|
||||
_convert_inference_chunk_uncleaned_to_inference_chunk(
|
||||
inference_chunk_uncleaned
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
search_hit.document_chunk, search_hit.score, search_hit.match_highlights
|
||||
)
|
||||
for inference_chunk_uncleaned in inference_chunks_uncleaned
|
||||
for search_hit in search_hits
|
||||
]
|
||||
inference_chunks: list[InferenceChunk] = cleanup_content_for_chunks(
|
||||
inference_chunks_uncleaned
|
||||
)
|
||||
|
||||
return inference_chunks
|
||||
|
||||
def random_retrieval(
|
||||
self,
|
||||
# TODO(andrei): When going over ACL look very carefully at
|
||||
# access_control_list. Notice DocumentAccess::to_acl prepends every
|
||||
# string with a type.
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int = 100,
|
||||
dirty: bool | None = None,
|
||||
|
||||
@@ -25,7 +25,7 @@ TITLE_VECTOR_FIELD_NAME = "title_vector"
|
||||
CONTENT_FIELD_NAME = "content"
|
||||
CONTENT_VECTOR_FIELD_NAME = "content_vector"
|
||||
SOURCE_TYPE_FIELD_NAME = "source_type"
|
||||
METADATA_FIELD_NAME = "metadata"
|
||||
METADATA_LIST_FIELD_NAME = "metadata_list"
|
||||
LAST_UPDATED_FIELD_NAME = "last_updated"
|
||||
PUBLIC_FIELD_NAME = "public"
|
||||
ACCESS_CONTROL_LIST_FIELD_NAME = "access_control_list"
|
||||
@@ -35,7 +35,7 @@ SEMANTIC_IDENTIFIER_FIELD_NAME = "semantic_identifier"
|
||||
IMAGE_FILE_ID_FIELD_NAME = "image_file_id"
|
||||
SOURCE_LINKS_FIELD_NAME = "source_links"
|
||||
DOCUMENT_SETS_FIELD_NAME = "document_sets"
|
||||
PROJECT_IDS_FIELD_NAME = "project_ids"
|
||||
USER_PROJECTS_FIELD_NAME = "user_projects"
|
||||
DOCUMENT_ID_FIELD_NAME = "document_id"
|
||||
CHUNK_INDEX_FIELD_NAME = "chunk_index"
|
||||
MAX_CHUNK_SIZE_FIELD_NAME = "max_chunk_size"
|
||||
@@ -43,6 +43,7 @@ TENANT_ID_FIELD_NAME = "tenant_id"
|
||||
BLURB_FIELD_NAME = "blurb"
|
||||
DOC_SUMMARY_FIELD_NAME = "doc_summary"
|
||||
CHUNK_CONTEXT_FIELD_NAME = "chunk_context"
|
||||
METADATA_SUFFIX_FIELD_NAME = "metadata_suffix"
|
||||
PRIMARY_OWNERS_FIELD_NAME = "primary_owners"
|
||||
SECONDARY_OWNERS_FIELD_NAME = "secondary_owners"
|
||||
|
||||
@@ -101,12 +102,9 @@ class DocumentChunk(BaseModel):
|
||||
content_vector: list[float]
|
||||
|
||||
source_type: str
|
||||
# Contains a string representation of a dict which maps string key to either
|
||||
# string value or list of string values.
|
||||
# TODO(andrei): When we augment content with metadata this can just be an
|
||||
# index pointer, and when we support metadata list that will just be a list
|
||||
# of strings.
|
||||
metadata: str
|
||||
# A list of key-value pairs separated by INDEX_SEPARATOR. See
|
||||
# convert_metadata_dict_to_list_of_strings.
|
||||
metadata_list: list[str] | None = None
|
||||
# If it exists, time zone should always be UTC.
|
||||
last_updated: datetime | None = None
|
||||
|
||||
@@ -123,12 +121,16 @@ class DocumentChunk(BaseModel):
|
||||
# chunk text to the link corresponding to that point.
|
||||
source_links: str | None = None
|
||||
blurb: str
|
||||
# doc_summary, chunk_context, and metadata_suffix are all stored simply to
|
||||
# reverse the augmentations to content. Ideally these would just be start
|
||||
# and stop indices into the content string. For legacy reasons they are not
|
||||
# right now.
|
||||
doc_summary: str
|
||||
chunk_context: str
|
||||
metadata_suffix: str | None = None
|
||||
|
||||
document_sets: list[str] | None = None
|
||||
# User projects.
|
||||
project_ids: list[int] | None = None
|
||||
user_projects: list[int] | None = None
|
||||
primary_owners: list[str] | None = None
|
||||
secondary_owners: list[str] | None = None
|
||||
|
||||
@@ -283,6 +285,12 @@ class DocumentSchema:
|
||||
full-text searches.
|
||||
- "store": True fields are stored and can be returned on their own,
|
||||
independent of the parent document.
|
||||
- "index": True fields can be queried on.
|
||||
- "doc_values": True fields can be sorted and aggregated efficiently.
|
||||
Not supported for "text" type fields.
|
||||
- "store": True fields are stored separately from the source document
|
||||
and can thus be returned from a query separately from _source.
|
||||
Generally this is not necessary.
|
||||
|
||||
Args:
|
||||
vector_dimension: The dimension of vector embeddings. Must be a
|
||||
@@ -309,10 +317,18 @@ class DocumentSchema:
|
||||
# TODO(andrei): Ask Yuhong do we want this?
|
||||
"keyword": {"type": "keyword", "ignore_above": 256}
|
||||
},
|
||||
# This makes highlighting text during queries more efficient
|
||||
# at the cost of disk space. See
|
||||
# https://docs.opensearch.org/latest/search-plugins/searching-data/highlight/#methods-of-obtaining-offsets
|
||||
"index_options": "offsets",
|
||||
},
|
||||
CONTENT_FIELD_NAME: {
|
||||
"type": "text",
|
||||
"store": True,
|
||||
# This makes highlighting text during queries more efficient
|
||||
# at the cost of disk space. See
|
||||
# https://docs.opensearch.org/latest/search-plugins/searching-data/highlight/#methods-of-obtaining-offsets
|
||||
"index_options": "offsets",
|
||||
},
|
||||
TITLE_VECTOR_FIELD_NAME: {
|
||||
"type": "knn_vector",
|
||||
@@ -337,7 +353,7 @@ class DocumentSchema:
|
||||
},
|
||||
},
|
||||
SOURCE_TYPE_FIELD_NAME: {"type": "keyword"},
|
||||
METADATA_FIELD_NAME: {"type": "keyword"},
|
||||
METADATA_LIST_FIELD_NAME: {"type": "keyword"},
|
||||
# TODO(andrei): Check if Vespa stores seconds, we may wanna do
|
||||
# seconds here not millis.
|
||||
LAST_UPDATED_FIELD_NAME: {
|
||||
@@ -362,11 +378,13 @@ class DocumentSchema:
|
||||
GLOBAL_BOOST_FIELD_NAME: {"type": "integer"},
|
||||
# This field is only used for displaying a useful name for the
|
||||
# doc in the UI and is not used for searching. Disabling these
|
||||
# features to increase perf.
|
||||
# features to increase perf. This field is therefore essentially
|
||||
# just metadata.
|
||||
SEMANTIC_IDENTIFIER_FIELD_NAME: {
|
||||
"type": "keyword",
|
||||
"index": False,
|
||||
"doc_values": False,
|
||||
# Generally False by default; just making sure.
|
||||
"store": False,
|
||||
},
|
||||
# Same as above; used to display an image along with the doc.
|
||||
@@ -374,6 +392,7 @@ class DocumentSchema:
|
||||
"type": "keyword",
|
||||
"index": False,
|
||||
"doc_values": False,
|
||||
# Generally False by default; just making sure.
|
||||
"store": False,
|
||||
},
|
||||
# Same as above; used to link to the source doc.
|
||||
@@ -381,6 +400,7 @@ class DocumentSchema:
|
||||
"type": "keyword",
|
||||
"index": False,
|
||||
"doc_values": False,
|
||||
# Generally False by default; just making sure.
|
||||
"store": False,
|
||||
},
|
||||
# Same as above; used to quickly summarize the doc in the UI.
|
||||
@@ -388,6 +408,7 @@ class DocumentSchema:
|
||||
"type": "keyword",
|
||||
"index": False,
|
||||
"doc_values": False,
|
||||
# Generally False by default; just making sure.
|
||||
"store": False,
|
||||
},
|
||||
# Same as above.
|
||||
@@ -397,12 +418,21 @@ class DocumentSchema:
|
||||
"type": "keyword",
|
||||
"index": False,
|
||||
"doc_values": False,
|
||||
# Generally False by default; just making sure.
|
||||
"store": False,
|
||||
},
|
||||
# Same as above.
|
||||
# TODO(andrei): If we want to search on this this needs to be
|
||||
# changed.
|
||||
CHUNK_CONTEXT_FIELD_NAME: {
|
||||
"type": "keyword",
|
||||
"index": False,
|
||||
"doc_values": False,
|
||||
# Generally False by default; just making sure.
|
||||
"store": False,
|
||||
},
|
||||
# Same as above.
|
||||
METADATA_SUFFIX_FIELD_NAME: {
|
||||
"type": "keyword",
|
||||
"index": False,
|
||||
"doc_values": False,
|
||||
@@ -410,7 +440,7 @@ class DocumentSchema:
|
||||
},
|
||||
# Product-specific fields.
|
||||
DOCUMENT_SETS_FIELD_NAME: {"type": "keyword"},
|
||||
PROJECT_IDS_FIELD_NAME: {"type": "integer"},
|
||||
USER_PROJECTS_FIELD_NAME: {"type": "integer"},
|
||||
PRIMARY_OWNERS_FIELD_NAME: {"type": "keyword"},
|
||||
SECONDARY_OWNERS_FIELD_NAME: {"type": "keyword"},
|
||||
# OpenSearch metadata fields.
|
||||
|
||||
@@ -244,6 +244,9 @@ class DocumentQuery:
|
||||
query_text, query_vector, num_candidates
|
||||
)
|
||||
hybrid_search_filters = DocumentQuery._get_hybrid_search_filters(tenant_state)
|
||||
match_highlights_configuration = (
|
||||
DocumentQuery._get_match_highlights_configuration()
|
||||
)
|
||||
|
||||
hybrid_search_query: dict[str, Any] = {
|
||||
"bool": {
|
||||
@@ -254,6 +257,8 @@ class DocumentQuery:
|
||||
}
|
||||
}
|
||||
],
|
||||
# TODO(andrei): When revisiting our hybrid query logic see if
|
||||
# this needs to be nested one level down.
|
||||
"filter": hybrid_search_filters,
|
||||
}
|
||||
}
|
||||
@@ -261,6 +266,7 @@ class DocumentQuery:
|
||||
final_hybrid_search_body: dict[str, Any] = {
|
||||
"query": hybrid_search_query,
|
||||
"size": num_hits,
|
||||
"highlight": match_highlights_configuration,
|
||||
}
|
||||
return final_hybrid_search_body
|
||||
|
||||
@@ -346,3 +352,30 @@ class DocumentQuery:
|
||||
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
|
||||
)
|
||||
return hybrid_search_filters
|
||||
|
||||
@staticmethod
|
||||
def _get_match_highlights_configuration() -> dict[str, Any]:
|
||||
"""
|
||||
Gets configuration for returning match highlights for a hit.
|
||||
"""
|
||||
match_highlights_configuration: dict[str, Any] = {
|
||||
"fields": {
|
||||
CONTENT_FIELD_NAME: {
|
||||
# See https://docs.opensearch.org/latest/search-plugins/searching-data/highlight/#highlighter-types
|
||||
"type": "unified",
|
||||
# The length in chars of a match snippet. Somewhat
|
||||
# arbitrarily-chosen. The Vespa codepath limited total
|
||||
# highlights length to 400 chars. fragment_size *
|
||||
# number_of_fragments = 400 should be good enough.
|
||||
"fragment_size": 100,
|
||||
# The number of snippets to return per field per document
|
||||
# hit.
|
||||
"number_of_fragments": 4,
|
||||
# These tags wrap matched keywords and they match what Vespa
|
||||
# used to return. Use them to minimize changes to our code.
|
||||
"pre_tags": ["<hi>"],
|
||||
"post_tags": ["</hi>"],
|
||||
}
|
||||
}
|
||||
}
|
||||
return match_highlights_configuration
|
||||
|
||||
@@ -41,7 +41,6 @@ from onyx.document_index.vespa_constants import MAX_OR_CONDITIONS
|
||||
from onyx.document_index.vespa_constants import METADATA
|
||||
from onyx.document_index.vespa_constants import METADATA_SUFFIX
|
||||
from onyx.document_index.vespa_constants import PRIMARY_OWNERS
|
||||
from onyx.document_index.vespa_constants import RECENCY_BIAS
|
||||
from onyx.document_index.vespa_constants import SEARCH_ENDPOINT
|
||||
from onyx.document_index.vespa_constants import SECONDARY_OWNERS
|
||||
from onyx.document_index.vespa_constants import SECTION_CONTINUATION
|
||||
@@ -142,7 +141,6 @@ def _vespa_hit_to_inference_chunk(
|
||||
title=fields.get(TITLE),
|
||||
semantic_identifier=fields[SEMANTIC_IDENTIFIER],
|
||||
boost=fields.get(BOOST, 1),
|
||||
recency_bias=fields.get("matchfeatures", {}).get(RECENCY_BIAS, 1.0),
|
||||
score=None if null_score else hit.get("relevance", 0),
|
||||
hidden=fields.get(HIDDEN, False),
|
||||
primary_owners=fields.get(PRIMARY_OWNERS),
|
||||
|
||||
@@ -71,6 +71,7 @@ from onyx.utils.batching import batch_generator
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -479,12 +480,24 @@ class VespaIndex(DocumentIndex):
|
||||
indexing_metadata = IndexingMetadata(
|
||||
doc_id_to_chunk_cnt_diff=doc_id_to_chunk_cnt_diff,
|
||||
)
|
||||
tenant_state = TenantState(
|
||||
tenant_id=get_current_tenant_id(),
|
||||
multitenant=MULTI_TENANT,
|
||||
)
|
||||
if tenant_state.multitenant != self.multitenant:
|
||||
raise ValueError(
|
||||
f"Bug: Multitenant mismatch. Expected {tenant_state.multitenant}, got {self.multitenant}."
|
||||
)
|
||||
if (
|
||||
tenant_state.multitenant
|
||||
and tenant_state.tenant_id != index_batch_params.tenant_id
|
||||
):
|
||||
raise ValueError(
|
||||
f"Bug: Tenant ID mismatch. Expected {tenant_state.tenant_id}, got {index_batch_params.tenant_id}."
|
||||
)
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=self.index_name,
|
||||
tenant_state=TenantState(
|
||||
tenant_id=index_batch_params.tenant_id,
|
||||
multitenant=self.multitenant,
|
||||
),
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=self.large_chunks_enabled,
|
||||
httpx_client=self.httpx_client,
|
||||
)
|
||||
@@ -655,18 +668,23 @@ class VespaIndex(DocumentIndex):
|
||||
raise ValueError(
|
||||
f"Bug: Tried to update document {doc_id} with no updated fields or user fields."
|
||||
)
|
||||
# TODO(andrei): Very temporary, reinstate this soon.
|
||||
# if fields is not None and fields.document_id is not None:
|
||||
# raise ValueError(
|
||||
# "The new vector db interface does not support updating the document ID field."
|
||||
# )
|
||||
|
||||
tenant_state = TenantState(
|
||||
tenant_id=get_current_tenant_id(),
|
||||
multitenant=MULTI_TENANT,
|
||||
)
|
||||
if tenant_state.multitenant != self.multitenant:
|
||||
raise ValueError(
|
||||
f"Bug: Multitenant mismatch. Expected {tenant_state.multitenant}, got {self.multitenant}."
|
||||
)
|
||||
if tenant_state.multitenant and tenant_state.tenant_id != tenant_id:
|
||||
raise ValueError(
|
||||
f"Bug: Tenant ID mismatch. Expected {tenant_state.tenant_id}, got {tenant_id}."
|
||||
)
|
||||
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=self.index_name,
|
||||
tenant_state=TenantState(
|
||||
tenant_id=tenant_id,
|
||||
multitenant=self.multitenant,
|
||||
),
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=self.large_chunks_enabled,
|
||||
httpx_client=self.httpx_client,
|
||||
)
|
||||
@@ -695,12 +713,21 @@ class VespaIndex(DocumentIndex):
|
||||
tenant_id: str,
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
tenant_state = TenantState(
|
||||
tenant_id=get_current_tenant_id(),
|
||||
multitenant=MULTI_TENANT,
|
||||
)
|
||||
if tenant_state.multitenant != self.multitenant:
|
||||
raise ValueError(
|
||||
f"Bug: Multitenant mismatch. Expected {tenant_state.multitenant}, got {self.multitenant}."
|
||||
)
|
||||
if tenant_state.multitenant and tenant_state.tenant_id != tenant_id:
|
||||
raise ValueError(
|
||||
f"Bug: Tenant ID mismatch. Expected {tenant_state.tenant_id}, got {tenant_id}."
|
||||
)
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=self.index_name,
|
||||
tenant_state=TenantState(
|
||||
tenant_id=tenant_id,
|
||||
multitenant=self.multitenant,
|
||||
),
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=self.large_chunks_enabled,
|
||||
httpx_client=self.httpx_client,
|
||||
)
|
||||
@@ -713,13 +740,13 @@ class VespaIndex(DocumentIndex):
|
||||
batch_retrieval: bool = False,
|
||||
get_large_chunks: bool = False,
|
||||
) -> list[InferenceChunk]:
|
||||
tenant_id = filters.tenant_id if filters.tenant_id is not None else ""
|
||||
tenant_state = TenantState(
|
||||
tenant_id=get_current_tenant_id(),
|
||||
multitenant=MULTI_TENANT,
|
||||
)
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=self.index_name,
|
||||
tenant_state=TenantState(
|
||||
tenant_id=tenant_id,
|
||||
multitenant=self.multitenant,
|
||||
),
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=self.large_chunks_enabled,
|
||||
httpx_client=self.httpx_client,
|
||||
)
|
||||
@@ -752,13 +779,13 @@ class VespaIndex(DocumentIndex):
|
||||
offset: int = 0,
|
||||
title_content_ratio: float | None = TITLE_CONTENT_RATIO,
|
||||
) -> list[InferenceChunk]:
|
||||
tenant_id = filters.tenant_id if filters.tenant_id is not None else ""
|
||||
tenant_state = TenantState(
|
||||
tenant_id=get_current_tenant_id(),
|
||||
multitenant=MULTI_TENANT,
|
||||
)
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=self.index_name,
|
||||
tenant_state=TenantState(
|
||||
tenant_id=tenant_id,
|
||||
multitenant=self.multitenant,
|
||||
),
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=self.large_chunks_enabled,
|
||||
httpx_client=self.httpx_client,
|
||||
)
|
||||
@@ -1025,13 +1052,13 @@ class VespaIndex(DocumentIndex):
|
||||
This method is currently used for random chunk retrieval in the context of
|
||||
assistant starter message creation (passed as sample context for usage by the assistant).
|
||||
"""
|
||||
tenant_id = filters.tenant_id if filters.tenant_id is not None else ""
|
||||
tenant_state = TenantState(
|
||||
tenant_id=get_current_tenant_id(),
|
||||
multitenant=MULTI_TENANT,
|
||||
)
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=self.index_name,
|
||||
tenant_state=TenantState(
|
||||
tenant_id=tenant_id,
|
||||
multitenant=self.multitenant,
|
||||
),
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=self.large_chunks_enabled,
|
||||
httpx_client=self.httpx_client,
|
||||
)
|
||||
|
||||
@@ -16,6 +16,9 @@ from retry import retry
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_experts_stores_representations,
|
||||
)
|
||||
from onyx.document_index.chunk_content_enrichment import (
|
||||
generate_enriched_content_for_chunk,
|
||||
)
|
||||
from onyx.document_index.document_index_utils import get_uuid_from_chunk
|
||||
from onyx.document_index.document_index_utils import get_uuid_from_chunk_info_old
|
||||
from onyx.document_index.interfaces import MinimalDocumentIndexingInfo
|
||||
@@ -183,7 +186,7 @@ def _index_vespa_chunk(
|
||||
# For the BM25 index, the keyword suffix is used, the vector is already generated with the more
|
||||
# natural language representation of the metadata section
|
||||
CONTENT: remove_invalid_unicode_chars(
|
||||
f"{chunk.title_prefix}{chunk.doc_summary}{chunk.content}{chunk.chunk_context}{chunk.metadata_suffix_keyword}"
|
||||
generate_enriched_content_for_chunk(chunk)
|
||||
),
|
||||
# This duplication of `content` is needed for keyword highlighting
|
||||
# Note that it's not exactly the same as the actual content
|
||||
|
||||
@@ -7,18 +7,16 @@ import httpx
|
||||
from pydantic import BaseModel
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import BLURB_SIZE
|
||||
from onyx.configs.app_configs import RECENCY_BIAS_MULTIPLIER
|
||||
from onyx.configs.app_configs import RERANK_COUNT
|
||||
from onyx.configs.chat_configs import DOC_TIME_DECAY
|
||||
from onyx.configs.chat_configs import HYBRID_ALPHA
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
from onyx.configs.constants import RETURN_SEPARATOR
|
||||
from onyx.context.search.enums import QueryType
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceChunkUncleaned
|
||||
from onyx.context.search.preprocessing.preprocessing import HYBRID_ALPHA
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.document_index.chunk_content_enrichment import cleanup_content_for_chunks
|
||||
from onyx.document_index.document_index_utils import get_document_chunk_ids
|
||||
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
|
||||
from onyx.document_index.interfaces import MinimalDocumentIndexingInfo
|
||||
@@ -130,79 +128,6 @@ def _enrich_basic_chunk_info(
|
||||
return enriched_doc_info
|
||||
|
||||
|
||||
def _cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk]:
|
||||
"""Removes indexing-time content additions from chunks retrieved from Vespa.
|
||||
|
||||
During indexing, chunks are augmented with additional text to improve search
|
||||
quality:
|
||||
- Title prepended to content (for better keyword/semantic matching)
|
||||
- Metadata suffix appended to content
|
||||
- Contextual RAG: doc_summary (beginning) and chunk_context (end)
|
||||
|
||||
This function strips these additions before returning chunks to users,
|
||||
restoring the original document content. Cleaning is applied in sequence:
|
||||
1. Title removal:
|
||||
- Full match: Strips exact title from beginning
|
||||
- Partial match: If content starts with title[:BLURB_SIZE], splits on
|
||||
RETURN_SEPARATOR to remove title section
|
||||
2. Metadata suffix removal:
|
||||
- Strips metadata_suffix from end, plus trailing RETURN_SEPARATOR
|
||||
3. Contextual RAG removal:
|
||||
- Strips doc_summary from beginning (if present)
|
||||
- Strips chunk_context from end (if present)
|
||||
|
||||
Args:
|
||||
chunks: Chunks as retrieved from Vespa with indexing augmentations
|
||||
intact.
|
||||
|
||||
Returns:
|
||||
Clean InferenceChunk objects with augmentations removed, containing only
|
||||
the original document content that should be shown to users.
|
||||
"""
|
||||
|
||||
def _remove_title(chunk: InferenceChunkUncleaned) -> str:
|
||||
if not chunk.title or not chunk.content:
|
||||
return chunk.content
|
||||
|
||||
if chunk.content.startswith(chunk.title):
|
||||
return chunk.content[len(chunk.title) :].lstrip()
|
||||
|
||||
# BLURB SIZE is by token instead of char but each token is at least 1 char
|
||||
# If this prefix matches the content, it's assumed the title was prepended
|
||||
if chunk.content.startswith(chunk.title[:BLURB_SIZE]):
|
||||
return (
|
||||
chunk.content.split(RETURN_SEPARATOR, 1)[-1]
|
||||
if RETURN_SEPARATOR in chunk.content
|
||||
else chunk.content
|
||||
)
|
||||
return chunk.content
|
||||
|
||||
def _remove_metadata_suffix(chunk: InferenceChunkUncleaned) -> str:
|
||||
if not chunk.metadata_suffix:
|
||||
return chunk.content
|
||||
return chunk.content.removesuffix(chunk.metadata_suffix).rstrip(
|
||||
RETURN_SEPARATOR
|
||||
)
|
||||
|
||||
def _remove_contextual_rag(chunk: InferenceChunkUncleaned) -> str:
|
||||
# remove document summary
|
||||
if chunk.doc_summary and chunk.content.startswith(chunk.doc_summary):
|
||||
chunk.content = chunk.content[len(chunk.doc_summary) :].lstrip()
|
||||
# remove chunk context
|
||||
if chunk.chunk_context and chunk.content.endswith(chunk.chunk_context):
|
||||
chunk.content = chunk.content[
|
||||
: len(chunk.content) - len(chunk.chunk_context)
|
||||
].rstrip()
|
||||
return chunk.content
|
||||
|
||||
for chunk in chunks:
|
||||
chunk.content = _remove_title(chunk)
|
||||
chunk.content = _remove_metadata_suffix(chunk)
|
||||
chunk.content = _remove_contextual_rag(chunk)
|
||||
|
||||
return [chunk.to_inference_chunk() for chunk in chunks]
|
||||
|
||||
|
||||
@retry(
|
||||
tries=3,
|
||||
delay=1,
|
||||
@@ -590,7 +515,7 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
]
|
||||
|
||||
if batch_retrieval:
|
||||
return _cleanup_chunks(
|
||||
return cleanup_content_for_chunks(
|
||||
batch_search_api_retrieval(
|
||||
index_name=self._index_name,
|
||||
chunk_requests=sanitized_chunk_requests,
|
||||
@@ -600,7 +525,7 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
get_large_chunks=False,
|
||||
)
|
||||
)
|
||||
return _cleanup_chunks(
|
||||
return cleanup_content_for_chunks(
|
||||
parallel_visit_api_retrieval(
|
||||
index_name=self._index_name,
|
||||
chunk_requests=sanitized_chunk_requests,
|
||||
@@ -670,7 +595,7 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
"timeout": VESPA_TIMEOUT,
|
||||
}
|
||||
|
||||
return _cleanup_chunks(query_vespa(params))
|
||||
return cleanup_content_for_chunks(query_vespa(params))
|
||||
|
||||
def random_retrieval(
|
||||
self,
|
||||
@@ -692,4 +617,4 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
"ranking.properties.random.seed": random_seed,
|
||||
}
|
||||
|
||||
return _cleanup_chunks(query_vespa(params))
|
||||
return cleanup_content_for_chunks(query_vespa(params))
|
||||
|
||||
@@ -11,14 +11,13 @@ from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm.session import SessionTransaction
|
||||
|
||||
from onyx.chat.chat_utils import prepare_chat_message_request
|
||||
from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.process_message import AnswerStream
|
||||
from onyx.chat.process_message import handle_stream_message_objects
|
||||
from onyx.chat.process_message import remove_answer_citations
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
|
||||
from onyx.db.users import get_user_by_email
|
||||
@@ -33,7 +32,10 @@ from onyx.evals.models import ToolAssertion
|
||||
from onyx.evals.provider import get_provider
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
|
||||
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.models import RetrievalDetails
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
@@ -399,22 +401,19 @@ def _get_answer_with_tools(
|
||||
else None
|
||||
)
|
||||
|
||||
request = prepare_chat_message_request(
|
||||
message_text=eval_input["message"],
|
||||
user=user,
|
||||
persona_id=None,
|
||||
persona_override_config=full_configuration.persona_override_config,
|
||||
message_ts_to_respond_to=None,
|
||||
retrieval_details=RetrievalDetails(),
|
||||
rerank_settings=None,
|
||||
db_session=db_session,
|
||||
skip_gen_ai_answer_generation=False,
|
||||
forced_tool_id = forced_tool_ids[0] if forced_tool_ids else None
|
||||
request = SendMessageRequest(
|
||||
message=eval_input["message"],
|
||||
llm_override=llm_override,
|
||||
allowed_tool_ids=full_configuration.allowed_tool_ids,
|
||||
forced_tool_ids=forced_tool_ids or None,
|
||||
forced_tool_id=forced_tool_id,
|
||||
chat_session_info=ChatSessionCreationRequest(
|
||||
persona_id=DEFAULT_PERSONA_ID,
|
||||
description="Eval session",
|
||||
),
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
packets = handle_stream_message_objects(
|
||||
new_msg_req=request,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
|
||||
0
backend/onyx/image_gen/__init__.py
Normal file
0
backend/onyx/image_gen/__init__.py
Normal file
6
backend/onyx/image_gen/exceptions.py
Normal file
6
backend/onyx/image_gen/exceptions.py
Normal file
@@ -0,0 +1,6 @@
|
||||
class ImageProviderError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ImageProviderCredentialsError(ImageProviderError):
|
||||
pass
|
||||
44
backend/onyx/image_gen/factory.py
Normal file
44
backend/onyx/image_gen/factory.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from enum import Enum
|
||||
|
||||
from onyx.image_gen.interfaces import ImageGenerationProvider
|
||||
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
|
||||
from onyx.image_gen.providers.azure_img_gen import AzureImageGenerationProvider
|
||||
from onyx.image_gen.providers.openai_img_gen import OpenAIImageGenerationProvider
|
||||
from onyx.image_gen.providers.vertex_img_gen import VertexImageGenerationProvider
|
||||
|
||||
|
||||
class ImageGenerationProviderName(str, Enum):
|
||||
AZURE = "azure"
|
||||
OPENAI = "openai"
|
||||
VERTEX_AI = "vertex_ai"
|
||||
|
||||
|
||||
PROVIDERS: dict[ImageGenerationProviderName, type[ImageGenerationProvider]] = {
|
||||
ImageGenerationProviderName.AZURE: AzureImageGenerationProvider,
|
||||
ImageGenerationProviderName.OPENAI: OpenAIImageGenerationProvider,
|
||||
ImageGenerationProviderName.VERTEX_AI: VertexImageGenerationProvider,
|
||||
}
|
||||
|
||||
|
||||
def get_image_generation_provider(
|
||||
provider: str,
|
||||
credentials: ImageGenerationProviderCredentials,
|
||||
) -> ImageGenerationProvider:
|
||||
provider_cls = _get_provider_cls(provider)
|
||||
return provider_cls.build_from_credentials(credentials)
|
||||
|
||||
|
||||
def validate_credentials(
|
||||
provider: str,
|
||||
credentials: ImageGenerationProviderCredentials,
|
||||
) -> bool:
|
||||
provider_cls = _get_provider_cls(provider)
|
||||
return provider_cls.validate_credentials(credentials)
|
||||
|
||||
|
||||
def _get_provider_cls(provider: str) -> type[ImageGenerationProvider]:
|
||||
try:
|
||||
provider_enum = ImageGenerationProviderName(provider)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid image generation provider: {provider}")
|
||||
return PROVIDERS[provider_enum]
|
||||
69
backend/onyx/image_gen/interfaces.py
Normal file
69
backend/onyx/image_gen/interfaces.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.image_gen.exceptions import ImageProviderCredentialsError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.utils import ImageResponse as ImageGenerationResponse
|
||||
|
||||
|
||||
class ImageGenerationProviderCredentials(BaseModel):
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
deployment_name: str | None = None
|
||||
custom_config: dict[str, str] | None = None
|
||||
|
||||
|
||||
class ImageGenerationProvider(abc.ABC):
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def validate_credentials(
|
||||
cls,
|
||||
credentials: ImageGenerationProviderCredentials,
|
||||
) -> bool:
|
||||
"""Returns true if sufficient credentials are given to build this provider."""
|
||||
raise NotImplementedError("validate_credentials not implemented")
|
||||
|
||||
@classmethod
|
||||
def build_from_credentials(
|
||||
cls,
|
||||
credentials: ImageGenerationProviderCredentials,
|
||||
) -> ImageGenerationProvider:
|
||||
if not cls.validate_credentials(credentials):
|
||||
raise ImageProviderCredentialsError(
|
||||
f"Invalid image generation credentials: {credentials}"
|
||||
)
|
||||
return cls._build_from_credentials(credentials)
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def _build_from_credentials(
|
||||
cls,
|
||||
credentials: ImageGenerationProviderCredentials,
|
||||
) -> ImageGenerationProvider:
|
||||
"""
|
||||
Given credentials, builds an instance of the provider.
|
||||
Should NOT be called directly - use build_from_credentials instead.
|
||||
|
||||
AssertionError if credentials are invalid.
|
||||
"""
|
||||
raise NotImplementedError("build_from_credentials not implemented")
|
||||
|
||||
@abc.abstractmethod
|
||||
def generate_image(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
size: str,
|
||||
n: int,
|
||||
quality: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ImageGenerationResponse:
|
||||
"""Generates an image based on a prompt."""
|
||||
raise NotImplementedError("generate_image not implemented")
|
||||
79
backend/onyx/image_gen/providers/azure_img_gen.py
Normal file
79
backend/onyx/image_gen/providers/azure_img_gen.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from onyx.image_gen.interfaces import ImageGenerationProvider
|
||||
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.image_gen.interfaces import ImageGenerationResponse
|
||||
|
||||
|
||||
class AzureImageGenerationProvider(ImageGenerationProvider):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
api_version: str,
|
||||
deployment_name: str | None = None,
|
||||
):
|
||||
self._api_key = api_key
|
||||
self._api_base = api_base
|
||||
self._api_version = api_version
|
||||
self._deployment_name = deployment_name
|
||||
|
||||
@classmethod
|
||||
def validate_credentials(
|
||||
cls,
|
||||
credentials: ImageGenerationProviderCredentials,
|
||||
) -> bool:
|
||||
return all(
|
||||
[
|
||||
credentials.api_key,
|
||||
credentials.api_base,
|
||||
credentials.api_version,
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _build_from_credentials(
|
||||
cls,
|
||||
credentials: ImageGenerationProviderCredentials,
|
||||
) -> AzureImageGenerationProvider:
|
||||
assert credentials.api_key
|
||||
assert credentials.api_base
|
||||
assert credentials.api_version
|
||||
|
||||
return cls(
|
||||
api_key=credentials.api_key,
|
||||
api_base=credentials.api_base,
|
||||
api_version=credentials.api_version,
|
||||
deployment_name=credentials.deployment_name,
|
||||
)
|
||||
|
||||
def generate_image(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
size: str,
|
||||
n: int,
|
||||
quality: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ImageGenerationResponse:
|
||||
from litellm import image_generation
|
||||
|
||||
deployment = self._deployment_name or model
|
||||
model_name = f"azure/{deployment}"
|
||||
|
||||
return image_generation(
|
||||
prompt=prompt,
|
||||
model=model_name,
|
||||
api_key=self._api_key,
|
||||
api_base=self._api_base,
|
||||
api_version=self._api_version,
|
||||
size=size,
|
||||
n=n,
|
||||
quality=quality,
|
||||
**kwargs,
|
||||
)
|
||||
61
backend/onyx/image_gen/providers/openai_img_gen.py
Normal file
61
backend/onyx/image_gen/providers/openai_img_gen.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from onyx.image_gen.interfaces import ImageGenerationProvider
|
||||
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.image_gen.interfaces import ImageGenerationResponse
|
||||
|
||||
|
||||
class OpenAIImageGenerationProvider(ImageGenerationProvider):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
api_base: str | None = None,
|
||||
):
|
||||
self._api_key = api_key
|
||||
self._api_base = api_base
|
||||
|
||||
@classmethod
|
||||
def validate_credentials(
|
||||
cls,
|
||||
credentials: ImageGenerationProviderCredentials,
|
||||
) -> bool:
|
||||
return bool(credentials.api_key)
|
||||
|
||||
@classmethod
|
||||
def _build_from_credentials(
|
||||
cls,
|
||||
credentials: ImageGenerationProviderCredentials,
|
||||
) -> OpenAIImageGenerationProvider:
|
||||
assert credentials.api_key
|
||||
|
||||
return cls(
|
||||
api_key=credentials.api_key,
|
||||
api_base=credentials.api_base,
|
||||
)
|
||||
|
||||
def generate_image(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
size: str,
|
||||
n: int,
|
||||
quality: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ImageGenerationResponse:
|
||||
from litellm import image_generation
|
||||
|
||||
return image_generation(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
api_key=self._api_key,
|
||||
api_base=self._api_base,
|
||||
size=size,
|
||||
n=n,
|
||||
quality=quality,
|
||||
**kwargs,
|
||||
)
|
||||
105
backend/onyx/image_gen/providers/vertex_img_gen.py
Normal file
105
backend/onyx/image_gen/providers/vertex_img_gen.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.image_gen.exceptions import ImageProviderCredentialsError
|
||||
from onyx.image_gen.interfaces import ImageGenerationProvider
|
||||
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.image_gen.interfaces import ImageGenerationResponse
|
||||
|
||||
|
||||
class VertexCredentials(BaseModel):
|
||||
vertex_credentials: str
|
||||
vertex_location: str
|
||||
project_id: str
|
||||
|
||||
|
||||
class VertexImageGenerationProvider(ImageGenerationProvider):
|
||||
def __init__(
|
||||
self,
|
||||
vertex_credentials: VertexCredentials,
|
||||
):
|
||||
self._vertex_credentials = vertex_credentials.vertex_credentials
|
||||
self._vertex_location = vertex_credentials.vertex_location
|
||||
self._vertex_project = vertex_credentials.project_id
|
||||
|
||||
@classmethod
|
||||
def validate_credentials(
|
||||
cls,
|
||||
credentials: ImageGenerationProviderCredentials,
|
||||
) -> bool:
|
||||
try:
|
||||
_parse_to_vertex_credentials(credentials)
|
||||
return True
|
||||
except ImageProviderCredentialsError:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _build_from_credentials(
|
||||
cls,
|
||||
credentials: ImageGenerationProviderCredentials,
|
||||
) -> VertexImageGenerationProvider:
|
||||
vertex_credentials = _parse_to_vertex_credentials(credentials)
|
||||
|
||||
return cls(
|
||||
vertex_credentials=vertex_credentials,
|
||||
)
|
||||
|
||||
def generate_image(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
size: str,
|
||||
n: int,
|
||||
quality: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ImageGenerationResponse:
|
||||
from litellm import image_generation
|
||||
|
||||
return image_generation(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
size=size,
|
||||
n=n,
|
||||
quality=quality,
|
||||
vertex_location=self._vertex_location,
|
||||
vertex_credentials=self._vertex_credentials,
|
||||
vertex_project=self._vertex_project,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _parse_to_vertex_credentials(
|
||||
credentials: ImageGenerationProviderCredentials,
|
||||
) -> VertexCredentials:
|
||||
custom_config = credentials.custom_config
|
||||
|
||||
if not custom_config:
|
||||
raise ImageProviderCredentialsError("Custom config is required")
|
||||
|
||||
vertex_credentials = custom_config.get("vertex_credentials")
|
||||
vertex_location = custom_config.get("vertex_location")
|
||||
|
||||
if not vertex_credentials:
|
||||
raise ImageProviderCredentialsError("Vertex credentials are required")
|
||||
|
||||
if not vertex_location:
|
||||
raise ImageProviderCredentialsError("Vertex location is required")
|
||||
|
||||
vertex_json = json.loads(vertex_credentials)
|
||||
vertex_project = vertex_json.get("project_id")
|
||||
|
||||
if not vertex_project:
|
||||
raise ImageProviderCredentialsError("Project ID is required")
|
||||
|
||||
return VertexCredentials(
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_location=vertex_location,
|
||||
project_id=vertex_project,
|
||||
)
|
||||
@@ -51,6 +51,7 @@ from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_bac
|
||||
from onyx.llm.factory import get_default_llm_with_vision
|
||||
from onyx.llm.factory import get_llm_for_contextual_rag
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.llm.multi_llm import LLMRateLimitError
|
||||
from onyx.llm.utils import llm_response_to_string
|
||||
from onyx.llm.utils import MAX_CONTEXT_TOKENS
|
||||
@@ -492,7 +493,7 @@ def add_document_summaries(
|
||||
# So we just pass the full prompt without caching
|
||||
summary_prompt = DOCUMENT_SUMMARY_PROMPT.format(document=doc_content)
|
||||
doc_summary = llm_response_to_string(
|
||||
llm.invoke(summary_prompt, max_tokens=MAX_CONTEXT_TOKENS)
|
||||
llm.invoke(UserMessage(content=summary_prompt), max_tokens=MAX_CONTEXT_TOKENS)
|
||||
)
|
||||
|
||||
for chunk in chunks_by_doc:
|
||||
@@ -534,7 +535,9 @@ def add_chunk_summaries(
|
||||
# In this case we compute a doc summary using the LLM
|
||||
doc_info = llm_response_to_string(
|
||||
llm.invoke(
|
||||
DOCUMENT_SUMMARY_PROMPT.format(document=doc_content),
|
||||
UserMessage(
|
||||
content=DOCUMENT_SUMMARY_PROMPT.format(document=doc_content)
|
||||
),
|
||||
max_tokens=MAX_CONTEXT_TOKENS,
|
||||
)
|
||||
)
|
||||
@@ -550,8 +553,8 @@ def add_chunk_summaries(
|
||||
# For string inputs with continuation=True, the result will be a concatenated string
|
||||
processed_prompt, _ = process_with_prompt_cache(
|
||||
llm_config=llm.config,
|
||||
cacheable_prefix=context_prompt1,
|
||||
suffix=context_prompt2,
|
||||
cacheable_prefix=UserMessage(content=context_prompt1),
|
||||
suffix=UserMessage(content=context_prompt2),
|
||||
continuation=True, # Append chunk to the document context
|
||||
)
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ class BaseChunk(BaseModel):
|
||||
image_file_id: str | None
|
||||
# True if this Chunk's start is not at the start of a Section
|
||||
# TODO(andrei): This is deprecated as of the OpenSearch migration. Remove.
|
||||
# Do not use.
|
||||
section_continuation: bool
|
||||
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ from onyx.kg.utils.formatting_utils import make_relationship_id
|
||||
from onyx.kg.utils.formatting_utils import make_relationship_type_id
|
||||
from onyx.kg.vespa.vespa_interactions import get_document_vespa_contents
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.llm.utils import llm_response_to_string
|
||||
from onyx.prompts.kg_prompts import CALL_CHUNK_PREPROCESSING_PROMPT
|
||||
from onyx.prompts.kg_prompts import CALL_DOCUMENT_CLASSIFICATION_PROMPT
|
||||
@@ -417,7 +418,9 @@ def kg_classify_document(
|
||||
# classify with LLM
|
||||
llm = get_default_llm()
|
||||
try:
|
||||
raw_classification_result = llm_response_to_string(llm.invoke(prompt))
|
||||
raw_classification_result = llm_response_to_string(
|
||||
llm.invoke(UserMessage(content=prompt))
|
||||
)
|
||||
classification_result = (
|
||||
raw_classification_result.replace("```json", "").replace("```", "").strip()
|
||||
)
|
||||
@@ -481,7 +484,9 @@ def kg_deep_extract_chunks(
|
||||
# extract with LLM
|
||||
llm = get_default_llm()
|
||||
try:
|
||||
raw_extraction_result = llm_response_to_string(llm.invoke(prompt))
|
||||
raw_extraction_result = llm_response_to_string(
|
||||
llm.invoke(UserMessage(content=prompt))
|
||||
)
|
||||
cleaned_response = (
|
||||
raw_extraction_result.replace("{{", "{")
|
||||
.replace("}}", "}")
|
||||
|
||||
@@ -96,6 +96,7 @@ def get_llm_config_for_persona(
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
deployment_name=llm_provider.deployment_name,
|
||||
custom_config=llm_provider.custom_config,
|
||||
max_input_tokens=max_input_tokens,
|
||||
)
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ class LLMConfig(BaseModel):
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
deployment_name: str | None = None
|
||||
credentials_file: str | None = None
|
||||
custom_config: dict[str, str] | None = None
|
||||
max_input_tokens: int
|
||||
# This disables the "model_" protected namespace for pydantic
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
@@ -369,8 +369,6 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
# New output item added
|
||||
output_item = parsed_chunk.get("item", {})
|
||||
if output_item.get("type") == "function_call":
|
||||
# Track that we've received tool calls via streaming
|
||||
self._has_streamed_tool_calls = True
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=ChatCompletionToolCallChunk(
|
||||
@@ -396,8 +394,6 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
content_part: Optional[str] = parsed_chunk.get("delta", None)
|
||||
if content_part:
|
||||
# Track that we've received tool calls via streaming
|
||||
self._has_streamed_tool_calls = True
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=ChatCompletionToolCallChunk(
|
||||
@@ -495,71 +491,21 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
|
||||
elif event_type == "response.completed":
|
||||
# Final event signaling all output items (including parallel tool calls) are done
|
||||
# Check if we already received tool calls via streaming events
|
||||
# There is an issue where OpenAI (not via Azure) will give back the tool calls streamed out as tokens
|
||||
# But on Azure, it's only given out all at once. OpenAI also happens to give back the tool calls in the
|
||||
# response.completed event so we need to throw it out here or there are duplicate tool calls.
|
||||
has_streamed_tool_calls = getattr(self, "_has_streamed_tool_calls", False)
|
||||
|
||||
response_data = parsed_chunk.get("response", {})
|
||||
output_items = response_data.get("output", [])
|
||||
|
||||
# Check if there are function_call items in the output
|
||||
has_function_calls = any(
|
||||
isinstance(item, dict) and item.get("type") == "function_call"
|
||||
for item in output_items
|
||||
)
|
||||
|
||||
if has_function_calls and not has_streamed_tool_calls:
|
||||
# Azure's Responses API returns all tool calls in response.completed
|
||||
# without streaming them incrementally. Extract them here.
|
||||
from litellm.types.utils import (
|
||||
Delta,
|
||||
ModelResponseStream,
|
||||
StreamingChoices,
|
||||
)
|
||||
|
||||
tool_calls = []
|
||||
for idx, item in enumerate(output_items):
|
||||
# Determine finish reason based on response content
|
||||
finish_reason = "stop"
|
||||
if response_data.get("output"):
|
||||
for item in response_data["output"]:
|
||||
if isinstance(item, dict) and item.get("type") == "function_call":
|
||||
tool_calls.append(
|
||||
ChatCompletionToolCallChunk(
|
||||
id=item.get("call_id"),
|
||||
index=idx,
|
||||
type="function",
|
||||
function=ChatCompletionToolCallFunctionChunk(
|
||||
name=item.get("name"),
|
||||
arguments=item.get("arguments", ""),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return ModelResponseStream(
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
index=0,
|
||||
delta=Delta(tool_calls=tool_calls),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
]
|
||||
)
|
||||
elif has_function_calls:
|
||||
# Tool calls were already streamed, just signal completion
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason="tool_calls",
|
||||
usage=None,
|
||||
)
|
||||
else:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason="stop",
|
||||
usage=None,
|
||||
)
|
||||
finish_reason = "tool_calls"
|
||||
break
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason=finish_reason,
|
||||
usage=None,
|
||||
)
|
||||
|
||||
else:
|
||||
pass
|
||||
@@ -685,40 +631,6 @@ def _patch_openai_responses_transform_response() -> None:
|
||||
LiteLLMResponsesTransformationHandler.transform_response = _patched_transform_response # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _patch_azure_responses_should_fake_stream() -> None:
|
||||
"""
|
||||
Patches AzureOpenAIResponsesAPIConfig.should_fake_stream to always return False.
|
||||
|
||||
By default, LiteLLM uses "fake streaming" (MockResponsesAPIStreamingIterator) for models
|
||||
not in its database. This causes Azure custom model deployments to buffer the entire
|
||||
response before yielding, resulting in poor time-to-first-token.
|
||||
|
||||
Azure's Responses API supports native streaming, so we override this to always use
|
||||
real streaming (SyncResponsesAPIStreamingIterator).
|
||||
"""
|
||||
from litellm.llms.azure.responses.transformation import (
|
||||
AzureOpenAIResponsesAPIConfig,
|
||||
)
|
||||
|
||||
if (
|
||||
getattr(AzureOpenAIResponsesAPIConfig.should_fake_stream, "__name__", "")
|
||||
== "_patched_should_fake_stream"
|
||||
):
|
||||
return
|
||||
|
||||
def _patched_should_fake_stream(
|
||||
self: Any,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
# Azure Responses API supports native streaming - never fake it
|
||||
return False
|
||||
|
||||
_patched_should_fake_stream.__name__ = "_patched_should_fake_stream"
|
||||
AzureOpenAIResponsesAPIConfig.should_fake_stream = _patched_should_fake_stream # type: ignore[method-assign]
|
||||
|
||||
|
||||
def apply_monkey_patches() -> None:
|
||||
"""
|
||||
Apply all necessary monkey patches to LiteLLM for compatibility.
|
||||
@@ -728,13 +640,12 @@ def apply_monkey_patches() -> None:
|
||||
- Patching OllamaChatCompletionResponseIterator.chunk_parser for streaming content
|
||||
- Patching OpenAiResponsesToChatCompletionStreamIterator.chunk_parser for OpenAI Responses API
|
||||
- Patching LiteLLMResponsesTransformationHandler.transform_response for non-streaming responses
|
||||
- Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming
|
||||
- Patching LiteLLMResponsesTransformationHandler._convert_content_str_to_input_text for tool content types
|
||||
"""
|
||||
_patch_ollama_transform_request()
|
||||
_patch_ollama_chunk_parser()
|
||||
_patch_openai_responses_chunk_parser()
|
||||
_patch_openai_responses_transform_response()
|
||||
_patch_azure_responses_should_fake_stream()
|
||||
|
||||
|
||||
def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[str]]:
|
||||
|
||||
@@ -104,4 +104,4 @@ class ToolMessage(CacheableMessage):
|
||||
# Union type for all OpenAI Chat Completions messages
|
||||
ChatCompletionMessage = SystemMessage | UserMessage | AssistantMessage | ToolMessage
|
||||
# Allows for passing in a string directly. This is provided for convenience and is wrapped as a UserMessage.
|
||||
LanguageModelInput = list[ChatCompletionMessage] | str
|
||||
LanguageModelInput = list[ChatCompletionMessage] | ChatCompletionMessage
|
||||
|
||||
@@ -64,9 +64,9 @@ def _prompt_to_dicts(prompt: LanguageModelInput) -> list[dict[str, Any]]:
|
||||
LiteLLM expects messages to be dictionaries (with .get() method),
|
||||
not Pydantic models. This function serializes the messages.
|
||||
"""
|
||||
if isinstance(prompt, str):
|
||||
return [{"role": "user", "content": prompt}]
|
||||
return [msg.model_dump(exclude_none=True) for msg in prompt]
|
||||
if isinstance(prompt, list):
|
||||
return [msg.model_dump(exclude_none=True) for msg in prompt]
|
||||
return [prompt.model_dump(exclude_none=True)]
|
||||
|
||||
|
||||
def _prompt_as_json(prompt: LanguageModelInput) -> JSON_ro:
|
||||
@@ -164,9 +164,13 @@ class LitellmLLM(LLM):
|
||||
def _safe_model_config(self) -> dict:
|
||||
dump = self.config.model_dump()
|
||||
dump["api_key"] = mask_string(dump.get("api_key") or "")
|
||||
credentials_file = dump.get("credentials_file")
|
||||
if isinstance(credentials_file, str) and credentials_file:
|
||||
dump["credentials_file"] = mask_string(credentials_file)
|
||||
custom_config = dump.get("custom_config")
|
||||
if isinstance(custom_config, dict):
|
||||
# Mask sensitive values in custom_config
|
||||
masked_config = {}
|
||||
for k, v in custom_config.items():
|
||||
masked_config[k] = mask_string(v) if v else v
|
||||
dump["custom_config"] = masked_config
|
||||
return dump
|
||||
|
||||
def _record_call(
|
||||
@@ -402,12 +406,6 @@ class LitellmLLM(LLM):
|
||||
|
||||
@property
|
||||
def config(self) -> LLMConfig:
|
||||
credentials_file: str | None = (
|
||||
self._custom_config.get(VERTEX_CREDENTIALS_FILE_KWARG, None)
|
||||
if self._custom_config
|
||||
else None
|
||||
)
|
||||
|
||||
return LLMConfig(
|
||||
model_provider=self._model_provider,
|
||||
model_name=self._model_version,
|
||||
@@ -416,7 +414,7 @@ class LitellmLLM(LLM):
|
||||
api_base=self._api_base,
|
||||
api_version=self._api_version,
|
||||
deployment_name=self._deployment_name,
|
||||
credentials_file=credentials_file,
|
||||
custom_config=self._custom_config,
|
||||
max_input_tokens=self._max_input_tokens,
|
||||
)
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ from onyx.llm.prompt_cache.providers.noop import NoOpPromptCacheProvider
|
||||
from onyx.llm.prompt_cache.providers.openai import OpenAIPromptCacheProvider
|
||||
from onyx.llm.prompt_cache.providers.vertex import VertexAIPromptCacheProvider
|
||||
from onyx.llm.prompt_cache.utils import combine_messages_with_continuation
|
||||
from onyx.llm.prompt_cache.utils import normalize_language_model_input
|
||||
from onyx.llm.prompt_cache.utils import prepare_messages_with_cacheable_transform
|
||||
|
||||
__all__ = [
|
||||
@@ -26,7 +25,6 @@ __all__ = [
|
||||
"combine_messages_with_continuation",
|
||||
"generate_cache_key_hash",
|
||||
"get_provider_adapter",
|
||||
"normalize_language_model_input",
|
||||
"NoOpPromptCacheProvider",
|
||||
"OpenAIPromptCacheProvider",
|
||||
"prepare_messages_with_cacheable_transform",
|
||||
|
||||
@@ -9,7 +9,6 @@ from onyx.configs.model_configs import PROMPT_CACHE_REDIS_TTL_MULTIPLIER
|
||||
from onyx.key_value_store.store import PgRedisKVStore
|
||||
from onyx.llm.interfaces import LanguageModelInput
|
||||
from onyx.llm.prompt_cache.models import CacheMetadata
|
||||
from onyx.llm.prompt_cache.utils import normalize_language_model_input
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
@@ -196,7 +195,7 @@ def generate_cache_key_hash(
|
||||
"""Generate a deterministic cache key hash from cacheable prefix.
|
||||
|
||||
Args:
|
||||
cacheable_prefix: LanguageModelInput (str or Sequence[ChatCompletionMessage])
|
||||
cacheable_prefix: Single message or list of messages to hash
|
||||
provider: LLM provider name
|
||||
model_name: Model name
|
||||
tenant_id: Tenant ID
|
||||
@@ -204,10 +203,11 @@ def generate_cache_key_hash(
|
||||
Returns:
|
||||
SHA256 hash as hex string
|
||||
"""
|
||||
# Normalize to Sequence[ChatCompletionMessage] for consistent hashing
|
||||
messages = normalize_language_model_input(cacheable_prefix)
|
||||
# Convert to list of dicts for serialization, handling nested Pydantic models
|
||||
messages_dict = [_make_json_serializable(dict(msg)) for msg in messages]
|
||||
# Normalize to list for consistent hashing; _make_json_serializable handles Pydantic models
|
||||
messages = (
|
||||
cacheable_prefix if isinstance(cacheable_prefix, list) else [cacheable_prefix]
|
||||
)
|
||||
messages_dict = [_make_json_serializable(msg) for msg in messages]
|
||||
|
||||
# Serialize messages in a deterministic way
|
||||
# Include only content, roles, and order - exclude timestamps or dynamic fields
|
||||
|
||||
@@ -8,51 +8,29 @@ from typing import Any
|
||||
|
||||
from onyx.llm.models import ChatCompletionMessage
|
||||
from onyx.llm.models import LanguageModelInput
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def normalize_language_model_input(
|
||||
input: LanguageModelInput,
|
||||
) -> Sequence[ChatCompletionMessage]:
|
||||
"""Normalize LanguageModelInput to Sequence[ChatCompletionMessage].
|
||||
|
||||
Args:
|
||||
input: LanguageModelInput (str or Sequence[ChatCompletionMessage])
|
||||
|
||||
Returns:
|
||||
Sequence of ChatCompletionMessage objects
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
# Convert string to user message
|
||||
return [UserMessage(role="user", content=input)]
|
||||
return input
|
||||
|
||||
|
||||
def combine_messages_with_continuation(
|
||||
prefix_msgs: Sequence[ChatCompletionMessage],
|
||||
suffix_msgs: Sequence[ChatCompletionMessage],
|
||||
continuation: bool,
|
||||
was_prefix_string: bool,
|
||||
) -> list[ChatCompletionMessage]:
|
||||
"""Combine prefix and suffix messages, handling continuation flag.
|
||||
|
||||
Args:
|
||||
prefix_msgs: Normalized cacheable prefix messages
|
||||
suffix_msgs: Normalized suffix messages
|
||||
continuation: If True and prefix is not a string, append suffix content
|
||||
to the last message of prefix
|
||||
was_prefix_string: Whether the original prefix was a string (strings
|
||||
remain in their own content block even if continuation=True)
|
||||
continuation: If True, append suffix content to the last message of prefix
|
||||
was_prefix_string: Deprecated, no longer used
|
||||
|
||||
Returns:
|
||||
Combined messages
|
||||
"""
|
||||
if not continuation or not prefix_msgs or was_prefix_string:
|
||||
# Simple concatenation (or prefix was a string, so keep separate)
|
||||
if not continuation or not prefix_msgs:
|
||||
return list(prefix_msgs) + list(suffix_msgs)
|
||||
# Append suffix content to last message of prefix
|
||||
result = list(prefix_msgs)
|
||||
@@ -130,16 +108,15 @@ def prepare_messages_with_cacheable_transform(
|
||||
if cacheable_prefix is None:
|
||||
return suffix
|
||||
|
||||
prefix_msgs = normalize_language_model_input(cacheable_prefix)
|
||||
suffix_msgs = normalize_language_model_input(suffix)
|
||||
prefix_msgs = (
|
||||
cacheable_prefix if isinstance(cacheable_prefix, list) else [cacheable_prefix]
|
||||
)
|
||||
suffix_msgs = suffix if isinstance(suffix, list) else [suffix]
|
||||
|
||||
# Apply transformation to cacheable messages if provided
|
||||
if transform_cacheable is not None:
|
||||
prefix_msgs = transform_cacheable(prefix_msgs)
|
||||
|
||||
# Handle continuation flag
|
||||
was_prefix_string = isinstance(cacheable_prefix, str)
|
||||
prefix_msgs = list(transform_cacheable(prefix_msgs))
|
||||
|
||||
return combine_messages_with_continuation(
|
||||
prefix_msgs, suffix_msgs, continuation, was_prefix_string
|
||||
prefix_msgs=prefix_msgs, suffix_msgs=suffix_msgs, continuation=continuation
|
||||
)
|
||||
|
||||
@@ -23,6 +23,7 @@ from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.model_response import ModelResponse
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_TOKEN_ESTIMATE
|
||||
from onyx.prompts.contextual_retrieval import DOCUMENT_SUMMARY_TOKEN_ESTIMATE
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -320,7 +321,7 @@ def test_llm(llm: LLM) -> str | None:
|
||||
error_msg = None
|
||||
for _ in range(2):
|
||||
try:
|
||||
llm.invoke("Do not respond")
|
||||
llm.invoke(UserMessage(content="Do not respond"))
|
||||
return None
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
|
||||
@@ -117,7 +117,6 @@ from onyx.server.middleware.rate_limiting import setup_auth_limiter
|
||||
from onyx.server.onyx_api.ingestion import router as onyx_api_router
|
||||
from onyx.server.pat.api import router as pat_router
|
||||
from onyx.server.query_and_chat.chat_backend import router as chat_router
|
||||
from onyx.server.query_and_chat.chat_backend_v0 import router as chat_v0_router
|
||||
from onyx.server.query_and_chat.query_backend import (
|
||||
admin_router as admin_query_router,
|
||||
)
|
||||
@@ -365,7 +364,6 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
|
||||
include_router_with_global_prefix_prepended(application, password_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_v0_router)
|
||||
include_router_with_global_prefix_prepended(application, query_router)
|
||||
include_router_with_global_prefix_prepended(application, document_router)
|
||||
include_router_with_global_prefix_prepended(application, user_router)
|
||||
|
||||
@@ -151,8 +151,6 @@ Expected response:
|
||||
- `MCP_SERVER_CORS_ORIGINS`: Comma-separated CORS origins (optional)
|
||||
|
||||
**API Server Connection:**
|
||||
- `API_SERVER_BASE_URL`: Full API base URL (e.g., `https://cloud.onyx.app/api`). If set, overrides protocol/host/port below.
|
||||
- `ONYX_URL`: Alternative to `API_SERVER_BASE_URL` (same purpose, either can be used)
|
||||
- `API_SERVER_PROTOCOL`: Protocol for internal API calls (default: "http")
|
||||
- `API_SERVER_HOST`: Host for internal API calls (default: "127.0.0.1")
|
||||
- `API_SERVER_PORT`: Port for internal API calls (default: 8080)
|
||||
- `API_SERVER_PROTOCOL`: Protocol for API server connection (default: "http")
|
||||
- `API_SERVER_HOST`: Hostname for API server connection (default: "127.0.0.1")
|
||||
- `API_SERVER_URL_OVERRIDE_FOR_HTTP_REQUESTS`: Optional override URL. If set, takes precedence over the protocol/host variables. Used for self-hosting the MCP server with Onyx Cloud as the backend.
|
||||
@@ -5,9 +5,9 @@ from typing import Optional
|
||||
from fastmcp.server.auth.auth import AccessToken
|
||||
from fastmcp.server.auth.auth import TokenVerifier
|
||||
|
||||
from onyx.mcp_server.utils import get_api_server_url
|
||||
from onyx.mcp_server.utils import get_http_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import build_api_server_url_for_http_requests
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -19,7 +19,7 @@ class OnyxTokenVerifier(TokenVerifier):
|
||||
"""Call API /me to verify the token, return minimal AccessToken on success."""
|
||||
try:
|
||||
response = await get_http_client().get(
|
||||
f"{get_api_server_url()}/me",
|
||||
f"{build_api_server_url_for_http_requests(respect_env_override_if_set=True)}/me",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
except Exception as exc:
|
||||
|
||||
@@ -4,17 +4,12 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.mcp_server.api import mcp_server
|
||||
from onyx.mcp_server.utils import get_api_server_url
|
||||
from onyx.mcp_server.utils import get_http_client
|
||||
from onyx.mcp_server.utils import get_indexed_sources
|
||||
from onyx.mcp_server.utils import require_access_token
|
||||
from onyx.server.query_and_chat.models import DocumentSearchRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import build_api_server_url_for_http_requests
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -107,32 +102,44 @@ async def search_indexed_documents(
|
||||
f"Onyx MCP Server: Invalid source type '{src}' - will be ignored by server"
|
||||
)
|
||||
|
||||
search_request = DocumentSearchRequest(
|
||||
message=query,
|
||||
search_type=SearchType.SEMANTIC,
|
||||
retrieval_options=RetrievalDetails(
|
||||
filters=IndexFilters(
|
||||
source_type=source_type_enums,
|
||||
time_cutoff=time_cutoff_dt,
|
||||
access_control_list=None, # Server handles ACL using the access token
|
||||
),
|
||||
enable_auto_detect_filters=False,
|
||||
offset=0,
|
||||
limit=limit,
|
||||
),
|
||||
evaluation_type=LLMEvaluationType.SKIP,
|
||||
)
|
||||
# Build filters dict only with non-None values
|
||||
filters: dict[str, Any] | None = None
|
||||
if source_type_enums or time_cutoff_dt:
|
||||
filters = {}
|
||||
if source_type_enums:
|
||||
filters["source_type"] = [src.value for src in source_type_enums]
|
||||
if time_cutoff_dt:
|
||||
filters["time_cutoff"] = time_cutoff_dt.isoformat()
|
||||
|
||||
# Call the API server
|
||||
# Build the search request using the new SendSearchQueryRequest format
|
||||
search_request = {
|
||||
"search_query": query,
|
||||
"filters": filters,
|
||||
"num_docs_fed_to_llm_selection": limit,
|
||||
"run_query_expansion": False,
|
||||
"include_content": True,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
# Call the API server using the new send-search-message route
|
||||
try:
|
||||
response = await get_http_client().post(
|
||||
f"{get_api_server_url()}/query/document-search",
|
||||
json=search_request.model_dump(mode="json"),
|
||||
f"{build_api_server_url_for_http_requests(respect_env_override_if_set=True)}/search/send-search-message",
|
||||
json=search_request,
|
||||
headers={"Authorization": f"Bearer {access_token.token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
# Check for error in response
|
||||
if result.get("error"):
|
||||
return {
|
||||
"documents": [],
|
||||
"total_results": 0,
|
||||
"query": query,
|
||||
"error": result.get("error"),
|
||||
}
|
||||
|
||||
# Return simplified format for MCP clients
|
||||
fields_to_return = [
|
||||
"semantic_identifier",
|
||||
@@ -143,7 +150,7 @@ async def search_indexed_documents(
|
||||
]
|
||||
documents = [
|
||||
{key: doc.get(key) for key in fields_to_return}
|
||||
for doc in result.get("top_documents", [])
|
||||
for doc in result.get("search_docs", [])
|
||||
]
|
||||
|
||||
logger.info(
|
||||
@@ -153,6 +160,7 @@ async def search_indexed_documents(
|
||||
"documents": documents,
|
||||
"total_results": len(documents),
|
||||
"query": query,
|
||||
"executed_queries": result.get("all_executed_queries", [query]),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Onyx MCP Server: Document search error: {e}", exc_info=True)
|
||||
@@ -190,7 +198,7 @@ async def search_web(
|
||||
try:
|
||||
request_payload = {"queries": [query], "max_results": limit}
|
||||
response = await get_http_client().post(
|
||||
f"{get_api_server_url()}/web-search/search-lite",
|
||||
f"{build_api_server_url_for_http_requests(respect_env_override_if_set=True)}/web-search/search-lite",
|
||||
json=request_payload,
|
||||
headers={"Authorization": f"Bearer {access_token.token}"},
|
||||
)
|
||||
@@ -236,7 +244,7 @@ async def open_urls(
|
||||
|
||||
try:
|
||||
response = await get_http_client().post(
|
||||
f"{get_api_server_url()}/web-search/open-urls",
|
||||
f"{build_api_server_url_for_http_requests(respect_env_override_if_set=True)}/web-search/open-urls",
|
||||
json={"urls": urls},
|
||||
headers={"Authorization": f"Bearer {access_token.token}"},
|
||||
)
|
||||
|
||||
@@ -2,15 +2,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
import httpx
|
||||
from fastmcp.server.auth.auth import AccessToken
|
||||
from fastmcp.server.dependencies import get_access_token
|
||||
|
||||
from onyx.configs.app_configs import APP_API_PREFIX
|
||||
from onyx.configs.app_configs import APP_PORT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import build_api_server_url_for_http_requests
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -36,21 +33,6 @@ def require_access_token() -> AccessToken:
|
||||
return access_token
|
||||
|
||||
|
||||
def get_api_server_url() -> str:
|
||||
"""Construct the API server base URL for internal or external requests."""
|
||||
override = os.getenv("API_SERVER_BASE_URL") or os.getenv("ONYX_URL")
|
||||
if override:
|
||||
return override.rstrip("/")
|
||||
|
||||
protocol = os.getenv("API_SERVER_PROTOCOL", "http")
|
||||
host = os.getenv("API_SERVER_HOST", "127.0.0.1")
|
||||
port = os.getenv("API_SERVER_PORT", str(APP_PORT))
|
||||
prefix = (APP_API_PREFIX or "").strip("/")
|
||||
|
||||
base = f"{protocol}://{host}:{port}"
|
||||
return f"{base}/{prefix}" if prefix else base
|
||||
|
||||
|
||||
def get_http_client() -> httpx.AsyncClient:
|
||||
"""Return a shared async HTTP client."""
|
||||
global _http_client
|
||||
@@ -79,7 +61,7 @@ async def get_indexed_sources(
|
||||
headers = {"Authorization": f"Bearer {access_token.token}"}
|
||||
try:
|
||||
response = await get_http_client().get(
|
||||
f"{get_api_server_url()}/manage/indexed-sources",
|
||||
f"{build_api_server_url_for_http_requests(respect_env_override_if_set=True)}/manage/indexed-sources",
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -7,21 +7,16 @@ from typing import TypeVar
|
||||
from retry import retry
|
||||
from slack_sdk import WebClient
|
||||
|
||||
from onyx.chat.chat_utils import prepare_chat_message_request
|
||||
from onyx.chat.models import ChatBasicResponse
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.chat.process_message import gather_stream
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.chat.process_message import handle_stream_message_objects
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.onyxbot_configs import MAX_THREAD_CONTEXT_PERCENTAGE
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_DISABLE_DOCS_ONLY_ANSWER
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_DISPLAY_ERROR_MSGS
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_NUM_RETRIES
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
|
||||
from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.models import User
|
||||
@@ -30,12 +25,14 @@ from onyx.db.users import get_user_by_email
|
||||
from onyx.onyxbot.slack.blocks import build_slack_response_blocks
|
||||
from onyx.onyxbot.slack.handlers.utils import send_team_member_message
|
||||
from onyx.onyxbot.slack.models import SlackMessageInfo
|
||||
from onyx.onyxbot.slack.models import ThreadMessage
|
||||
from onyx.onyxbot.slack.utils import get_channel_name_from_id
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
|
||||
from onyx.onyxbot.slack.utils import SlackRateLimiter
|
||||
from onyx.onyxbot.slack.utils import update_emote_react
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.utils.logger import OnyxLoggingAdapter
|
||||
|
||||
srl = SlackRateLimiter()
|
||||
@@ -97,7 +94,6 @@ def handle_regular_answer(
|
||||
logger: OnyxLoggingAdapter,
|
||||
feedback_reminder_id: str | None,
|
||||
num_retries: int = ONYX_BOT_NUM_RETRIES,
|
||||
thread_context_percent: float = MAX_THREAD_CONTEXT_PERCENTAGE,
|
||||
should_respond_with_error_msgs: bool = ONYX_BOT_DISPLAY_ERROR_MSGS,
|
||||
disable_docs_only_answer: bool = ONYX_BOT_DISABLE_DOCS_ONLY_ANSWER,
|
||||
) -> bool:
|
||||
@@ -181,13 +177,13 @@ def handle_regular_answer(
|
||||
)
|
||||
@rate_limits(client=client, channel=channel, thread_ts=message_ts_to_respond_to)
|
||||
def _get_slack_answer(
|
||||
new_message_request: CreateChatMessageRequest,
|
||||
new_message_request: SendMessageRequest,
|
||||
slack_context_str: str | None,
|
||||
# pass in `None` to make the answer based on public documents only
|
||||
onyx_user: User | None,
|
||||
) -> ChatBasicResponse:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
packets = stream_chat_message_objects(
|
||||
packets = handle_stream_message_objects(
|
||||
new_msg_req=new_message_request,
|
||||
user=onyx_user,
|
||||
db_session=db_session,
|
||||
@@ -211,40 +207,24 @@ def handle_regular_answer(
|
||||
time_cutoff=None,
|
||||
)
|
||||
|
||||
# Default True because no other ways to apply filters in Slack (no nice UI)
|
||||
# Commenting this out because this is only available to the slackbot for now
|
||||
# later we plan to implement this at the persona level where this will get
|
||||
# commented back in
|
||||
# auto_detect_filters = (
|
||||
# persona.llm_filter_extraction if persona is not None else True
|
||||
# )
|
||||
auto_detect_filters = slack_channel_config.enable_auto_filters
|
||||
retrieval_details = RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=False,
|
||||
filters=filters,
|
||||
enable_auto_detect_filters=auto_detect_filters,
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
answer_request = prepare_chat_message_request(
|
||||
message_text=user_message.message,
|
||||
user=user,
|
||||
new_message_request = SendMessageRequest(
|
||||
message=user_message.message,
|
||||
allowed_tool_ids=None,
|
||||
forced_tool_id=None,
|
||||
file_descriptors=[],
|
||||
internal_search_filters=filters,
|
||||
deep_research=False,
|
||||
origin=MessageOrigin.SLACKBOT,
|
||||
chat_session_info=ChatSessionCreationRequest(
|
||||
persona_id=persona.id,
|
||||
# This is not used in the Slack flow, only in the answer API
|
||||
persona_override_config=None,
|
||||
message_ts_to_respond_to=message_ts_to_respond_to,
|
||||
retrieval_details=retrieval_details,
|
||||
rerank_settings=None, # Rerank customization supported in Slack flow
|
||||
db_session=db_session,
|
||||
origin=MessageOrigin.SLACKBOT,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
# if it's a DM or ephemeral message, answer based on private documents.
|
||||
# otherwise, answer based on public documents ONLY as to not leak information.
|
||||
can_search_over_private_docs = message_info.is_bot_dm or send_as_ephemeral
|
||||
answer = _get_slack_answer(
|
||||
new_message_request=answer_request,
|
||||
new_message_request=new_message_request,
|
||||
onyx_user=user if can_search_over_private_docs else None,
|
||||
slack_context_str=slack_context_str,
|
||||
)
|
||||
|
||||
@@ -25,14 +25,12 @@ from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import POD_NAME
|
||||
from onyx.configs.app_configs import POD_NAMESPACE
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.onyxbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_REPHRASE_MESSAGE
|
||||
from onyx.connectors.slack.utils import expert_info_from_slack_id
|
||||
from onyx.context.search.retrieval.search_runner import (
|
||||
download_nltk_data,
|
||||
@@ -83,6 +81,7 @@ from onyx.onyxbot.slack.handlers.handle_message import (
|
||||
from onyx.onyxbot.slack.handlers.handle_message import schedule_feedback_reminder
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.onyxbot.slack.models import SlackMessageInfo
|
||||
from onyx.onyxbot.slack.models import ThreadMessage
|
||||
from onyx.onyxbot.slack.utils import check_message_limit
|
||||
from onyx.onyxbot.slack.utils import decompose_action_id
|
||||
from onyx.onyxbot.slack.utils import get_channel_name_from_id
|
||||
@@ -90,7 +89,6 @@ from onyx.onyxbot.slack.utils import get_channel_type_from_id
|
||||
from onyx.onyxbot.slack.utils import get_onyx_bot_auth_ids
|
||||
from onyx.onyxbot.slack.utils import read_slack_thread
|
||||
from onyx.onyxbot.slack.utils import remove_onyx_bot_tag
|
||||
from onyx.onyxbot.slack.utils import rephrase_slack_message
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
|
||||
from onyx.onyxbot.slack.utils import TenantSocketModeClient
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -842,15 +840,7 @@ def build_request_details(
|
||||
|
||||
msg = remove_onyx_bot_tag(tenant_id, msg, client=client.web_client)
|
||||
|
||||
if ONYX_BOT_REPHRASE_MESSAGE:
|
||||
logger.info(f"Rephrasing Slack message. Original message: {msg}")
|
||||
try:
|
||||
msg = rephrase_slack_message(msg)
|
||||
logger.info(f"Rephrased message: {msg}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error while trying to rephrase the Slack message: {e}")
|
||||
else:
|
||||
logger.info(f"Received Slack message: {msg}")
|
||||
logger.info(f"Received Slack message: {msg}")
|
||||
|
||||
event_type = event.get("type")
|
||||
if event_type == "app_mention":
|
||||
@@ -880,7 +870,7 @@ def build_request_details(
|
||||
)
|
||||
|
||||
if thread_ts != message_ts and thread_ts is not None:
|
||||
thread_messages = read_slack_thread(
|
||||
thread_messages: list[ThreadMessage] = read_slack_thread(
|
||||
tenant_id=tenant_id,
|
||||
channel=channel,
|
||||
thread=thread_ts,
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.constants import MessageType
|
||||
|
||||
|
||||
class ChannelType(str, Enum):
|
||||
@@ -25,6 +25,12 @@ class SlackContext(BaseModel):
|
||||
message_ts: str | None = None # Used as request ID for log correlation
|
||||
|
||||
|
||||
class ThreadMessage(BaseModel):
|
||||
message: str
|
||||
sender: str | None = None
|
||||
role: MessageType = MessageType.USER
|
||||
|
||||
|
||||
class SlackMessageInfo(BaseModel):
|
||||
thread_messages: list[ThreadMessage]
|
||||
channel_to_respond: str
|
||||
|
||||
@@ -34,12 +34,9 @@ from onyx.configs.onyxbot_configs import (
|
||||
from onyx.connectors.slack.utils import SlackTextCleaner
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.llm.utils import llm_response_to_string
|
||||
from onyx.onyxbot.slack.constants import FeedbackVisibility
|
||||
from onyx.onyxbot.slack.models import ChannelType
|
||||
from onyx.onyxbot.slack.models import ThreadMessage
|
||||
from onyx.prompts.miscellaneous_prompts import SLACK_LANGUAGE_REPHRASE_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
@@ -140,15 +137,6 @@ def check_message_limit() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def rephrase_slack_message(msg: str) -> str:
|
||||
llm = get_default_llm(timeout=5)
|
||||
prompt = SLACK_LANGUAGE_REPHRASE_PROMPT.format(query=msg)
|
||||
model_output = llm_response_to_string(llm.invoke(prompt))
|
||||
logger.debug(model_output)
|
||||
|
||||
return model_output
|
||||
|
||||
|
||||
def update_emote_react(
|
||||
emoji: str,
|
||||
channel: str,
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
# The following prompts are used for verifying the LLM answer after it is already produced.
|
||||
# Reflexion flow essentially. This feature can be toggled on/off
|
||||
from onyx.configs.app_configs import CUSTOM_ANSWER_VALIDITY_CONDITIONS
|
||||
from onyx.prompts.constants import ANSWER_PAT
|
||||
from onyx.prompts.constants import QUESTION_PAT
|
||||
|
||||
ANSWER_VALIDITY_CONDITIONS = (
|
||||
"""
|
||||
1. Query is asking for information that varies by person or is subjective. If there is not a \
|
||||
globally true answer, the language model should not respond, therefore any answer is invalid.
|
||||
2. Answer addresses a related but different query. To be helpful, the model may provide \
|
||||
related information about a query but it won't match what the user is asking, this is invalid.
|
||||
3. Answer is just some form of "I don\'t know" or "not enough information" without significant \
|
||||
additional useful information. Explaining why it does not know or cannot answer is invalid.
|
||||
"""
|
||||
if not CUSTOM_ANSWER_VALIDITY_CONDITIONS
|
||||
else "\n".join(
|
||||
[
|
||||
f"{indice+1}. {condition}"
|
||||
for indice, condition in enumerate(CUSTOM_ANSWER_VALIDITY_CONDITIONS)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
ANSWER_FORMAT = (
|
||||
"""
|
||||
1. True or False
|
||||
2. True or False
|
||||
3. True or False
|
||||
"""
|
||||
if not CUSTOM_ANSWER_VALIDITY_CONDITIONS
|
||||
else "\n".join(
|
||||
[
|
||||
f"{indice+1}. True or False"
|
||||
for indice, _ in enumerate(CUSTOM_ANSWER_VALIDITY_CONDITIONS)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
ANSWER_VALIDITY_PROMPT = f"""
|
||||
You are an assistant to identify invalid query/answer pairs coming from a large language model.
|
||||
The query/answer pair is invalid if any of the following are True:
|
||||
{ANSWER_VALIDITY_CONDITIONS}
|
||||
|
||||
{QUESTION_PAT} {{user_query}}
|
||||
{ANSWER_PAT} {{llm_answer}}
|
||||
|
||||
------------------------
|
||||
You MUST answer in EXACTLY the following format:
|
||||
```
|
||||
{ANSWER_FORMAT}
|
||||
Final Answer: Valid or Invalid
|
||||
```
|
||||
|
||||
Hint: Remember, if ANY of the conditions are True, it is Invalid.
|
||||
""".strip()
|
||||
|
||||
|
||||
# Use the following for easy viewing of prompts
|
||||
if __name__ == "__main__":
|
||||
print(ANSWER_VALIDITY_PROMPT)
|
||||
@@ -1,5 +1,3 @@
|
||||
from onyx.prompts.constants import GENERAL_SEP_PAT
|
||||
|
||||
# ruff: noqa: E501, W605 start
|
||||
|
||||
DATETIME_REPLACEMENT_PAT = "{{CURRENT_DATETIME}}"
|
||||
@@ -12,7 +10,9 @@ ALT_CITATION_GUIDANCE_REPLACEMENT_PAT = "[[CITATION_GUIDANCE]]"
|
||||
# This is editable by the user in the admin UI.
|
||||
# The first line is intended to help guide the general feel/behavior of the system.
|
||||
DEFAULT_SYSTEM_PROMPT = f"""
|
||||
You are a highly capable, thoughtful, and precise assistant. Your goal is to deeply understand the user's intent, ask clarifying questions when needed, think step-by-step through complex problems, provide clear and accurate answers, and proactively anticipate helpful follow-up information. Always prioritize being truthful, nuanced, insightful, and efficient.
|
||||
You are an expert assistant who is truthful, nuanced, insightful, and efficient. \
|
||||
Your goal is to deeply understand the user's intent, think step-by-step through complex problems, provide clear and accurate answers, and proactively anticipate helpful follow-up information. \
|
||||
Whenever there is any ambiguity around the user's query (or more information would be helpful), you use available tools (if any) to get more context.
|
||||
|
||||
The current date is {DATETIME_REPLACEMENT_PAT}.{CITATION_GUIDANCE_REPLACEMENT_PAT}
|
||||
|
||||
@@ -93,17 +93,18 @@ This tool call completed but the results are no longer accessible.
|
||||
# date and time but the replacement pattern is not present in the prompt.
|
||||
ADDITIONAL_INFO = "\n\nAdditional Information:\n\t- {datetime_info}."
|
||||
|
||||
CHAT_NAMING = f"""
|
||||
Given the following conversation, provide a SHORT name for the conversation.{{language_hint_or_empty}}
|
||||
IMPORTANT: TRY NOT TO USE MORE THAN 5 WORDS, MAKE IT AS CONCISE AS POSSIBLE.
|
||||
Focus the name on the important keywords to convey the topic of the conversation.
|
||||
|
||||
Chat History:
|
||||
{GENERAL_SEP_PAT}
|
||||
{{chat_history}}
|
||||
{GENERAL_SEP_PAT}
|
||||
CHAT_NAMING_SYSTEM_PROMPT = """
|
||||
Given the conversation history, provide a SHORT name for the conversation. Focus the name on the important keywords to convey the topic of the conversation. \
|
||||
Make sure the name is in the same language as the user's language.
|
||||
|
||||
Based on the above, what is a short name to convey the topic of the conversation?
|
||||
IMPORTANT: DO NOT OUTPUT ANYTHING ASIDE FROM THE NAME. MAKE IT AS CONCISE AS POSSIBLE. NEVER USE MORE THAN 5 WORDS, LESS IS FINE.
|
||||
""".strip()
|
||||
|
||||
|
||||
CHAT_NAMING_REMINDER = """
|
||||
Provide a short name for the conversation. Refer to other messages in the conversation (not including this one) to determine the language of the name.
|
||||
|
||||
IMPORTANT: DO NOT OUTPUT ANYTHING ASIDE FROM THE NAME. MAKE IT AS CONCISE AS POSSIBLE. NEVER USE MORE THAN 5 WORDS, LESS IS FINE.
|
||||
""".strip()
|
||||
# ruff: noqa: E501, W605 end
|
||||
|
||||
@@ -19,6 +19,7 @@ If you need to ask questions, follow these guidelines:
|
||||
- Be concise and do not ask more than 5 questions.
|
||||
- If there are ambiguous terms or questions, ask the user to clarify.
|
||||
- Your questions should be a numbered list for clarity.
|
||||
- Respond in the user's language.
|
||||
- Make sure to gather all the information needed to carry out the research task in a concise, well-structured manner.{{internal_search_clarification_guidance}}
|
||||
- Wrap up with a quick sentence on what the clarification will help with, it's ok to reference the user query closely here.
|
||||
""".strip()
|
||||
@@ -45,7 +46,7 @@ The research plan should be formatted as a numbered list of steps and have 6 or
|
||||
|
||||
Each step should be a standalone exploration question or topic that can be researched independently but may build on previous steps.
|
||||
|
||||
Output only the numbered list of steps with no additional prefix or suffix.
|
||||
Output only the numbered list of steps with no additional prefix or suffix. Respond in the user's language.
|
||||
""".strip()
|
||||
|
||||
|
||||
@@ -78,7 +79,7 @@ The research task provided to the {RESEARCH_AGENT_TOOL_NAME} should be reasonabl
|
||||
It should not be a single short query, rather it should be 1 (or 2 if necessary) descriptive sentences that outline the direction of the investigation.
|
||||
|
||||
CRITICAL - the {RESEARCH_AGENT_TOOL_NAME} only receives the task and has no additional context about the user's query, research plan, other research agents, or message history. \
|
||||
You absolutely must provide all of the context needed to complete the task in the argument to the {RESEARCH_AGENT_TOOL_NAME}.{{internal_search_research_task_guidance}}
|
||||
You absolutely must provide all of the context needed to complete the task in the argument to the {RESEARCH_AGENT_TOOL_NAME}. The research task should be in the user's language.{{internal_search_research_task_guidance}}
|
||||
|
||||
You should call the {RESEARCH_AGENT_TOOL_NAME} MANY times before completing with the {GENERATE_REPORT_TOOL_NAME} tool.
|
||||
|
||||
@@ -128,7 +129,7 @@ For context, the date is {current_datetime}.
|
||||
|
||||
Users have explicitly selected the deep research mode and will expect a long and detailed answer. It is ok and encouraged that your response is several pages long.
|
||||
|
||||
You use different text styles and formatting to make the response easier to read. You may use markdown rarely when necessary to make the response more digestible.
|
||||
You use different text styles and formatting to make the response easier to read. You may use markdown rarely when necessary to make the response more digestible. Respond in the user's language.
|
||||
|
||||
Not every fact retrieved will be relevant to the user's query.
|
||||
|
||||
@@ -167,7 +168,7 @@ The research task provided to the {RESEARCH_AGENT_TOOL_NAME} should be reasonabl
|
||||
It should not be a single short query, rather it should be 1 (or 2 if necessary) descriptive sentences that outline the direction of the investigation.
|
||||
|
||||
CRITICAL - the {RESEARCH_AGENT_TOOL_NAME} only receives the task and has no additional context about the user's query, research plan, or message history. \
|
||||
You absolutely must provide all of the context needed to complete the task in the argument to the {RESEARCH_AGENT_TOOL_NAME}.{{internal_search_research_task_guidance}}
|
||||
You absolutely must provide all of the context needed to complete the task in the argument to the {RESEARCH_AGENT_TOOL_NAME}. The research task should be in the user's language.{{internal_search_research_task_guidance}}
|
||||
|
||||
You should call the {RESEARCH_AGENT_TOOL_NAME} MANY times before completing with the {GENERATE_REPORT_TOOL_NAME} tool.
|
||||
|
||||
|
||||
@@ -51,6 +51,8 @@ Remove any obviously irrelevant or duplicative information.
|
||||
|
||||
If a statement seems not trustworthy or is contradictory to other statements, it is important to flag it.
|
||||
|
||||
Write the report in the same language as the provided task.
|
||||
|
||||
Cite all sources INLINE using the format [1], [2], [3], etc. based on the `document` field of the source. \
|
||||
Cite inline as opposed to leaving all citations until the very end of the response.
|
||||
"""
|
||||
@@ -61,7 +63,8 @@ Please write me a comprehensive report on the research topic given the context a
|
||||
{research_topic}
|
||||
|
||||
Remember to include AS MUCH INFORMATION AS POSSIBLE and as faithful to the original sources as possible. \
|
||||
Keep it free of formatting and focus on the facts only. Be sure to include all context for each fact to avoid misinterpretation or misattribution.
|
||||
Keep it free of formatting and focus on the facts only. Be sure to include all context for each fact to avoid misinterpretation or misattribution. \
|
||||
Respond in the same language as the topic provided above.
|
||||
|
||||
Cite every fact INLINE using the format [1], [2], [3], etc. based on the `document` field of the source.
|
||||
|
||||
|
||||
@@ -1,39 +1,30 @@
|
||||
from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS
|
||||
|
||||
SLACK_QUERY_EXPANSION_PROMPT = f"""
|
||||
Rewrite the user's query into at most {MAX_SLACK_QUERY_EXPANSIONS} keyword-only queries for Slack's keyword search.
|
||||
Rewrite the user's query and, if helpful, split it into at most {MAX_SLACK_QUERY_EXPANSIONS} \
|
||||
keyword-only queries, so that Slack's keyword search yields the best matches.
|
||||
|
||||
Slack search behavior:
|
||||
- Pure keyword AND search (no semantics)
|
||||
- More words = fewer matches, so keep queries concise (1-3 words)
|
||||
Keep in mind the Slack's search behavior:
|
||||
- Pure keyword AND search (no semantics).
|
||||
- Word order matters.
|
||||
- More words = fewer matches, so keep each query concise.
|
||||
- IMPORTANT: Prefer simple 1-2 word queries over longer multi-word queries.
|
||||
|
||||
ALWAYS include:
|
||||
- Person names (e.g., "Sarah Chen", "Mike Johnson") - people search for messages from/about specific people
|
||||
- Project/product names, technical terms, proper nouns
|
||||
- Actual content words: "performance", "bug", "deployment", "API", "error"
|
||||
Critical: Extract ONLY keywords that would actually appear in Slack message content.
|
||||
|
||||
DO NOT include:
|
||||
- Meta-words: "topics", "conversations", "discussed", "summary", "messages"
|
||||
- Temporal: "today", "yesterday", "week", "month", "recent", "last"
|
||||
- Channel names: "general", "eng-general", "random"
|
||||
- Meta-words: "topics", "conversations", "discussed", "summary", "messages", "big", "main", "talking"
|
||||
- Temporal: "today", "yesterday", "week", "month", "recent", "past", "last"
|
||||
- Channels/Users: "general", "eng-general", "engineering", "@username"
|
||||
|
||||
DO include:
|
||||
- Actual content: "performance", "bug", "deployment", "API", "database", "error", "feature"
|
||||
|
||||
Examples:
|
||||
|
||||
Query: "what are the big topics in eng-general this week?"
|
||||
Output:
|
||||
|
||||
Query: "messages with Sarah about the deployment"
|
||||
Output:
|
||||
Sarah deployment
|
||||
Sarah
|
||||
deployment
|
||||
|
||||
Query: "what did Mike say about the budget?"
|
||||
Output:
|
||||
Mike budget
|
||||
Mike
|
||||
budget
|
||||
|
||||
Query: "performance issues in eng-general"
|
||||
Output:
|
||||
performance issues
|
||||
@@ -50,7 +41,7 @@ Now process this query:
|
||||
|
||||
{{query}}
|
||||
|
||||
Output (keywords only, one per line, NO explanations or commentary):
|
||||
Output:
|
||||
"""
|
||||
|
||||
SLACK_DATE_EXTRACTION_PROMPT = """
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
# The following prompts are used to pass each chunk to the LLM (the cheap/fast one)
|
||||
# to determine if the chunk is useful towards the user query. This is used as part
|
||||
# of the reranking flow
|
||||
|
||||
USEFUL_PAT = "Yes useful"
|
||||
NONUSEFUL_PAT = "Not useful"
|
||||
SECTION_FILTER_PROMPT = f"""
|
||||
Determine if the following section is USEFUL for answering the user query.
|
||||
It is NOT enough for the section to be related to the query, \
|
||||
it must contain information that is USEFUL for answering the query.
|
||||
If the section contains ANY useful information, that is good enough, \
|
||||
it does not need to fully answer the every part of the user query.
|
||||
|
||||
|
||||
Title: {{title}}
|
||||
{{optional_metadata}}
|
||||
Reference Section:
|
||||
```
|
||||
{{chunk_text}}
|
||||
```
|
||||
|
||||
User Query:
|
||||
```
|
||||
{{user_query}}
|
||||
```
|
||||
|
||||
Respond with EXACTLY AND ONLY: "{USEFUL_PAT}" or "{NONUSEFUL_PAT}"
|
||||
""".strip()
|
||||
|
||||
|
||||
# Use the following for easy viewing of prompts
|
||||
if __name__ == "__main__":
|
||||
print(SECTION_FILTER_PROMPT)
|
||||
@@ -1,29 +0,0 @@
|
||||
# Prompts that aren't part of a particular configurable feature
|
||||
|
||||
LANGUAGE_REPHRASE_PROMPT = """
|
||||
Translate query to {target_language}.
|
||||
If the query at the end is already in {target_language}, simply repeat the ORIGINAL query back to me, EXACTLY as is with no edits.
|
||||
If the query below is not in {target_language}, translate it into {target_language}.
|
||||
|
||||
Query:
|
||||
{query}
|
||||
""".strip()
|
||||
|
||||
SLACK_LANGUAGE_REPHRASE_PROMPT = """
|
||||
As an AI assistant employed by an organization, \
|
||||
your role is to transform user messages into concise \
|
||||
inquiries suitable for a Large Language Model (LLM) that \
|
||||
retrieves pertinent materials within a Retrieval-Augmented \
|
||||
Generation (RAG) framework. Ensure to reply in the identical \
|
||||
language as the original request. When faced with multiple \
|
||||
questions within a single query, distill them into a singular, \
|
||||
unified question, disregarding any direct mentions.
|
||||
|
||||
Query:
|
||||
{query}
|
||||
""".strip()
|
||||
|
||||
|
||||
# Use the following for easy viewing of prompts
|
||||
if __name__ == "__main__":
|
||||
print(LANGUAGE_REPHRASE_PROMPT)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user