mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-16 23:35:46 +00:00
Compare commits
281 Commits
v2.8.0-bet
...
hotfix/89d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e19880dfb9 | ||
|
|
e6ef2b5074 | ||
|
|
74132175a8 | ||
|
|
29f707ee2d | ||
|
|
f0eb86fb9f | ||
|
|
b422496a4c | ||
|
|
31d6a45b23 | ||
|
|
36f3ac1ec5 | ||
|
|
74f5b3025a | ||
|
|
c18545d74c | ||
|
|
48171e3700 | ||
|
|
f5a5709876 | ||
|
|
85868b1b83 | ||
|
|
8dc14c23e6 | ||
|
|
23821cc0e8 | ||
|
|
b359e13281 | ||
|
|
717f410a4a | ||
|
|
ada0946a62 | ||
|
|
eb2ac8f5a3 | ||
|
|
fbeb57c592 | ||
|
|
d6da9c9b85 | ||
|
|
5aea2e223e | ||
|
|
1ff91de07e | ||
|
|
b3dbc69faf | ||
|
|
431597b0f9 | ||
|
|
51b4e5f2fb | ||
|
|
9afa04a26b | ||
|
|
70a3a9c0cd | ||
|
|
080165356c | ||
|
|
3ae974bdf6 | ||
|
|
1471658151 | ||
|
|
3e85e9c1a3 | ||
|
|
851033be5f | ||
|
|
91e974a6cc | ||
|
|
38ba4f8a1c | ||
|
|
6f02473064 | ||
|
|
f89432009f | ||
|
|
8ab2bab34e | ||
|
|
59e0d62512 | ||
|
|
a1471b16a5 | ||
|
|
9d3811cb58 | ||
|
|
3cd9505383 | ||
|
|
d11829b393 | ||
|
|
f6e068e914 | ||
|
|
0c84edd980 | ||
|
|
2b274a7683 | ||
|
|
ddd91f2d71 | ||
|
|
a7c7da0dfc | ||
|
|
b00a3e8b5d | ||
|
|
d77d1a48f1 | ||
|
|
7b4fc6729c | ||
|
|
1f113c86ef | ||
|
|
8e38ba3e21 | ||
|
|
bb9708a64f | ||
|
|
8cae97e145 | ||
|
|
7e4abca224 | ||
|
|
233a91ea65 | ||
|
|
b30737b6b2 | ||
|
|
caf8b85ec2 | ||
|
|
1d13580b63 | ||
|
|
00390c53e0 | ||
|
|
66656df9e6 | ||
|
|
51d26d7e4c | ||
|
|
198ac8ccbc | ||
|
|
ee6d33f484 | ||
|
|
7bcb72d055 | ||
|
|
876894e097 | ||
|
|
7215f56b25 | ||
|
|
0fd1c34014 | ||
|
|
9e24b41b7b | ||
|
|
ab3853578b | ||
|
|
7db969d36a | ||
|
|
6cdeb71656 | ||
|
|
2c4b2c68b4 | ||
|
|
5301ee7cef | ||
|
|
f8e6716875 | ||
|
|
755c65fd8a | ||
|
|
90cf5f49e3 | ||
|
|
d4068c2b07 | ||
|
|
fd6fa43fe1 | ||
|
|
8d5013bf01 | ||
|
|
dabd7c6263 | ||
|
|
c8c0389675 | ||
|
|
9cfcfb12e1 | ||
|
|
786a0c2bd0 | ||
|
|
0cd8d3402b | ||
|
|
3fa397b24d | ||
|
|
e0a97230b8 | ||
|
|
7f1272117a | ||
|
|
79302f19be | ||
|
|
4a91e644d4 | ||
|
|
ca0318f16e | ||
|
|
be8e0b3a98 | ||
|
|
49c4814c70 | ||
|
|
2f945613a2 | ||
|
|
e9242ca3a8 | ||
|
|
a150de761a | ||
|
|
0e792ca6c9 | ||
|
|
6be467a4ac | ||
|
|
dd91bfcfe6 | ||
|
|
8a72291781 | ||
|
|
b2d71da4eb | ||
|
|
6e2f851c62 | ||
|
|
be078edcb4 | ||
|
|
194c54aca3 | ||
|
|
9fa7221e24 | ||
|
|
3a5c7ef8ee | ||
|
|
84458aa0bf | ||
|
|
de57bfa35f | ||
|
|
386f8f31ed | ||
|
|
376f04caea | ||
|
|
4b0a3c2b04 | ||
|
|
1bd9f9d9a6 | ||
|
|
4ac10abaea | ||
|
|
a66a283af4 | ||
|
|
bf5da04166 | ||
|
|
693487f855 | ||
|
|
d02a76d7d1 | ||
|
|
28e05c6e90 | ||
|
|
a18f546921 | ||
|
|
e98dea149e | ||
|
|
027c165794 | ||
|
|
14ebe912c8 | ||
|
|
a63b906789 | ||
|
|
92a68a3c22 | ||
|
|
95db4ed9c7 | ||
|
|
5134d60d48 | ||
|
|
651a54470d | ||
|
|
269d243b67 | ||
|
|
0286dd7da9 | ||
|
|
f3a0710d69 | ||
|
|
037c2aee3a | ||
|
|
9b2f3d234d | ||
|
|
7646399cd4 | ||
|
|
d913b93d10 | ||
|
|
8a0ce4c294 | ||
|
|
862c140763 | ||
|
|
47487f1940 | ||
|
|
e3471df940 | ||
|
|
fb33c815b3 | ||
|
|
5c6594be73 | ||
|
|
8d30a03d7f | ||
|
|
277428f579 | ||
|
|
9f8c0d4237 | ||
|
|
9ccbb6a04b | ||
|
|
58a943f782 | ||
|
|
9021c607f2 | ||
|
|
c03b0d80fd | ||
|
|
fcf0b316a4 | ||
|
|
157f672b4b | ||
|
|
51b9484b96 | ||
|
|
0c8f55c049 | ||
|
|
c7be2571d1 | ||
|
|
4948b6cca9 | ||
|
|
638ea5f316 | ||
|
|
6e3268ca75 | ||
|
|
d8921df60c | ||
|
|
693d9f5f69 | ||
|
|
02e17871cc | ||
|
|
209cfd00b0 | ||
|
|
cd36baa484 | ||
|
|
c78fe275af | ||
|
|
c935c4808f | ||
|
|
4ebcfef541 | ||
|
|
e320ef9d9c | ||
|
|
9e02438af5 | ||
|
|
177e097ddb | ||
|
|
9ecd47ec31 | ||
|
|
83f3d29b10 | ||
|
|
12e668cc0f | ||
|
|
afe8376d5e | ||
|
|
082ef3e096 | ||
|
|
cb2951a1c0 | ||
|
|
eda5598af5 | ||
|
|
0bbb4b6988 | ||
|
|
4768aadb20 | ||
|
|
e05e85e782 | ||
|
|
6408f61307 | ||
|
|
5a5cd51e4f | ||
|
|
7c047c47a0 | ||
|
|
22138bbb33 | ||
|
|
7cff1064a8 | ||
|
|
deeb6fdcd2 | ||
|
|
3e7f4e0aa5 | ||
|
|
ac73671e35 | ||
|
|
3c20d132e0 | ||
|
|
0e3e7eb4a2 | ||
|
|
c85aebe8ab | ||
|
|
a47e6a3146 | ||
|
|
1e61737e03 | ||
|
|
c7fc1cd5ae | ||
|
|
e2b60bf67c | ||
|
|
f4d4d14286 | ||
|
|
1c24bc6ea2 | ||
|
|
cacbd18dcd | ||
|
|
8527b83b15 | ||
|
|
33e37a1846 | ||
|
|
d454d8a878 | ||
|
|
00ad65a6a8 | ||
|
|
dac60d403c | ||
|
|
6256b2854d | ||
|
|
8acb8e191d | ||
|
|
8c4cbddc43 | ||
|
|
f6cd006bd6 | ||
|
|
0033934319 | ||
|
|
ff87b79d14 | ||
|
|
ebf18af7c9 | ||
|
|
cf67ae962c | ||
|
|
7a9a132739 | ||
|
|
33bad8c37b | ||
|
|
9241ff7a75 | ||
|
|
0a25bc30ec | ||
|
|
e359732f4c | ||
|
|
be47866a4d | ||
|
|
8a20540559 | ||
|
|
e6e1f2860a | ||
|
|
fc3f433df7 | ||
|
|
016caf453b | ||
|
|
a9de25053f | ||
|
|
8ef8dfdeb7 | ||
|
|
0643b626d9 | ||
|
|
64a0eb52e0 | ||
|
|
b82ffc82cf | ||
|
|
b3014b9911 | ||
|
|
439707c395 | ||
|
|
65351aa8bd | ||
|
|
b44ee07eaf | ||
|
|
065d391c08 | ||
|
|
14fe3b375f | ||
|
|
bb1b96dded | ||
|
|
9f949ae2d9 | ||
|
|
975c0e8009 | ||
|
|
3dfb38c460 | ||
|
|
a1512a0485 | ||
|
|
8ea3bacd38 | ||
|
|
6b560b8162 | ||
|
|
3b750939ed | ||
|
|
bd4cb17a48 | ||
|
|
485cd9a311 | ||
|
|
2108c72353 | ||
|
|
98f43fb6ab | ||
|
|
e112ebb371 | ||
|
|
f88cbcfe27 | ||
|
|
0df0b10d3a | ||
|
|
ed0d12452a | ||
|
|
dc7cb80594 | ||
|
|
4312b24945 | ||
|
|
afd920bb33 | ||
|
|
d009b12aa7 | ||
|
|
596b3d9f3e | ||
|
|
1981c912b7 | ||
|
|
68b1bb8448 | ||
|
|
4676b5017f | ||
|
|
eb7b6a5ce1 | ||
|
|
87d6df2621 | ||
|
|
13b4108b53 | ||
|
|
13e806b625 | ||
|
|
f4f7839d84 | ||
|
|
2dbf1c3b1f | ||
|
|
288d4147c3 | ||
|
|
fee27b2274 | ||
|
|
340e938627 | ||
|
|
6faa47e0f7 | ||
|
|
ba6801f5af | ||
|
|
d7447eb8af | ||
|
|
196f890a68 | ||
|
|
3ac96572c3 | ||
|
|
3d8ae22b3a | ||
|
|
233d06ec0e | ||
|
|
9ff82ac740 | ||
|
|
b15f01fd78 | ||
|
|
6480cf6738 | ||
|
|
c521a4397a | ||
|
|
41a8d86df3 | ||
|
|
735cf926e4 | ||
|
|
035e73655f | ||
|
|
f317420f58 | ||
|
|
d50a84f2e4 | ||
|
|
9b441e3686 | ||
|
|
c4c1e16f19 | ||
|
|
9044e0f5fa |
559
.github/workflows/deployment.yml
vendored
559
.github/workflows/deployment.yml
vendored
File diff suppressed because it is too large
Load Diff
@@ -13,7 +13,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- uses: actions/stale@5f858e3efba33a5ca4407a664cc011ad407f2008 # ratchet:actions/stale@v10
|
||||
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # ratchet:actions/stale@v10
|
||||
with:
|
||||
stale-issue-message: 'This issue is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
|
||||
stale-pr-message: 'This PR is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.'
|
||||
|
||||
@@ -38,6 +38,8 @@ env:
|
||||
# LLMs
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
VERTEX_CREDENTIALS: ${{ secrets.VERTEX_CREDENTIALS }}
|
||||
VERTEX_LOCATION: ${{ vars.VERTEX_LOCATION }}
|
||||
|
||||
# Code Interpreter
|
||||
# TODO: debug why this is failing and enable
|
||||
@@ -170,7 +172,7 @@ jobs:
|
||||
|
||||
- name: Upload Docker logs
|
||||
if: failure()
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v5
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-logs-${{ matrix.test-dir }}
|
||||
path: docker-logs/
|
||||
|
||||
412
.github/workflows/pr-helm-chart-testing.yml
vendored
412
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -6,11 +6,11 @@ concurrency:
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
branches: [main]
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -18,225 +18,233 @@ permissions:
|
||||
jobs:
|
||||
helm-chart-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}-helm-chart-check"]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=8cpu-linux-x64,
|
||||
hdd=256,
|
||||
"run-id=${{ github.run_id }}-helm-chart-check",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
|
||||
# fetch-depth 0 is required for helm/chart-testing-action
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4.3.1
|
||||
with:
|
||||
version: v3.19.0
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4.3.1
|
||||
with:
|
||||
version: v3.19.0
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@6ec842c01de15ebb84c8627d2744a0c2f2755c9f # ratchet:helm/chart-testing-action@v2.8.0
|
||||
- name: Set up chart-testing
|
||||
# NOTE: This is Jamison's patch from https://github.com/helm/chart-testing-action/pull/194
|
||||
uses: helm/chart-testing-action@8958a6ac472cbd8ee9a8fbb6f1acbc1b0e966e44 # zizmor: ignore[impostor-commit]
|
||||
with:
|
||||
uv_version: "0.9.9"
|
||||
|
||||
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
|
||||
- name: Run chart-testing (list-changed)
|
||||
id: list-changed
|
||||
env:
|
||||
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
|
||||
run: |
|
||||
echo "default_branch: ${DEFAULT_BRANCH}"
|
||||
changed=$(ct list-changed --remote origin --target-branch ${DEFAULT_BRANCH} --chart-dirs deployment/helm/charts)
|
||||
echo "list-changed output: $changed"
|
||||
if [[ -n "$changed" ]]; then
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
# uncomment to force run chart-testing
|
||||
# - name: Force run chart-testing (list-changed)
|
||||
# id: list-changed
|
||||
# run: echo "changed=true" >> $GITHUB_OUTPUT
|
||||
|
||||
# lint all charts if any changes were detected
|
||||
- name: Run chart-testing (lint)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct lint --config ct.yaml --all
|
||||
# the following would lint only changed charts, but linting isn't expensive
|
||||
# run: ct lint --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
- name: Create kind cluster
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@92086f6be054225fa813e0a4b13787fc9088faab # ratchet:helm/kind-action@v1.13.0
|
||||
|
||||
- name: Pre-install cluster status check
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Pre-install Cluster Status ==="
|
||||
kubectl get nodes -o wide
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get storageclass
|
||||
|
||||
- name: Add Helm repositories and update
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Adding Helm repositories ==="
|
||||
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
|
||||
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
|
||||
helm repo update
|
||||
|
||||
- name: Install Redis operator
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
echo "=== Installing redis-operator CRDs ==="
|
||||
helm upgrade --install redis-operator ot-container-kit/redis-operator \
|
||||
--namespace redis-operator --create-namespace --wait --timeout 300s
|
||||
|
||||
- name: Pre-pull required images
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Pre-pulling required images to avoid timeout ==="
|
||||
KIND_CLUSTER=$(kubectl config current-context | sed 's/kind-//')
|
||||
echo "Kind cluster: $KIND_CLUSTER"
|
||||
|
||||
IMAGES=(
|
||||
"ghcr.io/cloudnative-pg/cloudnative-pg:1.27.0"
|
||||
"quay.io/opstree/redis:v7.0.15"
|
||||
"docker.io/onyxdotapp/onyx-web-server:latest"
|
||||
)
|
||||
|
||||
for image in "${IMAGES[@]}"; do
|
||||
echo "Pre-pulling $image"
|
||||
if docker pull "$image"; then
|
||||
kind load docker-image "$image" --name "$KIND_CLUSTER" || echo "Failed to load $image into kind"
|
||||
else
|
||||
echo "Failed to pull $image"
|
||||
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
|
||||
- name: Run chart-testing (list-changed)
|
||||
id: list-changed
|
||||
env:
|
||||
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
|
||||
run: |
|
||||
echo "default_branch: ${DEFAULT_BRANCH}"
|
||||
changed=$(ct list-changed --remote origin --target-branch ${DEFAULT_BRANCH} --chart-dirs deployment/helm/charts)
|
||||
echo "list-changed output: $changed"
|
||||
if [[ -n "$changed" ]]; then
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
done
|
||||
|
||||
echo "=== Images loaded into Kind cluster ==="
|
||||
docker exec "$KIND_CLUSTER"-control-plane crictl images | grep -E "(cloudnative-pg|redis|onyx)" || echo "Some images may still be loading..."
|
||||
# uncomment to force run chart-testing
|
||||
# - name: Force run chart-testing (list-changed)
|
||||
# id: list-changed
|
||||
# run: echo "changed=true" >> $GITHUB_OUTPUT
|
||||
# lint all charts if any changes were detected
|
||||
- name: Run chart-testing (lint)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct lint --config ct.yaml --all
|
||||
# the following would lint only changed charts, but linting isn't expensive
|
||||
# run: ct lint --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
- name: Validate chart dependencies
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Validating chart dependencies ==="
|
||||
cd deployment/helm/charts/onyx
|
||||
helm dependency update
|
||||
helm lint .
|
||||
- name: Create kind cluster
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@92086f6be054225fa813e0a4b13787fc9088faab # ratchet:helm/kind-action@v1.13.0
|
||||
|
||||
- name: Run chart-testing (install) with enhanced monitoring
|
||||
timeout-minutes: 25
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Starting chart installation with monitoring ==="
|
||||
- name: Pre-install cluster status check
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Pre-install Cluster Status ==="
|
||||
kubectl get nodes -o wide
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get storageclass
|
||||
|
||||
# Function to monitor cluster state
|
||||
monitor_cluster() {
|
||||
while true; do
|
||||
echo "=== Cluster Status Check at $(date) ==="
|
||||
# Only show non-running pods to reduce noise
|
||||
NON_RUNNING_PODS=$(kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded --no-headers 2>/dev/null | wc -l)
|
||||
if [ "$NON_RUNNING_PODS" -gt 0 ]; then
|
||||
echo "Non-running pods:"
|
||||
kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded
|
||||
- name: Add Helm repositories and update
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Adding Helm repositories ==="
|
||||
helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx
|
||||
helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
|
||||
helm repo update
|
||||
|
||||
- name: Install Redis operator
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
shell: bash
|
||||
run: |
|
||||
echo "=== Installing redis-operator CRDs ==="
|
||||
helm upgrade --install redis-operator ot-container-kit/redis-operator \
|
||||
--namespace redis-operator --create-namespace --wait --timeout 300s
|
||||
|
||||
- name: Pre-pull required images
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Pre-pulling required images to avoid timeout ==="
|
||||
KIND_CLUSTER=$(kubectl config current-context | sed 's/kind-//')
|
||||
echo "Kind cluster: $KIND_CLUSTER"
|
||||
|
||||
IMAGES=(
|
||||
"ghcr.io/cloudnative-pg/cloudnative-pg:1.27.0"
|
||||
"quay.io/opstree/redis:v7.0.15"
|
||||
"docker.io/onyxdotapp/onyx-web-server:latest"
|
||||
)
|
||||
|
||||
for image in "${IMAGES[@]}"; do
|
||||
echo "Pre-pulling $image"
|
||||
if docker pull "$image"; then
|
||||
kind load docker-image "$image" --name "$KIND_CLUSTER" || echo "Failed to load $image into kind"
|
||||
else
|
||||
echo "All pods running successfully"
|
||||
echo "Failed to pull $image"
|
||||
fi
|
||||
# Only show recent events if there are issues
|
||||
RECENT_EVENTS=$(kubectl get events --sort-by=.lastTimestamp --all-namespaces --field-selector=type!=Normal 2>/dev/null | tail -5)
|
||||
if [ -n "$RECENT_EVENTS" ]; then
|
||||
echo "Recent warnings/errors:"
|
||||
echo "$RECENT_EVENTS"
|
||||
fi
|
||||
sleep 60
|
||||
done
|
||||
}
|
||||
|
||||
# Start monitoring in background
|
||||
monitor_cluster &
|
||||
MONITOR_PID=$!
|
||||
echo "=== Images loaded into Kind cluster ==="
|
||||
docker exec "$KIND_CLUSTER"-control-plane crictl images | grep -E "(cloudnative-pg|redis|onyx)" || echo "Some images may still be loading..."
|
||||
|
||||
# Set up cleanup
|
||||
cleanup() {
|
||||
echo "=== Cleaning up monitoring process ==="
|
||||
kill $MONITOR_PID 2>/dev/null || true
|
||||
- name: Validate chart dependencies
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Validating chart dependencies ==="
|
||||
cd deployment/helm/charts/onyx
|
||||
helm dependency update
|
||||
helm lint .
|
||||
|
||||
- name: Run chart-testing (install) with enhanced monitoring
|
||||
timeout-minutes: 25
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Starting chart installation with monitoring ==="
|
||||
|
||||
# Function to monitor cluster state
|
||||
monitor_cluster() {
|
||||
while true; do
|
||||
echo "=== Cluster Status Check at $(date) ==="
|
||||
# Only show non-running pods to reduce noise
|
||||
NON_RUNNING_PODS=$(kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded --no-headers 2>/dev/null | wc -l)
|
||||
if [ "$NON_RUNNING_PODS" -gt 0 ]; then
|
||||
echo "Non-running pods:"
|
||||
kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded
|
||||
else
|
||||
echo "All pods running successfully"
|
||||
fi
|
||||
# Only show recent events if there are issues
|
||||
RECENT_EVENTS=$(kubectl get events --sort-by=.lastTimestamp --all-namespaces --field-selector=type!=Normal 2>/dev/null | tail -5)
|
||||
if [ -n "$RECENT_EVENTS" ]; then
|
||||
echo "Recent warnings/errors:"
|
||||
echo "$RECENT_EVENTS"
|
||||
fi
|
||||
sleep 60
|
||||
done
|
||||
}
|
||||
|
||||
# Start monitoring in background
|
||||
monitor_cluster &
|
||||
MONITOR_PID=$!
|
||||
|
||||
# Set up cleanup
|
||||
cleanup() {
|
||||
echo "=== Cleaning up monitoring process ==="
|
||||
kill $MONITOR_PID 2>/dev/null || true
|
||||
echo "=== Final cluster state ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -20
|
||||
}
|
||||
|
||||
# Trap cleanup on exit
|
||||
trap cleanup EXIT
|
||||
|
||||
# Run the actual installation with detailed logging
|
||||
echo "=== Starting ct install ==="
|
||||
set +e
|
||||
ct install --all \
|
||||
--helm-extra-set-args="\
|
||||
--set=nginx.enabled=false \
|
||||
--set=minio.enabled=false \
|
||||
--set=vespa.enabled=false \
|
||||
--set=slackbot.enabled=false \
|
||||
--set=postgresql.enabled=true \
|
||||
--set=postgresql.nameOverride=cloudnative-pg \
|
||||
--set=postgresql.cluster.storage.storageClass=standard \
|
||||
--set=redis.enabled=true \
|
||||
--set=redis.storageSpec.volumeClaimTemplate.spec.storageClassName=standard \
|
||||
--set=webserver.replicaCount=1 \
|
||||
--set=api.replicaCount=0 \
|
||||
--set=inferenceCapability.replicaCount=0 \
|
||||
--set=indexCapability.replicaCount=0 \
|
||||
--set=celery_beat.replicaCount=0 \
|
||||
--set=celery_worker_heavy.replicaCount=0 \
|
||||
--set=celery_worker_docfetching.replicaCount=0 \
|
||||
--set=celery_worker_docprocessing.replicaCount=0 \
|
||||
--set=celery_worker_light.replicaCount=0 \
|
||||
--set=celery_worker_monitoring.replicaCount=0 \
|
||||
--set=celery_worker_primary.replicaCount=0 \
|
||||
--set=celery_worker_user_file_processing.replicaCount=0 \
|
||||
--set=celery_worker_user_files_indexing.replicaCount=0" \
|
||||
--helm-extra-args="--timeout 900s --debug" \
|
||||
--debug --config ct.yaml
|
||||
CT_EXIT=$?
|
||||
set -e
|
||||
|
||||
if [[ $CT_EXIT -ne 0 ]]; then
|
||||
echo "ct install failed with exit code $CT_EXIT"
|
||||
exit $CT_EXIT
|
||||
else
|
||||
echo "=== Installation completed successfully ==="
|
||||
fi
|
||||
|
||||
kubectl get pods --all-namespaces
|
||||
|
||||
- name: Post-install verification
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Post-install verification ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get services --all-namespaces
|
||||
# Only show issues if they exist
|
||||
kubectl describe pods --all-namespaces | grep -A 5 -B 2 "Failed\|Error\|Warning" || echo "No pod issues found"
|
||||
|
||||
- name: Cleanup on failure
|
||||
if: failure() && steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Cleanup on failure ==="
|
||||
echo "=== Final cluster state ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -20
|
||||
}
|
||||
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -10
|
||||
|
||||
# Trap cleanup on exit
|
||||
trap cleanup EXIT
|
||||
echo "=== Pod descriptions for debugging ==="
|
||||
kubectl describe pods --all-namespaces | grep -A 10 -B 3 "Failed\|Error\|Warning\|Pending" || echo "No problematic pods found"
|
||||
|
||||
# Run the actual installation with detailed logging
|
||||
echo "=== Starting ct install ==="
|
||||
set +e
|
||||
ct install --all \
|
||||
--helm-extra-set-args="\
|
||||
--set=nginx.enabled=false \
|
||||
--set=minio.enabled=false \
|
||||
--set=vespa.enabled=false \
|
||||
--set=slackbot.enabled=false \
|
||||
--set=postgresql.enabled=true \
|
||||
--set=postgresql.nameOverride=cloudnative-pg \
|
||||
--set=postgresql.cluster.storage.storageClass=standard \
|
||||
--set=redis.enabled=true \
|
||||
--set=redis.storageSpec.volumeClaimTemplate.spec.storageClassName=standard \
|
||||
--set=webserver.replicaCount=1 \
|
||||
--set=api.replicaCount=0 \
|
||||
--set=inferenceCapability.replicaCount=0 \
|
||||
--set=indexCapability.replicaCount=0 \
|
||||
--set=celery_beat.replicaCount=0 \
|
||||
--set=celery_worker_heavy.replicaCount=0 \
|
||||
--set=celery_worker_docfetching.replicaCount=0 \
|
||||
--set=celery_worker_docprocessing.replicaCount=0 \
|
||||
--set=celery_worker_light.replicaCount=0 \
|
||||
--set=celery_worker_monitoring.replicaCount=0 \
|
||||
--set=celery_worker_primary.replicaCount=0 \
|
||||
--set=celery_worker_user_file_processing.replicaCount=0 \
|
||||
--set=celery_worker_user_files_indexing.replicaCount=0" \
|
||||
--helm-extra-args="--timeout 900s --debug" \
|
||||
--debug --config ct.yaml
|
||||
CT_EXIT=$?
|
||||
set -e
|
||||
echo "=== Recent logs for debugging ==="
|
||||
kubectl logs --all-namespaces --tail=50 | grep -i "error\|timeout\|failed\|pull" || echo "No error logs found"
|
||||
|
||||
if [[ $CT_EXIT -ne 0 ]]; then
|
||||
echo "ct install failed with exit code $CT_EXIT"
|
||||
exit $CT_EXIT
|
||||
else
|
||||
echo "=== Installation completed successfully ==="
|
||||
fi
|
||||
|
||||
kubectl get pods --all-namespaces
|
||||
|
||||
- name: Post-install verification
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Post-install verification ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get services --all-namespaces
|
||||
# Only show issues if they exist
|
||||
kubectl describe pods --all-namespaces | grep -A 5 -B 2 "Failed\|Error\|Warning" || echo "No pod issues found"
|
||||
|
||||
- name: Cleanup on failure
|
||||
if: failure() && steps.list-changed.outputs.changed == 'true'
|
||||
run: |
|
||||
echo "=== Cleanup on failure ==="
|
||||
echo "=== Final cluster state ==="
|
||||
kubectl get pods --all-namespaces
|
||||
kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -10
|
||||
|
||||
echo "=== Pod descriptions for debugging ==="
|
||||
kubectl describe pods --all-namespaces | grep -A 10 -B 3 "Failed\|Error\|Warning\|Pending" || echo "No problematic pods found"
|
||||
|
||||
echo "=== Recent logs for debugging ==="
|
||||
kubectl logs --all-namespaces --tail=50 | grep -i "error\|timeout\|failed\|pull" || echo "No error logs found"
|
||||
|
||||
echo "=== Helm releases ==="
|
||||
helm list --all-namespaces
|
||||
# the following would install only changed charts, but we only have one chart so
|
||||
# don't worry about that for now
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
echo "=== Helm releases ==="
|
||||
helm list --all-namespaces
|
||||
# the following would install only changed charts, but we only have one chart so
|
||||
# don't worry about that for now
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
8
.github/workflows/pr-integration-tests.yml
vendored
8
.github/workflows/pr-integration-tests.yml
vendored
@@ -310,8 +310,9 @@ jobs:
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=1
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
MCP_SERVER_ENABLED=true
|
||||
USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
|
||||
EOF
|
||||
|
||||
- name: Start Docker containers
|
||||
@@ -438,7 +439,7 @@ jobs:
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-${{ matrix.test-dir.name }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
@@ -480,6 +481,7 @@ jobs:
|
||||
AUTH_TYPE=cloud \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
OPENAI_DEFAULT_API_KEY=${OPENAI_API_KEY} \
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
|
||||
DEV_MODE=true \
|
||||
@@ -566,7 +568,7 @@ jobs:
|
||||
|
||||
- name: Upload logs (multi-tenant)
|
||||
if: always()
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-multitenant
|
||||
path: ${{ github.workspace }}/docker-compose-multitenant.log
|
||||
|
||||
2
.github/workflows/pr-jest-tests.yml
vendored
2
.github/workflows/pr-jest-tests.yml
vendored
@@ -44,7 +44,7 @@ jobs:
|
||||
|
||||
- name: Upload coverage reports
|
||||
if: always()
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: jest-coverage-${{ github.run_id }}
|
||||
path: ./web/coverage
|
||||
|
||||
@@ -301,7 +301,7 @@ jobs:
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
MCP_SERVER_ENABLED=true
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=1
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
EOF
|
||||
|
||||
- name: Start Docker containers
|
||||
@@ -424,7 +424,7 @@ jobs:
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-${{ matrix.test-dir.name }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
4
.github/workflows/pr-playwright-tests.yml
vendored
4
.github/workflows/pr-playwright-tests.yml
vendored
@@ -435,7 +435,7 @@ jobs:
|
||||
fi
|
||||
npx playwright test --project ${PROJECT}
|
||||
|
||||
- uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
if: always()
|
||||
with:
|
||||
# Includes test results and trace.zip files
|
||||
@@ -455,7 +455,7 @@ jobs:
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
3
.github/workflows/pr-python-checks.yml
vendored
3
.github/workflows/pr-python-checks.yml
vendored
@@ -50,8 +50,9 @@ jobs:
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: backend/.mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
restore-keys: |
|
||||
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
|
||||
mypy-${{ runner.os }}-
|
||||
|
||||
- name: Run MyPy
|
||||
|
||||
140
.github/workflows/pr-python-model-tests.yml
vendored
140
.github/workflows/pr-python-model-tests.yml
vendored
@@ -5,11 +5,6 @@ on:
|
||||
# This cron expression runs the job daily at 16:00 UTC (9am PT)
|
||||
- cron: "0 16 * * *"
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
branch:
|
||||
description: 'Branch to run the workflow on'
|
||||
required: false
|
||||
default: 'main'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -31,7 +26,11 @@ env:
|
||||
jobs:
|
||||
model-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}-model-check"]
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- "run-id=${{ github.run_id }}-model-check"
|
||||
- "extras=ecr-cache"
|
||||
timeout-minutes: 45
|
||||
|
||||
env:
|
||||
@@ -43,108 +42,87 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup Python and Install Dependencies
|
||||
uses: ./.github/actions/setup-python-and-install-dependencies
|
||||
with:
|
||||
requirements: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REF_NAME: ${{ github.ref_name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435
|
||||
|
||||
# We don't need to build the Web Docker image since it's not yet used
|
||||
# in the integration tests. We have a separate action to verify that it builds
|
||||
# successfully.
|
||||
- name: Pull Model Server Docker image
|
||||
run: |
|
||||
docker pull onyxdotapp/onyx-model-server:latest
|
||||
docker tag onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:test
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # ratchet:actions/setup-python@v6
|
||||
- name: Build and load
|
||||
uses: docker/bake-action@5be5f02ff8819ecd3092ea6b2e6261c31774f2b4 # ratchet:docker/bake-action@v6
|
||||
env:
|
||||
TAG: model-server-${{ github.run_id }}
|
||||
with:
|
||||
python-version: "3.11"
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
|
||||
- name: Install Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
load: true
|
||||
targets: model-server
|
||||
set: |
|
||||
model-server.cache-from=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }}
|
||||
model-server.cache-from=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
model-server.cache-from=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
|
||||
model-server.cache-from=type=registry,ref=onyxdotapp/onyx-model-server:latest
|
||||
model-server.cache-to=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max
|
||||
model-server.cache-to=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
model-server.cache-to=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
|
||||
|
||||
- name: Start Docker containers
|
||||
id: start_docker
|
||||
env:
|
||||
IMAGE_TAG: model-server-${{ github.run_id }}
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
AUTH_TYPE=basic \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.model-server-test.yml up -d indexing_model_server
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Service did not become ready in 5 minutes."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Use curl with error handling to ignore specific exit code 56
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:9000/api/health || echo "curl_error")
|
||||
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "Service is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
|
||||
else
|
||||
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
|
||||
sleep 5
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
docker compose \
|
||||
-f docker-compose.yml \
|
||||
-f docker-compose.dev.yml \
|
||||
up -d --wait \
|
||||
inference_model_server
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/llm
|
||||
py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/embedding
|
||||
|
||||
- name: Alert on Failure
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
env:
|
||||
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
|
||||
REPO: ${{ github.repository }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
curl -X POST \
|
||||
-H 'Content-type: application/json' \
|
||||
--data "{\"text\":\"Scheduled Model Tests failed! Check the run at: https://github.com/${REPO}/actions/runs/${RUN_ID}\"}" \
|
||||
$SLACK_WEBHOOK
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
|
||||
failed-jobs: model-check
|
||||
title: "🚨 Scheduled Model Tests failed!"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
|
||||
- name: Dump all-container logs (optional)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.model-server-test.yml logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # ratchet:actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
11
.github/workflows/zizmor.yml
vendored
11
.github/workflows/zizmor.yml
vendored
@@ -21,18 +21,29 @@ jobs:
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Detect changes
|
||||
id: filter
|
||||
uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # ratchet:dorny/paths-filter@v3
|
||||
with:
|
||||
filters: |
|
||||
zizmor:
|
||||
- '.github/**'
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
|
||||
uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
- name: Run zizmor
|
||||
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
|
||||
run: uv run --no-sync --with zizmor zizmor --format=sarif . > results.sarif
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Upload SARIF file
|
||||
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
|
||||
uses: github/codeql-action/upload-sarif@ba454b8ab46733eb6145342877cd148270bb77ab # ratchet:github/codeql-action/upload-sarif@codeql-bundle-v2.23.5
|
||||
with:
|
||||
sarif_file: results.sarif
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -21,6 +21,7 @@ backend/tests/regression/search_quality/*.json
|
||||
backend/onyx/evals/data/
|
||||
backend/onyx/evals/one_off/*.json
|
||||
*.log
|
||||
*.csv
|
||||
|
||||
# secret files
|
||||
.env
|
||||
|
||||
@@ -11,7 +11,6 @@ repos:
|
||||
- id: uv-sync
|
||||
args: ["--locked", "--all-extras"]
|
||||
- id: uv-lock
|
||||
files: ^pyproject\.toml$
|
||||
- id: uv-export
|
||||
name: uv-export default.txt
|
||||
args:
|
||||
@@ -147,6 +146,22 @@ repos:
|
||||
pass_filenames: false
|
||||
files: \.tf$
|
||||
|
||||
- id: npm-install
|
||||
name: npm install
|
||||
description: "Automatically run 'npm install' after a checkout, pull or rebase"
|
||||
language: system
|
||||
entry: bash -c 'cd web && npm install --no-save'
|
||||
pass_filenames: false
|
||||
files: ^web/package(-lock)?\.json$
|
||||
stages: [post-checkout, post-merge, post-rewrite]
|
||||
- id: npm-install-check
|
||||
name: npm install --package-lock-only
|
||||
description: "Check the 'web/package-lock.json' is updated"
|
||||
language: system
|
||||
entry: bash -c 'cd web && npm install --package-lock-only'
|
||||
pass_filenames: false
|
||||
files: ^web/package(-lock)?\.json$
|
||||
|
||||
# Uses tsgo (TypeScript's native Go compiler) for ~10x faster type checking.
|
||||
# This is a preview package - if it breaks:
|
||||
# 1. Try updating: cd web && npm update @typescript/native-preview
|
||||
|
||||
6
.vscode/env_template.txt
vendored
6
.vscode/env_template.txt
vendored
@@ -17,12 +17,6 @@ LOG_ONYX_MODEL_INTERACTIONS=True
|
||||
LOG_LEVEL=debug
|
||||
|
||||
|
||||
# This passes top N results to LLM an additional time for reranking prior to
|
||||
# answer generation.
|
||||
# This step is quite heavy on token usage so we disable it for dev generally.
|
||||
DISABLE_LLM_DOC_RELEVANCE=False
|
||||
|
||||
|
||||
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically).
|
||||
OAUTH_CLIENT_ID=<REPLACE THIS>
|
||||
OAUTH_CLIENT_SECRET=<REPLACE THIS>
|
||||
|
||||
18
.vscode/launch.template.jsonc
vendored
18
.vscode/launch.template.jsonc
vendored
@@ -151,6 +151,24 @@
|
||||
},
|
||||
"consoleTitle": "Slack Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "Discord Bot",
|
||||
"consoleName": "Discord Bot",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "onyx/onyxbot/discord/client.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Discord Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "MCP Server",
|
||||
"consoleName": "MCP Server",
|
||||
|
||||
@@ -7,7 +7,7 @@ This file provides guidance to AI agents when working with code in this reposito
|
||||
- If you run into any missing python dependency errors, try running your command with `source .venv/bin/activate` \
|
||||
to assume the python venv.
|
||||
- To make tests work, check the `.env` file at the root of the project to find an OpenAI key.
|
||||
- If using `playwright` to explore the frontend, you can usually log in with username `a@test.com` and password
|
||||
- If using `playwright` to explore the frontend, you can usually log in with username `a@example.com` and password
|
||||
`a`. The app can be accessed at `http://localhost:3000`.
|
||||
- You should assume that all Onyx services are running. To verify, you can check the `backend/log` directory to
|
||||
make sure we see logs coming out from the relevant service.
|
||||
|
||||
@@ -7,7 +7,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
|
||||
- If you run into any missing python dependency errors, try running your command with `source .venv/bin/activate` \
|
||||
to assume the python venv.
|
||||
- To make tests work, check the `.env` file at the root of the project to find an OpenAI key.
|
||||
- If using `playwright` to explore the frontend, you can usually log in with username `a@test.com` and password
|
||||
- If using `playwright` to explore the frontend, you can usually log in with username `a@example.com` and password
|
||||
`a`. The app can be accessed at `http://localhost:3000`.
|
||||
- You should assume that all Onyx services are running. To verify, you can check the `backend/log` directory to
|
||||
make sure we see logs coming out from the relevant service.
|
||||
|
||||
259
CONTRIBUTING.md
259
CONTRIBUTING.md
@@ -1,262 +1,31 @@
|
||||
<!-- ONYX_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/CONTRIBUTING.md"} -->
|
||||
|
||||
# Contributing to Onyx
|
||||
|
||||
Hey there! We are so excited that you're interested in Onyx.
|
||||
|
||||
As an open source project in a rapidly changing space, we welcome all contributions.
|
||||
|
||||
## 💃 Guidelines
|
||||
## Contribution Opportunities
|
||||
The [GitHub Issues](https://github.com/onyx-dot-app/onyx/issues) page is a great place to look for and share contribution ideas.
|
||||
|
||||
### Contribution Opportunities
|
||||
If you have your own feature that you would like to build please create an issue and community members can provide feedback and
|
||||
thumb it up if they feel a common need.
|
||||
|
||||
The [GitHub Issues](https://github.com/onyx-dot-app/onyx/issues) page is a great place to start for contribution ideas.
|
||||
|
||||
To ensure that your contribution is aligned with the project's direction, please reach out to any maintainer on the Onyx team
|
||||
via [Discord](https://discord.gg/4NA5SbzrWb) or [email](mailto:hello@onyx.app).
|
||||
## Contributing Code
|
||||
Please reference the documents in contributing_guides folder to ensure that the code base is kept to a high standard.
|
||||
1. dev_setup.md (start here): gives you a guide to setting up a local development environment.
|
||||
2. contribution_process.md: how to ensure you are building valuable features that will get reviewed and merged.
|
||||
3. best_practices.md: before asking for reviews, ensure your changes meet the repo code quality standards.
|
||||
|
||||
Issues that have been explicitly approved by the maintainers (aligned with the direction of the project)
|
||||
will be marked with the `approved by maintainers` label.
|
||||
Issues marked `good first issue` are an especially great place to start.
|
||||
|
||||
**Connectors** to other tools are another great place to contribute. For details on how, refer to this
|
||||
[README.md](https://github.com/onyx-dot-app/onyx/blob/main/backend/onyx/connectors/README.md).
|
||||
|
||||
If you have a new/different contribution in mind, we'd love to hear about it!
|
||||
Your input is vital to making sure that Onyx moves in the right direction.
|
||||
Before starting on implementation, please raise a GitHub issue.
|
||||
|
||||
Also, always feel free to message the founders (Chris Weaver / Yuhong Sun) on
|
||||
[Discord](https://discord.gg/4NA5SbzrWb) directly about anything at all.
|
||||
|
||||
### Contributing Code
|
||||
|
||||
To contribute to this project, please follow the
|
||||
To contribute, please follow the
|
||||
["fork and pull request"](https://docs.github.com/en/get-started/quickstart/contributing-to-projects) workflow.
|
||||
When opening a pull request, mention related issues and feel free to tag relevant maintainers.
|
||||
|
||||
Before creating a pull request please make sure that the new changes conform to the formatting and linting requirements.
|
||||
See the [Formatting and Linting](#formatting-and-linting) section for how to run these checks locally.
|
||||
|
||||
### Getting Help 🙋
|
||||
## Getting Help 🙋
|
||||
We have support channels and generally interesting discussions on our [Discord](https://discord.gg/4NA5SbzrWb).
|
||||
|
||||
Our goal is to make contributing as easy as possible. If you run into any issues please don't hesitate to reach out.
|
||||
That way we can help future contributors and users can avoid the same issue.
|
||||
See you there!
|
||||
|
||||
We also have support channels and generally interesting discussions on our
|
||||
[Discord](https://discord.gg/4NA5SbzrWb).
|
||||
|
||||
We would love to see you there!
|
||||
|
||||
## Get Started 🚀
|
||||
|
||||
Onyx being a fully functional app, relies on some external software, specifically:
|
||||
|
||||
- [Postgres](https://www.postgresql.org/) (Relational DB)
|
||||
- [Vespa](https://vespa.ai/) (Vector DB/Search Engine)
|
||||
- [Redis](https://redis.io/) (Cache)
|
||||
- [MinIO](https://min.io/) (File Store)
|
||||
- [Nginx](https://nginx.org/) (Not needed for development flows generally)
|
||||
|
||||
> **Note:**
|
||||
> This guide provides instructions to build and run Onyx locally from source with Docker containers providing the above external software. We believe this combination is easier for
|
||||
> development purposes. If you prefer to use pre-built container images, we provide instructions on running the full Onyx stack within Docker below.
|
||||
|
||||
### Local Set Up
|
||||
|
||||
Be sure to use Python version 3.11. For instructions on installing Python 3.11 on macOS, refer to the [CONTRIBUTING_MACOS.md](./CONTRIBUTING_MACOS.md) readme.
|
||||
|
||||
If using a lower version, modifications will have to be made to the code.
|
||||
If using a higher version, sometimes some libraries will not be available (i.e. we had problems with Tensorflow in the past with higher versions of python).
|
||||
|
||||
#### Backend: Python requirements
|
||||
|
||||
Currently, we use [uv](https://docs.astral.sh/uv/) and recommend creating a [virtual environment](https://docs.astral.sh/uv/pip/environments/#using-a-virtual-environment).
|
||||
|
||||
For convenience here's a command for it:
|
||||
|
||||
```bash
|
||||
uv venv .venv --python 3.11
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
_For Windows, activate the virtual environment using Command Prompt:_
|
||||
|
||||
```bash
|
||||
.venv\Scripts\activate
|
||||
```
|
||||
|
||||
If using PowerShell, the command slightly differs:
|
||||
|
||||
```powershell
|
||||
.venv\Scripts\Activate.ps1
|
||||
```
|
||||
|
||||
Install the required python dependencies:
|
||||
|
||||
```bash
|
||||
uv sync --all-extras
|
||||
```
|
||||
|
||||
Install Playwright for Python (headless browser required by the Web Connector):
|
||||
|
||||
```bash
|
||||
uv run playwright install
|
||||
```
|
||||
|
||||
#### Frontend: Node dependencies
|
||||
|
||||
Onyx uses Node v22.20.0. We highly recommend you use [Node Version Manager (nvm)](https://github.com/nvm-sh/nvm)
|
||||
to manage your Node installations. Once installed, you can run
|
||||
|
||||
```bash
|
||||
nvm install 22 && nvm use 22
|
||||
node -v # verify your active version
|
||||
```
|
||||
|
||||
Navigate to `onyx/web` and run:
|
||||
|
||||
```bash
|
||||
npm i
|
||||
```
|
||||
|
||||
## Formatting and Linting
|
||||
|
||||
### Backend
|
||||
|
||||
For the backend, you'll need to setup pre-commit hooks (black / reorder-python-imports).
|
||||
|
||||
Then run:
|
||||
|
||||
```bash
|
||||
uv run pre-commit install
|
||||
```
|
||||
|
||||
Additionally, we use `mypy` for static type checking.
|
||||
Onyx is fully type-annotated, and we want to keep it that way!
|
||||
To run the mypy checks manually, run `uv run mypy .` from the `onyx/backend` directory.
|
||||
|
||||
### Web
|
||||
|
||||
We use `prettier` for formatting. The desired version will be installed via a `npm i` from the `onyx/web` directory.
|
||||
To run the formatter, use `npx prettier --write .` from the `onyx/web` directory.
|
||||
|
||||
Pre-commit will also run prettier automatically on files you've recently touched. If re-formatted, your commit will fail.
|
||||
Re-stage your changes and commit again.
|
||||
|
||||
# Running the application for development
|
||||
|
||||
## Developing using VSCode Debugger (recommended)
|
||||
|
||||
**We highly recommend using VSCode debugger for development.**
|
||||
See [CONTRIBUTING_VSCODE.md](./CONTRIBUTING_VSCODE.md) for more details.
|
||||
|
||||
Otherwise, you can follow the instructions below to run the application for development.
|
||||
|
||||
## Manually running the application for development
|
||||
### Docker containers for external software
|
||||
|
||||
You will need Docker installed to run these containers.
|
||||
|
||||
First navigate to `onyx/deployment/docker_compose`, then start up Postgres/Vespa/Redis/MinIO with:
|
||||
|
||||
```bash
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d index relational_db cache minio
|
||||
```
|
||||
|
||||
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)
|
||||
|
||||
### Running Onyx locally
|
||||
|
||||
To start the frontend, navigate to `onyx/web` and run:
|
||||
|
||||
```bash
|
||||
npm run dev
|
||||
```
|
||||
|
||||
Next, start the model server which runs the local NLP models.
|
||||
Navigate to `onyx/backend` and run:
|
||||
|
||||
```bash
|
||||
uvicorn model_server.main:app --reload --port 9000
|
||||
```
|
||||
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
|
||||
```bash
|
||||
powershell -Command "uvicorn model_server.main:app --reload --port 9000"
|
||||
```
|
||||
|
||||
The first time running Onyx, you will need to run the DB migrations for Postgres.
|
||||
After the first time, this is no longer required unless the DB models change.
|
||||
|
||||
Navigate to `onyx/backend` and with the venv active, run:
|
||||
|
||||
```bash
|
||||
alembic upgrade head
|
||||
```
|
||||
|
||||
Next, start the task queue which orchestrates the background jobs.
|
||||
Jobs that take more time are run async from the API server.
|
||||
|
||||
Still in `onyx/backend`, run:
|
||||
|
||||
```bash
|
||||
python ./scripts/dev_run_background_jobs.py
|
||||
```
|
||||
|
||||
To run the backend API server, navigate back to `onyx/backend` and run:
|
||||
|
||||
```bash
|
||||
AUTH_TYPE=disabled uvicorn onyx.main:app --reload --port 8080
|
||||
```
|
||||
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
|
||||
```bash
|
||||
powershell -Command "
|
||||
$env:AUTH_TYPE='disabled'
|
||||
uvicorn onyx.main:app --reload --port 8080
|
||||
"
|
||||
```
|
||||
|
||||
> **Note:**
|
||||
> If you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services.
|
||||
|
||||
#### Wrapping up
|
||||
|
||||
You should now have 4 servers running:
|
||||
|
||||
- Web server
|
||||
- Backend API
|
||||
- Model server
|
||||
- Background jobs
|
||||
|
||||
Now, visit `http://localhost:3000` in your browser. You should see the Onyx onboarding wizard where you can connect your external LLM provider to Onyx.
|
||||
|
||||
You've successfully set up a local Onyx instance! 🏁
|
||||
|
||||
#### Running the Onyx application in a container
|
||||
|
||||
You can run the full Onyx application stack from pre-built images including all external software dependencies.
|
||||
|
||||
Navigate to `onyx/deployment/docker_compose` and run:
|
||||
|
||||
```bash
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
After Docker pulls and starts these containers, navigate to `http://localhost:3000` to use Onyx.
|
||||
|
||||
If you want to make changes to Onyx and run those changes in Docker, you can also build a local version of the Onyx container images that incorporates your changes like so:
|
||||
|
||||
```bash
|
||||
docker compose up -d --build
|
||||
```
|
||||
|
||||
|
||||
### Release Process
|
||||
|
||||
## Release Process
|
||||
Onyx loosely follows the SemVer versioning standard.
|
||||
Major changes are released with a "minor" version bump. Currently we use patch release versions to indicate small feature changes.
|
||||
A set of Docker containers will be pushed automatically to DockerHub with every tag.
|
||||
|
||||
@@ -225,7 +225,6 @@ def do_run_migrations(
|
||||
) -> None:
|
||||
if create_schema:
|
||||
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
|
||||
connection.execute(text("COMMIT"))
|
||||
|
||||
connection.execute(text(f'SET search_path TO "{schema_name}"'))
|
||||
|
||||
@@ -309,6 +308,7 @@ async def run_async_migrations() -> None:
|
||||
schema_name=schema,
|
||||
create_schema=create_schema,
|
||||
)
|
||||
await connection.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema}: {e}")
|
||||
if not continue_on_error:
|
||||
@@ -346,6 +346,7 @@ async def run_async_migrations() -> None:
|
||||
schema_name=schema,
|
||||
create_schema=create_schema,
|
||||
)
|
||||
await connection.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema}: {e}")
|
||||
if not continue_on_error:
|
||||
|
||||
@@ -85,103 +85,122 @@ class UserRow(NamedTuple):
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Start transaction
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
# Step 1: Create or update the unified assistant (ID 0)
|
||||
search_assistant = conn.execute(
|
||||
sa.text("SELECT * FROM persona WHERE id = 0")
|
||||
).fetchone()
|
||||
|
||||
try:
|
||||
# Step 1: Create or update the unified assistant (ID 0)
|
||||
search_assistant = conn.execute(
|
||||
sa.text("SELECT * FROM persona WHERE id = 0")
|
||||
).fetchone()
|
||||
|
||||
if search_assistant:
|
||||
# Update existing Search assistant to be the unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET name = :name,
|
||||
description = :description,
|
||||
system_prompt = :system_prompt,
|
||||
num_chunks = :num_chunks,
|
||||
is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false,
|
||||
display_priority = :display_priority,
|
||||
llm_filter_extraction = :llm_filter_extraction,
|
||||
llm_relevance_filter = :llm_relevance_filter,
|
||||
recency_bias = :recency_bias,
|
||||
chunks_above = :chunks_above,
|
||||
chunks_below = :chunks_below,
|
||||
datetime_aware = :datetime_aware,
|
||||
starter_messages = null
|
||||
WHERE id = 0
|
||||
"""
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
else:
|
||||
# Create new unified assistant with ID 0
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona (
|
||||
id, name, description, system_prompt, num_chunks,
|
||||
is_default_persona, is_visible, deleted, display_priority,
|
||||
llm_filter_extraction, llm_relevance_filter, recency_bias,
|
||||
chunks_above, chunks_below, datetime_aware, starter_messages,
|
||||
builtin_persona
|
||||
) VALUES (
|
||||
0, :name, :description, :system_prompt, :num_chunks,
|
||||
true, true, false, :display_priority, :llm_filter_extraction,
|
||||
:llm_relevance_filter, :recency_bias, :chunks_above, :chunks_below,
|
||||
:datetime_aware, null, true
|
||||
)
|
||||
"""
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
|
||||
# Step 2: Mark ALL builtin assistants as deleted (except the unified assistant ID 0)
|
||||
if search_assistant:
|
||||
# Update existing Search assistant to be the unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = true, is_visible = false, is_default_persona = false
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
SET name = :name,
|
||||
description = :description,
|
||||
system_prompt = :system_prompt,
|
||||
num_chunks = :num_chunks,
|
||||
is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false,
|
||||
display_priority = :display_priority,
|
||||
llm_filter_extraction = :llm_filter_extraction,
|
||||
llm_relevance_filter = :llm_relevance_filter,
|
||||
recency_bias = :recency_bias,
|
||||
chunks_above = :chunks_above,
|
||||
chunks_below = :chunks_below,
|
||||
datetime_aware = :datetime_aware,
|
||||
starter_messages = null
|
||||
WHERE id = 0
|
||||
"""
|
||||
)
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
else:
|
||||
# Create new unified assistant with ID 0
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona (
|
||||
id, name, description, system_prompt, num_chunks,
|
||||
is_default_persona, is_visible, deleted, display_priority,
|
||||
llm_filter_extraction, llm_relevance_filter, recency_bias,
|
||||
chunks_above, chunks_below, datetime_aware, starter_messages,
|
||||
builtin_persona
|
||||
) VALUES (
|
||||
0, :name, :description, :system_prompt, :num_chunks,
|
||||
true, true, false, :display_priority, :llm_filter_extraction,
|
||||
:llm_relevance_filter, :recency_bias, :chunks_above, :chunks_below,
|
||||
:datetime_aware, null, true
|
||||
)
|
||||
"""
|
||||
),
|
||||
INSERT_DICT,
|
||||
)
|
||||
|
||||
# Step 3: Add all built-in tools to the unified assistant
|
||||
# First, get the tool IDs for SearchTool, ImageGenerationTool, and WebSearchTool
|
||||
search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'SearchTool'")
|
||||
).fetchone()
|
||||
# Step 2: Mark ALL builtin assistants as deleted (except the unified assistant ID 0)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = true, is_visible = false, is_default_persona = false
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
if not search_tool:
|
||||
raise ValueError(
|
||||
"SearchTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
# Step 3: Add all built-in tools to the unified assistant
|
||||
# First, get the tool IDs for SearchTool, ImageGenerationTool, and WebSearchTool
|
||||
search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'SearchTool'")
|
||||
).fetchone()
|
||||
|
||||
image_gen_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'ImageGenerationTool'")
|
||||
).fetchone()
|
||||
if not search_tool:
|
||||
raise ValueError(
|
||||
"SearchTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
|
||||
if not image_gen_tool:
|
||||
raise ValueError(
|
||||
"ImageGenerationTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
image_gen_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'ImageGenerationTool'")
|
||||
).fetchone()
|
||||
|
||||
# WebSearchTool is optional - may not be configured
|
||||
web_search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'WebSearchTool'")
|
||||
).fetchone()
|
||||
if not image_gen_tool:
|
||||
raise ValueError(
|
||||
"ImageGenerationTool not found in database. Ensure tools migration has run first."
|
||||
)
|
||||
|
||||
# Clear existing tool associations for persona 0
|
||||
conn.execute(sa.text("DELETE FROM persona__tool WHERE persona_id = 0"))
|
||||
# WebSearchTool is optional - may not be configured
|
||||
web_search_tool = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'WebSearchTool'")
|
||||
).fetchone()
|
||||
|
||||
# Add tools to the unified assistant
|
||||
# Clear existing tool associations for persona 0
|
||||
conn.execute(sa.text("DELETE FROM persona__tool WHERE persona_id = 0"))
|
||||
|
||||
# Add tools to the unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": search_tool[0]},
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": image_gen_tool[0]},
|
||||
)
|
||||
|
||||
if web_search_tool:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
@@ -190,191 +209,148 @@ def upgrade() -> None:
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": search_tool[0]},
|
||||
{"tool_id": web_search_tool[0]},
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
# Step 4: Migrate existing chat sessions from all builtin assistants to unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
),
|
||||
{"tool_id": image_gen_tool[0]},
|
||||
UPDATE chat_session
|
||||
SET persona_id = 0
|
||||
WHERE persona_id IN (
|
||||
SELECT id FROM persona WHERE builtin_persona = true AND id != 0
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
if web_search_tool:
|
||||
# Step 5: Migrate user preferences - remove references to all builtin assistants
|
||||
# First, get all builtin assistant IDs (except 0)
|
||||
builtin_assistants_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id FROM persona
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
builtin_assistant_ids = [row[0] for row in builtin_assistants_result]
|
||||
|
||||
# Get all users with preferences
|
||||
users_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, chosen_assistants, visible_assistants,
|
||||
hidden_assistants, pinned_assistants
|
||||
FROM "user"
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for user_row in users_result:
|
||||
user = UserRow(*user_row)
|
||||
user_id: UUID = user.id
|
||||
updates: dict[str, Any] = {}
|
||||
|
||||
# Remove all builtin assistants from chosen_assistants
|
||||
if user.chosen_assistants:
|
||||
new_chosen: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.chosen_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_chosen != user.chosen_assistants:
|
||||
updates["chosen_assistants"] = json.dumps(new_chosen)
|
||||
|
||||
# Remove all builtin assistants from visible_assistants
|
||||
if user.visible_assistants:
|
||||
new_visible: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.visible_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_visible != user.visible_assistants:
|
||||
updates["visible_assistants"] = json.dumps(new_visible)
|
||||
|
||||
# Add all builtin assistants to hidden_assistants
|
||||
if user.hidden_assistants:
|
||||
new_hidden: list[int] = list(user.hidden_assistants)
|
||||
for old_id in builtin_assistant_ids:
|
||||
if old_id not in new_hidden:
|
||||
new_hidden.append(old_id)
|
||||
if new_hidden != user.hidden_assistants:
|
||||
updates["hidden_assistants"] = json.dumps(new_hidden)
|
||||
else:
|
||||
updates["hidden_assistants"] = json.dumps(builtin_assistant_ids)
|
||||
|
||||
# Remove all builtin assistants from pinned_assistants
|
||||
if user.pinned_assistants:
|
||||
new_pinned: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.pinned_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_pinned != user.pinned_assistants:
|
||||
updates["pinned_assistants"] = json.dumps(new_pinned)
|
||||
|
||||
# Apply updates if any
|
||||
if updates:
|
||||
set_clause = ", ".join([f"{k} = :{k}" for k in updates.keys()])
|
||||
updates["user_id"] = str(user_id) # Convert UUID to string for SQL
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": web_search_tool[0]},
|
||||
sa.text(f'UPDATE "user" SET {set_clause} WHERE id = :user_id'),
|
||||
updates,
|
||||
)
|
||||
|
||||
# Step 4: Migrate existing chat sessions from all builtin assistants to unified assistant
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE chat_session
|
||||
SET persona_id = 0
|
||||
WHERE persona_id IN (
|
||||
SELECT id FROM persona WHERE builtin_persona = true AND id != 0
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Step 5: Migrate user preferences - remove references to all builtin assistants
|
||||
# First, get all builtin assistant IDs (except 0)
|
||||
builtin_assistants_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id FROM persona
|
||||
WHERE builtin_persona = true AND id != 0
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
builtin_assistant_ids = [row[0] for row in builtin_assistants_result]
|
||||
|
||||
# Get all users with preferences
|
||||
users_result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, chosen_assistants, visible_assistants,
|
||||
hidden_assistants, pinned_assistants
|
||||
FROM "user"
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for user_row in users_result:
|
||||
user = UserRow(*user_row)
|
||||
user_id: UUID = user.id
|
||||
updates: dict[str, Any] = {}
|
||||
|
||||
# Remove all builtin assistants from chosen_assistants
|
||||
if user.chosen_assistants:
|
||||
new_chosen: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.chosen_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_chosen != user.chosen_assistants:
|
||||
updates["chosen_assistants"] = json.dumps(new_chosen)
|
||||
|
||||
# Remove all builtin assistants from visible_assistants
|
||||
if user.visible_assistants:
|
||||
new_visible: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.visible_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_visible != user.visible_assistants:
|
||||
updates["visible_assistants"] = json.dumps(new_visible)
|
||||
|
||||
# Add all builtin assistants to hidden_assistants
|
||||
if user.hidden_assistants:
|
||||
new_hidden: list[int] = list(user.hidden_assistants)
|
||||
for old_id in builtin_assistant_ids:
|
||||
if old_id not in new_hidden:
|
||||
new_hidden.append(old_id)
|
||||
if new_hidden != user.hidden_assistants:
|
||||
updates["hidden_assistants"] = json.dumps(new_hidden)
|
||||
else:
|
||||
updates["hidden_assistants"] = json.dumps(builtin_assistant_ids)
|
||||
|
||||
# Remove all builtin assistants from pinned_assistants
|
||||
if user.pinned_assistants:
|
||||
new_pinned: list[int] = [
|
||||
assistant_id
|
||||
for assistant_id in user.pinned_assistants
|
||||
if assistant_id not in builtin_assistant_ids
|
||||
]
|
||||
if new_pinned != user.pinned_assistants:
|
||||
updates["pinned_assistants"] = json.dumps(new_pinned)
|
||||
|
||||
# Apply updates if any
|
||||
if updates:
|
||||
set_clause = ", ".join([f"{k} = :{k}" for k in updates.keys()])
|
||||
updates["user_id"] = str(user_id) # Convert UUID to string for SQL
|
||||
conn.execute(
|
||||
sa.text(f'UPDATE "user" SET {set_clause} WHERE id = :user_id'),
|
||||
updates,
|
||||
)
|
||||
|
||||
# Commit transaction
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on error
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Start transaction
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
try:
|
||||
# Only restore General (ID -1) and Art (ID -3) assistants
|
||||
# Step 1: Keep Search assistant (ID 0) as default but restore original state
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false
|
||||
WHERE id = 0
|
||||
# Only restore General (ID -1) and Art (ID -3) assistants
|
||||
# Step 1: Keep Search assistant (ID 0) as default but restore original state
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
)
|
||||
UPDATE persona
|
||||
SET is_default_persona = true,
|
||||
is_visible = true,
|
||||
deleted = false
|
||||
WHERE id = 0
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Step 2: Restore General assistant (ID -1)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :general_assistant_id
|
||||
# Step 2: Restore General assistant (ID -1)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
),
|
||||
{"general_assistant_id": GENERAL_ASSISTANT_ID},
|
||||
)
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :general_assistant_id
|
||||
"""
|
||||
),
|
||||
{"general_assistant_id": GENERAL_ASSISTANT_ID},
|
||||
)
|
||||
|
||||
# Step 3: Restore Art assistant (ID -3)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :art_assistant_id
|
||||
# Step 3: Restore Art assistant (ID -3)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
),
|
||||
{"art_assistant_id": ART_ASSISTANT_ID},
|
||||
)
|
||||
UPDATE persona
|
||||
SET deleted = false,
|
||||
is_visible = true,
|
||||
is_default_persona = true
|
||||
WHERE id = :art_assistant_id
|
||||
"""
|
||||
),
|
||||
{"art_assistant_id": ART_ASSISTANT_ID},
|
||||
)
|
||||
|
||||
# Note: We don't restore the original tool associations, names, or descriptions
|
||||
# as those would require more complex logic to determine original state.
|
||||
# We also cannot restore original chat session persona_ids as we don't
|
||||
# have the original mappings.
|
||||
# Other builtin assistants remain deleted as per the requirement.
|
||||
|
||||
# Commit transaction
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on error
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
# Note: We don't restore the original tool associations, names, or descriptions
|
||||
# as those would require more complex logic to determine original state.
|
||||
# We also cannot restore original chat session persona_ids as we don't
|
||||
# have the original mappings.
|
||||
# Other builtin assistants remain deleted as per the requirement.
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
"""backend driven notification details
|
||||
|
||||
Revision ID: 5c3dca366b35
|
||||
Revises: 9087b548dd69
|
||||
Create Date: 2026-01-06 16:03:11.413724
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5c3dca366b35"
|
||||
down_revision = "9087b548dd69"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"notification",
|
||||
sa.Column(
|
||||
"title", sa.String(), nullable=False, server_default="New Notification"
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"notification",
|
||||
sa.Column("description", sa.String(), nullable=True, server_default=""),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("notification", "title")
|
||||
op.drop_column("notification", "description")
|
||||
@@ -0,0 +1,47 @@
|
||||
"""add_search_query_table
|
||||
|
||||
Revision ID: 73e9983e5091
|
||||
Revises: d1b637d7050a
|
||||
Create Date: 2026-01-14 14:16:52.837489
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "73e9983e5091"
|
||||
down_revision = "d1b637d7050a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"search_query",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("user.id"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("query", sa.String(), nullable=False),
|
||||
sa.Column("query_expansions", postgresql.ARRAY(sa.String()), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
)
|
||||
|
||||
op.create_index("ix_search_query_user_id", "search_query", ["user_id"])
|
||||
op.create_index("ix_search_query_created_at", "search_query", ["created_at"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_search_query_created_at", table_name="search_query")
|
||||
op.drop_index("ix_search_query_user_id", table_name="search_query")
|
||||
op.drop_table("search_query")
|
||||
@@ -10,8 +10,7 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.enums import RecencyBiasSetting, SearchType
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "776b3bbe9092"
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
"""notifications constraint, sort index, and cleanup old notifications
|
||||
|
||||
Revision ID: 8405ca81cc83
|
||||
Revises: a3c1a7904cd0
|
||||
Create Date: 2026-01-07 16:43:44.855156
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8405ca81cc83"
|
||||
down_revision = "a3c1a7904cd0"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create unique index for notification deduplication.
|
||||
# This enables atomic ON CONFLICT DO NOTHING inserts in batch_create_notifications.
|
||||
#
|
||||
# Uses COALESCE to handle NULL additional_data (NULLs are normally distinct
|
||||
# in unique constraints, but we want NULL == NULL for deduplication).
|
||||
# The '{}' represents an empty JSONB object as the NULL replacement.
|
||||
|
||||
# Clean up legacy notifications first
|
||||
op.execute("DELETE FROM notification WHERE title = 'New Notification'")
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS ix_notification_user_type_data
|
||||
ON notification (user_id, notif_type, COALESCE(additional_data, '{}'::jsonb))
|
||||
"""
|
||||
)
|
||||
|
||||
# Create index for efficient notification sorting by user
|
||||
# Covers: WHERE user_id = ? ORDER BY dismissed, first_shown DESC
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS ix_notification_user_sort
|
||||
ON notification (user_id, dismissed, first_shown DESC)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS ix_notification_user_type_data")
|
||||
op.execute("DROP INDEX IF EXISTS ix_notification_user_sort")
|
||||
116
backend/alembic/versions/8b5ce697290e_add_discord_bot_tables.py
Normal file
116
backend/alembic/versions/8b5ce697290e_add_discord_bot_tables.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Add Discord bot tables
|
||||
|
||||
Revision ID: 8b5ce697290e
|
||||
Revises: a1b2c3d4e5f7
|
||||
Create Date: 2025-01-14
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8b5ce697290e"
|
||||
down_revision = "a1b2c3d4e5f7"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# DiscordBotConfig (singleton table - one per tenant)
|
||||
op.create_table(
|
||||
"discord_bot_config",
|
||||
sa.Column(
|
||||
"id",
|
||||
sa.String(),
|
||||
primary_key=True,
|
||||
server_default=sa.text("'SINGLETON'"),
|
||||
),
|
||||
sa.Column("bot_token", sa.LargeBinary(), nullable=False), # EncryptedString
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.CheckConstraint("id = 'SINGLETON'", name="ck_discord_bot_config_singleton"),
|
||||
)
|
||||
|
||||
# DiscordGuildConfig
|
||||
op.create_table(
|
||||
"discord_guild_config",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("guild_id", sa.BigInteger(), nullable=True, unique=True),
|
||||
sa.Column("guild_name", sa.String(), nullable=True),
|
||||
sa.Column("registration_key", sa.String(), nullable=False, unique=True),
|
||||
sa.Column("registered_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column(
|
||||
"default_persona_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("persona.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False
|
||||
),
|
||||
)
|
||||
|
||||
# DiscordChannelConfig
|
||||
op.create_table(
|
||||
"discord_channel_config",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column(
|
||||
"guild_config_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("discord_guild_config.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("channel_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("channel_name", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"channel_type",
|
||||
sa.String(20),
|
||||
server_default=sa.text("'text'"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"is_private",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("false"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"thread_only_mode",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("false"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"require_bot_invocation",
|
||||
sa.Boolean(),
|
||||
server_default=sa.text("true"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"persona_override_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("persona.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"enabled", sa.Boolean(), server_default=sa.text("false"), nullable=False
|
||||
),
|
||||
)
|
||||
|
||||
# Unique constraint: one config per channel per guild
|
||||
op.create_unique_constraint(
|
||||
"uq_discord_channel_guild_channel",
|
||||
"discord_channel_config",
|
||||
["guild_config_id", "channel_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("discord_channel_config")
|
||||
op.drop_table("discord_guild_config")
|
||||
op.drop_table("discord_bot_config")
|
||||
@@ -0,0 +1,136 @@
|
||||
"""seed_default_image_gen_config
|
||||
|
||||
Revision ID: 9087b548dd69
|
||||
Revises: 2b90f3af54b8
|
||||
Create Date: 2026-01-05 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9087b548dd69"
|
||||
down_revision = "2b90f3af54b8"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# Constants for default image generation config
|
||||
# Source: web/src/app/admin/configuration/image-generation/constants.ts
|
||||
IMAGE_PROVIDER_ID = "openai_gpt_image_1"
|
||||
MODEL_NAME = "gpt-image-1"
|
||||
PROVIDER_NAME = "openai"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Check if image_generation_config table already has records
|
||||
existing_configs = (
|
||||
conn.execute(sa.text("SELECT COUNT(*) FROM image_generation_config")).scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
if existing_configs > 0:
|
||||
# Skip if configs already exist - user may have configured manually
|
||||
return
|
||||
|
||||
# Find the first OpenAI LLM provider
|
||||
openai_provider = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, api_key
|
||||
FROM llm_provider
|
||||
WHERE provider = :provider
|
||||
ORDER BY id
|
||||
LIMIT 1
|
||||
"""
|
||||
),
|
||||
{"provider": PROVIDER_NAME},
|
||||
).fetchone()
|
||||
|
||||
if not openai_provider:
|
||||
# No OpenAI provider found - nothing to do
|
||||
return
|
||||
|
||||
source_provider_id, api_key = openai_provider
|
||||
|
||||
# Create new LLM provider for image generation (clone only api_key)
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO llm_provider (
|
||||
name, provider, api_key, api_base, api_version,
|
||||
deployment_name, default_model_name, is_public,
|
||||
is_default_provider, is_default_vision_provider, is_auto_mode
|
||||
)
|
||||
VALUES (
|
||||
:name, :provider, :api_key, NULL, NULL,
|
||||
NULL, :default_model_name, :is_public,
|
||||
NULL, NULL, :is_auto_mode
|
||||
)
|
||||
RETURNING id
|
||||
"""
|
||||
),
|
||||
{
|
||||
"name": f"Image Gen - {IMAGE_PROVIDER_ID}",
|
||||
"provider": PROVIDER_NAME,
|
||||
"api_key": api_key,
|
||||
"default_model_name": MODEL_NAME,
|
||||
"is_public": True,
|
||||
"is_auto_mode": False,
|
||||
},
|
||||
)
|
||||
new_provider_id = result.scalar()
|
||||
|
||||
# Create model configuration
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO model_configuration (
|
||||
llm_provider_id, name, is_visible, max_input_tokens,
|
||||
supports_image_input, display_name
|
||||
)
|
||||
VALUES (
|
||||
:llm_provider_id, :name, :is_visible, :max_input_tokens,
|
||||
:supports_image_input, :display_name
|
||||
)
|
||||
RETURNING id
|
||||
"""
|
||||
),
|
||||
{
|
||||
"llm_provider_id": new_provider_id,
|
||||
"name": MODEL_NAME,
|
||||
"is_visible": True,
|
||||
"max_input_tokens": None,
|
||||
"supports_image_input": False,
|
||||
"display_name": None,
|
||||
},
|
||||
)
|
||||
model_config_id = result.scalar()
|
||||
|
||||
# Create image generation config
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO image_generation_config (
|
||||
image_provider_id, model_configuration_id, is_default
|
||||
)
|
||||
VALUES (
|
||||
:image_provider_id, :model_configuration_id, :is_default
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"image_provider_id": IMAGE_PROVIDER_ID,
|
||||
"model_configuration_id": model_config_id,
|
||||
"is_default": True,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# We don't remove the config on downgrade since it's safe to keep around
|
||||
# If we upgrade again, it will be a no-op due to the existing records check
|
||||
pass
|
||||
@@ -42,20 +42,13 @@ TOOL_DESCRIPTIONS = {
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
|
||||
try:
|
||||
for tool_id, description in TOOL_DESCRIPTIONS.items():
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"UPDATE tool SET description = :description WHERE in_code_tool_id = :tool_id"
|
||||
),
|
||||
{"description": description, "tool_id": tool_id},
|
||||
)
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
except Exception as e:
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
for tool_id, description in TOOL_DESCRIPTIONS.items():
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"UPDATE tool SET description = :description WHERE in_code_tool_id = :tool_id"
|
||||
),
|
||||
{"description": description, "tool_id": tool_id},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
"""drop agent_search_metrics table
|
||||
|
||||
Revision ID: a1b2c3d4e5f7
|
||||
Revises: 73e9983e5091
|
||||
Create Date: 2026-01-17
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a1b2c3d4e5f7"
|
||||
down_revision = "73e9983e5091"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_table("agent__search_metrics")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.create_table(
|
||||
"agent__search_metrics",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("persona_id", sa.Integer(), nullable=True),
|
||||
sa.Column("agent_type", sa.String(), nullable=False),
|
||||
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("base_duration_s", sa.Float(), nullable=False),
|
||||
sa.Column("full_duration_s", sa.Float(), nullable=False),
|
||||
sa.Column("base_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("refined_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("all_metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["persona_id"],
|
||||
["persona.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
@@ -0,0 +1,39 @@
|
||||
"""remove userfile related deprecated fields
|
||||
|
||||
Revision ID: a3c1a7904cd0
|
||||
Revises: 5c3dca366b35
|
||||
Create Date: 2026-01-06 13:00:30.634396
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a3c1a7904cd0"
|
||||
down_revision = "5c3dca366b35"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("user_file", "document_id")
|
||||
op.drop_column("user_file", "document_id_migrated")
|
||||
op.drop_column("connector_credential_pair", "is_user_file")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column("is_user_file", sa.Boolean(), nullable=False, server_default="false"),
|
||||
)
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column("document_id", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column(
|
||||
"document_id_migrated", sa.Boolean(), nullable=False, server_default="true"
|
||||
),
|
||||
)
|
||||
@@ -7,7 +7,6 @@ Create Date: 2025-12-18 16:00:00.000000
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_DB_NAME
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
@@ -19,7 +18,7 @@ depends_on = None
|
||||
|
||||
|
||||
DEEP_RESEARCH_TOOL = {
|
||||
"name": RESEARCH_AGENT_DB_NAME,
|
||||
"name": "ResearchAgent",
|
||||
"display_name": "Research Agent",
|
||||
"description": "The Research Agent is a sub-agent that conducts research on a specific topic.",
|
||||
"in_code_tool_id": "ResearchAgent",
|
||||
|
||||
@@ -70,80 +70,66 @@ BUILT_IN_TOOLS = [
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Start transaction
|
||||
conn.execute(sa.text("BEGIN"))
|
||||
# Get existing tools to check what already exists
|
||||
existing_tools = conn.execute(
|
||||
sa.text("SELECT in_code_tool_id FROM tool WHERE in_code_tool_id IS NOT NULL")
|
||||
).fetchall()
|
||||
existing_tool_ids = {row[0] for row in existing_tools}
|
||||
|
||||
try:
|
||||
# Get existing tools to check what already exists
|
||||
existing_tools = conn.execute(
|
||||
sa.text(
|
||||
"SELECT in_code_tool_id FROM tool WHERE in_code_tool_id IS NOT NULL"
|
||||
# Insert or update built-in tools
|
||||
for tool in BUILT_IN_TOOLS:
|
||||
in_code_id = tool["in_code_tool_id"]
|
||||
|
||||
# Handle historical rename: InternetSearchTool -> WebSearchTool
|
||||
if (
|
||||
in_code_id == "WebSearchTool"
|
||||
and "WebSearchTool" not in existing_tool_ids
|
||||
and "InternetSearchTool" in existing_tool_ids
|
||||
):
|
||||
# Rename the existing InternetSearchTool row in place and update fields
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description,
|
||||
in_code_tool_id = :in_code_tool_id
|
||||
WHERE in_code_tool_id = 'InternetSearchTool'
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
).fetchall()
|
||||
existing_tool_ids = {row[0] for row in existing_tools}
|
||||
# Keep the local view of existing ids in sync to avoid duplicate insert
|
||||
existing_tool_ids.discard("InternetSearchTool")
|
||||
existing_tool_ids.add("WebSearchTool")
|
||||
continue
|
||||
|
||||
# Insert or update built-in tools
|
||||
for tool in BUILT_IN_TOOLS:
|
||||
in_code_id = tool["in_code_tool_id"]
|
||||
|
||||
# Handle historical rename: InternetSearchTool -> WebSearchTool
|
||||
if (
|
||||
in_code_id == "WebSearchTool"
|
||||
and "WebSearchTool" not in existing_tool_ids
|
||||
and "InternetSearchTool" in existing_tool_ids
|
||||
):
|
||||
# Rename the existing InternetSearchTool row in place and update fields
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description,
|
||||
in_code_tool_id = :in_code_tool_id
|
||||
WHERE in_code_tool_id = 'InternetSearchTool'
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
# Keep the local view of existing ids in sync to avoid duplicate insert
|
||||
existing_tool_ids.discard("InternetSearchTool")
|
||||
existing_tool_ids.add("WebSearchTool")
|
||||
continue
|
||||
|
||||
if in_code_id in existing_tool_ids:
|
||||
# Update existing tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
else:
|
||||
# Insert new tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id)
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
|
||||
# Commit transaction
|
||||
conn.execute(sa.text("COMMIT"))
|
||||
|
||||
except Exception as e:
|
||||
# Rollback on error
|
||||
conn.execute(sa.text("ROLLBACK"))
|
||||
raise e
|
||||
if in_code_id in existing_tool_ids:
|
||||
# Update existing tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :name,
|
||||
display_name = :display_name,
|
||||
description = :description
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
else:
|
||||
# Insert new tool
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO tool (name, display_name, description, in_code_tool_id)
|
||||
VALUES (:name, :display_name, :description, :in_code_tool_id)
|
||||
"""
|
||||
),
|
||||
tool,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
"""sync_exa_api_key_to_content_provider
|
||||
|
||||
Revision ID: d1b637d7050a
|
||||
Revises: d25168c2beee
|
||||
Create Date: 2026-01-09 15:54:15.646249
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
from sqlalchemy import text
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d1b637d7050a"
|
||||
down_revision = "d25168c2beee"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Exa uses a shared API key between search and content providers.
|
||||
# For existing Exa search providers with API keys, create the corresponding
|
||||
# content provider if it doesn't exist yet.
|
||||
connection = op.get_bind()
|
||||
|
||||
# Check if Exa search provider exists with an API key
|
||||
result = connection.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT api_key FROM internet_search_provider
|
||||
WHERE provider_type = 'exa' AND api_key IS NOT NULL
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
)
|
||||
row = result.fetchone()
|
||||
|
||||
if row:
|
||||
api_key = row[0]
|
||||
# Create Exa content provider with the shared key
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO internet_content_provider
|
||||
(name, provider_type, api_key, is_active)
|
||||
VALUES ('Exa', 'exa', :api_key, false)
|
||||
ON CONFLICT (name) DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"api_key": api_key},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the Exa content provider that was created by this migration
|
||||
connection = op.get_bind()
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
DELETE FROM internet_content_provider
|
||||
WHERE provider_type = 'exa'
|
||||
"""
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,86 @@
|
||||
"""tool_name_consistency
|
||||
|
||||
Revision ID: d25168c2beee
|
||||
Revises: 8405ca81cc83
|
||||
Create Date: 2026-01-11 17:54:40.135777
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d25168c2beee"
|
||||
down_revision = "8405ca81cc83"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
# Currently the seeded tools have the in_code_tool_id == name
|
||||
CURRENT_TOOL_NAME_MAPPING = [
|
||||
"SearchTool",
|
||||
"WebSearchTool",
|
||||
"ImageGenerationTool",
|
||||
"PythonTool",
|
||||
"OpenURLTool",
|
||||
"KnowledgeGraphTool",
|
||||
"ResearchAgent",
|
||||
]
|
||||
|
||||
# Mapping of in_code_tool_id -> name
|
||||
# These are the expected names that we want in the database
|
||||
EXPECTED_TOOL_NAME_MAPPING = {
|
||||
"SearchTool": "internal_search",
|
||||
"WebSearchTool": "web_search",
|
||||
"ImageGenerationTool": "generate_image",
|
||||
"PythonTool": "python",
|
||||
"OpenURLTool": "open_url",
|
||||
"KnowledgeGraphTool": "run_kg_search",
|
||||
"ResearchAgent": "research_agent",
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Mapping of in_code_tool_id to the NAME constant from each tool class
|
||||
# These match the .name property of each tool implementation
|
||||
tool_name_mapping = EXPECTED_TOOL_NAME_MAPPING
|
||||
|
||||
# Update the name column for each tool based on its in_code_tool_id
|
||||
for in_code_tool_id, expected_name in tool_name_mapping.items():
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :expected_name
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
{
|
||||
"expected_name": expected_name,
|
||||
"in_code_tool_id": in_code_tool_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Reverse the migration by setting name back to in_code_tool_id
|
||||
# This matches the original pattern where name was the class name
|
||||
for in_code_tool_id in CURRENT_TOOL_NAME_MAPPING:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tool
|
||||
SET name = :current_name
|
||||
WHERE in_code_tool_id = :in_code_tool_id
|
||||
"""
|
||||
),
|
||||
{
|
||||
"current_name": in_code_tool_id,
|
||||
"in_code_tool_id": in_code_tool_id,
|
||||
},
|
||||
)
|
||||
@@ -109,7 +109,6 @@ CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS = float(
|
||||
|
||||
|
||||
STRIPE_SECRET_KEY = os.environ.get("STRIPE_SECRET_KEY")
|
||||
STRIPE_PRICE_ID = os.environ.get("STRIPE_PRICE")
|
||||
|
||||
# JWT Public Key URL
|
||||
JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
|
||||
@@ -129,3 +128,8 @@ MARKETING_POSTHOG_API_KEY = os.environ.get("MARKETING_POSTHOG_API_KEY")
|
||||
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
|
||||
|
||||
GATED_TENANTS_KEY = "gated_tenants"
|
||||
|
||||
# License enforcement - when True, blocks API access for gated/expired licenses
|
||||
LICENSE_ENFORCEMENT_ENABLED = (
|
||||
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
@@ -3,30 +3,42 @@ from uuid import UUID
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.server.features.persona.models import PersonaSharedNotificationData
|
||||
|
||||
|
||||
def make_persona_private(
|
||||
def update_persona_access(
|
||||
persona_id: int,
|
||||
creator_user_id: UUID | None,
|
||||
user_ids: list[UUID] | None,
|
||||
group_ids: list[int] | None,
|
||||
db_session: Session,
|
||||
is_public: bool | None = None,
|
||||
user_ids: list[UUID] | None = None,
|
||||
group_ids: list[int] | None = None,
|
||||
) -> None:
|
||||
"""NOTE(rkuo): This function batches all updates into a single commit. If we don't
|
||||
dedupe the inputs, the commit will exception."""
|
||||
"""Updates the access settings for a persona including public status, user shares,
|
||||
and group shares.
|
||||
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
db_session.query(Persona__UserGroup).filter(
|
||||
Persona__UserGroup.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
NOTE: This function batches all updates. If we don't dedupe the inputs,
|
||||
the commit will exception.
|
||||
|
||||
NOTE: Callers are responsible for committing."""
|
||||
|
||||
if is_public is not None:
|
||||
persona = db_session.query(Persona).filter(Persona.id == persona_id).first()
|
||||
if persona:
|
||||
persona.is_public = is_public
|
||||
|
||||
# NOTE: For user-ids and group-ids, `None` means "leave unchanged", `[]` means "clear all shares",
|
||||
# and a non-empty list means "replace with these shares".
|
||||
|
||||
if user_ids is not None:
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
if user_ids:
|
||||
user_ids_set = set(user_ids)
|
||||
for user_id in user_ids_set:
|
||||
db_session.add(Persona__User(persona_id=persona_id, user_id=user_id))
|
||||
@@ -34,17 +46,20 @@ def make_persona_private(
|
||||
create_notification(
|
||||
user_id=user_id,
|
||||
notif_type=NotificationType.PERSONA_SHARED,
|
||||
title="A new agent was shared with you!",
|
||||
db_session=db_session,
|
||||
additional_data=PersonaSharedNotificationData(
|
||||
persona_id=persona_id,
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
if group_ids:
|
||||
if group_ids is not None:
|
||||
db_session.query(Persona__UserGroup).filter(
|
||||
Persona__UserGroup.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
group_ids_set = set(group_ids)
|
||||
for group_id in group_ids_set:
|
||||
db_session.add(
|
||||
Persona__UserGroup(persona_id=persona_id, user_group_id=group_id)
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
64
backend/ee/onyx/db/search.py
Normal file
64
backend/ee/onyx/db/search.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.time_utils import get_db_current_time
|
||||
from onyx.db.models import SearchQuery
|
||||
|
||||
|
||||
def create_search_query(
|
||||
db_session: Session,
|
||||
user_id: UUID,
|
||||
query: str,
|
||||
query_expansions: list[str] | None = None,
|
||||
) -> SearchQuery:
|
||||
"""Create and persist a `SearchQuery` row.
|
||||
|
||||
Notes:
|
||||
- `SearchQuery.id` is a UUID PK without a server-side default, so we generate it.
|
||||
- `created_at` is filled by the DB (server_default=now()).
|
||||
"""
|
||||
search_query = SearchQuery(
|
||||
id=uuid.uuid4(),
|
||||
user_id=user_id,
|
||||
query=query,
|
||||
query_expansions=query_expansions,
|
||||
)
|
||||
db_session.add(search_query)
|
||||
db_session.commit()
|
||||
db_session.refresh(search_query)
|
||||
return search_query
|
||||
|
||||
|
||||
def fetch_search_queries_for_user(
|
||||
db_session: Session,
|
||||
user_id: UUID,
|
||||
filter_days: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[SearchQuery]:
|
||||
"""Fetch `SearchQuery` rows for a user.
|
||||
|
||||
Args:
|
||||
user_id: User UUID.
|
||||
filter_days: Optional time filter. If provided, only rows created within
|
||||
the last `filter_days` days are returned.
|
||||
limit: Optional max number of rows to return.
|
||||
"""
|
||||
if filter_days is not None and filter_days <= 0:
|
||||
raise ValueError("filter_days must be > 0")
|
||||
|
||||
stmt = select(SearchQuery).where(SearchQuery.user_id == user_id)
|
||||
|
||||
if filter_days is not None and filter_days > 0:
|
||||
cutoff = get_db_current_time(db_session) - timedelta(days=filter_days)
|
||||
stmt = stmt.where(SearchQuery.created_at >= cutoff)
|
||||
|
||||
stmt = stmt.order_by(SearchQuery.created_at.desc())
|
||||
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
@@ -16,16 +16,17 @@ from ee.onyx.server.enterprise_settings.api import (
|
||||
from ee.onyx.server.evals.api import router as evals_router
|
||||
from ee.onyx.server.license.api import router as license_router
|
||||
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
|
||||
from ee.onyx.server.middleware.license_enforcement import (
|
||||
add_license_enforcement_middleware,
|
||||
)
|
||||
from ee.onyx.server.middleware.tenant_tracking import (
|
||||
add_api_server_tenant_id_middleware,
|
||||
)
|
||||
from ee.onyx.server.oauth.api import router as ee_oauth_router
|
||||
from ee.onyx.server.query_and_chat.chat_backend import (
|
||||
router as chat_router,
|
||||
)
|
||||
from ee.onyx.server.query_and_chat.query_backend import (
|
||||
basic_router as ee_query_router,
|
||||
)
|
||||
from ee.onyx.server.query_and_chat.search_backend import router as search_router
|
||||
from ee.onyx.server.query_history.api import router as query_history_router
|
||||
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
|
||||
from ee.onyx.server.seeding import seed_db
|
||||
@@ -85,6 +86,10 @@ def get_application() -> FastAPI:
|
||||
if MULTI_TENANT:
|
||||
add_api_server_tenant_id_middleware(application, logger)
|
||||
|
||||
# Add license enforcement middleware (runs after tenant tracking)
|
||||
# This blocks access when license is expired/gated
|
||||
add_license_enforcement_middleware(application, logger)
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
# For Google OAuth, refresh tokens are requested by:
|
||||
# 1. Adding the right scopes
|
||||
@@ -124,7 +129,7 @@ def get_application() -> FastAPI:
|
||||
# EE only backend APIs
|
||||
include_router_with_global_prefix_prepended(application, query_router)
|
||||
include_router_with_global_prefix_prepended(application, ee_query_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
include_router_with_global_prefix_prepended(application, search_router)
|
||||
include_router_with_global_prefix_prepended(application, standard_answer_router)
|
||||
include_router_with_global_prefix_prepended(application, ee_oauth_router)
|
||||
include_router_with_global_prefix_prepended(application, ee_document_cc_pair_router)
|
||||
|
||||
27
backend/ee/onyx/prompts/query_expansion.py
Normal file
27
backend/ee/onyx/prompts/query_expansion.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# Single message is likely most reliable and generally better for this task
|
||||
# No final reminders at the end since the user query is expected to be short
|
||||
# If it is not short, it should go into the chat flow so we do not need to account for this.
|
||||
KEYWORD_EXPANSION_PROMPT = """
|
||||
Generate a set of keyword-only queries to help find relevant documents for the provided query. \
|
||||
These queries will be passed to a bm25-based keyword search engine. \
|
||||
Provide a single query per line (where each query consists of one or more keywords). \
|
||||
The queries must be purely keywords and not contain any filler natural language. \
|
||||
The each query should have as few keywords as necessary to represent the user's search intent. \
|
||||
If there are no useful expansions, simply return the original query with no additional keyword queries. \
|
||||
CRITICAL: Do not include any additional formatting, comments, or anything aside from the keyword queries.
|
||||
|
||||
The user query is:
|
||||
{user_query}
|
||||
""".strip()
|
||||
|
||||
|
||||
QUERY_TYPE_PROMPT = """
|
||||
Determine if the provided query is better suited for a keyword search or a semantic search.
|
||||
Respond with "keyword" or "semantic" literally and nothing else.
|
||||
Do not provide any additional text or reasoning to your response.
|
||||
|
||||
CRITICAL: It must only be 1 single word - EITHER "keyword" or "semantic".
|
||||
|
||||
The user query is:
|
||||
{user_query}
|
||||
""".strip()
|
||||
42
backend/ee/onyx/prompts/search_flow_classification.py
Normal file
42
backend/ee/onyx/prompts/search_flow_classification.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# ruff: noqa: E501, W605 start
|
||||
SEARCH_CLASS = "search"
|
||||
CHAT_CLASS = "chat"
|
||||
|
||||
# Will note that with many larger LLMs the latency on running this prompt via third party APIs is as high as 2 seconds which is too slow for many
|
||||
# use cases.
|
||||
SEARCH_CHAT_PROMPT = f"""
|
||||
Determine if the following query is better suited for a search UI or a chat UI. Respond with "{SEARCH_CLASS}" or "{CHAT_CLASS}" literally and nothing else. \
|
||||
Do not provide any additional text or reasoning to your response. CRITICAL, IT MUST ONLY BE 1 SINGLE WORD - EITHER "{SEARCH_CLASS}" or "{CHAT_CLASS}".
|
||||
|
||||
# Classification Guidelines:
|
||||
## {SEARCH_CLASS}
|
||||
- If the query consists entirely of keywords or query doesn't require any answer from the AI
|
||||
- If the query is a short statement that seems like a search query rather than a question
|
||||
- If the query feels nonsensical or is a short phrase that possibly describes a document or information that could be found in a internal document
|
||||
|
||||
### Examples of {SEARCH_CLASS} queries:
|
||||
- Find me the document that goes over the onboarding process for a new hire
|
||||
- Pull requests since last week
|
||||
- Sales Runbook AMEA Region
|
||||
- Procurement process
|
||||
- Retrieve the PRD for project X
|
||||
|
||||
## {CHAT_CLASS}
|
||||
- If the query is asking a question that requires an answer rather than a document
|
||||
- If the query is asking for a solution, suggestion, or general help
|
||||
- If the query is seeking information that is on the web and likely not in a company internal document
|
||||
- If the query should be answered without any context from additional documents or searches
|
||||
|
||||
### Examples of {CHAT_CLASS} queries:
|
||||
- What led us to win the deal with company X? (seeking answer)
|
||||
- Google Drive not sync-ing files to my computer (seeking solution)
|
||||
- Review my email: <whatever the email is> (general help)
|
||||
- Write me a script to... (general help)
|
||||
- Cheap flights Europe to Tokyo (information likely found on the web, not internal)
|
||||
|
||||
# User Query:
|
||||
{{user_query}}
|
||||
|
||||
REMEMBER TO ONLY RESPOND WITH "{SEARCH_CLASS}" OR "{CHAT_CLASS}" AND NOTHING ELSE.
|
||||
""".strip()
|
||||
# ruff: noqa: E501, W605 end
|
||||
270
backend/ee/onyx/search/process_search_query.py
Normal file
270
backend/ee/onyx/search/process_search_query.py
Normal file
@@ -0,0 +1,270 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.search import create_search_query
|
||||
from ee.onyx.secondary_llm_flows.query_expansion import expand_keywords
|
||||
from ee.onyx.server.query_and_chat.models import SearchDocWithContent
|
||||
from ee.onyx.server.query_and_chat.models import SearchFullResponse
|
||||
from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest
|
||||
from ee.onyx.server.query_and_chat.streaming_models import LLMSelectedDocsPacket
|
||||
from ee.onyx.server.query_and_chat.streaming_models import SearchDocsPacket
|
||||
from ee.onyx.server.query_and_chat.streaming_models import SearchErrorPacket
|
||||
from ee.onyx.server.query_and_chat.streaming_models import SearchQueriesPacket
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import ChunkSearchRequest
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.pipeline import merge_individual_chunks
|
||||
from onyx.context.search.pipeline import search_pipeline
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.factory import get_current_primary_default_document_index
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.secondary_llm_flows.document_filter import select_sections_for_expansion
|
||||
from onyx.tools.tool_implementations.search.search_utils import (
|
||||
weighted_reciprocal_rank_fusion,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# This is just a heuristic that also happens to work well for the UI/UX
|
||||
# Users would not find it useful to see a huge list of suggested docs
|
||||
# but more than 1 is also likely good as many questions may target more than 1 doc.
|
||||
TARGET_NUM_SECTIONS_FOR_LLM_SELECTION = 3
|
||||
|
||||
|
||||
def _run_single_search(
|
||||
query: str,
|
||||
filters: BaseFilters | None,
|
||||
document_index: DocumentIndex,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Execute a single search query and return chunks."""
|
||||
chunk_search_request = ChunkSearchRequest(
|
||||
query=query,
|
||||
user_selected_filters=filters,
|
||||
)
|
||||
|
||||
return search_pipeline(
|
||||
chunk_search_request=chunk_search_request,
|
||||
document_index=document_index,
|
||||
user=user,
|
||||
persona=None, # No persona for direct search
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def stream_search_query(
|
||||
request: SendSearchQueryRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> Generator[
|
||||
SearchQueriesPacket | SearchDocsPacket | LLMSelectedDocsPacket | SearchErrorPacket,
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
Core search function that yields streaming packets.
|
||||
Used by both streaming and non-streaming endpoints.
|
||||
"""
|
||||
# Get document index
|
||||
document_index = get_current_primary_default_document_index(db_session)
|
||||
|
||||
# Determine queries to execute
|
||||
original_query = request.search_query
|
||||
keyword_expansions: list[str] = []
|
||||
|
||||
if request.run_query_expansion:
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
keyword_expansions = expand_keywords(
|
||||
user_query=original_query,
|
||||
llm=llm,
|
||||
)
|
||||
if keyword_expansions:
|
||||
logger.debug(
|
||||
f"Query expansion generated {len(keyword_expansions)} keyword queries"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Query expansion failed: {e}; using original query only.")
|
||||
keyword_expansions = []
|
||||
|
||||
# Build list of all executed queries for tracking
|
||||
all_executed_queries = [original_query] + keyword_expansions
|
||||
|
||||
# TODO remove this check, user should not be None
|
||||
if user is not None:
|
||||
create_search_query(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
query=request.search_query,
|
||||
query_expansions=keyword_expansions if keyword_expansions else None,
|
||||
)
|
||||
|
||||
# Execute search(es)
|
||||
if not keyword_expansions:
|
||||
# Single query (original only) - no threading needed
|
||||
chunks = _run_single_search(
|
||||
query=original_query,
|
||||
filters=request.filters,
|
||||
document_index=document_index,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
else:
|
||||
# Multiple queries - run in parallel and merge with RRF
|
||||
# First query is the original (semantic), rest are keyword expansions
|
||||
search_functions = [
|
||||
(
|
||||
_run_single_search,
|
||||
(query, request.filters, document_index, user, db_session),
|
||||
)
|
||||
for query in all_executed_queries
|
||||
]
|
||||
|
||||
# Run all searches in parallel
|
||||
all_search_results: list[list[InferenceChunk]] = (
|
||||
run_functions_tuples_in_parallel(
|
||||
search_functions,
|
||||
allow_failures=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Separate original query results from keyword expansion results
|
||||
# Note that in rare cases, the original query may have failed and so we may be
|
||||
# just overweighting one set of keyword results, should be not a big deal though.
|
||||
original_result = all_search_results[0] if all_search_results else []
|
||||
keyword_results = all_search_results[1:] if len(all_search_results) > 1 else []
|
||||
|
||||
# Build valid results and weights
|
||||
# Original query (semantic): weight 2.0
|
||||
# Keyword expansions: weight 1.0 each
|
||||
valid_results: list[list[InferenceChunk]] = []
|
||||
weights: list[float] = []
|
||||
|
||||
if original_result:
|
||||
valid_results.append(original_result)
|
||||
weights.append(2.0)
|
||||
|
||||
for keyword_result in keyword_results:
|
||||
if keyword_result:
|
||||
valid_results.append(keyword_result)
|
||||
weights.append(1.0)
|
||||
|
||||
if not valid_results:
|
||||
logger.warning("All parallel searches returned empty results")
|
||||
chunks = []
|
||||
else:
|
||||
chunks = weighted_reciprocal_rank_fusion(
|
||||
ranked_results=valid_results,
|
||||
weights=weights,
|
||||
id_extractor=lambda chunk: f"{chunk.document_id}_{chunk.chunk_id}",
|
||||
)
|
||||
|
||||
# Merge chunks into sections
|
||||
sections = merge_individual_chunks(chunks)
|
||||
|
||||
# Apply LLM document selection if requested
|
||||
# num_docs_fed_to_llm_selection specifies how many sections to feed to the LLM for selection
|
||||
# The LLM will always try to select TARGET_NUM_SECTIONS_FOR_LLM_SELECTION sections from those fed to it
|
||||
# llm_selected_doc_ids will be:
|
||||
# - None if LLM selection was not requested or failed
|
||||
# - Empty list if LLM selection ran but selected nothing
|
||||
# - List of doc IDs if LLM selection succeeded
|
||||
run_llm_selection = (
|
||||
request.num_docs_fed_to_llm_selection is not None
|
||||
and request.num_docs_fed_to_llm_selection >= 1
|
||||
)
|
||||
llm_selected_doc_ids: list[str] | None = None
|
||||
llm_selection_failed = False
|
||||
if run_llm_selection and sections:
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
sections_to_evaluate = sections[: request.num_docs_fed_to_llm_selection]
|
||||
selected_sections, _ = select_sections_for_expansion(
|
||||
sections=sections_to_evaluate,
|
||||
user_query=original_query,
|
||||
llm=llm,
|
||||
max_sections=TARGET_NUM_SECTIONS_FOR_LLM_SELECTION,
|
||||
try_to_fill_to_max=True,
|
||||
)
|
||||
# Extract unique document IDs from selected sections (may be empty)
|
||||
llm_selected_doc_ids = list(
|
||||
dict.fromkeys(
|
||||
section.center_chunk.document_id for section in selected_sections
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
f"LLM document selection evaluated {len(sections_to_evaluate)} sections, "
|
||||
f"selected {len(selected_sections)} sections with doc IDs: {llm_selected_doc_ids}"
|
||||
)
|
||||
except Exception as e:
|
||||
# Allowing a blanket exception here as this step is not critical and the rest of the results are still valid
|
||||
logger.warning(f"LLM document selection failed: {e}")
|
||||
llm_selection_failed = True
|
||||
elif run_llm_selection and not sections:
|
||||
# LLM selection requested but no sections to evaluate
|
||||
llm_selected_doc_ids = []
|
||||
|
||||
# Convert to SearchDocWithContent list, optionally including content
|
||||
search_docs = SearchDocWithContent.from_inference_sections(
|
||||
sections,
|
||||
include_content=request.include_content,
|
||||
is_internet=False,
|
||||
)
|
||||
|
||||
# Yield queries packet
|
||||
yield SearchQueriesPacket(all_executed_queries=all_executed_queries)
|
||||
|
||||
# Yield docs packet
|
||||
yield SearchDocsPacket(search_docs=search_docs)
|
||||
|
||||
# Yield LLM selected docs packet if LLM selection was requested
|
||||
# - llm_selected_doc_ids is None if selection failed
|
||||
# - llm_selected_doc_ids is empty list if no docs were selected
|
||||
# - llm_selected_doc_ids is list of IDs if docs were selected
|
||||
if run_llm_selection:
|
||||
yield LLMSelectedDocsPacket(
|
||||
llm_selected_doc_ids=None if llm_selection_failed else llm_selected_doc_ids
|
||||
)
|
||||
|
||||
|
||||
def gather_search_stream(
|
||||
packets: Generator[
|
||||
SearchQueriesPacket
|
||||
| SearchDocsPacket
|
||||
| LLMSelectedDocsPacket
|
||||
| SearchErrorPacket,
|
||||
None,
|
||||
None,
|
||||
],
|
||||
) -> SearchFullResponse:
|
||||
"""
|
||||
Aggregate all streaming packets into SearchFullResponse.
|
||||
"""
|
||||
all_executed_queries: list[str] = []
|
||||
search_docs: list[SearchDocWithContent] = []
|
||||
llm_selected_doc_ids: list[str] | None = None
|
||||
error: str | None = None
|
||||
|
||||
for packet in packets:
|
||||
if isinstance(packet, SearchQueriesPacket):
|
||||
all_executed_queries = packet.all_executed_queries
|
||||
elif isinstance(packet, SearchDocsPacket):
|
||||
search_docs = packet.search_docs
|
||||
elif isinstance(packet, LLMSelectedDocsPacket):
|
||||
llm_selected_doc_ids = packet.llm_selected_doc_ids
|
||||
elif isinstance(packet, SearchErrorPacket):
|
||||
error = packet.error
|
||||
|
||||
return SearchFullResponse(
|
||||
all_executed_queries=all_executed_queries,
|
||||
search_docs=search_docs,
|
||||
doc_selection_reasoning=None,
|
||||
llm_selected_doc_ids=llm_selected_doc_ids,
|
||||
error=error,
|
||||
)
|
||||
0
backend/ee/onyx/secondary_llm_flows/__init__.py
Normal file
0
backend/ee/onyx/secondary_llm_flows/__init__.py
Normal file
92
backend/ee/onyx/secondary_llm_flows/query_expansion.py
Normal file
92
backend/ee/onyx/secondary_llm_flows/query_expansion.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import re
|
||||
|
||||
from ee.onyx.prompts.query_expansion import KEYWORD_EXPANSION_PROMPT
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import LanguageModelInput
|
||||
from onyx.llm.models import ReasoningEffort
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.llm.utils import llm_response_to_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Pattern to remove common LLM artifacts: brackets, quotes, list markers, etc.
|
||||
CLEANUP_PATTERN = re.compile(r'[\[\]"\'`]')
|
||||
|
||||
|
||||
def _clean_keyword_line(line: str) -> str:
|
||||
"""Clean a keyword line by removing common LLM artifacts.
|
||||
|
||||
Removes brackets, quotes, and other characters that LLMs may accidentally
|
||||
include in their output.
|
||||
"""
|
||||
# Remove common artifacts
|
||||
cleaned = CLEANUP_PATTERN.sub("", line)
|
||||
# Remove leading list markers like "1.", "2.", "-", "*"
|
||||
cleaned = re.sub(r"^\s*(?:\d+[\.\)]\s*|[-*]\s*)", "", cleaned)
|
||||
return cleaned.strip()
|
||||
|
||||
|
||||
def expand_keywords(
|
||||
user_query: str,
|
||||
llm: LLM,
|
||||
) -> list[str]:
|
||||
"""Expand a user query into multiple keyword-only queries for BM25 search.
|
||||
|
||||
Uses an LLM to generate keyword-based search queries that capture different
|
||||
aspects of the user's search intent. Returns only the expanded queries,
|
||||
not the original query.
|
||||
|
||||
Args:
|
||||
user_query: The original search query from the user
|
||||
llm: Language model to use for keyword expansion
|
||||
|
||||
Returns:
|
||||
List of expanded keyword queries (excluding the original query).
|
||||
Returns empty list if expansion fails or produces no useful expansions.
|
||||
"""
|
||||
messages: LanguageModelInput = [
|
||||
UserMessage(content=KEYWORD_EXPANSION_PROMPT.format(user_query=user_query))
|
||||
]
|
||||
|
||||
try:
|
||||
response = llm.invoke(
|
||||
prompt=messages,
|
||||
reasoning_effort=ReasoningEffort.OFF,
|
||||
# Limit output - we only expect a few short keyword queries
|
||||
max_tokens=150,
|
||||
)
|
||||
|
||||
content = llm_response_to_string(response).strip()
|
||||
|
||||
if not content:
|
||||
logger.warning("Keyword expansion returned empty response.")
|
||||
return []
|
||||
|
||||
# Parse response - each line is a separate keyword query
|
||||
# Clean each line to remove LLM artifacts and drop empty lines
|
||||
parsed_queries = []
|
||||
for line in content.strip().split("\n"):
|
||||
cleaned = _clean_keyword_line(line)
|
||||
if cleaned:
|
||||
parsed_queries.append(cleaned)
|
||||
|
||||
if not parsed_queries:
|
||||
logger.warning("Keyword expansion parsing returned no queries.")
|
||||
return []
|
||||
|
||||
# Filter out duplicates and queries that match the original
|
||||
expanded_queries: list[str] = []
|
||||
seen_lower: set[str] = {user_query.lower()}
|
||||
for query in parsed_queries:
|
||||
query_lower = query.lower()
|
||||
if query_lower not in seen_lower:
|
||||
seen_lower.add(query_lower)
|
||||
expanded_queries.append(query)
|
||||
|
||||
logger.debug(f"Keyword expansion generated {len(expanded_queries)} queries")
|
||||
return expanded_queries
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Keyword expansion failed: {e}")
|
||||
return []
|
||||
@@ -0,0 +1,50 @@
|
||||
from ee.onyx.prompts.search_flow_classification import CHAT_CLASS
|
||||
from ee.onyx.prompts.search_flow_classification import SEARCH_CHAT_PROMPT
|
||||
from ee.onyx.prompts.search_flow_classification import SEARCH_CLASS
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import LanguageModelInput
|
||||
from onyx.llm.models import ReasoningEffort
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.llm.utils import llm_response_to_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def classify_is_search_flow(
|
||||
query: str,
|
||||
llm: LLM,
|
||||
) -> bool:
|
||||
messages: LanguageModelInput = [
|
||||
UserMessage(content=SEARCH_CHAT_PROMPT.format(user_query=query))
|
||||
]
|
||||
response = llm.invoke(
|
||||
prompt=messages,
|
||||
reasoning_effort=ReasoningEffort.OFF,
|
||||
# Nothing can happen in the UI until this call finishes so we need to be aggressive with the timeout
|
||||
timeout_override=2,
|
||||
# Well more than necessary but just to ensure completion and in case it succeeds with classifying but
|
||||
# ends up rambling
|
||||
max_tokens=20,
|
||||
)
|
||||
|
||||
content = llm_response_to_string(response).strip().lower()
|
||||
if not content:
|
||||
logger.warning(
|
||||
"Search flow classification returned empty response; defaulting to chat flow."
|
||||
)
|
||||
return False
|
||||
|
||||
# Prefer chat if both appear.
|
||||
if CHAT_CLASS in content:
|
||||
return False
|
||||
if SEARCH_CLASS in content:
|
||||
return True
|
||||
|
||||
logger.warning(
|
||||
"Search flow classification returned unexpected response; defaulting to chat flow. Response=%r",
|
||||
content,
|
||||
)
|
||||
return False
|
||||
@@ -19,10 +19,11 @@ from ee.onyx.db.analytics import fetch_query_analytics
|
||||
from ee.onyx.db.analytics import user_can_view_assistant_stats
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
|
||||
router = APIRouter(prefix="/analytics")
|
||||
router = APIRouter(prefix="/analytics", tags=PUBLIC_API_TAGS)
|
||||
|
||||
|
||||
_DEFAULT_LOOKBACK_DAYS = 30
|
||||
|
||||
102
backend/ee/onyx/server/middleware/license_enforcement.py
Normal file
102
backend/ee/onyx/server/middleware/license_enforcement.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Middleware to enforce license status application-wide."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from ee.onyx.server.tenants.product_gating import is_tenant_gated
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
# Paths that are ALWAYS accessible, even when license is expired/gated.
|
||||
# These enable users to:
|
||||
# /auth - Log in/out (users can't fix billing if locked out of auth)
|
||||
# /license - Fetch, upload, or check license status
|
||||
# /health - Health checks for load balancers/orchestrators
|
||||
# /me - Basic user info needed for UI rendering
|
||||
# /settings, /enterprise-settings - View app status and branding
|
||||
# /tenants/billing-* - Manage subscription to resolve gating
|
||||
ALLOWED_PATH_PREFIXES = {
|
||||
"/auth",
|
||||
"/license",
|
||||
"/health",
|
||||
"/me",
|
||||
"/settings",
|
||||
"/enterprise-settings",
|
||||
"/tenants/billing-information",
|
||||
"/tenants/create-customer-portal-session",
|
||||
"/tenants/create-subscription-session",
|
||||
}
|
||||
|
||||
|
||||
def _is_path_allowed(path: str) -> bool:
|
||||
"""Check if path is in allowlist (prefix match)."""
|
||||
return any(path.startswith(prefix) for prefix in ALLOWED_PATH_PREFIXES)
|
||||
|
||||
|
||||
def add_license_enforcement_middleware(
|
||||
app: FastAPI, logger: logging.LoggerAdapter
|
||||
) -> None:
|
||||
logger.info("License enforcement middleware registered")
|
||||
|
||||
@app.middleware("http")
|
||||
async def enforce_license(
|
||||
request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||||
) -> Response:
|
||||
"""Block requests when license is expired/gated."""
|
||||
if not LICENSE_ENFORCEMENT_ENABLED:
|
||||
return await call_next(request)
|
||||
|
||||
path = request.url.path
|
||||
if path.startswith("/api"):
|
||||
path = path[4:]
|
||||
|
||||
if _is_path_allowed(path):
|
||||
return await call_next(request)
|
||||
|
||||
is_gated = False
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if MULTI_TENANT:
|
||||
try:
|
||||
is_gated = is_tenant_gated(tenant_id)
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check tenant gating status: {e}")
|
||||
# Fail open - don't block users due to Redis connectivity issues
|
||||
is_gated = False
|
||||
else:
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if metadata:
|
||||
if metadata.status == ApplicationStatus.GATED_ACCESS:
|
||||
is_gated = True
|
||||
else:
|
||||
# No license metadata = gated for self-hosted EE
|
||||
is_gated = True
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license metadata: {e}")
|
||||
# Fail open - don't block users due to Redis connectivity issues
|
||||
is_gated = False
|
||||
|
||||
if is_gated:
|
||||
logger.info(f"Blocking request for gated tenant: {tenant_id}, path={path}")
|
||||
return JSONResponse(
|
||||
status_code=402,
|
||||
content={
|
||||
"detail": {
|
||||
"error": "license_expired",
|
||||
"message": "Your subscription has expired. Please update your billing.",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
@@ -1,214 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.query_and_chat.models import BasicCreateChatMessageRequest
|
||||
from ee.onyx.server.query_and_chat.models import (
|
||||
BasicCreateChatMessageWithHistoryRequest,
|
||||
)
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.chat_utils import create_chat_history_chain
|
||||
from onyx.chat.models import ChatBasicResponse
|
||||
from onyx.chat.process_message import gather_stream
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import OptionalSearchSetting
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/chat")
|
||||
|
||||
|
||||
@router.post("/send-message-simple-api")
|
||||
def handle_simplified_chat_message(
|
||||
chat_message_req: BasicCreateChatMessageRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatBasicResponse:
|
||||
"""This is a Non-Streaming version that only gives back a minimal set of information"""
|
||||
logger.notice(f"Received new simple api chat message: {chat_message_req.message}")
|
||||
|
||||
if not chat_message_req.message:
|
||||
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
|
||||
|
||||
# Handle chat session creation if chat_session_id is not provided
|
||||
if chat_message_req.chat_session_id is None:
|
||||
if chat_message_req.persona_id is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Either chat_session_id or persona_id must be provided",
|
||||
)
|
||||
|
||||
# Create a new chat session with the provided persona_id
|
||||
try:
|
||||
new_chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="", # Leave empty for simple API
|
||||
user_id=user.id if user else None,
|
||||
persona_id=chat_message_req.persona_id,
|
||||
)
|
||||
chat_session_id = new_chat_session.id
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise HTTPException(status_code=400, detail="Invalid Persona provided.")
|
||||
else:
|
||||
chat_session_id = chat_message_req.chat_session_id
|
||||
|
||||
try:
|
||||
parent_message = create_chat_history_chain(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)[-1]
|
||||
except Exception:
|
||||
parent_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
|
||||
if (
|
||||
chat_message_req.retrieval_options is None
|
||||
and chat_message_req.search_doc_ids is None
|
||||
):
|
||||
retrieval_options: RetrievalDetails | None = RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=False,
|
||||
)
|
||||
else:
|
||||
retrieval_options = chat_message_req.retrieval_options
|
||||
|
||||
full_chat_msg_info = CreateChatMessageRequest(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=parent_message.id,
|
||||
message=chat_message_req.message,
|
||||
file_descriptors=[],
|
||||
search_doc_ids=chat_message_req.search_doc_ids,
|
||||
retrieval_options=retrieval_options,
|
||||
# Simple API does not support reranking, hide complexity from user
|
||||
rerank_settings=None,
|
||||
query_override=chat_message_req.query_override,
|
||||
# Currently only applies to search flow not chat
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
full_doc=chat_message_req.full_doc,
|
||||
structured_response_format=chat_message_req.structured_response_format,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=full_chat_msg_info,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return gather_stream(packets)
|
||||
|
||||
|
||||
@router.post("/send-message-simple-with-history")
|
||||
def handle_send_message_simple_with_history(
|
||||
req: BasicCreateChatMessageWithHistoryRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatBasicResponse:
|
||||
"""This is a Non-Streaming version that only gives back a minimal set of information.
|
||||
takes in chat history maintained by the caller
|
||||
and does query rephrasing similar to answer-with-quote"""
|
||||
|
||||
if len(req.messages) == 0:
|
||||
raise HTTPException(status_code=400, detail="Messages cannot be zero length")
|
||||
|
||||
# This is a sanity check to make sure the chat history is valid
|
||||
# It must start with a user message and alternate beteen user and assistant
|
||||
expected_role = MessageType.USER
|
||||
for msg in req.messages:
|
||||
if not msg.message:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="One or more chat messages were empty"
|
||||
)
|
||||
|
||||
if msg.role != expected_role:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Message roles must start and end with MessageType.USER and alternate in-between.",
|
||||
)
|
||||
if expected_role == MessageType.USER:
|
||||
expected_role = MessageType.ASSISTANT
|
||||
else:
|
||||
expected_role = MessageType.USER
|
||||
|
||||
query = req.messages[-1].message
|
||||
msg_history = req.messages[:-1]
|
||||
|
||||
logger.notice(f"Received new simple with history chat message: {query}")
|
||||
|
||||
user_id = user.id if user is not None else None
|
||||
chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="handle_send_message_simple_with_history",
|
||||
user_id=user_id,
|
||||
persona_id=req.persona_id,
|
||||
)
|
||||
|
||||
llm = get_llm_for_persona(persona=chat_session.persona, user=user)
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
provider_type=llm.config.model_provider,
|
||||
)
|
||||
|
||||
# Every chat Session begins with an empty root message
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
)
|
||||
|
||||
chat_message = root_message
|
||||
for msg in msg_history:
|
||||
chat_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=chat_message,
|
||||
message=msg.message,
|
||||
token_count=len(llm_tokenizer.encode(msg.message)),
|
||||
message_type=msg.role,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
if req.retrieval_options is None and req.search_doc_ids is None:
|
||||
retrieval_options: RetrievalDetails | None = RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=False,
|
||||
)
|
||||
else:
|
||||
retrieval_options = req.retrieval_options
|
||||
|
||||
full_chat_msg_info = CreateChatMessageRequest(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message_id=chat_message.id,
|
||||
message=query,
|
||||
file_descriptors=[],
|
||||
search_doc_ids=req.search_doc_ids,
|
||||
retrieval_options=retrieval_options,
|
||||
# Simple API does not support reranking, hide complexity from user
|
||||
rerank_settings=None,
|
||||
query_override=None,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
full_doc=req.full_doc,
|
||||
structured_response_format=req.structured_response_format,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=full_chat_msg_info,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return gather_stream(packets)
|
||||
@@ -1,18 +1,12 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import BasicChunkRequest
|
||||
from onyx.context.search.models import ChunkContext
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.server.manage.models import StandardAnswer
|
||||
|
||||
|
||||
@@ -25,119 +19,88 @@ class StandardAnswerResponse(BaseModel):
|
||||
standard_answers: list[StandardAnswer] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DocumentSearchRequest(BasicChunkRequest):
|
||||
user_selected_filters: BaseFilters | None = None
|
||||
class SearchFlowClassificationRequest(BaseModel):
|
||||
user_query: str
|
||||
|
||||
|
||||
class DocumentSearchResponse(BaseModel):
|
||||
top_documents: list[InferenceChunk]
|
||||
class SearchFlowClassificationResponse(BaseModel):
|
||||
is_search_flow: bool
|
||||
|
||||
|
||||
class BasicCreateChatMessageRequest(ChunkContext):
|
||||
"""If a chat_session_id is not provided, a persona_id must be provided to automatically create a new chat session
|
||||
Note, for simplicity this option only allows for a single linear chain of messages
|
||||
"""
|
||||
class SendSearchQueryRequest(BaseModel):
|
||||
search_query: str
|
||||
filters: BaseFilters | None = None
|
||||
num_docs_fed_to_llm_selection: int | None = None
|
||||
run_query_expansion: bool = False
|
||||
|
||||
chat_session_id: UUID | None = None
|
||||
# Optional persona_id to create a new chat session if chat_session_id is not provided
|
||||
persona_id: int | None = None
|
||||
# New message contents
|
||||
message: str
|
||||
# Defaults to using retrieval with no additional filters
|
||||
retrieval_options: RetrievalDetails | None = None
|
||||
# Allows the caller to specify the exact search query they want to use
|
||||
# will disable Query Rewording if specified
|
||||
query_override: str | None = None
|
||||
# If search_doc_ids provided, then retrieval options are unused
|
||||
search_doc_ids: list[int] | None = None
|
||||
# only works if using an OpenAI model. See the following for more details:
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_chat_session_or_persona(self) -> "BasicCreateChatMessageRequest":
|
||||
if self.chat_session_id is None and self.persona_id is None:
|
||||
raise ValueError("Either chat_session_id or persona_id must be provided")
|
||||
return self
|
||||
include_content: bool = False
|
||||
stream: bool = False
|
||||
|
||||
|
||||
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
# Last element is the new query. All previous elements are historical context
|
||||
messages: list[ThreadMessage]
|
||||
persona_id: int
|
||||
retrieval_options: RetrievalDetails | None = None
|
||||
query_override: str | None = None
|
||||
skip_rerank: bool | None = None
|
||||
# If search_doc_ids provided, then retrieval options are unused
|
||||
search_doc_ids: list[int] | None = None
|
||||
# only works if using an OpenAI model. See the following for more details:
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
class SearchDocWithContent(SearchDoc):
|
||||
# Allows None because this is determined by a flag but the object used in code
|
||||
# of the search path uses this type
|
||||
content: str | None
|
||||
|
||||
@classmethod
|
||||
def from_inference_sections(
|
||||
cls,
|
||||
sections: Sequence[InferenceSection],
|
||||
include_content: bool = False,
|
||||
is_internet: bool = False,
|
||||
) -> list["SearchDocWithContent"]:
|
||||
"""Convert InferenceSections to SearchDocWithContent objects.
|
||||
|
||||
class SimpleDoc(BaseModel):
|
||||
id: str
|
||||
semantic_identifier: str
|
||||
link: str | None
|
||||
blurb: str
|
||||
match_highlights: list[str]
|
||||
source_type: DocumentSource
|
||||
metadata: dict | None
|
||||
Args:
|
||||
sections: Sequence of InferenceSection objects
|
||||
include_content: If True, populate content field with combined_content
|
||||
is_internet: Whether these are internet search results
|
||||
|
||||
|
||||
class AgentSubQuestion(BaseModel):
|
||||
sub_question: str
|
||||
document_ids: list[str]
|
||||
|
||||
|
||||
class AgentAnswer(BaseModel):
|
||||
answer: str
|
||||
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
|
||||
|
||||
|
||||
class AgentSubQuery(BaseModel):
|
||||
sub_query: str
|
||||
query_id: int
|
||||
|
||||
@staticmethod
|
||||
def make_dict_by_level_and_question_index(
|
||||
original_dict: dict[tuple[int, int, int], "AgentSubQuery"],
|
||||
) -> dict[int, dict[int, list["AgentSubQuery"]]]:
|
||||
"""Takes a dict of tuple(level, question num, query_id) to sub queries.
|
||||
|
||||
returns a dict of level to dict[question num to list of query_id's]
|
||||
Ordering is asc for readability.
|
||||
Returns:
|
||||
List of SearchDocWithContent with optional content
|
||||
"""
|
||||
# In this function, when we sort int | None, we deliberately push None to the end
|
||||
if not sections:
|
||||
return []
|
||||
|
||||
# map entries to the level_question_dict
|
||||
level_question_dict: dict[int, dict[int, list["AgentSubQuery"]]] = {}
|
||||
for k1, obj in original_dict.items():
|
||||
level = k1[0]
|
||||
question = k1[1]
|
||||
|
||||
if level not in level_question_dict:
|
||||
level_question_dict[level] = {}
|
||||
|
||||
if question not in level_question_dict[level]:
|
||||
level_question_dict[level][question] = []
|
||||
|
||||
level_question_dict[level][question].append(obj)
|
||||
|
||||
# sort each query_id list and question_index
|
||||
for key1, obj1 in level_question_dict.items():
|
||||
for key2, value2 in obj1.items():
|
||||
# sort the query_id list of each question_index
|
||||
level_question_dict[key1][key2] = sorted(
|
||||
value2, key=lambda o: o.query_id
|
||||
)
|
||||
# sort the question_index dict of level
|
||||
level_question_dict[key1] = OrderedDict(
|
||||
sorted(level_question_dict[key1].items(), key=lambda x: (x is None, x))
|
||||
return [
|
||||
cls(
|
||||
document_id=(chunk := section.center_chunk).document_id,
|
||||
chunk_ind=chunk.chunk_id,
|
||||
semantic_identifier=chunk.semantic_identifier or "Unknown",
|
||||
link=chunk.source_links[0] if chunk.source_links else None,
|
||||
blurb=chunk.blurb,
|
||||
source_type=chunk.source_type,
|
||||
boost=chunk.boost,
|
||||
hidden=chunk.hidden,
|
||||
metadata=chunk.metadata,
|
||||
score=chunk.score,
|
||||
match_highlights=chunk.match_highlights,
|
||||
updated_at=chunk.updated_at,
|
||||
primary_owners=chunk.primary_owners,
|
||||
secondary_owners=chunk.secondary_owners,
|
||||
is_internet=is_internet,
|
||||
content=section.combined_content if include_content else None,
|
||||
)
|
||||
for section in sections
|
||||
]
|
||||
|
||||
# sort the top dict of levels
|
||||
sorted_dict = OrderedDict(
|
||||
sorted(level_question_dict.items(), key=lambda x: (x is None, x))
|
||||
)
|
||||
return sorted_dict
|
||||
|
||||
class SearchFullResponse(BaseModel):
|
||||
all_executed_queries: list[str]
|
||||
search_docs: list[SearchDocWithContent]
|
||||
# Reasoning tokens output by the LLM for the document selection
|
||||
doc_selection_reasoning: str | None = None
|
||||
# This a list of document ids that are in the search_docs list
|
||||
llm_selected_doc_ids: list[str] | None = None
|
||||
# Error message if the search failed partway through
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class SearchQueryResponse(BaseModel):
|
||||
query: str
|
||||
query_expansions: list[str] | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class SearchHistoryResponse(BaseModel):
|
||||
search_queries: list[SearchQueryResponse]
|
||||
|
||||
170
backend/ee/onyx/server/query_and_chat/search_backend.py
Normal file
170
backend/ee/onyx/server/query_and_chat/search_backend.py
Normal file
@@ -0,0 +1,170 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.search import fetch_search_queries_for_user
|
||||
from ee.onyx.search.process_search_query import gather_search_stream
|
||||
from ee.onyx.search.process_search_query import stream_search_query
|
||||
from ee.onyx.secondary_llm_flows.search_flow_classification import (
|
||||
classify_is_search_flow,
|
||||
)
|
||||
from ee.onyx.server.query_and_chat.models import SearchFlowClassificationRequest
|
||||
from ee.onyx.server.query_and_chat.models import SearchFlowClassificationResponse
|
||||
from ee.onyx.server.query_and_chat.models import SearchFullResponse
|
||||
from ee.onyx.server.query_and_chat.models import SearchHistoryResponse
|
||||
from ee.onyx.server.query_and_chat.models import SearchQueryResponse
|
||||
from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest
|
||||
from ee.onyx.server.query_and_chat.streaming_models import SearchErrorPacket
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.factory import get_default_llm
|
||||
from onyx.server.usage_limits import check_llm_cost_limit_for_provider
|
||||
from onyx.server.utils import get_json_line
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/search")
|
||||
|
||||
|
||||
@router.post("/search-flow-classification")
|
||||
def search_flow_classification(
|
||||
request: SearchFlowClassificationRequest,
|
||||
# This is added just to ensure this endpoint isn't spammed by non-authorized users since there's an LLM call underneath it
|
||||
_: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SearchFlowClassificationResponse:
|
||||
query = request.user_query
|
||||
# This is a heuristic that if the user is typing a lot of text, it's unlikely they're looking for some specific document
|
||||
# Most likely something needs to be done with the text included so we'll just classify it as a chat flow
|
||||
if len(query) > 200:
|
||||
return SearchFlowClassificationResponse(is_search_flow=False)
|
||||
|
||||
llm = get_default_llm()
|
||||
|
||||
check_llm_cost_limit_for_provider(
|
||||
db_session=db_session,
|
||||
tenant_id=get_current_tenant_id(),
|
||||
llm_provider_api_key=llm.config.api_key,
|
||||
)
|
||||
|
||||
try:
|
||||
is_search_flow = classify_is_search_flow(query=query, llm=llm)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Search flow classification failed; defaulting to chat flow",
|
||||
exc_info=e,
|
||||
)
|
||||
is_search_flow = False
|
||||
|
||||
return SearchFlowClassificationResponse(is_search_flow=is_search_flow)
|
||||
|
||||
|
||||
@router.post("/send-search-message", response_model=None)
|
||||
def handle_send_search_message(
|
||||
request: SendSearchQueryRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse | SearchFullResponse:
|
||||
"""
|
||||
Execute a search query with optional streaming.
|
||||
|
||||
When stream=True: Returns StreamingResponse with SSE
|
||||
When stream=False: Returns SearchFullResponse
|
||||
"""
|
||||
logger.debug(f"Received search query: {request.search_query}")
|
||||
|
||||
# Non-streaming path
|
||||
if not request.stream:
|
||||
try:
|
||||
packets = stream_search_query(request, user, db_session)
|
||||
return gather_search_stream(packets)
|
||||
except NotImplementedError as e:
|
||||
return SearchFullResponse(
|
||||
all_executed_queries=[],
|
||||
search_docs=[],
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
# Streaming path
|
||||
def stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
with get_session_with_current_tenant() as streaming_db_session:
|
||||
for packet in stream_search_query(request, user, streaming_db_session):
|
||||
yield get_json_line(packet.model_dump())
|
||||
except NotImplementedError as e:
|
||||
yield get_json_line(SearchErrorPacket(error=str(e)).model_dump())
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error in search streaming")
|
||||
yield get_json_line(SearchErrorPacket(error=str(e)).model_dump())
|
||||
|
||||
return StreamingResponse(stream_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.get("/search-history")
|
||||
def get_search_history(
|
||||
limit: int = 100,
|
||||
filter_days: int | None = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SearchHistoryResponse:
|
||||
"""
|
||||
Fetch past search queries for the authenticated user.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of queries to return (default 100)
|
||||
filter_days: Only return queries from the last N days (optional)
|
||||
|
||||
Returns:
|
||||
SearchHistoryResponse with list of search queries, ordered by most recent first.
|
||||
"""
|
||||
# Validate limit
|
||||
if limit <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="limit must be greater than 0",
|
||||
)
|
||||
if limit > 1000:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="limit must be at most 1000",
|
||||
)
|
||||
|
||||
# Validate filter_days
|
||||
if filter_days is not None and filter_days <= 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="filter_days must be greater than 0",
|
||||
)
|
||||
|
||||
# TODO(yuhong) remove this
|
||||
if user is None:
|
||||
# Return empty list for unauthenticated users
|
||||
return SearchHistoryResponse(search_queries=[])
|
||||
|
||||
search_queries = fetch_search_queries_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
filter_days=filter_days,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return SearchHistoryResponse(
|
||||
search_queries=[
|
||||
SearchQueryResponse(
|
||||
query=sq.query,
|
||||
query_expansions=sq.query_expansions,
|
||||
created_at=sq.created_at,
|
||||
)
|
||||
for sq in search_queries
|
||||
]
|
||||
)
|
||||
35
backend/ee/onyx/server/query_and_chat/streaming_models.py
Normal file
35
backend/ee/onyx/server/query_and_chat/streaming_models.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from ee.onyx.server.query_and_chat.models import SearchDocWithContent
|
||||
|
||||
|
||||
class SearchQueriesPacket(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
type: Literal["search_queries"] = "search_queries"
|
||||
all_executed_queries: list[str]
|
||||
|
||||
|
||||
class SearchDocsPacket(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
type: Literal["search_docs"] = "search_docs"
|
||||
search_docs: list[SearchDocWithContent]
|
||||
|
||||
|
||||
class SearchErrorPacket(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
type: Literal["search_error"] = "search_error"
|
||||
error: str
|
||||
|
||||
|
||||
class LLMSelectedDocsPacket(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
type: Literal["llm_selected_docs"] = "llm_selected_docs"
|
||||
# None if LLM selection failed, empty list if no docs selected, list of IDs otherwise
|
||||
llm_selected_doc_ids: list[str] | None
|
||||
@@ -32,6 +32,7 @@ from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.configs.constants import SessionType
|
||||
@@ -294,7 +295,7 @@ def list_all_query_history_exports(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/admin/query-history/start-export")
|
||||
@router.post("/admin/query-history/start-export", tags=PUBLIC_API_TAGS)
|
||||
def start_query_history_export(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
@@ -340,7 +341,7 @@ def start_query_history_export(
|
||||
return {"request_id": task_id}
|
||||
|
||||
|
||||
@router.get("/admin/query-history/export-status")
|
||||
@router.get("/admin/query-history/export-status", tags=PUBLIC_API_TAGS)
|
||||
def get_query_history_export_status(
|
||||
request_id: str,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
@@ -374,7 +375,7 @@ def get_query_history_export_status(
|
||||
return {"status": TaskStatus.SUCCESS}
|
||||
|
||||
|
||||
@router.get("/admin/query-history/download")
|
||||
@router.get("/admin/query-history/download", tags=PUBLIC_API_TAGS)
|
||||
def download_query_history_csv(
|
||||
request_id: str,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
|
||||
0
backend/ee/onyx/server/settings/__init__.py
Normal file
0
backend/ee/onyx/server/settings/__init__.py
Normal file
54
backend/ee/onyx/server/settings/api.py
Normal file
54
backend/ee/onyx/server/settings/api.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""EE Settings API - provides license-aware settings override."""
|
||||
|
||||
from redis.exceptions import RedisError
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Statuses that indicate a billing/license problem - propagate these to settings
|
||||
_GATED_STATUSES = frozenset(
|
||||
{
|
||||
ApplicationStatus.GATED_ACCESS,
|
||||
ApplicationStatus.GRACE_PERIOD,
|
||||
ApplicationStatus.PAYMENT_REMINDER,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
"""EE version: checks license status for self-hosted deployments.
|
||||
|
||||
For self-hosted, looks up license metadata and overrides application_status
|
||||
if the license is missing or indicates a problem (expired, grace period, etc.).
|
||||
|
||||
For multi-tenant (cloud), the settings already have the correct status
|
||||
from the control plane, so no override is needed.
|
||||
|
||||
If LICENSE_ENFORCEMENT_ENABLED is false, settings are returned unchanged,
|
||||
allowing the product to function normally without license checks.
|
||||
"""
|
||||
if not LICENSE_ENFORCEMENT_ENABLED:
|
||||
return settings
|
||||
|
||||
if MULTI_TENANT:
|
||||
return settings
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if metadata and metadata.status in _GATED_STATUSES:
|
||||
settings.application_status = metadata.status
|
||||
elif not metadata:
|
||||
# No license = gated access for self-hosted EE
|
||||
settings.application_status = ApplicationStatus.GATED_ACCESS
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license metadata for settings: {e}")
|
||||
|
||||
return settings
|
||||
133
backend/ee/onyx/server/tenant_usage_limits.py
Normal file
133
backend/ee/onyx/server/tenant_usage_limits.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Tenant-specific usage limit overrides from the control plane (EE version)."""
|
||||
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.server.tenant_usage_limits import TenantUsageLimitOverrides
|
||||
from onyx.server.usage_limits import NO_LIMIT
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# In-memory storage for tenant overrides (populated at startup)
|
||||
_tenant_usage_limit_overrides: dict[str, TenantUsageLimitOverrides] | None = None
|
||||
_last_fetch_time: float = 0.0
|
||||
_FETCH_INTERVAL = 60 * 60 * 24 # 24 hours
|
||||
_ERROR_FETCH_INTERVAL = 30 * 60 # 30 minutes (if the last fetch failed)
|
||||
|
||||
|
||||
def fetch_usage_limit_overrides() -> dict[str, TenantUsageLimitOverrides] | None:
|
||||
"""
|
||||
Fetch tenant-specific usage limit overrides from the control plane.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping tenant_id to their specific limit overrides.
|
||||
Returns empty dict on any error (falls back to defaults).
|
||||
"""
|
||||
try:
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/usage-limit-overrides"
|
||||
response = requests.get(url, headers=headers, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
tenant_overrides = response.json()
|
||||
|
||||
# Parse each tenant's overrides
|
||||
result: dict[str, TenantUsageLimitOverrides] = {}
|
||||
for override_data in tenant_overrides:
|
||||
tenant_id = override_data["tenant_id"]
|
||||
try:
|
||||
result[tenant_id] = TenantUsageLimitOverrides(**override_data)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to parse usage limit overrides for tenant {tenant_id}: {e}"
|
||||
)
|
||||
|
||||
return (
|
||||
result or None
|
||||
) # if empty dictionary, something went wrong and we shouldn't enforce limits
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.warning(f"Failed to fetch usage limit overrides from control plane: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing usage limit overrides: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def load_usage_limit_overrides() -> None:
|
||||
"""
|
||||
Load tenant usage limit overrides from the control plane.
|
||||
"""
|
||||
global _tenant_usage_limit_overrides
|
||||
global _last_fetch_time
|
||||
|
||||
logger.info("Loading tenant usage limit overrides from control plane...")
|
||||
overrides = fetch_usage_limit_overrides()
|
||||
|
||||
_last_fetch_time = time.time()
|
||||
|
||||
# use the new result if it exists, otherwise use the old result
|
||||
# (prevents us from updating to a failed fetch result)
|
||||
_tenant_usage_limit_overrides = overrides or _tenant_usage_limit_overrides
|
||||
|
||||
if overrides:
|
||||
logger.info(f"Loaded usage limit overrides for {len(overrides)} tenants")
|
||||
else:
|
||||
logger.info("No tenant-specific usage limit overrides found")
|
||||
|
||||
|
||||
def unlimited(tenant_id: str) -> TenantUsageLimitOverrides:
|
||||
return TenantUsageLimitOverrides(
|
||||
tenant_id=tenant_id,
|
||||
llm_cost_cents_trial=NO_LIMIT,
|
||||
llm_cost_cents_paid=NO_LIMIT,
|
||||
chunks_indexed_trial=NO_LIMIT,
|
||||
chunks_indexed_paid=NO_LIMIT,
|
||||
api_calls_trial=NO_LIMIT,
|
||||
api_calls_paid=NO_LIMIT,
|
||||
non_streaming_calls_trial=NO_LIMIT,
|
||||
non_streaming_calls_paid=NO_LIMIT,
|
||||
)
|
||||
|
||||
|
||||
def get_tenant_usage_limit_overrides(
|
||||
tenant_id: str,
|
||||
) -> TenantUsageLimitOverrides | None:
|
||||
"""
|
||||
Get the usage limit overrides for a specific tenant.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID to look up
|
||||
|
||||
Returns:
|
||||
TenantUsageLimitOverrides if the tenant has overrides, None otherwise.
|
||||
"""
|
||||
|
||||
if DEV_MODE: # in dev mode, we return unlimited limits for all tenants
|
||||
return unlimited(tenant_id)
|
||||
|
||||
global _tenant_usage_limit_overrides
|
||||
time_since = time.time() - _last_fetch_time
|
||||
if (
|
||||
_tenant_usage_limit_overrides is None and time_since > _ERROR_FETCH_INTERVAL
|
||||
) or (time_since > _FETCH_INTERVAL):
|
||||
logger.debug(
|
||||
f"Last fetch time: {_last_fetch_time}, time since last fetch: {time_since}"
|
||||
)
|
||||
|
||||
load_usage_limit_overrides()
|
||||
|
||||
# If we have failed to fetch from the control plane or we're in dev mode, don't usage limit anyone.
|
||||
if _tenant_usage_limit_overrides is None or DEV_MODE:
|
||||
return unlimited(tenant_id)
|
||||
return _tenant_usage_limit_overrides.get(tenant_id)
|
||||
@@ -1,9 +1,9 @@
|
||||
from typing import cast
|
||||
from typing import Literal
|
||||
|
||||
import requests
|
||||
import stripe
|
||||
|
||||
from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
@@ -16,15 +16,21 @@ stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def fetch_stripe_checkout_session(tenant_id: str) -> str:
|
||||
def fetch_stripe_checkout_session(
|
||||
tenant_id: str,
|
||||
billing_period: Literal["monthly", "annual"] = "monthly",
|
||||
) -> str:
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/create-checkout-session"
|
||||
params = {"tenant_id": tenant_id}
|
||||
response = requests.post(url, headers=headers, params=params)
|
||||
payload = {
|
||||
"tenant_id": tenant_id,
|
||||
"billing_period": billing_period,
|
||||
}
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()["sessionId"]
|
||||
|
||||
@@ -70,24 +76,46 @@ def fetch_billing_information(
|
||||
return BillingInformation(**response_data)
|
||||
|
||||
|
||||
def fetch_customer_portal_session(tenant_id: str, return_url: str | None = None) -> str:
|
||||
"""
|
||||
Fetch a Stripe customer portal session URL from the control plane.
|
||||
NOTE: This is currently only used for multi-tenant (cloud) deployments.
|
||||
Self-hosted proxy endpoints will be added in a future phase.
|
||||
"""
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/create-customer-portal-session"
|
||||
payload = {"tenant_id": tenant_id}
|
||||
if return_url:
|
||||
payload["return_url"] = return_url
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()["url"]
|
||||
|
||||
|
||||
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:
|
||||
"""
|
||||
Send a request to the control service to register the number of users for a tenant.
|
||||
Update the number of seats for a tenant's subscription.
|
||||
Preserves the existing price (monthly, annual, or grandfathered).
|
||||
"""
|
||||
|
||||
if not STRIPE_PRICE_ID:
|
||||
raise Exception("STRIPE_PRICE_ID is not set")
|
||||
|
||||
response = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_subscription_id = cast(str, response.get("stripe_subscription_id"))
|
||||
|
||||
subscription = stripe.Subscription.retrieve(stripe_subscription_id)
|
||||
subscription_item = subscription["items"]["data"][0]
|
||||
|
||||
# Use existing price to preserve the customer's current plan
|
||||
current_price_id = subscription_item.price.id
|
||||
|
||||
updated_subscription = stripe.Subscription.modify(
|
||||
stripe_subscription_id,
|
||||
items=[
|
||||
{
|
||||
"id": subscription["items"]["data"][0].id,
|
||||
"price": STRIPE_PRICE_ID,
|
||||
"id": subscription_item.id,
|
||||
"price": current_price_id,
|
||||
"quantity": number_of_users,
|
||||
}
|
||||
],
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
import stripe
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import control_plane_dep
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.billing import fetch_customer_portal_session
|
||||
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
|
||||
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import CreateSubscriptionSessionRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingFullSyncRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingResponse
|
||||
@@ -23,7 +22,6 @@ from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
@@ -82,21 +80,17 @@ async def billing_information(
|
||||
async def create_customer_portal_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> dict:
|
||||
"""
|
||||
Create a Stripe customer portal session via the control plane.
|
||||
NOTE: This is currently only used for multi-tenant (cloud) deployments.
|
||||
Self-hosted proxy endpoints will be added in a future phase.
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
return_url = f"{WEB_DOMAIN}/admin/billing"
|
||||
|
||||
try:
|
||||
stripe_info = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_customer_id = stripe_info.get("stripe_customer_id")
|
||||
if not stripe_customer_id:
|
||||
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
|
||||
logger.info(stripe_customer_id)
|
||||
|
||||
portal_session = stripe.billing_portal.Session.create(
|
||||
customer=stripe_customer_id,
|
||||
return_url=f"{WEB_DOMAIN}/admin/billing",
|
||||
)
|
||||
logger.info(portal_session)
|
||||
return {"url": portal_session.url}
|
||||
portal_url = fetch_customer_portal_session(tenant_id, return_url)
|
||||
return {"url": portal_url}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create customer portal session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -104,15 +98,18 @@ async def create_customer_portal_session(
|
||||
|
||||
@router.post("/create-subscription-session")
|
||||
async def create_subscription_session(
|
||||
request: CreateSubscriptionSessionRequest | None = None,
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> SubscriptionSessionResponse:
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="Tenant ID not found")
|
||||
session_id = fetch_stripe_checkout_session(tenant_id)
|
||||
|
||||
billing_period = request.billing_period if request else "monthly"
|
||||
session_id = fetch_stripe_checkout_session(tenant_id, billing_period)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create resubscription session")
|
||||
logger.exception("Failed to create subscription session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -73,6 +74,12 @@ class SubscriptionSessionResponse(BaseModel):
|
||||
sessionId: str
|
||||
|
||||
|
||||
class CreateSubscriptionSessionRequest(BaseModel):
|
||||
"""Request to create a subscription checkout session."""
|
||||
|
||||
billing_period: Literal["monthly", "annual"] = "monthly"
|
||||
|
||||
|
||||
class TenantByDomainResponse(BaseModel):
|
||||
tenant_id: str
|
||||
number_of_users: int
|
||||
|
||||
@@ -65,3 +65,9 @@ def get_gated_tenants() -> set[str]:
|
||||
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
gated_tenants_bytes = cast(set[bytes], redis_client.smembers(GATED_TENANTS_KEY))
|
||||
return {tenant_id.decode("utf-8") for tenant_id in gated_tenants_bytes}
|
||||
|
||||
|
||||
def is_tenant_gated(tenant_id: str) -> bool:
|
||||
"""Fast O(1) check if tenant is in gated set (multi-tenant only)."""
|
||||
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
return bool(redis_client.sismember(GATED_TENANTS_KEY, tenant_id))
|
||||
|
||||
@@ -32,6 +32,7 @@ from onyx.configs.app_configs import VERTEXAI_DEFAULT_LOCATION
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine.sql_engine import get_session_with_shared_schema
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.image_generation import create_default_image_gen_config_from_api_key
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_cloud_embedding_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
@@ -330,6 +331,14 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openai_provider)
|
||||
|
||||
# Create default image generation config using the OpenAI API key
|
||||
try:
|
||||
create_default_image_gen_config_from_api_key(
|
||||
db_session, OPENAI_DEFAULT_API_KEY
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create default image gen config: {e}")
|
||||
else:
|
||||
logger.info(
|
||||
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
|
||||
|
||||
@@ -9,6 +9,7 @@ from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits_for_user
|
||||
from ee.onyx.db.token_limit import insert_user_group_token_rate_limit
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.token_limit import fetch_all_user_token_rate_limits
|
||||
@@ -17,7 +18,7 @@ from onyx.server.query_and_chat.token_limit import any_rate_limit_exists
|
||||
from onyx.server.token_rate_limits.models import TokenRateLimitArgs
|
||||
from onyx.server.token_rate_limits.models import TokenRateLimitDisplay
|
||||
|
||||
router = APIRouter(prefix="/admin/token-rate-limits")
|
||||
router = APIRouter(prefix="/admin/token-rate-limits", tags=PUBLIC_API_TAGS)
|
||||
|
||||
|
||||
"""
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
"""EE Usage limits - trial detection via billing information."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
@@ -31,13 +28,7 @@ def is_tenant_on_trial(tenant_id: str) -> bool:
|
||||
return True
|
||||
|
||||
if isinstance(billing_info, BillingInformation):
|
||||
# Check if trial is active
|
||||
if billing_info.trial_end is not None:
|
||||
now = datetime.now(timezone.utc)
|
||||
# Trial active if trial_end is in the future
|
||||
# and subscription status indicates trialing
|
||||
if billing_info.trial_end > now and billing_info.status == "trialing":
|
||||
return True
|
||||
return billing_info.status == "trialing"
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from ee.onyx.server.user_group.models import UserGroupCreate
|
||||
from ee.onyx.server.user_group.models import UserGroupUpdate
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
@@ -25,7 +26,7 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/manage")
|
||||
router = APIRouter(prefix="/manage", tags=PUBLIC_API_TAGS)
|
||||
|
||||
|
||||
@router.get("/admin/user-group")
|
||||
|
||||
@@ -105,6 +105,8 @@ class DocExternalAccess:
|
||||
)
|
||||
|
||||
|
||||
# TODO(andrei): First refactor this into a pydantic model, then get rid of
|
||||
# duplicate fields.
|
||||
@dataclass(frozen=True, init=False)
|
||||
class DocumentAccess(ExternalAccess):
|
||||
# User emails for Onyx users, None indicates admin
|
||||
|
||||
192
backend/onyx/auth/disposable_email_validator.py
Normal file
192
backend/onyx/auth/disposable_email_validator.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
Utility to validate and block disposable/temporary email addresses.
|
||||
|
||||
This module fetches a list of known disposable email domains from a remote source
|
||||
and caches them for performance. It's used during user registration to prevent
|
||||
abuse from temporary email services.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from typing import Set
|
||||
|
||||
import httpx
|
||||
|
||||
from onyx.configs.app_configs import DISPOSABLE_EMAIL_DOMAINS_URL
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class DisposableEmailValidator:
|
||||
"""
|
||||
Thread-safe singleton validator for disposable email domains.
|
||||
|
||||
Fetches and caches the list of disposable domains, with periodic refresh.
|
||||
"""
|
||||
|
||||
_instance: "DisposableEmailValidator | None" = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls) -> "DisposableEmailValidator":
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Check if already initialized using a try/except to avoid type issues
|
||||
try:
|
||||
if self._initialized:
|
||||
return
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
self._domains: Set[str] = set()
|
||||
self._last_fetch_time: float = 0
|
||||
self._fetch_lock = threading.Lock()
|
||||
# Cache for 1 hour
|
||||
self._cache_duration = 3600
|
||||
# Hardcoded fallback list of common disposable domains
|
||||
# This ensures we block at least these even if the remote fetch fails
|
||||
self._fallback_domains = {
|
||||
"trashlify.com",
|
||||
"10minutemail.com",
|
||||
"guerrillamail.com",
|
||||
"mailinator.com",
|
||||
"tempmail.com",
|
||||
"throwaway.email",
|
||||
"yopmail.com",
|
||||
"temp-mail.org",
|
||||
"getnada.com",
|
||||
"maildrop.cc",
|
||||
}
|
||||
# Set initialized flag last to prevent race conditions
|
||||
self._initialized: bool = True
|
||||
|
||||
def _should_refresh(self) -> bool:
|
||||
"""Check if the cached domains should be refreshed."""
|
||||
return (time.time() - self._last_fetch_time) > self._cache_duration
|
||||
|
||||
def _fetch_domains(self) -> Set[str]:
|
||||
"""
|
||||
Fetch disposable email domains from the configured URL.
|
||||
|
||||
Returns:
|
||||
Set of domain strings (lowercased)
|
||||
"""
|
||||
if not DISPOSABLE_EMAIL_DOMAINS_URL:
|
||||
logger.debug("DISPOSABLE_EMAIL_DOMAINS_URL not configured")
|
||||
return self._fallback_domains.copy()
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
f"Fetching disposable email domains from {DISPOSABLE_EMAIL_DOMAINS_URL}"
|
||||
)
|
||||
with httpx.Client(timeout=10.0) as client:
|
||||
response = client.get(DISPOSABLE_EMAIL_DOMAINS_URL)
|
||||
response.raise_for_status()
|
||||
|
||||
domains_list = response.json()
|
||||
|
||||
if not isinstance(domains_list, list):
|
||||
logger.error(
|
||||
f"Expected list from disposable domains URL, got {type(domains_list)}"
|
||||
)
|
||||
return self._fallback_domains.copy()
|
||||
|
||||
# Convert all to lowercase and create set
|
||||
domains = {domain.lower().strip() for domain in domains_list if domain}
|
||||
|
||||
# Always include fallback domains
|
||||
domains.update(self._fallback_domains)
|
||||
|
||||
logger.info(
|
||||
f"Successfully fetched {len(domains)} disposable email domains"
|
||||
)
|
||||
return domains
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.warning(f"Failed to fetch disposable domains (HTTP error): {e}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch disposable domains: {e}")
|
||||
|
||||
# On error, return fallback domains
|
||||
return self._fallback_domains.copy()
|
||||
|
||||
def get_domains(self) -> Set[str]:
|
||||
"""
|
||||
Get the cached set of disposable email domains.
|
||||
Refreshes the cache if needed.
|
||||
|
||||
Returns:
|
||||
Set of disposable domain strings (lowercased)
|
||||
"""
|
||||
# Fast path: return cached domains if still fresh
|
||||
if self._domains and not self._should_refresh():
|
||||
return self._domains.copy()
|
||||
|
||||
# Slow path: need to refresh
|
||||
with self._fetch_lock:
|
||||
# Double-check after acquiring lock
|
||||
if self._domains and not self._should_refresh():
|
||||
return self._domains.copy()
|
||||
|
||||
self._domains = self._fetch_domains()
|
||||
self._last_fetch_time = time.time()
|
||||
return self._domains.copy()
|
||||
|
||||
def is_disposable(self, email: str) -> bool:
|
||||
"""
|
||||
Check if an email address uses a disposable domain.
|
||||
|
||||
Args:
|
||||
email: The email address to check
|
||||
|
||||
Returns:
|
||||
True if the email domain is disposable, False otherwise
|
||||
"""
|
||||
if not email or "@" not in email:
|
||||
return False
|
||||
|
||||
parts = email.split("@")
|
||||
if len(parts) != 2 or not parts[0]: # Must have user@domain with non-empty user
|
||||
return False
|
||||
|
||||
domain = parts[1].lower().strip()
|
||||
if not domain: # Domain part must not be empty
|
||||
return False
|
||||
|
||||
disposable_domains = self.get_domains()
|
||||
return domain in disposable_domains
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
_validator = DisposableEmailValidator()
|
||||
|
||||
|
||||
def is_disposable_email(email: str) -> bool:
|
||||
"""
|
||||
Check if an email address uses a disposable/temporary domain.
|
||||
|
||||
This is a convenience function that uses the global validator instance.
|
||||
|
||||
Args:
|
||||
email: The email address to check
|
||||
|
||||
Returns:
|
||||
True if the email uses a disposable domain, False otherwise
|
||||
"""
|
||||
return _validator.is_disposable(email)
|
||||
|
||||
|
||||
def refresh_disposable_domains() -> None:
|
||||
"""
|
||||
Force a refresh of the disposable domains list.
|
||||
|
||||
This can be called manually if you want to update the list
|
||||
without waiting for the cache to expire.
|
||||
"""
|
||||
_validator._last_fetch_time = 0
|
||||
_validator.get_domains()
|
||||
@@ -60,6 +60,7 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
from onyx.auth.disposable_email_validator import is_disposable_email
|
||||
from onyx.auth.email_utils import send_forgot_password_email
|
||||
from onyx.auth.email_utils import send_user_verification_email
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
@@ -248,13 +249,23 @@ def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
|
||||
|
||||
|
||||
def verify_email_domain(email: str) -> None:
|
||||
if email.count("@") != 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email is not valid",
|
||||
)
|
||||
|
||||
domain = email.split("@")[-1].lower()
|
||||
|
||||
# Check if email uses a disposable/temporary domain
|
||||
if is_disposable_email(email):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Disposable email addresses are not allowed. Please use a permanent email address.",
|
||||
)
|
||||
|
||||
# Check domain whitelist if configured
|
||||
if VALID_EMAIL_DOMAINS:
|
||||
if email.count("@") != 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email is not valid",
|
||||
)
|
||||
domain = email.split("@")[-1].lower()
|
||||
if domain not in VALID_EMAIL_DOMAINS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@@ -322,6 +333,27 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user_create.password, cast(schemas.UC, user_create)
|
||||
)
|
||||
|
||||
# Check for disposable emails BEFORE provisioning tenant
|
||||
# This prevents creating tenants for throwaway email addresses
|
||||
try:
|
||||
verify_email_domain(user_create.email)
|
||||
except HTTPException as e:
|
||||
# Log blocked disposable email attempts
|
||||
if (
|
||||
e.status_code == status.HTTP_400_BAD_REQUEST
|
||||
and "Disposable email" in str(e.detail)
|
||||
):
|
||||
domain = (
|
||||
user_create.email.split("@")[-1]
|
||||
if "@" in user_create.email
|
||||
else "unknown"
|
||||
)
|
||||
logger.warning(
|
||||
f"Blocked disposable email registration attempt: {domain}",
|
||||
extra={"email_domain": domain},
|
||||
)
|
||||
raise
|
||||
|
||||
user_count: int | None = None
|
||||
referral_source = (
|
||||
request.cookies.get("referral_source", None)
|
||||
@@ -343,8 +375,17 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
async with get_async_session_context_manager(tenant_id) as db_session:
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
# Check invite list based on deployment mode
|
||||
if MULTI_TENANT:
|
||||
# Multi-tenant: Only require invite for existing tenants
|
||||
# New tenant creation (first user) doesn't require an invite
|
||||
user_count = await get_user_count()
|
||||
if user_count > 0:
|
||||
# Tenant already has users - require invite for new users
|
||||
verify_email_is_invited(user_create.email)
|
||||
else:
|
||||
# Single-tenant: Check invite list (skips if SAML/OIDC or no list configured)
|
||||
verify_email_is_invited(user_create.email)
|
||||
if MULTI_TENANT:
|
||||
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
|
||||
db_session, User, OAuthAccount
|
||||
|
||||
@@ -26,6 +26,7 @@ from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
from onyx.background.celery.celery_utils import make_probe_path
|
||||
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_PREFIX
|
||||
from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_TASKSET_KEY
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_FOR_ONYX
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine.sql_engine import get_sqlalchemy_engine
|
||||
@@ -515,6 +516,10 @@ def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for Vespa to become ready subject to a timeout.
|
||||
Raises WorkerShutdown if the timeout is reached."""
|
||||
|
||||
if ENABLE_OPENSEARCH_FOR_ONYX:
|
||||
# TODO(andrei): Do some similar liveness checking for OpenSearch.
|
||||
return
|
||||
|
||||
if not wait_for_vespa_with_timeout():
|
||||
msg = "Vespa: Readiness probe did not succeed within the timeout. Exiting..."
|
||||
logger.error(msg)
|
||||
|
||||
@@ -124,6 +124,7 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.kg_processing",
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
# Light worker tasks
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
|
||||
@@ -98,8 +98,5 @@ for bootstep in base_bootsteps:
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
# Ensure the user files indexing worker registers the doc_id migration task
|
||||
# TODO(subash): remove this once the doc_id migration is complete
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -2,9 +2,12 @@ import copy
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from celery.schedules import crontab
|
||||
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -54,16 +57,6 @@ beat_task_templates: list[dict] = [
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "user-file-docid-migration",
|
||||
"task": OnyxCeleryTask.USER_FILE_DOCID_MIGRATION,
|
||||
"schedule": timedelta(minutes=10),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.USER_FILES_INDEXING,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-kg-processing",
|
||||
"task": OnyxCeleryTask.CHECK_KG_PROCESSING,
|
||||
@@ -181,7 +174,26 @@ if AUTO_LLM_CONFIG_URL:
|
||||
"schedule": timedelta(seconds=AUTO_LLM_UPDATE_INTERVAL_SECONDS),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": AUTO_LLM_UPDATE_INTERVAL_SECONDS,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Add scheduled eval task if datasets are configured
|
||||
if SCHEDULED_EVAL_DATASET_NAMES:
|
||||
beat_task_templates.append(
|
||||
{
|
||||
"name": "scheduled-eval-pipeline",
|
||||
"task": OnyxCeleryTask.SCHEDULED_EVAL_TASK,
|
||||
# run every Sunday at midnight UTC
|
||||
"schedule": crontab(
|
||||
hour=0,
|
||||
minute=0,
|
||||
day_of_week=0,
|
||||
),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -72,15 +72,6 @@ def try_creating_docfetching_task(
|
||||
# Another indexing attempt is already running
|
||||
return None
|
||||
|
||||
# Determine which queue to use based on whether this is a user file
|
||||
# TODO: at the moment the indexing pipeline is
|
||||
# shared between user files and connectors
|
||||
queue = (
|
||||
OnyxCeleryQueues.USER_FILES_INDEXING
|
||||
if cc_pair.is_user_file
|
||||
else OnyxCeleryQueues.CONNECTOR_DOC_FETCHING
|
||||
)
|
||||
|
||||
# Use higher priority for first-time indexing to ensure new connectors
|
||||
# get processed before re-indexing of existing connectors
|
||||
has_successful_attempt = cc_pair.last_successful_index_time is not None
|
||||
@@ -99,7 +90,7 @@ def try_creating_docfetching_task(
|
||||
search_settings_id=search_settings.id,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=queue,
|
||||
queue=OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
|
||||
task_id=custom_task_id,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
@@ -40,9 +41,11 @@ from onyx.background.indexing.checkpointing_utils import (
|
||||
)
|
||||
from onyx.background.indexing.index_attempt_utils import cleanup_index_attempts
|
||||
from onyx.background.indexing.index_attempt_utils import get_old_index_attempts
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
@@ -59,11 +62,9 @@ from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from onyx.db.connector_credential_pair import (
|
||||
fetch_indexable_standard_connector_credential_pair_ids,
|
||||
)
|
||||
from onyx.db.connector_credential_pair import (
|
||||
fetch_indexable_user_file_connector_credential_pair_ids,
|
||||
)
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import set_cc_pair_repeated_error_state
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair_from_id
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.time_utils import get_db_current_time
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
@@ -86,7 +87,6 @@ from onyx.db.models import SearchSettings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.db.swap_index import check_and_perform_index_swap
|
||||
from onyx.db.usage import UsageLimitExceededError
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.file_store.document_batch_storage import DocumentBatchStorage
|
||||
from onyx.file_store.document_batch_storage import get_document_batch_storage
|
||||
@@ -540,12 +540,7 @@ def check_indexing_completion(
|
||||
]:
|
||||
# User file connectors must be paused on success
|
||||
# NOTE: _run_indexing doesn't update connectors if the index attempt is the future embedding model
|
||||
# TODO: figure out why this doesn't pause connectors during swap
|
||||
cc_pair.status = (
|
||||
ConnectorCredentialPairStatus.PAUSED
|
||||
if cc_pair.is_user_file
|
||||
else ConnectorCredentialPairStatus.ACTIVE
|
||||
)
|
||||
cc_pair.status = ConnectorCredentialPairStatus.ACTIVE
|
||||
db_session.commit()
|
||||
|
||||
mt_cloud_telemetry(
|
||||
@@ -811,13 +806,8 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
db_session, active_cc_pairs_only=True
|
||||
)
|
||||
)
|
||||
user_file_cc_pair_ids = (
|
||||
fetch_indexable_user_file_connector_credential_pair_ids(
|
||||
db_session, search_settings_id=current_search_settings.id
|
||||
)
|
||||
)
|
||||
|
||||
primary_cc_pair_ids = standard_cc_pair_ids + user_file_cc_pair_ids
|
||||
primary_cc_pair_ids = standard_cc_pair_ids
|
||||
|
||||
# Get CC pairs for secondary search settings
|
||||
secondary_cc_pair_ids: list[int] = []
|
||||
@@ -833,30 +823,47 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
db_session, active_cc_pairs_only=not include_paused
|
||||
)
|
||||
)
|
||||
user_file_cc_pair_ids = (
|
||||
fetch_indexable_user_file_connector_credential_pair_ids(
|
||||
db_session, search_settings_id=secondary_search_settings.id
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
secondary_cc_pair_ids = standard_cc_pair_ids + user_file_cc_pair_ids
|
||||
secondary_cc_pair_ids = standard_cc_pair_ids
|
||||
|
||||
# Flag CC pairs in repeated error state for primary/current search settings
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for cc_pair_id in primary_cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
|
||||
if is_in_repeated_error_state(
|
||||
cc_pair_id=cc_pair_id,
|
||||
search_settings_id=current_search_settings.id,
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
# if already in repeated error state, don't do anything
|
||||
# this is important so that we don't keep pausing the connector
|
||||
# immediately upon a user un-pausing it to manually re-trigger and
|
||||
# recover.
|
||||
if (
|
||||
cc_pair
|
||||
and not cc_pair.in_repeated_error_state
|
||||
and is_in_repeated_error_state(
|
||||
cc_pair=cc_pair,
|
||||
search_settings_id=current_search_settings.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
):
|
||||
set_cc_pair_repeated_error_state(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
in_repeated_error_state=True,
|
||||
)
|
||||
# When entering repeated error state, also pause the connector
|
||||
# to prevent continued indexing retry attempts burning through embedding credits.
|
||||
# NOTE: only for Cloud, since most self-hosted users use self-hosted embedding
|
||||
# models. Also, they are more prone to repeated failures -> eventual success.
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
update_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
status=ConnectorCredentialPairStatus.PAUSED,
|
||||
)
|
||||
|
||||
# NOTE: At this point, we haven't done heavy checks on whether or not the CC pairs should actually be indexed
|
||||
# Heavy check, should_index(), is called in _kickoff_indexing_tasks
|
||||
@@ -1289,19 +1296,14 @@ def _check_chunk_usage_limit(tenant_id: str) -> None:
|
||||
if not USAGE_LIMITS_ENABLED:
|
||||
return
|
||||
|
||||
from onyx.db.usage import check_usage_limit
|
||||
from onyx.db.usage import UsageType
|
||||
from onyx.server.usage_limits import get_limit_for_usage_type
|
||||
from onyx.server.usage_limits import is_tenant_on_trial
|
||||
|
||||
is_trial = is_tenant_on_trial(tenant_id)
|
||||
limit = get_limit_for_usage_type(UsageType.CHUNKS_INDEXED, is_trial)
|
||||
from onyx.server.usage_limits import check_usage_and_raise
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
check_usage_limit(
|
||||
check_usage_and_raise(
|
||||
db_session=db_session,
|
||||
usage_type=UsageType.CHUNKS_INDEXED,
|
||||
limit=limit,
|
||||
tenant_id=tenant_id,
|
||||
pending_amount=0, # Just check current usage
|
||||
)
|
||||
|
||||
@@ -1321,7 +1323,7 @@ def _docprocessing_task(
|
||||
if USAGE_LIMITS_ENABLED:
|
||||
try:
|
||||
_check_chunk_usage_limit(tenant_id)
|
||||
except UsageLimitExceededError as e:
|
||||
except HTTPException as e:
|
||||
# Log the error and fail the indexing attempt
|
||||
task_logger.error(
|
||||
f"Chunk indexing usage limit exceeded for tenant {tenant_id}: {e}"
|
||||
|
||||
@@ -10,7 +10,6 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine.time_utils import get_db_current_time
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
@@ -126,18 +125,9 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
|
||||
|
||||
def is_in_repeated_error_state(
|
||||
cc_pair_id: int, search_settings_id: int, db_session: Session
|
||||
cc_pair: ConnectorCredentialPair, search_settings_id: int, db_session: Session
|
||||
) -> bool:
|
||||
"""Checks if the cc pair / search setting combination is in a repeated error state."""
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise RuntimeError(
|
||||
f"is_in_repeated_error_state - could not find cc_pair with id={cc_pair_id}"
|
||||
)
|
||||
|
||||
# if the connector doesn't have a refresh_freq, a single failed attempt is enough
|
||||
number_of_failed_attempts_in_a_row_needed = (
|
||||
NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE
|
||||
@@ -146,7 +136,7 @@ def is_in_repeated_error_state(
|
||||
)
|
||||
|
||||
most_recent_index_attempts = get_recent_attempts_for_cc_pair(
|
||||
cc_pair_id=cc_pair_id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
search_settings_id=search_settings_id,
|
||||
limit=number_of_failed_attempts_in_a_row_needed,
|
||||
db_session=db_session,
|
||||
@@ -180,7 +170,7 @@ def should_index(
|
||||
db_session=db_session,
|
||||
)
|
||||
all_recent_errored = is_in_repeated_error_state(
|
||||
cc_pair_id=cc_pair.id,
|
||||
cc_pair=cc_pair,
|
||||
search_settings_id=search_settings_instance.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from onyx.configs.app_configs import BRAINTRUST_API_KEY
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_PERMISSIONS_EMAIL
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_PROJECT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.evals.eval import run_eval
|
||||
from onyx.evals.models import EvalConfigurationOptions
|
||||
@@ -33,3 +39,109 @@ def eval_run_task(
|
||||
except Exception:
|
||||
logger.error("Failed to run eval task")
|
||||
raise
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.SCHEDULED_EVAL_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT * 5, # Allow more time for multiple datasets
|
||||
bind=True,
|
||||
trail=False,
|
||||
)
|
||||
def scheduled_eval_task(self: Task, **kwargs: Any) -> None:
|
||||
"""
|
||||
Scheduled task to run evaluations on configured datasets.
|
||||
Runs weekly on Sunday at midnight UTC.
|
||||
|
||||
Configure via environment variables (with defaults):
|
||||
- SCHEDULED_EVAL_DATASET_NAMES: Comma-separated list of Braintrust dataset names
|
||||
- SCHEDULED_EVAL_PERMISSIONS_EMAIL: Email for search permissions (default: roshan@onyx.app)
|
||||
- SCHEDULED_EVAL_PROJECT: Braintrust project name
|
||||
"""
|
||||
if not BRAINTRUST_API_KEY:
|
||||
logger.error("BRAINTRUST_API_KEY is not configured, cannot run scheduled evals")
|
||||
return
|
||||
|
||||
if not SCHEDULED_EVAL_PROJECT:
|
||||
logger.error(
|
||||
"SCHEDULED_EVAL_PROJECT is not configured, cannot run scheduled evals"
|
||||
)
|
||||
return
|
||||
|
||||
if not SCHEDULED_EVAL_DATASET_NAMES:
|
||||
logger.info("No scheduled eval datasets configured, skipping")
|
||||
return
|
||||
|
||||
if not SCHEDULED_EVAL_PERMISSIONS_EMAIL:
|
||||
logger.error("SCHEDULED_EVAL_PERMISSIONS_EMAIL not configured")
|
||||
return
|
||||
|
||||
project_name = SCHEDULED_EVAL_PROJECT
|
||||
dataset_names = SCHEDULED_EVAL_DATASET_NAMES
|
||||
permissions_email = SCHEDULED_EVAL_PERMISSIONS_EMAIL
|
||||
|
||||
# Create a timestamp for the scheduled run
|
||||
run_timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
|
||||
logger.info(
|
||||
f"Starting scheduled eval pipeline for project '{project_name}' "
|
||||
f"with {len(dataset_names)} dataset(s): {dataset_names}"
|
||||
)
|
||||
|
||||
pipeline_start = datetime.now(timezone.utc)
|
||||
results: list[dict[str, Any]] = []
|
||||
|
||||
for dataset_name in dataset_names:
|
||||
start_time = datetime.now(timezone.utc)
|
||||
error_message: str | None = None
|
||||
success = False
|
||||
|
||||
# Create informative experiment name for scheduled runs
|
||||
experiment_name = f"{dataset_name} - {run_timestamp}"
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
f"Running scheduled eval for dataset: {dataset_name} "
|
||||
f"(project: {project_name})"
|
||||
)
|
||||
|
||||
configuration = EvalConfigurationOptions(
|
||||
search_permissions_email=permissions_email,
|
||||
dataset_name=dataset_name,
|
||||
no_send_logs=False,
|
||||
braintrust_project=project_name,
|
||||
experiment_name=experiment_name,
|
||||
)
|
||||
|
||||
result = run_eval(
|
||||
configuration=configuration,
|
||||
remote_dataset_name=dataset_name,
|
||||
)
|
||||
success = result.success
|
||||
logger.info(f"Completed eval for {dataset_name}: success={success}")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to run scheduled eval for {dataset_name}")
|
||||
error_message = str(e)
|
||||
success = False
|
||||
|
||||
end_time = datetime.now(timezone.utc)
|
||||
|
||||
results.append(
|
||||
{
|
||||
"dataset_name": dataset_name,
|
||||
"success": success,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"error_message": error_message,
|
||||
}
|
||||
)
|
||||
|
||||
pipeline_end = datetime.now(timezone.utc)
|
||||
total_duration = (pipeline_end - pipeline_start).total_seconds()
|
||||
|
||||
passed_count = sum(1 for r in results if r["success"])
|
||||
logger.info(
|
||||
f"Scheduled eval pipeline completed: {passed_count}/{len(results)} passed "
|
||||
f"in {total_duration:.1f}s"
|
||||
)
|
||||
|
||||
@@ -5,6 +5,9 @@ from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
sync_llm_models_from_github,
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -26,24 +29,9 @@ def check_for_auto_llm_updates(self: Task, *, tenant_id: str) -> bool | None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
fetch_llm_recommendations_from_github,
|
||||
)
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
sync_llm_models_from_github,
|
||||
)
|
||||
|
||||
# Fetch config from GitHub
|
||||
config = fetch_llm_recommendations_from_github()
|
||||
|
||||
if not config:
|
||||
task_logger.warning("Failed to fetch GitHub config")
|
||||
return None
|
||||
|
||||
# Sync to database
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
results = sync_llm_models_from_github(db_session, config)
|
||||
results = sync_llm_models_from_github(db_session)
|
||||
|
||||
if results:
|
||||
task_logger.info(f"Auto mode sync results: {results}")
|
||||
|
||||
@@ -886,9 +886,7 @@ def monitor_celery_queues_helper(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
|
||||
)
|
||||
n_docprocessing = celery_get_queue_length(OnyxCeleryQueues.DOCPROCESSING, r_celery)
|
||||
n_user_files_indexing = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILES_INDEXING, r_celery
|
||||
)
|
||||
|
||||
n_user_file_processing = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
|
||||
)
|
||||
@@ -924,7 +922,6 @@ def monitor_celery_queues_helper(
|
||||
f"docfetching_prefetched={len(n_docfetching_prefetched)} "
|
||||
f"docprocessing={n_docprocessing} "
|
||||
f"docprocessing_prefetched={len(n_docprocessing_prefetched)} "
|
||||
f"user_files_indexing={n_user_files_indexing} "
|
||||
f"user_file_processing={n_user_file_processing} "
|
||||
f"user_file_project_sync={n_user_file_project_sync} "
|
||||
f"user_file_delete={n_user_file_delete} "
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import datetime
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
@@ -19,11 +18,9 @@ from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -32,20 +29,13 @@ from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import FileRecord
|
||||
from onyx.db.models import SearchDoc
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.document_index.interfaces import VespaDocumentUserFields
|
||||
from onyx.document_index.vespa.shared_utils.utils import (
|
||||
replace_invalid_doc_id_characters,
|
||||
)
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.file_store import S3BackedFileStore
|
||||
from onyx.file_store.utils import user_file_id_to_plaintext_file_name
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.adapters.user_file_indexing_adapter import UserFileIndexingAdapter
|
||||
@@ -618,315 +608,3 @@ def process_single_user_file_project_sync(
|
||||
file_lock.release()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_legacy_user_file_doc_id(old_id: str) -> str:
|
||||
# Convert USER_FILE_CONNECTOR__<uuid> -> FILE_CONNECTOR__<uuid> for legacy values
|
||||
user_prefix = "USER_FILE_CONNECTOR__"
|
||||
file_prefix = "FILE_CONNECTOR__"
|
||||
if old_id.startswith(user_prefix):
|
||||
remainder = old_id[len(user_prefix) :]
|
||||
return file_prefix + remainder
|
||||
return old_id
|
||||
|
||||
|
||||
def update_legacy_plaintext_file_records() -> None:
|
||||
"""Migrate legacy plaintext cache objects from int-based keys to UUID-based
|
||||
keys. Copies each S3 object to its expected UUID key and updates DB.
|
||||
|
||||
Examples:
|
||||
- Old key: bucket/schema/plaintext_<int>
|
||||
- New key: bucket/schema/plaintext_<uuid>
|
||||
"""
|
||||
|
||||
task_logger.info("update_legacy_plaintext_file_records - Starting")
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
store = get_default_file_store()
|
||||
|
||||
if not isinstance(store, S3BackedFileStore):
|
||||
task_logger.info(
|
||||
"update_legacy_plaintext_file_records - Skipping non-S3 store"
|
||||
)
|
||||
return
|
||||
|
||||
s3_client = store._get_s3_client()
|
||||
bucket_name = store._get_bucket_name()
|
||||
|
||||
# Select PLAINTEXT_CACHE records whose object_key ends with 'plaintext_' + non-hyphen chars
|
||||
# Example: 'some/path/plaintext_abc123' matches; '.../plaintext_foo-bar' does not
|
||||
plaintext_records: Sequence[FileRecord] = (
|
||||
db_session.execute(
|
||||
sa.select(FileRecord).where(
|
||||
FileRecord.file_origin == FileOrigin.PLAINTEXT_CACHE,
|
||||
FileRecord.object_key.op("~")(r"plaintext_[^-]+$"),
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"update_legacy_plaintext_file_records - Found {len(plaintext_records)} plaintext records to update"
|
||||
)
|
||||
|
||||
normalized = 0
|
||||
for fr in plaintext_records:
|
||||
try:
|
||||
expected_key = store._get_s3_key(fr.file_id)
|
||||
if fr.object_key == expected_key:
|
||||
continue
|
||||
|
||||
if fr.bucket_name is None:
|
||||
task_logger.warning(f"id={fr.file_id} - Bucket name is None")
|
||||
continue
|
||||
|
||||
if fr.object_key is None:
|
||||
task_logger.warning(f"id={fr.file_id} - Object key is None")
|
||||
continue
|
||||
|
||||
# Copy old object to new key
|
||||
copy_source = f"{fr.bucket_name}/{fr.object_key}"
|
||||
s3_client.copy_object(
|
||||
CopySource=copy_source,
|
||||
Bucket=bucket_name,
|
||||
Key=expected_key,
|
||||
MetadataDirective="COPY",
|
||||
)
|
||||
|
||||
# Delete old object (best-effort)
|
||||
try:
|
||||
s3_client.delete_object(Bucket=fr.bucket_name, Key=fr.object_key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Update DB record with new key
|
||||
fr.object_key = expected_key
|
||||
db_session.add(fr)
|
||||
normalized += 1
|
||||
except Exception as e:
|
||||
task_logger.warning(f"id={fr.file_id} - {e.__class__.__name__}")
|
||||
|
||||
if normalized:
|
||||
db_session.commit()
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task normalized {normalized} plaintext objects"
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.USER_FILE_DOCID_MIGRATION,
|
||||
ignore_result=True,
|
||||
bind=True,
|
||||
)
|
||||
def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
|
||||
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Starting for tenant={tenant_id}"
|
||||
)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
lock: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.USER_FILE_DOCID_MIGRATION_LOCK,
|
||||
timeout=CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
if not lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Lock held, skipping tenant={tenant_id}"
|
||||
)
|
||||
return False
|
||||
|
||||
updated_count = 0
|
||||
try:
|
||||
update_legacy_plaintext_file_records()
|
||||
# Track lock renewal
|
||||
last_lock_time = time.monotonic()
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
|
||||
# 20 is the documented default for httpx max_keepalive_connections
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
active_settings = get_active_search_settings(db_session)
|
||||
document_index = get_default_document_index(
|
||||
search_settings=active_settings.primary,
|
||||
secondary_search_settings=active_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
|
||||
retry_index = RetryDocumentIndex(document_index)
|
||||
|
||||
# Select user files with a legacy doc id that have not been migrated
|
||||
user_files = (
|
||||
db_session.execute(
|
||||
sa.select(UserFile).where(
|
||||
sa.and_(
|
||||
UserFile.document_id.is_not(None),
|
||||
UserFile.document_id_migrated.is_(False),
|
||||
)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Found {len(user_files)} user files to migrate"
|
||||
)
|
||||
|
||||
# Query all SearchDocs that need updating
|
||||
search_docs = (
|
||||
db_session.execute(
|
||||
sa.select(SearchDoc).where(
|
||||
SearchDoc.document_id.like("%FILE_CONNECTOR__%")
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Found {len(search_docs)} search docs to update"
|
||||
)
|
||||
|
||||
# Build a map of normalized doc IDs to SearchDocs
|
||||
search_doc_map: dict[str, list[SearchDoc]] = {}
|
||||
for sd in search_docs:
|
||||
doc_id = sd.document_id
|
||||
if search_doc_map.get(doc_id) is None:
|
||||
search_doc_map[doc_id] = []
|
||||
search_doc_map[doc_id].append(sd)
|
||||
|
||||
task_logger.debug(
|
||||
f"user_file_docid_migration_task - Built search doc map with {len(search_doc_map)} entries"
|
||||
)
|
||||
|
||||
ids_preview = list(search_doc_map.keys())[:5]
|
||||
task_logger.debug(
|
||||
f"user_file_docid_migration_task - First few search_doc_map ids: {ids_preview if ids_preview else 'No ids found'}"
|
||||
)
|
||||
task_logger.debug(
|
||||
f"user_file_docid_migration_task - search_doc_map total items: "
|
||||
f"{sum(len(docs) for docs in search_doc_map.values())}"
|
||||
)
|
||||
for user_file in user_files:
|
||||
# Periodically renew the Redis lock to prevent expiry mid-run
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT / 4
|
||||
):
|
||||
renewed = False
|
||||
try:
|
||||
# extend lock ttl to full timeout window
|
||||
lock.extend(CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT)
|
||||
renewed = True
|
||||
except Exception:
|
||||
# if extend fails, best-effort reacquire as a fallback
|
||||
try:
|
||||
lock.reacquire()
|
||||
renewed = True
|
||||
except Exception:
|
||||
renewed = False
|
||||
last_lock_time = current_time
|
||||
if not renewed or not lock.owned():
|
||||
task_logger.error(
|
||||
"user_file_docid_migration_task - Lost lock ownership or failed to renew; aborting for safety"
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
clean_old_doc_id = replace_invalid_doc_id_characters(
|
||||
user_file.document_id
|
||||
)
|
||||
normalized_doc_id = _normalize_legacy_user_file_doc_id(
|
||||
clean_old_doc_id
|
||||
)
|
||||
user_project_ids = [project.id for project in user_file.projects]
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Migrating user file {user_file.id} with doc_id {normalized_doc_id}"
|
||||
)
|
||||
|
||||
index_name = active_settings.primary.index_name
|
||||
|
||||
# First find the chunks count using direct Vespa query
|
||||
selection = f"{index_name}.document_id=='{normalized_doc_id}'"
|
||||
|
||||
# Count all chunks for this document
|
||||
chunk_count = _get_document_chunk_count(
|
||||
index_name=index_name,
|
||||
selection=selection,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Found {chunk_count} chunks for document {normalized_doc_id}"
|
||||
)
|
||||
|
||||
# Now update Vespa chunks with the found chunk count using retry_index
|
||||
# WARNING: In the future this will error; we no longer want
|
||||
# to support changing document ID.
|
||||
# TODO(andrei): Delete soon.
|
||||
retry_index.update_single(
|
||||
doc_id=str(normalized_doc_id),
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
fields=VespaDocumentFields(document_id=str(user_file.id)),
|
||||
user_fields=VespaDocumentUserFields(
|
||||
user_projects=user_project_ids
|
||||
),
|
||||
)
|
||||
user_file.chunk_count = chunk_count
|
||||
|
||||
# Update the SearchDocs
|
||||
actual_doc_id = str(user_file.document_id)
|
||||
normalized_actual_doc_id = _normalize_legacy_user_file_doc_id(
|
||||
actual_doc_id
|
||||
)
|
||||
if (
|
||||
normalized_doc_id in search_doc_map
|
||||
or normalized_actual_doc_id in search_doc_map
|
||||
):
|
||||
to_update = (
|
||||
search_doc_map[normalized_doc_id]
|
||||
if normalized_doc_id in search_doc_map
|
||||
else search_doc_map[normalized_actual_doc_id]
|
||||
)
|
||||
task_logger.debug(
|
||||
f"user_file_docid_migration_task - Updating {len(to_update)} search docs for user file {user_file.id}"
|
||||
)
|
||||
for search_doc in to_update:
|
||||
search_doc.document_id = str(user_file.id)
|
||||
db_session.add(search_doc)
|
||||
|
||||
user_file.document_id_migrated = True
|
||||
db_session.add(user_file)
|
||||
db_session.commit()
|
||||
updated_count += 1
|
||||
except Exception as per_file_exc:
|
||||
# Rollback the current transaction and continue with the next file
|
||||
db_session.rollback()
|
||||
task_logger.exception(
|
||||
f"user_file_docid_migration_task - Error migrating user file {user_file.id} - "
|
||||
f"{per_file_exc.__class__.__name__}"
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Updated {updated_count} user files"
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"user_file_docid_migration_task - Completed for tenant={tenant_id} (updated={updated_count})"
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"user_file_docid_migration_task - Error during execution for tenant={tenant_id} "
|
||||
f"(updated={updated_count}) exception={e.__class__.__name__}"
|
||||
)
|
||||
return False
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
57
backend/onyx/chat/chat_processing_checker.py
Normal file
57
backend/onyx/chat/chat_processing_checker.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from uuid import UUID
|
||||
|
||||
from redis.client import Redis
|
||||
|
||||
# Redis key prefixes for chat message processing
|
||||
PREFIX = "chatprocessing"
|
||||
FENCE_PREFIX = f"{PREFIX}_fence"
|
||||
FENCE_TTL = 30 * 60 # 30 minutes
|
||||
|
||||
|
||||
def _get_fence_key(chat_session_id: UUID) -> str:
|
||||
"""
|
||||
Generate the Redis key for a chat session processing a message.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
|
||||
Returns:
|
||||
The fence key string (tenant_id is automatically added by the Redis client)
|
||||
"""
|
||||
return f"{FENCE_PREFIX}_{chat_session_id}"
|
||||
|
||||
|
||||
def set_processing_status(
|
||||
chat_session_id: UUID, redis_client: Redis, value: bool
|
||||
) -> None:
|
||||
"""
|
||||
Set or clear the fence for a chat session processing a message.
|
||||
|
||||
If the key exists, we are processing a message. If the key does not exist, we are not processing a message.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
redis_client: The Redis client to use
|
||||
value: True to set the fence, False to clear it
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
|
||||
if value:
|
||||
redis_client.set(fence_key, 0, ex=FENCE_TTL)
|
||||
else:
|
||||
redis_client.delete(fence_key)
|
||||
|
||||
|
||||
def is_chat_session_processing(chat_session_id: UUID, redis_client: Redis) -> bool:
|
||||
"""
|
||||
Check if the chat session is processing a message.
|
||||
|
||||
Args:
|
||||
chat_session_id: The UUID of the chat session
|
||||
redis_client: The Redis client to use
|
||||
|
||||
Returns:
|
||||
True if the chat session is processing a message, False otherwise
|
||||
"""
|
||||
fence_key = _get_fence_key(chat_session_id)
|
||||
return bool(redis_client.exists(fence_key))
|
||||
@@ -94,6 +94,7 @@ class ChatStateContainer:
|
||||
|
||||
def run_chat_loop_with_state_containers(
|
||||
func: Callable[..., None],
|
||||
completion_callback: Callable[[ChatStateContainer], None],
|
||||
is_connected: Callable[[], bool],
|
||||
emitter: Emitter,
|
||||
state_container: ChatStateContainer,
|
||||
@@ -196,3 +197,12 @@ def run_chat_loop_with_state_containers(
|
||||
# Skip waiting if user disconnected to exit quickly.
|
||||
if is_connected():
|
||||
wait_on_background(thread)
|
||||
try:
|
||||
completion_callback(state_container)
|
||||
except Exception as e:
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=PacketException(type="error", exception=e),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -18,12 +18,10 @@ from onyx.background.celery.tasks.kg_processing.kg_indexing import (
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import get_chat_messages_by_session
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
@@ -48,13 +46,10 @@ from onyx.kg.models import KGException
|
||||
from onyx.kg.setup.kg_default_entity_definitions import (
|
||||
populate_missing_default_entity_types__commit,
|
||||
)
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.prompts.chat_prompts import ADDITIONAL_CONTEXT_PROMPT
|
||||
from onyx.prompts.chat_prompts import TOOL_CALL_RESPONSE_CROSS_MESSAGE
|
||||
from onyx.prompts.tool_prompts import TOOL_CALL_FAILURE_PROMPT
|
||||
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import CitationInfo
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.tool_implementations.custom.custom_tool import (
|
||||
@@ -103,89 +98,6 @@ def create_chat_session_from_request(
|
||||
)
|
||||
|
||||
|
||||
def prepare_chat_message_request(
|
||||
message_text: str,
|
||||
user: User | None,
|
||||
persona_id: int | None,
|
||||
# Does the question need to have a persona override
|
||||
persona_override_config: PersonaOverrideConfig | None,
|
||||
message_ts_to_respond_to: str | None,
|
||||
retrieval_details: RetrievalDetails | None,
|
||||
rerank_settings: RerankingDetails | None,
|
||||
db_session: Session,
|
||||
skip_gen_ai_answer_generation: bool = False,
|
||||
llm_override: LLMOverride | None = None,
|
||||
allowed_tool_ids: list[int] | None = None,
|
||||
forced_tool_ids: list[int] | None = None,
|
||||
) -> CreateChatMessageRequest:
|
||||
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
|
||||
new_chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description=None,
|
||||
user_id=user.id if user else None,
|
||||
# If using an override, this id will be ignored later on
|
||||
persona_id=persona_id or DEFAULT_PERSONA_ID,
|
||||
onyxbot_flow=True,
|
||||
slack_thread_id=message_ts_to_respond_to,
|
||||
)
|
||||
|
||||
return CreateChatMessageRequest(
|
||||
chat_session_id=new_chat_session.id,
|
||||
parent_message_id=None, # It's a standalone chat session each time
|
||||
message=message_text,
|
||||
file_descriptors=[], # Currently SlackBot/answer api do not support files in the context
|
||||
# Can always override the persona for the single query, if it's a normal persona
|
||||
# then it will be treated the same
|
||||
persona_override_config=persona_override_config,
|
||||
search_doc_ids=None,
|
||||
retrieval_options=retrieval_details,
|
||||
rerank_settings=rerank_settings,
|
||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||
llm_override=llm_override,
|
||||
allowed_tool_ids=allowed_tool_ids,
|
||||
forced_tool_ids=forced_tool_ids,
|
||||
)
|
||||
|
||||
|
||||
def combine_message_thread(
|
||||
messages: list[ThreadMessage],
|
||||
max_tokens: int | None,
|
||||
llm_tokenizer: BaseTokenizer,
|
||||
) -> str:
|
||||
"""Used to create a single combined message context from threads"""
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
message_strs: list[str] = []
|
||||
total_token_count = 0
|
||||
|
||||
for message in reversed(messages):
|
||||
if message.role == MessageType.USER:
|
||||
role_str = message.role.value.upper()
|
||||
if message.sender:
|
||||
role_str += " " + message.sender
|
||||
else:
|
||||
# Since other messages might have the user identifying information
|
||||
# better to use Unknown for symmetry
|
||||
role_str += " Unknown"
|
||||
else:
|
||||
role_str = message.role.value.upper()
|
||||
|
||||
msg_str = f"{role_str}:\n{message.message}"
|
||||
message_token_count = len(llm_tokenizer.encode(msg_str))
|
||||
|
||||
if (
|
||||
max_tokens is not None
|
||||
and total_token_count + message_token_count > max_tokens
|
||||
):
|
||||
break
|
||||
|
||||
message_strs.insert(0, msg_str)
|
||||
total_token_count += message_token_count
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
|
||||
def create_chat_history_chain(
|
||||
chat_session_id: UUID,
|
||||
db_session: Session,
|
||||
@@ -247,31 +159,6 @@ def create_chat_history_chain(
|
||||
return mainline_messages
|
||||
|
||||
|
||||
def combine_message_chain(
|
||||
messages: list[ChatMessage],
|
||||
token_limit: int,
|
||||
msg_limit: int | None = None,
|
||||
) -> str:
|
||||
"""Used for secondary LLM flows that require the chat history,"""
|
||||
message_strs: list[str] = []
|
||||
total_token_count = 0
|
||||
|
||||
if msg_limit is not None:
|
||||
messages = messages[-msg_limit:]
|
||||
|
||||
for message in cast(list[ChatMessage], reversed(messages)):
|
||||
message_token_count = message.token_count
|
||||
|
||||
if total_token_count + message_token_count > token_limit:
|
||||
break
|
||||
|
||||
role = message.message_type.value.upper()
|
||||
message_strs.insert(0, f"{role}:\n{message.message}")
|
||||
total_token_count += message_token_count
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
|
||||
def reorganize_citations(
|
||||
answer: str, citations: list[CitationInfo]
|
||||
) -> tuple[str, list[CitationInfo]]:
|
||||
@@ -412,7 +299,7 @@ def create_temporary_persona(
|
||||
num_chunks=persona_config.num_chunks,
|
||||
llm_relevance_filter=persona_config.llm_relevance_filter,
|
||||
llm_filter_extraction=persona_config.llm_filter_extraction,
|
||||
recency_bias=persona_config.recency_bias,
|
||||
recency_bias=RecencyBiasSetting.BASE_DECAY,
|
||||
llm_model_provider_override=persona_config.llm_model_provider_override,
|
||||
llm_model_version_override=persona_config.llm_model_version_override,
|
||||
)
|
||||
@@ -582,6 +469,71 @@ def load_all_chat_files(
|
||||
return files
|
||||
|
||||
|
||||
def convert_chat_history_basic(
|
||||
chat_history: list[ChatMessage],
|
||||
token_counter: Callable[[str], int],
|
||||
max_individual_message_tokens: int | None = None,
|
||||
max_total_tokens: int | None = None,
|
||||
) -> list[ChatMessageSimple]:
|
||||
"""Convert ChatMessage history to ChatMessageSimple format with no tool calls or files included.
|
||||
|
||||
Args:
|
||||
chat_history: List of ChatMessage objects to convert
|
||||
token_counter: Function to count tokens in a message string
|
||||
max_individual_message_tokens: If set, messages exceeding this number of tokens are dropped.
|
||||
If None, no messages are dropped based on individual token count.
|
||||
max_total_tokens: If set, maximum number of tokens allowed for the entire history.
|
||||
If None, the history is not trimmed based on total token count.
|
||||
|
||||
Returns:
|
||||
List of ChatMessageSimple objects
|
||||
"""
|
||||
# Defensive: treat a non-positive total budget as "no history".
|
||||
if max_total_tokens is not None and max_total_tokens <= 0:
|
||||
return []
|
||||
|
||||
# Convert only the core USER/ASSISTANT messages; omit files and tool calls.
|
||||
converted: list[ChatMessageSimple] = []
|
||||
for chat_message in chat_history:
|
||||
if chat_message.message_type not in (MessageType.USER, MessageType.ASSISTANT):
|
||||
continue
|
||||
|
||||
message = chat_message.message or ""
|
||||
token_count = getattr(chat_message, "token_count", None)
|
||||
if token_count is None:
|
||||
token_count = token_counter(message)
|
||||
|
||||
# Drop any single message that would dominate the context window.
|
||||
if (
|
||||
max_individual_message_tokens is not None
|
||||
and token_count > max_individual_message_tokens
|
||||
):
|
||||
continue
|
||||
|
||||
converted.append(
|
||||
ChatMessageSimple(
|
||||
message=message,
|
||||
token_count=token_count,
|
||||
message_type=chat_message.message_type,
|
||||
image_files=None,
|
||||
)
|
||||
)
|
||||
|
||||
if max_total_tokens is None:
|
||||
return converted
|
||||
|
||||
# Enforce a max total budget by keeping a contiguous suffix of the conversation.
|
||||
trimmed_reversed: list[ChatMessageSimple] = []
|
||||
total_tokens = 0
|
||||
for msg in reversed(converted):
|
||||
if total_tokens + msg.token_count > max_total_tokens:
|
||||
break
|
||||
trimmed_reversed.append(msg)
|
||||
total_tokens += msg.token_count
|
||||
|
||||
return list(reversed(trimmed_reversed))
|
||||
|
||||
|
||||
def convert_chat_history(
|
||||
chat_history: list[ChatMessage],
|
||||
files: list[ChatLoadedFile],
|
||||
|
||||
@@ -4,14 +4,15 @@ Dynamic Citation Processor for LLM Responses
|
||||
This module provides a citation processor that can:
|
||||
- Accept citation number to SearchDoc mappings dynamically
|
||||
- Process token streams from LLMs to extract citations
|
||||
- Optionally replace citation markers with formatted markdown links
|
||||
- Emit CitationInfo objects for detected citations (when replacing)
|
||||
- Track all seen citations regardless of replacement mode
|
||||
- Handle citations in three modes: REMOVE, KEEP_MARKERS, or HYPERLINK
|
||||
- Emit CitationInfo objects for detected citations (in HYPERLINK mode)
|
||||
- Track all seen citations regardless of mode
|
||||
- Maintain a list of cited documents in order of first citation
|
||||
"""
|
||||
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from typing import TypeAlias
|
||||
|
||||
from onyx.configs.chat_configs import STOP_STREAM_PAT
|
||||
@@ -23,6 +24,29 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class CitationMode(Enum):
|
||||
"""Defines how citations should be handled in the output.
|
||||
|
||||
REMOVE: Citations are completely removed from output text.
|
||||
No CitationInfo objects are emitted.
|
||||
Use case: When you need to remove citations from the output if they are not shared with the user
|
||||
(e.g. in discord bot, public slack bot).
|
||||
|
||||
KEEP_MARKERS: Original citation markers like [1], [2] are preserved unchanged.
|
||||
No CitationInfo objects are emitted.
|
||||
Use case: When you need to track citations in research agent and later process
|
||||
them with collapse_citations() to renumber.
|
||||
|
||||
HYPERLINK: Citations are replaced with markdown links like [[1]](url).
|
||||
CitationInfo objects are emitted for UI tracking.
|
||||
Use case: Final reports shown to users with clickable links.
|
||||
"""
|
||||
|
||||
REMOVE = "remove"
|
||||
KEEP_MARKERS = "keep_markers"
|
||||
HYPERLINK = "hyperlink"
|
||||
|
||||
|
||||
CitationMapping: TypeAlias = dict[int, SearchDoc]
|
||||
|
||||
|
||||
@@ -48,29 +72,37 @@ class DynamicCitationProcessor:
|
||||
|
||||
This processor is designed for multi-turn conversations where the citation
|
||||
number to document mapping is provided externally. It processes streaming
|
||||
tokens from an LLM, detects citations (e.g., [1], [2,3], [[4]]), and based
|
||||
on the `replace_citation_tokens` setting:
|
||||
tokens from an LLM, detects citations (e.g., [1], [2,3], [[4]]), and handles
|
||||
them according to the configured CitationMode:
|
||||
|
||||
When replace_citation_tokens=True (default):
|
||||
CitationMode.HYPERLINK (default):
|
||||
1. Replaces citation markers with formatted markdown links (e.g., [[1]](url))
|
||||
2. Emits CitationInfo objects for tracking
|
||||
3. Maintains the order in which documents were first cited
|
||||
Use case: Final reports shown to users with clickable links.
|
||||
|
||||
When replace_citation_tokens=False:
|
||||
1. Preserves original citation markers in the output text
|
||||
CitationMode.KEEP_MARKERS:
|
||||
1. Preserves original citation markers like [1], [2] unchanged
|
||||
2. Does NOT emit CitationInfo objects
|
||||
3. Still tracks all seen citations via get_seen_citations()
|
||||
Use case: When citations need later processing (e.g., renumbering).
|
||||
|
||||
CitationMode.REMOVE:
|
||||
1. Removes citation markers entirely from the output text
|
||||
2. Does NOT emit CitationInfo objects
|
||||
3. Still tracks all seen citations via get_seen_citations()
|
||||
Use case: Research agent intermediate reports.
|
||||
|
||||
Features:
|
||||
- Accepts citation number → SearchDoc mapping via update_citation_mapping()
|
||||
- Configurable citation replacement behavior at initialization
|
||||
- Always tracks seen citations regardless of replacement mode
|
||||
- Configurable citation mode at initialization
|
||||
- Always tracks seen citations regardless of mode
|
||||
- Holds back tokens that might be partial citations
|
||||
- Maintains list of cited SearchDocs in order of first citation
|
||||
- Handles unicode bracket variants (【】, [])
|
||||
- Skips citation processing inside code blocks
|
||||
|
||||
Example (with citation replacement - default):
|
||||
Example (HYPERLINK mode - default):
|
||||
processor = DynamicCitationProcessor()
|
||||
|
||||
# Set up citation mapping
|
||||
@@ -87,8 +119,8 @@ class DynamicCitationProcessor:
|
||||
# Get cited documents at the end
|
||||
cited_docs = processor.get_cited_documents()
|
||||
|
||||
Example (without citation replacement):
|
||||
processor = DynamicCitationProcessor(replace_citation_tokens=False)
|
||||
Example (KEEP_MARKERS mode):
|
||||
processor = DynamicCitationProcessor(citation_mode=CitationMode.KEEP_MARKERS)
|
||||
processor.update_citation_mapping({1: search_doc1, 2: search_doc2})
|
||||
|
||||
# Process tokens from LLM
|
||||
@@ -99,26 +131,42 @@ class DynamicCitationProcessor:
|
||||
|
||||
# Get all seen citations after processing
|
||||
seen_citations = processor.get_seen_citations() # {1: search_doc1, ...}
|
||||
|
||||
Example (REMOVE mode):
|
||||
processor = DynamicCitationProcessor(citation_mode=CitationMode.REMOVE)
|
||||
processor.update_citation_mapping({1: search_doc1, 2: search_doc2})
|
||||
|
||||
# Process tokens - citations are removed but tracked
|
||||
for token in llm_stream:
|
||||
for result in processor.process_token(token):
|
||||
print(result) # Text without any citation markers
|
||||
|
||||
# Citations are still tracked
|
||||
seen_citations = processor.get_seen_citations()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
replace_citation_tokens: bool = True,
|
||||
citation_mode: CitationMode = CitationMode.HYPERLINK,
|
||||
stop_stream: str | None = STOP_STREAM_PAT,
|
||||
):
|
||||
"""
|
||||
Initialize the citation processor.
|
||||
|
||||
Args:
|
||||
replace_citation_tokens: If True (default), citations like [1] are replaced
|
||||
with formatted markdown links like [[1]](url) and CitationInfo objects
|
||||
are emitted. If False, original citation text is preserved in output
|
||||
and no CitationInfo objects are emitted. Regardless of this setting,
|
||||
all seen citations are tracked and available via get_seen_citations().
|
||||
citation_mode: How to handle citations in the output. One of:
|
||||
- CitationMode.HYPERLINK (default): Replace [1] with [[1]](url)
|
||||
and emit CitationInfo objects.
|
||||
- CitationMode.KEEP_MARKERS: Keep original [1] markers unchanged,
|
||||
no CitationInfo objects emitted.
|
||||
- CitationMode.REMOVE: Remove citations entirely from output,
|
||||
no CitationInfo objects emitted.
|
||||
All modes track seen citations via get_seen_citations().
|
||||
stop_stream: Optional stop token pattern to halt processing early.
|
||||
When this pattern is detected in the token stream, processing stops.
|
||||
Defaults to STOP_STREAM_PAT from chat configs.
|
||||
"""
|
||||
|
||||
# Citation mapping from citation number to SearchDoc
|
||||
self.citation_to_doc: CitationMapping = {}
|
||||
self.seen_citations: CitationMapping = {} # citation num -> SearchDoc
|
||||
@@ -128,7 +176,7 @@ class DynamicCitationProcessor:
|
||||
self.curr_segment = "" # tokens held for citation processing
|
||||
self.hold = "" # tokens held for stop token processing
|
||||
self.stop_stream = stop_stream
|
||||
self.replace_citation_tokens = replace_citation_tokens
|
||||
self.citation_mode = citation_mode
|
||||
|
||||
# Citation tracking
|
||||
self.cited_documents_in_order: list[SearchDoc] = (
|
||||
@@ -199,19 +247,21 @@ class DynamicCitationProcessor:
|
||||
5. Handles stop tokens
|
||||
6. Always tracks seen citations in self.seen_citations
|
||||
|
||||
Behavior depends on the `replace_citation_tokens` setting from __init__:
|
||||
- If True: Citations are replaced with [[n]](url) format and CitationInfo
|
||||
Behavior depends on the `citation_mode` setting from __init__:
|
||||
- HYPERLINK: Citations are replaced with [[n]](url) format and CitationInfo
|
||||
objects are yielded before each formatted citation
|
||||
- If False: Original citation text (e.g., [1]) is preserved in output
|
||||
and no CitationInfo objects are yielded
|
||||
- KEEP_MARKERS: Original citation markers like [1] are preserved unchanged,
|
||||
no CitationInfo objects are yielded
|
||||
- REMOVE: Citations are removed entirely from output,
|
||||
no CitationInfo objects are yielded
|
||||
|
||||
Args:
|
||||
token: The next token from the LLM stream, or None to signal end of stream.
|
||||
Pass None to flush any remaining buffered text at end of stream.
|
||||
|
||||
Yields:
|
||||
str: Text chunks to display. Citation format depends on replace_citation_tokens.
|
||||
CitationInfo: Citation metadata (only when replace_citation_tokens=True)
|
||||
str: Text chunks to display. Citation format depends on citation_mode.
|
||||
CitationInfo: Citation metadata (only when citation_mode=HYPERLINK)
|
||||
"""
|
||||
# None -> end of stream, flush remaining segment
|
||||
if token is None:
|
||||
@@ -299,17 +349,17 @@ class DynamicCitationProcessor:
|
||||
if self.non_citation_count > 5:
|
||||
self.recent_cited_documents.clear()
|
||||
|
||||
# Yield text before citation FIRST (preserve order)
|
||||
if intermatch_str:
|
||||
yield intermatch_str
|
||||
|
||||
# Process the citation (returns formatted citation text and CitationInfo objects)
|
||||
# Always tracks seen citations regardless of strip_citations flag
|
||||
# Always tracks seen citations regardless of citation_mode
|
||||
citation_text, citation_info_list = self._process_citation(
|
||||
match, has_leading_space, self.replace_citation_tokens
|
||||
match, has_leading_space
|
||||
)
|
||||
|
||||
if self.replace_citation_tokens:
|
||||
if self.citation_mode == CitationMode.HYPERLINK:
|
||||
# HYPERLINK mode: Replace citations with markdown links [[n]](url)
|
||||
# Yield text before citation FIRST (preserve order)
|
||||
if intermatch_str:
|
||||
yield intermatch_str
|
||||
# Yield CitationInfo objects BEFORE the citation text
|
||||
# This allows the frontend to receive citation metadata before the token
|
||||
# that contains [[n]](link), enabling immediate rendering
|
||||
@@ -318,10 +368,34 @@ class DynamicCitationProcessor:
|
||||
# Then yield the formatted citation text
|
||||
if citation_text:
|
||||
yield citation_text
|
||||
else:
|
||||
# When not stripping, yield the original citation text unchanged
|
||||
|
||||
elif self.citation_mode == CitationMode.KEEP_MARKERS:
|
||||
# KEEP_MARKERS mode: Preserve original citation markers unchanged
|
||||
# Yield text before citation
|
||||
if intermatch_str:
|
||||
yield intermatch_str
|
||||
# Yield the original citation marker as-is
|
||||
yield match.group()
|
||||
|
||||
else: # CitationMode.REMOVE
|
||||
# REMOVE mode: Remove citations entirely from output
|
||||
# This strips citation markers like [1], [2], 【1】 from the output text
|
||||
# When removing citations, we need to handle spacing to avoid issues like:
|
||||
# - "text [1] more" -> "text more" (double space)
|
||||
# - "text [1]." -> "text ." (space before punctuation)
|
||||
if intermatch_str:
|
||||
remaining_text = self.curr_segment[match_span[1] :]
|
||||
# Strip trailing space from intermatch if:
|
||||
# 1. Remaining text starts with space (avoids double space)
|
||||
# 2. Remaining text starts with punctuation (avoids space before punctuation)
|
||||
if intermatch_str[-1].isspace() and remaining_text:
|
||||
first_char = remaining_text[0]
|
||||
# Check if next char is space or common punctuation
|
||||
if first_char.isspace() or first_char in ".,;:!?)]}":
|
||||
intermatch_str = intermatch_str.rstrip()
|
||||
if intermatch_str:
|
||||
yield intermatch_str
|
||||
|
||||
self.non_citation_count = 0
|
||||
|
||||
# Leftover text could be part of next citation
|
||||
@@ -338,7 +412,7 @@ class DynamicCitationProcessor:
|
||||
yield result
|
||||
|
||||
def _process_citation(
|
||||
self, match: re.Match, has_leading_space: bool, replace_tokens: bool = True
|
||||
self, match: re.Match, has_leading_space: bool
|
||||
) -> tuple[str, list[CitationInfo]]:
|
||||
"""
|
||||
Process a single citation match and return formatted citation text and citation info objects.
|
||||
@@ -349,31 +423,28 @@ class DynamicCitationProcessor:
|
||||
This method always:
|
||||
1. Extracts citation numbers from the match
|
||||
2. Looks up the corresponding SearchDoc from the mapping
|
||||
3. Tracks seen citations in self.seen_citations (regardless of replace_tokens)
|
||||
3. Tracks seen citations in self.seen_citations (regardless of citation_mode)
|
||||
|
||||
When replace_tokens=True (controlled by self.replace_citation_tokens):
|
||||
When citation_mode is HYPERLINK:
|
||||
4. Creates formatted citation text as [[n]](url)
|
||||
5. Creates CitationInfo objects for new citations
|
||||
6. Handles deduplication of recently cited documents
|
||||
|
||||
When replace_tokens=False:
|
||||
4. Returns empty string and empty list (caller yields original match text)
|
||||
When citation_mode is REMOVE or KEEP_MARKERS:
|
||||
4. Returns empty string and empty list (caller handles output based on mode)
|
||||
|
||||
Args:
|
||||
match: Regex match object containing the citation pattern
|
||||
has_leading_space: Whether the text immediately before this citation
|
||||
ends with whitespace. Used to determine if a leading space should
|
||||
be added to the formatted output.
|
||||
replace_tokens: If True, return formatted text and CitationInfo objects.
|
||||
If False, only track seen citations and return empty results.
|
||||
This is passed from self.replace_citation_tokens by the caller.
|
||||
|
||||
Returns:
|
||||
Tuple of (formatted_citation_text, citation_info_list):
|
||||
- formatted_citation_text: Markdown-formatted citation text like
|
||||
"[[1]](https://example.com)" or empty string if replace_tokens=False
|
||||
"[[1]](https://example.com)" or empty string if not in HYPERLINK mode
|
||||
- citation_info_list: List of CitationInfo objects for newly cited
|
||||
documents, or empty list if replace_tokens=False
|
||||
documents, or empty list if not in HYPERLINK mode
|
||||
"""
|
||||
citation_str: str = match.group() # e.g., '[1]', '[1, 2, 3]', '[[1]]', '【1】'
|
||||
formatted = (
|
||||
@@ -411,11 +482,11 @@ class DynamicCitationProcessor:
|
||||
doc_id = search_doc.document_id
|
||||
link = search_doc.link or ""
|
||||
|
||||
# Always track seen citations regardless of replace_tokens setting
|
||||
# Always track seen citations regardless of citation_mode setting
|
||||
self.seen_citations[num] = search_doc
|
||||
|
||||
# When not replacing citation tokens, skip the rest of the processing
|
||||
if not replace_tokens:
|
||||
# Only generate formatted citations and CitationInfo in HYPERLINK mode
|
||||
if self.citation_mode != CitationMode.HYPERLINK:
|
||||
continue
|
||||
|
||||
# Format the citation text as [[n]](link)
|
||||
@@ -450,14 +521,14 @@ class DynamicCitationProcessor:
|
||||
"""
|
||||
Get the list of cited SearchDoc objects in the order they were first cited.
|
||||
|
||||
Note: This list is only populated when `replace_citation_tokens=True`.
|
||||
When `replace_citation_tokens=False`, this will return an empty list.
|
||||
Note: This list is only populated when `citation_mode=HYPERLINK`.
|
||||
When using REMOVE or KEEP_MARKERS mode, this will return an empty list.
|
||||
Use get_seen_citations() instead if you need to track citations without
|
||||
replacing them.
|
||||
emitting CitationInfo objects.
|
||||
|
||||
Returns:
|
||||
List of SearchDoc objects in the order they were first cited.
|
||||
Empty list if replace_citation_tokens=False.
|
||||
Empty list if citation_mode is not HYPERLINK.
|
||||
"""
|
||||
return self.cited_documents_in_order
|
||||
|
||||
@@ -465,14 +536,14 @@ class DynamicCitationProcessor:
|
||||
"""
|
||||
Get the list of cited document IDs in the order they were first cited.
|
||||
|
||||
Note: This list is only populated when `replace_citation_tokens=True`.
|
||||
When `replace_citation_tokens=False`, this will return an empty list.
|
||||
Note: This list is only populated when `citation_mode=HYPERLINK`.
|
||||
When using REMOVE or KEEP_MARKERS mode, this will return an empty list.
|
||||
Use get_seen_citations() instead if you need to track citations without
|
||||
replacing them.
|
||||
emitting CitationInfo objects.
|
||||
|
||||
Returns:
|
||||
List of document IDs (strings) in the order they were first cited.
|
||||
Empty list if replace_citation_tokens=False.
|
||||
Empty list if citation_mode is not HYPERLINK.
|
||||
"""
|
||||
return [doc.document_id for doc in self.cited_documents_in_order]
|
||||
|
||||
@@ -481,12 +552,12 @@ class DynamicCitationProcessor:
|
||||
Get all seen citations as a mapping from citation number to SearchDoc.
|
||||
|
||||
This returns all citations that have been encountered during processing,
|
||||
regardless of the `replace_citation_tokens` setting. Citations are tracked
|
||||
regardless of the `citation_mode` setting. Citations are tracked
|
||||
whenever they are parsed, making this useful for cases where you need to
|
||||
know which citations appeared in the text without replacing them.
|
||||
know which citations appeared in the text without emitting CitationInfo objects.
|
||||
|
||||
This is particularly useful when `replace_citation_tokens=False`, as
|
||||
get_cited_documents() will be empty in that case, but get_seen_citations()
|
||||
This is particularly useful when using REMOVE or KEEP_MARKERS mode, as
|
||||
get_cited_documents() will be empty in those cases, but get_seen_citations()
|
||||
will still contain all the citations that were found.
|
||||
|
||||
Returns:
|
||||
@@ -501,13 +572,13 @@ class DynamicCitationProcessor:
|
||||
"""
|
||||
Get the number of unique documents that have been cited.
|
||||
|
||||
Note: This count is only updated when `replace_citation_tokens=True`.
|
||||
When `replace_citation_tokens=False`, this will always return 0.
|
||||
Note: This count is only updated when `citation_mode=HYPERLINK`.
|
||||
When using REMOVE or KEEP_MARKERS mode, this will always return 0.
|
||||
Use len(get_seen_citations()) instead if you need to count citations
|
||||
without replacing them.
|
||||
without emitting CitationInfo objects.
|
||||
|
||||
Returns:
|
||||
Number of unique documents cited. 0 if replace_citation_tokens=False.
|
||||
Number of unique documents cited. 0 if citation_mode is not HYPERLINK.
|
||||
"""
|
||||
return len(self.cited_document_ids)
|
||||
|
||||
@@ -519,9 +590,9 @@ class DynamicCitationProcessor:
|
||||
CitationInfo objects for the same document when it's cited multiple times
|
||||
in close succession. This method clears that tracker.
|
||||
|
||||
This is primarily useful when `replace_citation_tokens=True` to allow
|
||||
This is primarily useful when `citation_mode=HYPERLINK` to allow
|
||||
previously cited documents to emit CitationInfo objects again. Has no
|
||||
effect when `replace_citation_tokens=False`.
|
||||
effect when using REMOVE or KEEP_MARKERS mode.
|
||||
|
||||
The recent citation tracker is also automatically cleared when more than
|
||||
5 non-citation characters are processed between citations.
|
||||
|
||||
@@ -5,9 +5,11 @@ from sqlalchemy.orm import Session
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_utils import create_tool_call_failure_messages
|
||||
from onyx.chat.citation_processor import CitationMapping
|
||||
from onyx.chat.citation_processor import CitationMode
|
||||
from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
from onyx.chat.citation_utils import update_citation_processor_from_tool_response
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.llm_step import extract_tool_calls_from_response_text
|
||||
from onyx.chat.llm_step import run_llm_step
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import ExtractedProjectFiles
|
||||
@@ -37,6 +39,7 @@ from onyx.tools.built_in_tools import CITEABLE_TOOLS_NAMES
|
||||
from onyx.tools.built_in_tools import STOPPING_TOOLS_NAMES
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ToolCallInfo
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool_implementations.images.models import (
|
||||
FinalImageGenerationResponse,
|
||||
@@ -50,6 +53,78 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _try_fallback_tool_extraction(
|
||||
llm_step_result: LlmStepResult,
|
||||
tool_choice: ToolChoiceOptions,
|
||||
fallback_extraction_attempted: bool,
|
||||
tool_defs: list[dict],
|
||||
turn_index: int,
|
||||
) -> tuple[LlmStepResult, bool]:
|
||||
"""Attempt to extract tool calls from response text as a fallback.
|
||||
|
||||
This is a last resort fallback for low quality LLMs or those that don't have
|
||||
tool calling from the serving layer. Also triggers if there's reasoning but
|
||||
no answer and no tool calls.
|
||||
|
||||
Args:
|
||||
llm_step_result: The result from the LLM step
|
||||
tool_choice: The tool choice option used for this step
|
||||
fallback_extraction_attempted: Whether fallback extraction was already attempted
|
||||
tool_defs: List of tool definitions
|
||||
turn_index: The current turn index for placement
|
||||
|
||||
Returns:
|
||||
Tuple of (possibly updated LlmStepResult, whether fallback was attempted this call)
|
||||
"""
|
||||
if fallback_extraction_attempted:
|
||||
return llm_step_result, False
|
||||
|
||||
no_tool_calls = (
|
||||
not llm_step_result.tool_calls or len(llm_step_result.tool_calls) == 0
|
||||
)
|
||||
reasoning_but_no_answer_or_tools = (
|
||||
llm_step_result.reasoning and not llm_step_result.answer and no_tool_calls
|
||||
)
|
||||
should_try_fallback = (
|
||||
tool_choice == ToolChoiceOptions.REQUIRED and no_tool_calls
|
||||
) or reasoning_but_no_answer_or_tools
|
||||
|
||||
if not should_try_fallback:
|
||||
return llm_step_result, False
|
||||
|
||||
# Try to extract from answer first, then fall back to reasoning
|
||||
extracted_tool_calls: list[ToolCallKickoff] = []
|
||||
if llm_step_result.answer:
|
||||
extracted_tool_calls = extract_tool_calls_from_response_text(
|
||||
response_text=llm_step_result.answer,
|
||||
tool_definitions=tool_defs,
|
||||
placement=Placement(turn_index=turn_index),
|
||||
)
|
||||
if not extracted_tool_calls and llm_step_result.reasoning:
|
||||
extracted_tool_calls = extract_tool_calls_from_response_text(
|
||||
response_text=llm_step_result.reasoning,
|
||||
tool_definitions=tool_defs,
|
||||
placement=Placement(turn_index=turn_index),
|
||||
)
|
||||
|
||||
if extracted_tool_calls:
|
||||
logger.info(
|
||||
f"Extracted {len(extracted_tool_calls)} tool call(s) from response text "
|
||||
f"as fallback (tool_choice was REQUIRED but no tool calls returned)"
|
||||
)
|
||||
return (
|
||||
LlmStepResult(
|
||||
reasoning=llm_step_result.reasoning,
|
||||
answer=llm_step_result.answer,
|
||||
tool_calls=extracted_tool_calls,
|
||||
),
|
||||
True,
|
||||
)
|
||||
|
||||
return llm_step_result, True
|
||||
|
||||
|
||||
# Hardcoded oppinionated value, might breaks down to something like:
|
||||
# Cycle 1: Calls web_search for something
|
||||
# Cycle 2: Calls open_url for some results
|
||||
@@ -125,6 +200,9 @@ def construct_message_history(
|
||||
if history_token_budget < 0:
|
||||
raise ValueError("Not enough tokens available to construct message history")
|
||||
|
||||
if system_prompt:
|
||||
system_prompt.should_cache = True
|
||||
|
||||
# If no history, build minimal context
|
||||
if not simple_chat_history:
|
||||
result = [system_prompt] if system_prompt else []
|
||||
@@ -199,6 +277,7 @@ def construct_message_history(
|
||||
|
||||
for msg in reversed(history_before_last_user):
|
||||
if current_token_count + msg.token_count <= remaining_budget:
|
||||
msg.should_cache = True
|
||||
truncated_history_before.insert(0, msg)
|
||||
current_token_count += msg.token_count
|
||||
else:
|
||||
@@ -293,6 +372,7 @@ def run_llm_loop(
|
||||
forced_tool_id: int | None = None,
|
||||
user_identity: LLMUserIdentity | None = None,
|
||||
chat_session_id: str | None = None,
|
||||
include_citations: bool = True,
|
||||
) -> None:
|
||||
with trace(
|
||||
"run_llm_loop",
|
||||
@@ -310,7 +390,13 @@ def run_llm_loop(
|
||||
initialize_litellm()
|
||||
|
||||
# Initialize citation processor for handling citations dynamically
|
||||
citation_processor = DynamicCitationProcessor()
|
||||
# When include_citations is True, use HYPERLINK mode to format citations as [[1]](url)
|
||||
# When include_citations is False, use REMOVE mode to strip citations from output
|
||||
citation_processor = DynamicCitationProcessor(
|
||||
citation_mode=(
|
||||
CitationMode.HYPERLINK if include_citations else CitationMode.REMOVE
|
||||
)
|
||||
)
|
||||
|
||||
# Add project file citation mappings if project files are present
|
||||
project_citation_mapping: CitationMapping = {}
|
||||
@@ -340,6 +426,7 @@ def run_llm_loop(
|
||||
ran_image_gen: bool = False
|
||||
just_ran_web_search: bool = False
|
||||
has_called_search_tool: bool = False
|
||||
fallback_extraction_attempted: bool = False
|
||||
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
|
||||
|
||||
default_base_system_prompt: str = get_default_base_system_prompt(db_session)
|
||||
@@ -348,6 +435,7 @@ def run_llm_loop(
|
||||
|
||||
reasoning_cycles = 0
|
||||
for llm_cycle_count in range(MAX_LLM_CYCLES):
|
||||
out_of_cycles = llm_cycle_count == MAX_LLM_CYCLES - 1
|
||||
if forced_tool_id:
|
||||
# Needs to be just the single one because the "required" currently doesn't have a specified tool, just a binary
|
||||
final_tools = [tool for tool in tools if tool.id == forced_tool_id]
|
||||
@@ -355,7 +443,7 @@ def run_llm_loop(
|
||||
raise ValueError(f"Tool {forced_tool_id} not found in tools")
|
||||
tool_choice = ToolChoiceOptions.REQUIRED
|
||||
forced_tool_id = None
|
||||
elif llm_cycle_count == MAX_LLM_CYCLES - 1 or ran_image_gen:
|
||||
elif out_of_cycles or ran_image_gen:
|
||||
# Last cycle, no tools allowed, just answer!
|
||||
tool_choice = ToolChoiceOptions.NONE
|
||||
final_tools = []
|
||||
@@ -422,7 +510,7 @@ def run_llm_loop(
|
||||
# This is to prevent it generating things like:
|
||||
# [Cute Cat](attachment://a_cute_cat_sitting_playfully.png)
|
||||
reminder_message_text = IMAGE_GEN_REMINDER
|
||||
elif just_ran_web_search:
|
||||
elif just_ran_web_search and not out_of_cycles:
|
||||
reminder_message_text = OPEN_URL_REMINDER
|
||||
else:
|
||||
# This is the default case, the LLM at this point may answer so it is important
|
||||
@@ -433,6 +521,7 @@ def run_llm_loop(
|
||||
),
|
||||
include_citation_reminder=should_cite_documents
|
||||
or always_cite_documents,
|
||||
is_last_cycle=out_of_cycles,
|
||||
)
|
||||
|
||||
reminder_msg = (
|
||||
@@ -456,10 +545,11 @@ def run_llm_loop(
|
||||
|
||||
# This calls the LLM, yields packets (reasoning, answers, etc.) and returns the result
|
||||
# It also pre-processes the tool calls in preparation for running them
|
||||
tool_defs = [tool.tool_definition() for tool in final_tools]
|
||||
llm_step_result, has_reasoned = run_llm_step(
|
||||
emitter=emitter,
|
||||
history=truncated_message_history,
|
||||
tool_definitions=[tool.tool_definition() for tool in final_tools],
|
||||
tool_definitions=tool_defs,
|
||||
tool_choice=tool_choice,
|
||||
llm=llm,
|
||||
placement=Placement(turn_index=llm_cycle_count + reasoning_cycles),
|
||||
@@ -474,6 +564,19 @@ def run_llm_loop(
|
||||
if has_reasoned:
|
||||
reasoning_cycles += 1
|
||||
|
||||
# Fallback extraction for LLMs that don't support tool calling natively or are lower quality
|
||||
# and might incorrectly output tool calls in other channels
|
||||
llm_step_result, attempted = _try_fallback_tool_extraction(
|
||||
llm_step_result=llm_step_result,
|
||||
tool_choice=tool_choice,
|
||||
fallback_extraction_attempted=fallback_extraction_attempted,
|
||||
tool_defs=tool_defs,
|
||||
turn_index=llm_cycle_count + reasoning_cycles,
|
||||
)
|
||||
if attempted:
|
||||
# To prevent the case of excessive looping with bad models, we only allow one fallback attempt
|
||||
fallback_extraction_attempted = True
|
||||
|
||||
# Save citation mapping after each LLM step for incremental state updates
|
||||
state_container.set_citation_mapping(citation_processor.citation_to_doc)
|
||||
|
||||
@@ -499,7 +602,7 @@ def run_llm_loop(
|
||||
# in-flight citations
|
||||
# It can be cleaned up but not super trivial or worthwhile right now
|
||||
just_ran_web_search = False
|
||||
tool_responses, citation_mapping = run_tool_calls(
|
||||
parallel_tool_call_results = run_tool_calls(
|
||||
tool_calls=tool_calls,
|
||||
tools=final_tools,
|
||||
message_history=truncated_message_history,
|
||||
@@ -507,8 +610,11 @@ def run_llm_loop(
|
||||
user_info=None, # TODO, this is part of memories right now, might want to separate it out
|
||||
citation_mapping=citation_mapping,
|
||||
next_citation_num=citation_processor.get_next_citation_number(),
|
||||
max_concurrent_tools=None,
|
||||
skip_search_query_expansion=has_called_search_tool,
|
||||
)
|
||||
tool_responses = parallel_tool_call_results.tool_responses
|
||||
citation_mapping = parallel_tool_call_results.updated_citation_mapping
|
||||
|
||||
# Failure case, give something reasonable to the LLM to try again
|
||||
if tool_calls and not tool_responses:
|
||||
@@ -563,6 +669,12 @@ def run_llm_loop(
|
||||
):
|
||||
generated_images = tool_response.rich_response.generated_images
|
||||
|
||||
saved_response = (
|
||||
tool_response.rich_response
|
||||
if isinstance(tool_response.rich_response, str)
|
||||
else tool_response.llm_facing_response
|
||||
)
|
||||
|
||||
tool_call_info = ToolCallInfo(
|
||||
parent_tool_call_id=None, # Top-level tool calls are attached to the chat message
|
||||
turn_index=llm_cycle_count + reasoning_cycles,
|
||||
@@ -572,7 +684,7 @@ def run_llm_loop(
|
||||
tool_id=tool.id,
|
||||
reasoning_tokens=llm_step_result.reasoning, # All tool calls from this loop share the same reasoning
|
||||
tool_call_arguments=tool_call.tool_args,
|
||||
tool_call_response=tool_response.llm_facing_response,
|
||||
tool_call_response=saved_response,
|
||||
search_docs=search_docs,
|
||||
generated_images=generated_images,
|
||||
)
|
||||
@@ -628,7 +740,12 @@ def run_llm_loop(
|
||||
should_cite_documents = True
|
||||
|
||||
if not llm_step_result or not llm_step_result.answer:
|
||||
raise RuntimeError("LLM did not return an answer.")
|
||||
raise RuntimeError(
|
||||
"The LLM did not return an answer. "
|
||||
"Typically this is an issue with LLMs that do not support tool calling natively, "
|
||||
"or the model serving API is not configured correctly. "
|
||||
"This may also happen with models that are lower quality outputting invalid tool calls."
|
||||
)
|
||||
|
||||
emitter.emit(
|
||||
Packet(
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Mapping
|
||||
@@ -17,6 +19,7 @@ from onyx.context.search.models import SearchDoc
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.llm.interfaces import LanguageModelInput
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.model_response import Delta
|
||||
@@ -31,6 +34,7 @@ from onyx.llm.models import TextContentPart
|
||||
from onyx.llm.models import ToolCall
|
||||
from onyx.llm.models import ToolMessage
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.llm.prompt_cache.processor import process_with_prompt_cache
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
@@ -45,7 +49,7 @@ from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tracing.framework.create import generation_span
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
from onyx.utils.text_processing import find_all_json_objects
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -134,12 +138,11 @@ def _format_message_history_for_logging(
|
||||
|
||||
separator = "================================================"
|
||||
|
||||
# Handle string input
|
||||
if isinstance(message_history, str):
|
||||
formatted_lines.append("Message [string]:")
|
||||
formatted_lines.append(separator)
|
||||
formatted_lines.append(f"{message_history}")
|
||||
return "\n".join(formatted_lines)
|
||||
# Handle single ChatCompletionMessage - wrap in list for uniform processing
|
||||
if isinstance(
|
||||
message_history, (SystemMessage, UserMessage, AssistantMessage, ToolMessage)
|
||||
):
|
||||
message_history = [message_history]
|
||||
|
||||
# Handle sequence of messages
|
||||
for i, msg in enumerate(message_history):
|
||||
@@ -209,7 +212,8 @@ def _update_tool_call_with_delta(
|
||||
|
||||
if index not in tool_calls_in_progress:
|
||||
tool_calls_in_progress[index] = {
|
||||
"id": None,
|
||||
# Fallback ID in case the provider never sends one via deltas.
|
||||
"id": f"fallback_{uuid.uuid4().hex}",
|
||||
"name": None,
|
||||
"arguments": "",
|
||||
}
|
||||
@@ -275,8 +279,147 @@ def _extract_tool_call_kickoffs(
|
||||
return tool_calls
|
||||
|
||||
|
||||
def extract_tool_calls_from_response_text(
|
||||
response_text: str | None,
|
||||
tool_definitions: list[dict],
|
||||
placement: Placement,
|
||||
) -> list[ToolCallKickoff]:
|
||||
"""Extract tool calls from LLM response text by matching JSON against tool definitions.
|
||||
|
||||
This is a fallback mechanism for when the LLM was expected to return tool calls
|
||||
but didn't use the proper tool call format. It searches for JSON objects in the
|
||||
response text that match the structure of available tools.
|
||||
|
||||
Args:
|
||||
response_text: The LLM's text response to search for tool calls
|
||||
tool_definitions: List of tool definitions to match against
|
||||
placement: Placement information for the tool calls
|
||||
|
||||
Returns:
|
||||
List of ToolCallKickoff objects for any matched tool calls
|
||||
"""
|
||||
if not response_text or not tool_definitions:
|
||||
return []
|
||||
|
||||
# Build a map of tool names to their definitions
|
||||
tool_name_to_def: dict[str, dict] = {}
|
||||
for tool_def in tool_definitions:
|
||||
if tool_def.get("type") == "function" and "function" in tool_def:
|
||||
func_def = tool_def["function"]
|
||||
tool_name = func_def.get("name")
|
||||
if tool_name:
|
||||
tool_name_to_def[tool_name] = func_def
|
||||
|
||||
if not tool_name_to_def:
|
||||
return []
|
||||
|
||||
# Find all JSON objects in the response text
|
||||
json_objects = find_all_json_objects(response_text)
|
||||
|
||||
tool_calls: list[ToolCallKickoff] = []
|
||||
tab_index = 0
|
||||
|
||||
for json_obj in json_objects:
|
||||
matched_tool_call = _try_match_json_to_tool(json_obj, tool_name_to_def)
|
||||
if matched_tool_call:
|
||||
tool_name, tool_args = matched_tool_call
|
||||
tool_calls.append(
|
||||
ToolCallKickoff(
|
||||
tool_call_id=f"extracted_{uuid.uuid4().hex[:8]}",
|
||||
tool_name=tool_name,
|
||||
tool_args=tool_args,
|
||||
placement=Placement(
|
||||
turn_index=placement.turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=placement.sub_turn_index,
|
||||
),
|
||||
)
|
||||
)
|
||||
tab_index += 1
|
||||
|
||||
logger.info(
|
||||
f"Extracted {len(tool_calls)} tool call(s) from response text as fallback"
|
||||
)
|
||||
|
||||
return tool_calls
|
||||
|
||||
|
||||
def _try_match_json_to_tool(
|
||||
json_obj: dict[str, Any],
|
||||
tool_name_to_def: dict[str, dict],
|
||||
) -> tuple[str, dict[str, Any]] | None:
|
||||
"""Try to match a JSON object to a tool definition.
|
||||
|
||||
Supports several formats:
|
||||
1. Direct tool call format: {"name": "tool_name", "arguments": {...}}
|
||||
2. Function call format: {"function": {"name": "tool_name", "arguments": {...}}}
|
||||
3. Tool name as key: {"tool_name": {...arguments...}}
|
||||
4. Arguments matching a tool's parameter schema
|
||||
|
||||
Args:
|
||||
json_obj: The JSON object to match
|
||||
tool_name_to_def: Map of tool names to their function definitions
|
||||
|
||||
Returns:
|
||||
Tuple of (tool_name, tool_args) if matched, None otherwise
|
||||
"""
|
||||
# Format 1: Direct tool call format {"name": "...", "arguments": {...}}
|
||||
if "name" in json_obj and json_obj["name"] in tool_name_to_def:
|
||||
tool_name = json_obj["name"]
|
||||
arguments = json_obj.get("arguments", json_obj.get("parameters", {}))
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
if isinstance(arguments, dict):
|
||||
return (tool_name, arguments)
|
||||
|
||||
# Format 2: Function call format {"function": {"name": "...", "arguments": {...}}}
|
||||
if "function" in json_obj and isinstance(json_obj["function"], dict):
|
||||
func_obj = json_obj["function"]
|
||||
if "name" in func_obj and func_obj["name"] in tool_name_to_def:
|
||||
tool_name = func_obj["name"]
|
||||
arguments = func_obj.get("arguments", func_obj.get("parameters", {}))
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
if isinstance(arguments, dict):
|
||||
return (tool_name, arguments)
|
||||
|
||||
# Format 3: Tool name as key {"tool_name": {...arguments...}}
|
||||
for tool_name in tool_name_to_def:
|
||||
if tool_name in json_obj:
|
||||
arguments = json_obj[tool_name]
|
||||
if isinstance(arguments, dict):
|
||||
return (tool_name, arguments)
|
||||
|
||||
# Format 4: Check if the JSON object matches a tool's parameter schema
|
||||
for tool_name, func_def in tool_name_to_def.items():
|
||||
params = func_def.get("parameters", {})
|
||||
properties = params.get("properties", {})
|
||||
required = params.get("required", [])
|
||||
|
||||
if not properties:
|
||||
continue
|
||||
|
||||
# Check if all required parameters are present (empty required = all optional)
|
||||
if all(req in json_obj for req in required):
|
||||
# Check if any of the tool's properties are in the JSON object
|
||||
matching_props = [prop for prop in properties if prop in json_obj]
|
||||
if matching_props:
|
||||
# Filter to only include known properties
|
||||
filtered_args = {k: v for k, v in json_obj.items() if k in properties}
|
||||
return (tool_name, filtered_args)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def translate_history_to_llm_format(
|
||||
history: list[ChatMessageSimple],
|
||||
llm_config: LLMConfig,
|
||||
) -> LanguageModelInput:
|
||||
"""Convert a list of ChatMessageSimple to LanguageModelInput format.
|
||||
|
||||
@@ -284,8 +427,23 @@ def translate_history_to_llm_format(
|
||||
handling different message types and image files for multimodal support.
|
||||
"""
|
||||
messages: list[ChatCompletionMessage] = []
|
||||
last_cacheable_msg_idx = -1
|
||||
all_previous_msgs_cacheable = True
|
||||
|
||||
for idx, msg in enumerate(history):
|
||||
# if the message is being added to the history
|
||||
if msg.message_type in [
|
||||
MessageType.SYSTEM,
|
||||
MessageType.USER,
|
||||
MessageType.ASSISTANT,
|
||||
MessageType.TOOL_CALL_RESPONSE,
|
||||
]:
|
||||
all_previous_msgs_cacheable = (
|
||||
all_previous_msgs_cacheable and msg.should_cache
|
||||
)
|
||||
if all_previous_msgs_cacheable:
|
||||
last_cacheable_msg_idx = idx
|
||||
|
||||
for msg in history:
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
system_msg = SystemMessage(
|
||||
role="system",
|
||||
@@ -395,7 +553,7 @@ def translate_history_to_llm_format(
|
||||
assistant_msg_with_tool = AssistantMessage(
|
||||
role="assistant",
|
||||
content=None, # The tool call is parsed, doesn't need to be duplicated in the content
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
tool_calls=tool_calls or None,
|
||||
)
|
||||
messages.append(assistant_msg_with_tool)
|
||||
|
||||
@@ -417,6 +575,18 @@ def translate_history_to_llm_format(
|
||||
f"Unknown message type {msg.message_type} in history. Skipping message."
|
||||
)
|
||||
|
||||
# prompt caching: rely on should_cache in ChatMessageSimple to
|
||||
# pick the split point for the cacheable prefix and suffix
|
||||
if last_cacheable_msg_idx != -1:
|
||||
processed_messages, _ = process_with_prompt_cache(
|
||||
llm_config=llm_config,
|
||||
cacheable_prefix=messages[: last_cacheable_msg_idx + 1],
|
||||
suffix=messages[last_cacheable_msg_idx + 1 :],
|
||||
continuation=False,
|
||||
)
|
||||
assert isinstance(processed_messages, list) # for mypy
|
||||
messages = processed_messages
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
@@ -429,6 +599,10 @@ def _increment_turns(
|
||||
return turn_index, sub_turn_index + 1
|
||||
|
||||
|
||||
def _delta_has_action(delta: Delta) -> bool:
|
||||
return bool(delta.content or delta.reasoning_content or delta.tool_calls)
|
||||
|
||||
|
||||
def run_llm_step_pkt_generator(
|
||||
history: list[ChatMessageSimple],
|
||||
tool_definitions: list[dict],
|
||||
@@ -499,7 +673,7 @@ def run_llm_step_pkt_generator(
|
||||
tab_index = placement.tab_index
|
||||
sub_turn_index = placement.sub_turn_index
|
||||
|
||||
llm_msg_history = translate_history_to_llm_format(history)
|
||||
llm_msg_history = translate_history_to_llm_format(history, llm.config)
|
||||
has_reasoned = 0
|
||||
|
||||
# Uncomment the line below to log the entire message history to the console
|
||||
@@ -526,6 +700,8 @@ def run_llm_step_pkt_generator(
|
||||
span_generation.span_data.input = cast(
|
||||
Sequence[Mapping[str, Any]], llm_msg_history
|
||||
)
|
||||
stream_start_time = time.monotonic()
|
||||
first_action_recorded = False
|
||||
for packet in llm.stream(
|
||||
prompt=llm_msg_history,
|
||||
tools=tool_definitions,
|
||||
@@ -546,6 +722,23 @@ def run_llm_step_pkt_generator(
|
||||
# Note: LLM cost tracking is now handled in multi_llm.py
|
||||
delta = packet.choice.delta
|
||||
|
||||
# Weird behavior from some model providers, just log and ignore for now
|
||||
if (
|
||||
delta.content is None
|
||||
and delta.reasoning_content is None
|
||||
and delta.tool_calls is None
|
||||
):
|
||||
logger.warning(
|
||||
f"LLM packet is empty (no contents, reasoning or tool calls). Skipping: {packet}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not first_action_recorded and _delta_has_action(delta):
|
||||
span_generation.span_data.time_to_first_action_seconds = (
|
||||
time.monotonic() - stream_start_time
|
||||
)
|
||||
first_action_recorded = True
|
||||
|
||||
if custom_token_processor:
|
||||
# The custom token processor can modify the deltas for specific custom logic
|
||||
# It can also return a state so that it can handle aggregated delta logic etc.
|
||||
@@ -703,6 +896,15 @@ def run_llm_step_pkt_generator(
|
||||
# Flush custom token processor to get any final tool calls
|
||||
if custom_token_processor:
|
||||
flush_delta, processor_state = custom_token_processor(None, processor_state)
|
||||
if (
|
||||
not first_action_recorded
|
||||
and flush_delta is not None
|
||||
and _delta_has_action(flush_delta)
|
||||
):
|
||||
span_generation.span_data.time_to_first_action_seconds = (
|
||||
time.monotonic() - stream_start_time
|
||||
)
|
||||
first_action_recorded = True
|
||||
if flush_delta and flush_delta.tool_calls:
|
||||
for tool_call_delta in flush_delta.tool_calls:
|
||||
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
|
||||
@@ -790,14 +992,14 @@ def run_llm_step_pkt_generator(
|
||||
logger.debug(f"Accumulated reasoning: {accumulated_reasoning}")
|
||||
logger.debug(f"Accumulated answer: {accumulated_answer}")
|
||||
|
||||
if tool_calls:
|
||||
tool_calls_str = "\n".join(
|
||||
f" - {tc.tool_name}: {json.dumps(tc.tool_args, indent=4)}"
|
||||
for tc in tool_calls
|
||||
)
|
||||
logger.debug(f"Tool calls:\n{tool_calls_str}")
|
||||
else:
|
||||
logger.debug("Tool calls: []")
|
||||
if tool_calls:
|
||||
tool_calls_str = "\n".join(
|
||||
f" - {tc.tool_name}: {json.dumps(tc.tool_args, indent=4)}"
|
||||
for tc in tool_calls
|
||||
)
|
||||
logger.debug(f"Tool calls:\n{tool_calls_str}")
|
||||
else:
|
||||
logger.debug("Tool calls: []")
|
||||
|
||||
return (
|
||||
LlmStepResult(
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
@@ -8,10 +7,7 @@ from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.enums import QueryFlow
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
@@ -24,25 +20,6 @@ from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.tool_implementations.custom.base_tool_types import ToolResultType
|
||||
|
||||
|
||||
# First chunk of info for streaming QA
|
||||
class QADocsResponse(BaseModel):
|
||||
top_documents: list[SearchDoc]
|
||||
rephrased_query: str | None = None
|
||||
predicted_flow: QueryFlow | None
|
||||
predicted_search: SearchType | None
|
||||
applied_source_filters: list[DocumentSource] | None
|
||||
applied_time_cutoff: datetime | None
|
||||
recency_bias_multiplier: float
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
initial_dict["applied_time_cutoff"] = (
|
||||
self.applied_time_cutoff.isoformat() if self.applied_time_cutoff else None
|
||||
)
|
||||
|
||||
return initial_dict
|
||||
|
||||
|
||||
class StreamStopReason(Enum):
|
||||
CONTEXT_LENGTH = "context_length"
|
||||
CANCELLED = "cancelled"
|
||||
@@ -70,22 +47,11 @@ class UserKnowledgeFilePacket(BaseModel):
|
||||
user_files: list[FileDescriptor]
|
||||
|
||||
|
||||
class LLMRelevanceFilterResponse(BaseModel):
|
||||
llm_selected_doc_indices: list[int]
|
||||
|
||||
|
||||
class RelevanceAnalysis(BaseModel):
|
||||
relevant: bool
|
||||
content: str | None = None
|
||||
|
||||
|
||||
class SectionRelevancePiece(RelevanceAnalysis):
|
||||
"""LLM analysis mapped to an Inference Section"""
|
||||
|
||||
document_id: str
|
||||
chunk_id: int # ID of the center chunk for a given inference section
|
||||
|
||||
|
||||
class DocumentRelevance(BaseModel):
|
||||
"""Contains all relevance information for a given search"""
|
||||
|
||||
@@ -116,12 +82,6 @@ class OnyxAnswer(BaseModel):
|
||||
answer: str | None
|
||||
|
||||
|
||||
class ThreadMessage(BaseModel):
|
||||
message: str
|
||||
sender: str | None = None
|
||||
role: MessageType = MessageType.USER
|
||||
|
||||
|
||||
class FileChatDisplay(BaseModel):
|
||||
file_ids: list[str]
|
||||
|
||||
@@ -158,7 +118,6 @@ class PersonaOverrideConfig(BaseModel):
|
||||
num_chunks: float | None = None
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
|
||||
@@ -264,6 +223,12 @@ class ChatMessageSimple(BaseModel):
|
||||
image_files: list[ChatLoadedFile] | None = None
|
||||
# Only for TOOL_CALL_RESPONSE type messages
|
||||
tool_call_id: str | None = None
|
||||
# The last message for which this is true
|
||||
# AND is true for all previous messages
|
||||
# (counting from the start of the history)
|
||||
# represents the end of the cacheable prefix
|
||||
# used for prompt caching
|
||||
should_cache: bool = False
|
||||
|
||||
|
||||
class ProjectFileMetadata(BaseModel):
|
||||
|
||||
@@ -5,10 +5,13 @@ An overview can be found in the README.md file in this directory.
|
||||
|
||||
import re
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from uuid import UUID
|
||||
|
||||
from redis.client import Redis
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.chat_processing_checker import set_processing_status
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.chat_state import run_chat_loop_with_state_containers
|
||||
from onyx.chat.chat_utils import convert_chat_history
|
||||
@@ -35,9 +38,10 @@ from onyx.chat.save_chat import save_chat_turn
|
||||
from onyx.chat.stop_signal_checker import is_connected as check_stop_signal
|
||||
from onyx.chat.stop_signal_checker import reset_cancel_status
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import CitationDocInfo
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
@@ -45,6 +49,9 @@ from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.memory import get_memories
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.projects import get_project_token_count
|
||||
from onyx.db.projects import get_user_files_from_project
|
||||
@@ -62,6 +69,7 @@ from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
|
||||
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from onyx.server.query_and_chat.models import OptionalSearchSetting
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
|
||||
@@ -84,12 +92,20 @@ logger = setup_logger()
|
||||
ERROR_TYPE_CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class ToolCallException(Exception):
|
||||
"""Exception raised for errors during tool calls."""
|
||||
def _should_enable_slack_search(
|
||||
persona: Persona,
|
||||
filters: BaseFilters | None,
|
||||
) -> bool:
|
||||
"""Determine if Slack search should be enabled.
|
||||
|
||||
def __init__(self, message: str, tool_name: str | None = None):
|
||||
super().__init__(message)
|
||||
self.tool_name = tool_name
|
||||
Returns True if:
|
||||
- Source type filter exists and includes Slack, OR
|
||||
- Default persona with no source type filter
|
||||
"""
|
||||
source_types = filters.source_type if filters else None
|
||||
return (source_types is not None and DocumentSource.SLACK in source_types) or (
|
||||
persona.id == DEFAULT_PERSONA_ID and source_types is None
|
||||
)
|
||||
|
||||
|
||||
def _extract_project_file_texts_and_images(
|
||||
@@ -280,6 +296,7 @@ def handle_stream_message_objects(
|
||||
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
mcp_headers: dict[str, str] | None = None,
|
||||
bypass_acl: bool = False,
|
||||
# Additional context that should be included in the chat history, for example:
|
||||
# Slack threads where the conversation cannot be represented by a chain of User/Assistant
|
||||
@@ -294,6 +311,8 @@ def handle_stream_message_objects(
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
llm: LLM | None = None
|
||||
chat_session: ChatSession | None = None
|
||||
redis_client: Redis | None = None
|
||||
|
||||
user_id = user.id if user is not None else None
|
||||
llm_user_identifier = (
|
||||
@@ -339,6 +358,23 @@ def handle_stream_message_objects(
|
||||
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
|
||||
)
|
||||
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=(
|
||||
user.email
|
||||
if user and not getattr(user, "is_anonymous", False)
|
||||
else tenant_id
|
||||
),
|
||||
event=MilestoneRecordType.USER_MESSAGE_SENT,
|
||||
properties={
|
||||
"origin": new_msg_req.origin.value,
|
||||
"has_files": len(new_msg_req.file_descriptors) > 0,
|
||||
"has_project": chat_session.project_id is not None,
|
||||
"has_persona": persona is not None and persona.id != DEFAULT_PERSONA_ID,
|
||||
"deep_research": new_msg_req.deep_research,
|
||||
},
|
||||
)
|
||||
|
||||
llm = get_llm_for_persona(
|
||||
persona=persona,
|
||||
user=user,
|
||||
@@ -380,7 +416,10 @@ def handle_stream_message_objects(
|
||||
if new_msg_req.parent_message_id == AUTO_PLACE_AFTER_LATEST_MESSAGE:
|
||||
# Auto-place after the latest message in the chain
|
||||
parent_message = chat_history[-1] if chat_history else root_message
|
||||
elif new_msg_req.parent_message_id is None:
|
||||
elif (
|
||||
new_msg_req.parent_message_id is None
|
||||
or new_msg_req.parent_message_id == root_message.id
|
||||
):
|
||||
# None = regeneration from root
|
||||
parent_message = root_message
|
||||
# Truncate history since we're starting from root
|
||||
@@ -480,11 +519,15 @@ def handle_stream_message_objects(
|
||||
),
|
||||
bypass_acl=bypass_acl,
|
||||
slack_context=slack_context,
|
||||
enable_slack_search=_should_enable_slack_search(
|
||||
persona, new_msg_req.internal_search_filters
|
||||
),
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=chat_session.id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
additional_headers=custom_tool_additional_headers,
|
||||
mcp_headers=mcp_headers,
|
||||
),
|
||||
allowed_tool_ids=new_msg_req.allowed_tool_ids,
|
||||
search_usage_forcing_setting=project_search_config.search_usage,
|
||||
@@ -536,10 +579,27 @@ def handle_stream_message_objects(
|
||||
def check_is_connected() -> bool:
|
||||
return check_stop_signal(chat_session.id, redis_client)
|
||||
|
||||
set_processing_status(
|
||||
chat_session_id=chat_session.id,
|
||||
redis_client=redis_client,
|
||||
value=True,
|
||||
)
|
||||
|
||||
# Use external state container if provided, otherwise create internal one
|
||||
# External container allows non-streaming callers to access accumulated state
|
||||
state_container = external_state_container or ChatStateContainer()
|
||||
|
||||
def llm_loop_completion_callback(
|
||||
state_container: ChatStateContainer,
|
||||
) -> None:
|
||||
llm_loop_completion_handle(
|
||||
state_container=state_container,
|
||||
db_session=db_session,
|
||||
chat_session_id=str(chat_session.id),
|
||||
is_connected=check_is_connected,
|
||||
assistant_message=assistant_response,
|
||||
)
|
||||
|
||||
# Run the LLM loop with explicit wrapper for stop signal handling
|
||||
# The wrapper runs run_llm_loop in a background thread and polls every 300ms
|
||||
# for stop signals. run_llm_loop itself doesn't know about stopping.
|
||||
@@ -555,6 +615,7 @@ def handle_stream_message_objects(
|
||||
|
||||
yield from run_chat_loop_with_state_containers(
|
||||
run_deep_research_llm_loop,
|
||||
llm_loop_completion_callback,
|
||||
is_connected=check_is_connected,
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
@@ -571,6 +632,7 @@ def handle_stream_message_objects(
|
||||
else:
|
||||
yield from run_chat_loop_with_state_containers(
|
||||
run_llm_loop,
|
||||
llm_loop_completion_callback,
|
||||
is_connected=check_is_connected, # Not passed through to run_llm_loop
|
||||
emitter=emitter,
|
||||
state_container=state_container,
|
||||
@@ -586,53 +648,9 @@ def handle_stream_message_objects(
|
||||
forced_tool_id=forced_tool_id,
|
||||
user_identity=user_identity,
|
||||
chat_session_id=str(chat_session.id),
|
||||
include_citations=new_msg_req.include_citations,
|
||||
)
|
||||
|
||||
# Determine if stopped by user
|
||||
completed_normally = check_is_connected()
|
||||
if not completed_normally:
|
||||
logger.debug(f"Chat session {chat_session.id} stopped by user")
|
||||
|
||||
# Build final answer based on completion status
|
||||
if completed_normally:
|
||||
if state_container.answer_tokens is None:
|
||||
raise RuntimeError(
|
||||
"LLM run completed normally but did not return an answer."
|
||||
)
|
||||
final_answer = state_container.answer_tokens
|
||||
else:
|
||||
# Stopped by user - append stop message
|
||||
if state_container.answer_tokens:
|
||||
final_answer = (
|
||||
state_container.answer_tokens
|
||||
+ " ... The generation was stopped by the user here."
|
||||
)
|
||||
else:
|
||||
final_answer = "The generation was stopped by the user."
|
||||
|
||||
# Build citation_docs_info from accumulated citations in state container
|
||||
citation_docs_info: list[CitationDocInfo] = []
|
||||
seen_citation_nums: set[int] = set()
|
||||
for citation_num, search_doc in state_container.citation_to_doc.items():
|
||||
if citation_num not in seen_citation_nums:
|
||||
seen_citation_nums.add(citation_num)
|
||||
citation_docs_info.append(
|
||||
CitationDocInfo(
|
||||
search_doc=search_doc,
|
||||
citation_number=citation_num,
|
||||
)
|
||||
)
|
||||
|
||||
save_chat_turn(
|
||||
message_text=final_answer,
|
||||
reasoning_tokens=state_container.reasoning_tokens,
|
||||
citation_docs_info=citation_docs_info,
|
||||
tool_calls=state_container.tool_calls,
|
||||
db_session=db_session,
|
||||
assistant_message=assistant_response,
|
||||
is_clarification=state_container.is_clarification,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to process chat message.")
|
||||
|
||||
@@ -650,15 +668,7 @@ def handle_stream_message_objects(
|
||||
error_msg = str(e)
|
||||
stack_trace = traceback.format_exc()
|
||||
|
||||
if isinstance(e, ToolCallException):
|
||||
yield StreamingError(
|
||||
error=error_msg,
|
||||
stack_trace=stack_trace,
|
||||
error_code="TOOL_CALL_FAILED",
|
||||
is_retryable=True,
|
||||
details={"tool_name": e.tool_name} if e.tool_name else None,
|
||||
)
|
||||
elif llm:
|
||||
if llm:
|
||||
client_error_msg, error_code, is_retryable = litellm_exception_to_error_msg(
|
||||
e, llm
|
||||
)
|
||||
@@ -690,7 +700,67 @@ def handle_stream_message_objects(
|
||||
)
|
||||
|
||||
db_session.rollback()
|
||||
return
|
||||
finally:
|
||||
try:
|
||||
if redis_client is not None and chat_session is not None:
|
||||
set_processing_status(
|
||||
chat_session_id=chat_session.id,
|
||||
redis_client=redis_client,
|
||||
value=False,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error in setting processing status")
|
||||
|
||||
|
||||
def llm_loop_completion_handle(
|
||||
state_container: ChatStateContainer,
|
||||
is_connected: Callable[[], bool],
|
||||
db_session: Session,
|
||||
chat_session_id: str,
|
||||
assistant_message: ChatMessage,
|
||||
) -> None:
|
||||
# Determine if stopped by user
|
||||
completed_normally = is_connected()
|
||||
# Build final answer based on completion status
|
||||
if completed_normally:
|
||||
if state_container.answer_tokens is None:
|
||||
raise RuntimeError(
|
||||
"LLM run completed normally but did not return an answer."
|
||||
)
|
||||
final_answer = state_container.answer_tokens
|
||||
else:
|
||||
# Stopped by user - append stop message
|
||||
logger.debug(f"Chat session {chat_session_id} stopped by user")
|
||||
if state_container.answer_tokens:
|
||||
final_answer = (
|
||||
state_container.answer_tokens
|
||||
+ " ... \n\nGeneration was stopped by the user."
|
||||
)
|
||||
else:
|
||||
final_answer = "The generation was stopped by the user."
|
||||
|
||||
# Build citation_docs_info from accumulated citations in state container
|
||||
citation_docs_info: list[CitationDocInfo] = []
|
||||
seen_citation_nums: set[int] = set()
|
||||
for citation_num, search_doc in state_container.citation_to_doc.items():
|
||||
if citation_num not in seen_citation_nums:
|
||||
seen_citation_nums.add(citation_num)
|
||||
citation_docs_info.append(
|
||||
CitationDocInfo(
|
||||
search_doc=search_doc,
|
||||
citation_number=citation_num,
|
||||
)
|
||||
)
|
||||
|
||||
save_chat_turn(
|
||||
message_text=final_answer,
|
||||
reasoning_tokens=state_container.reasoning_tokens,
|
||||
citation_docs_info=citation_docs_info,
|
||||
tool_calls=state_container.tool_calls,
|
||||
db_session=db_session,
|
||||
assistant_message=assistant_message,
|
||||
is_clarification=state_container.is_clarification,
|
||||
)
|
||||
|
||||
|
||||
def stream_chat_message_objects(
|
||||
@@ -739,6 +809,8 @@ def stream_chat_message_objects(
|
||||
deep_research=new_msg_req.deep_research,
|
||||
parent_message_id=new_msg_req.parent_message_id,
|
||||
chat_session_id=new_msg_req.chat_session_id,
|
||||
origin=new_msg_req.origin,
|
||||
include_citations=new_msg_req.include_citations,
|
||||
)
|
||||
return handle_stream_message_objects(
|
||||
new_msg_req=translated_new_msg_req,
|
||||
|
||||
@@ -10,6 +10,7 @@ from onyx.file_store.models import FileDescriptor
|
||||
from onyx.prompts.chat_prompts import CITATION_REMINDER
|
||||
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
|
||||
from onyx.prompts.chat_prompts import DEFAULT_SYSTEM_PROMPT
|
||||
from onyx.prompts.chat_prompts import LAST_CYCLE_CITATION_REMINDER
|
||||
from onyx.prompts.chat_prompts import REQUIRE_CITATION_GUIDANCE
|
||||
from onyx.prompts.chat_prompts import USER_INFO_HEADER
|
||||
from onyx.prompts.prompt_utils import get_company_context
|
||||
@@ -17,15 +18,18 @@ from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
from onyx.prompts.prompt_utils import replace_citation_guidance_tag
|
||||
from onyx.prompts.tool_prompts import GENERATE_IMAGE_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import INTERNAL_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import MEMORY_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import OPEN_URLS_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import PYTHON_TOOL_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import TOOL_SECTION_HEADER
|
||||
from onyx.prompts.tool_prompts import WEB_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import WEB_SEARCH_SITE_DISABLED_GUIDANCE
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.memory.memory_tool import MemoryTool
|
||||
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
@@ -115,8 +119,11 @@ def calculate_reserved_tokens(
|
||||
def build_reminder_message(
|
||||
reminder_text: str | None,
|
||||
include_citation_reminder: bool,
|
||||
is_last_cycle: bool,
|
||||
) -> str | None:
|
||||
reminder = reminder_text.strip() if reminder_text else ""
|
||||
if is_last_cycle:
|
||||
reminder += "\n\n" + LAST_CYCLE_CITATION_REMINDER
|
||||
if include_citation_reminder:
|
||||
reminder += "\n\n" + CITATION_REMINDER
|
||||
reminder = reminder.strip()
|
||||
@@ -169,10 +176,13 @@ def build_system_prompt(
|
||||
TOOL_SECTION_HEADER
|
||||
+ TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
+ INTERNAL_SEARCH_GUIDANCE
|
||||
+ WEB_SEARCH_GUIDANCE
|
||||
+ WEB_SEARCH_GUIDANCE.format(
|
||||
site_colon_disabled=WEB_SEARCH_SITE_DISABLED_GUIDANCE
|
||||
)
|
||||
+ OPEN_URLS_GUIDANCE
|
||||
+ GENERATE_IMAGE_GUIDANCE
|
||||
+ PYTHON_TOOL_GUIDANCE
|
||||
+ GENERATE_IMAGE_GUIDANCE
|
||||
+ MEMORY_GUIDANCE
|
||||
)
|
||||
return system_prompt
|
||||
|
||||
@@ -186,6 +196,7 @@ def build_system_prompt(
|
||||
has_generate_image = any(
|
||||
isinstance(tool, ImageGenerationTool) for tool in tools
|
||||
)
|
||||
has_memory = any(isinstance(tool, MemoryTool) for tool in tools)
|
||||
|
||||
if has_web_search or has_internal_search or include_all_guidance:
|
||||
system_prompt += TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
@@ -195,7 +206,16 @@ def build_system_prompt(
|
||||
system_prompt += INTERNAL_SEARCH_GUIDANCE
|
||||
|
||||
if has_web_search or include_all_guidance:
|
||||
system_prompt += WEB_SEARCH_GUIDANCE
|
||||
site_disabled_guidance = ""
|
||||
if has_web_search:
|
||||
web_search_tool = next(
|
||||
(t for t in tools if isinstance(t, WebSearchTool)), None
|
||||
)
|
||||
if web_search_tool and not web_search_tool.supports_site_filter:
|
||||
site_disabled_guidance = WEB_SEARCH_SITE_DISABLED_GUIDANCE
|
||||
system_prompt += WEB_SEARCH_GUIDANCE.format(
|
||||
site_colon_disabled=site_disabled_guidance
|
||||
)
|
||||
|
||||
if has_open_urls or include_all_guidance:
|
||||
system_prompt += OPEN_URLS_GUIDANCE
|
||||
@@ -206,4 +226,7 @@ def build_system_prompt(
|
||||
if has_generate_image or include_all_guidance:
|
||||
system_prompt += GENERATE_IMAGE_GUIDANCE
|
||||
|
||||
if has_memory or include_all_guidance:
|
||||
system_prompt += MEMORY_GUIDANCE
|
||||
|
||||
return system_prompt
|
||||
|
||||
@@ -22,6 +22,14 @@ APP_PORT = 8080
|
||||
# prefix from requests directed towards the API server. In these cases, set this to `/api`
|
||||
APP_API_PREFIX = os.environ.get("API_PREFIX", "")
|
||||
|
||||
# Certain services need to make HTTP requests to the API server, such as the MCP server and Discord bot
|
||||
API_SERVER_PROTOCOL = os.environ.get("API_SERVER_PROTOCOL", "http")
|
||||
API_SERVER_HOST = os.environ.get("API_SERVER_HOST", "127.0.0.1")
|
||||
# This override allows self-hosting the MCP server with Onyx Cloud backend.
|
||||
API_SERVER_URL_OVERRIDE_FOR_HTTP_REQUESTS = os.environ.get(
|
||||
"API_SERVER_URL_OVERRIDE_FOR_HTTP_REQUESTS"
|
||||
)
|
||||
|
||||
# Whether to send user metadata (user_id/email and session_id) to the LLM provider.
|
||||
# Disabled by default.
|
||||
SEND_USER_METADATA_TO_LLM_PROVIDER = (
|
||||
@@ -120,6 +128,14 @@ VALID_EMAIL_DOMAINS = (
|
||||
if _VALID_EMAIL_DOMAINS_STR
|
||||
else []
|
||||
)
|
||||
|
||||
# Disposable email blocking - blocks temporary/throwaway email addresses
|
||||
# Set to empty string to disable disposable email blocking
|
||||
DISPOSABLE_EMAIL_DOMAINS_URL = os.environ.get(
|
||||
"DISPOSABLE_EMAIL_DOMAINS_URL",
|
||||
"https://disposable.github.io/disposable-email-domains/domains.json",
|
||||
)
|
||||
|
||||
# OAuth Login Flow
|
||||
# Used for both Google OAuth2 and OIDC flows
|
||||
OAUTH_CLIENT_ID = (
|
||||
@@ -192,6 +208,10 @@ OPENSEARCH_REST_API_PORT = int(os.environ.get("OPENSEARCH_REST_API_PORT") or 920
|
||||
OPENSEARCH_ADMIN_USERNAME = os.environ.get("OPENSEARCH_ADMIN_USERNAME", "admin")
|
||||
OPENSEARCH_ADMIN_PASSWORD = os.environ.get("OPENSEARCH_ADMIN_PASSWORD", "")
|
||||
|
||||
ENABLE_OPENSEARCH_FOR_ONYX = (
|
||||
os.environ.get("ENABLE_OPENSEARCH_FOR_ONYX", "").lower() == "true"
|
||||
)
|
||||
|
||||
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
|
||||
# NOTE: this is used if and only if the vespa config server is accessible via a
|
||||
# different host than the main vespa application
|
||||
@@ -556,6 +576,7 @@ JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
JIRA_CONNECTOR_MAX_TICKET_SIZE = int(
|
||||
os.environ.get("JIRA_CONNECTOR_MAX_TICKET_SIZE", 100 * 1024)
|
||||
)
|
||||
JIRA_SLIM_PAGE_SIZE = int(os.environ.get("JIRA_SLIM_PAGE_SIZE", 500))
|
||||
|
||||
GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
|
||||
|
||||
@@ -667,10 +688,6 @@ INDEXING_EMBEDDING_MODEL_NUM_THREADS = int(
|
||||
os.environ.get("INDEXING_EMBEDDING_MODEL_NUM_THREADS") or 8
|
||||
)
|
||||
|
||||
# Maximum number of user file connector credential pairs to index in a single batch
|
||||
# Setting this number too high may overload the indexing process
|
||||
USER_FILE_INDEXING_LIMIT = int(os.environ.get("USER_FILE_INDEXING_LIMIT") or 100)
|
||||
|
||||
# Maximum file size in a document to be indexed
|
||||
MAX_DOCUMENT_CHARS = int(os.environ.get("MAX_DOCUMENT_CHARS") or 5_000_000)
|
||||
MAX_FILE_SIZE_BYTES = int(
|
||||
@@ -742,7 +759,27 @@ BRAINTRUST_PROJECT = os.environ.get("BRAINTRUST_PROJECT", "Onyx")
|
||||
# Braintrust API key - if provided, Braintrust tracing will be enabled
|
||||
BRAINTRUST_API_KEY = os.environ.get("BRAINTRUST_API_KEY") or ""
|
||||
# Maximum concurrency for Braintrust evaluations
|
||||
BRAINTRUST_MAX_CONCURRENCY = int(os.environ.get("BRAINTRUST_MAX_CONCURRENCY") or 5)
|
||||
# None means unlimited concurrency, otherwise specify a number
|
||||
_braintrust_concurrency = os.environ.get("BRAINTRUST_MAX_CONCURRENCY")
|
||||
BRAINTRUST_MAX_CONCURRENCY = (
|
||||
int(_braintrust_concurrency) if _braintrust_concurrency else None
|
||||
)
|
||||
|
||||
#####
|
||||
# Scheduled Evals Configuration
|
||||
#####
|
||||
# Comma-separated list of Braintrust dataset names to run on schedule
|
||||
SCHEDULED_EVAL_DATASET_NAMES = [
|
||||
name.strip()
|
||||
for name in os.environ.get("SCHEDULED_EVAL_DATASET_NAMES", "").split(",")
|
||||
if name.strip()
|
||||
]
|
||||
# Email address to use for search permissions during scheduled evals
|
||||
SCHEDULED_EVAL_PERMISSIONS_EMAIL = os.environ.get(
|
||||
"SCHEDULED_EVAL_PERMISSIONS_EMAIL", "roshan@onyx.app"
|
||||
)
|
||||
# Braintrust project name to use for scheduled evals
|
||||
SCHEDULED_EVAL_PROJECT = os.environ.get("SCHEDULED_EVAL_PROJECT", "st-dev")
|
||||
|
||||
#####
|
||||
# Langfuse Configuration
|
||||
@@ -800,6 +837,11 @@ ENTERPRISE_EDITION_ENABLED = (
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
|
||||
)
|
||||
|
||||
#####
|
||||
# Image Generation Configuration (DEPRECATED)
|
||||
# These environment variables will be deprecated soon.
|
||||
# To configure image generation, please visit the Image Generation page in the Admin Panel.
|
||||
#####
|
||||
# Azure Image Configurations
|
||||
AZURE_IMAGE_API_VERSION = os.environ.get("AZURE_IMAGE_API_VERSION") or os.environ.get(
|
||||
"AZURE_DALLE_API_VERSION"
|
||||
@@ -816,6 +858,7 @@ AZURE_IMAGE_DEPLOYMENT_NAME = os.environ.get(
|
||||
|
||||
# configurable image model
|
||||
IMAGE_MODEL_NAME = os.environ.get("IMAGE_MODEL_NAME", "gpt-image-1")
|
||||
IMAGE_MODEL_PROVIDER = os.environ.get("IMAGE_MODEL_PROVIDER", "openai")
|
||||
|
||||
# Use managed Vespa (Vespa Cloud). If set, must also set VESPA_CLOUD_URL, VESPA_CLOUD_CERT_PATH and VESPA_CLOUD_KEY_PATH
|
||||
MANAGED_VESPA = os.environ.get("MANAGED_VESPA", "").lower() == "true"
|
||||
@@ -962,3 +1005,14 @@ COHERE_DEFAULT_API_KEY = os.environ.get("COHERE_DEFAULT_API_KEY")
|
||||
VERTEXAI_DEFAULT_CREDENTIALS = os.environ.get("VERTEXAI_DEFAULT_CREDENTIALS")
|
||||
VERTEXAI_DEFAULT_LOCATION = os.environ.get("VERTEXAI_DEFAULT_LOCATION", "global")
|
||||
OPENROUTER_DEFAULT_API_KEY = os.environ.get("OPENROUTER_DEFAULT_API_KEY")
|
||||
|
||||
INSTANCE_TYPE = (
|
||||
"managed"
|
||||
if os.environ.get("IS_MANAGED_INSTANCE", "").lower() == "true"
|
||||
else "cloud" if AUTH_TYPE == AuthType.CLOUD else "self_hosted"
|
||||
)
|
||||
|
||||
|
||||
## Discord Bot Configuration
|
||||
DISCORD_BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN")
|
||||
DISCORD_BOT_INVOKE_CHAR = os.environ.get("DISCORD_BOT_INVOKE_CHAR", "!")
|
||||
|
||||
@@ -12,9 +12,6 @@ NUM_POSTPROCESSED_RESULTS = 20
|
||||
# May be less depending on model
|
||||
MAX_CHUNKS_FED_TO_CHAT = int(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 25)
|
||||
|
||||
# Maximum percentage of the context window to fill with selected sections
|
||||
SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE = 0.8
|
||||
|
||||
# 1 / (1 + DOC_TIME_DECAY * doc-age-in-years), set to 0 to have no decay
|
||||
# Capped in Vespa at 0.5
|
||||
DOC_TIME_DECAY = float(
|
||||
@@ -27,11 +24,6 @@ FAVOR_RECENT_DECAY_MULTIPLIER = 2.0
|
||||
# Currently only applies to search flow not chat
|
||||
CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 1)
|
||||
CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 1)
|
||||
DISABLE_LLM_QUERY_REPHRASE = (
|
||||
os.environ.get("DISABLE_LLM_QUERY_REPHRASE", "").lower() == "true"
|
||||
)
|
||||
# 1 edit per 20 characters, currently unused due to fuzzy match being too slow
|
||||
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
|
||||
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
|
||||
# Weighting factor between Vector and Keyword Search, 1 for completely vector search
|
||||
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.5)))
|
||||
@@ -46,34 +38,6 @@ TITLE_CONTENT_RATIO = max(
|
||||
0, min(1, float(os.environ.get("TITLE_CONTENT_RATIO") or 0.10))
|
||||
)
|
||||
|
||||
# A list of languages passed to the LLM to rephase the query
|
||||
# For example "English,French,Spanish", be sure to use the "," separator
|
||||
# TODO these are not used, should probably reintroduce these
|
||||
MULTILINGUAL_QUERY_EXPANSION = os.environ.get("MULTILINGUAL_QUERY_EXPANSION") or None
|
||||
LANGUAGE_HINT = "\n" + (
|
||||
os.environ.get("LANGUAGE_HINT")
|
||||
or "IMPORTANT: Respond in the same language as my query!"
|
||||
)
|
||||
LANGUAGE_CHAT_NAMING_HINT = (
|
||||
os.environ.get("LANGUAGE_CHAT_NAMING_HINT")
|
||||
or "The name of the conversation must be in the same language as the user query."
|
||||
)
|
||||
|
||||
# Number of prompts each persona should have
|
||||
NUM_PERSONA_PROMPTS = 4
|
||||
NUM_PERSONA_PROMPT_GENERATION_CHUNKS = 5
|
||||
|
||||
# Agentic search takes significantly more tokens and therefore has much higher cost.
|
||||
# This configuration allows users to get a search-only experience with instant results
|
||||
# and no involvement from the LLM.
|
||||
# Additionally, some LLM providers have strict rate limits which may prohibit
|
||||
# sending many API requests at once (as is done in agentic search).
|
||||
# Whether the LLM should evaluate all of the document chunks passed in for usefulness
|
||||
# in relation to the user query
|
||||
DISABLE_LLM_DOC_RELEVANCE = (
|
||||
os.environ.get("DISABLE_LLM_DOC_RELEVANCE", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Stops streaming answers back to the UI if this pattern is seen:
|
||||
STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None
|
||||
|
||||
@@ -86,9 +50,6 @@ HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "").lower() == "true"
|
||||
NUM_INTERNET_SEARCH_RESULTS = int(os.environ.get("NUM_INTERNET_SEARCH_RESULTS") or 10)
|
||||
NUM_INTERNET_SEARCH_CHUNKS = int(os.environ.get("NUM_INTERNET_SEARCH_CHUNKS") or 50)
|
||||
|
||||
# Enable in-house model for detecting connector-based filtering in queries
|
||||
ENABLE_CONNECTOR_CLASSIFIER = os.environ.get("ENABLE_CONNECTOR_CLASSIFIER", False)
|
||||
|
||||
VESPA_SEARCHER_THREADS = int(os.environ.get("VESPA_SEARCHER_THREADS") or 2)
|
||||
|
||||
# Whether or not to use the semantic & keyword search expansions for Basic Search
|
||||
@@ -96,5 +57,3 @@ USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH = (
|
||||
os.environ.get("USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH", "false").lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
USE_DIV_CON_AGENT = os.environ.get("USE_DIV_CON_AGENT", "false").lower() == "true"
|
||||
|
||||
@@ -7,6 +7,7 @@ from enum import Enum
|
||||
|
||||
ONYX_DEFAULT_APPLICATION_NAME = "Onyx"
|
||||
ONYX_DISCORD_URL = "https://discord.gg/4NA5SbzrWb"
|
||||
ONYX_UTM_SOURCE = "onyx_app"
|
||||
SLACK_USER_TOKEN_PREFIX = "xoxp-"
|
||||
SLACK_BOT_TOKEN_PREFIX = "xoxb-"
|
||||
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
|
||||
@@ -22,6 +23,9 @@ PUBLIC_DOC_PAT = "PUBLIC"
|
||||
ID_SEPARATOR = ":;:"
|
||||
DEFAULT_BOOST = 0
|
||||
|
||||
# Tag for endpoints that should be included in the public API documentation
|
||||
PUBLIC_API_TAGS: list[str | Enum] = ["public"]
|
||||
|
||||
# Cookies
|
||||
FASTAPI_USERS_AUTH_COOKIE_NAME = (
|
||||
"fastapiusersauth" # Currently a constant, but logic allows for configuration
|
||||
@@ -89,6 +93,7 @@ SSL_CERT_FILE = "bundle.pem"
|
||||
DANSWER_API_KEY_PREFIX = "API_KEY__"
|
||||
DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN = "onyxapikey.ai"
|
||||
UNNAMED_KEY_PLACEHOLDER = "Unnamed"
|
||||
DISCORD_SERVICE_API_KEY_NAME = "discord-bot-service"
|
||||
|
||||
# Key-Value store keys
|
||||
KV_REINDEX_KEY = "needs_reindexing"
|
||||
@@ -146,9 +151,6 @@ CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT = 3600 # 1 hour (in seconds)
|
||||
|
||||
CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
# Doc ID migration can be long-running; use a longer TTL and renew periodically
|
||||
CELERY_USER_FILE_DOCID_MIGRATION_LOCK_TIMEOUT = 10 * 60 # 10 minutes (in seconds)
|
||||
|
||||
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
|
||||
|
||||
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
@@ -237,6 +239,8 @@ class NotificationType(str, Enum):
|
||||
REINDEX = "reindex"
|
||||
PERSONA_SHARED = "persona_shared"
|
||||
TRIAL_ENDS_TWO_DAYS = "two_day_trial_ending" # 2 days left in trial
|
||||
RELEASE_NOTES = "release_notes"
|
||||
ASSISTANT_FILES_READY = "assistant_files_ready"
|
||||
|
||||
|
||||
class BlobType(str, Enum):
|
||||
@@ -337,6 +341,7 @@ class MilestoneRecordType(str, Enum):
|
||||
CREATED_CONNECTOR = "created_connector"
|
||||
CONNECTOR_SUCCEEDED = "connector_succeeded"
|
||||
RAN_QUERY = "ran_query"
|
||||
USER_MESSAGE_SENT = "user_message_sent"
|
||||
MULTIPLE_ASSISTANTS = "multiple_assistants"
|
||||
CREATED_ASSISTANT = "created_assistant"
|
||||
CREATED_ONYX_BOT = "created_onyx_bot"
|
||||
@@ -365,9 +370,6 @@ class OnyxCeleryQueues:
|
||||
CONNECTOR_EXTERNAL_GROUP_SYNC = "connector_external_group_sync"
|
||||
CSV_GENERATION = "csv_generation"
|
||||
|
||||
# Indexing queue
|
||||
USER_FILES_INDEXING = "user_files_indexing"
|
||||
|
||||
# User file processing queue
|
||||
USER_FILE_PROCESSING = "user_file_processing"
|
||||
USER_FILE_PROJECT_SYNC = "user_file_project_sync"
|
||||
@@ -426,7 +428,9 @@ class OnyxRedisLocks:
|
||||
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
|
||||
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
|
||||
USER_FILE_DELETE_LOCK_PREFIX = "da_lock:user_file_delete"
|
||||
USER_FILE_DOCID_MIGRATION_LOCK = "da_lock:user_file_docid_migration"
|
||||
|
||||
# Release notes
|
||||
RELEASE_NOTES_FETCH_LOCK = "da_lock:release_notes_fetch"
|
||||
|
||||
|
||||
class OnyxRedisSignals:
|
||||
@@ -533,7 +537,6 @@ class OnyxCeleryTask:
|
||||
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
|
||||
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
|
||||
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"
|
||||
USER_FILE_DOCID_MIGRATION = "user_file_docid_migration"
|
||||
|
||||
# chat retention
|
||||
CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task"
|
||||
@@ -542,6 +545,7 @@ class OnyxCeleryTask:
|
||||
GENERATE_USAGE_REPORT_TASK = "generate_usage_report_task"
|
||||
|
||||
EVAL_RUN_TASK = "eval_run_task"
|
||||
SCHEDULED_EVAL_TASK = "scheduled_eval_task"
|
||||
|
||||
EXPORT_QUERY_HISTORY_TASK = "export_query_history_task"
|
||||
EXPORT_QUERY_HISTORY_CLEANUP_TASK = "export_query_history_cleanup_task"
|
||||
|
||||
@@ -128,3 +128,17 @@ if _LITELLM_EXTRA_BODY_RAW:
|
||||
LITELLM_EXTRA_BODY = json.loads(_LITELLM_EXTRA_BODY_RAW)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
#####
|
||||
# Prompt Caching Configs
|
||||
#####
|
||||
# Enable prompt caching framework
|
||||
ENABLE_PROMPT_CACHING = (
|
||||
os.environ.get("ENABLE_PROMPT_CACHING", "true").lower() != "false"
|
||||
)
|
||||
|
||||
# Cache TTL multiplier - store caches slightly longer than provider TTL
|
||||
# This allows for some clock skew and ensures we don't lose cache metadata prematurely
|
||||
PROMPT_CACHE_REDIS_TTL_MULTIPLIER = float(
|
||||
os.environ.get("PROMPT_CACHE_REDIS_TTL_MULTIPLIER") or 1.2
|
||||
)
|
||||
|
||||
@@ -4,8 +4,6 @@ import os
|
||||
# Onyx Slack Bot Configs
|
||||
#####
|
||||
ONYX_BOT_NUM_RETRIES = int(os.environ.get("ONYX_BOT_NUM_RETRIES", "5"))
|
||||
# How much of the available input context can be used for thread context
|
||||
MAX_THREAD_CONTEXT_PERCENTAGE = 512 * 2 / 3072
|
||||
# Number of docs to display in "Reference Documents"
|
||||
ONYX_BOT_NUM_DOCS_TO_DISPLAY = int(os.environ.get("ONYX_BOT_NUM_DOCS_TO_DISPLAY", "5"))
|
||||
# If the LLM fails to answer, Onyx can still show the "Reference Documents"
|
||||
@@ -47,10 +45,6 @@ ONYX_BOT_MAX_WAIT_TIME = int(os.environ.get("ONYX_BOT_MAX_WAIT_TIME") or 180)
|
||||
# Time (in minutes) after which a Slack message is sent to the user to remind him to give feedback.
|
||||
# Set to 0 to disable it (default)
|
||||
ONYX_BOT_FEEDBACK_REMINDER = int(os.environ.get("ONYX_BOT_FEEDBACK_REMINDER") or 0)
|
||||
# Set to True to rephrase the Slack users messages
|
||||
ONYX_BOT_REPHRASE_MESSAGE = (
|
||||
os.environ.get("ONYX_BOT_REPHRASE_MESSAGE", "").lower() == "true"
|
||||
)
|
||||
|
||||
# ONYX_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD is the number of
|
||||
# responses OnyxBot can send in a given time period.
|
||||
|
||||
@@ -93,7 +93,7 @@ if __name__ == "__main__":
|
||||
#### Docs Changes
|
||||
|
||||
Create the new connector page (with guiding images!) with how to get the connector credentials and how to set up the
|
||||
connector in Onyx. Then create a Pull Request in https://github.com/onyx-dot-app/onyx-docs.
|
||||
connector in Onyx. Then create a Pull Request in [https://github.com/onyx-dot-app/documentation](https://github.com/onyx-dot-app/documentation).
|
||||
|
||||
### Before opening PR
|
||||
|
||||
|
||||
@@ -25,11 +25,17 @@ class AsanaConnector(LoadConnector, PollConnector):
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
) -> None:
|
||||
self.workspace_id = asana_workspace_id
|
||||
self.project_ids_to_index: list[str] | None = (
|
||||
asana_project_ids.split(",") if asana_project_ids is not None else None
|
||||
)
|
||||
self.asana_team_id = asana_team_id
|
||||
self.workspace_id = asana_workspace_id.strip()
|
||||
if asana_project_ids:
|
||||
project_ids = [
|
||||
project_id.strip()
|
||||
for project_id in asana_project_ids.split(",")
|
||||
if project_id.strip()
|
||||
]
|
||||
self.project_ids_to_index = project_ids or None
|
||||
else:
|
||||
self.project_ids_to_index = None
|
||||
self.asana_team_id = (asana_team_id.strip() or None) if asana_team_id else None
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
logger.info(
|
||||
|
||||
@@ -901,13 +901,16 @@ class OnyxConfluence:
|
||||
space_key: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
This is a confluence server specific method that can be used to
|
||||
This is a confluence server/data center specific method that can be used to
|
||||
fetch the permissions of a space.
|
||||
This is better logging than calling the get_space_permissions method
|
||||
because it returns a jsonrpc response.
|
||||
TODO: Make this call these endpoints for newer confluence versions:
|
||||
- /rest/api/space/{spaceKey}/permissions
|
||||
- /rest/api/space/{spaceKey}/permissions/anonymous
|
||||
|
||||
NOTE: This uses the JSON-RPC API which is the ONLY way to get space permissions
|
||||
on Confluence Server/Data Center. The REST API equivalent (expand=permissions)
|
||||
is Cloud-only and not available on Data Center as of version 8.9.x.
|
||||
|
||||
If this fails with 401 Unauthorized, the customer needs to enable JSON-RPC:
|
||||
Confluence Admin -> General Configuration -> Further Configuration
|
||||
-> Enable "Remote API (XML-RPC & SOAP)"
|
||||
"""
|
||||
url = "rpc/json-rpc/confluenceservice-v2"
|
||||
data = {
|
||||
@@ -916,7 +919,18 @@ class OnyxConfluence:
|
||||
"id": 7,
|
||||
"params": [space_key],
|
||||
}
|
||||
response = self.post(url, data=data)
|
||||
try:
|
||||
response = self.post(url, data=data)
|
||||
except HTTPError as e:
|
||||
if e.response is not None and e.response.status_code == 401:
|
||||
raise HTTPError(
|
||||
"Unauthorized (401) when calling JSON-RPC API for space permissions. "
|
||||
"This is likely because the Remote API is disabled. "
|
||||
"To fix: Confluence Admin -> General Configuration -> Further Configuration "
|
||||
"-> Enable 'Remote API (XML-RPC & SOAP)'",
|
||||
response=e.response,
|
||||
) from e
|
||||
raise
|
||||
logger.debug(f"jsonrpc response: {response}")
|
||||
if not response.get("result"):
|
||||
logger.warning(
|
||||
@@ -961,14 +975,20 @@ def get_user_email_from_username__server(
|
||||
try:
|
||||
response = confluence_client.get_mobile_parameters(user_name)
|
||||
email = response.get("email")
|
||||
except Exception:
|
||||
logger.warning(f"failed to get confluence email for {user_name}")
|
||||
except HTTPError as e:
|
||||
status_code = e.response.status_code if e.response is not None else "N/A"
|
||||
logger.warning(
|
||||
f"Failed to get confluence email for {user_name}: "
|
||||
f"HTTP {status_code} - {e}"
|
||||
)
|
||||
# For now, we'll just return None and log a warning. This means
|
||||
# we will keep retrying to get the email every group sync.
|
||||
email = None
|
||||
# We may want to just return a string that indicates failure so we dont
|
||||
# keep retrying
|
||||
# email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to get confluence email for {user_name}: {type(e).__name__} - {e}"
|
||||
)
|
||||
email = None
|
||||
_USER_EMAIL_CACHE[user_name] = email
|
||||
return _USER_EMAIL_CACHE[user_name]
|
||||
|
||||
|
||||
@@ -97,10 +97,17 @@ def basic_expert_info_representation(info: BasicExpertInfo) -> str | None:
|
||||
def get_experts_stores_representations(
|
||||
experts: list[BasicExpertInfo] | None,
|
||||
) -> list[str] | None:
|
||||
"""Gets string representations of experts supplied.
|
||||
|
||||
If an expert cannot be represented as a string, it is omitted from the
|
||||
result.
|
||||
"""
|
||||
if not experts:
|
||||
return None
|
||||
|
||||
reps = [basic_expert_info_representation(owner) for owner in experts]
|
||||
reps: list[str | None] = [
|
||||
basic_expert_info_representation(owner) for owner in experts
|
||||
]
|
||||
return [owner for owner in reps if owner is not None]
|
||||
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from typing_extensions import override
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
|
||||
from onyx.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
|
||||
from onyx.configs.app_configs import JIRA_SLIM_PAGE_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
is_atlassian_date_error,
|
||||
@@ -57,7 +58,6 @@ logger = setup_logger()
|
||||
ONE_HOUR = 3600
|
||||
|
||||
_MAX_RESULTS_FETCH_IDS = 5000 # 5000
|
||||
_JIRA_SLIM_PAGE_SIZE = 500
|
||||
_JIRA_FULL_PAGE_SIZE = 50
|
||||
|
||||
# Constants for Jira field names
|
||||
@@ -683,7 +683,7 @@ class JiraConnector(
|
||||
jira_client=self.jira_client,
|
||||
jql=jql,
|
||||
start=current_offset,
|
||||
max_results=_JIRA_SLIM_PAGE_SIZE,
|
||||
max_results=JIRA_SLIM_PAGE_SIZE,
|
||||
all_issue_ids=checkpoint.all_issue_ids,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
nextPageToken=checkpoint.cursor,
|
||||
@@ -703,11 +703,11 @@ class JiraConnector(
|
||||
)
|
||||
)
|
||||
current_offset += 1
|
||||
if len(slim_doc_batch) >= _JIRA_SLIM_PAGE_SIZE:
|
||||
if len(slim_doc_batch) >= JIRA_SLIM_PAGE_SIZE:
|
||||
yield slim_doc_batch
|
||||
slim_doc_batch = []
|
||||
self.update_checkpoint_for_next_run(
|
||||
checkpoint, current_offset, prev_offset, _JIRA_SLIM_PAGE_SIZE
|
||||
checkpoint, current_offset, prev_offset, JIRA_SLIM_PAGE_SIZE
|
||||
)
|
||||
prev_offset = current_offset
|
||||
|
||||
|
||||
@@ -161,6 +161,8 @@ class DocumentBase(BaseModel):
|
||||
sections: list[TextSection | ImageSection]
|
||||
source: DocumentSource | None = None
|
||||
semantic_identifier: str # displayed in the UI as the main identifier for the doc
|
||||
# TODO(andrei): Ideally we could improve this to where each value is just a
|
||||
# list of strings.
|
||||
metadata: dict[str, str | list[str]]
|
||||
|
||||
# UTC time
|
||||
@@ -202,13 +204,7 @@ class DocumentBase(BaseModel):
|
||||
if not self.metadata:
|
||||
return None
|
||||
# Combined string for the key/value for easy filtering
|
||||
attributes: list[str] = []
|
||||
for k, v in self.metadata.items():
|
||||
if isinstance(v, list):
|
||||
attributes.extend([k + INDEX_SEPARATOR + vi for vi in v])
|
||||
else:
|
||||
attributes.append(k + INDEX_SEPARATOR + v)
|
||||
return attributes
|
||||
return convert_metadata_dict_to_list_of_strings(self.metadata)
|
||||
|
||||
def __sizeof__(self) -> int:
|
||||
size = sys.getsizeof(self.id)
|
||||
@@ -240,6 +236,66 @@ class DocumentBase(BaseModel):
|
||||
return " ".join([section.text for section in self.sections if section.text])
|
||||
|
||||
|
||||
def convert_metadata_dict_to_list_of_strings(
|
||||
metadata: dict[str, str | list[str]],
|
||||
) -> list[str]:
|
||||
"""Converts a metadata dict to a list of strings.
|
||||
|
||||
Each string is a key-value pair separated by the INDEX_SEPARATOR. If a key
|
||||
points to a list of values, each value generates a unique pair.
|
||||
|
||||
Args:
|
||||
metadata: The metadata dict to convert where values can be either a
|
||||
string or a list of strings.
|
||||
|
||||
Returns:
|
||||
A list of strings where each string is a key-value pair separated by the
|
||||
INDEX_SEPARATOR.
|
||||
"""
|
||||
attributes: list[str] = []
|
||||
for k, v in metadata.items():
|
||||
if isinstance(v, list):
|
||||
attributes.extend([k + INDEX_SEPARATOR + vi for vi in v])
|
||||
else:
|
||||
attributes.append(k + INDEX_SEPARATOR + v)
|
||||
return attributes
|
||||
|
||||
|
||||
def convert_metadata_list_of_strings_to_dict(
|
||||
metadata_list: list[str],
|
||||
) -> dict[str, str | list[str]]:
|
||||
"""
|
||||
Converts a list of strings to a metadata dict. The inverse of
|
||||
convert_metadata_dict_to_list_of_strings.
|
||||
|
||||
Assumes the input strings are formatted as in the output of
|
||||
convert_metadata_dict_to_list_of_strings.
|
||||
|
||||
The schema of the output metadata dict is suboptimal yet bound to legacy
|
||||
code. Ideally each key would just point to a list of strings, where each
|
||||
list might contain just one element.
|
||||
|
||||
Args:
|
||||
metadata_list: The list of strings to convert to a metadata dict.
|
||||
|
||||
Returns:
|
||||
A metadata dict where values can be either a string or a list of
|
||||
strings.
|
||||
"""
|
||||
metadata: dict[str, str | list[str]] = {}
|
||||
for item in metadata_list:
|
||||
key, value = item.split(INDEX_SEPARATOR, 1)
|
||||
if key in metadata:
|
||||
# We have already seen this key therefore it must point to a list.
|
||||
if isinstance(metadata[key], list):
|
||||
cast(list[str], metadata[key]).append(value)
|
||||
else:
|
||||
metadata[key] = [cast(str, metadata[key]), value]
|
||||
else:
|
||||
metadata[key] = value
|
||||
return metadata
|
||||
|
||||
|
||||
class Document(DocumentBase):
|
||||
"""Used for Onyx ingestion api, the ID is required"""
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import io
|
||||
import ipaddress
|
||||
import random
|
||||
import socket
|
||||
@@ -36,10 +35,11 @@ from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import read_pdf_file
|
||||
from onyx.file_processing.html_utils import web_html_cleanup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.sitemap import list_pages_for_site
|
||||
from onyx.utils.web_content import extract_pdf_text
|
||||
from onyx.utils.web_content import is_pdf_resource
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -116,16 +116,6 @@ DEFAULT_HEADERS = {
|
||||
"Sec-CH-UA-Platform": '"macOS"',
|
||||
}
|
||||
|
||||
# Common PDF MIME types
|
||||
PDF_MIME_TYPES = [
|
||||
"application/pdf",
|
||||
"application/x-pdf",
|
||||
"application/acrobat",
|
||||
"application/vnd.pdf",
|
||||
"text/pdf",
|
||||
"text/x-pdf",
|
||||
]
|
||||
|
||||
|
||||
class WEB_CONNECTOR_VALID_SETTINGS(str, Enum):
|
||||
# Given a base site, index everything under that path
|
||||
@@ -268,12 +258,6 @@ def get_internal_links(
|
||||
return internal_links
|
||||
|
||||
|
||||
def is_pdf_content(response: requests.Response) -> bool:
|
||||
"""Check if the response contains PDF content based on content-type header"""
|
||||
content_type = response.headers.get("content-type", "").lower()
|
||||
return any(pdf_type in content_type for pdf_type in PDF_MIME_TYPES)
|
||||
|
||||
|
||||
def start_playwright() -> Tuple[Playwright, BrowserContext]:
|
||||
playwright = sync_playwright().start()
|
||||
|
||||
@@ -529,14 +513,13 @@ class WebConnector(LoadConnector):
|
||||
head_response = requests.head(
|
||||
initial_url, headers=DEFAULT_HEADERS, allow_redirects=True
|
||||
)
|
||||
is_pdf = is_pdf_content(head_response)
|
||||
content_type = head_response.headers.get("content-type")
|
||||
is_pdf = is_pdf_resource(initial_url, content_type)
|
||||
|
||||
if is_pdf or initial_url.lower().endswith(".pdf"):
|
||||
if is_pdf:
|
||||
# PDF files are not checked for links
|
||||
response = requests.get(initial_url, headers=DEFAULT_HEADERS)
|
||||
page_text, metadata, images = read_pdf_file(
|
||||
file=io.BytesIO(response.content)
|
||||
)
|
||||
page_text, metadata = extract_pdf_text(response.content)
|
||||
last_modified = response.headers.get("Last-Modified")
|
||||
|
||||
result.doc = Document(
|
||||
|
||||
@@ -13,13 +13,6 @@ class RecencyBiasSetting(str, Enum):
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class OptionalSearchSetting(str, Enum):
|
||||
ALWAYS = "always"
|
||||
NEVER = "never"
|
||||
# Determine whether to run search based on history and latest query
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class QueryType(str, Enum):
|
||||
"""
|
||||
The type of first-pass query to use for hybrid search.
|
||||
@@ -36,15 +29,3 @@ class SearchType(str, Enum):
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
INTERNET = "internet"
|
||||
|
||||
|
||||
class LLMEvaluationType(str, Enum):
|
||||
AGENTIC = "agentic" # applies agentic evaluation
|
||||
BASIC = "basic" # applies boolean evaluation
|
||||
SKIP = "skip" # skips evaluation
|
||||
UNSPECIFIED = "unspecified" # reverts to default
|
||||
|
||||
|
||||
class QueryFlow(str, Enum):
|
||||
SEARCH = "search"
|
||||
QUESTION_ANSWER = "question-answer"
|
||||
|
||||
@@ -31,7 +31,6 @@ from onyx.context.search.federated.slack_search_utils import is_recency_query
|
||||
from onyx.context.search.federated.slack_search_utils import should_include_message
|
||||
from onyx.context.search.models import ChunkIndexRequest
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import SearchQuery
|
||||
from onyx.db.document import DocumentSource
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.document_index_utils import (
|
||||
@@ -425,7 +424,6 @@ class SlackQueryResult(BaseModel):
|
||||
|
||||
def query_slack(
|
||||
query_string: str,
|
||||
original_query: SearchQuery,
|
||||
access_token: str,
|
||||
limit: int | None = None,
|
||||
allowed_private_channel: str | None = None,
|
||||
@@ -456,7 +454,7 @@ def query_slack(
|
||||
logger.info(f"Final query to slack: {final_query}")
|
||||
|
||||
# Detect if query asks for most recent results
|
||||
sort_by_time = is_recency_query(original_query.query)
|
||||
sort_by_time = is_recency_query(query_string)
|
||||
|
||||
slack_client = WebClient(token=access_token)
|
||||
try:
|
||||
@@ -536,8 +534,7 @@ def query_slack(
|
||||
)
|
||||
document_id = f"{channel_id}_{message_id}"
|
||||
|
||||
# compute recency bias (parallels vespa calculation) and metadata
|
||||
decay_factor = DOC_TIME_DECAY * original_query.recency_bias_multiplier
|
||||
decay_factor = DOC_TIME_DECAY
|
||||
doc_time = datetime.fromtimestamp(float(message_id))
|
||||
doc_age_years = (datetime.now() - doc_time).total_seconds() / (
|
||||
365 * 24 * 60 * 60
|
||||
@@ -1002,7 +999,6 @@ def slack_retrieval(
|
||||
query_slack,
|
||||
(
|
||||
query_string,
|
||||
query,
|
||||
access_token,
|
||||
query_limit,
|
||||
allowed_private_channel,
|
||||
@@ -1045,7 +1041,6 @@ def slack_retrieval(
|
||||
query_slack,
|
||||
(
|
||||
query_string,
|
||||
query,
|
||||
access_token,
|
||||
query_limit,
|
||||
allowed_private_channel,
|
||||
@@ -1225,7 +1220,6 @@ def slack_retrieval(
|
||||
source_type=DocumentSource.SLACK,
|
||||
title=chunk.title_prefix,
|
||||
boost=0,
|
||||
recency_bias=docid_to_message[document_id].recency_bias,
|
||||
score=convert_slack_score(docid_to_message[document_id].slack_score),
|
||||
hidden=False,
|
||||
is_relevant=None,
|
||||
|
||||
@@ -13,6 +13,7 @@ from onyx.context.search.federated.models import ChannelMetadata
|
||||
from onyx.context.search.models import ChunkIndexRequest
|
||||
from onyx.federated_connectors.slack.models import SlackEntities
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import UserMessage
|
||||
from onyx.llm.utils import llm_response_to_string
|
||||
from onyx.onyxbot.slack.models import ChannelType
|
||||
from onyx.prompts.federated_search import SLACK_DATE_EXTRACTION_PROMPT
|
||||
@@ -190,7 +191,7 @@ def extract_date_range_from_query(
|
||||
|
||||
try:
|
||||
prompt = SLACK_DATE_EXTRACTION_PROMPT.format(query=query)
|
||||
response = llm_response_to_string(llm.invoke(prompt))
|
||||
response = llm_response_to_string(llm.invoke(UserMessage(content=prompt)))
|
||||
|
||||
response_clean = _parse_llm_code_block_response(response)
|
||||
|
||||
@@ -566,6 +567,23 @@ def extract_content_words_from_recency_query(
|
||||
return content_words_filtered[:MAX_CONTENT_WORDS]
|
||||
|
||||
|
||||
def _is_valid_keyword_query(line: str) -> bool:
|
||||
"""Check if a line looks like a valid keyword query vs explanatory text.
|
||||
|
||||
Returns False for lines that appear to be LLM explanations rather than keywords.
|
||||
"""
|
||||
# Reject lines that start with parentheses (explanatory notes)
|
||||
if line.startswith("("):
|
||||
return False
|
||||
|
||||
# Reject lines that are too long (likely sentences, not keywords)
|
||||
# Keywords should be short - reject if > 50 chars or > 6 words
|
||||
if len(line) > 50 or len(line.split()) > 6:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
"""Use LLM to expand query into multiple search variations.
|
||||
|
||||
@@ -576,8 +594,10 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
Returns:
|
||||
List of rephrased query strings (up to MAX_SLACK_QUERY_EXPANSIONS)
|
||||
"""
|
||||
prompt = SLACK_QUERY_EXPANSION_PROMPT.format(
|
||||
query=query_text, max_queries=MAX_SLACK_QUERY_EXPANSIONS
|
||||
prompt = UserMessage(
|
||||
content=SLACK_QUERY_EXPANSION_PROMPT.format(
|
||||
query=query_text, max_queries=MAX_SLACK_QUERY_EXPANSIONS
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -586,10 +606,18 @@ def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]:
|
||||
response_clean = _parse_llm_code_block_response(response)
|
||||
|
||||
# Split into lines and filter out empty lines
|
||||
rephrased_queries = [
|
||||
raw_queries = [
|
||||
line.strip() for line in response_clean.split("\n") if line.strip()
|
||||
]
|
||||
|
||||
# Filter out lines that look like explanatory text rather than keywords
|
||||
rephrased_queries = [q for q in raw_queries if _is_valid_keyword_query(q)]
|
||||
|
||||
# Log if we filtered out garbage
|
||||
if len(raw_queries) != len(rephrased_queries):
|
||||
filtered_out = set(raw_queries) - set(rephrased_queries)
|
||||
logger.warning(f"Filtered out non-keyword LLM responses: {filtered_out}")
|
||||
|
||||
# If no queries generated, use empty query
|
||||
if not rephrased_queries:
|
||||
logger.debug("No content keywords extracted from query expansion")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user