mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-24 17:12:44 +00:00
Compare commits
109 Commits
right-side
...
fix/chat-s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d8b672bb8e | ||
|
|
f3e38a7ef7 | ||
|
|
a4c9926eb1 | ||
|
|
8c63831fff | ||
|
|
c48a77c644 | ||
|
|
26d70ab16b | ||
|
|
f8a2f3ac93 | ||
|
|
5186356a26 | ||
|
|
7b826e2a4e | ||
|
|
c175dc8f6a | ||
|
|
aa11813cc0 | ||
|
|
6235f49b49 | ||
|
|
fd6a110794 | ||
|
|
bd42c459d6 | ||
|
|
aede532e63 | ||
|
|
068ac543ad | ||
|
|
30e7a831a5 | ||
|
|
276261c96d | ||
|
|
205f1410e4 | ||
|
|
a93d154c27 | ||
|
|
1361879bd0 | ||
|
|
c58cc320b2 | ||
|
|
461350958a | ||
|
|
50dde0be1a | ||
|
|
199e1df453 | ||
|
|
996b674840 | ||
|
|
5413723ccc | ||
|
|
9660056a51 | ||
|
|
3105177238 | ||
|
|
24bb4bda8b | ||
|
|
9532af4ceb | ||
|
|
0a913f6af5 | ||
|
|
fe30c55199 | ||
|
|
2cf0a65dd3 | ||
|
|
659416f363 | ||
|
|
40aecbc4b9 | ||
|
|
710b39074f | ||
|
|
8fe2f67d38 | ||
|
|
f00aaf9fc0 | ||
|
|
5b2426b002 | ||
|
|
ba6ab0245b | ||
|
|
b64ebb57e1 | ||
|
|
2fcfdbabde | ||
|
|
ea1a2749c1 | ||
|
|
73c4e22588 | ||
|
|
fceaac6e13 | ||
|
|
e8bf45cfd2 | ||
|
|
13ff648fcd | ||
|
|
ae8268afb1 | ||
|
|
b338bd9e97 | ||
|
|
0dcc90a042 | ||
|
|
0f6a6693d3 | ||
|
|
e32cc450b2 | ||
|
|
732fb71edf | ||
|
|
ca3320c0e0 | ||
|
|
d7c554aca7 | ||
|
|
69e5c19695 | ||
|
|
b4ce1c7a97 | ||
|
|
cd64a91154 | ||
|
|
c282cdc096 | ||
|
|
b1de1c59b6 | ||
|
|
64d484039f | ||
|
|
0530095b71 | ||
|
|
23280d5b91 | ||
|
|
229442679c | ||
|
|
95a192fb0f | ||
|
|
6bd96ec906 | ||
|
|
a1ec88269f | ||
|
|
b929518c34 | ||
|
|
479220e774 | ||
|
|
d3e0acf905 | ||
|
|
cbd1a344f2 | ||
|
|
e79264b69b | ||
|
|
1e0a8e9a0e | ||
|
|
b7841a513d | ||
|
|
c779bf722d | ||
|
|
a5aff0d199 | ||
|
|
8ed170b070 | ||
|
|
c890cd4767 | ||
|
|
2b2df18463 | ||
|
|
11cfc92f15 | ||
|
|
c7da99cfd7 | ||
|
|
b384c77863 | ||
|
|
b0f31cd46b | ||
|
|
323eb9bbba | ||
|
|
708e310849 | ||
|
|
c25509e212 | ||
|
|
6af0da41bd | ||
|
|
b94da25d7c | ||
|
|
7d443c1b53 | ||
|
|
d6b7b3c68f | ||
|
|
f5073d331e | ||
|
|
64c9f6a0d5 | ||
|
|
f5a494f790 | ||
|
|
8598e9f25d | ||
|
|
3ef8aecc54 | ||
|
|
eb311c7550 | ||
|
|
13284d9def | ||
|
|
aaa99fcb60 | ||
|
|
5f628da4e8 | ||
|
|
e40f80cfe1 | ||
|
|
ca6ba2cca9 | ||
|
|
98ef5006ff | ||
|
|
dfd168cde9 | ||
|
|
6c7ae243d0 | ||
|
|
c4a2ff2593 | ||
|
|
4b74a6dc76 | ||
|
|
eea5f5b380 | ||
|
|
ae428ba684 |
25
.github/actions/slack-notify/action.yml
vendored
25
.github/actions/slack-notify/action.yml
vendored
@@ -10,6 +10,9 @@ inputs:
|
||||
failed-jobs:
|
||||
description: "Deprecated alias for details"
|
||||
required: false
|
||||
mention:
|
||||
description: "GitHub username to resolve to a Slack @-mention. Replaces {mention} in details."
|
||||
required: false
|
||||
title:
|
||||
description: "Title for the notification"
|
||||
required: false
|
||||
@@ -26,6 +29,7 @@ runs:
|
||||
SLACK_WEBHOOK_URL: ${{ inputs.webhook-url }}
|
||||
DETAILS: ${{ inputs.details }}
|
||||
FAILED_JOBS: ${{ inputs.failed-jobs }}
|
||||
MENTION_USER: ${{ inputs.mention }}
|
||||
TITLE: ${{ inputs.title }}
|
||||
REF_NAME: ${{ inputs.ref-name }}
|
||||
REPO: ${{ github.repository }}
|
||||
@@ -52,6 +56,27 @@ runs:
|
||||
DETAILS="$FAILED_JOBS"
|
||||
fi
|
||||
|
||||
# Resolve {mention} placeholder if a GitHub username was provided.
|
||||
# Looks up the username in user-mappings.json (co-located with this action)
|
||||
# and replaces {mention} with <@SLACK_ID> for a Slack @-mention.
|
||||
# Falls back to the plain GitHub username if not found in the mapping.
|
||||
if [ -n "$MENTION_USER" ]; then
|
||||
MAPPINGS_FILE="${GITHUB_ACTION_PATH}/user-mappings.json"
|
||||
slack_id="$(jq -r --arg gh "$MENTION_USER" 'to_entries[] | select(.value | ascii_downcase == ($gh | ascii_downcase)) | .key' "$MAPPINGS_FILE" 2>/dev/null | head -1)"
|
||||
|
||||
if [ -n "$slack_id" ]; then
|
||||
mention_text="<@${slack_id}>"
|
||||
else
|
||||
mention_text="${MENTION_USER}"
|
||||
fi
|
||||
|
||||
DETAILS="${DETAILS//\{mention\}/$mention_text}"
|
||||
TITLE="${TITLE//\{mention\}/}"
|
||||
else
|
||||
DETAILS="${DETAILS//\{mention\}/}"
|
||||
TITLE="${TITLE//\{mention\}/}"
|
||||
fi
|
||||
|
||||
normalize_multiline() {
|
||||
printf '%s' "$1" | awk 'BEGIN { ORS=""; first=1 } { if (!first) printf "\\n"; printf "%s", $0; first=0 }'
|
||||
}
|
||||
|
||||
18
.github/actions/slack-notify/user-mappings.json
vendored
Normal file
18
.github/actions/slack-notify/user-mappings.json
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"U05SAGZPEA1": "yuhongsun96",
|
||||
"U05SAH6UGUD": "Weves",
|
||||
"U07PWEQB7A5": "evan-onyx",
|
||||
"U07V1SM68KF": "joachim-danswer",
|
||||
"U08JZ9N3QNN": "raunakab",
|
||||
"U08L24NCLJE": "Subash-Mohan",
|
||||
"U090B9M07B2": "wenxi-onyx",
|
||||
"U094RASDP0Q": "duo-onyx",
|
||||
"U096L8ZQ85B": "justin-tahara",
|
||||
"U09AHV8UBQX": "jessicasingh7",
|
||||
"U09KAL5T3C2": "nmgarza5",
|
||||
"U09KPGVQ70R": "acaprau",
|
||||
"U09QR8KTSJH": "rohoswagger",
|
||||
"U09RB4NTXA4": "jmelahman",
|
||||
"U0A6K9VCY6A": "Danelegend",
|
||||
"U0AGC4KH71A": "Bo-Onyx"
|
||||
}
|
||||
30
.github/workflows/deployment.yml
vendored
30
.github/workflows/deployment.yml
vendored
@@ -455,7 +455,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
@@ -529,7 +529,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
@@ -607,7 +607,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
@@ -668,7 +668,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
@@ -750,7 +750,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
@@ -836,7 +836,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
@@ -894,7 +894,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
@@ -967,7 +967,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
@@ -1044,7 +1044,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
@@ -1105,7 +1105,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
@@ -1178,7 +1178,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
@@ -1256,7 +1256,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
@@ -1317,7 +1317,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
@@ -1397,7 +1397,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
@@ -1480,7 +1480,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
|
||||
with:
|
||||
images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
|
||||
@@ -207,7 +207,7 @@ jobs:
|
||||
CHERRY_PICK_PR_URL: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_pr_url }}
|
||||
run: |
|
||||
source_pr_url="https://github.com/${GITHUB_REPOSITORY}/pull/${SOURCE_PR_NUMBER}"
|
||||
details="*Cherry-pick PR opened successfully.*\\n• source PR: ${source_pr_url}"
|
||||
details="*Cherry-pick PR opened successfully.*\\n• author: {mention}\\n• source PR: ${source_pr_url}"
|
||||
if [ -n "${CHERRY_PICK_PR_URL}" ]; then
|
||||
details="${details}\\n• cherry-pick PR: ${CHERRY_PICK_PR_URL}"
|
||||
fi
|
||||
@@ -221,6 +221,7 @@ jobs:
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
|
||||
mention: ${{ needs.resolve-cherry-pick-request.outputs.merged_by }}
|
||||
details: ${{ steps.success-summary.outputs.details }}
|
||||
title: "✅ Automated Cherry-Pick PR Opened"
|
||||
ref-name: ${{ github.event.pull_request.base.ref }}
|
||||
@@ -275,20 +276,21 @@ jobs:
|
||||
else
|
||||
failed_job_label="cherry-pick-to-latest-release"
|
||||
fi
|
||||
failed_jobs="• ${failed_job_label}\\n• source PR: ${source_pr_url}\\n• reason: ${reason_text}"
|
||||
details="• author: {mention}\\n• ${failed_job_label}\\n• source PR: ${source_pr_url}\\n• reason: ${reason_text}"
|
||||
if [ -n "${MERGE_COMMIT_SHA}" ]; then
|
||||
failed_jobs="${failed_jobs}\\n• merge SHA: ${MERGE_COMMIT_SHA}"
|
||||
details="${details}\\n• merge SHA: ${MERGE_COMMIT_SHA}"
|
||||
fi
|
||||
if [ -n "${details_excerpt}" ]; then
|
||||
failed_jobs="${failed_jobs}\\n• excerpt: ${details_excerpt}"
|
||||
details="${details}\\n• excerpt: ${details_excerpt}"
|
||||
fi
|
||||
|
||||
echo "jobs=${failed_jobs}" >> "$GITHUB_OUTPUT"
|
||||
echo "details=${details}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Notify #cherry-pick-prs about cherry-pick failure
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
|
||||
details: ${{ steps.failure-summary.outputs.jobs }}
|
||||
mention: ${{ needs.resolve-cherry-pick-request.outputs.merged_by }}
|
||||
details: ${{ steps.failure-summary.outputs.details }}
|
||||
title: "🚨 Automated Cherry-Pick Failed"
|
||||
ref-name: ${{ github.event.pull_request.base.ref }}
|
||||
|
||||
2
.github/workflows/pr-desktop-build.yml
vendored
2
.github/workflows/pr-desktop-build.yml
vendored
@@ -105,7 +105,7 @@ jobs:
|
||||
|
||||
- name: Upload build artifacts
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
|
||||
with:
|
||||
name: desktop-build-${{ matrix.platform }}-${{ github.run_id }}
|
||||
path: |
|
||||
|
||||
@@ -174,7 +174,7 @@ jobs:
|
||||
|
||||
- name: Upload Docker logs
|
||||
if: failure()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
|
||||
with:
|
||||
name: docker-logs-${{ matrix.test-dir }}
|
||||
path: docker-logs/
|
||||
|
||||
4
.github/workflows/pr-golang-tests.yml
vendored
4
.github/workflows/pr-golang-tests.yml
vendored
@@ -25,7 +25,7 @@ jobs:
|
||||
outputs:
|
||||
modules: ${{ steps.set-modules.outputs.modules }}
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd
|
||||
with:
|
||||
persist-credentials: false
|
||||
- id: set-modules
|
||||
@@ -39,7 +39,7 @@ jobs:
|
||||
matrix:
|
||||
modules: ${{ fromJSON(needs.detect-modules.outputs.modules) }}
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # zizmor: ignore[cache-poisoning]
|
||||
|
||||
6
.github/workflows/pr-integration-tests.yml
vendored
6
.github/workflows/pr-integration-tests.yml
vendored
@@ -466,7 +466,7 @@ jobs:
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
|
||||
with:
|
||||
name: docker-all-logs-${{ matrix.edition }}-${{ matrix.test-dir.name }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
@@ -587,7 +587,7 @@ jobs:
|
||||
|
||||
- name: Upload logs (onyx-lite)
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
|
||||
with:
|
||||
name: docker-all-logs-onyx-lite
|
||||
path: ${{ github.workspace }}/docker-compose-onyx-lite.log
|
||||
@@ -725,7 +725,7 @@ jobs:
|
||||
|
||||
- name: Upload logs (multi-tenant)
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
|
||||
with:
|
||||
name: docker-all-logs-multitenant
|
||||
path: ${{ github.workspace }}/docker-compose-multitenant.log
|
||||
|
||||
2
.github/workflows/pr-jest-tests.yml
vendored
2
.github/workflows/pr-jest-tests.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
|
||||
with:
|
||||
name: jest-coverage-${{ github.run_id }}
|
||||
path: ./web/coverage
|
||||
|
||||
14
.github/workflows/pr-playwright-tests.yml
vendored
14
.github/workflows/pr-playwright-tests.yml
vendored
@@ -445,7 +445,7 @@ jobs:
|
||||
run: |
|
||||
npx playwright test --project ${PROJECT}
|
||||
|
||||
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
- uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
|
||||
if: always()
|
||||
with:
|
||||
# Includes test results and trace.zip files
|
||||
@@ -454,7 +454,7 @@ jobs:
|
||||
retention-days: 30
|
||||
|
||||
- name: Upload screenshots
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
|
||||
if: always()
|
||||
with:
|
||||
name: playwright-screenshots-${{ matrix.project }}-${{ github.run_id }}
|
||||
@@ -534,7 +534,7 @@ jobs:
|
||||
"s3://${PLAYWRIGHT_S3_BUCKET}/reports/pr-${PR_NUMBER}/${RUN_ID}/${PROJECT}/"
|
||||
|
||||
- name: Upload visual diff summary
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
|
||||
if: always()
|
||||
with:
|
||||
name: screenshot-diff-summary-${{ matrix.project }}
|
||||
@@ -543,7 +543,7 @@ jobs:
|
||||
retention-days: 5
|
||||
|
||||
- name: Upload visual diff report artifact
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
|
||||
if: always()
|
||||
with:
|
||||
name: screenshot-diff-report-${{ matrix.project }}-${{ github.run_id }}
|
||||
@@ -590,7 +590,7 @@ jobs:
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
|
||||
with:
|
||||
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
@@ -674,7 +674,7 @@ jobs:
|
||||
working-directory: ./web
|
||||
run: npx playwright test --project lite
|
||||
|
||||
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
- uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
|
||||
if: always()
|
||||
with:
|
||||
name: playwright-test-results-lite-${{ github.run_id }}
|
||||
@@ -692,7 +692,7 @@ jobs:
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
|
||||
with:
|
||||
name: docker-logs-lite-${{ github.run_id }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
2
.github/workflows/pr-python-model-tests.yml
vendored
2
.github/workflows/pr-python-model-tests.yml
vendored
@@ -122,7 +122,7 @@ jobs:
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
|
||||
with:
|
||||
name: docker-all-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
@@ -319,7 +319,7 @@ jobs:
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f
|
||||
with:
|
||||
name: docker-all-logs-nightly-${{ matrix.provider }}-llm-provider
|
||||
path: |
|
||||
|
||||
6
.github/workflows/sandbox-deployment.yml
vendored
6
.github/workflows/sandbox-deployment.yml
vendored
@@ -125,7 +125,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
@@ -195,7 +195,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
@@ -268,7 +268,7 @@ jobs:
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
|
||||
uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
flavor: |
|
||||
|
||||
279
AGENTS.md
279
AGENTS.md
@@ -167,284 +167,7 @@ web/
|
||||
|
||||
## Frontend Standards
|
||||
|
||||
### 1. Import Standards
|
||||
|
||||
**Always use absolute imports with the `@` prefix.**
|
||||
|
||||
**Reason:** Moving files around becomes easier since you don't also have to update those import statements. This makes modifications to the codebase much nicer.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useAuth } from "@/hooks/useAuth";
|
||||
import { Text } from "@/refresh-components/texts/Text";
|
||||
|
||||
// ❌ Bad
|
||||
import { Button } from "../../../components/ui/button";
|
||||
import { useAuth } from "./hooks/useAuth";
|
||||
```
|
||||
|
||||
### 2. React Component Functions
|
||||
|
||||
**Prefer regular functions over arrow functions for React components.**
|
||||
|
||||
**Reason:** Functions just become easier to read.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
function UserProfile({ userId }: UserProfileProps) {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
const UserProfile = ({ userId }: UserProfileProps) => {
|
||||
return <div>User Profile</div>
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Props Interface Extraction
|
||||
|
||||
**Extract prop types into their own interface definitions.**
|
||||
|
||||
**Reason:** Functions just become easier to read.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
interface UserCardProps {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}
|
||||
|
||||
function UserCard({ user, showActions = false, onEdit }: UserCardProps) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({
|
||||
user,
|
||||
showActions = false,
|
||||
onEdit
|
||||
}: {
|
||||
user: User
|
||||
showActions?: boolean
|
||||
onEdit?: (userId: string) => void
|
||||
}) {
|
||||
return <div>User Card</div>
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Spacing Guidelines
|
||||
|
||||
**Prefer padding over margins for spacing.**
|
||||
|
||||
**Reason:** We want to consolidate usage to paddings instead of margins.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
<div className="p-4 space-y-2">
|
||||
<div className="p-2">Content</div>
|
||||
</div>
|
||||
|
||||
// ❌ Bad
|
||||
<div className="m-4 space-y-2">
|
||||
<div className="m-2">Content</div>
|
||||
</div>
|
||||
```
|
||||
|
||||
### 5. Tailwind Dark Mode
|
||||
|
||||
**Strictly forbid using the `dark:` modifier in Tailwind classes, except for logo icon handling.**
|
||||
|
||||
**Reason:** The `colors.css` file already, VERY CAREFULLY, defines what the exact opposite colour of each light-mode colour is. Overriding this behaviour is VERY bad and will lead to horrible UI breakages.
|
||||
|
||||
**Exception:** The `createLogoIcon` helper in `web/src/components/icons/icons.tsx` uses `dark:` modifiers (`dark:invert`, `dark:hidden`, `dark:block`) to handle third-party logo icons that cannot automatically adapt through `colors.css`. This is the ONLY acceptable use of dark mode modifiers.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Standard components use `tailwind-themes/tailwind.config.js` / `src/app/css/colors.css`
|
||||
<div className="bg-background-neutral-03 text-text-02">
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ✅ Good - Logo icons with dark mode handling via createLogoIcon
|
||||
export const GithubIcon = createLogoIcon(githubLightIcon, {
|
||||
monochromatic: true, // Will apply dark:invert internally
|
||||
});
|
||||
|
||||
export const GitbookIcon = createLogoIcon(gitbookLightIcon, {
|
||||
darkSrc: gitbookDarkIcon, // Will use dark:hidden/dark:block internally
|
||||
});
|
||||
|
||||
// ❌ Bad - Manual dark mode overrides
|
||||
<div className="bg-white dark:bg-black text-black dark:text-white">
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
### 6. Class Name Utilities
|
||||
|
||||
**Use the `cn` utility instead of raw string formatting for classNames.**
|
||||
|
||||
**Reason:** `cn`s are easier to read. They also allow for more complex types (i.e., string-arrays) to get formatted properly (it flattens each element in that string array down). As a result, it can allow things such as conditionals (i.e., `myCondition && "some-tailwind-class"`, which evaluates to `false` when `myCondition` is `false`) to get filtered out.
|
||||
|
||||
```typescript
|
||||
import { cn } from '@/lib/utils'
|
||||
|
||||
// ✅ Good
|
||||
<div className={cn(
|
||||
'base-class',
|
||||
isActive && 'active-class',
|
||||
className
|
||||
)}>
|
||||
Content
|
||||
</div>
|
||||
|
||||
// ❌ Bad
|
||||
<div className={`base-class ${isActive ? 'active-class' : ''} ${className}`}>
|
||||
Content
|
||||
</div>
|
||||
```
|
||||
|
||||
### 7. Custom Hooks Organization
|
||||
|
||||
**Follow a "hook-per-file" layout. Each hook should live in its own file within `web/src/hooks`.**
|
||||
|
||||
**Reason:** This is just a layout preference. Keeps code clean.
|
||||
|
||||
```typescript
|
||||
// web/src/hooks/useUserData.ts
|
||||
export function useUserData(userId: string) {
|
||||
// hook implementation
|
||||
}
|
||||
|
||||
// web/src/hooks/useLocalStorage.ts
|
||||
export function useLocalStorage<T>(key: string, initialValue: T) {
|
||||
// hook implementation
|
||||
}
|
||||
```
|
||||
|
||||
### 8. Icon Usage
|
||||
|
||||
**ONLY use icons from the `web/src/icons` directory. Do NOT use icons from `react-icons`, `lucide`, or other external libraries.**
|
||||
|
||||
**Reason:** We have a very carefully curated selection of icons that match our Onyx guidelines. We do NOT want to muddy those up with different aesthetic stylings.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import SvgX from "@/icons/x";
|
||||
import SvgMoreHorizontal from "@/icons/more-horizontal";
|
||||
|
||||
// ❌ Bad
|
||||
import { User } from "lucide-react";
|
||||
import { FiSearch } from "react-icons/fi";
|
||||
```
|
||||
|
||||
**Missing Icons**: If an icon is needed but doesn't exist in the `web/src/icons` directory, import it from Figma using the Figma MCP tool and add it to the icons directory.
|
||||
If you need help with this step, reach out to `raunak@onyx.app`.
|
||||
|
||||
### 9. Text Rendering
|
||||
|
||||
**Prefer using the `refresh-components/texts/Text` component for all text rendering. Avoid "naked" text nodes.**
|
||||
|
||||
**Reason:** The `Text` component is fully compliant with the stylings provided in Figma. It provides easy utilities to specify the text-colour and font-size in the form of flags. Super duper easy.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import { Text } from '@/refresh-components/texts/Text'
|
||||
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<Text
|
||||
{/* The `text03` flag makes the text it renders to be coloured the 3rd-scale grey */}
|
||||
text03
|
||||
{/* The `mainAction` flag makes the text it renders to be "main-action" font + line-height + weightage, as described in the Figma */}
|
||||
mainAction
|
||||
>
|
||||
{name}
|
||||
</Text>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function UserCard({ name }: { name: string }) {
|
||||
return (
|
||||
<div>
|
||||
<h2>{name}</h2>
|
||||
<p>User details</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### 10. Component Usage
|
||||
|
||||
**Heavily avoid raw HTML input components. Always use components from the `web/src/refresh-components` or `web/lib/opal/src` directory.**
|
||||
|
||||
**Reason:** We've put in a lot of effort to unify the components that are rendered in the Onyx app. Using raw components breaks the entire UI of the application, and leaves it in a muddier state than before.
|
||||
|
||||
```typescript
|
||||
// ✅ Good
|
||||
import Button from '@/refresh-components/buttons/Button'
|
||||
import InputTypeIn from '@/refresh-components/inputs/InputTypeIn'
|
||||
import SvgPlusCircle from '@/icons/plus-circle'
|
||||
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<InputTypeIn placeholder="Search..." />
|
||||
<Button type="submit" leftIcon={SvgPlusCircle}>Submit</Button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
||||
// ❌ Bad
|
||||
function ContactForm() {
|
||||
return (
|
||||
<form>
|
||||
<input placeholder="Name" />
|
||||
<textarea placeholder="Message" />
|
||||
<button type="submit">Submit</button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### 11. Colors
|
||||
|
||||
**Always use custom overrides for colors and borders rather than built in Tailwind CSS colors. These overrides live in `web/tailwind-themes/tailwind.config.js`.**
|
||||
|
||||
**Reason:** Our custom color system uses CSS variables that automatically handle dark mode and maintain design consistency across the app. Standard Tailwind colors bypass this system.
|
||||
|
||||
**Available color categories:**
|
||||
|
||||
- **Text:** `text-01` through `text-05`, `text-inverted-XX`
|
||||
- **Backgrounds:** `background-neutral-XX`, `background-tint-XX` (and inverted variants)
|
||||
- **Borders:** `border-01` through `border-05`, `border-inverted-XX`
|
||||
- **Actions:** `action-link-XX`, `action-danger-XX`
|
||||
- **Status:** `status-info-XX`, `status-success-XX`, `status-warning-XX`, `status-error-XX`
|
||||
- **Theme:** `theme-primary-XX`, `theme-red-XX`, `theme-blue-XX`, etc.
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Use custom Onyx color classes
|
||||
<div className="bg-background-neutral-01 border border-border-02" />
|
||||
<div className="bg-background-tint-02 border border-border-01" />
|
||||
<div className="bg-status-success-01" />
|
||||
<div className="bg-action-link-01" />
|
||||
<div className="bg-theme-primary-05" />
|
||||
|
||||
// ❌ Bad - Do NOT use standard Tailwind colors
|
||||
<div className="bg-gray-100 border border-gray-300 text-gray-600" />
|
||||
<div className="bg-white border border-slate-200" />
|
||||
<div className="bg-green-100 text-green-700" />
|
||||
<div className="bg-blue-100 text-blue-600" />
|
||||
<div className="bg-indigo-500" />
|
||||
```
|
||||
|
||||
### 12. Data Fetching
|
||||
|
||||
**Prefer using `useSWR` for data fetching. Data should generally be fetched on the client side. Components that need data should display a loader / placeholder while waiting for that data. Prefer loading data within the component that needs it rather than at the top level and passing it down.**
|
||||
|
||||
**Reason:** Client side fetching allows us to load the skeleton of the page without waiting for data to load, leading to a snappier UX. Loading data where needed reduces dependencies between a component and its parent component(s).
|
||||
Frontend standards for the `web/` and `desktop/` projects live in `web/AGENTS.md`.
|
||||
|
||||
## Database & Migrations
|
||||
|
||||
|
||||
@@ -47,6 +47,8 @@ RUN apt-get update && \
|
||||
gcc \
|
||||
nano \
|
||||
vim \
|
||||
# Install procps so kubernetes exec sessions can use ps aux for debugging
|
||||
procps \
|
||||
libjemalloc2 \
|
||||
&& \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
|
||||
@@ -0,0 +1,103 @@
|
||||
"""add_hook_and_hook_execution_log_tables
|
||||
|
||||
Revision ID: 689433b0d8de
|
||||
Revises: 93a2e195e25c
|
||||
Create Date: 2026-03-13 11:25:06.547474
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import UUID as PGUUID
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "689433b0d8de"
|
||||
down_revision = "93a2e195e25c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"hook",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"hook_point",
|
||||
sa.Enum("document_ingestion", "query_processing", native_enum=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("endpoint_url", sa.Text(), nullable=True),
|
||||
sa.Column("api_key", sa.LargeBinary(), nullable=True),
|
||||
sa.Column("is_reachable", sa.Boolean(), nullable=True),
|
||||
sa.Column(
|
||||
"fail_strategy",
|
||||
sa.Enum("hard", "soft", native_enum=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("timeout_seconds", sa.Float(), nullable=False),
|
||||
sa.Column(
|
||||
"is_active", sa.Boolean(), nullable=False, server_default=sa.text("false")
|
||||
),
|
||||
sa.Column(
|
||||
"deleted", sa.Boolean(), nullable=False, server_default=sa.text("false")
|
||||
),
|
||||
sa.Column("creator_id", PGUUID(as_uuid=True), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(["creator_id"], ["user.id"], ondelete="SET NULL"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_hook_one_non_deleted_per_point",
|
||||
"hook",
|
||||
["hook_point"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("deleted = false"),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"hook_execution_log",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("hook_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"is_success",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("error_message", sa.Text(), nullable=True),
|
||||
sa.Column("status_code", sa.Integer(), nullable=True),
|
||||
sa.Column("duration_ms", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(["hook_id"], ["hook.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("ix_hook_execution_log_hook_id", "hook_execution_log", ["hook_id"])
|
||||
op.create_index(
|
||||
"ix_hook_execution_log_created_at", "hook_execution_log", ["created_at"]
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_hook_execution_log_created_at", table_name="hook_execution_log")
|
||||
op.drop_index("ix_hook_execution_log_hook_id", table_name="hook_execution_log")
|
||||
op.drop_table("hook_execution_log")
|
||||
|
||||
op.drop_index("ix_hook_one_non_deleted_per_point", table_name="hook")
|
||||
op.drop_table("hook")
|
||||
@@ -25,9 +25,6 @@ from onyx.redis.redis_pool import get_redis_client
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
|
||||
# Default number of pre-provisioned tenants to maintain
|
||||
DEFAULT_TARGET_AVAILABLE_TENANTS = 5
|
||||
|
||||
# Soft time limit for tenant pre-provisioning tasks (in seconds)
|
||||
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
|
||||
# Hard time limit for tenant pre-provisioning tasks (in seconds)
|
||||
@@ -58,7 +55,7 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
lock_check: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_AVAILABLE_TENANTS_LOCK,
|
||||
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
|
||||
timeout=_TENANT_PROVISIONING_TIME_LIMIT,
|
||||
)
|
||||
|
||||
# These tasks should never overlap
|
||||
@@ -74,9 +71,7 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
num_available_tenants = db_session.query(AvailableTenant).count()
|
||||
|
||||
# Get the target number of available tenants
|
||||
num_minimum_available_tenants = getattr(
|
||||
TARGET_AVAILABLE_TENANTS, "value", DEFAULT_TARGET_AVAILABLE_TENANTS
|
||||
)
|
||||
num_minimum_available_tenants = TARGET_AVAILABLE_TENANTS
|
||||
|
||||
# Calculate how many new tenants we need to provision
|
||||
if num_available_tenants < num_minimum_available_tenants:
|
||||
@@ -98,7 +93,12 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
|
||||
task_logger.exception("Error in check_available_tenants task")
|
||||
|
||||
finally:
|
||||
lock_check.release()
|
||||
try:
|
||||
lock_check.release()
|
||||
except Exception:
|
||||
task_logger.warning(
|
||||
"Could not release check lock (likely expired), continuing"
|
||||
)
|
||||
|
||||
|
||||
def pre_provision_tenant() -> None:
|
||||
@@ -113,7 +113,7 @@ def pre_provision_tenant() -> None:
|
||||
r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
lock_provision: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CLOUD_PRE_PROVISION_TENANT_LOCK,
|
||||
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
|
||||
timeout=_TENANT_PROVISIONING_TIME_LIMIT,
|
||||
)
|
||||
|
||||
# Allow multiple pre-provisioning tasks to run, but ensure they don't overlap
|
||||
@@ -185,4 +185,9 @@ def pre_provision_tenant() -> None:
|
||||
except Exception:
|
||||
task_logger.exception(f"Error during rollback for tenant: {tenant_id}")
|
||||
finally:
|
||||
lock_provision.release()
|
||||
try:
|
||||
lock_provision.release()
|
||||
except Exception:
|
||||
task_logger.warning(
|
||||
"Could not release provision lock (likely expired), continuing"
|
||||
)
|
||||
|
||||
@@ -118,9 +118,7 @@ JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
|
||||
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", "[]"))
|
||||
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
|
||||
|
||||
# The posthog client does not accept empty API keys or hosts however it fails silently
|
||||
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app
|
||||
POSTHOG_API_KEY = os.environ.get("POSTHOG_API_KEY") or "FooBar"
|
||||
POSTHOG_API_KEY = os.environ.get("POSTHOG_API_KEY")
|
||||
POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
|
||||
POSTHOG_DEBUG_LOGS_ENABLED = (
|
||||
os.environ.get("POSTHOG_DEBUG_LOGS_ENABLED", "").lower() == "true"
|
||||
|
||||
@@ -34,6 +34,9 @@ class PostHogFeatureFlagProvider(FeatureFlagProvider):
|
||||
Returns:
|
||||
True if the feature is enabled for the user, False otherwise.
|
||||
"""
|
||||
if not posthog:
|
||||
return False
|
||||
|
||||
try:
|
||||
posthog.set(
|
||||
distinct_id=user_id,
|
||||
|
||||
@@ -29,7 +29,6 @@ from onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
|
||||
from onyx.configs.app_configs import OPENROUTER_DEFAULT_API_KEY
|
||||
from onyx.configs.app_configs import VERTEXAI_DEFAULT_CREDENTIALS
|
||||
from onyx.configs.app_configs import VERTEXAI_DEFAULT_LOCATION
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine.sql_engine import get_session_with_shared_schema
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.image_generation import create_default_image_gen_config_from_api_key
|
||||
@@ -59,7 +58,6 @@ from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.setup import setup_onyx
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
@@ -71,7 +69,9 @@ logger = setup_logger()
|
||||
|
||||
|
||||
async def get_or_provision_tenant(
|
||||
email: str, referral_source: str | None = None, request: Request | None = None
|
||||
email: str,
|
||||
referral_source: str | None = None,
|
||||
request: Request | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get existing tenant ID for an email or create a new tenant if none exists.
|
||||
@@ -693,12 +693,6 @@ async def assign_tenant_to_user(
|
||||
|
||||
try:
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=email,
|
||||
event=MilestoneRecordType.TENANT_CREATED,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to assign tenant {tenant_id} to user {email}")
|
||||
raise Exception("Failed to assign tenant to user")
|
||||
|
||||
@@ -9,6 +9,7 @@ from ee.onyx.configs.app_configs import POSTHOG_API_KEY
|
||||
from ee.onyx.configs.app_configs import POSTHOG_DEBUG_LOGS_ENABLED
|
||||
from ee.onyx.configs.app_configs import POSTHOG_HOST
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -18,12 +19,19 @@ def posthog_on_error(error: Any, items: Any) -> None:
|
||||
logger.error(f"PostHog error: {error}, items: {items}")
|
||||
|
||||
|
||||
posthog = Posthog(
|
||||
project_api_key=POSTHOG_API_KEY,
|
||||
host=POSTHOG_HOST,
|
||||
debug=POSTHOG_DEBUG_LOGS_ENABLED,
|
||||
on_error=posthog_on_error,
|
||||
)
|
||||
posthog: Posthog | None = None
|
||||
if POSTHOG_API_KEY:
|
||||
posthog = Posthog(
|
||||
project_api_key=POSTHOG_API_KEY,
|
||||
host=POSTHOG_HOST,
|
||||
debug=POSTHOG_DEBUG_LOGS_ENABLED,
|
||||
on_error=posthog_on_error,
|
||||
)
|
||||
elif MULTI_TENANT:
|
||||
logger.warning(
|
||||
"POSTHOG_API_KEY is not set but MULTI_TENANT is enabled — "
|
||||
"PostHog telemetry and feature flags will be disabled"
|
||||
)
|
||||
|
||||
# For cross referencing between cloud and www Onyx sites
|
||||
# NOTE: These clients are separate because they are separate posthog projects.
|
||||
@@ -60,7 +68,7 @@ def capture_and_sync_with_alternate_posthog(
|
||||
logger.error(f"Error capturing marketing posthog event: {e}")
|
||||
|
||||
try:
|
||||
if cloud_user_id := props.get("onyx_cloud_user_id"):
|
||||
if posthog and (cloud_user_id := props.get("onyx_cloud_user_id")):
|
||||
cloud_props = props.copy()
|
||||
cloud_props.pop("onyx_cloud_user_id", None)
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from ee.onyx.utils.posthog_client import posthog
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -5,12 +7,27 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def event_telemetry(
|
||||
distinct_id: str, event: str, properties: dict | None = None
|
||||
distinct_id: str, event: str, properties: dict[str, Any] | None = None
|
||||
) -> None:
|
||||
"""Capture and send an event to PostHog, flushing immediately."""
|
||||
if not posthog:
|
||||
return
|
||||
|
||||
logger.info(f"Capturing PostHog event: {distinct_id} {event} {properties}")
|
||||
try:
|
||||
posthog.capture(distinct_id, event, properties)
|
||||
posthog.flush()
|
||||
except Exception as e:
|
||||
logger.error(f"Error capturing PostHog event: {e}")
|
||||
|
||||
|
||||
def identify_user(distinct_id: str, properties: dict[str, Any] | None = None) -> None:
|
||||
"""Create/update a PostHog person profile, flushing immediately."""
|
||||
if not posthog:
|
||||
return
|
||||
|
||||
try:
|
||||
posthog.identify(distinct_id, properties)
|
||||
posthog.flush()
|
||||
except Exception as e:
|
||||
logger.error(f"Error identifying PostHog user: {e}")
|
||||
|
||||
@@ -19,6 +19,7 @@ from typing import Optional
|
||||
from typing import Protocol
|
||||
from typing import Tuple
|
||||
from typing import TypeVar
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import jwt
|
||||
from email_validator import EmailNotValidError
|
||||
@@ -134,6 +135,7 @@ from onyx.redis.redis_pool import retrieve_ws_token_data
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import mt_cloud_identify
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
@@ -792,6 +794,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
except Exception:
|
||||
logger.exception("Error deleting anonymous user cookie")
|
||||
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
mt_cloud_identify(
|
||||
distinct_id=str(user.id),
|
||||
properties={"email": user.email, "tenant_id": tenant_id},
|
||||
)
|
||||
|
||||
async def on_after_register(
|
||||
self, user: User, request: Optional[Request] = None
|
||||
) -> None:
|
||||
@@ -810,12 +818,25 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user_count = await get_user_count()
|
||||
logger.debug(f"Current tenant user count: {user_count}")
|
||||
|
||||
# Ensure a PostHog person profile exists for this user.
|
||||
mt_cloud_identify(
|
||||
distinct_id=str(user.id),
|
||||
properties={"email": user.email, "tenant_id": tenant_id},
|
||||
)
|
||||
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=user.email,
|
||||
distinct_id=str(user.id),
|
||||
event=MilestoneRecordType.USER_SIGNED_UP,
|
||||
)
|
||||
|
||||
if user_count == 1:
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=str(user.id),
|
||||
event=MilestoneRecordType.TENANT_CREATED,
|
||||
)
|
||||
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
@@ -1652,6 +1673,33 @@ async def _get_user_from_token_data(token_data: dict) -> User | None:
|
||||
return user
|
||||
|
||||
|
||||
_LOOPBACK_HOSTNAMES = frozenset({"localhost", "127.0.0.1", "::1"})
|
||||
|
||||
|
||||
def _is_same_origin(actual: str, expected: str) -> bool:
|
||||
"""Compare two origins for the WebSocket CSWSH check.
|
||||
|
||||
Scheme and hostname must match exactly. Port must also match, except
|
||||
when the hostname is a loopback address (localhost / 127.0.0.1 / ::1),
|
||||
where port is ignored. On loopback, all ports belong to the same
|
||||
operator, so port differences carry no security significance — the
|
||||
CSWSH threat is remote origins, not local ones.
|
||||
"""
|
||||
a = urlparse(actual.rstrip("/"))
|
||||
e = urlparse(expected.rstrip("/"))
|
||||
|
||||
if a.scheme != e.scheme or a.hostname != e.hostname:
|
||||
return False
|
||||
|
||||
if a.hostname in _LOOPBACK_HOSTNAMES:
|
||||
return True
|
||||
|
||||
actual_port = a.port or (443 if a.scheme == "https" else 80)
|
||||
expected_port = e.port or (443 if e.scheme == "https" else 80)
|
||||
|
||||
return actual_port == expected_port
|
||||
|
||||
|
||||
async def current_user_from_websocket(
|
||||
websocket: WebSocket,
|
||||
token: str = Query(..., description="WebSocket authentication token"),
|
||||
@@ -1671,19 +1719,15 @@ async def current_user_from_websocket(
|
||||
|
||||
This applies the same auth checks as current_user() for HTTP endpoints.
|
||||
"""
|
||||
# Check Origin header to prevent Cross-Site WebSocket Hijacking (CSWSH)
|
||||
# Browsers always send Origin on WebSocket connections
|
||||
# Check Origin header to prevent Cross-Site WebSocket Hijacking (CSWSH).
|
||||
# Browsers always send Origin on WebSocket connections.
|
||||
origin = websocket.headers.get("origin")
|
||||
expected_origin = WEB_DOMAIN.rstrip("/")
|
||||
if not origin:
|
||||
logger.warning("WS auth: missing Origin header")
|
||||
raise BasicAuthenticationError(detail="Access denied. Missing origin.")
|
||||
|
||||
actual_origin = origin.rstrip("/")
|
||||
if actual_origin != expected_origin:
|
||||
logger.warning(
|
||||
f"WS auth: origin mismatch. Expected {expected_origin}, got {actual_origin}"
|
||||
)
|
||||
if not _is_same_origin(origin, WEB_DOMAIN):
|
||||
logger.warning(f"WS auth: origin mismatch. Expected {WEB_DOMAIN}, got {origin}")
|
||||
raise BasicAuthenticationError(detail="Access denied. Invalid origin.")
|
||||
|
||||
# Validate WS token in Redis (single-use, deleted after retrieval)
|
||||
|
||||
@@ -317,6 +317,7 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
"onyx.background.celery.tasks.evals",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
"onyx.background.celery.tasks.hooks",
|
||||
"onyx.background.celery.tasks.periodic",
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.shared",
|
||||
|
||||
@@ -14,6 +14,7 @@ from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
# choosing 15 minutes because it roughly gives us enough time to process many tasks
|
||||
@@ -361,6 +362,19 @@ if not MULTI_TENANT:
|
||||
|
||||
tasks_to_schedule.extend(beat_task_templates)
|
||||
|
||||
if HOOKS_AVAILABLE:
|
||||
tasks_to_schedule.append(
|
||||
{
|
||||
"name": "hook-execution-log-cleanup",
|
||||
"task": OnyxCeleryTask.HOOK_EXECUTION_LOG_CLEANUP_TASK,
|
||||
"schedule": timedelta(days=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def generate_cloud_tasks(
|
||||
beat_tasks: list[dict], beat_templates: list[dict], beat_multiplier: float
|
||||
|
||||
@@ -29,6 +29,8 @@ from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.connectors.factory import ConnectorMissingException
|
||||
from onyx.connectors.factory import identify_connector_class
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.interfaces import HierarchyConnector
|
||||
from onyx.connectors.models import HierarchyNode as PydanticHierarchyNode
|
||||
@@ -55,6 +57,26 @@ logger = setup_logger()
|
||||
HIERARCHY_FETCH_INTERVAL_SECONDS = 24 * 60 * 60
|
||||
|
||||
|
||||
def _connector_supports_hierarchy_fetching(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> bool:
|
||||
"""Return True only for connectors whose class implements HierarchyConnector."""
|
||||
try:
|
||||
connector_class = identify_connector_class(
|
||||
cc_pair.connector.source,
|
||||
)
|
||||
except ConnectorMissingException as e:
|
||||
task_logger.warning(
|
||||
"Skipping hierarchy fetching enqueue for source=%s input_type=%s: %s",
|
||||
cc_pair.connector.source,
|
||||
cc_pair.connector.input_type,
|
||||
str(e),
|
||||
)
|
||||
return False
|
||||
|
||||
return issubclass(connector_class, HierarchyConnector)
|
||||
|
||||
|
||||
def _is_hierarchy_fetching_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
"""Returns boolean indicating if hierarchy fetching is due for this connector.
|
||||
|
||||
@@ -186,7 +208,10 @@ def check_for_hierarchy_fetching(self: Task, *, tenant_id: str) -> int | None:
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
if not cc_pair or not _is_hierarchy_fetching_due(cc_pair):
|
||||
if not cc_pair or not _connector_supports_hierarchy_fetching(cc_pair):
|
||||
continue
|
||||
|
||||
if not _is_hierarchy_fetching_due(cc_pair):
|
||||
continue
|
||||
|
||||
task_id = _try_creating_hierarchy_fetching_task(
|
||||
|
||||
35
backend/onyx/background/celery/tasks/hooks/tasks.py
Normal file
35
backend/onyx/background/celery/tasks/hooks/tasks.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from celery import shared_task
|
||||
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.hook import cleanup_old_execution_logs__no_commit
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_HOOK_EXECUTION_LOG_RETENTION_DAYS: int = 30
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.HOOK_EXECUTION_LOG_CLEANUP_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
)
|
||||
def hook_execution_log_cleanup_task(*, tenant_id: str) -> None: # noqa: ARG001
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
deleted: int = cleanup_old_execution_logs__no_commit(
|
||||
db_session=db_session,
|
||||
max_age_days=_HOOK_EXECUTION_LOG_RETENTION_DAYS,
|
||||
)
|
||||
db_session.commit()
|
||||
if deleted:
|
||||
logger.info(
|
||||
f"Deleted {deleted} hook execution log(s) older than "
|
||||
f"{_HOOK_EXECUTION_LOG_RETENTION_DAYS} days."
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to clean up hook execution logs")
|
||||
raise
|
||||
@@ -24,6 +24,7 @@ from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_DELETE_TASK_EXPIRES
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
|
||||
@@ -33,6 +34,7 @@ from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import USER_FILE_DELETE_MAX_QUEUE_DEPTH
|
||||
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
|
||||
from onyx.configs.constants import USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
@@ -91,6 +93,17 @@ def _user_file_delete_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_DELETE_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_delete_queued_key(user_file_id: str | UUID) -> str:
|
||||
"""Key that exists while a delete_single_user_file task is sitting in the queue.
|
||||
|
||||
The beat generator sets this with a TTL equal to CELERY_USER_FILE_DELETE_TASK_EXPIRES
|
||||
before enqueuing and the worker deletes it as its first action. This prevents
|
||||
the beat from adding duplicate tasks for files that already have a live task
|
||||
in flight.
|
||||
"""
|
||||
return f"{OnyxRedisLocks.USER_FILE_DELETE_QUEUED_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def get_user_file_project_sync_queue_depth(celery_app: Celery) -> int:
|
||||
redis_celery: Redis = celery_app.broker_connection().channel().client # type: ignore
|
||||
return celery_get_queue_length(
|
||||
@@ -546,7 +559,23 @@ def process_single_user_file(
|
||||
ignore_result=True,
|
||||
)
|
||||
def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
|
||||
"""Scan for user files with DELETING status and enqueue per-file tasks."""
|
||||
"""Scan for user files with DELETING status and enqueue per-file tasks.
|
||||
|
||||
Three mechanisms prevent queue runaway (mirrors check_user_file_processing):
|
||||
|
||||
1. **Queue depth backpressure** – if the broker queue already has more than
|
||||
USER_FILE_DELETE_MAX_QUEUE_DEPTH items we skip this beat cycle entirely.
|
||||
|
||||
2. **Per-file queued guard** – before enqueuing a task we set a short-lived
|
||||
Redis key (TTL = CELERY_USER_FILE_DELETE_TASK_EXPIRES). If that key
|
||||
already exists the file already has a live task in the queue, so we skip
|
||||
it. The worker deletes the key the moment it picks up the task so the
|
||||
next beat cycle can re-enqueue if the file is still DELETING.
|
||||
|
||||
3. **Task expiry** – every enqueued task carries an `expires` value equal to
|
||||
CELERY_USER_FILE_DELETE_TASK_EXPIRES. If a task is still sitting in
|
||||
the queue after that deadline, Celery discards it without touching the DB.
|
||||
"""
|
||||
task_logger.info("check_for_user_file_delete - Starting")
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
lock: RedisLock = redis_client.lock(
|
||||
@@ -555,8 +584,23 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
|
||||
)
|
||||
if not lock.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
enqueued = 0
|
||||
skipped_guard = 0
|
||||
try:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
# NOTE: must use the broker's Redis client (not redis_client) because
|
||||
# Celery queues live on a separate Redis DB with CELERY_SEPARATOR keys.
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
queue_len = celery_get_queue_length(OnyxCeleryQueues.USER_FILE_DELETE, r_celery)
|
||||
if queue_len > USER_FILE_DELETE_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
f"check_for_user_file_delete - Queue depth {queue_len} exceeds "
|
||||
f"{USER_FILE_DELETE_MAX_QUEUE_DEPTH}, skipping enqueue for "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file_ids = (
|
||||
db_session.execute(
|
||||
@@ -568,23 +612,40 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
|
||||
.all()
|
||||
)
|
||||
for user_file_id in user_file_ids:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_DELETE,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
# --- Protection 2: per-file queued guard ---
|
||||
queued_key = _user_file_delete_queued_key(user_file_id)
|
||||
guard_set = redis_client.set(
|
||||
queued_key,
|
||||
1,
|
||||
ex=CELERY_USER_FILE_DELETE_TASK_EXPIRES,
|
||||
nx=True,
|
||||
)
|
||||
if not guard_set:
|
||||
skipped_guard += 1
|
||||
continue
|
||||
|
||||
# --- Protection 3: task expiry ---
|
||||
try:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
kwargs={
|
||||
"user_file_id": str(user_file_id),
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.USER_FILE_DELETE,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
expires=CELERY_USER_FILE_DELETE_TASK_EXPIRES,
|
||||
)
|
||||
except Exception:
|
||||
redis_client.delete(queued_key)
|
||||
raise
|
||||
enqueued += 1
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"check_for_user_file_delete - Error enqueuing deletes - {e.__class__.__name__}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"check_for_user_file_delete - Enqueued {enqueued} tasks for tenant={tenant_id}"
|
||||
f"check_for_user_file_delete - Enqueued {enqueued} tasks, skipped_guard={skipped_guard} for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -602,6 +663,9 @@ def delete_user_file_impl(
|
||||
file_lock: RedisLock | None = None
|
||||
if redis_locking:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
# Clear the queued guard so the beat can re-enqueue if deletion fails
|
||||
# and the file remains in DELETING status.
|
||||
redis_client.delete(_user_file_delete_queued_key(user_file_id))
|
||||
file_lock = redis_client.lock(
|
||||
_user_file_delete_lock_key(user_file_id),
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
|
||||
4
backend/onyx/cache/postgres_backend.py
vendored
4
backend/onyx/cache/postgres_backend.py
vendored
@@ -297,7 +297,9 @@ class PostgresCacheBackend(CacheBackend):
|
||||
|
||||
def _lock_id_for(self, name: str) -> int:
|
||||
"""Map *name* to a 64-bit signed int for ``pg_advisory_lock``."""
|
||||
h = hashlib.md5(f"{self._tenant_id}:{name}".encode()).digest()
|
||||
h = hashlib.md5(
|
||||
f"{self._tenant_id}:{name}".encode(), usedforsecurity=False
|
||||
).digest()
|
||||
return struct.unpack("q", h[:8])[0]
|
||||
|
||||
|
||||
|
||||
@@ -30,6 +30,8 @@ from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.utils import plaintext_file_name_for_id
|
||||
from onyx.file_store.utils import store_plaintext
|
||||
from onyx.kg.models import KGException
|
||||
from onyx.kg.setup.kg_default_entity_definitions import (
|
||||
populate_missing_default_entity_types__commit,
|
||||
@@ -289,6 +291,33 @@ def process_kg_commands(
|
||||
raise KGException("KG setup done")
|
||||
|
||||
|
||||
def _get_or_extract_plaintext(
|
||||
file_id: str,
|
||||
extract_fn: Callable[[], str],
|
||||
) -> str:
|
||||
"""Load cached plaintext for a file, or extract and store it.
|
||||
|
||||
Tries to read pre-stored plaintext from the file store. On a miss,
|
||||
calls extract_fn to produce the text, then stores the result so
|
||||
future calls skip the expensive extraction.
|
||||
"""
|
||||
file_store = get_default_file_store()
|
||||
plaintext_key = plaintext_file_name_for_id(file_id)
|
||||
|
||||
# Try cached plaintext first.
|
||||
try:
|
||||
plaintext_io = file_store.read_file(plaintext_key, mode="b")
|
||||
return plaintext_io.read().decode("utf-8")
|
||||
except Exception:
|
||||
logger.exception(f"Error when reading file, id={file_id}")
|
||||
|
||||
# Cache miss — extract and store.
|
||||
content_text = extract_fn()
|
||||
if content_text:
|
||||
store_plaintext(file_id, content_text)
|
||||
return content_text
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def load_chat_file(
|
||||
file_descriptor: FileDescriptor, db_session: Session
|
||||
@@ -303,12 +332,23 @@ def load_chat_file(
|
||||
file_type = ChatFileType(file_descriptor["type"])
|
||||
|
||||
if file_type.is_text_file():
|
||||
try:
|
||||
content_text = extract_file_text(
|
||||
file_id = file_descriptor["id"]
|
||||
|
||||
def _extract() -> str:
|
||||
return extract_file_text(
|
||||
file=file_io,
|
||||
file_name=file_descriptor.get("name") or "",
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
|
||||
# Use the user_file_id as cache key when available (matches what
|
||||
# the celery indexing worker stores), otherwise fall back to the
|
||||
# file store id (covers code-interpreter-generated files, etc.).
|
||||
user_file_id_str = file_descriptor.get("user_file_id")
|
||||
cache_key = user_file_id_str or file_id
|
||||
|
||||
try:
|
||||
content_text = _get_or_extract_plaintext(cache_key, _extract)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to retrieve content for file {file_descriptor['id']}: {str(e)}"
|
||||
|
||||
@@ -36,9 +36,11 @@ from onyx.db.memory import add_memory
|
||||
from onyx.db.memory import update_memory_at_index
|
||||
from onyx.db.memory import UserMemoryContext
|
||||
from onyx.db.models import Persona
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.utils import is_true_openai_model
|
||||
from onyx.prompts.chat_prompts import IMAGE_GEN_REMINDER
|
||||
from onyx.prompts.chat_prompts import OPEN_URL_REMINDER
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
@@ -72,6 +74,70 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class EmptyLLMResponseError(RuntimeError):
|
||||
"""Raised when the streamed LLM response completes without a usable answer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
model: str,
|
||||
tool_choice: ToolChoiceOptions,
|
||||
client_error_msg: str,
|
||||
error_code: str = "EMPTY_LLM_RESPONSE",
|
||||
is_retryable: bool = True,
|
||||
) -> None:
|
||||
super().__init__(client_error_msg)
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self.tool_choice = tool_choice
|
||||
self.client_error_msg = client_error_msg
|
||||
self.error_code = error_code
|
||||
self.is_retryable = is_retryable
|
||||
|
||||
|
||||
def _build_empty_llm_response_error(
|
||||
llm: LLM,
|
||||
llm_step_result: LlmStepResult,
|
||||
tool_choice: ToolChoiceOptions,
|
||||
) -> EmptyLLMResponseError:
|
||||
provider = llm.config.model_provider
|
||||
model = llm.config.model_name
|
||||
|
||||
# OpenAI quota exhaustion has reached us as a streamed "stop" with zero content.
|
||||
# When the stream is completely empty and there is no reasoning/tool output, surface
|
||||
# the likely account-level cause instead of a generic tool-calling error.
|
||||
if (
|
||||
not llm_step_result.reasoning
|
||||
and provider == LlmProviderNames.OPENAI
|
||||
and is_true_openai_model(provider, model)
|
||||
):
|
||||
return EmptyLLMResponseError(
|
||||
provider=provider,
|
||||
model=model,
|
||||
tool_choice=tool_choice,
|
||||
client_error_msg=(
|
||||
"The selected OpenAI model returned an empty streamed response "
|
||||
"before producing any tokens. This commonly happens when the API "
|
||||
"key or project has no remaining quota or billing is not enabled. "
|
||||
"Verify quota and billing for this key and try again."
|
||||
),
|
||||
error_code="BUDGET_EXCEEDED",
|
||||
is_retryable=False,
|
||||
)
|
||||
|
||||
return EmptyLLMResponseError(
|
||||
provider=provider,
|
||||
model=model,
|
||||
tool_choice=tool_choice,
|
||||
client_error_msg=(
|
||||
"The selected model returned no final answer before the stream "
|
||||
"completed. No text or tool calls were received from the upstream "
|
||||
"provider."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _looks_like_xml_tool_call_payload(text: str | None) -> bool:
|
||||
"""Detect XML-style marshaled tool calls emitted as plain text."""
|
||||
if not text:
|
||||
@@ -613,7 +679,12 @@ def run_llm_loop(
|
||||
)
|
||||
citation_processor.update_citation_mapping(project_citation_mapping)
|
||||
|
||||
llm_step_result: LlmStepResult | None = None
|
||||
llm_step_result = LlmStepResult(
|
||||
reasoning=None,
|
||||
answer=None,
|
||||
tool_calls=None,
|
||||
raw_answer=None,
|
||||
)
|
||||
|
||||
# Pass the total budget to construct_message_history, which will handle token allocation
|
||||
available_tokens = llm.config.max_input_tokens
|
||||
@@ -1084,12 +1155,18 @@ def run_llm_loop(
|
||||
# As long as 1 tool with citeable documents is called at any point, we ask the LLM to try to cite
|
||||
should_cite_documents = True
|
||||
|
||||
if not llm_step_result or not llm_step_result.answer:
|
||||
if not llm_step_result.answer and not llm_step_result.tool_calls:
|
||||
raise _build_empty_llm_response_error(
|
||||
llm=llm,
|
||||
llm_step_result=llm_step_result,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
if not llm_step_result.answer:
|
||||
raise RuntimeError(
|
||||
"The LLM did not return an answer. "
|
||||
"Typically this is an issue with LLMs that do not support tool calling natively, "
|
||||
"or the model serving API is not configured correctly. "
|
||||
"This may also happen with models that are lower quality outputting invalid tool calls."
|
||||
"The LLM did not return a final answer after tool execution. "
|
||||
"Typically this indicates invalid tool-call output, a model/provider mismatch, "
|
||||
"or serving API misconfiguration."
|
||||
)
|
||||
|
||||
emitter.emit(
|
||||
|
||||
@@ -1013,6 +1013,10 @@ def run_llm_step_pkt_generator(
|
||||
accumulated_reasoning = ""
|
||||
accumulated_answer = ""
|
||||
accumulated_raw_answer = ""
|
||||
stream_chunk_count = 0
|
||||
actionable_chunk_count = 0
|
||||
empty_chunk_count = 0
|
||||
finish_reasons: set[str] = set()
|
||||
xml_tool_call_content_filter = _XmlToolCallContentFilter()
|
||||
|
||||
processor_state: Any = None
|
||||
@@ -1145,6 +1149,7 @@ def run_llm_step_pkt_generator(
|
||||
user_identity=user_identity,
|
||||
timeout_override=timeout_override,
|
||||
):
|
||||
stream_chunk_count += 1
|
||||
if packet.usage:
|
||||
usage = packet.usage
|
||||
span_generation.span_data.usage = {
|
||||
@@ -1154,16 +1159,21 @@ def run_llm_step_pkt_generator(
|
||||
"cache_creation_input_tokens": usage.cache_creation_input_tokens,
|
||||
}
|
||||
# Note: LLM cost tracking is now handled in multi_llm.py
|
||||
finish_reason = packet.choice.finish_reason
|
||||
if finish_reason:
|
||||
finish_reasons.add(str(finish_reason))
|
||||
delta = packet.choice.delta
|
||||
|
||||
# Weird behavior from some model providers, just log and ignore for now
|
||||
if (
|
||||
delta.content is None
|
||||
not delta.content
|
||||
and delta.reasoning_content is None
|
||||
and delta.tool_calls is None
|
||||
and not delta.tool_calls
|
||||
):
|
||||
empty_chunk_count += 1
|
||||
logger.warning(
|
||||
f"LLM packet is empty (no contents, reasoning or tool calls). Skipping: {packet}"
|
||||
"LLM packet is empty (no content, reasoning, or tool calls). "
|
||||
f"finish_reason={finish_reason}. Skipping: {packet}"
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -1172,6 +1182,8 @@ def run_llm_step_pkt_generator(
|
||||
time.monotonic() - stream_start_time
|
||||
)
|
||||
first_action_recorded = True
|
||||
if _delta_has_action(delta):
|
||||
actionable_chunk_count += 1
|
||||
|
||||
if custom_token_processor:
|
||||
# The custom token processor can modify the deltas for specific custom logic
|
||||
@@ -1307,6 +1319,15 @@ def run_llm_step_pkt_generator(
|
||||
else:
|
||||
logger.debug("Tool calls: []")
|
||||
|
||||
if actionable_chunk_count == 0:
|
||||
logger.warning(
|
||||
"LLM stream completed with no actionable deltas. "
|
||||
f"chunks={stream_chunk_count}, empty_chunks={empty_chunk_count}, "
|
||||
f"finish_reasons={sorted(finish_reasons)}, "
|
||||
f"provider={llm.config.model_provider}, model={llm.config.model_name}, "
|
||||
f"tool_choice={tool_choice}, tools_sent={len(tool_definitions)}"
|
||||
)
|
||||
|
||||
return (
|
||||
LlmStepResult(
|
||||
reasoning=accumulated_reasoning if accumulated_reasoning else None,
|
||||
|
||||
@@ -177,8 +177,8 @@ class ExtractedContextFiles(BaseModel):
|
||||
class SearchParams(BaseModel):
|
||||
"""Resolved search filter IDs and search-tool usage for a chat turn."""
|
||||
|
||||
search_project_id: int | None
|
||||
search_persona_id: int | None
|
||||
project_id_filter: int | None
|
||||
persona_id_filter: int | None
|
||||
search_usage: SearchToolUsage
|
||||
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ from onyx.chat.compression import compress_chat_history
|
||||
from onyx.chat.compression import find_summary_for_branch
|
||||
from onyx.chat.compression import get_compression_params
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.chat.llm_loop import EmptyLLMResponseError
|
||||
from onyx.chat.llm_loop import run_llm_loop
|
||||
from onyx.chat.models import AnswerStream
|
||||
from onyx.chat.models import ChatBasicResponse
|
||||
@@ -398,13 +399,13 @@ def determine_search_params(
|
||||
"""
|
||||
is_custom_persona = persona_id != DEFAULT_PERSONA_ID
|
||||
|
||||
search_project_id: int | None = None
|
||||
search_persona_id: int | None = None
|
||||
project_id_filter: int | None = None
|
||||
persona_id_filter: int | None = None
|
||||
if extracted_context_files.use_as_search_filter:
|
||||
if is_custom_persona:
|
||||
search_persona_id = persona_id
|
||||
persona_id_filter = persona_id
|
||||
else:
|
||||
search_project_id = project_id
|
||||
project_id_filter = project_id
|
||||
|
||||
search_usage = SearchToolUsage.AUTO
|
||||
if not is_custom_persona and project_id:
|
||||
@@ -417,8 +418,8 @@ def determine_search_params(
|
||||
search_usage = SearchToolUsage.DISABLED
|
||||
|
||||
return SearchParams(
|
||||
search_project_id=search_project_id,
|
||||
search_persona_id=search_persona_id,
|
||||
project_id_filter=project_id_filter,
|
||||
persona_id_filter=persona_id_filter,
|
||||
search_usage=search_usage,
|
||||
)
|
||||
|
||||
@@ -473,11 +474,18 @@ def handle_stream_message_objects(
|
||||
db_session=db_session,
|
||||
)
|
||||
yield CreateChatSessionID(chat_session_id=chat_session.id)
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session.id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
eager_load_persona=True,
|
||||
)
|
||||
else:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=new_msg_req.chat_session_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
eager_load_persona=True,
|
||||
)
|
||||
|
||||
persona = chat_session.persona
|
||||
@@ -490,13 +498,13 @@ def handle_stream_message_objects(
|
||||
# Milestone tracking, most devs using the API don't need to understand this
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=user.email if not user.is_anonymous else tenant_id,
|
||||
distinct_id=str(user.id) if not user.is_anonymous else tenant_id,
|
||||
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
|
||||
)
|
||||
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=user.email if not user.is_anonymous else tenant_id,
|
||||
distinct_id=str(user.id) if not user.is_anonymous else tenant_id,
|
||||
event=MilestoneRecordType.USER_MESSAGE_SENT,
|
||||
properties={
|
||||
"origin": new_msg_req.origin.value,
|
||||
@@ -710,8 +718,8 @@ def handle_stream_message_objects(
|
||||
llm=llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
user_selected_filters=new_msg_req.internal_search_filters,
|
||||
project_id=search_params.search_project_id,
|
||||
persona_id=search_params.search_persona_id,
|
||||
project_id_filter=search_params.project_id_filter,
|
||||
persona_id_filter=search_params.persona_id_filter,
|
||||
bypass_acl=bypass_acl,
|
||||
slack_context=slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
@@ -925,9 +933,28 @@ def handle_stream_message_objects(
|
||||
db_session.rollback()
|
||||
return
|
||||
|
||||
except EmptyLLMResponseError as e:
|
||||
stack_trace = traceback.format_exc()
|
||||
|
||||
logger.warning(
|
||||
"LLM returned an empty response "
|
||||
f"(provider={e.provider}, model={e.model}, tool_choice={e.tool_choice})"
|
||||
)
|
||||
|
||||
yield StreamingError(
|
||||
error=e.client_error_msg,
|
||||
stack_trace=stack_trace,
|
||||
error_code=e.error_code,
|
||||
is_retryable=e.is_retryable,
|
||||
details={
|
||||
"model": e.model,
|
||||
"provider": e.provider,
|
||||
"tool_choice": e.tool_choice.value,
|
||||
},
|
||||
)
|
||||
db_session.rollback()
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to process chat message due to {e}")
|
||||
error_msg = str(e)
|
||||
stack_trace = traceback.format_exc()
|
||||
|
||||
if llm:
|
||||
@@ -1046,10 +1073,46 @@ def llm_loop_completion_handle(
|
||||
)
|
||||
|
||||
|
||||
def remove_answer_citations(answer: str) -> str:
|
||||
pattern = r"\s*\[\[\d+\]\]\(http[s]?://[^\s]+\)"
|
||||
_CITATION_LINK_START_PATTERN = re.compile(r"\s*\[\[\d+\]\]\(")
|
||||
|
||||
return re.sub(pattern, "", answer)
|
||||
|
||||
def _find_markdown_link_end(text: str, destination_start: int) -> int | None:
|
||||
depth = 0
|
||||
i = destination_start
|
||||
|
||||
while i < len(text):
|
||||
curr = text[i]
|
||||
if curr == "\\":
|
||||
i += 2
|
||||
continue
|
||||
|
||||
if curr == "(":
|
||||
depth += 1
|
||||
elif curr == ")":
|
||||
if depth == 0:
|
||||
return i
|
||||
depth -= 1
|
||||
|
||||
i += 1
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def remove_answer_citations(answer: str) -> str:
|
||||
stripped_parts: list[str] = []
|
||||
cursor = 0
|
||||
|
||||
while match := _CITATION_LINK_START_PATTERN.search(answer, cursor):
|
||||
stripped_parts.append(answer[cursor : match.start()])
|
||||
link_end = _find_markdown_link_end(answer, match.end())
|
||||
if link_end is None:
|
||||
stripped_parts.append(answer[match.start() :])
|
||||
return "".join(stripped_parts)
|
||||
|
||||
cursor = link_end + 1
|
||||
|
||||
stripped_parts.append(answer[cursor:])
|
||||
return "".join(stripped_parts)
|
||||
|
||||
|
||||
@log_function_time()
|
||||
@@ -1087,8 +1150,11 @@ def gather_stream(
|
||||
raise ValueError("Message ID is required")
|
||||
|
||||
if answer is None:
|
||||
# This should never be the case as these non-streamed flows do not have a stop-generation signal
|
||||
raise RuntimeError("Answer was not generated")
|
||||
if error_msg is not None:
|
||||
answer = ""
|
||||
else:
|
||||
# This should never be the case as these non-streamed flows do not have a stop-generation signal
|
||||
raise RuntimeError("Answer was not generated")
|
||||
|
||||
return ChatBasicResponse(
|
||||
answer=answer,
|
||||
|
||||
@@ -278,14 +278,17 @@ USING_AWS_MANAGED_OPENSEARCH = (
|
||||
OPENSEARCH_PROFILING_DISABLED = (
|
||||
os.environ.get("OPENSEARCH_PROFILING_DISABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Whether to disable match highlights for OpenSearch. Defaults to True for now
|
||||
# as we investigate query performance.
|
||||
OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED = (
|
||||
os.environ.get("OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED", "true").lower() == "true"
|
||||
)
|
||||
# When enabled, OpenSearch returns detailed score breakdowns for each hit.
|
||||
# Useful for debugging and tuning search relevance. Has ~10-30% performance overhead according to documentation.
|
||||
# Seems for Hybrid Search in practice, the impact is actually more like 1000x slower.
|
||||
OPENSEARCH_EXPLAIN_ENABLED = (
|
||||
os.environ.get("OPENSEARCH_EXPLAIN_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Analyzer used for full-text fields (title, content). Use OpenSearch built-in analyzer
|
||||
# names (e.g. "english", "standard", "german"). Affects stemming and tokenization;
|
||||
# existing indices need reindexing after a change.
|
||||
@@ -318,8 +321,16 @@ VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT = (
|
||||
OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE = int(
|
||||
os.environ.get("OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE") or 500
|
||||
)
|
||||
OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = int(
|
||||
os.environ.get("OPENSEARCH_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES") or 0
|
||||
# If set, will override the default number of shards and replicas for the index.
|
||||
OPENSEARCH_INDEX_NUM_SHARDS: int | None = (
|
||||
int(os.environ["OPENSEARCH_INDEX_NUM_SHARDS"])
|
||||
if os.environ.get("OPENSEARCH_INDEX_NUM_SHARDS", None) is not None
|
||||
else None
|
||||
)
|
||||
OPENSEARCH_INDEX_NUM_REPLICAS: int | None = (
|
||||
int(os.environ["OPENSEARCH_INDEX_NUM_REPLICAS"])
|
||||
if os.environ.get("OPENSEARCH_INDEX_NUM_REPLICAS", None) is not None
|
||||
else None
|
||||
)
|
||||
|
||||
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
|
||||
@@ -957,7 +968,7 @@ ENTERPRISE_EDITION_ENABLED = (
|
||||
#####
|
||||
# Image Generation Configuration (DEPRECATED)
|
||||
# These environment variables will be deprecated soon.
|
||||
# To configure image generation, please visit the Image Generation page in the Admin Settings.
|
||||
# To configure image generation, please visit the Image Generation page in the Admin Panel.
|
||||
#####
|
||||
# Azure Image Configurations
|
||||
AZURE_IMAGE_API_VERSION = os.environ.get("AZURE_IMAGE_API_VERSION") or os.environ.get(
|
||||
@@ -1046,6 +1057,8 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
|
||||
|
||||
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
|
||||
|
||||
HOOK_ENABLED = os.environ.get("HOOK_ENABLED", "").lower() == "true"
|
||||
|
||||
INTEGRATION_TESTS_MODE = os.environ.get("INTEGRATION_TESTS_MODE", "").lower() == "true"
|
||||
|
||||
#####
|
||||
|
||||
@@ -177,6 +177,14 @@ USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH = 500
|
||||
|
||||
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
|
||||
# How long a queued user-file-delete task is valid before workers discard it.
|
||||
# Mirrors the processing task expiry to prevent indefinite queue growth when
|
||||
# files are stuck in DELETING status and the beat keeps re-enqueuing them.
|
||||
CELERY_USER_FILE_DELETE_TASK_EXPIRES = 60 # 1 minute (in seconds)
|
||||
|
||||
# Max queue depth before the delete beat stops enqueuing more delete tasks.
|
||||
USER_FILE_DELETE_MAX_QUEUE_DEPTH = 500
|
||||
|
||||
CELERY_SANDBOX_FILE_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
|
||||
@@ -469,6 +477,9 @@ class OnyxRedisLocks:
|
||||
USER_FILE_PROJECT_SYNC_QUEUED_PREFIX = "da_lock:user_file_project_sync_queued"
|
||||
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
|
||||
USER_FILE_DELETE_LOCK_PREFIX = "da_lock:user_file_delete"
|
||||
# Short-lived key set when a delete task is enqueued; cleared when the worker picks it up.
|
||||
# Prevents the beat from re-enqueuing the same file while a delete task is already queued.
|
||||
USER_FILE_DELETE_QUEUED_PREFIX = "da_lock:user_file_delete_queued"
|
||||
|
||||
# Release notes
|
||||
RELEASE_NOTES_FETCH_LOCK = "da_lock:release_notes_fetch"
|
||||
@@ -597,6 +608,9 @@ class OnyxCeleryTask:
|
||||
EXPORT_QUERY_HISTORY_TASK = "export_query_history_task"
|
||||
EXPORT_QUERY_HISTORY_CLEANUP_TASK = "export_query_history_cleanup_task"
|
||||
|
||||
# Hook execution log retention
|
||||
HOOK_EXECUTION_LOG_CLEANUP_TASK = "hook_execution_log_cleanup_task"
|
||||
|
||||
# Sandbox cleanup
|
||||
CLEANUP_IDLE_SANDBOXES = "cleanup_idle_sandboxes"
|
||||
CLEANUP_OLD_SNAPSHOTS = "cleanup_old_snapshots"
|
||||
|
||||
@@ -157,9 +157,7 @@ def _execute_single_retrieval(
|
||||
logger.error(f"Error executing request: {e}")
|
||||
raise e
|
||||
elif _is_rate_limit_error(e):
|
||||
results = _execute_with_retry(
|
||||
lambda: retrieval_function(**request_kwargs).execute()
|
||||
)
|
||||
results = _execute_with_retry(retrieval_function(**request_kwargs))
|
||||
elif e.resp.status == 404 or e.resp.status == 403:
|
||||
if continue_on_404_or_403:
|
||||
logger.debug(f"Error executing request: {e}")
|
||||
|
||||
@@ -2,7 +2,6 @@ from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
@@ -70,9 +69,13 @@ class BaseFilters(BaseModel):
|
||||
|
||||
|
||||
class UserFileFilters(BaseModel):
|
||||
user_file_ids: list[UUID] | None = None
|
||||
project_id: int | None = None
|
||||
persona_id: int | None = None
|
||||
# Scopes search to user files tagged with a given project/persona in Vespa.
|
||||
# These are NOT simply the IDs of the current project or persona — they are
|
||||
# only set when the persona's/project's user files overflowed the LLM
|
||||
# context window and must be searched via vector DB instead of being loaded
|
||||
# directly into the prompt.
|
||||
project_id_filter: int | None = None
|
||||
persona_id_filter: int | None = None
|
||||
|
||||
|
||||
class AssistantKnowledgeFilters(BaseModel):
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -39,9 +38,8 @@ logger = setup_logger()
|
||||
def _build_index_filters(
|
||||
user_provided_filters: BaseFilters | None,
|
||||
user: User, # Used for ACLs, anonymous users only see public docs
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
user_file_ids: list[UUID] | None,
|
||||
project_id_filter: int | None,
|
||||
persona_id_filter: int | None,
|
||||
persona_document_sets: list[str] | None,
|
||||
persona_time_cutoff: datetime | None,
|
||||
db_session: Session | None = None,
|
||||
@@ -97,16 +95,6 @@ def _build_index_filters(
|
||||
if not source_filter and detected_source_filter:
|
||||
source_filter = detected_source_filter
|
||||
|
||||
# CRITICAL FIX: If user_file_ids are present, we must ensure "user_file"
|
||||
# source type is included in the filter, otherwise user files will be excluded!
|
||||
if user_file_ids and source_filter:
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
# Add user_file to the source filter if not already present
|
||||
if DocumentSource.USER_FILE not in source_filter:
|
||||
source_filter = list(source_filter) + [DocumentSource.USER_FILE]
|
||||
logger.debug("Added USER_FILE to source_filter for user knowledge search")
|
||||
|
||||
if bypass_acl:
|
||||
user_acl_filters = None
|
||||
elif acl_filters is not None:
|
||||
@@ -117,9 +105,8 @@ def _build_index_filters(
|
||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||
|
||||
final_filters = IndexFilters(
|
||||
user_file_ids=user_file_ids,
|
||||
project_id=project_id,
|
||||
persona_id=persona_id,
|
||||
project_id_filter=project_id_filter,
|
||||
persona_id_filter=persona_id_filter,
|
||||
source_type=source_filter,
|
||||
document_set=document_set_filter,
|
||||
time_cutoff=time_filter,
|
||||
@@ -265,19 +252,16 @@ def search_pipeline(
|
||||
db_session: Session | None = None,
|
||||
auto_detect_filters: bool = False,
|
||||
llm: LLM | None = None,
|
||||
# If a project ID is provided, it will be exclusively scoped to that project
|
||||
project_id: int | None = None,
|
||||
# If a persona_id is provided, search scopes to files attached to this persona
|
||||
persona_id: int | None = None,
|
||||
# Vespa metadata filters for overflowing user files. NOT the raw IDs
|
||||
# of the current project/persona — only set when user files couldn't fit
|
||||
# in the LLM context and need to be searched via vector DB.
|
||||
project_id_filter: int | None = None,
|
||||
persona_id_filter: int | None = None,
|
||||
# Pre-fetched data — when provided, avoids DB queries (no session needed)
|
||||
acl_filters: list[str] | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
user_uploaded_persona_files: list[UUID] | None = (
|
||||
[user_file.id for user_file in persona.user_files] if persona else None
|
||||
)
|
||||
|
||||
persona_document_sets: list[str] | None = (
|
||||
[persona_document_set.name for persona_document_set in persona.document_sets]
|
||||
if persona
|
||||
@@ -302,9 +286,8 @@ def search_pipeline(
|
||||
filters = _build_index_filters(
|
||||
user_provided_filters=chunk_search_request.user_selected_filters,
|
||||
user=user,
|
||||
project_id=project_id,
|
||||
persona_id=persona_id,
|
||||
user_file_ids=user_uploaded_persona_files,
|
||||
project_id_filter=project_id_filter,
|
||||
persona_id_filter=persona_id_filter,
|
||||
persona_document_sets=persona_document_sets,
|
||||
persona_time_cutoff=persona_time_cutoff,
|
||||
db_session=db_session,
|
||||
|
||||
@@ -110,7 +110,6 @@ def search_chunks(
|
||||
user_id=user_id,
|
||||
source_types=list(source_filters) if source_filters else None,
|
||||
document_set_names=query_request.filters.document_set,
|
||||
user_file_ids=query_request.filters.user_file_ids,
|
||||
)
|
||||
|
||||
federated_sources = set(
|
||||
|
||||
@@ -28,6 +28,7 @@ from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatMessage__SearchDoc
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import ChatSessionSharedStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import SearchDoc as DBSearchDoc
|
||||
from onyx.db.models import ToolCall
|
||||
from onyx.db.models import User
|
||||
@@ -53,9 +54,17 @@ def get_chat_session_by_id(
|
||||
db_session: Session,
|
||||
include_deleted: bool = False,
|
||||
is_shared: bool = False,
|
||||
eager_load_persona: bool = False,
|
||||
) -> ChatSession:
|
||||
stmt = select(ChatSession).where(ChatSession.id == chat_session_id)
|
||||
|
||||
if eager_load_persona:
|
||||
stmt = stmt.options(
|
||||
selectinload(ChatSession.persona).selectinload(Persona.tools),
|
||||
selectinload(ChatSession.persona).selectinload(Persona.user_files),
|
||||
selectinload(ChatSession.project),
|
||||
)
|
||||
|
||||
if is_shared:
|
||||
stmt = stmt.where(ChatSession.shared_status == ChatSessionSharedStatus.PUBLIC)
|
||||
else:
|
||||
|
||||
@@ -511,7 +511,7 @@ def add_credential_to_connector(
|
||||
user: User,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
cc_pair_name: str | None,
|
||||
cc_pair_name: str,
|
||||
access_type: AccessType,
|
||||
groups: list[int] | None,
|
||||
auto_sync_options: dict | None = None,
|
||||
|
||||
@@ -304,3 +304,13 @@ class LLMModelFlowType(str, PyEnum):
|
||||
CHAT = "chat"
|
||||
VISION = "vision"
|
||||
CONTEXTUAL_RAG = "contextual_rag"
|
||||
|
||||
|
||||
class HookPoint(str, PyEnum):
|
||||
DOCUMENT_INGESTION = "document_ingestion"
|
||||
QUERY_PROCESSING = "query_processing"
|
||||
|
||||
|
||||
class HookFailStrategy(str, PyEnum):
|
||||
HARD = "hard" # exception propagates, pipeline aborts
|
||||
SOFT = "soft" # log error, return original input, pipeline continues
|
||||
|
||||
233
backend/onyx/db/hook.py
Normal file
233
backend/onyx/db/hook.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.constants import UNSET
|
||||
from onyx.db.constants import UnsetType
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.models import Hook
|
||||
from onyx.db.models import HookExecutionLog
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
|
||||
# ── Hook CRUD ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_hook_by_id(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_id: int,
|
||||
include_deleted: bool = False,
|
||||
include_creator: bool = False,
|
||||
) -> Hook | None:
|
||||
stmt = select(Hook).where(Hook.id == hook_id)
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(Hook.deleted.is_(False))
|
||||
if include_creator:
|
||||
stmt = stmt.options(selectinload(Hook.creator))
|
||||
return db_session.scalar(stmt)
|
||||
|
||||
|
||||
def get_non_deleted_hook_by_hook_point(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
include_creator: bool = False,
|
||||
) -> Hook | None:
|
||||
stmt = (
|
||||
select(Hook).where(Hook.hook_point == hook_point).where(Hook.deleted.is_(False))
|
||||
)
|
||||
if include_creator:
|
||||
stmt = stmt.options(selectinload(Hook.creator))
|
||||
return db_session.scalar(stmt)
|
||||
|
||||
|
||||
def get_hooks(
|
||||
*,
|
||||
db_session: Session,
|
||||
include_deleted: bool = False,
|
||||
include_creator: bool = False,
|
||||
) -> list[Hook]:
|
||||
stmt = select(Hook)
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(Hook.deleted.is_(False))
|
||||
if include_creator:
|
||||
stmt = stmt.options(selectinload(Hook.creator))
|
||||
stmt = stmt.order_by(Hook.hook_point, Hook.created_at.desc())
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def create_hook__no_commit(
|
||||
*,
|
||||
db_session: Session,
|
||||
name: str,
|
||||
hook_point: HookPoint,
|
||||
endpoint_url: str | None = None,
|
||||
api_key: str | None = None,
|
||||
fail_strategy: HookFailStrategy,
|
||||
timeout_seconds: float,
|
||||
is_active: bool = False,
|
||||
creator_id: UUID | None = None,
|
||||
) -> Hook:
|
||||
"""Create a new hook for the given hook point.
|
||||
|
||||
At most one non-deleted hook per hook point is allowed. Raises
|
||||
OnyxError(CONFLICT) if a hook already exists, including under concurrent
|
||||
duplicate creates where the partial unique index fires an IntegrityError.
|
||||
"""
|
||||
existing = get_non_deleted_hook_by_hook_point(
|
||||
db_session=db_session, hook_point=hook_point
|
||||
)
|
||||
if existing:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CONFLICT,
|
||||
f"A hook for '{hook_point.value}' already exists (id={existing.id}).",
|
||||
)
|
||||
|
||||
hook = Hook(
|
||||
name=name,
|
||||
hook_point=hook_point,
|
||||
endpoint_url=endpoint_url,
|
||||
api_key=api_key,
|
||||
fail_strategy=fail_strategy,
|
||||
timeout_seconds=timeout_seconds,
|
||||
is_active=is_active,
|
||||
creator_id=creator_id,
|
||||
)
|
||||
# Use a savepoint so that a failed insert only rolls back this operation,
|
||||
# not the entire outer transaction.
|
||||
savepoint = db_session.begin_nested()
|
||||
try:
|
||||
db_session.add(hook)
|
||||
savepoint.commit()
|
||||
except IntegrityError as exc:
|
||||
savepoint.rollback()
|
||||
if "ix_hook_one_non_deleted_per_point" in str(exc.orig):
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.CONFLICT,
|
||||
f"A hook for '{hook_point.value}' already exists.",
|
||||
)
|
||||
raise # re-raise unrelated integrity errors (FK violations, etc.)
|
||||
return hook
|
||||
|
||||
|
||||
def update_hook__no_commit(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_id: int,
|
||||
name: str | None = None,
|
||||
endpoint_url: str | None | UnsetType = UNSET,
|
||||
api_key: str | None | UnsetType = UNSET,
|
||||
fail_strategy: HookFailStrategy | None = None,
|
||||
timeout_seconds: float | None = None,
|
||||
is_active: bool | None = None,
|
||||
is_reachable: bool | None = None,
|
||||
include_creator: bool = False,
|
||||
) -> Hook:
|
||||
"""Update hook fields.
|
||||
|
||||
Sentinel conventions:
|
||||
- endpoint_url, api_key: pass UNSET to leave unchanged; pass None to clear.
|
||||
- name, fail_strategy, timeout_seconds, is_active, is_reachable: pass None to leave unchanged.
|
||||
"""
|
||||
hook = get_hook_by_id(
|
||||
db_session=db_session, hook_id=hook_id, include_creator=include_creator
|
||||
)
|
||||
if hook is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Hook with id {hook_id} not found.")
|
||||
|
||||
if name is not None:
|
||||
hook.name = name
|
||||
if not isinstance(endpoint_url, UnsetType):
|
||||
hook.endpoint_url = endpoint_url
|
||||
if not isinstance(api_key, UnsetType):
|
||||
hook.api_key = api_key # type: ignore[assignment] # EncryptedString coerces str → SensitiveValue at the ORM level
|
||||
if fail_strategy is not None:
|
||||
hook.fail_strategy = fail_strategy
|
||||
if timeout_seconds is not None:
|
||||
hook.timeout_seconds = timeout_seconds
|
||||
if is_active is not None:
|
||||
hook.is_active = is_active
|
||||
if is_reachable is not None:
|
||||
hook.is_reachable = is_reachable
|
||||
|
||||
db_session.flush()
|
||||
return hook
|
||||
|
||||
|
||||
def delete_hook__no_commit(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_id: int,
|
||||
) -> None:
|
||||
hook = get_hook_by_id(db_session=db_session, hook_id=hook_id)
|
||||
if hook is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Hook with id {hook_id} not found.")
|
||||
|
||||
hook.deleted = True
|
||||
hook.is_active = False
|
||||
db_session.flush()
|
||||
|
||||
|
||||
# ── HookExecutionLog CRUD ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def create_hook_execution_log__no_commit(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_id: int,
|
||||
is_success: bool,
|
||||
error_message: str | None = None,
|
||||
status_code: int | None = None,
|
||||
duration_ms: int | None = None,
|
||||
) -> HookExecutionLog:
|
||||
log = HookExecutionLog(
|
||||
hook_id=hook_id,
|
||||
is_success=is_success,
|
||||
error_message=error_message,
|
||||
status_code=status_code,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
db_session.add(log)
|
||||
db_session.flush()
|
||||
return log
|
||||
|
||||
|
||||
def get_hook_execution_logs(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_id: int,
|
||||
limit: int,
|
||||
) -> list[HookExecutionLog]:
|
||||
stmt = (
|
||||
select(HookExecutionLog)
|
||||
.where(HookExecutionLog.hook_id == hook_id)
|
||||
.order_by(HookExecutionLog.created_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def cleanup_old_execution_logs__no_commit(
|
||||
*,
|
||||
db_session: Session,
|
||||
max_age_days: int,
|
||||
) -> int:
|
||||
"""Delete execution logs older than max_age_days. Returns the number of rows deleted."""
|
||||
cutoff = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
|
||||
days=max_age_days
|
||||
)
|
||||
result: CursorResult = db_session.execute( # type: ignore[assignment]
|
||||
delete(HookExecutionLog)
|
||||
.where(HookExecutionLog.created_at < cutoff)
|
||||
.execution_options(synchronize_session=False)
|
||||
)
|
||||
return result.rowcount
|
||||
@@ -64,6 +64,8 @@ from onyx.db.enums import (
|
||||
BuildSessionStatus,
|
||||
EmbeddingPrecision,
|
||||
HierarchyNodeType,
|
||||
HookFailStrategy,
|
||||
HookPoint,
|
||||
IndexingMode,
|
||||
OpenSearchDocumentMigrationStatus,
|
||||
OpenSearchTenantMigrationStatus,
|
||||
@@ -5178,3 +5180,90 @@ class CacheStore(Base):
|
||||
expires_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
|
||||
class Hook(Base):
|
||||
"""Pairs a HookPoint with a customer-provided API endpoint.
|
||||
|
||||
At most one non-deleted Hook per HookPoint is allowed, enforced by a
|
||||
partial unique index on (hook_point) where deleted=false.
|
||||
"""
|
||||
|
||||
__tablename__ = "hook"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, nullable=False)
|
||||
hook_point: Mapped[HookPoint] = mapped_column(
|
||||
Enum(HookPoint, native_enum=False), nullable=False
|
||||
)
|
||||
endpoint_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=True
|
||||
)
|
||||
is_reachable: Mapped[bool | None] = mapped_column(
|
||||
Boolean, nullable=True, default=None
|
||||
) # null = never validated, true = last check passed, false = last check failed
|
||||
fail_strategy: Mapped[HookFailStrategy] = mapped_column(
|
||||
Enum(HookFailStrategy, native_enum=False),
|
||||
nullable=False,
|
||||
default=HookFailStrategy.HARD,
|
||||
)
|
||||
timeout_seconds: Mapped[float] = mapped_column(Float, nullable=False, default=30.0)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
creator_id: Mapped[UUID | None] = mapped_column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("user.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
creator: Mapped["User | None"] = relationship("User", foreign_keys=[creator_id])
|
||||
execution_logs: Mapped[list["HookExecutionLog"]] = relationship(
|
||||
"HookExecutionLog", back_populates="hook", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_hook_one_non_deleted_per_point",
|
||||
"hook_point",
|
||||
unique=True,
|
||||
postgresql_where=(deleted == False), # noqa: E712
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class HookExecutionLog(Base):
|
||||
"""Records hook executions for health monitoring and debugging.
|
||||
|
||||
Currently only failures are logged; the is_success column exists so
|
||||
success logging can be added later without a schema change.
|
||||
Retention: rows older than 30 days are deleted by a nightly Celery task.
|
||||
"""
|
||||
|
||||
__tablename__ = "hook_execution_log"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
hook_id: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
ForeignKey("hook.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
is_success: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
status_code: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
duration_ms: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False, index=True
|
||||
)
|
||||
|
||||
hook: Mapped["Hook"] = relationship("Hook", back_populates="execution_logs")
|
||||
|
||||
@@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
|
||||
from starlette.background import BackgroundTasks
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -144,6 +145,7 @@ def upload_files_to_user_files_with_indexing(
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered indexing for user_file_id={user_file.id} with task_id={task.id}"
|
||||
|
||||
@@ -2,6 +2,7 @@ import time
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP
|
||||
from onyx.configs.constants import KV_REINDEX_KEY
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
@@ -149,6 +150,9 @@ def check_and_perform_index_swap(db_session: Session) -> SearchSettings | None:
|
||||
Returns None if search settings did not change, or the old search settings if they
|
||||
did change.
|
||||
"""
|
||||
if DISABLE_VECTOR_DB:
|
||||
return None
|
||||
|
||||
# Default CC-pair created for Ingestion API unused here
|
||||
all_cc_pairs = get_connector_credential_pairs(db_session)
|
||||
cc_pair_count = max(len(all_cc_pairs) - 1, 0)
|
||||
|
||||
@@ -10,8 +10,8 @@ How `IndexFilters` fields combine into the final query filter. Applies to both V
|
||||
| **Tenant** | `tenant_id` | AND (multi-tenant only) |
|
||||
| **ACL** | `access_control_list` | OR within, AND with rest |
|
||||
| **Narrowing** | `source_type`, `tags`, `time_cutoff` | Each OR within, AND with rest |
|
||||
| **Knowledge scope** | `document_set`, `user_file_ids`, `attached_document_ids`, `hierarchy_node_ids` | OR within group, AND with rest |
|
||||
| **Additive scope** | `project_id`, `persona_id` | OR'd into knowledge scope **only when** a knowledge scope filter already exists |
|
||||
| **Knowledge scope** | `document_set`, `attached_document_ids`, `hierarchy_node_ids`, `persona_id_filter` | OR within group, AND with rest |
|
||||
| **Additive scope** | `project_id_filter` | OR'd into knowledge scope **only when** a knowledge scope filter already exists |
|
||||
|
||||
## How filters combine
|
||||
|
||||
@@ -31,12 +31,22 @@ AND time >= cutoff -- if set
|
||||
|
||||
The knowledge scope filter controls **what knowledge an assistant can access**.
|
||||
|
||||
### Primary vs additive triggers
|
||||
|
||||
- **`persona_id_filter`** is a **primary** trigger. A persona with user files IS explicit
|
||||
knowledge, so `persona_id_filter` alone can start a knowledge scope. Note: this is
|
||||
NOT the raw ID of the persona being used — it is only set when the persona's
|
||||
user files overflowed the LLM context window.
|
||||
- **`project_id_filter`** is **additive**. It widens an existing scope to include project
|
||||
files but never restricts on its own — a chat inside a project should still search
|
||||
team knowledge when no other knowledge is attached.
|
||||
|
||||
### No explicit knowledge attached
|
||||
|
||||
When `document_set`, `user_file_ids`, `attached_document_ids`, and `hierarchy_node_ids` are all empty/None:
|
||||
When `document_set`, `attached_document_ids`, `hierarchy_node_ids`, and `persona_id_filter` are all empty/None:
|
||||
|
||||
- **No knowledge scope filter is applied.** The assistant can see everything (subject to ACL).
|
||||
- `project_id` and `persona_id` are ignored — they never restrict on their own.
|
||||
- `project_id_filter` is ignored — it never restricts on its own.
|
||||
|
||||
### One explicit knowledge type
|
||||
|
||||
@@ -44,39 +54,40 @@ When `document_set`, `user_file_ids`, `attached_document_ids`, and `hierarchy_no
|
||||
-- Only document sets
|
||||
AND (document_sets contains "Engineering" OR document_sets contains "Legal")
|
||||
|
||||
-- Only user files
|
||||
AND (document_id = "uuid-1" OR document_id = "uuid-2")
|
||||
-- Only persona user files (overflowed context)
|
||||
AND (personas contains 42)
|
||||
```
|
||||
|
||||
### Multiple explicit knowledge types (OR'd)
|
||||
|
||||
```
|
||||
-- Document sets + user files
|
||||
AND (
|
||||
document_sets contains "Engineering"
|
||||
OR document_id = "uuid-1"
|
||||
)
|
||||
```
|
||||
|
||||
### Explicit knowledge + overflowing user files
|
||||
|
||||
When an explicit knowledge restriction is in effect **and** `project_id` or `persona_id` is set (user files overflowed the LLM context window), the additive scopes widen the filter:
|
||||
|
||||
```
|
||||
-- Document sets + persona user files overflowed
|
||||
-- Document sets + persona user files
|
||||
AND (
|
||||
document_sets contains "Engineering"
|
||||
OR personas contains 42
|
||||
)
|
||||
```
|
||||
|
||||
-- User files + project files overflowed
|
||||
### Explicit knowledge + overflowing project files
|
||||
|
||||
When an explicit knowledge restriction is in effect **and** `project_id_filter` is set (project files overflowed the LLM context window), `project_id_filter` widens the filter:
|
||||
|
||||
```
|
||||
-- Document sets + project files overflowed
|
||||
AND (
|
||||
document_id = "uuid-1"
|
||||
document_sets contains "Engineering"
|
||||
OR user_project contains 7
|
||||
)
|
||||
|
||||
-- Persona user files + project files (won't happen in practice;
|
||||
-- custom personas ignore project files per the precedence rule)
|
||||
AND (
|
||||
personas contains 42
|
||||
OR user_project contains 7
|
||||
)
|
||||
```
|
||||
|
||||
### Only project_id or persona_id (no explicit knowledge)
|
||||
### Only project_id_filter (no explicit knowledge)
|
||||
|
||||
No knowledge scope filter. The assistant searches everything.
|
||||
|
||||
@@ -91,11 +102,10 @@ AND (acl contains ...)
|
||||
| Filter field | Vespa field | Vespa type | Purpose |
|
||||
|---|---|---|---|
|
||||
| `document_set` | `document_sets` | `weightedset<string>` | Connector doc sets attached to assistant |
|
||||
| `user_file_ids` | `document_id` | `string` | User files uploaded to assistant |
|
||||
| `attached_document_ids` | `document_id` | `string` | Documents explicitly attached (OpenSearch only) |
|
||||
| `hierarchy_node_ids` | `ancestor_hierarchy_node_ids` | `array<int>` | Folder/space nodes (OpenSearch only) |
|
||||
| `project_id` | `user_project` | `array<int>` | Project tag for overflowing user files |
|
||||
| `persona_id` | `personas` | `array<int>` | Persona tag for overflowing user files |
|
||||
| `persona_id_filter` | `personas` | `array<int>` | Persona tag for overflowing user files (**primary** trigger) |
|
||||
| `project_id_filter` | `user_project` | `array<int>` | Project tag for overflowing project files (**additive** only) |
|
||||
| `access_control_list` | `access_control_list` | `weightedset<string>` | ACL entries for the requesting user |
|
||||
| `source_type` | `source_type` | `string` | Connector source type (e.g. `web`, `jira`) |
|
||||
| `tags` | `metadata_list` | `array<string>` | Document metadata tags |
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from contextlib import AbstractContextManager
|
||||
@@ -18,6 +19,7 @@ from onyx.configs.app_configs import OPENSEARCH_HOST
|
||||
from onyx.configs.app_configs import OPENSEARCH_REST_API_PORT
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.schema import DocumentChunk
|
||||
from onyx.document_index.opensearch.schema import DocumentChunkWithoutVectors
|
||||
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
|
||||
from onyx.document_index.opensearch.search import DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -56,8 +58,8 @@ class SearchHit(BaseModel, Generic[SchemaDocumentModel]):
|
||||
# Maps schema property name to a list of highlighted snippets with match
|
||||
# terms wrapped in tags (e.g. "something <hi>keyword</hi> other thing").
|
||||
match_highlights: dict[str, list[str]] = {}
|
||||
# Score explanation from OpenSearch when "explain": true is set in the query.
|
||||
# Contains detailed breakdown of how the score was calculated.
|
||||
# Score explanation from OpenSearch when "explain": true is set in the
|
||||
# query. Contains detailed breakdown of how the score was calculated.
|
||||
explanation: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@@ -833,9 +835,13 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def search(
|
||||
self, body: dict[str, Any], search_pipeline_id: str | None
|
||||
) -> list[SearchHit[DocumentChunk]]:
|
||||
) -> list[SearchHit[DocumentChunkWithoutVectors]]:
|
||||
"""Searches the index.
|
||||
|
||||
NOTE: Does not return vector fields. In order to take advantage of
|
||||
performance benefits, the search body should exclude the schema's vector
|
||||
fields.
|
||||
|
||||
TODO(andrei): Ideally we could check that every field in the body is
|
||||
present in the index, to avoid a class of runtime bugs that could easily
|
||||
be caught during development. Or change the function signature to accept
|
||||
@@ -883,7 +889,7 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
raise_on_timeout=True,
|
||||
)
|
||||
|
||||
search_hits: list[SearchHit[DocumentChunk]] = []
|
||||
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = []
|
||||
for hit in hits:
|
||||
document_chunk_source: dict[str, Any] | None = hit.get("_source")
|
||||
if not document_chunk_source:
|
||||
@@ -893,8 +899,10 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
document_chunk_score = hit.get("_score", None)
|
||||
match_highlights: dict[str, list[str]] = hit.get("highlight", {})
|
||||
explanation: dict[str, Any] | None = hit.get("_explanation", None)
|
||||
search_hit = SearchHit[DocumentChunk](
|
||||
document_chunk=DocumentChunk.model_validate(document_chunk_source),
|
||||
search_hit = SearchHit[DocumentChunkWithoutVectors](
|
||||
document_chunk=DocumentChunkWithoutVectors.model_validate(
|
||||
document_chunk_source
|
||||
),
|
||||
score=document_chunk_score,
|
||||
match_highlights=match_highlights,
|
||||
explanation=explanation,
|
||||
@@ -1055,7 +1063,7 @@ class OpenSearchIndexClient(OpenSearchClient):
|
||||
f"Body: {get_new_body_without_vectors(body)}\n"
|
||||
f"Search pipeline ID: {search_pipeline_id}\n"
|
||||
f"Phase took: {phase_took}\n"
|
||||
f"Profile: {profile}\n"
|
||||
f"Profile: {json.dumps(profile, indent=2)}\n"
|
||||
)
|
||||
if timed_out:
|
||||
error_str = f"OpenSearch client error: Search timed out for index {self._index_name}."
|
||||
|
||||
@@ -1,12 +1,23 @@
|
||||
# Default value for the maximum number of tokens a chunk can hold, if none is
|
||||
# specified when creating an index.
|
||||
from onyx.configs.app_configs import (
|
||||
OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
|
||||
)
|
||||
import os
|
||||
from enum import Enum
|
||||
|
||||
|
||||
DEFAULT_MAX_CHUNK_SIZE = 512
|
||||
|
||||
|
||||
# By default OpenSearch will only return a maximum of this many results in a
|
||||
# given search. This value is configurable in the index settings.
|
||||
DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW = 10_000
|
||||
|
||||
|
||||
# For documents which do not have a value for LAST_UPDATED_FIELD_NAME, we assume
|
||||
# that the document was last updated this many days ago for the purpose of time
|
||||
# cutoff filtering during retrieval.
|
||||
ASSUMED_DOCUMENT_AGE_DAYS = 90
|
||||
|
||||
|
||||
# Size of the dynamic list used to consider elements during kNN graph creation.
|
||||
# Higher values improve search quality but increase indexing time. Values
|
||||
# typically range between 100 - 512.
|
||||
@@ -26,10 +37,10 @@ M = 32 # Set relatively high for better accuracy.
|
||||
# we have a much higher chance of all 10 of the final desired docs showing up
|
||||
# and getting scored. In worse situations, the final 10 docs don't even show up
|
||||
# as the final 10 (worse than just a miss at the reranking step).
|
||||
DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = (
|
||||
OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
|
||||
if OPENSEARCH_OVERRIDE_DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES > 0
|
||||
else 750
|
||||
# Defaults to 100 for now. Initially this defaulted to 750 but we were seeing
|
||||
# poor search performance.
|
||||
DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES = int(
|
||||
os.environ.get("DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES", 100)
|
||||
)
|
||||
|
||||
# Number of vectors to examine to decide the top k neighbors for the HNSW
|
||||
@@ -39,23 +50,43 @@ DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = (
|
||||
# larger than k, you can provide the size parameter to limit the final number of
|
||||
# results to k." from
|
||||
# https://docs.opensearch.org/latest/query-dsl/specialized/k-nn/index/#ef_search
|
||||
EF_SEARCH = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
|
||||
EF_SEARCH = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES
|
||||
|
||||
# Since the titles are included in the contents, the embedding matches are
|
||||
# heavily downweighted as they act as a boost rather than an independent scoring
|
||||
# component.
|
||||
SEARCH_TITLE_VECTOR_WEIGHT = 0.1
|
||||
SEARCH_CONTENT_VECTOR_WEIGHT = 0.45
|
||||
# Single keyword weight for both title and content (merged from former title
|
||||
# keyword + content keyword).
|
||||
SEARCH_KEYWORD_WEIGHT = 0.45
|
||||
|
||||
# NOTE: It is critical that the order of these weights matches the order of the
|
||||
# sub-queries in the hybrid search.
|
||||
HYBRID_SEARCH_NORMALIZATION_WEIGHTS = [
|
||||
SEARCH_TITLE_VECTOR_WEIGHT,
|
||||
SEARCH_CONTENT_VECTOR_WEIGHT,
|
||||
SEARCH_KEYWORD_WEIGHT,
|
||||
]
|
||||
class HybridSearchSubqueryConfiguration(Enum):
|
||||
TITLE_VECTOR_CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD = 1
|
||||
# Current default.
|
||||
CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD = 2
|
||||
|
||||
assert sum(HYBRID_SEARCH_NORMALIZATION_WEIGHTS) == 1.0
|
||||
|
||||
# Will raise and block application start if HYBRID_SEARCH_SUBQUERY_CONFIGURATION
|
||||
# is set but not a valid value. If not set, defaults to
|
||||
# CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD.
|
||||
HYBRID_SEARCH_SUBQUERY_CONFIGURATION: HybridSearchSubqueryConfiguration = (
|
||||
HybridSearchSubqueryConfiguration(
|
||||
int(os.environ["HYBRID_SEARCH_SUBQUERY_CONFIGURATION"])
|
||||
)
|
||||
if os.environ.get("HYBRID_SEARCH_SUBQUERY_CONFIGURATION", None) is not None
|
||||
else HybridSearchSubqueryConfiguration.CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD
|
||||
)
|
||||
|
||||
|
||||
class HybridSearchNormalizationPipeline(Enum):
|
||||
# Current default.
|
||||
MIN_MAX = 1
|
||||
# NOTE: Using z-score normalization is better for hybrid search from a
|
||||
# theoretical standpoint. Empirically on a small dataset of up to 10K docs,
|
||||
# it's not very different. Likely more impactful at scale.
|
||||
# https://opensearch.org/blog/introducing-the-z-score-normalization-technique-for-hybrid-search/
|
||||
ZSCORE = 2
|
||||
|
||||
|
||||
# Will raise and block application start if HYBRID_SEARCH_NORMALIZATION_PIPELINE
|
||||
# is set but not a valid value. If not set, defaults to MIN_MAX.
|
||||
HYBRID_SEARCH_NORMALIZATION_PIPELINE: HybridSearchNormalizationPipeline = (
|
||||
HybridSearchNormalizationPipeline(
|
||||
int(os.environ["HYBRID_SEARCH_NORMALIZATION_PIPELINE"])
|
||||
)
|
||||
if os.environ.get("HYBRID_SEARCH_NORMALIZATION_PIPELINE", None) is not None
|
||||
else HybridSearchNormalizationPipeline.MIN_MAX
|
||||
)
|
||||
|
||||
@@ -47,6 +47,7 @@ from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import DOCUMENT_SETS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import DocumentChunk
|
||||
from onyx.document_index.opensearch.schema import DocumentChunkWithoutVectors
|
||||
from onyx.document_index.opensearch.schema import DocumentSchema
|
||||
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
|
||||
from onyx.document_index.opensearch.schema import GLOBAL_BOOST_FIELD_NAME
|
||||
@@ -55,16 +56,13 @@ from onyx.document_index.opensearch.schema import PERSONAS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.search import DocumentQuery
|
||||
from onyx.document_index.opensearch.search import (
|
||||
MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
|
||||
get_min_max_normalization_pipeline_name_and_config,
|
||||
)
|
||||
from onyx.document_index.opensearch.search import (
|
||||
MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
get_normalization_pipeline_name_and_config,
|
||||
)
|
||||
from onyx.document_index.opensearch.search import (
|
||||
ZSCORE_NORMALIZATION_PIPELINE_CONFIG,
|
||||
)
|
||||
from onyx.document_index.opensearch.search import (
|
||||
ZSCORE_NORMALIZATION_PIPELINE_NAME,
|
||||
get_zscore_normalization_pipeline_name_and_config,
|
||||
)
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import Document
|
||||
@@ -103,18 +101,24 @@ def set_cluster_state(client: OpenSearchClient) -> None:
|
||||
"is not the first time running Onyx against this instance of OpenSearch, these "
|
||||
"settings have likely already been set. Not taking any further action..."
|
||||
)
|
||||
client.create_search_pipeline(
|
||||
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
|
||||
pipeline_body=MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
|
||||
min_max_normalization_pipeline_name, min_max_normalization_pipeline_config = (
|
||||
get_min_max_normalization_pipeline_name_and_config()
|
||||
)
|
||||
zscore_normalization_pipeline_name, zscore_normalization_pipeline_config = (
|
||||
get_zscore_normalization_pipeline_name_and_config()
|
||||
)
|
||||
client.create_search_pipeline(
|
||||
pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
|
||||
pipeline_body=ZSCORE_NORMALIZATION_PIPELINE_CONFIG,
|
||||
pipeline_id=min_max_normalization_pipeline_name,
|
||||
pipeline_body=min_max_normalization_pipeline_config,
|
||||
)
|
||||
client.create_search_pipeline(
|
||||
pipeline_id=zscore_normalization_pipeline_name,
|
||||
pipeline_body=zscore_normalization_pipeline_config,
|
||||
)
|
||||
|
||||
|
||||
def _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
chunk: DocumentChunk,
|
||||
chunk: DocumentChunkWithoutVectors,
|
||||
score: float | None,
|
||||
highlights: dict[str, list[str]],
|
||||
) -> InferenceChunkUncleaned:
|
||||
@@ -877,7 +881,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
)
|
||||
results: list[InferenceChunk] = []
|
||||
for chunk_request in chunk_requests:
|
||||
search_hits: list[SearchHit[DocumentChunk]] = []
|
||||
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = []
|
||||
query_body = DocumentQuery.get_from_document_id_query(
|
||||
document_id=chunk_request.document_id,
|
||||
tenant_state=self._tenant_state,
|
||||
@@ -940,17 +944,92 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
index_filters=filters,
|
||||
include_hidden=False,
|
||||
)
|
||||
# NOTE: Using z-score normalization here because it's better for hybrid
|
||||
# search from a theoretical standpoint. Empirically on a small dataset
|
||||
# of up to 10K docs, it's not very different. Likely more impactful at
|
||||
# scale.
|
||||
# https://opensearch.org/blog/introducing-the-z-score-normalization-technique-for-hybrid-search/
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
|
||||
normalization_pipeline_name, _ = get_normalization_pipeline_name_and_config()
|
||||
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
|
||||
search_pipeline_id=normalization_pipeline_name,
|
||||
)
|
||||
|
||||
# Good place for a breakpoint to inspect the search hits if you have
|
||||
# "explain" enabled.
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
search_hit.document_chunk, search_hit.score, search_hit.match_highlights
|
||||
)
|
||||
for search_hit in search_hits
|
||||
]
|
||||
inference_chunks: list[InferenceChunk] = cleanup_content_for_chunks(
|
||||
inference_chunks_uncleaned
|
||||
)
|
||||
|
||||
return inference_chunks
|
||||
|
||||
def keyword_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Keyword retrieving {num_to_retrieve} chunks for index {self._index_name}."
|
||||
)
|
||||
query_body = DocumentQuery.get_keyword_search_query(
|
||||
query_text=query,
|
||||
num_hits=num_to_retrieve,
|
||||
tenant_state=self._tenant_state,
|
||||
# NOTE: Index filters includes metadata tags which were filtered
|
||||
# for invalid unicode at indexing time. In theory it would be
|
||||
# ideal to do filtering here as well, in practice we never did
|
||||
# that in the Vespa codepath and have not seen issues in
|
||||
# production, so we deliberately conform to the existing logic
|
||||
# in order to not unknowningly introduce a possible bug.
|
||||
index_filters=filters,
|
||||
include_hidden=False,
|
||||
)
|
||||
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
)
|
||||
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
search_hit.document_chunk, search_hit.score, search_hit.match_highlights
|
||||
)
|
||||
for search_hit in search_hits
|
||||
]
|
||||
inference_chunks: list[InferenceChunk] = cleanup_content_for_chunks(
|
||||
inference_chunks_uncleaned
|
||||
)
|
||||
|
||||
return inference_chunks
|
||||
|
||||
def semantic_retrieval(
|
||||
self,
|
||||
query_embedding: Embedding,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Semantic retrieving {num_to_retrieve} chunks for index {self._index_name}."
|
||||
)
|
||||
query_body = DocumentQuery.get_semantic_search_query(
|
||||
query_embedding=query_embedding,
|
||||
num_hits=num_to_retrieve,
|
||||
tenant_state=self._tenant_state,
|
||||
# NOTE: Index filters includes metadata tags which were filtered
|
||||
# for invalid unicode at indexing time. In theory it would be
|
||||
# ideal to do filtering here as well, in practice we never did
|
||||
# that in the Vespa codepath and have not seen issues in
|
||||
# production, so we deliberately conform to the existing logic
|
||||
# in order to not unknowningly introduce a possible bug.
|
||||
index_filters=filters,
|
||||
include_hidden=False,
|
||||
)
|
||||
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
)
|
||||
|
||||
# Good place for a breakpoint to inspect the search hits if you have "explain" enabled.
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
search_hit.document_chunk, search_hit.score, search_hit.match_highlights
|
||||
@@ -977,7 +1056,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
index_filters=filters,
|
||||
num_to_retrieve=num_to_retrieve,
|
||||
)
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
|
||||
search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search(
|
||||
body=query_body,
|
||||
search_pipeline_id=None,
|
||||
)
|
||||
|
||||
@@ -11,6 +11,8 @@ from pydantic import model_serializer
|
||||
from pydantic import model_validator
|
||||
from pydantic import SerializerFunctionWrapHandler
|
||||
|
||||
from onyx.configs.app_configs import OPENSEARCH_INDEX_NUM_REPLICAS
|
||||
from onyx.configs.app_configs import OPENSEARCH_INDEX_NUM_SHARDS
|
||||
from onyx.configs.app_configs import OPENSEARCH_TEXT_ANALYZER
|
||||
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
@@ -100,9 +102,9 @@ def set_or_convert_timezone_to_utc(value: datetime) -> datetime:
|
||||
return value
|
||||
|
||||
|
||||
class DocumentChunk(BaseModel):
|
||||
class DocumentChunkWithoutVectors(BaseModel):
|
||||
"""
|
||||
Represents a chunk of a document in the OpenSearch index.
|
||||
Represents a chunk of a document in the OpenSearch index without vectors.
|
||||
|
||||
The names of these fields are based on the OpenSearch schema. Changes to the
|
||||
schema require changes here. See get_document_schema.
|
||||
@@ -124,9 +126,7 @@ class DocumentChunk(BaseModel):
|
||||
|
||||
# Either both should be None or both should be non-None.
|
||||
title: str | None = None
|
||||
title_vector: list[float] | None = None
|
||||
content: str
|
||||
content_vector: list[float]
|
||||
|
||||
source_type: str
|
||||
# A list of key-value pairs separated by INDEX_SEPARATOR. See
|
||||
@@ -176,19 +176,9 @@ class DocumentChunk(BaseModel):
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"DocumentChunk(document_id={self.document_id}, chunk_index={self.chunk_index}, "
|
||||
f"content length={len(self.content)}, content vector length={len(self.content_vector)}, "
|
||||
f"tenant_id={self.tenant_id.tenant_id})"
|
||||
f"content length={len(self.content)}, tenant_id={self.tenant_id.tenant_id})."
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_title_and_title_vector_are_consistent(self) -> Self:
|
||||
# title and title_vector should both either be None or not.
|
||||
if self.title is not None and self.title_vector is None:
|
||||
raise ValueError("Bug: Title vector must not be None if title is not None.")
|
||||
if self.title_vector is not None and self.title is None:
|
||||
raise ValueError("Bug: Title must not be None if title vector is not None.")
|
||||
return self
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def serialize_model(
|
||||
self, handler: SerializerFunctionWrapHandler
|
||||
@@ -305,6 +295,35 @@ class DocumentChunk(BaseModel):
|
||||
return TenantState(tenant_id=value, multitenant=MULTI_TENANT)
|
||||
|
||||
|
||||
class DocumentChunk(DocumentChunkWithoutVectors):
|
||||
"""Represents a chunk of a document in the OpenSearch index.
|
||||
|
||||
The names of these fields are based on the OpenSearch schema. Changes to the
|
||||
schema require changes here. See get_document_schema.
|
||||
"""
|
||||
|
||||
model_config = {"frozen": True}
|
||||
|
||||
title_vector: list[float] | None = None
|
||||
content_vector: list[float]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"DocumentChunk(document_id={self.document_id}, chunk_index={self.chunk_index}, "
|
||||
f"content length={len(self.content)}, content vector length={len(self.content_vector)}, "
|
||||
f"tenant_id={self.tenant_id.tenant_id})"
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_title_and_title_vector_are_consistent(self) -> Self:
|
||||
# title and title_vector should both either be None or not.
|
||||
if self.title is not None and self.title_vector is None:
|
||||
raise ValueError("Bug: Title vector must not be None if title is not None.")
|
||||
if self.title_vector is not None and self.title is None:
|
||||
raise ValueError("Bug: Title must not be None if title vector is not None.")
|
||||
return self
|
||||
|
||||
|
||||
class DocumentSchema:
|
||||
"""
|
||||
Represents the schema and indexing strategies of the OpenSearch index.
|
||||
@@ -516,78 +535,35 @@ class DocumentSchema:
|
||||
|
||||
return schema
|
||||
|
||||
@staticmethod
|
||||
def get_index_settings() -> dict[str, Any]:
|
||||
"""
|
||||
Standard settings for reasonable local index and search performance.
|
||||
"""
|
||||
return {
|
||||
"index": {
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 1,
|
||||
# Required for vector search.
|
||||
"knn": True,
|
||||
"knn.algo_param.ef_search": EF_SEARCH,
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_index_settings_for_aws_managed_opensearch_st_dev() -> dict[str, Any]:
|
||||
"""
|
||||
Settings for AWS-managed OpenSearch.
|
||||
|
||||
Our AWS-managed OpenSearch cluster has 3 data nodes in 3 availability
|
||||
zones.
|
||||
- We use 3 shards to distribute load across all data nodes.
|
||||
- We use 2 replicas to ensure each shard has a copy in each
|
||||
availability zone. This is a hard requirement from AWS. The number
|
||||
of data copies, including the primary (not a replica) copy, must be
|
||||
divisible by the number of AZs.
|
||||
"""
|
||||
return {
|
||||
"index": {
|
||||
"number_of_shards": 3,
|
||||
"number_of_replicas": 2,
|
||||
# Required for vector search.
|
||||
"knn": True,
|
||||
"knn.algo_param.ef_search": EF_SEARCH,
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_index_settings_for_aws_managed_opensearch_mt_cloud() -> dict[str, Any]:
|
||||
"""
|
||||
Settings for AWS-managed OpenSearch in multi-tenant cloud.
|
||||
|
||||
324 shards very roughly targets a storage load of ~30Gb per shard, which
|
||||
according to AWS OpenSearch documentation is within a good target range.
|
||||
|
||||
As documented above we need 2 replicas for a total of 3 copies of the
|
||||
data because the cluster is configured with 3-AZ awareness.
|
||||
"""
|
||||
return {
|
||||
"index": {
|
||||
"number_of_shards": 324,
|
||||
"number_of_replicas": 2,
|
||||
# Required for vector search.
|
||||
"knn": True,
|
||||
"knn.algo_param.ef_search": EF_SEARCH,
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_index_settings_based_on_environment() -> dict[str, Any]:
|
||||
"""
|
||||
Returns the index settings based on the environment.
|
||||
"""
|
||||
if USING_AWS_MANAGED_OPENSEARCH:
|
||||
# NOTE: The number of data copies, including the primary (not a
|
||||
# replica) copy, must be divisible by the number of AZs.
|
||||
if MULTI_TENANT:
|
||||
return (
|
||||
DocumentSchema.get_index_settings_for_aws_managed_opensearch_mt_cloud()
|
||||
)
|
||||
number_of_shards = 324
|
||||
number_of_replicas = 2
|
||||
else:
|
||||
return (
|
||||
DocumentSchema.get_index_settings_for_aws_managed_opensearch_st_dev()
|
||||
)
|
||||
number_of_shards = 3
|
||||
number_of_replicas = 2
|
||||
else:
|
||||
return DocumentSchema.get_index_settings()
|
||||
number_of_shards = 1
|
||||
number_of_replicas = 1
|
||||
|
||||
if OPENSEARCH_INDEX_NUM_SHARDS is not None:
|
||||
number_of_shards = OPENSEARCH_INDEX_NUM_SHARDS
|
||||
if OPENSEARCH_INDEX_NUM_REPLICAS is not None:
|
||||
number_of_replicas = OPENSEARCH_INDEX_NUM_REPLICAS
|
||||
|
||||
return {
|
||||
"index": {
|
||||
"number_of_shards": number_of_shards,
|
||||
"number_of_replicas": number_of_replicas,
|
||||
# Required for vector search.
|
||||
"knn": True,
|
||||
"knn.algo_param.ef_search": EF_SEARCH,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,20 +3,31 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from onyx.configs.app_configs import DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S
|
||||
from onyx.configs.app_configs import OPENSEARCH_EXPLAIN_ENABLED
|
||||
from onyx.configs.app_configs import OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED
|
||||
from onyx.configs.app_configs import OPENSEARCH_PROFILING_DISABLED
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import INDEX_SEPARATOR
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import Tag
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.constants import ASSUMED_DOCUMENT_AGE_DAYS
|
||||
from onyx.document_index.opensearch.constants import (
|
||||
DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
|
||||
DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES,
|
||||
)
|
||||
from onyx.document_index.opensearch.constants import HYBRID_SEARCH_NORMALIZATION_WEIGHTS
|
||||
from onyx.document_index.opensearch.constants import (
|
||||
DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW,
|
||||
)
|
||||
from onyx.document_index.opensearch.constants import (
|
||||
HYBRID_SEARCH_NORMALIZATION_PIPELINE,
|
||||
)
|
||||
from onyx.document_index.opensearch.constants import (
|
||||
HYBRID_SEARCH_SUBQUERY_CONFIGURATION,
|
||||
)
|
||||
from onyx.document_index.opensearch.constants import HybridSearchNormalizationPipeline
|
||||
from onyx.document_index.opensearch.constants import HybridSearchSubqueryConfiguration
|
||||
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import CHUNK_INDEX_FIELD_NAME
|
||||
@@ -43,49 +54,113 @@ from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME
|
||||
|
||||
# TODO(andrei): Turn all magic dictionaries to pydantic models.
|
||||
|
||||
MIN_MAX_NORMALIZATION_PIPELINE_NAME = "normalization_pipeline_min_max"
|
||||
MIN_MAX_NORMALIZATION_PIPELINE_CONFIG: dict[str, Any] = {
|
||||
"description": "Normalization for keyword and vector scores using min-max",
|
||||
"phase_results_processors": [
|
||||
{
|
||||
# https://docs.opensearch.org/latest/search-plugins/search-pipelines/normalization-processor/
|
||||
"normalization-processor": {
|
||||
"normalization": {"technique": "min_max"},
|
||||
"combination": {
|
||||
"technique": "arithmetic_mean",
|
||||
"parameters": {"weights": HYBRID_SEARCH_NORMALIZATION_WEIGHTS},
|
||||
},
|
||||
|
||||
def _get_hybrid_search_normalization_weights() -> list[float]:
|
||||
if (
|
||||
HYBRID_SEARCH_SUBQUERY_CONFIGURATION
|
||||
is HybridSearchSubqueryConfiguration.TITLE_VECTOR_CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD
|
||||
):
|
||||
# Since the titles are included in the contents, the embedding matches
|
||||
# are heavily downweighted as they act as a boost rather than an
|
||||
# independent scoring component.
|
||||
search_title_vector_weight = 0.1
|
||||
search_content_vector_weight = 0.45
|
||||
# Single keyword weight for both title and content (merged from former
|
||||
# title keyword + content keyword).
|
||||
search_keyword_weight = 0.45
|
||||
|
||||
# NOTE: It is critical that the order of these weights matches the order
|
||||
# of the sub-queries in the hybrid search.
|
||||
hybrid_search_normalization_weights = [
|
||||
search_title_vector_weight,
|
||||
search_content_vector_weight,
|
||||
search_keyword_weight,
|
||||
]
|
||||
elif (
|
||||
HYBRID_SEARCH_SUBQUERY_CONFIGURATION
|
||||
is HybridSearchSubqueryConfiguration.CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD
|
||||
):
|
||||
search_content_vector_weight = 0.5
|
||||
# Single keyword weight for both title and content (merged from former
|
||||
# title keyword + content keyword).
|
||||
search_keyword_weight = 0.5
|
||||
|
||||
# NOTE: It is critical that the order of these weights matches the order
|
||||
# of the sub-queries in the hybrid search.
|
||||
hybrid_search_normalization_weights = [
|
||||
search_content_vector_weight,
|
||||
search_keyword_weight,
|
||||
]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Bug: Unhandled hybrid search subquery configuration: {HYBRID_SEARCH_SUBQUERY_CONFIGURATION}."
|
||||
)
|
||||
|
||||
assert (
|
||||
sum(hybrid_search_normalization_weights) == 1.0
|
||||
), "Bug: Hybrid search normalization weights do not sum to 1.0."
|
||||
|
||||
return hybrid_search_normalization_weights
|
||||
|
||||
|
||||
def get_min_max_normalization_pipeline_name_and_config() -> tuple[str, dict[str, Any]]:
|
||||
min_max_normalization_pipeline_name = "normalization_pipeline_min_max"
|
||||
min_max_normalization_pipeline_config: dict[str, Any] = {
|
||||
"description": "Normalization for keyword and vector scores using min-max",
|
||||
"phase_results_processors": [
|
||||
{
|
||||
# https://docs.opensearch.org/latest/search-plugins/search-pipelines/normalization-processor/
|
||||
"normalization-processor": {
|
||||
"normalization": {"technique": "min_max"},
|
||||
"combination": {
|
||||
"technique": "arithmetic_mean",
|
||||
"parameters": {
|
||||
"weights": _get_hybrid_search_normalization_weights()
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
return min_max_normalization_pipeline_name, min_max_normalization_pipeline_config
|
||||
|
||||
ZSCORE_NORMALIZATION_PIPELINE_NAME = "normalization_pipeline_zscore"
|
||||
ZSCORE_NORMALIZATION_PIPELINE_CONFIG: dict[str, Any] = {
|
||||
"description": "Normalization for keyword and vector scores using z-score",
|
||||
"phase_results_processors": [
|
||||
{
|
||||
# https://docs.opensearch.org/latest/search-plugins/search-pipelines/normalization-processor/
|
||||
"normalization-processor": {
|
||||
"normalization": {"technique": "z_score"},
|
||||
"combination": {
|
||||
"technique": "arithmetic_mean",
|
||||
"parameters": {"weights": HYBRID_SEARCH_NORMALIZATION_WEIGHTS},
|
||||
},
|
||||
|
||||
def get_zscore_normalization_pipeline_name_and_config() -> tuple[str, dict[str, Any]]:
|
||||
zscore_normalization_pipeline_name = "normalization_pipeline_zscore"
|
||||
zscore_normalization_pipeline_config: dict[str, Any] = {
|
||||
"description": "Normalization for keyword and vector scores using z-score",
|
||||
"phase_results_processors": [
|
||||
{
|
||||
# https://docs.opensearch.org/latest/search-plugins/search-pipelines/normalization-processor/
|
||||
"normalization-processor": {
|
||||
"normalization": {"technique": "z_score"},
|
||||
"combination": {
|
||||
"technique": "arithmetic_mean",
|
||||
"parameters": {
|
||||
"weights": _get_hybrid_search_normalization_weights()
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
return zscore_normalization_pipeline_name, zscore_normalization_pipeline_config
|
||||
|
||||
|
||||
# By default OpenSearch will only return a maximum of this many results in a
|
||||
# given search. This value is configurable in the index settings.
|
||||
DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW = 10_000
|
||||
|
||||
# For documents which do not have a value for LAST_UPDATED_FIELD_NAME, we assume
|
||||
# that the document was last updated this many days ago for the purpose of time
|
||||
# cutoff filtering during retrieval.
|
||||
ASSUMED_DOCUMENT_AGE_DAYS = 90
|
||||
def get_normalization_pipeline_name_and_config() -> tuple[str, dict[str, Any]]:
|
||||
if (
|
||||
HYBRID_SEARCH_NORMALIZATION_PIPELINE
|
||||
is HybridSearchNormalizationPipeline.MIN_MAX
|
||||
):
|
||||
return get_min_max_normalization_pipeline_name_and_config()
|
||||
elif (
|
||||
HYBRID_SEARCH_NORMALIZATION_PIPELINE is HybridSearchNormalizationPipeline.ZSCORE
|
||||
):
|
||||
return get_zscore_normalization_pipeline_name_and_config()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Bug: Unhandled hybrid search normalization pipeline: {HYBRID_SEARCH_NORMALIZATION_PIPELINE}."
|
||||
)
|
||||
|
||||
|
||||
class DocumentQuery:
|
||||
@@ -143,9 +218,8 @@ class DocumentQuery:
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
project_id_filter=index_filters.project_id_filter,
|
||||
persona_id_filter=index_filters.persona_id_filter,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=min_chunk_index,
|
||||
max_chunk_index=max_chunk_index,
|
||||
@@ -160,9 +234,17 @@ class DocumentQuery:
|
||||
# returning some number of results less than the index max allowed
|
||||
# return size.
|
||||
"size": DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW,
|
||||
"_source": get_full_document,
|
||||
# By default exclude retrieving the vector fields in order to save
|
||||
# on retrieval cost as we don't need them upstream.
|
||||
"_source": {
|
||||
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
|
||||
},
|
||||
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
|
||||
}
|
||||
if not get_full_document:
|
||||
# If we explicitly do not want the underlying document, we will only
|
||||
# retrieve IDs.
|
||||
final_get_ids_query["_source"] = False
|
||||
if not OPENSEARCH_PROFILING_DISABLED:
|
||||
final_get_ids_query["profile"] = True
|
||||
|
||||
@@ -202,9 +284,8 @@ class DocumentQuery:
|
||||
source_types=[],
|
||||
tags=[],
|
||||
document_sets=[],
|
||||
user_file_ids=[],
|
||||
project_id=None,
|
||||
persona_id=None,
|
||||
project_id_filter=None,
|
||||
persona_id_filter=None,
|
||||
time_cutoff=None,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -257,7 +338,7 @@ class DocumentQuery:
|
||||
|
||||
# TODO(andrei, yuhong): We can tune this more dynamically based on
|
||||
# num_hits.
|
||||
max_results_per_subquery = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
|
||||
max_results_per_subquery = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES
|
||||
|
||||
hybrid_search_subqueries = DocumentQuery._get_hybrid_search_subqueries(
|
||||
query_text, query_vector, vector_candidates=max_results_per_subquery
|
||||
@@ -272,18 +353,14 @@ class DocumentQuery:
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
project_id_filter=index_filters.project_id_filter,
|
||||
persona_id_filter=index_filters.persona_id_filter,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
attached_document_ids=index_filters.attached_document_ids,
|
||||
hierarchy_node_ids=index_filters.hierarchy_node_ids,
|
||||
)
|
||||
match_highlights_configuration = (
|
||||
DocumentQuery._get_match_highlights_configuration()
|
||||
)
|
||||
|
||||
# See https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
|
||||
hybrid_search_query: dict[str, Any] = {
|
||||
@@ -310,16 +387,181 @@ class DocumentQuery:
|
||||
final_hybrid_search_body: dict[str, Any] = {
|
||||
"query": hybrid_search_query,
|
||||
"size": num_hits,
|
||||
"highlight": match_highlights_configuration,
|
||||
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
|
||||
# Exclude retrieving the vector fields in order to save on
|
||||
# retrieval cost as we don't need them upstream.
|
||||
"_source": {
|
||||
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
|
||||
},
|
||||
}
|
||||
|
||||
# Explain is for scoring breakdowns.
|
||||
if not OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED:
|
||||
final_hybrid_search_body["highlight"] = (
|
||||
DocumentQuery._get_match_highlights_configuration()
|
||||
)
|
||||
|
||||
# Explain is for scoring breakdowns. Setting this significantly
|
||||
# increases query latency.
|
||||
if OPENSEARCH_EXPLAIN_ENABLED:
|
||||
final_hybrid_search_body["explain"] = True
|
||||
|
||||
return final_hybrid_search_body
|
||||
|
||||
@staticmethod
|
||||
def get_keyword_search_query(
|
||||
query_text: str,
|
||||
num_hits: int,
|
||||
tenant_state: TenantState,
|
||||
index_filters: IndexFilters,
|
||||
include_hidden: bool,
|
||||
) -> dict[str, Any]:
|
||||
"""Returns a final keyword search query.
|
||||
|
||||
This query can be directly supplied to the OpenSearch client.
|
||||
|
||||
Args:
|
||||
query_text: The text to query for.
|
||||
num_hits: The final number of hits to return.
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
index_filters: Filters for the keyword search query.
|
||||
include_hidden: Whether to include hidden documents.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the final keyword search query.
|
||||
"""
|
||||
if num_hits > DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW:
|
||||
raise ValueError(
|
||||
f"Bug: num_hits ({num_hits}) is greater than the current maximum allowed "
|
||||
f"result window ({DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW})."
|
||||
)
|
||||
|
||||
keyword_search_filters = DocumentQuery._get_search_filters(
|
||||
tenant_state=tenant_state,
|
||||
include_hidden=include_hidden,
|
||||
# TODO(andrei): We've done no filtering for PUBLIC_DOC_PAT up to
|
||||
# now. This should not cause any issues but it can introduce
|
||||
# redundant filters in queries that may affect performance.
|
||||
access_control_list=index_filters.access_control_list,
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
project_id_filter=index_filters.project_id_filter,
|
||||
persona_id_filter=index_filters.persona_id_filter,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
attached_document_ids=index_filters.attached_document_ids,
|
||||
hierarchy_node_ids=index_filters.hierarchy_node_ids,
|
||||
)
|
||||
|
||||
keyword_search_query = (
|
||||
DocumentQuery._get_title_content_combined_keyword_search_query(
|
||||
query_text, search_filters=keyword_search_filters
|
||||
)
|
||||
)
|
||||
|
||||
final_keyword_search_query: dict[str, Any] = {
|
||||
"query": keyword_search_query,
|
||||
"size": num_hits,
|
||||
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
|
||||
# Exclude retrieving the vector fields in order to save on
|
||||
# retrieval cost as we don't need them upstream.
|
||||
"_source": {
|
||||
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
|
||||
},
|
||||
}
|
||||
|
||||
if not OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED:
|
||||
final_keyword_search_query["highlight"] = (
|
||||
DocumentQuery._get_match_highlights_configuration()
|
||||
)
|
||||
|
||||
if not OPENSEARCH_PROFILING_DISABLED:
|
||||
final_keyword_search_query["profile"] = True
|
||||
|
||||
# Explain is for scoring breakdowns. Setting this significantly
|
||||
# increases query latency.
|
||||
if OPENSEARCH_EXPLAIN_ENABLED:
|
||||
final_keyword_search_query["explain"] = True
|
||||
|
||||
return final_keyword_search_query
|
||||
|
||||
@staticmethod
|
||||
def get_semantic_search_query(
|
||||
query_embedding: list[float],
|
||||
num_hits: int,
|
||||
tenant_state: TenantState,
|
||||
index_filters: IndexFilters,
|
||||
include_hidden: bool,
|
||||
) -> dict[str, Any]:
|
||||
"""Returns a final semantic search query.
|
||||
|
||||
This query can be directly supplied to the OpenSearch client.
|
||||
|
||||
Args:
|
||||
query_embedding: The vector embedding of the text to query for.
|
||||
num_hits: The final number of hits to return.
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
index_filters: Filters for the semantic search query.
|
||||
include_hidden: Whether to include hidden documents.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the final semantic search query.
|
||||
"""
|
||||
if num_hits > DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW:
|
||||
raise ValueError(
|
||||
f"Bug: num_hits ({num_hits}) is greater than the current maximum allowed "
|
||||
f"result window ({DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW})."
|
||||
)
|
||||
|
||||
semantic_search_filters = DocumentQuery._get_search_filters(
|
||||
tenant_state=tenant_state,
|
||||
include_hidden=include_hidden,
|
||||
# TODO(andrei): We've done no filtering for PUBLIC_DOC_PAT up to
|
||||
# now. This should not cause any issues but it can introduce
|
||||
# redundant filters in queries that may affect performance.
|
||||
access_control_list=index_filters.access_control_list,
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
project_id_filter=index_filters.project_id_filter,
|
||||
persona_id_filter=index_filters.persona_id_filter,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
attached_document_ids=index_filters.attached_document_ids,
|
||||
hierarchy_node_ids=index_filters.hierarchy_node_ids,
|
||||
)
|
||||
|
||||
semantic_search_query = (
|
||||
DocumentQuery._get_content_vector_similarity_search_query(
|
||||
query_embedding,
|
||||
vector_candidates=num_hits,
|
||||
search_filters=semantic_search_filters,
|
||||
)
|
||||
)
|
||||
|
||||
final_semantic_search_query: dict[str, Any] = {
|
||||
"query": semantic_search_query,
|
||||
"size": num_hits,
|
||||
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
|
||||
# Exclude retrieving the vector fields in order to save on
|
||||
# retrieval cost as we don't need them upstream.
|
||||
"_source": {
|
||||
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
|
||||
},
|
||||
}
|
||||
|
||||
if not OPENSEARCH_PROFILING_DISABLED:
|
||||
final_semantic_search_query["profile"] = True
|
||||
|
||||
# Explain is for scoring breakdowns. Setting this significantly
|
||||
# increases query latency.
|
||||
if OPENSEARCH_EXPLAIN_ENABLED:
|
||||
final_semantic_search_query["explain"] = True
|
||||
|
||||
return final_semantic_search_query
|
||||
|
||||
@staticmethod
|
||||
def get_random_search_query(
|
||||
tenant_state: TenantState,
|
||||
@@ -343,9 +585,8 @@ class DocumentQuery:
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
persona_id=index_filters.persona_id,
|
||||
project_id_filter=index_filters.project_id_filter,
|
||||
persona_id_filter=index_filters.persona_id_filter,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
@@ -371,6 +612,11 @@ class DocumentQuery:
|
||||
},
|
||||
"size": num_to_retrieve,
|
||||
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
|
||||
# Exclude retrieving the vector fields in order to save on
|
||||
# retrieval cost as we don't need them upstream.
|
||||
"_source": {
|
||||
"excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME]
|
||||
},
|
||||
}
|
||||
if not OPENSEARCH_PROFILING_DISABLED:
|
||||
final_random_search_query["profile"] = True
|
||||
@@ -385,7 +631,7 @@ class DocumentQuery:
|
||||
# search. This is higher than the number of results because the scoring
|
||||
# is hybrid. For a detailed breakdown, see where the default value is
|
||||
# set.
|
||||
vector_candidates: int = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
|
||||
vector_candidates: int = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Returns subqueries for hybrid search.
|
||||
|
||||
@@ -395,20 +641,18 @@ class DocumentQuery:
|
||||
The return of this function is not sufficient to be directly supplied to
|
||||
the OpenSearch client. See get_hybrid_search_query.
|
||||
|
||||
Matches:
|
||||
- Title vector
|
||||
- Content vector
|
||||
- Keyword (title + content, match and phrase)
|
||||
|
||||
Normalization is not performed here.
|
||||
The weights of each of these subqueries should be configured in a search
|
||||
pipeline.
|
||||
|
||||
The exact subqueries executed depend on the
|
||||
HYBRID_SEARCH_SUBQUERY_CONFIGURATION setting.
|
||||
|
||||
NOTE: For OpenSearch, 5 is the maximum number of query clauses allowed
|
||||
in a single hybrid query. Source:
|
||||
https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
|
||||
|
||||
NOTE: Each query is independent during the search phase, there is no
|
||||
NOTE: Each query is independent during the search phase; there is no
|
||||
backfilling of scores for missing query components. What this means is
|
||||
that if a document was a good vector match but did not show up for
|
||||
keyword, it gets a score of 0 for the keyword component of the hybrid
|
||||
@@ -437,74 +681,133 @@ class DocumentQuery:
|
||||
similarity search.
|
||||
"""
|
||||
# Build sub-queries for hybrid search. Order must match normalization
|
||||
# pipeline weights: title vector, content vector, keyword (title + content).
|
||||
hybrid_search_queries: list[dict[str, Any]] = [
|
||||
# 1. Title vector search
|
||||
{
|
||||
"knn": {
|
||||
TITLE_VECTOR_FIELD_NAME: {
|
||||
"vector": query_vector,
|
||||
"k": vector_candidates,
|
||||
}
|
||||
}
|
||||
},
|
||||
# 2. Content vector search
|
||||
{
|
||||
"knn": {
|
||||
CONTENT_VECTOR_FIELD_NAME: {
|
||||
"vector": query_vector,
|
||||
"k": vector_candidates,
|
||||
}
|
||||
}
|
||||
},
|
||||
# 3. Keyword (title + content) match and phrase search.
|
||||
{
|
||||
"bool": {
|
||||
"should": [
|
||||
{
|
||||
"match": {
|
||||
TITLE_FIELD_NAME: {
|
||||
"query": query_text,
|
||||
"operator": "or",
|
||||
# The title fields are strongly discounted as they are included in the content.
|
||||
# It just acts as a minor boost
|
||||
"boost": 0.1,
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"match_phrase": {
|
||||
TITLE_FIELD_NAME: {
|
||||
"query": query_text,
|
||||
"slop": 1,
|
||||
"boost": 0.2,
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"match": {
|
||||
CONTENT_FIELD_NAME: {
|
||||
"query": query_text,
|
||||
"operator": "or",
|
||||
"boost": 1.0,
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"match_phrase": {
|
||||
CONTENT_FIELD_NAME: {
|
||||
"query": query_text,
|
||||
"slop": 1,
|
||||
"boost": 1.5,
|
||||
}
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
},
|
||||
]
|
||||
# pipeline weights.
|
||||
if (
|
||||
HYBRID_SEARCH_SUBQUERY_CONFIGURATION
|
||||
is HybridSearchSubqueryConfiguration.TITLE_VECTOR_CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD
|
||||
):
|
||||
return [
|
||||
DocumentQuery._get_title_vector_similarity_search_query(
|
||||
query_vector, vector_candidates
|
||||
),
|
||||
DocumentQuery._get_content_vector_similarity_search_query(
|
||||
query_vector, vector_candidates
|
||||
),
|
||||
DocumentQuery._get_title_content_combined_keyword_search_query(
|
||||
query_text
|
||||
),
|
||||
]
|
||||
elif (
|
||||
HYBRID_SEARCH_SUBQUERY_CONFIGURATION
|
||||
is HybridSearchSubqueryConfiguration.CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD
|
||||
):
|
||||
return [
|
||||
DocumentQuery._get_content_vector_similarity_search_query(
|
||||
query_vector, vector_candidates
|
||||
),
|
||||
DocumentQuery._get_title_content_combined_keyword_search_query(
|
||||
query_text
|
||||
),
|
||||
]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Bug: Unhandled hybrid search subquery configuration: {HYBRID_SEARCH_SUBQUERY_CONFIGURATION}"
|
||||
)
|
||||
|
||||
return hybrid_search_queries
|
||||
@staticmethod
|
||||
def _get_title_vector_similarity_search_query(
|
||||
query_vector: list[float],
|
||||
vector_candidates: int = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"knn": {
|
||||
TITLE_VECTOR_FIELD_NAME: {
|
||||
"vector": query_vector,
|
||||
"k": vector_candidates,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_content_vector_similarity_search_query(
|
||||
query_vector: list[float],
|
||||
vector_candidates: int = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES,
|
||||
search_filters: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
query = {
|
||||
"knn": {
|
||||
CONTENT_VECTOR_FIELD_NAME: {
|
||||
"vector": query_vector,
|
||||
"k": vector_candidates,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if search_filters is not None:
|
||||
query["knn"][CONTENT_VECTOR_FIELD_NAME]["filter"] = {
|
||||
"bool": {"filter": search_filters}
|
||||
}
|
||||
|
||||
return query
|
||||
|
||||
@staticmethod
|
||||
def _get_title_content_combined_keyword_search_query(
|
||||
query_text: str,
|
||||
search_filters: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
query = {
|
||||
"bool": {
|
||||
"should": [
|
||||
{
|
||||
"match": {
|
||||
TITLE_FIELD_NAME: {
|
||||
"query": query_text,
|
||||
"operator": "or",
|
||||
# The title fields are strongly discounted as they are included in the content.
|
||||
# It just acts as a minor boost
|
||||
"boost": 0.1,
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"match_phrase": {
|
||||
TITLE_FIELD_NAME: {
|
||||
"query": query_text,
|
||||
"slop": 1,
|
||||
"boost": 0.2,
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"match": {
|
||||
CONTENT_FIELD_NAME: {
|
||||
"query": query_text,
|
||||
"operator": "or",
|
||||
"boost": 1.0,
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"match_phrase": {
|
||||
CONTENT_FIELD_NAME: {
|
||||
"query": query_text,
|
||||
"slop": 1,
|
||||
"boost": 1.5,
|
||||
}
|
||||
}
|
||||
},
|
||||
],
|
||||
# Ensure at least one term from the query is present in the
|
||||
# document. This defaults to 1, unless a filter or must clause
|
||||
# is supplied, in which case it defaults to 0.
|
||||
"minimum_should_match": 1,
|
||||
}
|
||||
}
|
||||
|
||||
if search_filters is not None:
|
||||
query["bool"]["filter"] = search_filters
|
||||
|
||||
return query
|
||||
|
||||
@staticmethod
|
||||
def _get_search_filters(
|
||||
@@ -514,9 +817,8 @@ class DocumentQuery:
|
||||
source_types: list[DocumentSource],
|
||||
tags: list[Tag],
|
||||
document_sets: list[str],
|
||||
user_file_ids: list[UUID],
|
||||
project_id: int | None,
|
||||
persona_id: int | None,
|
||||
project_id_filter: int | None,
|
||||
persona_id_filter: int | None,
|
||||
time_cutoff: datetime | None,
|
||||
min_chunk_index: int | None,
|
||||
max_chunk_index: int | None,
|
||||
@@ -547,12 +849,12 @@ class DocumentQuery:
|
||||
list corresponding to a tag will be retrieved.
|
||||
document_sets: If supplied, only documents with at least one
|
||||
document set ID from this list will be retrieved.
|
||||
user_file_ids: If supplied, only document IDs in this list will be
|
||||
retrieved.
|
||||
project_id: If not None, only documents with this project ID in user
|
||||
projects will be retrieved.
|
||||
persona_id: If not None, only documents whose personas array
|
||||
contains this persona ID will be retrieved.
|
||||
project_id_filter: If not None, only documents with this project ID
|
||||
in user projects will be retrieved. Additive — only applied
|
||||
when a knowledge scope already exists.
|
||||
persona_id_filter: If not None, only documents whose personas array
|
||||
contains this persona ID will be retrieved. Primary — creates
|
||||
a knowledge scope on its own.
|
||||
time_cutoff: Time cutoff for the documents to retrieve. If not None,
|
||||
Documents which were last updated before this date will not be
|
||||
returned. For documents which do not have a value for their last
|
||||
@@ -569,10 +871,6 @@ class DocumentQuery:
|
||||
NOTE: See DocumentChunk.max_chunk_size.
|
||||
document_id: The document ID to retrieve. If None, no filter will be
|
||||
applied for this. Defaults to None.
|
||||
WARNING: This filters on the same property as user_file_ids.
|
||||
Although it would never make sense to supply both, note that if
|
||||
user_file_ids is supplied and does not contain document_id, no
|
||||
matches will be retrieved.
|
||||
attached_document_ids: Document IDs explicitly attached to the
|
||||
assistant. If provided along with hierarchy_node_ids, documents
|
||||
matching EITHER criteria will be retrieved (OR logic).
|
||||
@@ -633,15 +931,6 @@ class DocumentQuery:
|
||||
)
|
||||
return document_set_filter
|
||||
|
||||
def _get_user_file_id_filter(user_file_ids: list[UUID]) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
user_file_id_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for user_file_id in user_file_ids:
|
||||
user_file_id_filter["bool"]["should"].append(
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": str(user_file_id)}}}
|
||||
)
|
||||
return user_file_id_filter
|
||||
|
||||
def _get_user_project_filter(project_id: int) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
user_project_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
@@ -742,14 +1031,17 @@ class DocumentQuery:
|
||||
# assistant can see. When none are set the assistant searches
|
||||
# everything.
|
||||
#
|
||||
# project_id / persona_id are additive: they make overflowing user files
|
||||
# findable but must NOT trigger the restriction on their own (an agent
|
||||
# with no explicit knowledge should search everything).
|
||||
# persona_id_filter is a primary trigger — a persona with user files IS
|
||||
# explicit knowledge, so it can start a knowledge scope on its own.
|
||||
#
|
||||
# project_id_filter is additive — it widens the scope to also cover
|
||||
# overflowing project files but never restricts on its own (a chat
|
||||
# inside a project should still search team knowledge).
|
||||
has_knowledge_scope = (
|
||||
attached_document_ids
|
||||
or hierarchy_node_ids
|
||||
or user_file_ids
|
||||
or document_sets
|
||||
or persona_id_filter is not None
|
||||
)
|
||||
|
||||
if has_knowledge_scope:
|
||||
@@ -764,23 +1056,17 @@ class DocumentQuery:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_hierarchy_node_filter(hierarchy_node_ids)
|
||||
)
|
||||
if user_file_ids:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_file_id_filter(user_file_ids)
|
||||
)
|
||||
if document_sets:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_document_set_filter(document_sets)
|
||||
)
|
||||
# Additive: widen scope to also cover overflowing user files, but
|
||||
# only when an explicit restriction is already in effect.
|
||||
if project_id is not None:
|
||||
if persona_id_filter is not None:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_user_project_filter(project_id)
|
||||
_get_persona_filter(persona_id_filter)
|
||||
)
|
||||
if persona_id is not None:
|
||||
if project_id_filter is not None:
|
||||
knowledge_filter["bool"]["should"].append(
|
||||
_get_persona_filter(persona_id)
|
||||
_get_user_project_filter(project_id_filter)
|
||||
)
|
||||
filter_clauses.append(knowledge_filter)
|
||||
|
||||
@@ -798,8 +1084,6 @@ class DocumentQuery:
|
||||
)
|
||||
|
||||
if document_id is not None:
|
||||
# WARNING: If user_file_ids has elements and if none of them are
|
||||
# document_id, no matches will be retrieved.
|
||||
filter_clauses.append(
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
|
||||
)
|
||||
|
||||
@@ -501,20 +501,31 @@ def query_vespa(
|
||||
response = http_client.post(SEARCH_ENDPOINT, json=params)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPError as e:
|
||||
error_base = "Failed to query Vespa"
|
||||
logger.error(
|
||||
f"{error_base}:\n"
|
||||
f"Request URL: {e.request.url}\n"
|
||||
f"Request Headers: {e.request.headers}\n"
|
||||
f"Request Payload: {params}\n"
|
||||
f"Exception: {str(e)}"
|
||||
+ (
|
||||
f"\nResponse: {e.response.text}"
|
||||
if isinstance(e, httpx.HTTPStatusError)
|
||||
else ""
|
||||
)
|
||||
response_text = (
|
||||
e.response.text if isinstance(e, httpx.HTTPStatusError) else None
|
||||
)
|
||||
raise httpx.HTTPError(error_base) from e
|
||||
status_code = (
|
||||
e.response.status_code if isinstance(e, httpx.HTTPStatusError) else None
|
||||
)
|
||||
yql_value = params.get("yql", "")
|
||||
yql_length = len(str(yql_value))
|
||||
|
||||
# Log each detail on its own line so log collectors capture them
|
||||
# as separate entries rather than truncating a single multiline msg
|
||||
logger.error(
|
||||
f"Failed to query Vespa | "
|
||||
f"status={status_code} | "
|
||||
f"yql_length={yql_length} | "
|
||||
f"exception={str(e)}"
|
||||
)
|
||||
if response_text:
|
||||
logger.error(f"Vespa error response: {response_text[:1000]}")
|
||||
logger.error(f"Vespa request URL: {e.request.url}")
|
||||
|
||||
# Re-raise with diagnostics so callers see what actually went wrong
|
||||
raise httpx.HTTPError(
|
||||
f"Failed to query Vespa (status={status_code}, " f"yql_length={yql_length})"
|
||||
) from e
|
||||
|
||||
response_json: dict[str, Any] = response.json()
|
||||
|
||||
|
||||
@@ -43,6 +43,22 @@ def build_vespa_filters(
|
||||
return ""
|
||||
return f"({' or '.join(eq_elems)})"
|
||||
|
||||
def _build_weighted_set_filter(key: str, vals: list[str] | None) -> str:
|
||||
"""Build a Vespa weightedSet filter for large value lists.
|
||||
|
||||
Uses Vespa's native weightedSet() operator instead of OR-chained
|
||||
'contains' clauses. This is critical for fields like
|
||||
access_control_list where a single user may have tens of thousands
|
||||
of ACL entries — OR clauses at that scale cause Vespa to reject
|
||||
the query with HTTP 400."""
|
||||
if not key or not vals:
|
||||
return ""
|
||||
filtered = [val for val in vals if val]
|
||||
if not filtered:
|
||||
return ""
|
||||
items = ", ".join(f'"{val}":1' for val in filtered)
|
||||
return f"weightedSet({key}, {{{items}}})"
|
||||
|
||||
def _build_int_or_filters(key: str, vals: list[int] | None) -> str:
|
||||
"""For an integer field filter.
|
||||
Returns a bare clause or ""."""
|
||||
@@ -157,11 +173,16 @@ def build_vespa_filters(
|
||||
if filters.tenant_id and MULTI_TENANT:
|
||||
filter_parts.append(build_tenant_id_filter(filters.tenant_id))
|
||||
|
||||
# ACL filters
|
||||
# ACL filters — use weightedSet for efficient matching against the
|
||||
# access_control_list weightedset<string> field. OR-chaining thousands
|
||||
# of 'contains' clauses causes Vespa to reject the query (HTTP 400)
|
||||
# for users with large numbers of external permission groups.
|
||||
if filters.access_control_list is not None:
|
||||
_append(
|
||||
filter_parts,
|
||||
_build_or_filters(ACCESS_CONTROL_LIST, filters.access_control_list),
|
||||
_build_weighted_set_filter(
|
||||
ACCESS_CONTROL_LIST, filters.access_control_list
|
||||
),
|
||||
)
|
||||
|
||||
# Source type filters
|
||||
@@ -178,31 +199,29 @@ def build_vespa_filters(
|
||||
]
|
||||
_append(filter_parts, _build_or_filters(METADATA_LIST, tag_attributes))
|
||||
|
||||
# Knowledge scope: explicit knowledge attachments (document_sets,
|
||||
# user_file_ids) restrict what an assistant can see. When none are
|
||||
# set, the assistant can see everything.
|
||||
# Knowledge scope: explicit knowledge attachments restrict what an
|
||||
# assistant can see. When none are set, the assistant can see
|
||||
# everything.
|
||||
#
|
||||
# project_id / persona_id are additive: they make overflowing user
|
||||
# files findable in Vespa but must NOT trigger the restriction on
|
||||
# their own (an agent with no explicit knowledge should search
|
||||
# everything).
|
||||
# persona_id_filter is a primary trigger — a persona with user files IS
|
||||
# explicit knowledge, so it can start a knowledge scope on its own.
|
||||
#
|
||||
# project_id_filter is additive — it widens the scope to also cover
|
||||
# overflowing project files but never restricts on its own (a chat
|
||||
# inside a project should still search team knowledge).
|
||||
knowledge_scope_parts: list[str] = []
|
||||
|
||||
_append(
|
||||
knowledge_scope_parts, _build_or_filters(DOCUMENT_SETS, filters.document_set)
|
||||
)
|
||||
_append(knowledge_scope_parts, _build_persona_filter(filters.persona_id_filter))
|
||||
|
||||
user_file_ids_str = (
|
||||
[str(uuid) for uuid in filters.user_file_ids] if filters.user_file_ids else None
|
||||
)
|
||||
_append(knowledge_scope_parts, _build_or_filters(DOCUMENT_ID, user_file_ids_str))
|
||||
|
||||
# Only include project/persona scopes when an explicit knowledge
|
||||
# restriction is already in effect — they widen the scope to also
|
||||
# cover overflowing user files but never restrict on their own.
|
||||
# project_id_filter only widens an existing scope.
|
||||
if knowledge_scope_parts:
|
||||
_append(knowledge_scope_parts, _build_user_project_filter(filters.project_id))
|
||||
_append(knowledge_scope_parts, _build_persona_filter(filters.persona_id))
|
||||
_append(
|
||||
knowledge_scope_parts,
|
||||
_build_user_project_filter(filters.project_id_filter),
|
||||
)
|
||||
|
||||
if len(knowledge_scope_parts) > 1:
|
||||
filter_parts.append("(" + " or ".join(knowledge_scope_parts) + ")")
|
||||
|
||||
@@ -35,6 +35,8 @@ class OnyxErrorCode(Enum):
|
||||
INSUFFICIENT_PERMISSIONS = ("INSUFFICIENT_PERMISSIONS", 403)
|
||||
ADMIN_ONLY = ("ADMIN_ONLY", 403)
|
||||
EE_REQUIRED = ("EE_REQUIRED", 403)
|
||||
SINGLE_TENANT_ONLY = ("SINGLE_TENANT_ONLY", 403)
|
||||
ENV_VAR_GATED = ("ENV_VAR_GATED", 403)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Validation / Bad Request (400)
|
||||
@@ -86,6 +88,7 @@ class OnyxErrorCode(Enum):
|
||||
SERVICE_UNAVAILABLE = ("SERVICE_UNAVAILABLE", 503)
|
||||
BAD_GATEWAY = ("BAD_GATEWAY", 502)
|
||||
LLM_PROVIDER_ERROR = ("LLM_PROVIDER_ERROR", 502)
|
||||
HOOK_EXECUTION_FAILED = ("HOOK_EXECUTION_FAILED", 502)
|
||||
GATEWAY_TIMEOUT = ("GATEWAY_TIMEOUT", 504)
|
||||
|
||||
def __init__(self, code: str, status_code: int) -> None:
|
||||
|
||||
@@ -38,17 +38,7 @@ def get_federated_retrieval_functions(
|
||||
source_types: list[DocumentSource] | None,
|
||||
document_set_names: list[str] | None,
|
||||
slack_context: SlackContext | None = None,
|
||||
user_file_ids: list[UUID] | None = None,
|
||||
) -> list[FederatedRetrievalInfo]:
|
||||
# When User Knowledge (user files) is the only knowledge source enabled,
|
||||
# skip federated connectors entirely. User Knowledge mode means the agent
|
||||
# should ONLY use uploaded files, not team connectors like Slack.
|
||||
if user_file_ids and not document_set_names:
|
||||
logger.debug(
|
||||
"Skipping all federated connectors: User Knowledge mode enabled "
|
||||
f"with {len(user_file_ids)} user files and no document sets"
|
||||
)
|
||||
return []
|
||||
|
||||
# Check for Slack bot context first (regardless of user_id)
|
||||
if slack_context:
|
||||
|
||||
@@ -88,9 +88,13 @@ def summarize_image_with_error_handling(
|
||||
try:
|
||||
return summarize_image_pipeline(llm, image_data, user_prompt, system_prompt)
|
||||
except UnsupportedImageFormatError:
|
||||
magic_hex = image_data[:8].hex() if image_data else "empty"
|
||||
logger.info(
|
||||
"Skipping image summarization due to unsupported MIME type for %s",
|
||||
"Skipping image summarization due to unsupported MIME type "
|
||||
"for %s (magic_bytes=%s, size=%d bytes)",
|
||||
context_name,
|
||||
magic_hex,
|
||||
len(image_data),
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -134,9 +138,23 @@ def _summarize_image(
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Summarization failed. Messages: {messages}"
|
||||
error_msg = error_msg[:1024]
|
||||
raise ValueError(error_msg) from e
|
||||
# Extract structured details from LiteLLM exceptions when available,
|
||||
# rather than dumping the full messages payload (which contains base64
|
||||
# image data and produces enormous, unreadable error logs).
|
||||
str_e = str(e)
|
||||
if len(str_e) > 512:
|
||||
str_e = str_e[:512] + "... (truncated)"
|
||||
parts = [f"Summarization failed: {type(e).__name__}: {str_e}"]
|
||||
status_code = getattr(e, "status_code", None)
|
||||
llm_provider = getattr(e, "llm_provider", None)
|
||||
model = getattr(e, "model", None)
|
||||
if status_code is not None:
|
||||
parts.append(f"status_code={status_code}")
|
||||
if llm_provider is not None:
|
||||
parts.append(f"llm_provider={llm_provider}")
|
||||
if model is not None:
|
||||
parts.append(f"model={model}")
|
||||
raise ValueError(" | ".join(parts)) from e
|
||||
|
||||
|
||||
def _encode_image_for_llm_prompt(image_data: bytes) -> str:
|
||||
|
||||
@@ -23,45 +23,55 @@ from onyx.utils.timing import log_function_time
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def user_file_id_to_plaintext_file_name(user_file_id: UUID) -> str:
|
||||
"""Generate a consistent file name for storing plaintext content of a user file."""
|
||||
return f"plaintext_{user_file_id}"
|
||||
def plaintext_file_name_for_id(file_id: str) -> str:
|
||||
"""Generate a consistent file name for storing plaintext content of a file."""
|
||||
return f"plaintext_{file_id}"
|
||||
|
||||
|
||||
def store_user_file_plaintext(user_file_id: UUID, plaintext_content: str) -> bool:
|
||||
def store_plaintext(file_id: str, plaintext_content: str) -> bool:
|
||||
"""
|
||||
Store plaintext content for a user file in the file store.
|
||||
Store plaintext content for a file in the file store.
|
||||
|
||||
Args:
|
||||
user_file_id: The ID of the user file
|
||||
file_id: The ID of the file (user_file or artifact_file)
|
||||
plaintext_content: The plaintext content to store
|
||||
|
||||
Returns:
|
||||
bool: True if storage was successful, False otherwise
|
||||
"""
|
||||
# Skip empty content
|
||||
if not plaintext_content:
|
||||
return False
|
||||
|
||||
# Get plaintext file name
|
||||
plaintext_file_name = user_file_id_to_plaintext_file_name(user_file_id)
|
||||
|
||||
plaintext_file_name = plaintext_file_name_for_id(file_id)
|
||||
try:
|
||||
file_store = get_default_file_store()
|
||||
file_content = BytesIO(plaintext_content.encode("utf-8"))
|
||||
file_store.save_file(
|
||||
content=file_content,
|
||||
display_name=f"Plaintext for user file {user_file_id}",
|
||||
display_name=f"Plaintext for {file_id}",
|
||||
file_origin=FileOrigin.PLAINTEXT_CACHE,
|
||||
file_type="text/plain",
|
||||
file_id=plaintext_file_name,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store plaintext for user file {user_file_id}: {e}")
|
||||
logger.warning(f"Failed to store plaintext for {file_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# --- Convenience wrappers for callers that use user-file UUIDs ---
|
||||
|
||||
|
||||
def user_file_id_to_plaintext_file_name(user_file_id: UUID) -> str:
|
||||
"""Generate a consistent file name for storing plaintext content of a user file."""
|
||||
return plaintext_file_name_for_id(str(user_file_id))
|
||||
|
||||
|
||||
def store_user_file_plaintext(user_file_id: UUID, plaintext_content: str) -> bool:
|
||||
"""Store plaintext content for a user file (delegates to :func:`store_plaintext`)."""
|
||||
return store_plaintext(str(user_file_id), plaintext_content)
|
||||
|
||||
|
||||
def load_chat_file_by_id(file_id: str) -> InMemoryChatFile:
|
||||
"""Load a file directly from the file store using its file_record ID.
|
||||
|
||||
|
||||
0
backend/onyx/hooks/__init__.py
Normal file
0
backend/onyx/hooks/__init__.py
Normal file
26
backend/onyx/hooks/api_dependencies.py
Normal file
26
backend/onyx/hooks/api_dependencies.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from onyx.configs.app_configs import HOOK_ENABLED
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
def require_hook_enabled() -> None:
|
||||
"""FastAPI dependency that gates all hook management endpoints.
|
||||
|
||||
Hooks are only available in single-tenant / self-hosted deployments with
|
||||
HOOK_ENABLED=true explicitly set. Two layers of protection:
|
||||
1. MULTI_TENANT check — rejects even if HOOK_ENABLED is accidentally set true
|
||||
2. HOOK_ENABLED flag — explicit opt-in by the operator
|
||||
|
||||
Use as: Depends(require_hook_enabled)
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.SINGLE_TENANT_ONLY,
|
||||
"Hooks are not available in multi-tenant deployments",
|
||||
)
|
||||
if not HOOK_ENABLED:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.ENV_VAR_GATED,
|
||||
"Hooks are not enabled. Set HOOK_ENABLED=true to enable.",
|
||||
)
|
||||
330
backend/onyx/hooks/executor.py
Normal file
330
backend/onyx/hooks/executor.py
Normal file
@@ -0,0 +1,330 @@
|
||||
"""Hook executor — calls a customer's external HTTP endpoint for a given hook point.
|
||||
|
||||
Usage (Celery tasks and FastAPI handlers):
|
||||
result = execute_hook(
|
||||
db_session=db_session,
|
||||
hook_point=HookPoint.QUERY_PROCESSING,
|
||||
payload={"query": "...", "user_email": "...", "chat_session_id": "..."},
|
||||
)
|
||||
|
||||
if isinstance(result, HookSkipped):
|
||||
# no active hook configured — continue with original behavior
|
||||
...
|
||||
elif isinstance(result, HookSoftFailed):
|
||||
# hook failed but fail strategy is SOFT — continue with original behavior
|
||||
...
|
||||
else:
|
||||
# result is the response payload dict from the customer's endpoint
|
||||
...
|
||||
|
||||
is_reachable update policy
|
||||
--------------------------
|
||||
``is_reachable`` on the Hook row is updated selectively — only when the outcome
|
||||
carries meaningful signal about physical reachability:
|
||||
|
||||
NetworkError (DNS, connection refused) → False (cannot reach the server)
|
||||
HTTP 401 / 403 → False (api_key revoked or invalid)
|
||||
TimeoutException → None (server may be slow, skip write)
|
||||
Other HTTP errors (4xx / 5xx) → None (server responded, skip write)
|
||||
Unknown exception → None (no signal, skip write)
|
||||
Non-JSON / non-dict response → None (server responded, skip write)
|
||||
Success (2xx, valid dict) → True (confirmed reachable)
|
||||
|
||||
None means "leave the current value unchanged" — no DB round-trip is made.
|
||||
|
||||
DB session design
|
||||
-----------------
|
||||
The executor uses three sessions:
|
||||
|
||||
1. Caller's session (db_session) — used only for the hook lookup read. All
|
||||
needed fields are extracted from the Hook object before the HTTP call, so
|
||||
the caller's session is not held open during the external HTTP request.
|
||||
|
||||
2. Log session — a separate short-lived session opened after the HTTP call
|
||||
completes to write the HookExecutionLog row on failure. Success runs are
|
||||
not recorded. Committed independently of everything else.
|
||||
|
||||
3. Reachable session — a second short-lived session to update is_reachable on
|
||||
the Hook. Kept separate from the log session so a concurrent hook deletion
|
||||
(which causes update_hook__no_commit to raise OnyxError(NOT_FOUND)) cannot
|
||||
prevent the execution log from being written. This update is best-effort.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.hook import create_hook_execution_log__no_commit
|
||||
from onyx.db.hook import get_non_deleted_hook_by_hook_point
|
||||
from onyx.db.hook import update_hook__no_commit
|
||||
from onyx.db.models import Hook
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.utils import HOOKS_AVAILABLE
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class HookSkipped:
|
||||
"""No active hook configured for this hook point."""
|
||||
|
||||
|
||||
class HookSoftFailed:
|
||||
"""Hook was called but failed with SOFT fail strategy — continuing."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _HttpOutcome(BaseModel):
|
||||
"""Structured result of an HTTP hook call, returned by _process_response."""
|
||||
|
||||
is_success: bool
|
||||
updated_is_reachable: (
|
||||
bool | None
|
||||
) # True/False = write to DB, None = unchanged (skip write)
|
||||
status_code: int | None
|
||||
error_message: str | None
|
||||
response_payload: dict[str, Any] | None
|
||||
|
||||
|
||||
def _lookup_hook(
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
) -> Hook | HookSkipped:
|
||||
"""Return the active Hook or HookSkipped if hooks are unavailable/unconfigured.
|
||||
|
||||
No HTTP call is made and no DB writes are performed for any HookSkipped path.
|
||||
There is nothing to log and no reachability information to update.
|
||||
"""
|
||||
if not HOOKS_AVAILABLE:
|
||||
return HookSkipped()
|
||||
hook = get_non_deleted_hook_by_hook_point(
|
||||
db_session=db_session, hook_point=hook_point
|
||||
)
|
||||
if hook is None or not hook.is_active:
|
||||
return HookSkipped()
|
||||
if not hook.endpoint_url:
|
||||
return HookSkipped()
|
||||
return hook
|
||||
|
||||
|
||||
def _process_response(
|
||||
*,
|
||||
response: httpx.Response | None,
|
||||
exc: Exception | None,
|
||||
timeout: float,
|
||||
) -> _HttpOutcome:
|
||||
"""Process the result of an HTTP call and return a structured outcome.
|
||||
|
||||
Called after the client.post() try/except. If post() raised, exc is set and
|
||||
response is None. Otherwise response is set and exc is None. Handles
|
||||
raise_for_status(), JSON decoding, and the dict shape check.
|
||||
"""
|
||||
if exc is not None:
|
||||
if isinstance(exc, httpx.NetworkError):
|
||||
msg = f"Hook network error (endpoint unreachable): {exc}"
|
||||
logger.warning(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=False,
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
if isinstance(exc, httpx.TimeoutException):
|
||||
msg = f"Hook timed out after {timeout}s: {exc}"
|
||||
logger.warning(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # timeout doesn't indicate unreachability
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
msg = f"Hook call failed: {exc}"
|
||||
logger.exception(msg, exc_info=exc)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # unknown error — don't make assumptions
|
||||
status_code=None,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise ValueError(
|
||||
"exactly one of response or exc must be non-None; both are None"
|
||||
)
|
||||
status_code = response.status_code
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
msg = f"Hook returned HTTP {e.response.status_code}: {e.response.text}"
|
||||
logger.warning(msg, exc_info=e)
|
||||
# 401/403 means the api_key has been revoked or is invalid — mark unreachable
|
||||
# so the operator knows to update it. All other HTTP errors keep is_reachable
|
||||
# as-is (server is up, the request just failed for application reasons).
|
||||
auth_failed = e.response.status_code in (401, 403)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=False if auth_failed else None,
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
try:
|
||||
response_payload = response.json()
|
||||
except (json.JSONDecodeError, httpx.DecodingError) as e:
|
||||
msg = f"Hook returned non-JSON response: {e}"
|
||||
logger.warning(msg, exc_info=e)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
if not isinstance(response_payload, dict):
|
||||
msg = f"Hook returned non-dict JSON (got {type(response_payload).__name__})"
|
||||
logger.warning(msg)
|
||||
return _HttpOutcome(
|
||||
is_success=False,
|
||||
updated_is_reachable=None, # server responded — reachability unchanged
|
||||
status_code=status_code,
|
||||
error_message=msg,
|
||||
response_payload=None,
|
||||
)
|
||||
|
||||
return _HttpOutcome(
|
||||
is_success=True,
|
||||
updated_is_reachable=True,
|
||||
status_code=status_code,
|
||||
error_message=None,
|
||||
response_payload=response_payload,
|
||||
)
|
||||
|
||||
|
||||
def _persist_result(
|
||||
*,
|
||||
hook_id: int,
|
||||
outcome: _HttpOutcome,
|
||||
duration_ms: int,
|
||||
) -> None:
|
||||
"""Write the execution log on failure and optionally update is_reachable, each
|
||||
in its own session so a failure in one does not affect the other."""
|
||||
# Only write the execution log on failure — success runs are not recorded.
|
||||
# Must not be skipped if the is_reachable update fails (e.g. hook concurrently
|
||||
# deleted between the initial lookup and here).
|
||||
if not outcome.is_success:
|
||||
try:
|
||||
with get_session_with_current_tenant() as log_session:
|
||||
create_hook_execution_log__no_commit(
|
||||
db_session=log_session,
|
||||
hook_id=hook_id,
|
||||
is_success=False,
|
||||
error_message=outcome.error_message,
|
||||
status_code=outcome.status_code,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
log_session.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to persist hook execution log for hook_id={hook_id}"
|
||||
)
|
||||
|
||||
# Update is_reachable separately — best-effort, non-critical.
|
||||
# None means the value is unchanged (set by the caller to skip the no-op write).
|
||||
# update_hook__no_commit can raise OnyxError(NOT_FOUND) if the hook was
|
||||
# concurrently deleted, so keep this isolated from the log write above.
|
||||
if outcome.updated_is_reachable is not None:
|
||||
try:
|
||||
with get_session_with_current_tenant() as reachable_session:
|
||||
update_hook__no_commit(
|
||||
db_session=reachable_session,
|
||||
hook_id=hook_id,
|
||||
is_reachable=outcome.updated_is_reachable,
|
||||
)
|
||||
reachable_session.commit()
|
||||
except Exception:
|
||||
logger.warning(f"Failed to update is_reachable for hook_id={hook_id}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def execute_hook(
|
||||
*,
|
||||
db_session: Session,
|
||||
hook_point: HookPoint,
|
||||
payload: dict[str, Any],
|
||||
) -> dict[str, Any] | HookSkipped | HookSoftFailed:
|
||||
"""Execute the hook for the given hook point synchronously."""
|
||||
hook = _lookup_hook(db_session, hook_point)
|
||||
if isinstance(hook, HookSkipped):
|
||||
return hook
|
||||
|
||||
timeout = hook.timeout_seconds
|
||||
hook_id = hook.id
|
||||
fail_strategy = hook.fail_strategy
|
||||
endpoint_url = hook.endpoint_url
|
||||
current_is_reachable: bool | None = hook.is_reachable
|
||||
if not endpoint_url:
|
||||
raise ValueError(
|
||||
f"hook_id={hook_id} is active but has no endpoint_url — "
|
||||
"active hooks without an endpoint_url must be rejected by _lookup_hook"
|
||||
)
|
||||
|
||||
start = time.monotonic()
|
||||
response: httpx.Response | None = None
|
||||
exc: Exception | None = None
|
||||
try:
|
||||
api_key: str | None = (
|
||||
hook.api_key.get_value(apply_mask=False) if hook.api_key else None
|
||||
)
|
||||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
with httpx.Client(timeout=timeout) as client:
|
||||
response = client.post(endpoint_url, json=payload, headers=headers)
|
||||
except Exception as e:
|
||||
exc = e
|
||||
duration_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
outcome = _process_response(response=response, exc=exc, timeout=timeout)
|
||||
# Skip the is_reachable write when the value would not change — avoids a
|
||||
# no-op DB round-trip on every call when the hook is already in the expected state.
|
||||
if outcome.updated_is_reachable == current_is_reachable:
|
||||
outcome = outcome.model_copy(update={"updated_is_reachable": None})
|
||||
_persist_result(hook_id=hook_id, outcome=outcome, duration_ms=duration_ms)
|
||||
|
||||
if not outcome.is_success:
|
||||
if fail_strategy == HookFailStrategy.HARD:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.HOOK_EXECUTION_FAILED,
|
||||
outcome.error_message or "Hook execution failed.",
|
||||
)
|
||||
logger.warning(
|
||||
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
|
||||
)
|
||||
return HookSoftFailed()
|
||||
if outcome.response_payload is None:
|
||||
raise ValueError(
|
||||
f"response_payload is None for successful hook call (hook_id={hook_id})"
|
||||
)
|
||||
return outcome.response_payload
|
||||
121
backend/onyx/hooks/models.py
Normal file
121
backend/onyx/hooks/models.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Annotated
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
from pydantic import model_validator
|
||||
from pydantic import SecretStr
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
|
||||
NonEmptySecretStr = Annotated[SecretStr, Field(min_length=1)]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class HookCreateRequest(BaseModel):
|
||||
name: str = Field(min_length=1)
|
||||
hook_point: HookPoint
|
||||
endpoint_url: str = Field(min_length=1)
|
||||
api_key: NonEmptySecretStr | None = None
|
||||
fail_strategy: HookFailStrategy | None = None # if None, uses HookPointSpec default
|
||||
timeout_seconds: float | None = Field(
|
||||
default=None, gt=0
|
||||
) # if None, uses HookPointSpec default
|
||||
|
||||
@field_validator("name", "endpoint_url")
|
||||
@classmethod
|
||||
def no_whitespace_only(cls, v: str) -> str:
|
||||
if not v.strip():
|
||||
raise ValueError("cannot be whitespace-only.")
|
||||
return v
|
||||
|
||||
|
||||
class HookUpdateRequest(BaseModel):
|
||||
name: str | None = None
|
||||
endpoint_url: str | None = None
|
||||
api_key: NonEmptySecretStr | None = None
|
||||
fail_strategy: HookFailStrategy | None = None
|
||||
timeout_seconds: float | None = Field(default=None, gt=0)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def require_at_least_one_field(self) -> "HookUpdateRequest":
|
||||
if not self.model_fields_set:
|
||||
raise ValueError("At least one field must be provided for an update.")
|
||||
if "name" in self.model_fields_set and not (self.name or "").strip():
|
||||
raise ValueError("name cannot be cleared.")
|
||||
if (
|
||||
"endpoint_url" in self.model_fields_set
|
||||
and not (self.endpoint_url or "").strip()
|
||||
):
|
||||
raise ValueError("endpoint_url cannot be cleared.")
|
||||
if "fail_strategy" in self.model_fields_set and self.fail_strategy is None:
|
||||
raise ValueError(
|
||||
"fail_strategy cannot be null; omit the field to leave it unchanged."
|
||||
)
|
||||
if "timeout_seconds" in self.model_fields_set and self.timeout_seconds is None:
|
||||
raise ValueError(
|
||||
"timeout_seconds cannot be null; omit the field to leave it unchanged."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class HookPointMetaResponse(BaseModel):
|
||||
hook_point: HookPoint
|
||||
display_name: str
|
||||
description: str
|
||||
docs_url: str | None
|
||||
input_schema: dict[str, Any]
|
||||
output_schema: dict[str, Any]
|
||||
default_timeout_seconds: float
|
||||
default_fail_strategy: HookFailStrategy
|
||||
fail_hard_description: str
|
||||
|
||||
|
||||
class HookResponse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
hook_point: HookPoint
|
||||
# Nullable to match the DB column — endpoint_url is required on creation but
|
||||
# future hook point types may not use an external endpoint (e.g. built-in handlers).
|
||||
endpoint_url: str | None
|
||||
fail_strategy: HookFailStrategy
|
||||
timeout_seconds: float # always resolved — None from request is replaced with spec default before DB write
|
||||
is_active: bool
|
||||
is_reachable: bool | None
|
||||
creator_email: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class HookValidateStatus(str, Enum):
|
||||
passed = "passed" # server responded (any status except 401/403)
|
||||
auth_failed = "auth_failed" # server responded with 401 or 403
|
||||
timeout = (
|
||||
"timeout" # TCP connected, but read/write timed out (server exists but slow)
|
||||
)
|
||||
cannot_connect = "cannot_connect" # could not connect to the server
|
||||
|
||||
|
||||
class HookValidateResponse(BaseModel):
|
||||
status: HookValidateStatus
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class HookExecutionRecord(BaseModel):
|
||||
error_message: str | None = None
|
||||
status_code: int | None = None
|
||||
duration_ms: int | None = None
|
||||
created_at: datetime
|
||||
0
backend/onyx/hooks/points/__init__.py
Normal file
0
backend/onyx/hooks/points/__init__.py
Normal file
75
backend/onyx/hooks/points/base.py
Normal file
75
backend/onyx/hooks/points/base.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from typing import Any
|
||||
from typing import ClassVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
|
||||
|
||||
_REQUIRED_ATTRS = (
|
||||
"hook_point",
|
||||
"display_name",
|
||||
"description",
|
||||
"default_timeout_seconds",
|
||||
"fail_hard_description",
|
||||
"default_fail_strategy",
|
||||
"payload_model",
|
||||
"response_model",
|
||||
)
|
||||
|
||||
|
||||
class HookPointSpec:
|
||||
"""Static metadata and contract for a pipeline hook point.
|
||||
|
||||
Each concrete subclass represents exactly one hook point and is instantiated
|
||||
once at startup, registered in onyx.hooks.registry._REGISTRY. Prefer
|
||||
get_hook_point_spec() or get_all_specs() from the registry over direct
|
||||
instantiation.
|
||||
|
||||
Each hook point is a concrete subclass of this class. Onyx engineers
|
||||
own these definitions — customers never touch this code.
|
||||
|
||||
Subclasses must define all attributes as class-level constants.
|
||||
payload_model and response_model must be Pydantic BaseModel subclasses;
|
||||
input_schema and output_schema are derived from them automatically.
|
||||
"""
|
||||
|
||||
hook_point: HookPoint
|
||||
display_name: str
|
||||
description: str
|
||||
default_timeout_seconds: float
|
||||
fail_hard_description: str
|
||||
default_fail_strategy: HookFailStrategy
|
||||
docs_url: str | None = None
|
||||
|
||||
payload_model: ClassVar[type[BaseModel]]
|
||||
response_model: ClassVar[type[BaseModel]]
|
||||
|
||||
# Computed once at class definition time from payload_model / response_model.
|
||||
input_schema: ClassVar[dict[str, Any]]
|
||||
output_schema: ClassVar[dict[str, Any]]
|
||||
|
||||
def __init_subclass__(cls, **kwargs: object) -> None:
|
||||
"""Enforce that every concrete subclass declares all required class attributes.
|
||||
|
||||
Called automatically by Python whenever a class inherits from HookPointSpec.
|
||||
Abstract subclasses (those still carrying unimplemented abstract methods) are
|
||||
skipped — they are intermediate base classes and may not yet define everything.
|
||||
Only fully concrete subclasses are validated, ensuring a clear TypeError at
|
||||
import time rather than a confusing AttributeError at runtime.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
missing = [attr for attr in _REQUIRED_ATTRS if not hasattr(cls, attr)]
|
||||
if missing:
|
||||
raise TypeError(f"{cls.__name__} must define class attributes: {missing}")
|
||||
for attr in ("payload_model", "response_model"):
|
||||
val = getattr(cls, attr, None)
|
||||
if val is None or not (
|
||||
isinstance(val, type) and issubclass(val, BaseModel)
|
||||
):
|
||||
raise TypeError(
|
||||
f"{cls.__name__}.{attr} must be a Pydantic BaseModel subclass, got {val!r}"
|
||||
)
|
||||
cls.input_schema = cls.payload_model.model_json_schema()
|
||||
cls.output_schema = cls.response_model.model_json_schema()
|
||||
31
backend/onyx/hooks/points/document_ingestion.py
Normal file
31
backend/onyx/hooks/points/document_ingestion.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.base import HookPointSpec
|
||||
|
||||
|
||||
# TODO(@Bo-Onyx): define payload and response fields
|
||||
class DocumentIngestionPayload(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class DocumentIngestionResponse(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class DocumentIngestionSpec(HookPointSpec):
|
||||
"""Hook point that runs during document ingestion.
|
||||
|
||||
# TODO(@Bo-Onyx): define call site, input/output schema, and timeout budget.
|
||||
"""
|
||||
|
||||
hook_point = HookPoint.DOCUMENT_INGESTION
|
||||
display_name = "Document Ingestion"
|
||||
description = "Runs during document ingestion. Allows filtering or transforming documents before indexing."
|
||||
default_timeout_seconds = 30.0
|
||||
fail_hard_description = "The document will not be indexed."
|
||||
default_fail_strategy = HookFailStrategy.HARD
|
||||
|
||||
payload_model = DocumentIngestionPayload
|
||||
response_model = DocumentIngestionResponse
|
||||
70
backend/onyx/hooks/points/query_processing.py
Normal file
70
backend/onyx/hooks/points/query_processing.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.db.enums import HookFailStrategy
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.base import HookPointSpec
|
||||
|
||||
|
||||
class QueryProcessingPayload(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
query: str = Field(description="The raw query string exactly as the user typed it.")
|
||||
user_email: str | None = Field(
|
||||
description="Email of the user submitting the query, or null if unauthenticated."
|
||||
)
|
||||
chat_session_id: str = Field(
|
||||
description="UUID of the chat session. Always present — the session is guaranteed to exist by the time this hook fires."
|
||||
)
|
||||
|
||||
|
||||
class QueryProcessingResponse(BaseModel):
|
||||
# Intentionally permissive — customer endpoints may return extra fields.
|
||||
query: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"The query to use in the pipeline. "
|
||||
"Null, empty string, or absent = reject the query."
|
||||
),
|
||||
)
|
||||
rejection_message: str | None = Field(
|
||||
default=None,
|
||||
description="Message shown to the user when the query is rejected. Falls back to a generic message if not provided.",
|
||||
)
|
||||
|
||||
|
||||
class QueryProcessingSpec(HookPointSpec):
|
||||
"""Hook point that runs on every user query before it enters the pipeline.
|
||||
|
||||
Call site: inside handle_stream_message_objects() in
|
||||
backend/onyx/chat/process_message.py, immediately after message_text is
|
||||
assigned from the request and before create_new_chat_message() saves it.
|
||||
|
||||
This is the earliest possible point in the query pipeline:
|
||||
- Raw query — unmodified, exactly as the user typed it
|
||||
- No side effects yet — message has not been saved to DB
|
||||
- User identity is available for user-specific logic
|
||||
|
||||
Supported use cases:
|
||||
- Query rejection: block queries based on content or user context
|
||||
- Query rewriting: normalize, expand, or modify the query
|
||||
- PII removal: scrub sensitive data before the LLM sees it
|
||||
- Access control: reject queries from certain users or groups
|
||||
- Query auditing: log or track queries based on business rules
|
||||
"""
|
||||
|
||||
hook_point = HookPoint.QUERY_PROCESSING
|
||||
display_name = "Query Processing"
|
||||
description = (
|
||||
"Runs on every user query before it enters the pipeline. "
|
||||
"Allows rewriting, filtering, or rejecting queries."
|
||||
)
|
||||
default_timeout_seconds = 5.0 # user is actively waiting — keep tight
|
||||
fail_hard_description = (
|
||||
"The query will be blocked and the user will see an error message."
|
||||
)
|
||||
default_fail_strategy = HookFailStrategy.HARD
|
||||
|
||||
payload_model = QueryProcessingPayload
|
||||
response_model = QueryProcessingResponse
|
||||
45
backend/onyx/hooks/registry.py
Normal file
45
backend/onyx/hooks/registry.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.hooks.points.base import HookPointSpec
|
||||
from onyx.hooks.points.document_ingestion import DocumentIngestionSpec
|
||||
from onyx.hooks.points.query_processing import QueryProcessingSpec
|
||||
|
||||
# Internal: use `monkeypatch.setattr(registry_module, "_REGISTRY", {...})` to override in tests.
|
||||
_REGISTRY: dict[HookPoint, HookPointSpec] = {
|
||||
HookPoint.DOCUMENT_INGESTION: DocumentIngestionSpec(),
|
||||
HookPoint.QUERY_PROCESSING: QueryProcessingSpec(),
|
||||
}
|
||||
|
||||
|
||||
def validate_registry() -> None:
|
||||
"""Assert that every HookPoint enum value has a registered spec.
|
||||
|
||||
Call once at application startup (e.g. from the FastAPI lifespan hook).
|
||||
Raises RuntimeError if any hook point is missing a spec.
|
||||
"""
|
||||
missing = set(HookPoint) - set(_REGISTRY)
|
||||
if missing:
|
||||
raise RuntimeError(
|
||||
f"Hook point(s) have no registered spec: {missing}. "
|
||||
"Add an entry to onyx.hooks.registry._REGISTRY."
|
||||
)
|
||||
|
||||
|
||||
def get_hook_point_spec(hook_point: HookPoint) -> HookPointSpec:
|
||||
"""Returns the spec for a given hook point.
|
||||
|
||||
Raises ValueError if the hook point has no registered spec — this is a
|
||||
programmer error; every HookPoint enum value must have a corresponding spec
|
||||
in _REGISTRY.
|
||||
"""
|
||||
try:
|
||||
return _REGISTRY[hook_point]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"No spec registered for hook point {hook_point!r}. "
|
||||
"Add an entry to onyx.hooks.registry._REGISTRY."
|
||||
)
|
||||
|
||||
|
||||
def get_all_specs() -> list[HookPointSpec]:
|
||||
"""Returns the specs for all registered hook points."""
|
||||
return list(_REGISTRY.values())
|
||||
5
backend/onyx/hooks/utils.py
Normal file
5
backend/onyx/hooks/utils.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from onyx.configs.app_configs import HOOK_ENABLED
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
# True only when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
|
||||
HOOKS_AVAILABLE: bool = HOOK_ENABLED and not MULTI_TENANT
|
||||
@@ -395,6 +395,12 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
|
||||
llm = get_default_llm_with_vision()
|
||||
|
||||
if not llm:
|
||||
if get_image_extraction_and_analysis_enabled():
|
||||
logger.warning(
|
||||
"Image analysis is enabled but no vision-capable LLM is "
|
||||
"available — images will not be summarized. Configure a "
|
||||
"vision model in the admin LLM settings."
|
||||
)
|
||||
# Even without LLM, we still convert to IndexingDocument with base Sections
|
||||
return [
|
||||
IndexingDocument(
|
||||
|
||||
@@ -168,10 +168,23 @@ def get_default_llm_with_vision(
|
||||
if model_supports_image_input(
|
||||
default_model.name, default_model.llm_provider.provider
|
||||
):
|
||||
logger.info(
|
||||
"Using default vision model: %s (provider=%s)",
|
||||
default_model.name,
|
||||
default_model.llm_provider.provider,
|
||||
)
|
||||
return create_vision_llm(
|
||||
LLMProviderView.from_model(default_model.llm_provider),
|
||||
default_model.name,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Default vision model %s (provider=%s) does not support "
|
||||
"image input — falling back to searching all providers",
|
||||
default_model.name,
|
||||
default_model.llm_provider.provider,
|
||||
)
|
||||
|
||||
# Fall back to searching all providers
|
||||
models = fetch_existing_models(
|
||||
db_session=db_session,
|
||||
@@ -179,6 +192,10 @@ def get_default_llm_with_vision(
|
||||
)
|
||||
|
||||
if not models:
|
||||
logger.warning(
|
||||
"No LLM models with VISION or CHAT flow type found — "
|
||||
"image summarization will be disabled"
|
||||
)
|
||||
return None
|
||||
|
||||
for model in models:
|
||||
@@ -200,11 +217,25 @@ def get_default_llm_with_vision(
|
||||
|
||||
for model in sorted_models:
|
||||
if model_supports_image_input(model.name, model.llm_provider.provider):
|
||||
logger.info(
|
||||
"Using fallback vision model: %s (provider=%s)",
|
||||
model.name,
|
||||
model.llm_provider.provider,
|
||||
)
|
||||
return create_vision_llm(
|
||||
provider_map[model.llm_provider_id],
|
||||
model.name,
|
||||
)
|
||||
|
||||
checked_models = [
|
||||
f"{m.name} (provider={m.llm_provider.provider})" for m in sorted_models
|
||||
]
|
||||
logger.warning(
|
||||
"No vision-capable model found among %d candidates: %s — "
|
||||
"image summarization will be disabled",
|
||||
len(sorted_models),
|
||||
", ".join(checked_models),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -530,6 +530,11 @@ class LitellmLLM(LLM):
|
||||
):
|
||||
messages = _strip_tool_content_from_messages(messages)
|
||||
|
||||
# Only pass tool_choice when tools are present — some providers (e.g. Fireworks)
|
||||
# reject requests where tool_choice is explicitly null.
|
||||
if tools and tool_choice is not None:
|
||||
optional_kwargs["tool_choice"] = tool_choice
|
||||
|
||||
response = litellm.completion(
|
||||
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
|
||||
model=model,
|
||||
@@ -538,7 +543,6 @@ class LitellmLLM(LLM):
|
||||
custom_llm_provider=self._custom_llm_provider or None,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
timeout=timeout_override or self._timeout,
|
||||
|
||||
@@ -219,13 +219,26 @@ def litellm_exception_to_error_msg(
|
||||
"ratelimiterror"
|
||||
):
|
||||
upstream_detail = upstream_detail.split(":", 1)[1].strip()
|
||||
error_msg = (
|
||||
f"{provider_name} rate limit: {upstream_detail}"
|
||||
if upstream_detail
|
||||
else f"{provider_name} rate limit exceeded: Please slow down your requests and try again later."
|
||||
)
|
||||
error_code = "RATE_LIMIT"
|
||||
is_retryable = True
|
||||
upstream_detail_lower = upstream_detail.lower()
|
||||
if (
|
||||
"insufficient_quota" in upstream_detail_lower
|
||||
or "exceeded your current quota" in upstream_detail_lower
|
||||
):
|
||||
error_msg = (
|
||||
f"{provider_name} quota exceeded: {upstream_detail}"
|
||||
if upstream_detail
|
||||
else f"{provider_name} quota exceeded: Verify billing and quota for this API key."
|
||||
)
|
||||
error_code = "BUDGET_EXCEEDED"
|
||||
is_retryable = False
|
||||
else:
|
||||
error_msg = (
|
||||
f"{provider_name} rate limit: {upstream_detail}"
|
||||
if upstream_detail
|
||||
else f"{provider_name} rate limit exceeded: Please slow down your requests and try again later."
|
||||
)
|
||||
error_code = "RATE_LIMIT"
|
||||
is_retryable = True
|
||||
elif isinstance(core_exception, ServiceUnavailableError):
|
||||
provider_name = (
|
||||
llm.config.model_provider
|
||||
|
||||
@@ -62,6 +62,7 @@ from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.error_handling.exceptions import register_onyx_exception_handlers
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.hooks.registry import validate_registry
|
||||
from onyx.server.api_key.api import router as api_key_router
|
||||
from onyx.server.auth_check import check_router_auth
|
||||
from onyx.server.documents.cc_pair import router as cc_pair_router
|
||||
@@ -76,6 +77,7 @@ from onyx.server.features.default_assistant.api import (
|
||||
)
|
||||
from onyx.server.features.document_set.api import router as document_set_router
|
||||
from onyx.server.features.hierarchy.api import router as hierarchy_router
|
||||
from onyx.server.features.hooks.api import router as hook_router
|
||||
from onyx.server.features.input_prompt.api import (
|
||||
admin_router as admin_input_prompt_router,
|
||||
)
|
||||
@@ -308,6 +310,7 @@ def validate_no_vector_db_settings() -> None:
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
|
||||
validate_no_vector_db_settings()
|
||||
validate_cache_backend_settings()
|
||||
validate_registry()
|
||||
|
||||
# Set recursion limit
|
||||
if SYSTEM_RECURSION_LIMIT is not None:
|
||||
@@ -451,6 +454,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
|
||||
register_onyx_exception_handlers(application)
|
||||
|
||||
include_router_with_global_prefix_prepended(application, hook_router)
|
||||
include_router_with_global_prefix_prepended(application, password_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
include_router_with_global_prefix_prepended(application, query_router)
|
||||
|
||||
@@ -479,7 +479,9 @@ def is_zip_file(file: UploadFile) -> bool:
|
||||
|
||||
|
||||
def upload_files(
|
||||
files: list[UploadFile], file_origin: FileOrigin = FileOrigin.CONNECTOR
|
||||
files: list[UploadFile],
|
||||
file_origin: FileOrigin = FileOrigin.CONNECTOR,
|
||||
unzip: bool = True,
|
||||
) -> FileUploadResponse:
|
||||
|
||||
# Skip directories and known macOS metadata entries
|
||||
@@ -502,31 +504,46 @@ def upload_files(
|
||||
if seen_zip:
|
||||
raise HTTPException(status_code=400, detail=SEEN_ZIP_DETAIL)
|
||||
seen_zip = True
|
||||
|
||||
# Validate the zip by opening it (catches corrupt/non-zip files)
|
||||
with zipfile.ZipFile(file.file, "r") as zf:
|
||||
zip_metadata_file_id = save_zip_metadata_to_file_store(
|
||||
zf, file_store
|
||||
)
|
||||
for file_info in zf.namelist():
|
||||
if zf.getinfo(file_info).is_dir():
|
||||
continue
|
||||
|
||||
if not should_process_file(file_info):
|
||||
continue
|
||||
|
||||
sub_file_bytes = zf.read(file_info)
|
||||
|
||||
mime_type, __ = mimetypes.guess_type(file_info)
|
||||
if mime_type is None:
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
file_id = file_store.save_file(
|
||||
content=BytesIO(sub_file_bytes),
|
||||
display_name=os.path.basename(file_info),
|
||||
file_origin=file_origin,
|
||||
file_type=mime_type,
|
||||
if unzip:
|
||||
zip_metadata_file_id = save_zip_metadata_to_file_store(
|
||||
zf, file_store
|
||||
)
|
||||
deduped_file_paths.append(file_id)
|
||||
deduped_file_names.append(os.path.basename(file_info))
|
||||
for file_info in zf.namelist():
|
||||
if zf.getinfo(file_info).is_dir():
|
||||
continue
|
||||
|
||||
if not should_process_file(file_info):
|
||||
continue
|
||||
|
||||
sub_file_bytes = zf.read(file_info)
|
||||
|
||||
mime_type, __ = mimetypes.guess_type(file_info)
|
||||
if mime_type is None:
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
file_id = file_store.save_file(
|
||||
content=BytesIO(sub_file_bytes),
|
||||
display_name=os.path.basename(file_info),
|
||||
file_origin=file_origin,
|
||||
file_type=mime_type,
|
||||
)
|
||||
deduped_file_paths.append(file_id)
|
||||
deduped_file_names.append(os.path.basename(file_info))
|
||||
continue
|
||||
|
||||
# Store the zip as-is (unzip=False)
|
||||
file.file.seek(0)
|
||||
file_id = file_store.save_file(
|
||||
content=file.file,
|
||||
display_name=file.filename,
|
||||
file_origin=file_origin,
|
||||
file_type=file.content_type or "application/zip",
|
||||
)
|
||||
deduped_file_paths.append(file_id)
|
||||
deduped_file_names.append(file.filename)
|
||||
continue
|
||||
|
||||
# Since we can't render docx files in the UI,
|
||||
@@ -613,9 +630,10 @@ def _fetch_and_check_file_connector_cc_pair_permissions(
|
||||
@router.post("/admin/connector/file/upload", tags=PUBLIC_API_TAGS)
|
||||
def upload_files_api(
|
||||
files: list[UploadFile],
|
||||
unzip: bool = True,
|
||||
_: User = Depends(current_curator_or_admin_user),
|
||||
) -> FileUploadResponse:
|
||||
return upload_files(files, FileOrigin.OTHER)
|
||||
return upload_files(files, FileOrigin.OTHER, unzip=unzip)
|
||||
|
||||
|
||||
@router.get("/admin/connector/{connector_id}/files", tags=PUBLIC_API_TAGS)
|
||||
@@ -1319,7 +1337,7 @@ def get_connector_indexing_status(
|
||||
# Track admin page visit for analytics
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=user.email,
|
||||
distinct_id=str(user.id),
|
||||
event=MilestoneRecordType.VISITED_ADMIN_PAGE,
|
||||
)
|
||||
|
||||
@@ -1533,7 +1551,7 @@ def create_connector_from_model(
|
||||
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=user.email,
|
||||
distinct_id=str(user.id),
|
||||
event=MilestoneRecordType.CREATED_CONNECTOR,
|
||||
)
|
||||
|
||||
@@ -1611,7 +1629,7 @@ def create_connector_with_mock_credential(
|
||||
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=user.email,
|
||||
distinct_id=str(user.id),
|
||||
event=MilestoneRecordType.CREATED_CONNECTOR,
|
||||
)
|
||||
return response
|
||||
@@ -1915,9 +1933,7 @@ def submit_connector_request(
|
||||
if not connector_name:
|
||||
raise HTTPException(status_code=400, detail="Connector name cannot be empty")
|
||||
|
||||
# Get user identifier for telemetry
|
||||
user_email = user.email
|
||||
distinct_id = user_email or tenant_id
|
||||
|
||||
# Track connector request via PostHog telemetry (Cloud only)
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -1925,11 +1941,11 @@ def submit_connector_request(
|
||||
if MULTI_TENANT:
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=distinct_id,
|
||||
distinct_id=str(user.id),
|
||||
event=MilestoneRecordType.REQUESTED_CONNECTOR,
|
||||
properties={
|
||||
"connector_name": connector_name,
|
||||
"user_email": user_email,
|
||||
"user_email": user.email,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -408,7 +408,7 @@ class FailedConnectorIndexingStatus(BaseModel):
|
||||
"""Simplified version of ConnectorIndexingStatus for failed indexing attempts"""
|
||||
|
||||
cc_pair_id: int
|
||||
name: str | None
|
||||
name: str
|
||||
error_msg: str | None
|
||||
is_deletable: bool
|
||||
connector_id: int
|
||||
@@ -422,7 +422,7 @@ class ConnectorStatus(BaseModel):
|
||||
"""
|
||||
|
||||
cc_pair_id: int
|
||||
name: str | None
|
||||
name: str
|
||||
connector: ConnectorSnapshot
|
||||
credential: CredentialSnapshot
|
||||
access_type: AccessType
|
||||
@@ -453,7 +453,7 @@ class DocsCountOperator(str, Enum):
|
||||
|
||||
class ConnectorIndexingStatusLite(BaseModel):
|
||||
cc_pair_id: int
|
||||
name: str | None
|
||||
name: str
|
||||
source: DocumentSource
|
||||
access_type: AccessType
|
||||
cc_pair_status: ConnectorCredentialPairStatus
|
||||
@@ -488,7 +488,7 @@ class ConnectorCredentialPairIdentifier(BaseModel):
|
||||
|
||||
|
||||
class ConnectorCredentialPairMetadata(BaseModel):
|
||||
name: str | None = None
|
||||
name: str
|
||||
access_type: AccessType
|
||||
auto_sync_options: dict[str, Any] | None = None
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
@@ -501,7 +501,7 @@ class CCStatusUpdateRequest(BaseModel):
|
||||
|
||||
class ConnectorCredentialPairDescriptor(BaseModel):
|
||||
id: int
|
||||
name: str | None = None
|
||||
name: str
|
||||
connector: ConnectorSnapshot
|
||||
credential: CredentialSnapshot
|
||||
access_type: AccessType
|
||||
@@ -511,7 +511,7 @@ class CCPairSummary(BaseModel):
|
||||
"""Simplified connector-credential pair information with just essential data"""
|
||||
|
||||
id: int
|
||||
name: str | None
|
||||
name: str
|
||||
source: DocumentSource
|
||||
access_type: AccessType
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
"date-fns": "^4.1.0",
|
||||
"embla-carousel-react": "^8.6.0",
|
||||
"lucide-react": "^0.562.0",
|
||||
"next": "16.1.5",
|
||||
"next": "16.1.7",
|
||||
"next-themes": "^0.4.6",
|
||||
"radix-ui": "^1.4.3",
|
||||
"react": "19.2.3",
|
||||
@@ -1711,9 +1711,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/env": {
|
||||
"version": "16.1.5",
|
||||
"resolved": "https://registry.npmjs.org/@next/env/-/env-16.1.5.tgz",
|
||||
"integrity": "sha512-CRSCPJiSZoi4Pn69RYBDI9R7YK2g59vLexPQFXY0eyw+ILevIenCywzg+DqmlBik9zszEnw2HLFOUlLAcJbL7g==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/env/-/env-16.1.7.tgz",
|
||||
"integrity": "sha512-rJJbIdJB/RQr2F1nylZr/PJzamvNNhfr3brdKP6s/GW850jbtR70QlSfFselvIBbcPUOlQwBakexjFzqLzF6pg==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@next/eslint-plugin-next": {
|
||||
@@ -1727,9 +1727,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-darwin-arm64": {
|
||||
"version": "16.1.5",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-16.1.5.tgz",
|
||||
"integrity": "sha512-eK7Wdm3Hjy/SCL7TevlH0C9chrpeOYWx2iR7guJDaz4zEQKWcS1IMVfMb9UKBFMg1XgzcPTYPIp1Vcpukkjg6Q==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-16.1.7.tgz",
|
||||
"integrity": "sha512-b2wWIE8sABdyafc4IM8r5Y/dS6kD80JRtOGrUiKTsACFQfWWgUQ2NwoUX1yjFMXVsAwcQeNpnucF2ZrujsBBPg==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -1743,9 +1743,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-darwin-x64": {
|
||||
"version": "16.1.5",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-16.1.5.tgz",
|
||||
"integrity": "sha512-foQscSHD1dCuxBmGkbIr6ScAUF6pRoDZP6czajyvmXPAOFNnQUJu2Os1SGELODjKp/ULa4fulnBWoHV3XdPLfA==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-darwin-x64/-/swc-darwin-x64-16.1.7.tgz",
|
||||
"integrity": "sha512-zcnVaaZulS1WL0Ss38R5Q6D2gz7MtBu8GZLPfK+73D/hp4GFMrC2sudLky1QibfV7h6RJBJs/gOFvYP0X7UVlQ==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -1759,9 +1759,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-arm64-gnu": {
|
||||
"version": "16.1.5",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-16.1.5.tgz",
|
||||
"integrity": "sha512-qNIb42o3C02ccIeSeKjacF3HXotGsxh/FMk/rSRmCzOVMtoWH88odn2uZqF8RLsSUWHcAqTgYmPD3pZ03L9ZAA==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-16.1.7.tgz",
|
||||
"integrity": "sha512-2ant89Lux/Q3VyC8vNVg7uBaFVP9SwoK2jJOOR0L8TQnX8CAYnh4uctAScy2Hwj2dgjVHqHLORQZJ2wH6VxhSQ==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -1775,9 +1775,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-arm64-musl": {
|
||||
"version": "16.1.5",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-16.1.5.tgz",
|
||||
"integrity": "sha512-U+kBxGUY1xMAzDTXmuVMfhaWUZQAwzRaHJ/I6ihtR5SbTVUEaDRiEU9YMjy1obBWpdOBuk1bcm+tsmifYSygfw==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-16.1.7.tgz",
|
||||
"integrity": "sha512-uufcze7LYv0FQg9GnNeZ3/whYfo+1Q3HnQpm16o6Uyi0OVzLlk2ZWoY7j07KADZFY8qwDbsmFnMQP3p3+Ftprw==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -1791,9 +1791,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-x64-gnu": {
|
||||
"version": "16.1.5",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-16.1.5.tgz",
|
||||
"integrity": "sha512-gq2UtoCpN7Ke/7tKaU7i/1L7eFLfhMbXjNghSv0MVGF1dmuoaPeEVDvkDuO/9LVa44h5gqpWeJ4mRRznjDv7LA==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-16.1.7.tgz",
|
||||
"integrity": "sha512-KWVf2gxYvHtvuT+c4MBOGxuse5TD7DsMFYSxVxRBnOzok/xryNeQSjXgxSv9QpIVlaGzEn/pIuI6Koosx8CGWA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -1807,9 +1807,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-linux-x64-musl": {
|
||||
"version": "16.1.5",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-16.1.5.tgz",
|
||||
"integrity": "sha512-bQWSE729PbXT6mMklWLf8dotislPle2L70E9q6iwETYEOt092GDn0c+TTNj26AjmeceSsC4ndyGsK5nKqHYXjQ==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-16.1.7.tgz",
|
||||
"integrity": "sha512-HguhaGwsGr1YAGs68uRKc4aGWxLET+NevJskOcCAwXbwj0fYX0RgZW2gsOCzr9S11CSQPIkxmoSbuVaBp4Z3dA==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -1823,9 +1823,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-win32-arm64-msvc": {
|
||||
"version": "16.1.5",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-16.1.5.tgz",
|
||||
"integrity": "sha512-LZli0anutkIllMtTAWZlDqdfvjWX/ch8AFK5WgkNTvaqwlouiD1oHM+WW8RXMiL0+vAkAJyAGEzPPjO+hnrSNQ==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-16.1.7.tgz",
|
||||
"integrity": "sha512-S0n3KrDJokKTeFyM/vGGGR8+pCmXYrjNTk2ZozOL1C/JFdfUIL9O1ATaJOl5r2POe56iRChbsszrjMAdWSv7kQ==",
|
||||
"cpu": [
|
||||
"arm64"
|
||||
],
|
||||
@@ -1839,9 +1839,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-win32-x64-msvc": {
|
||||
"version": "16.1.5",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-16.1.5.tgz",
|
||||
"integrity": "sha512-7is37HJTNQGhjPpQbkKjKEboHYQnCgpVt/4rBrrln0D9nderNxZ8ZWs8w1fAtzUx7wEyYjQ+/13myFgFj6K2Ng==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-16.1.7.tgz",
|
||||
"integrity": "sha512-mwgtg8CNZGYm06LeEd+bNnOUfwOyNem/rOiP14Lsz+AnUY92Zq/LXwtebtUiaeVkhbroRCQ0c8GlR4UT1U+0yg==",
|
||||
"cpu": [
|
||||
"x64"
|
||||
],
|
||||
@@ -4971,12 +4971,15 @@
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/baseline-browser-mapping": {
|
||||
"version": "2.9.17",
|
||||
"resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.9.17.tgz",
|
||||
"integrity": "sha512-agD0MgJFUP/4nvjqzIB29zRPUuCF7Ge6mEv9s8dHrtYD7QWXRcx75rOADE/d5ah1NI+0vkDl0yorDd5U852IQQ==",
|
||||
"version": "2.10.8",
|
||||
"resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.10.8.tgz",
|
||||
"integrity": "sha512-PCLz/LXGBsNTErbtB6i5u4eLpHeMfi93aUv5duMmj6caNu6IphS4q6UevDnL36sZQv9lrP11dbPKGMaXPwMKfQ==",
|
||||
"license": "Apache-2.0",
|
||||
"bin": {
|
||||
"baseline-browser-mapping": "dist/cli.js"
|
||||
"baseline-browser-mapping": "dist/cli.cjs"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=6.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/body-parser": {
|
||||
@@ -8975,14 +8978,14 @@
|
||||
}
|
||||
},
|
||||
"node_modules/next": {
|
||||
"version": "16.1.5",
|
||||
"resolved": "https://registry.npmjs.org/next/-/next-16.1.5.tgz",
|
||||
"integrity": "sha512-f+wE+NSbiQgh3DSAlTaw2FwY5yGdVViAtp8TotNQj4kk4Q8Bh1sC/aL9aH+Rg1YAVn18OYXsRDT7U/079jgP7w==",
|
||||
"version": "16.1.7",
|
||||
"resolved": "https://registry.npmjs.org/next/-/next-16.1.7.tgz",
|
||||
"integrity": "sha512-WM0L7WrSvKwoLegLYr6V+mz+RIofqQgVAfHhMp9a88ms0cFX8iX9ew+snpWlSBwpkURJOUdvCEt3uLl3NNzvWg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@next/env": "16.1.5",
|
||||
"@next/env": "16.1.7",
|
||||
"@swc/helpers": "0.5.15",
|
||||
"baseline-browser-mapping": "^2.8.3",
|
||||
"baseline-browser-mapping": "^2.9.19",
|
||||
"caniuse-lite": "^1.0.30001579",
|
||||
"postcss": "8.4.31",
|
||||
"styled-jsx": "5.1.6"
|
||||
@@ -8994,14 +8997,14 @@
|
||||
"node": ">=20.9.0"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@next/swc-darwin-arm64": "16.1.5",
|
||||
"@next/swc-darwin-x64": "16.1.5",
|
||||
"@next/swc-linux-arm64-gnu": "16.1.5",
|
||||
"@next/swc-linux-arm64-musl": "16.1.5",
|
||||
"@next/swc-linux-x64-gnu": "16.1.5",
|
||||
"@next/swc-linux-x64-musl": "16.1.5",
|
||||
"@next/swc-win32-arm64-msvc": "16.1.5",
|
||||
"@next/swc-win32-x64-msvc": "16.1.5",
|
||||
"@next/swc-darwin-arm64": "16.1.7",
|
||||
"@next/swc-darwin-x64": "16.1.7",
|
||||
"@next/swc-linux-arm64-gnu": "16.1.7",
|
||||
"@next/swc-linux-arm64-musl": "16.1.7",
|
||||
"@next/swc-linux-x64-gnu": "16.1.7",
|
||||
"@next/swc-linux-x64-musl": "16.1.7",
|
||||
"@next/swc-win32-arm64-msvc": "16.1.7",
|
||||
"@next/swc-win32-x64-msvc": "16.1.7",
|
||||
"sharp": "^0.34.4"
|
||||
},
|
||||
"peerDependencies": {
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
"date-fns": "^4.1.0",
|
||||
"embla-carousel-react": "^8.6.0",
|
||||
"lucide-react": "^0.562.0",
|
||||
"next": "16.1.5",
|
||||
"next": "16.1.7",
|
||||
"next-themes": "^0.4.6",
|
||||
"radix-ui": "^1.4.3",
|
||||
"react": "19.2.3",
|
||||
|
||||
0
backend/onyx/server/features/hooks/__init__.py
Normal file
0
backend/onyx/server/features/hooks/__init__.py
Normal file
453
backend/onyx/server/features/hooks/api.py
Normal file
453
backend/onyx/server/features/hooks/api.py
Normal file
@@ -0,0 +1,453 @@
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.constants import UNSET
|
||||
from onyx.db.constants import UnsetType
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.hook import create_hook__no_commit
|
||||
from onyx.db.hook import delete_hook__no_commit
|
||||
from onyx.db.hook import get_hook_by_id
|
||||
from onyx.db.hook import get_hook_execution_logs
|
||||
from onyx.db.hook import get_hooks
|
||||
from onyx.db.hook import update_hook__no_commit
|
||||
from onyx.db.models import Hook
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.hooks.api_dependencies import require_hook_enabled
|
||||
from onyx.hooks.models import HookCreateRequest
|
||||
from onyx.hooks.models import HookExecutionRecord
|
||||
from onyx.hooks.models import HookPointMetaResponse
|
||||
from onyx.hooks.models import HookResponse
|
||||
from onyx.hooks.models import HookUpdateRequest
|
||||
from onyx.hooks.models import HookValidateResponse
|
||||
from onyx.hooks.models import HookValidateStatus
|
||||
from onyx.hooks.registry import get_all_specs
|
||||
from onyx.hooks.registry import get_hook_point_spec
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.url import SSRFException
|
||||
from onyx.utils.url import validate_outbound_http_url
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SSRF protection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _check_ssrf_safety(endpoint_url: str) -> None:
|
||||
"""Raise OnyxError if endpoint_url could be used for SSRF.
|
||||
|
||||
Delegates to validate_outbound_http_url with https_only=True.
|
||||
"""
|
||||
try:
|
||||
validate_outbound_http_url(endpoint_url, https_only=True)
|
||||
except (SSRFException, ValueError) as e:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _hook_to_response(hook: Hook, creator_email: str | None = None) -> HookResponse:
|
||||
return HookResponse(
|
||||
id=hook.id,
|
||||
name=hook.name,
|
||||
hook_point=hook.hook_point,
|
||||
endpoint_url=hook.endpoint_url,
|
||||
fail_strategy=hook.fail_strategy,
|
||||
timeout_seconds=hook.timeout_seconds,
|
||||
is_active=hook.is_active,
|
||||
is_reachable=hook.is_reachable,
|
||||
creator_email=(
|
||||
creator_email
|
||||
if creator_email is not None
|
||||
else (hook.creator.email if hook.creator else None)
|
||||
),
|
||||
created_at=hook.created_at,
|
||||
updated_at=hook.updated_at,
|
||||
)
|
||||
|
||||
|
||||
def _get_hook_or_404(
|
||||
db_session: Session,
|
||||
hook_id: int,
|
||||
include_creator: bool = False,
|
||||
) -> Hook:
|
||||
hook = get_hook_by_id(
|
||||
db_session=db_session,
|
||||
hook_id=hook_id,
|
||||
include_creator=include_creator,
|
||||
)
|
||||
if hook is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Hook {hook_id} not found.")
|
||||
return hook
|
||||
|
||||
|
||||
def _raise_for_validation_failure(validation: HookValidateResponse) -> None:
|
||||
"""Raise an appropriate OnyxError for a non-passed validation result."""
|
||||
if validation.status == HookValidateStatus.auth_failed:
|
||||
raise OnyxError(OnyxErrorCode.CREDENTIAL_INVALID, validation.error_message)
|
||||
if validation.status == HookValidateStatus.timeout:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.GATEWAY_TIMEOUT,
|
||||
f"Endpoint validation failed: {validation.error_message}",
|
||||
)
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
f"Endpoint validation failed: {validation.error_message}",
|
||||
)
|
||||
|
||||
|
||||
def _validate_endpoint(
|
||||
endpoint_url: str,
|
||||
api_key: str | None,
|
||||
timeout_seconds: float,
|
||||
) -> HookValidateResponse:
|
||||
"""Check whether endpoint_url is reachable by sending an empty POST request.
|
||||
|
||||
We use POST since hook endpoints expect POST requests. The server will typically
|
||||
respond with 4xx (missing/invalid body) — that is fine. Any HTTP response means
|
||||
the server is up and routable. A 401/403 response returns auth_failed
|
||||
(not reachable — indicates the api_key is invalid).
|
||||
|
||||
Timeout handling:
|
||||
- ConnectTimeout: TCP handshake never completed → cannot_connect.
|
||||
- ReadTimeout / WriteTimeout: TCP was established, server responded slowly → timeout
|
||||
(operator should consider increasing timeout_seconds).
|
||||
- All other exceptions → cannot_connect.
|
||||
"""
|
||||
_check_ssrf_safety(endpoint_url)
|
||||
headers: dict[str, str] = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
try:
|
||||
with httpx.Client(timeout=timeout_seconds, follow_redirects=False) as client:
|
||||
response = client.post(endpoint_url, headers=headers)
|
||||
if response.status_code in (401, 403):
|
||||
return HookValidateResponse(
|
||||
status=HookValidateStatus.auth_failed,
|
||||
error_message=f"Authentication failed (HTTP {response.status_code})",
|
||||
)
|
||||
return HookValidateResponse(status=HookValidateStatus.passed)
|
||||
except httpx.TimeoutException as exc:
|
||||
# ConnectTimeout: TCP handshake never completed → cannot_connect.
|
||||
# ReadTimeout / WriteTimeout: TCP was established, server just responded slowly → timeout.
|
||||
if isinstance(exc, httpx.ConnectTimeout):
|
||||
logger.warning(
|
||||
"Hook endpoint validation: connect timeout for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
return HookValidateResponse(
|
||||
status=HookValidateStatus.cannot_connect, error_message=str(exc)
|
||||
)
|
||||
logger.warning(
|
||||
"Hook endpoint validation: read/write timeout for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
return HookValidateResponse(
|
||||
status=HookValidateStatus.timeout,
|
||||
error_message="Endpoint timed out — consider increasing timeout_seconds.",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Hook endpoint validation: connection error for %s",
|
||||
endpoint_url,
|
||||
exc_info=exc,
|
||||
)
|
||||
return HookValidateResponse(
|
||||
status=HookValidateStatus.cannot_connect, error_message=str(exc)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Routers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
router = APIRouter(prefix="/admin/hooks")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hook endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/specs")
|
||||
def get_hook_point_specs(
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
) -> list[HookPointMetaResponse]:
|
||||
return [
|
||||
HookPointMetaResponse(
|
||||
hook_point=spec.hook_point,
|
||||
display_name=spec.display_name,
|
||||
description=spec.description,
|
||||
docs_url=spec.docs_url,
|
||||
input_schema=spec.input_schema,
|
||||
output_schema=spec.output_schema,
|
||||
default_timeout_seconds=spec.default_timeout_seconds,
|
||||
default_fail_strategy=spec.default_fail_strategy,
|
||||
fail_hard_description=spec.fail_hard_description,
|
||||
)
|
||||
for spec in get_all_specs()
|
||||
]
|
||||
|
||||
|
||||
@router.get("")
|
||||
def list_hooks(
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[HookResponse]:
|
||||
hooks = get_hooks(db_session=db_session, include_creator=True)
|
||||
return [_hook_to_response(h) for h in hooks]
|
||||
|
||||
|
||||
@router.post("")
|
||||
def create_hook(
|
||||
req: HookCreateRequest,
|
||||
user: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
"""Create a new hook. The endpoint is validated before persisting — creation fails if
|
||||
the endpoint cannot be reached or the api_key is invalid. Hooks are created inactive;
|
||||
use POST /{hook_id}/activate once ready to receive traffic."""
|
||||
spec = get_hook_point_spec(req.hook_point)
|
||||
api_key = req.api_key.get_secret_value() if req.api_key else None
|
||||
validation = _validate_endpoint(
|
||||
endpoint_url=req.endpoint_url,
|
||||
api_key=api_key,
|
||||
timeout_seconds=req.timeout_seconds or spec.default_timeout_seconds,
|
||||
)
|
||||
if validation.status != HookValidateStatus.passed:
|
||||
_raise_for_validation_failure(validation)
|
||||
|
||||
hook = create_hook__no_commit(
|
||||
db_session=db_session,
|
||||
name=req.name,
|
||||
hook_point=req.hook_point,
|
||||
endpoint_url=req.endpoint_url,
|
||||
api_key=api_key,
|
||||
fail_strategy=req.fail_strategy or spec.default_fail_strategy,
|
||||
timeout_seconds=req.timeout_seconds or spec.default_timeout_seconds,
|
||||
creator_id=user.id,
|
||||
)
|
||||
hook.is_reachable = True
|
||||
db_session.commit()
|
||||
return _hook_to_response(hook, creator_email=user.email)
|
||||
|
||||
|
||||
@router.get("/{hook_id}")
|
||||
def get_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
hook = _get_hook_or_404(db_session, hook_id, include_creator=True)
|
||||
return _hook_to_response(hook)
|
||||
|
||||
|
||||
@router.patch("/{hook_id}")
|
||||
def update_hook(
|
||||
hook_id: int,
|
||||
req: HookUpdateRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
"""Update hook fields. If endpoint_url, api_key, or timeout_seconds changes, the
|
||||
endpoint is re-validated using the effective values. For active hooks the update is
|
||||
rejected on validation failure, keeping live traffic unaffected. For inactive hooks
|
||||
the update goes through regardless and is_reachable is updated to reflect the result.
|
||||
|
||||
Note: if an active hook's endpoint is currently down, even a timeout_seconds-only
|
||||
increase will be rejected. The recovery flow is: deactivate → update → reactivate.
|
||||
"""
|
||||
# api_key: UNSET = no change, None = clear, value = update
|
||||
api_key: str | None | UnsetType
|
||||
if "api_key" not in req.model_fields_set:
|
||||
api_key = UNSET
|
||||
elif req.api_key is None:
|
||||
api_key = None
|
||||
else:
|
||||
api_key = req.api_key.get_secret_value()
|
||||
|
||||
endpoint_url_changing = "endpoint_url" in req.model_fields_set
|
||||
api_key_changing = not isinstance(api_key, UnsetType)
|
||||
timeout_changing = "timeout_seconds" in req.model_fields_set
|
||||
|
||||
validated_is_reachable: bool | None = None
|
||||
if endpoint_url_changing or api_key_changing or timeout_changing:
|
||||
existing = _get_hook_or_404(db_session, hook_id)
|
||||
effective_url: str = (
|
||||
req.endpoint_url if endpoint_url_changing else existing.endpoint_url # type: ignore[assignment] # endpoint_url is required on create and cannot be cleared on update
|
||||
)
|
||||
effective_api_key: str | None = (
|
||||
(api_key if not isinstance(api_key, UnsetType) else None)
|
||||
if api_key_changing
|
||||
else (
|
||||
existing.api_key.get_value(apply_mask=False)
|
||||
if existing.api_key
|
||||
else None
|
||||
)
|
||||
)
|
||||
effective_timeout: float = (
|
||||
req.timeout_seconds if timeout_changing else existing.timeout_seconds # type: ignore[assignment] # req.timeout_seconds is non-None when timeout_changing (validated by HookUpdateRequest)
|
||||
)
|
||||
validation = _validate_endpoint(
|
||||
endpoint_url=effective_url,
|
||||
api_key=effective_api_key,
|
||||
timeout_seconds=effective_timeout,
|
||||
)
|
||||
if existing.is_active and validation.status != HookValidateStatus.passed:
|
||||
_raise_for_validation_failure(validation)
|
||||
validated_is_reachable = validation.status == HookValidateStatus.passed
|
||||
|
||||
hook = update_hook__no_commit(
|
||||
db_session=db_session,
|
||||
hook_id=hook_id,
|
||||
name=req.name,
|
||||
endpoint_url=(req.endpoint_url if endpoint_url_changing else UNSET),
|
||||
api_key=api_key,
|
||||
fail_strategy=req.fail_strategy,
|
||||
timeout_seconds=req.timeout_seconds,
|
||||
is_reachable=validated_is_reachable,
|
||||
include_creator=True,
|
||||
)
|
||||
db_session.commit()
|
||||
return _hook_to_response(hook)
|
||||
|
||||
|
||||
@router.delete("/{hook_id}")
|
||||
def delete_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
delete_hook__no_commit(db_session=db_session, hook_id=hook_id)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@router.post("/{hook_id}/activate")
|
||||
def activate_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
hook = _get_hook_or_404(db_session, hook_id)
|
||||
if not hook.endpoint_url:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT, "Hook has no endpoint URL configured."
|
||||
)
|
||||
|
||||
api_key = hook.api_key.get_value(apply_mask=False) if hook.api_key else None
|
||||
validation = _validate_endpoint(
|
||||
endpoint_url=hook.endpoint_url,
|
||||
api_key=api_key,
|
||||
timeout_seconds=hook.timeout_seconds,
|
||||
)
|
||||
if validation.status != HookValidateStatus.passed:
|
||||
# Persist is_reachable=False in a separate session so the request
|
||||
# session has no commits on the failure path and the transaction
|
||||
# boundary stays clean.
|
||||
if hook.is_reachable is not False:
|
||||
with get_session_with_current_tenant() as side_session:
|
||||
update_hook__no_commit(
|
||||
db_session=side_session, hook_id=hook_id, is_reachable=False
|
||||
)
|
||||
side_session.commit()
|
||||
_raise_for_validation_failure(validation)
|
||||
|
||||
hook = update_hook__no_commit(
|
||||
db_session=db_session,
|
||||
hook_id=hook_id,
|
||||
is_active=True,
|
||||
is_reachable=True,
|
||||
include_creator=True,
|
||||
)
|
||||
db_session.commit()
|
||||
return _hook_to_response(hook)
|
||||
|
||||
|
||||
@router.post("/{hook_id}/validate")
|
||||
def validate_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookValidateResponse:
|
||||
hook = _get_hook_or_404(db_session, hook_id)
|
||||
if not hook.endpoint_url:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT, "Hook has no endpoint URL configured."
|
||||
)
|
||||
|
||||
api_key = hook.api_key.get_value(apply_mask=False) if hook.api_key else None
|
||||
validation = _validate_endpoint(
|
||||
endpoint_url=hook.endpoint_url,
|
||||
api_key=api_key,
|
||||
timeout_seconds=hook.timeout_seconds,
|
||||
)
|
||||
validation_passed = validation.status == HookValidateStatus.passed
|
||||
if hook.is_reachable != validation_passed:
|
||||
update_hook__no_commit(
|
||||
db_session=db_session, hook_id=hook_id, is_reachable=validation_passed
|
||||
)
|
||||
db_session.commit()
|
||||
return validation
|
||||
|
||||
|
||||
@router.post("/{hook_id}/deactivate")
|
||||
def deactivate_hook(
|
||||
hook_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HookResponse:
|
||||
hook = update_hook__no_commit(
|
||||
db_session=db_session,
|
||||
hook_id=hook_id,
|
||||
is_active=False,
|
||||
include_creator=True,
|
||||
)
|
||||
db_session.commit()
|
||||
return _hook_to_response(hook)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Execution log endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/{hook_id}/execution-logs")
|
||||
def list_hook_execution_logs(
|
||||
hook_id: int,
|
||||
limit: int = Query(default=10, ge=1, le=100),
|
||||
_: User = Depends(current_admin_user),
|
||||
_hook_enabled: None = Depends(require_hook_enabled),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[HookExecutionRecord]:
|
||||
_get_hook_or_404(db_session, hook_id)
|
||||
logs = get_hook_execution_logs(db_session=db_session, hook_id=hook_id, limit=limit)
|
||||
return [
|
||||
HookExecutionRecord(
|
||||
error_message=log.error_message,
|
||||
status_code=log.status_code,
|
||||
duration_ms=log.duration_ms,
|
||||
created_at=log.created_at,
|
||||
)
|
||||
for log in logs
|
||||
]
|
||||
@@ -314,7 +314,7 @@ def create_persona(
|
||||
)
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=user.email,
|
||||
distinct_id=str(user.id),
|
||||
event=MilestoneRecordType.CREATED_ASSISTANT,
|
||||
)
|
||||
|
||||
|
||||
@@ -81,6 +81,7 @@ from onyx.server.manage.llm.models import VisionProviderResponse
|
||||
from onyx.server.manage.llm.utils import generate_bedrock_display_name
|
||||
from onyx.server.manage.llm.utils import generate_ollama_display_name
|
||||
from onyx.server.manage.llm.utils import infer_vision_support
|
||||
from onyx.server.manage.llm.utils import is_embedding_model
|
||||
from onyx.server.manage.llm.utils import is_reasoning_model
|
||||
from onyx.server.manage.llm.utils import is_valid_bedrock_model
|
||||
from onyx.server.manage.llm.utils import ModelMetadata
|
||||
@@ -1374,6 +1375,10 @@ def get_litellm_available_models(
|
||||
try:
|
||||
model_details = LitellmModelDetails.model_validate(model)
|
||||
|
||||
# Skip embedding models
|
||||
if is_embedding_model(model_details.id):
|
||||
continue
|
||||
|
||||
results.append(
|
||||
LitellmFinalModelResponse(
|
||||
provider_name=model_details.owned_by,
|
||||
|
||||
@@ -366,3 +366,18 @@ def extract_vendor_from_model_name(model_name: str, provider: str) -> str | None
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def is_embedding_model(model_name: str) -> bool:
|
||||
"""Checks for if a model is an embedding model"""
|
||||
from litellm import get_model_info
|
||||
|
||||
try:
|
||||
# get_model_info raises on unknown models
|
||||
# default to False
|
||||
model_info = get_model_info(model_name)
|
||||
except Exception:
|
||||
return False
|
||||
is_embedding_mode = model_info.get("mode") == "embedding"
|
||||
|
||||
return is_embedding_mode
|
||||
|
||||
@@ -561,7 +561,7 @@ def handle_send_chat_message(
|
||||
tenant_id = get_current_tenant_id()
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=tenant_id if user.is_anonymous else user.email,
|
||||
distinct_id=tenant_id if user.is_anonymous else str(user.id),
|
||||
event=MilestoneRecordType.RAN_QUERY,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,239 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.server.query_and_chat.streaming_models import GeneratedImage
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import OpenUrlDocuments
|
||||
from onyx.server.query_and_chat.streaming_models import OpenUrlStart
|
||||
from onyx.server.query_and_chat.streaming_models import OpenUrlUrls
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolDocumentsDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolQueriesDelta
|
||||
from onyx.server.query_and_chat.streaming_models import SearchToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
|
||||
|
||||
_CANNOT_SHOW_STEP_RESULTS_STR = "[Cannot display step results]"
|
||||
|
||||
|
||||
def _adjust_message_text_for_agent_search_results(
|
||||
adjusted_message_text: str,
|
||||
final_documents: list[SavedSearchDoc], # noqa: ARG001
|
||||
) -> str:
|
||||
# Remove all [Q<integer>] patterns (sub-question citations)
|
||||
return re.sub(r"\[Q\d+\]", "", adjusted_message_text)
|
||||
|
||||
|
||||
def _replace_d_citations_with_links(
|
||||
message_text: str, final_documents: list[SavedSearchDoc]
|
||||
) -> str:
|
||||
def replace_citation(match: re.Match[str]) -> str:
|
||||
d_number = match.group(1)
|
||||
try:
|
||||
doc_index = int(d_number) - 1
|
||||
if 0 <= doc_index < len(final_documents):
|
||||
doc = final_documents[doc_index]
|
||||
link = doc.link if doc.link else ""
|
||||
return f"[[{d_number}]]({link})"
|
||||
return match.group(0)
|
||||
except (ValueError, IndexError):
|
||||
return match.group(0)
|
||||
|
||||
return re.sub(r"\[D(\d+)\]", replace_citation, message_text)
|
||||
|
||||
|
||||
def create_message_packets(
|
||||
message_text: str,
|
||||
final_documents: list[SavedSearchDoc] | None,
|
||||
turn_index: int,
|
||||
is_legacy_agentic: bool = False,
|
||||
) -> list[Packet]:
|
||||
packets: list[Packet] = []
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
obj=AgentResponseStart(
|
||||
final_documents=SearchDoc.from_saved_search_docs(final_documents or []),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
adjusted_message_text = message_text
|
||||
if is_legacy_agentic:
|
||||
if final_documents is not None:
|
||||
adjusted_message_text = _adjust_message_text_for_agent_search_results(
|
||||
message_text, final_documents
|
||||
)
|
||||
adjusted_message_text = _replace_d_citations_with_links(
|
||||
adjusted_message_text, final_documents
|
||||
)
|
||||
else:
|
||||
adjusted_message_text = re.sub(r"\[Q\d+\]", "", message_text)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
obj=AgentResponseDelta(
|
||||
content=adjusted_message_text,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
obj=SectionEnd(),
|
||||
)
|
||||
)
|
||||
|
||||
return packets
|
||||
|
||||
|
||||
def create_citation_packets(
|
||||
citation_info_list: list[CitationInfo], turn_index: int
|
||||
) -> list[Packet]:
|
||||
packets: list[Packet] = []
|
||||
|
||||
# Emit each citation as a separate CitationInfo packet
|
||||
for citation_info in citation_info_list:
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
obj=citation_info,
|
||||
)
|
||||
)
|
||||
|
||||
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
|
||||
|
||||
return packets
|
||||
|
||||
|
||||
def create_reasoning_packets(reasoning_text: str, turn_index: int) -> list[Packet]:
|
||||
packets: list[Packet] = []
|
||||
|
||||
packets.append(
|
||||
Packet(placement=Placement(turn_index=turn_index), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
obj=ReasoningDelta(
|
||||
reasoning=reasoning_text,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
|
||||
|
||||
return packets
|
||||
|
||||
|
||||
def create_image_generation_packets(
|
||||
images: list[GeneratedImage], turn_index: int
|
||||
) -> list[Packet]:
|
||||
packets: list[Packet] = []
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
obj=ImageGenerationToolStart(),
|
||||
)
|
||||
)
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
obj=ImageGenerationFinal(images=images),
|
||||
),
|
||||
)
|
||||
|
||||
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
|
||||
|
||||
return packets
|
||||
|
||||
|
||||
def create_fetch_packets(
|
||||
fetch_docs: list[SavedSearchDoc],
|
||||
urls: list[str],
|
||||
turn_index: int,
|
||||
) -> list[Packet]:
|
||||
packets: list[Packet] = []
|
||||
# Emit start packet
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
obj=OpenUrlStart(),
|
||||
)
|
||||
)
|
||||
# Emit URLs packet
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
obj=OpenUrlUrls(urls=urls),
|
||||
)
|
||||
)
|
||||
# Emit documents packet
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
obj=OpenUrlDocuments(
|
||||
documents=SearchDoc.from_saved_search_docs(fetch_docs)
|
||||
),
|
||||
)
|
||||
)
|
||||
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
|
||||
return packets
|
||||
|
||||
|
||||
def create_search_packets(
|
||||
search_queries: list[str],
|
||||
saved_search_docs: list[SavedSearchDoc],
|
||||
is_internet_search: bool,
|
||||
turn_index: int,
|
||||
) -> list[Packet]:
|
||||
packets: list[Packet] = []
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
obj=SearchToolStart(
|
||||
is_internet_search=is_internet_search,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Emit queries if present
|
||||
if search_queries:
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
obj=SearchToolQueriesDelta(queries=search_queries),
|
||||
),
|
||||
)
|
||||
|
||||
# Emit documents if present
|
||||
if saved_search_docs:
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=Placement(turn_index=turn_index),
|
||||
obj=SearchToolDocumentsDelta(
|
||||
documents=SearchDoc.from_saved_search_docs(saved_search_docs)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
packets.append(Packet(placement=Placement(turn_index=turn_index), obj=SectionEnd()))
|
||||
|
||||
return packets
|
||||
@@ -53,8 +53,12 @@ logger = setup_logger()
|
||||
|
||||
class SearchToolConfig(BaseModel):
|
||||
user_selected_filters: BaseFilters | None = None
|
||||
project_id: int | None = None
|
||||
persona_id: int | None = None
|
||||
# Vespa metadata filters for overflowing user files. These are NOT the
|
||||
# IDs of the current project/persona — they are only set when the
|
||||
# project's/persona's user files didn't fit in the LLM context window and
|
||||
# must be found via vector DB search instead.
|
||||
project_id_filter: int | None = None
|
||||
persona_id_filter: int | None = None
|
||||
bypass_acl: bool = False
|
||||
additional_context: str | None = None
|
||||
slack_context: SlackContext | None = None
|
||||
@@ -180,8 +184,8 @@ def construct_tools(
|
||||
llm=llm,
|
||||
document_index=document_index,
|
||||
user_selected_filters=search_tool_config.user_selected_filters,
|
||||
project_id=search_tool_config.project_id,
|
||||
persona_id=search_tool_config.persona_id,
|
||||
project_id_filter=search_tool_config.project_id_filter,
|
||||
persona_id_filter=search_tool_config.persona_id_filter,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
slack_context=search_tool_config.slack_context,
|
||||
enable_slack_search=search_tool_config.enable_slack_search,
|
||||
@@ -428,8 +432,8 @@ def construct_tools(
|
||||
llm=llm,
|
||||
document_index=document_index,
|
||||
user_selected_filters=search_tool_config.user_selected_filters,
|
||||
project_id=search_tool_config.project_id,
|
||||
persona_id=search_tool_config.persona_id,
|
||||
project_id_filter=search_tool_config.project_id_filter,
|
||||
persona_id_filter=search_tool_config.persona_id_filter,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
slack_context=search_tool_config.slack_context,
|
||||
enable_slack_search=search_tool_config.enable_slack_search,
|
||||
|
||||
@@ -764,8 +764,7 @@ class OpenURLTool(Tool[OpenURLToolOverrideKwargs]):
|
||||
tags=None,
|
||||
access_control_list=access_control_list,
|
||||
tenant_id=get_current_tenant_id() if MULTI_TENANT else None,
|
||||
user_file_ids=None,
|
||||
project_id=None,
|
||||
project_id_filter=None,
|
||||
)
|
||||
|
||||
def _merge_indexed_and_crawled_results(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import hashlib
|
||||
import mimetypes
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
@@ -83,6 +84,14 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
def __init__(self, tool_id: int, emitter: Emitter) -> None:
|
||||
super().__init__(emitter=emitter)
|
||||
self._id = tool_id
|
||||
# Cache of (filename, content_hash) -> ci_file_id to avoid re-uploading
|
||||
# the same file on every tool call iteration within the same agent session.
|
||||
# Filename is included in the key so two files with identical bytes but
|
||||
# different names each get their own upload slot.
|
||||
# TTL assumption: code-interpreter file TTLs (typically hours) greatly
|
||||
# exceed the lifetime of a single agent session (at most MAX_LLM_CYCLES
|
||||
# iterations, typically a few minutes), so stale-ID eviction is not needed.
|
||||
self._uploaded_file_cache: dict[tuple[str, str], str] = {}
|
||||
|
||||
@property
|
||||
def id(self) -> int:
|
||||
@@ -182,8 +191,13 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
for ind, chat_file in enumerate(chat_files):
|
||||
file_name = chat_file.filename or f"file_{ind}"
|
||||
try:
|
||||
# Upload to Code Interpreter
|
||||
ci_file_id = client.upload_file(chat_file.content, file_name)
|
||||
content_hash = hashlib.sha256(chat_file.content).hexdigest()
|
||||
cache_key = (file_name, content_hash)
|
||||
ci_file_id = self._uploaded_file_cache.get(cache_key)
|
||||
if ci_file_id is None:
|
||||
# Upload to Code Interpreter
|
||||
ci_file_id = client.upload_file(chat_file.content, file_name)
|
||||
self._uploaded_file_cache[cache_key] = ci_file_id
|
||||
|
||||
# Stage for execution
|
||||
files_to_stage.append({"path": file_name, "file_id": ci_file_id})
|
||||
@@ -299,14 +313,10 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
f"Failed to delete Code Interpreter generated file {ci_file_id}: {e}"
|
||||
)
|
||||
|
||||
# Cleanup staged input files
|
||||
for file_mapping in files_to_stage:
|
||||
try:
|
||||
client.delete_file(file_mapping["file_id"])
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to delete Code Interpreter staged file {file_mapping['file_id']}: {e}"
|
||||
)
|
||||
# Note: staged input files are intentionally not deleted here because
|
||||
# _uploaded_file_cache reuses their file_ids across iterations. They are
|
||||
# orphaned when the session ends, but the code interpreter cleans up
|
||||
# stale files on its own TTL.
|
||||
|
||||
# Emit file_ids once files are processed
|
||||
if generated_file_ids:
|
||||
|
||||
@@ -244,10 +244,11 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
document_index: DocumentIndex,
|
||||
# Respecting user selections
|
||||
user_selected_filters: BaseFilters | None,
|
||||
# If the chat is part of a project
|
||||
project_id: int | None,
|
||||
# If set, search scopes to files attached to this persona
|
||||
persona_id: int | None = None,
|
||||
# Vespa metadata filters for overflowing user files. NOT the raw IDs
|
||||
# of the current project/persona — only set when user files couldn't
|
||||
# fit in the LLM context and need to be searched via vector DB.
|
||||
project_id_filter: int | None,
|
||||
persona_id_filter: int | None = None,
|
||||
bypass_acl: bool = False,
|
||||
# Slack context for federated Slack search (tokens fetched internally)
|
||||
slack_context: SlackContext | None = None,
|
||||
@@ -261,8 +262,8 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
self.llm = llm
|
||||
self.document_index = document_index
|
||||
self.user_selected_filters = user_selected_filters
|
||||
self.project_id = project_id
|
||||
self.persona_id = persona_id
|
||||
self.project_id_filter = project_id_filter
|
||||
self.persona_id_filter = persona_id_filter
|
||||
self.bypass_acl = bypass_acl
|
||||
self.slack_context = slack_context
|
||||
self.enable_slack_search = enable_slack_search
|
||||
@@ -451,13 +452,15 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
hybrid_alpha=hybrid_alpha,
|
||||
# For projects, the search scope is the project and has no other limits
|
||||
user_selected_filters=(
|
||||
self.user_selected_filters if self.project_id is None else None
|
||||
self.user_selected_filters
|
||||
if self.project_id_filter is None
|
||||
else None
|
||||
),
|
||||
bypass_acl=self.bypass_acl,
|
||||
limit=num_hits,
|
||||
),
|
||||
project_id=self.project_id,
|
||||
persona_id=self.persona_id,
|
||||
project_id_filter=self.project_id_filter,
|
||||
persona_id_filter=self.persona_id_filter,
|
||||
document_index=self.document_index,
|
||||
user=self.user,
|
||||
persona=self.persona,
|
||||
@@ -574,7 +577,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
)
|
||||
|
||||
# Federated retrieval functions (non-Slack; Slack is separate)
|
||||
if self.project_id is not None:
|
||||
if self.project_id_filter is not None:
|
||||
# Project mode ignores user filters → no federated sources
|
||||
prefetch_source_types = None
|
||||
else:
|
||||
@@ -587,16 +590,12 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
persona_document_sets = (
|
||||
[ds.name for ds in self.persona.document_sets] if self.persona else None
|
||||
)
|
||||
user_file_ids = (
|
||||
[uf.id for uf in self.persona.user_files] if self.persona else None
|
||||
)
|
||||
federated_retrieval_infos = (
|
||||
get_federated_retrieval_functions(
|
||||
db_session=db_session,
|
||||
user_id=self.user.id if self.user else None,
|
||||
source_types=prefetch_source_types,
|
||||
document_set_names=persona_document_sets,
|
||||
user_file_ids=user_file_ids,
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
@@ -74,7 +74,7 @@ def make_structured_onyx_request_id(prefix: str, request_url: str) -> str:
|
||||
|
||||
def _make_onyx_request_id(prefix: str, hash_input: str) -> str:
|
||||
"""helper function to return an id given a string input"""
|
||||
hash_obj = hashlib.md5(hash_input.encode("utf-8"))
|
||||
hash_obj = hashlib.md5(hash_input.encode("utf-8"), usedforsecurity=False)
|
||||
hash_bytes = hash_obj.digest()[:6] # Truncate to 6 bytes
|
||||
|
||||
# 6 bytes becomes 8 bytes. we shouldn't need to strip but just in case
|
||||
|
||||
@@ -2,6 +2,7 @@ import contextvars
|
||||
import threading
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
@@ -152,7 +153,7 @@ def mt_cloud_telemetry(
|
||||
tenant_id: str,
|
||||
distinct_id: str,
|
||||
event: MilestoneRecordType,
|
||||
properties: dict | None = None,
|
||||
properties: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
if not MULTI_TENANT:
|
||||
return
|
||||
@@ -173,3 +174,18 @@ def mt_cloud_telemetry(
|
||||
attribute="event_telemetry",
|
||||
fallback=noop_fallback,
|
||||
)(distinct_id, event, all_properties)
|
||||
|
||||
|
||||
def mt_cloud_identify(
|
||||
distinct_id: str,
|
||||
properties: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Create/update a PostHog person profile (Cloud only)."""
|
||||
if not MULTI_TENANT:
|
||||
return
|
||||
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
module="onyx.utils.telemetry",
|
||||
attribute="identify_user",
|
||||
fallback=noop_fallback,
|
||||
)(distinct_id, properties)
|
||||
|
||||
@@ -140,10 +140,20 @@ def _validate_and_resolve_url(url: str) -> tuple[str, str, int]:
|
||||
return validated_ip, hostname, port
|
||||
|
||||
|
||||
def validate_outbound_http_url(url: str, *, allow_private_network: bool = False) -> str:
|
||||
def validate_outbound_http_url(
|
||||
url: str,
|
||||
*,
|
||||
allow_private_network: bool = False,
|
||||
https_only: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Validate a URL that will be used by backend outbound HTTP calls.
|
||||
|
||||
Args:
|
||||
url: The URL to validate.
|
||||
allow_private_network: If True, skip private/reserved IP checks.
|
||||
https_only: If True, reject http:// URLs (only https:// is allowed).
|
||||
|
||||
Returns:
|
||||
A normalized URL string with surrounding whitespace removed.
|
||||
|
||||
@@ -157,7 +167,12 @@ def validate_outbound_http_url(url: str, *, allow_private_network: bool = False)
|
||||
|
||||
parsed = urlparse(normalized_url)
|
||||
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
if https_only:
|
||||
if parsed.scheme != "https":
|
||||
raise SSRFException(
|
||||
f"Invalid URL scheme '{parsed.scheme}'. Only https is allowed."
|
||||
)
|
||||
elif parsed.scheme not in ("http", "https"):
|
||||
raise SSRFException(
|
||||
f"Invalid URL scheme '{parsed.scheme}'. Only http and https are allowed."
|
||||
)
|
||||
|
||||
@@ -65,7 +65,7 @@ attrs==25.4.0
|
||||
# jsonschema
|
||||
# referencing
|
||||
# zeep
|
||||
authlib==1.6.7
|
||||
authlib==1.6.9
|
||||
# via fastmcp
|
||||
azure-cognitiveservices-speech==1.38.0
|
||||
# via onyx
|
||||
@@ -698,7 +698,7 @@ py-key-value-aio==0.4.4
|
||||
# via fastmcp
|
||||
pyairtable==3.0.1
|
||||
# via onyx
|
||||
pyasn1==0.6.2
|
||||
pyasn1==0.6.3
|
||||
# via
|
||||
# pyasn1-modules
|
||||
# rsa
|
||||
@@ -737,7 +737,7 @@ pygithub==2.5.0
|
||||
# via onyx
|
||||
pygments==2.19.2
|
||||
# via rich
|
||||
pyjwt==2.11.0
|
||||
pyjwt==2.12.0
|
||||
# via
|
||||
# fastapi-users
|
||||
# mcp
|
||||
@@ -752,7 +752,7 @@ pypandoc-binary==1.16.2
|
||||
# via onyx
|
||||
pyparsing==3.2.5
|
||||
# via httplib2
|
||||
pypdf==6.8.0
|
||||
pypdf==6.9.1
|
||||
# via
|
||||
# onyx
|
||||
# unstructured-client
|
||||
|
||||
@@ -263,7 +263,7 @@ oauthlib==3.2.2
|
||||
# via
|
||||
# kubernetes
|
||||
# requests-oauthlib
|
||||
onyx-devtools==0.7.0
|
||||
onyx-devtools==0.7.1
|
||||
# via onyx
|
||||
openai==2.14.0
|
||||
# via
|
||||
@@ -326,7 +326,7 @@ pure-eval==0.2.3
|
||||
# via stack-data
|
||||
py==1.11.0
|
||||
# via retry
|
||||
pyasn1==0.6.2
|
||||
pyasn1==0.6.3
|
||||
# via
|
||||
# pyasn1-modules
|
||||
# rsa
|
||||
@@ -353,7 +353,7 @@ pygments==2.19.2
|
||||
# via
|
||||
# ipython
|
||||
# ipython-pygments-lexers
|
||||
pyjwt==2.11.0
|
||||
pyjwt==2.12.0
|
||||
# via mcp
|
||||
pyparsing==3.2.5
|
||||
# via matplotlib
|
||||
|
||||
@@ -195,7 +195,7 @@ propcache==0.4.1
|
||||
# yarl
|
||||
py==1.11.0
|
||||
# via retry
|
||||
pyasn1==0.6.2
|
||||
pyasn1==0.6.3
|
||||
# via
|
||||
# pyasn1-modules
|
||||
# rsa
|
||||
@@ -218,7 +218,7 @@ pydantic-core==2.33.2
|
||||
# via pydantic
|
||||
pydantic-settings==2.12.0
|
||||
# via mcp
|
||||
pyjwt==2.11.0
|
||||
pyjwt==2.12.0
|
||||
# via mcp
|
||||
python-dateutil==2.8.2
|
||||
# via
|
||||
|
||||
@@ -285,7 +285,7 @@ psutil==7.1.3
|
||||
# via accelerate
|
||||
py==1.11.0
|
||||
# via retry
|
||||
pyasn1==0.6.2
|
||||
pyasn1==0.6.3
|
||||
# via
|
||||
# pyasn1-modules
|
||||
# rsa
|
||||
@@ -308,7 +308,7 @@ pydantic-core==2.33.2
|
||||
# via pydantic
|
||||
pydantic-settings==2.12.0
|
||||
# via mcp
|
||||
pyjwt==2.11.0
|
||||
pyjwt==2.12.0
|
||||
# via mcp
|
||||
python-dateutil==2.8.2
|
||||
# via
|
||||
|
||||
471
backend/scripts/run_industryrag_bench_questions.py
Normal file
471
backend/scripts/run_industryrag_bench_questions.py
Normal file
@@ -0,0 +1,471 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from dataclasses import asdict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import TypedDict
|
||||
from typing import TypeGuard
|
||||
|
||||
import aiohttp
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULT_API_BASE = "http://localhost:3000"
|
||||
INTERNAL_SEARCH_TOOL_NAME = "internal_search"
|
||||
INTERNAL_SEARCH_IN_CODE_TOOL_ID = "SearchTool"
|
||||
MAX_REQUEST_ATTEMPTS = 5
|
||||
RETRIABLE_STATUS_CODES = {429, 500, 502, 503, 504}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class QuestionRecord:
|
||||
question_id: str
|
||||
question: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AnswerRecord:
|
||||
question_id: str
|
||||
answer: str
|
||||
document_ids: list[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FailedQuestionRecord:
|
||||
question_id: str
|
||||
error: str
|
||||
|
||||
|
||||
class Citation(TypedDict, total=False):
|
||||
citation_number: int
|
||||
document_id: str
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Submit questions to Onyx chat with internal search forced and write "
|
||||
"answers to a JSONL file."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--questions-file",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Path to the input questions JSONL file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-file",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Path to the output answers JSONL file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-key",
|
||||
type=str,
|
||||
required=True,
|
||||
help="API key used to authenticate against Onyx.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-base",
|
||||
type=str,
|
||||
default=DEFAULT_API_BASE,
|
||||
help=(
|
||||
"Frontend base URL for Onyx. If `/api` is omitted, it will be added "
|
||||
f"automatically. Default: {DEFAULT_API_BASE}"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--parallelism",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of questions to process in parallel. Default: 1.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-questions",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Optional cap on how many questions to process. Defaults to all.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def normalize_api_base(api_base: str) -> str:
|
||||
normalized = api_base.rstrip("/")
|
||||
if normalized.endswith("/api"):
|
||||
return normalized
|
||||
return f"{normalized}/api"
|
||||
|
||||
|
||||
def load_questions(questions_file: Path) -> list[QuestionRecord]:
|
||||
if not questions_file.exists():
|
||||
raise FileNotFoundError(f"Questions file not found: {questions_file}")
|
||||
|
||||
questions: list[QuestionRecord] = []
|
||||
with questions_file.open("r", encoding="utf-8") as file:
|
||||
for line_number, line in enumerate(file, start=1):
|
||||
stripped_line = line.strip()
|
||||
if not stripped_line:
|
||||
continue
|
||||
|
||||
try:
|
||||
payload = json.loads(stripped_line)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError(
|
||||
f"Invalid JSON on line {line_number} of {questions_file}"
|
||||
) from exc
|
||||
|
||||
question_id = payload.get("question_id")
|
||||
question = payload.get("question")
|
||||
|
||||
if not isinstance(question_id, str) or not question_id:
|
||||
raise ValueError(
|
||||
f"Line {line_number} is missing a non-empty `question_id`."
|
||||
)
|
||||
if not isinstance(question, str) or not question:
|
||||
raise ValueError(
|
||||
f"Line {line_number} is missing a non-empty `question`."
|
||||
)
|
||||
|
||||
questions.append(QuestionRecord(question_id=question_id, question=question))
|
||||
|
||||
return questions
|
||||
|
||||
|
||||
async def read_json_response(
|
||||
response: aiohttp.ClientResponse,
|
||||
) -> dict[str, Any] | list[dict[str, Any]]:
|
||||
response_text = await response.text()
|
||||
if response.status >= 400:
|
||||
raise RuntimeError(
|
||||
f"Request to {response.url} failed with {response.status}: {response_text}"
|
||||
)
|
||||
|
||||
try:
|
||||
payload = json.loads(response_text)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise RuntimeError(
|
||||
f"Request to {response.url} returned non-JSON content: {response_text}"
|
||||
) from exc
|
||||
|
||||
if not isinstance(payload, (dict, list)):
|
||||
raise RuntimeError(
|
||||
f"Unexpected response payload type from {response.url}: {type(payload)}"
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
async def request_json_with_retries(
|
||||
session: aiohttp.ClientSession,
|
||||
method: str,
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
json_payload: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any] | list[dict[str, Any]]:
|
||||
backoff_seconds = 1.0
|
||||
|
||||
for attempt in range(1, MAX_REQUEST_ATTEMPTS + 1):
|
||||
try:
|
||||
async with session.request(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
json=json_payload,
|
||||
) as response:
|
||||
if (
|
||||
response.status in RETRIABLE_STATUS_CODES
|
||||
and attempt < MAX_REQUEST_ATTEMPTS
|
||||
):
|
||||
response_text = await response.text()
|
||||
logger.warning(
|
||||
"Retryable response from %s on attempt %s/%s: %s %s",
|
||||
url,
|
||||
attempt,
|
||||
MAX_REQUEST_ATTEMPTS,
|
||||
response.status,
|
||||
response_text,
|
||||
)
|
||||
await asyncio.sleep(backoff_seconds)
|
||||
backoff_seconds *= 2
|
||||
continue
|
||||
|
||||
return await read_json_response(response)
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError) as exc:
|
||||
if attempt == MAX_REQUEST_ATTEMPTS:
|
||||
raise RuntimeError(
|
||||
f"Request to {url} failed after {MAX_REQUEST_ATTEMPTS} attempts."
|
||||
) from exc
|
||||
|
||||
logger.warning(
|
||||
"Request to %s failed on attempt %s/%s: %s",
|
||||
url,
|
||||
attempt,
|
||||
MAX_REQUEST_ATTEMPTS,
|
||||
exc,
|
||||
)
|
||||
await asyncio.sleep(backoff_seconds)
|
||||
backoff_seconds *= 2
|
||||
|
||||
raise RuntimeError(f"Request to {url} failed unexpectedly.")
|
||||
|
||||
|
||||
def extract_document_ids(citation_info: object) -> list[str]:
|
||||
if not isinstance(citation_info, list):
|
||||
return []
|
||||
|
||||
sorted_citations = sorted(
|
||||
(citation for citation in citation_info if _is_valid_citation(citation)),
|
||||
key=_citation_sort_key,
|
||||
)
|
||||
|
||||
document_ids: list[str] = []
|
||||
seen_document_ids: set[str] = set()
|
||||
for citation in sorted_citations:
|
||||
document_id = citation["document_id"]
|
||||
if document_id not in seen_document_ids:
|
||||
seen_document_ids.add(document_id)
|
||||
document_ids.append(document_id)
|
||||
|
||||
return document_ids
|
||||
|
||||
|
||||
def _is_valid_citation(citation: object) -> TypeGuard[Citation]:
|
||||
return (
|
||||
isinstance(citation, dict)
|
||||
and isinstance(citation.get("document_id"), str)
|
||||
and bool(citation["document_id"])
|
||||
)
|
||||
|
||||
|
||||
def _citation_sort_key(citation: Citation) -> int:
|
||||
citation_number = citation.get("citation_number")
|
||||
if isinstance(citation_number, int):
|
||||
return citation_number
|
||||
return sys.maxsize
|
||||
|
||||
|
||||
async def fetch_internal_search_tool_id(
|
||||
session: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
headers: dict[str, str],
|
||||
) -> int:
|
||||
payload = await request_json_with_retries(
|
||||
session=session,
|
||||
method="GET",
|
||||
url=f"{api_base}/tool",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if not isinstance(payload, list):
|
||||
raise RuntimeError("Expected `/tool` to return a list.")
|
||||
|
||||
for tool in payload:
|
||||
if not isinstance(tool, dict):
|
||||
continue
|
||||
|
||||
if tool.get("in_code_tool_id") == INTERNAL_SEARCH_IN_CODE_TOOL_ID:
|
||||
tool_id = tool.get("id")
|
||||
if isinstance(tool_id, int):
|
||||
return tool_id
|
||||
|
||||
for tool in payload:
|
||||
if not isinstance(tool, dict):
|
||||
continue
|
||||
|
||||
if tool.get("name") == INTERNAL_SEARCH_TOOL_NAME:
|
||||
tool_id = tool.get("id")
|
||||
if isinstance(tool_id, int):
|
||||
return tool_id
|
||||
|
||||
raise RuntimeError(
|
||||
"Could not find the internal search tool in `/tool`. "
|
||||
"Make sure SearchTool is available for this environment."
|
||||
)
|
||||
|
||||
|
||||
async def submit_question(
|
||||
session: aiohttp.ClientSession,
|
||||
api_base: str,
|
||||
headers: dict[str, str],
|
||||
internal_search_tool_id: int,
|
||||
question_record: QuestionRecord,
|
||||
) -> AnswerRecord:
|
||||
payload = {
|
||||
"message": question_record.question,
|
||||
"chat_session_info": {"persona_id": 0},
|
||||
"parent_message_id": None,
|
||||
"file_descriptors": [],
|
||||
"allowed_tool_ids": [internal_search_tool_id],
|
||||
"forced_tool_id": internal_search_tool_id,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
response_payload = await request_json_with_retries(
|
||||
session=session,
|
||||
method="POST",
|
||||
url=f"{api_base}/chat/send-chat-message",
|
||||
headers=headers,
|
||||
json_payload=payload,
|
||||
)
|
||||
|
||||
if not isinstance(response_payload, dict):
|
||||
raise RuntimeError(
|
||||
"Expected `/chat/send-chat-message` to return an object when `stream=false`."
|
||||
)
|
||||
|
||||
answer = response_payload.get("answer_citationless")
|
||||
if not isinstance(answer, str):
|
||||
answer = response_payload.get("answer")
|
||||
|
||||
if not isinstance(answer, str):
|
||||
raise RuntimeError(
|
||||
f"Response for question {question_record.question_id} is missing `answer`."
|
||||
)
|
||||
|
||||
return AnswerRecord(
|
||||
question_id=question_record.question_id,
|
||||
answer=answer,
|
||||
document_ids=extract_document_ids(response_payload.get("citation_info")),
|
||||
)
|
||||
|
||||
|
||||
async def generate_answers(
|
||||
questions: list[QuestionRecord],
|
||||
output_file: Path,
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
parallelism: int,
|
||||
) -> None:
|
||||
if parallelism < 1:
|
||||
raise ValueError("`--parallelism` must be at least 1.")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
timeout = aiohttp.ClientTimeout(
|
||||
total=None,
|
||||
connect=30,
|
||||
sock_connect=30,
|
||||
sock_read=600,
|
||||
)
|
||||
connector = aiohttp.TCPConnector(limit=parallelism)
|
||||
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
with output_file.open("a", encoding="utf-8") as file:
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=timeout, connector=connector
|
||||
) as session:
|
||||
internal_search_tool_id = await fetch_internal_search_tool_id(
|
||||
session=session,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
)
|
||||
logger.info("Using internal search tool id %s", internal_search_tool_id)
|
||||
|
||||
semaphore = asyncio.Semaphore(parallelism)
|
||||
progress_lock = asyncio.Lock()
|
||||
write_lock = asyncio.Lock()
|
||||
completed = 0
|
||||
successful = 0
|
||||
failed_questions: list[FailedQuestionRecord] = []
|
||||
total = len(questions)
|
||||
|
||||
async def process_question(question_record: QuestionRecord) -> None:
|
||||
nonlocal completed
|
||||
nonlocal successful
|
||||
|
||||
try:
|
||||
async with semaphore:
|
||||
result = await submit_question(
|
||||
session=session,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
internal_search_tool_id=internal_search_tool_id,
|
||||
question_record=question_record,
|
||||
)
|
||||
except Exception as exc:
|
||||
async with progress_lock:
|
||||
completed += 1
|
||||
failed_questions.append(
|
||||
FailedQuestionRecord(
|
||||
question_id=question_record.question_id,
|
||||
error=str(exc),
|
||||
)
|
||||
)
|
||||
logger.exception(
|
||||
"Failed question %s (%s/%s)",
|
||||
question_record.question_id,
|
||||
completed,
|
||||
total,
|
||||
)
|
||||
return
|
||||
|
||||
async with write_lock:
|
||||
file.write(json.dumps(asdict(result), ensure_ascii=False))
|
||||
file.write("\n")
|
||||
file.flush()
|
||||
|
||||
async with progress_lock:
|
||||
completed += 1
|
||||
successful += 1
|
||||
logger.info("Processed %s/%s questions", completed, total)
|
||||
|
||||
await asyncio.gather(
|
||||
*(process_question(question_record) for question_record in questions)
|
||||
)
|
||||
|
||||
if failed_questions:
|
||||
logger.warning(
|
||||
"Completed with %s failed questions and %s successful questions.",
|
||||
len(failed_questions),
|
||||
successful,
|
||||
)
|
||||
for failed_question in failed_questions:
|
||||
logger.warning(
|
||||
"Failed question %s: %s",
|
||||
failed_question.question_id,
|
||||
failed_question.error,
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
questions = load_questions(args.questions_file)
|
||||
api_base = normalize_api_base(args.api_base)
|
||||
|
||||
if args.max_questions is not None:
|
||||
if args.max_questions < 1:
|
||||
raise ValueError("`--max-questions` must be at least 1 when provided.")
|
||||
questions = questions[: args.max_questions]
|
||||
|
||||
logger.info("Loaded %s questions from %s", len(questions), args.questions_file)
|
||||
logger.info("Writing answers to %s", args.output_file)
|
||||
|
||||
asyncio.run(
|
||||
generate_answers(
|
||||
questions=questions,
|
||||
output_file=args.output_file,
|
||||
api_base=api_base,
|
||||
api_key=args.api_key,
|
||||
parallelism=args.parallelism,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
291
backend/scripts/upload_files_as_connectors.py
Normal file
291
backend/scripts/upload_files_as_connectors.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""
|
||||
Script to upload files from a directory as individual file connectors in Onyx.
|
||||
Each file gets its own connector named after the file.
|
||||
|
||||
Usage:
|
||||
python upload_files_as_connectors.py --data-dir /path/to/files --api-key YOUR_KEY
|
||||
python upload_files_as_connectors.py --data-dir /path/to/files --api-key YOUR_KEY --api-base http://onyxserver:3000
|
||||
python upload_files_as_connectors.py --data-dir /path/to/files --api-key YOUR_KEY --file-glob '*.zip'
|
||||
|
||||
Requires:
|
||||
pip install requests
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import fnmatch
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
REQUEST_TIMEOUT = 900 # 15 minutes
|
||||
|
||||
|
||||
def _elapsed_printer(label: str, stop_event: threading.Event) -> None:
|
||||
"""Print a live elapsed-time counter until stop_event is set."""
|
||||
start = time.monotonic()
|
||||
while not stop_event.wait(timeout=1):
|
||||
elapsed = int(time.monotonic() - start)
|
||||
m, s = divmod(elapsed, 60)
|
||||
print(f"\r {label} ... {m:02d}:{s:02d}", end="", flush=True)
|
||||
elapsed = int(time.monotonic() - start)
|
||||
m, s = divmod(elapsed, 60)
|
||||
print(f"\r {label} ... {m:02d}:{s:02d} done")
|
||||
|
||||
|
||||
def _timed_request(label: str, fn: object) -> requests.Response:
|
||||
"""Run a request function while displaying a live elapsed timer."""
|
||||
stop = threading.Event()
|
||||
t = threading.Thread(target=_elapsed_printer, args=(label, stop), daemon=True)
|
||||
t.start()
|
||||
try:
|
||||
resp = fn() # type: ignore[operator]
|
||||
finally:
|
||||
stop.set()
|
||||
t.join()
|
||||
return resp
|
||||
|
||||
|
||||
def upload_file(
|
||||
session: requests.Session, base_url: str, file_path: str
|
||||
) -> dict | None:
|
||||
"""Upload a single file and return the response with file_paths and file_names."""
|
||||
with open(file_path, "rb") as f:
|
||||
resp = _timed_request(
|
||||
"Uploading",
|
||||
lambda: session.post(
|
||||
f"{base_url}/api/manage/admin/connector/file/upload",
|
||||
files={"files": (os.path.basename(file_path), f)},
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
),
|
||||
)
|
||||
if not resp.ok:
|
||||
print(f" ERROR uploading: {resp.text}")
|
||||
return None
|
||||
return resp.json()
|
||||
|
||||
|
||||
def create_connector(
|
||||
session: requests.Session,
|
||||
base_url: str,
|
||||
name: str,
|
||||
file_paths: list[str],
|
||||
file_names: list[str],
|
||||
zip_metadata_file_id: str | None,
|
||||
) -> int | None:
|
||||
"""Create a file connector and return its ID."""
|
||||
resp = _timed_request(
|
||||
"Creating connector",
|
||||
lambda: session.post(
|
||||
f"{base_url}/api/manage/admin/connector",
|
||||
json={
|
||||
"name": name,
|
||||
"source": "file",
|
||||
"input_type": "load_state",
|
||||
"connector_specific_config": {
|
||||
"file_locations": file_paths,
|
||||
"file_names": file_names,
|
||||
"zip_metadata_file_id": zip_metadata_file_id,
|
||||
},
|
||||
"refresh_freq": None,
|
||||
"prune_freq": None,
|
||||
"indexing_start": None,
|
||||
"access_type": "public",
|
||||
"groups": [],
|
||||
},
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
),
|
||||
)
|
||||
if not resp.ok:
|
||||
print(f" ERROR creating connector: {resp.text}")
|
||||
return None
|
||||
return resp.json()["id"]
|
||||
|
||||
|
||||
def create_credential(
|
||||
session: requests.Session, base_url: str, name: str
|
||||
) -> int | None:
|
||||
"""Create a dummy credential for the file connector."""
|
||||
resp = session.post(
|
||||
f"{base_url}/api/manage/credential",
|
||||
json={
|
||||
"credential_json": {},
|
||||
"admin_public": True,
|
||||
"source": "file",
|
||||
"curator_public": True,
|
||||
"groups": [],
|
||||
"name": name,
|
||||
},
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
)
|
||||
if not resp.ok:
|
||||
print(f" ERROR creating credential: {resp.text}")
|
||||
return None
|
||||
return resp.json()["id"]
|
||||
|
||||
|
||||
def link_credential(
|
||||
session: requests.Session,
|
||||
base_url: str,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
name: str,
|
||||
) -> bool:
|
||||
"""Link the connector to the credential (create CC pair)."""
|
||||
resp = session.put(
|
||||
f"{base_url}/api/manage/connector/{connector_id}/credential/{credential_id}",
|
||||
json={
|
||||
"name": name,
|
||||
"access_type": "public",
|
||||
"groups": [],
|
||||
"auto_sync_options": None,
|
||||
"processing_mode": "REGULAR",
|
||||
},
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
)
|
||||
if not resp.ok:
|
||||
print(f" ERROR linking credential: {resp.text}")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def run_connector(
|
||||
session: requests.Session,
|
||||
base_url: str,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> bool:
|
||||
"""Trigger the connector to start indexing."""
|
||||
resp = session.post(
|
||||
f"{base_url}/api/manage/admin/connector/run-once",
|
||||
json={
|
||||
"connector_id": connector_id,
|
||||
"credentialIds": [credential_id],
|
||||
"from_beginning": False,
|
||||
},
|
||||
timeout=REQUEST_TIMEOUT,
|
||||
)
|
||||
if not resp.ok:
|
||||
print(f" ERROR running connector: {resp.text}")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def process_file(session: requests.Session, base_url: str, file_path: str) -> bool:
|
||||
"""Process a single file through the full connector creation flow."""
|
||||
file_name = os.path.basename(file_path)
|
||||
connector_name = file_name
|
||||
print(f"Processing: {file_name}")
|
||||
|
||||
# Step 1: Upload
|
||||
upload_resp = upload_file(session, base_url, file_path)
|
||||
if not upload_resp:
|
||||
return False
|
||||
|
||||
# Step 2: Create connector
|
||||
connector_id = create_connector(
|
||||
session,
|
||||
base_url,
|
||||
name=f"FileConnector-{connector_name}",
|
||||
file_paths=upload_resp["file_paths"],
|
||||
file_names=upload_resp["file_names"],
|
||||
zip_metadata_file_id=upload_resp.get("zip_metadata_file_id"),
|
||||
)
|
||||
if connector_id is None:
|
||||
return False
|
||||
|
||||
# Step 3: Create credential
|
||||
credential_id = create_credential(session, base_url, name=connector_name)
|
||||
if credential_id is None:
|
||||
return False
|
||||
|
||||
# Step 4: Link connector to credential
|
||||
if not link_credential(
|
||||
session, base_url, connector_id, credential_id, connector_name
|
||||
):
|
||||
return False
|
||||
|
||||
# Step 5: Trigger indexing
|
||||
if not run_connector(session, base_url, connector_id, credential_id):
|
||||
return False
|
||||
|
||||
print(f" OK (connector_id={connector_id})")
|
||||
return True
|
||||
|
||||
|
||||
def get_authenticated_session(api_key: str) -> requests.Session:
|
||||
"""Create a session authenticated with an API key."""
|
||||
session = requests.Session()
|
||||
session.headers.update({"Authorization": f"Bearer {api_key}"})
|
||||
return session
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Upload files as individual Onyx file connectors."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
required=True,
|
||||
help="Directory containing files to upload.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-base",
|
||||
default="http://localhost:3000",
|
||||
help="Base URL for the Onyx API (default: http://localhost:3000).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-key",
|
||||
required=True,
|
||||
help="API key for authentication.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--file-glob",
|
||||
default=None,
|
||||
help="Glob pattern to filter files (e.g. '*.json', '*.zip').",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
data_dir = args.data_dir
|
||||
base_url = args.api_base.rstrip("/")
|
||||
api_key = args.api_key
|
||||
file_glob = args.file_glob
|
||||
|
||||
if not os.path.isdir(data_dir):
|
||||
print(f"Error: {data_dir} is not a directory")
|
||||
sys.exit(1)
|
||||
|
||||
script_path = os.path.realpath(__file__)
|
||||
files = sorted(
|
||||
os.path.join(data_dir, f)
|
||||
for f in os.listdir(data_dir)
|
||||
if os.path.isfile(os.path.join(data_dir, f))
|
||||
and os.path.realpath(os.path.join(data_dir, f)) != script_path
|
||||
and (file_glob is None or fnmatch.fnmatch(f, file_glob))
|
||||
)
|
||||
|
||||
if not files:
|
||||
print(f"No files found in {data_dir}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Found {len(files)} file(s) in {data_dir}\n")
|
||||
|
||||
session = get_authenticated_session(api_key)
|
||||
|
||||
success = 0
|
||||
failed = 0
|
||||
for file_path in files:
|
||||
if process_file(session, base_url, file_path):
|
||||
success += 1
|
||||
else:
|
||||
failed += 1
|
||||
# Small delay to avoid overwhelming the server
|
||||
time.sleep(0.5)
|
||||
|
||||
print(f"\nDone: {success} succeeded, {failed} failed out of {len(files)} files.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -45,6 +45,21 @@ npx playwright test <TEST_NAME>
|
||||
Shared fixtures live in `backend/tests/conftest.py`. Test subdirectories can define
|
||||
their own `conftest.py` for directory-scoped fixtures.
|
||||
|
||||
## Running Tests Repeatedly (`pytest-repeat`)
|
||||
|
||||
Use `pytest-repeat` to catch flaky tests by running them multiple times:
|
||||
|
||||
```bash
|
||||
# Run a specific test 50 times
|
||||
pytest --count=50 backend/tests/unit/path/to/test.py::test_name
|
||||
|
||||
# Stop on first failure with -x
|
||||
pytest --count=50 -x backend/tests/unit/path/to/test.py::test_name
|
||||
|
||||
# Repeat an entire test file
|
||||
pytest --count=10 backend/tests/unit/path/to/test_file.py
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Use `enable_ee` fixture instead of inlining
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user