mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-16 21:22:41 +00:00
Compare commits
85 Commits
feat/file-
...
v2.12.7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d676634f1c | ||
|
|
c49e9a93e8 | ||
|
|
12b2ab2459 | ||
|
|
f8005eb90c | ||
|
|
152a950710 | ||
|
|
52eaf3d706 | ||
|
|
96b7cc1711 | ||
|
|
47d3c511d4 | ||
|
|
414944ac47 | ||
|
|
dd13f730da | ||
|
|
bb71a91689 | ||
|
|
af0721e063 | ||
|
|
567651a812 | ||
|
|
7589767bb9 | ||
|
|
589d613f1e | ||
|
|
b17d7e0033 | ||
|
|
131d418771 | ||
|
|
0be04391b3 | ||
|
|
20351d9998 | ||
|
|
22152ad871 | ||
|
|
7caf197f98 | ||
|
|
140bc82b36 | ||
|
|
e7ecbfafd1 | ||
|
|
2c2af369f5 | ||
|
|
2032b76fbf | ||
|
|
055b30b00e | ||
|
|
360a4cf591 | ||
|
|
3d3cab9f91 | ||
|
|
6120d012ba | ||
|
|
3e7e2e93f2 | ||
|
|
ccf482fa3b | ||
|
|
fd45a612da | ||
|
|
c444d8883b | ||
|
|
9947837f9f | ||
|
|
bc324a8070 | ||
|
|
26f648c24a | ||
|
|
638f20f5f3 | ||
|
|
f6ee57f523 | ||
|
|
aae6fc7aac | ||
|
|
5d7a664250 | ||
|
|
e7386490bf | ||
|
|
106e10a143 | ||
|
|
513f430a1b | ||
|
|
696d73822f | ||
|
|
bfcc5a20a2 | ||
|
|
efe3613354 | ||
|
|
62405bdc42 | ||
|
|
8f505dc45f | ||
|
|
75f0db4fe5 | ||
|
|
f0a5c579a3 | ||
|
|
293bf30847 | ||
|
|
8774ca3b0f | ||
|
|
016a73f85f | ||
|
|
2eddb4e23e | ||
|
|
0a61660a59 | ||
|
|
a10599e76e | ||
|
|
b3d3f7af76 | ||
|
|
03d919c918 | ||
|
|
71d2ae563a | ||
|
|
19f9c7357c | ||
|
|
f8fa5b243c | ||
|
|
5f845c208f | ||
|
|
d8595f8de0 | ||
|
|
5b00d1ef9c | ||
|
|
41b6ed92a9 | ||
|
|
07f35336ad | ||
|
|
4728bb87c7 | ||
|
|
adfa2f30af | ||
|
|
9dac4165fb | ||
|
|
7d2ede5efc | ||
|
|
4592f6885f | ||
|
|
9dc14fad79 | ||
|
|
ff6e471cfb | ||
|
|
09b9443405 | ||
|
|
14cd6d08e8 | ||
|
|
5ee16697ce | ||
|
|
b794f7e10d | ||
|
|
bb3275bb75 | ||
|
|
7644e225a5 | ||
|
|
811600b84a | ||
|
|
40ce8615ff | ||
|
|
0cee3f6960 | ||
|
|
8883e5608f | ||
|
|
7c2f3ded44 | ||
|
|
aa094ce1f0 |
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@@ -8,5 +8,5 @@
|
||||
|
||||
## Additional Options
|
||||
|
||||
- [ ] [Required] I have considered whether this PR needs to be cherry-picked to the latest beta branch.
|
||||
- [ ] [Optional] Please cherry-pick this PR to the latest release version.
|
||||
- [ ] [Optional] Override Linear Check
|
||||
|
||||
32
.github/workflows/deployment.yml
vendored
32
.github/workflows/deployment.yml
vendored
@@ -249,7 +249,7 @@ jobs:
|
||||
xdg-utils
|
||||
|
||||
- name: setup node
|
||||
uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v6.1.0
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v6.2.0
|
||||
with:
|
||||
node-version: 24
|
||||
package-manager-cache: false
|
||||
@@ -409,7 +409,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
@@ -482,7 +482,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
@@ -542,7 +542,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
@@ -620,7 +620,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
@@ -701,7 +701,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
@@ -769,7 +769,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
@@ -844,7 +844,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
@@ -916,7 +916,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
@@ -975,7 +975,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
@@ -1053,7 +1053,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
@@ -1126,7 +1126,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
@@ -1187,7 +1187,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
@@ -1267,7 +1267,7 @@ jobs:
|
||||
buildkitd-flags: ${{ vars.DOCKER_DEBUG == 'true' && '--debug' || '' }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
@@ -1346,7 +1346,7 @@ jobs:
|
||||
buildkitd-flags: ${{ vars.DOCKER_DEBUG == 'true' && '--debug' || '' }}
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
@@ -1409,7 +1409,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
2
.github/workflows/docker-tag-beta.yml
vendored
2
.github/workflows/docker-tag-beta.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
2
.github/workflows/docker-tag-latest.yml
vendored
2
.github/workflows/docker-tag-latest.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
151
.github/workflows/nightly-scan-licenses.yml
vendored
151
.github/workflows/nightly-scan-licenses.yml
vendored
@@ -1,151 +0,0 @@
|
||||
# Scan for problematic software licenses
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
|
||||
name: 'Nightly - Scan licenses'
|
||||
on:
|
||||
# schedule:
|
||||
# - cron: '0 14 * * *' # Runs every day at 6 AM PST / 7 AM PDT / 2 PM UTC
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
scan-licenses:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-licenses"]
|
||||
timeout-minutes: 45
|
||||
permissions:
|
||||
actions: read
|
||||
contents: read
|
||||
security-events: write
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # ratchet:actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
|
||||
- name: Get explicit and transitive dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
pip freeze > requirements-all.txt
|
||||
|
||||
- name: Check python
|
||||
id: license_check_report
|
||||
uses: pilosus/action-pip-license-checker@e909b0226ff49d3235c99c4585bc617f49fff16a # ratchet:pilosus/action-pip-license-checker@v3
|
||||
with:
|
||||
requirements: 'requirements-all.txt'
|
||||
fail: 'Copyleft'
|
||||
exclude: '(?i)^(pylint|aio[-_]*).*'
|
||||
|
||||
- name: Print report
|
||||
if: always()
|
||||
env:
|
||||
REPORT: ${{ steps.license_check_report.outputs.report }}
|
||||
run: echo "$REPORT"
|
||||
|
||||
- name: Install npm dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
# be careful enabling the sarif and upload as it may spam the security tab
|
||||
# with a huge amount of items. Work out the issues before enabling upload.
|
||||
# - name: Run Trivy vulnerability scanner in repo mode
|
||||
# if: always()
|
||||
# uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
|
||||
# with:
|
||||
# scan-type: fs
|
||||
# scan-ref: .
|
||||
# scanners: license
|
||||
# format: table
|
||||
# severity: HIGH,CRITICAL
|
||||
# # format: sarif
|
||||
# # output: trivy-results.sarif
|
||||
#
|
||||
# # - name: Upload Trivy scan results to GitHub Security tab
|
||||
# # uses: github/codeql-action/upload-sarif@v3
|
||||
# # with:
|
||||
# # sarif_file: trivy-results.sarif
|
||||
|
||||
scan-trivy:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-trivy"]
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# Backend
|
||||
- name: Pull backend docker image
|
||||
run: docker pull onyxdotapp/onyx-backend:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner on backend
|
||||
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-backend:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0 # Set to 1 if we want a failed scan to fail the workflow
|
||||
|
||||
# Web server
|
||||
- name: Pull web server docker image
|
||||
run: docker pull onyxdotapp/onyx-web-server:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner on web server
|
||||
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-web-server:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0
|
||||
|
||||
# Model server
|
||||
- name: Pull model server docker image
|
||||
run: docker pull onyxdotapp/onyx-model-server:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-model-server:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0
|
||||
79
.github/workflows/post-merge-beta-cherry-pick.yml
vendored
Normal file
79
.github/workflows/post-merge-beta-cherry-pick.yml
vendored
Normal file
@@ -0,0 +1,79 @@
|
||||
name: Post-Merge Beta Cherry-Pick
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
cherry-pick-to-latest-release:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Resolve merged PR and checkbox state
|
||||
id: gate
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
# For the commit that triggered this workflow (HEAD on main), fetch all
|
||||
# associated PRs and keep only the PR that was actually merged into main
|
||||
# with this exact merge commit SHA.
|
||||
pr_numbers="$(gh api "repos/${GITHUB_REPOSITORY}/commits/${GITHUB_SHA}/pulls" | jq -r --arg sha "${GITHUB_SHA}" '.[] | select(.merged_at != null and .base.ref == "main" and .merge_commit_sha == $sha) | .number')"
|
||||
match_count="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | wc -l | tr -d ' ')"
|
||||
pr_number="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | head -n 1)"
|
||||
|
||||
if [ "${match_count}" -gt 1 ]; then
|
||||
echo "::warning::Multiple merged PRs matched commit ${GITHUB_SHA}. Using PR #${pr_number}."
|
||||
fi
|
||||
|
||||
if [ -z "$pr_number" ]; then
|
||||
echo "No merged PR associated with commit ${GITHUB_SHA}; skipping."
|
||||
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Read the PR body and check whether the helper checkbox is checked.
|
||||
pr_body="$(gh api "repos/${GITHUB_REPOSITORY}/pulls/${pr_number}" --jq '.body // ""')"
|
||||
echo "pr_number=$pr_number" >> "$GITHUB_OUTPUT"
|
||||
|
||||
if echo "$pr_body" | grep -qiE "\\[x\\][[:space:]]*(\\[[^]]+\\][[:space:]]*)?Please cherry-pick this PR to the latest release version"; then
|
||||
echo "should_cherrypick=true" >> "$GITHUB_OUTPUT"
|
||||
echo "Cherry-pick checkbox checked for PR #${pr_number}."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
|
||||
echo "Cherry-pick checkbox not checked for PR #${pr_number}. Skipping."
|
||||
|
||||
- name: Checkout repository
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: true
|
||||
ref: main
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
- name: Configure git identity
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
- name: Create cherry-pick PR to latest release
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify
|
||||
28
.github/workflows/pr-beta-cherrypick-check.yml
vendored
28
.github/workflows/pr-beta-cherrypick-check.yml
vendored
@@ -1,28 +0,0 @@
|
||||
name: Require beta cherry-pick consideration
|
||||
concurrency:
|
||||
group: Require-Beta-Cherrypick-Consideration-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, edited, reopened, synchronize]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
beta-cherrypick-check:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Check PR body for beta cherry-pick consideration
|
||||
env:
|
||||
PR_BODY: ${{ github.event.pull_request.body }}
|
||||
run: |
|
||||
if echo "$PR_BODY" | grep -qiE "\\[x\\][[:space:]]*\\[Required\\][[:space:]]*I have considered whether this PR needs to be cherry[- ]picked to the latest beta branch"; then
|
||||
echo "Cherry-pick consideration box is checked. Check passed."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "::error::Please check the 'I have considered whether this PR needs to be cherry-picked to the latest beta branch' box in the PR description."
|
||||
exit 1
|
||||
5
.github/workflows/pr-database-tests.yml
vendored
5
.github/workflows/pr-database-tests.yml
vendored
@@ -40,13 +40,16 @@ jobs:
|
||||
|
||||
- name: Generate OpenAPI schema and Python client
|
||||
shell: bash
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
env:
|
||||
LICENSE_ENFORCEMENT_ENABLED: "false"
|
||||
run: |
|
||||
ods openapi all
|
||||
|
||||
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
8
.github/workflows/pr-desktop-build.yml
vendored
8
.github/workflows/pr-desktop-build.yml
vendored
@@ -45,12 +45,12 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238
|
||||
with:
|
||||
node-version: 24
|
||||
cache: "npm" # zizmor: ignore[cache-poisoning]
|
||||
@@ -63,7 +63,7 @@ jobs:
|
||||
targets: ${{ matrix.target }}
|
||||
|
||||
- name: Cache Cargo registry and build
|
||||
uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # zizmor: ignore[cache-poisoning]
|
||||
uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # zizmor: ignore[cache-poisoning]
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/bin/
|
||||
@@ -105,7 +105,7 @@ jobs:
|
||||
|
||||
- name: Upload build artifacts
|
||||
if: always()
|
||||
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: desktop-build-${{ matrix.platform }}-${{ github.run_id }}
|
||||
path: |
|
||||
|
||||
@@ -110,7 +110,7 @@ jobs:
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
3
.github/workflows/pr-helm-chart-testing.yml
vendored
3
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -41,8 +41,7 @@ jobs:
|
||||
version: v3.19.0
|
||||
|
||||
- name: Set up chart-testing
|
||||
# NOTE: This is Jamison's patch from https://github.com/helm/chart-testing-action/pull/194
|
||||
uses: helm/chart-testing-action@8958a6ac472cbd8ee9a8fbb6f1acbc1b0e966e44 # zizmor: ignore[impostor-commit]
|
||||
uses: helm/chart-testing-action@b5eebdd9998021f29756c53432f48dab66394810
|
||||
with:
|
||||
uv_version: "0.9.9"
|
||||
|
||||
|
||||
13
.github/workflows/pr-integration-tests.yml
vendored
13
.github/workflows/pr-integration-tests.yml
vendored
@@ -109,7 +109,7 @@ jobs:
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -169,7 +169,7 @@ jobs:
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -214,7 +214,7 @@ jobs:
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -287,7 +287,7 @@ jobs:
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -302,6 +302,8 @@ jobs:
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
@@ -466,7 +468,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -478,6 +480,7 @@ jobs:
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
LICENSE_ENFORCEMENT_ENABLED=false \
|
||||
MULTI_TENANT=true \
|
||||
AUTH_TYPE=cloud \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
|
||||
2
.github/workflows/pr-jest-tests.yml
vendored
2
.github/workflows/pr-jest-tests.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v4
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
|
||||
@@ -101,7 +101,7 @@ jobs:
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -161,7 +161,7 @@ jobs:
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -220,7 +220,7 @@ jobs:
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -279,7 +279,7 @@ jobs:
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
60
.github/workflows/pr-playwright-tests.yml
vendored
60
.github/workflows/pr-playwright-tests.yml
vendored
@@ -22,6 +22,9 @@ env:
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
|
||||
FIRECRAWL_API_KEY: ${{ secrets.FIRECRAWL_API_KEY }}
|
||||
GOOGLE_PSE_API_KEY: ${{ secrets.GOOGLE_PSE_API_KEY }}
|
||||
GOOGLE_PSE_SEARCH_ENGINE_ID: ${{ secrets.GOOGLE_PSE_SEARCH_ENGINE_ID }}
|
||||
|
||||
# for federated slack tests
|
||||
SLACK_CLIENT_ID: ${{ secrets.SLACK_CLIENT_ID }}
|
||||
@@ -90,7 +93,7 @@ jobs:
|
||||
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -151,7 +154,7 @@ jobs:
|
||||
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -212,7 +215,7 @@ jobs:
|
||||
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -259,7 +262,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v4
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
@@ -291,6 +294,8 @@ jobs:
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
|
||||
EXA_API_KEY=${EXA_API_KEY_VALUE}
|
||||
@@ -305,7 +310,7 @@ jobs:
|
||||
# otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -465,48 +470,3 @@ jobs:
|
||||
- name: Check job status
|
||||
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
|
||||
run: exit 1
|
||||
|
||||
# NOTE: Chromatic UI diff testing is currently disabled.
|
||||
# We are using Playwright for local and CI testing without visual regression checks.
|
||||
# Chromatic may be reintroduced in the future for UI diff testing if needed.
|
||||
|
||||
# chromatic-tests:
|
||||
# name: Chromatic Tests
|
||||
|
||||
# needs: playwright-tests
|
||||
# runs-on:
|
||||
# [
|
||||
# runs-on,
|
||||
# runner=32cpu-linux-x64,
|
||||
# disk=large,
|
||||
# "run-id=${{ github.run_id }}",
|
||||
# ]
|
||||
# steps:
|
||||
# - name: Checkout code
|
||||
# uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
# with:
|
||||
# fetch-depth: 0
|
||||
|
||||
# - name: Setup node
|
||||
# uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v4
|
||||
# with:
|
||||
# node-version: 22
|
||||
|
||||
# - name: Install node dependencies
|
||||
# working-directory: ./web
|
||||
# run: npm ci
|
||||
|
||||
# - name: Download Playwright test results
|
||||
# uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # ratchet:actions/download-artifact@v4
|
||||
# with:
|
||||
# name: test-results
|
||||
# path: ./web/test-results
|
||||
|
||||
# - name: Run Chromatic
|
||||
# uses: chromaui/action@latest
|
||||
# with:
|
||||
# playwright: true
|
||||
# projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
|
||||
# workingDir: ./web
|
||||
# env:
|
||||
# CHROMATIC_ARCHIVE_LOCATION: ./test-results
|
||||
|
||||
3
.github/workflows/pr-python-checks.yml
vendored
3
.github/workflows/pr-python-checks.yml
vendored
@@ -42,6 +42,9 @@ jobs:
|
||||
|
||||
- name: Generate OpenAPI schema and Python client
|
||||
shell: bash
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
env:
|
||||
LICENSE_ENFORCEMENT_ENABLED: "false"
|
||||
run: |
|
||||
ods openapi all
|
||||
|
||||
|
||||
2
.github/workflows/pr-python-model-tests.yml
vendored
2
.github/workflows/pr-python-model-tests.yml
vendored
@@ -64,7 +64,7 @@ jobs:
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
2
.github/workflows/pr-python-tests.yml
vendored
2
.github/workflows/pr-python-tests.yml
vendored
@@ -27,6 +27,8 @@ jobs:
|
||||
PYTHONPATH: ./backend
|
||||
REDIS_CLOUD_PYTEST_PASSWORD: ${{ secrets.REDIS_CLOUD_PYTEST_PASSWORD }}
|
||||
DISABLE_TELEMETRY: "true"
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
LICENSE_ENFORCEMENT_ENABLED: "false"
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
4
.github/workflows/pr-quality-checks.yml
vendored
4
.github/workflows/pr-quality-checks.yml
vendored
@@ -24,13 +24,13 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # ratchet:actions/setup-python@v6
|
||||
- uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # ratchet:actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Setup Terraform
|
||||
uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # ratchet:hashicorp/setup-terraform@v3
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v6
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v6
|
||||
with: # zizmor: ignore[cache-poisoning]
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
|
||||
5
LICENSE
5
LICENSE
@@ -2,7 +2,10 @@ Copyright (c) 2023-present DanswerAI, Inc.
|
||||
|
||||
Portions of this software are licensed as follows:
|
||||
|
||||
- All content that resides under "ee" directories of this repository, if that directory exists, is licensed under the license defined in "backend/ee/LICENSE". Specifically all content under "backend/ee" and "web/src/app/ee" is licensed under the license defined in "backend/ee/LICENSE".
|
||||
- All content that resides under "ee" directories of this repository is licensed under the Onyx Enterprise License. Each ee directory contains an identical copy of this license at its root:
|
||||
- backend/ee/LICENSE
|
||||
- web/src/app/ee/LICENSE
|
||||
- web/src/ee/LICENSE
|
||||
- All third party components incorporated into the Onyx Software are licensed under the original license provided by the owner of the applicable component.
|
||||
- Content outside of the above mentioned directories or restrictions above is available under the "MIT Expat" license as defined below.
|
||||
|
||||
|
||||
@@ -134,6 +134,7 @@ COPY --chown=onyx:onyx ./alembic_tenants /app/alembic_tenants
|
||||
COPY --chown=onyx:onyx ./alembic.ini /app/alembic.ini
|
||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||
COPY --chown=onyx:onyx ./static /app/static
|
||||
COPY --chown=onyx:onyx ./keys /app/keys
|
||||
|
||||
# Escape hatch scripts
|
||||
COPY --chown=onyx:onyx ./scripts/debugging /app/scripts/debugging
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
The DanswerAI Enterprise license (the “Enterprise License”)
|
||||
The Onyx Enterprise License (the "Enterprise License")
|
||||
Copyright (c) 2023-present DanswerAI, Inc.
|
||||
|
||||
With regard to the Onyx Software:
|
||||
|
||||
This software and associated documentation files (the "Software") may only be
|
||||
used in production, if you (and any entity that you represent) have agreed to,
|
||||
and are in compliance with, the DanswerAI Subscription Terms of Service, available
|
||||
at https://onyx.app/terms (the “Enterprise Terms”), or other
|
||||
and are in compliance with, the Onyx Subscription Terms of Service, available
|
||||
at https://www.onyx.app/legal/self-host (the "Enterprise Terms"), or other
|
||||
agreement governing the use of the Software, as agreed by you and DanswerAI,
|
||||
and otherwise have a valid Onyx Enterprise license for the
|
||||
and otherwise have a valid Onyx Enterprise License for the
|
||||
correct number of user seats. Subject to the foregoing sentence, you are free to
|
||||
modify this Software and publish patches to the Software. You agree that DanswerAI
|
||||
and/or its licensors (as applicable) retain all right, title and interest in and
|
||||
to all such modifications and/or patches, and all such modifications and/or
|
||||
patches may only be used, copied, modified, displayed, distributed, or otherwise
|
||||
exploited with a valid Onyx Enterprise license for the correct
|
||||
exploited with a valid Onyx Enterprise License for the correct
|
||||
number of user seats. Notwithstanding the foregoing, you may copy and modify
|
||||
the Software for development and testing purposes, without requiring a
|
||||
subscription. You agree that DanswerAI and/or its licensors (as applicable) retain
|
||||
|
||||
@@ -134,7 +134,7 @@ GATED_TENANTS_KEY = "gated_tenants"
|
||||
|
||||
# License enforcement - when True, blocks API access for gated/expired licenses
|
||||
LICENSE_ENFORCEMENT_ENABLED = (
|
||||
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "").lower() == "true"
|
||||
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "true").lower() == "true"
|
||||
)
|
||||
|
||||
# Cloud data plane URL - self-hosted instances call this to reach cloud proxy endpoints
|
||||
|
||||
@@ -263,9 +263,15 @@ def refresh_license_cache(
|
||||
|
||||
try:
|
||||
payload = verify_license_signature(license_record.license_data)
|
||||
# Derive source from payload: manual licenses lack stripe_customer_id
|
||||
source: LicenseSource = (
|
||||
LicenseSource.AUTO_FETCH
|
||||
if payload.stripe_customer_id
|
||||
else LicenseSource.MANUAL_UPLOAD
|
||||
)
|
||||
return update_license_cache(
|
||||
payload,
|
||||
source=LicenseSource.AUTO_FETCH,
|
||||
source=source,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
|
||||
@@ -4,7 +4,6 @@ from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.server.analytics.api import router as analytics_router
|
||||
from ee.onyx.server.auth_check import check_ee_router_auth
|
||||
from ee.onyx.server.billing.api import router as billing_router
|
||||
@@ -151,12 +150,9 @@ def get_application() -> FastAPI:
|
||||
# License management
|
||||
include_router_with_global_prefix_prepended(application, license_router)
|
||||
|
||||
# Unified billing API - available when license system is enabled
|
||||
# Works for both self-hosted and cloud deployments
|
||||
# TODO(ENG-3533): Once frontend migrates to /admin/billing/*, this becomes the
|
||||
# primary billing API and /tenants/* billing endpoints can be removed
|
||||
if LICENSE_ENFORCEMENT_ENABLED:
|
||||
include_router_with_global_prefix_prepended(application, billing_router)
|
||||
# Unified billing API - always registered in EE.
|
||||
# Each endpoint is protected by the `current_admin_user` dependency (admin auth).
|
||||
include_router_with_global_prefix_prepended(application, billing_router)
|
||||
|
||||
if MULTI_TENANT:
|
||||
# Tenant management
|
||||
|
||||
@@ -109,7 +109,9 @@ async def _make_billing_request(
|
||||
headers = _get_headers(license_data)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=_REQUEST_TIMEOUT) as client:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=_REQUEST_TIMEOUT, follow_redirects=True
|
||||
) as client:
|
||||
if method == "GET":
|
||||
response = await client.get(url, headers=headers, params=params)
|
||||
else:
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
"""EE Settings API - provides license-aware settings override."""
|
||||
|
||||
from redis.exceptions import RedisError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from ee.onyx.db.license import refresh_license_cache
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -40,6 +44,14 @@ def check_ee_features_enabled() -> bool:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if not metadata:
|
||||
# Cache miss — warm from DB so cold-start doesn't block EE features
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
metadata = refresh_license_cache(db_session, tenant_id)
|
||||
except SQLAlchemyError as db_error:
|
||||
logger.warning(f"Failed to load license from DB: {db_error}")
|
||||
|
||||
if metadata and metadata.status != _BLOCKING_STATUS:
|
||||
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
|
||||
return True
|
||||
@@ -81,6 +93,18 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
metadata = get_cached_license_metadata(tenant_id)
|
||||
if not metadata:
|
||||
# Cache miss (e.g. after TTL expiry). Fall back to DB so
|
||||
# the /settings request doesn't falsely return GATED_ACCESS
|
||||
# while the cache is cold.
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
metadata = refresh_license_cache(db_session, tenant_id)
|
||||
except SQLAlchemyError as db_error:
|
||||
logger.warning(
|
||||
f"Failed to load license from DB for settings: {db_error}"
|
||||
)
|
||||
|
||||
if metadata:
|
||||
if metadata.status == _BLOCKING_STATUS:
|
||||
settings.application_status = metadata.status
|
||||
@@ -89,7 +113,11 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
|
||||
settings.ee_features_enabled = True
|
||||
else:
|
||||
# No license = community edition, disable EE features
|
||||
# No license found in cache or DB.
|
||||
if ENTERPRISE_EDITION_ENABLED:
|
||||
# Legacy EE flag is set → prior EE usage (e.g. permission
|
||||
# syncing) means indexed data may need protection.
|
||||
settings.application_status = _BLOCKING_STATUS
|
||||
settings.ee_features_enabled = False
|
||||
except RedisError as e:
|
||||
logger.warning(f"Failed to check license metadata for settings: {e}")
|
||||
|
||||
@@ -177,7 +177,7 @@ async def forward_to_control_plane(
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}{path}"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
if method == "GET":
|
||||
response = await client.get(url, headers=headers, params=params)
|
||||
elif method == "POST":
|
||||
|
||||
@@ -12,12 +12,14 @@ from ee.onyx.db.user_group import prepare_user_group_for_deletion
|
||||
from ee.onyx.db.user_group import update_user_curator_relationship
|
||||
from ee.onyx.db.user_group import update_user_group
|
||||
from ee.onyx.server.user_group.models import AddUsersToUserGroupRequest
|
||||
from ee.onyx.server.user_group.models import MinimalUserGroupSnapshot
|
||||
from ee.onyx.server.user_group.models import SetCuratorRequest
|
||||
from ee.onyx.server.user_group.models import UserGroup
|
||||
from ee.onyx.server.user_group.models import UserGroupCreate
|
||||
from ee.onyx.server.user_group.models import UserGroupUpdate
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
@@ -45,6 +47,23 @@ def list_user_groups(
|
||||
return [UserGroup.from_model(user_group) for user_group in user_groups]
|
||||
|
||||
|
||||
@router.get("/user-groups/minimal")
|
||||
def list_minimal_user_groups(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[MinimalUserGroupSnapshot]:
|
||||
if user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
)
|
||||
return [
|
||||
MinimalUserGroupSnapshot.from_model(user_group) for user_group in user_groups
|
||||
]
|
||||
|
||||
|
||||
@router.post("/admin/user-group")
|
||||
def create_user_group(
|
||||
user_group: UserGroupCreate,
|
||||
|
||||
@@ -76,6 +76,18 @@ class UserGroup(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class MinimalUserGroupSnapshot(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, user_group_model: UserGroupModel) -> "MinimalUserGroupSnapshot":
|
||||
return cls(
|
||||
id=user_group_model.id,
|
||||
name=user_group_model.name,
|
||||
)
|
||||
|
||||
|
||||
class UserGroupCreate(BaseModel):
|
||||
name: str
|
||||
user_ids: list[UUID]
|
||||
|
||||
@@ -60,6 +60,7 @@ from sqlalchemy import nulls_last
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
from onyx.auth.disposable_email_validator import is_disposable_email
|
||||
@@ -110,6 +111,7 @@ from onyx.db.auth import get_user_db
|
||||
from onyx.db.auth import SQLAlchemyUserAdminDB
|
||||
from onyx.db.engine.async_sql_engine import get_async_session
|
||||
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
@@ -272,6 +274,22 @@ def verify_email_domain(email: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def enforce_seat_limit(db_session: Session, seats_needed: int = 1) -> None:
|
||||
"""Raise HTTPException(402) if adding users would exceed the seat limit.
|
||||
|
||||
No-op for multi-tenant or CE deployments.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
result = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license", "check_seat_availability", None
|
||||
)(db_session, seats_needed=seats_needed)
|
||||
|
||||
if result is not None and not result.available:
|
||||
raise HTTPException(status_code=402, detail=result.error_message)
|
||||
|
||||
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = USER_AUTH_SECRET
|
||||
verification_token_secret = USER_AUTH_SECRET
|
||||
@@ -401,6 +419,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
):
|
||||
user_create.role = UserRole.ADMIN
|
||||
|
||||
# Check seat availability for new users (single-tenant only)
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
existing = get_user_by_email(user_create.email, sync_db)
|
||||
if existing is None:
|
||||
enforce_seat_limit(sync_db)
|
||||
|
||||
user_created = False
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request)
|
||||
@@ -610,6 +634,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
raise exceptions.UserNotExists()
|
||||
|
||||
except exceptions.UserNotExists:
|
||||
# Check seat availability before creating (single-tenant only)
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
enforce_seat_limit(sync_db)
|
||||
|
||||
password = self.password_helper.generate()
|
||||
user_dict = {
|
||||
"email": account_email,
|
||||
|
||||
@@ -217,9 +217,11 @@ if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
{
|
||||
"name": "check-for-documents-for-opensearch-migration",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_DOCUMENTS_FOR_OPENSEARCH_MIGRATION_TASK,
|
||||
# Try to enqueue an invocation of this task with this frequency.
|
||||
"schedule": timedelta(seconds=120), # 2 minutes
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
# If the task was not dequeued in this time, revoke it.
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
@@ -227,10 +229,18 @@ if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
beat_task_templates.append(
|
||||
{
|
||||
"name": "migrate-documents-from-vespa-to-opensearch",
|
||||
"task": OnyxCeleryTask.MIGRATE_DOCUMENT_FROM_VESPA_TO_OPENSEARCH_TASK,
|
||||
"task": OnyxCeleryTask.MIGRATE_DOCUMENTS_FROM_VESPA_TO_OPENSEARCH_TASK,
|
||||
# Try to enqueue an invocation of this task with this frequency.
|
||||
# NOTE: If MIGRATION_TASK_SOFT_TIME_LIMIT_S is greater than this
|
||||
# value and the task is maximally busy, we can expect to see some
|
||||
# enqueued tasks be revoked over time. This is ok; by erring on the
|
||||
# side of "there will probably always be at least one task of this
|
||||
# type in the queue", we are minimizing this task's idleness while
|
||||
# still giving chances for other tasks to execute.
|
||||
"schedule": timedelta(seconds=120), # 2 minutes
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
# If the task was not dequeued in this time, revoke it.
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
# Tasks are expected to cease execution and do cleanup after the soft time
|
||||
# limit. In principle they are also forceably terminated after the hard time
|
||||
# limit, in practice this does not happen since we use threadpools for Celery
|
||||
# task execution, and we simple hope that the total task time plus cleanup does
|
||||
# not exceed this. Therefore tasks should regularly check their timeout and lock
|
||||
# status. The lock timeout is the maximum time the lock manager (Redis in this
|
||||
# case) will enforce the lock, independent of what is happening in the task. To
|
||||
# reduce the chances that a task is still doing work while a lock has expired,
|
||||
# make the lock timeout well above the task timeouts. In practice we should
|
||||
# never see locks be held for this long anyway because a task should release the
|
||||
# lock after its cleanup which happens at most after its soft timeout.
|
||||
|
||||
# Constants corresponding to migrate_documents_from_vespa_to_opensearch_task.
|
||||
MIGRATION_TASK_SOFT_TIME_LIMIT_S = 60 * 5 # 5 minutes.
|
||||
MIGRATION_TASK_TIME_LIMIT_S = 60 * 6 # 6 minutes.
|
||||
# The maximum time the lock can be held for. Will automatically be released
|
||||
# after this time.
|
||||
MIGRATION_TASK_LOCK_TIMEOUT_S = 60 * 7 # 7 minutes.
|
||||
assert (
|
||||
MIGRATION_TASK_SOFT_TIME_LIMIT_S < MIGRATION_TASK_TIME_LIMIT_S
|
||||
), "The soft time limit must be less than the time limit."
|
||||
assert (
|
||||
MIGRATION_TASK_TIME_LIMIT_S < MIGRATION_TASK_LOCK_TIMEOUT_S
|
||||
), "The time limit must be less than the lock timeout."
|
||||
# Time to wait to acquire the lock.
|
||||
MIGRATION_TASK_LOCK_BLOCKING_TIMEOUT_S = 60 * 2 # 2 minutes.
|
||||
|
||||
# Constants corresponding to check_for_documents_for_opensearch_migration_task.
|
||||
CHECK_FOR_DOCUMENTS_TASK_SOFT_TIME_LIMIT_S = 60 # 60 seconds / 1 minute.
|
||||
CHECK_FOR_DOCUMENTS_TASK_TIME_LIMIT_S = 90 # 90 seconds.
|
||||
# The maximum time the lock can be held for. Will automatically be released
|
||||
# after this time.
|
||||
CHECK_FOR_DOCUMENTS_TASK_LOCK_TIMEOUT_S = 120 # 120 seconds / 2 minutes.
|
||||
assert (
|
||||
CHECK_FOR_DOCUMENTS_TASK_SOFT_TIME_LIMIT_S < CHECK_FOR_DOCUMENTS_TASK_TIME_LIMIT_S
|
||||
), "The soft time limit must be less than the time limit."
|
||||
assert (
|
||||
CHECK_FOR_DOCUMENTS_TASK_TIME_LIMIT_S < CHECK_FOR_DOCUMENTS_TASK_LOCK_TIMEOUT_S
|
||||
), "The time limit must be less than the lock timeout."
|
||||
# Time to wait to acquire the lock.
|
||||
CHECK_FOR_DOCUMENTS_TASK_LOCK_BLOCKING_TIMEOUT_S = 30 # 30 seconds.
|
||||
|
||||
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE = 15
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Celery tasks for migrating documents from Vespa to OpenSearch."""
|
||||
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
@@ -10,6 +11,30 @@ from celery import Task
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
CHECK_FOR_DOCUMENTS_TASK_LOCK_BLOCKING_TIMEOUT_S,
|
||||
)
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
CHECK_FOR_DOCUMENTS_TASK_LOCK_TIMEOUT_S,
|
||||
)
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
CHECK_FOR_DOCUMENTS_TASK_SOFT_TIME_LIMIT_S,
|
||||
)
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
CHECK_FOR_DOCUMENTS_TASK_TIME_LIMIT_S,
|
||||
)
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
MIGRATION_TASK_LOCK_BLOCKING_TIMEOUT_S,
|
||||
)
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
MIGRATION_TASK_LOCK_TIMEOUT_S,
|
||||
)
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
MIGRATION_TASK_SOFT_TIME_LIMIT_S,
|
||||
)
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
MIGRATION_TASK_TIME_LIMIT_S,
|
||||
)
|
||||
from onyx.background.celery.tasks.opensearch_migration.transformer import (
|
||||
transform_vespa_chunks_to_opensearch_chunks,
|
||||
)
|
||||
@@ -31,6 +56,9 @@ from onyx.db.opensearch_migration import (
|
||||
increment_num_times_observed_no_additional_docs_to_populate_migration_table_with_commit,
|
||||
)
|
||||
from onyx.db.opensearch_migration import should_document_migration_be_permanently_failed
|
||||
from onyx.db.opensearch_migration import (
|
||||
try_insert_opensearch_tenant_migration_record_with_commit,
|
||||
)
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
@@ -92,10 +120,14 @@ def _migrate_single_document(
|
||||
name=OnyxCeleryTask.CHECK_FOR_DOCUMENTS_FOR_OPENSEARCH_MIGRATION_TASK,
|
||||
# Does not store the task's return value in the result backend.
|
||||
ignore_result=True,
|
||||
# When exceeded celery will raise a SoftTimeLimitExceeded in the task.
|
||||
soft_time_limit=60 * 5, # 5 minutes.
|
||||
# When exceeded the task will be forcefully terminated.
|
||||
time_limit=60 * 6, # 6 minutes.
|
||||
# WARNING: This is here just for rigor but since we use threads for Celery
|
||||
# this config is not respected and timeout logic must be implemented in the
|
||||
# task.
|
||||
soft_time_limit=CHECK_FOR_DOCUMENTS_TASK_SOFT_TIME_LIMIT_S,
|
||||
# WARNING: This is here just for rigor but since we use threads for Celery
|
||||
# this config is not respected and timeout logic must be implemented in the
|
||||
# task.
|
||||
time_limit=CHECK_FOR_DOCUMENTS_TASK_TIME_LIMIT_S,
|
||||
# Passed in self to the task to get task metadata.
|
||||
bind=True,
|
||||
)
|
||||
@@ -107,7 +139,11 @@ def check_for_documents_for_opensearch_migration_task(
|
||||
table.
|
||||
|
||||
Should not execute meaningful logic at the same time as
|
||||
migrate_document_from_vespa_to_opensearch_task.
|
||||
migrate_documents_from_vespa_to_opensearch_task.
|
||||
|
||||
Effectively tries to populate as many migration records as possible within
|
||||
CHECK_FOR_DOCUMENTS_TASK_SOFT_TIME_LIMIT_S seconds. Does so in batches of
|
||||
1000 documents.
|
||||
|
||||
Returns:
|
||||
None if OpenSearch migration is not enabled, or if the lock could not be
|
||||
@@ -121,29 +157,33 @@ def check_for_documents_for_opensearch_migration_task(
|
||||
return None
|
||||
|
||||
task_logger.info("Checking for documents for OpenSearch migration.")
|
||||
|
||||
task_start_time = time.monotonic()
|
||||
r = get_redis_client()
|
||||
|
||||
# Use a lock to prevent overlapping tasks. Only this task or
|
||||
# migrate_document_from_vespa_to_opensearch_task can interact with the
|
||||
# migrate_documents_from_vespa_to_opensearch_task can interact with the
|
||||
# OpenSearchMigration table at once.
|
||||
lock_beat: RedisLock = r.lock(
|
||||
lock: RedisLock = r.lock(
|
||||
name=OnyxRedisLocks.OPENSEARCH_MIGRATION_BEAT_LOCK,
|
||||
# The maximum time the lock can be held for. Will automatically be
|
||||
# released after this time.
|
||||
timeout=60 * 6, # 6 minutes, same as the time limit for this task.
|
||||
timeout=CHECK_FOR_DOCUMENTS_TASK_LOCK_TIMEOUT_S,
|
||||
# .acquire will block until the lock is acquired.
|
||||
blocking=True,
|
||||
# Wait for 2 minutes trying to acquire the lock.
|
||||
blocking_timeout=60 * 2, # 2 minutes.
|
||||
# Time to wait to acquire the lock.
|
||||
blocking_timeout=CHECK_FOR_DOCUMENTS_TASK_LOCK_BLOCKING_TIMEOUT_S,
|
||||
)
|
||||
|
||||
if not lock_beat.acquire():
|
||||
if not lock.acquire():
|
||||
task_logger.warning(
|
||||
"The OpenSearch migration check task timed out waiting for the lock."
|
||||
)
|
||||
return None
|
||||
else:
|
||||
task_logger.info(
|
||||
f"Acquired the OpenSearch migration check lock. Took {time.monotonic() - task_start_time:.3f} seconds. "
|
||||
f"Token: {lock.local.token}"
|
||||
)
|
||||
|
||||
num_documents_found_for_record_creation = 0
|
||||
try:
|
||||
# Double check that tenant info is correct.
|
||||
if tenant_id != get_current_tenant_id():
|
||||
@@ -153,60 +193,84 @@ def check_for_documents_for_opensearch_migration_task(
|
||||
)
|
||||
task_logger.error(err_str)
|
||||
return False
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# For pagination, get the last ID we've inserted into
|
||||
# OpenSearchMigration.
|
||||
last_opensearch_migration_document_id = (
|
||||
get_last_opensearch_migration_document_id(db_session)
|
||||
)
|
||||
# Now get the next batch of doc IDs starting after the last ID.
|
||||
document_ids = get_paginated_document_batch(
|
||||
db_session,
|
||||
prev_ending_document_id=last_opensearch_migration_document_id,
|
||||
)
|
||||
|
||||
if not document_ids:
|
||||
task_logger.info(
|
||||
"No more documents to insert for OpenSearch migration."
|
||||
while (
|
||||
time.monotonic() - task_start_time
|
||||
< CHECK_FOR_DOCUMENTS_TASK_SOFT_TIME_LIMIT_S
|
||||
and lock.owned()
|
||||
):
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# For pagination, get the last ID we've inserted into
|
||||
# OpenSearchMigration.
|
||||
last_opensearch_migration_document_id = (
|
||||
get_last_opensearch_migration_document_id(db_session)
|
||||
)
|
||||
increment_num_times_observed_no_additional_docs_to_populate_migration_table_with_commit(
|
||||
db_session
|
||||
# Now get the next batch of doc IDs starting after the last ID.
|
||||
# We'll do 1000 documents per transaction/timeout check.
|
||||
document_ids = get_paginated_document_batch(
|
||||
db_session,
|
||||
limit=1000,
|
||||
prev_ending_document_id=last_opensearch_migration_document_id,
|
||||
)
|
||||
# TODO(andrei): Once we've done this enough times and the number
|
||||
# of documents matches the number of migration records, we can
|
||||
# be done with this task and update
|
||||
# document_migration_record_table_population_status.
|
||||
return True
|
||||
|
||||
# Create the migration records for the next batch of documents with
|
||||
# status PENDING.
|
||||
create_opensearch_migration_records_with_commit(db_session, document_ids)
|
||||
task_logger.info(
|
||||
f"Created {len(document_ids)} migration records for the next batch of documents."
|
||||
)
|
||||
if not document_ids:
|
||||
task_logger.info(
|
||||
"No more documents to insert for OpenSearch migration."
|
||||
)
|
||||
increment_num_times_observed_no_additional_docs_to_populate_migration_table_with_commit(
|
||||
db_session
|
||||
)
|
||||
# TODO(andrei): Once we've done this enough times and the
|
||||
# number of documents matches the number of migration
|
||||
# records, we can be done with this task and update
|
||||
# document_migration_record_table_population_status.
|
||||
return True
|
||||
|
||||
# Create the migration records for the next batch of documents
|
||||
# with status PENDING.
|
||||
create_opensearch_migration_records_with_commit(
|
||||
db_session, document_ids
|
||||
)
|
||||
num_documents_found_for_record_creation += len(document_ids)
|
||||
|
||||
# Try to create the singleton row in
|
||||
# OpenSearchTenantMigrationRecord if it doesn't already exist.
|
||||
# This is a reasonable place to put it because we already have a
|
||||
# lock, a session, and error handling, at the cost of running
|
||||
# this small set of logic for every batch.
|
||||
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
|
||||
except Exception:
|
||||
task_logger.exception("Error in the OpenSearch migration check task.")
|
||||
return False
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
else:
|
||||
task_logger.warning(
|
||||
"The OpenSearch migration lock was not owned on completion of the check task."
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Finished checking for documents for OpenSearch migration. Found {num_documents_found_for_record_creation} documents "
|
||||
f"to create migration records for in {time.monotonic() - task_start_time:.3f} seconds. However, this may include "
|
||||
"documents for which there already exist records."
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# shared_task allows this task to be shared across celery app instances.
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.MIGRATE_DOCUMENT_FROM_VESPA_TO_OPENSEARCH_TASK,
|
||||
name=OnyxCeleryTask.MIGRATE_DOCUMENTS_FROM_VESPA_TO_OPENSEARCH_TASK,
|
||||
# Does not store the task's return value in the result backend.
|
||||
ignore_result=True,
|
||||
# When exceeded celery will raise a SoftTimeLimitExceeded in the task.
|
||||
soft_time_limit=60 * 5, # 5 minutes.
|
||||
# When exceeded the task will be forcefully terminated.
|
||||
time_limit=60 * 6, # 6 minutes.
|
||||
# WARNING: This is here just for rigor but since we use threads for Celery
|
||||
# this config is not respected and timeout logic must be implemented in the
|
||||
# task.
|
||||
soft_time_limit=MIGRATION_TASK_SOFT_TIME_LIMIT_S,
|
||||
# WARNING: This is here just for rigor but since we use threads for Celery
|
||||
# this config is not respected and timeout logic must be implemented in the
|
||||
# task.
|
||||
time_limit=MIGRATION_TASK_TIME_LIMIT_S,
|
||||
# Passed in self to the task to get task metadata.
|
||||
bind=True,
|
||||
)
|
||||
@@ -220,10 +284,13 @@ def migrate_documents_from_vespa_to_opensearch_task(
|
||||
Should not execute meaningful logic at the same time as
|
||||
check_for_documents_for_opensearch_migration_task.
|
||||
|
||||
Effectively tries to migrate as many documents as possible within
|
||||
MIGRATION_TASK_SOFT_TIME_LIMIT_S seconds. Does so in batches of 5 documents.
|
||||
|
||||
Returns:
|
||||
None if OpenSearch migration is not enabled, or if the lock could not be
|
||||
acquired; effectively a no-op. True if the task completed
|
||||
successfully. False if the task failed.
|
||||
successfully. False if the task errored.
|
||||
"""
|
||||
if not ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
task_logger.warning(
|
||||
@@ -231,30 +298,36 @@ def migrate_documents_from_vespa_to_opensearch_task(
|
||||
)
|
||||
return None
|
||||
|
||||
task_logger.info("Trying to migrate documents from Vespa to OpenSearch.")
|
||||
|
||||
task_logger.info("Trying a migration batch from Vespa to OpenSearch.")
|
||||
task_start_time = time.monotonic()
|
||||
r = get_redis_client()
|
||||
|
||||
# Use a lock to prevent overlapping tasks. Only this task or
|
||||
# check_for_documents_for_opensearch_migration_task can interact with the
|
||||
# OpenSearchMigration table at once.
|
||||
lock_beat: RedisLock = r.lock(
|
||||
lock: RedisLock = r.lock(
|
||||
name=OnyxRedisLocks.OPENSEARCH_MIGRATION_BEAT_LOCK,
|
||||
# The maximum time the lock can be held for. Will automatically be
|
||||
# released after this time.
|
||||
timeout=60 * 6, # 6 minutes, same as the time limit for this task.
|
||||
timeout=MIGRATION_TASK_LOCK_TIMEOUT_S,
|
||||
# .acquire will block until the lock is acquired.
|
||||
blocking=True,
|
||||
# Wait for 2 minutes trying to acquire the lock.
|
||||
blocking_timeout=60 * 2, # 2 minutes.
|
||||
# Time to wait to acquire the lock.
|
||||
blocking_timeout=MIGRATION_TASK_LOCK_BLOCKING_TIMEOUT_S,
|
||||
)
|
||||
|
||||
if not lock_beat.acquire():
|
||||
if not lock.acquire():
|
||||
task_logger.warning(
|
||||
"The OpenSearch migration task timed out waiting for the lock."
|
||||
)
|
||||
return None
|
||||
else:
|
||||
task_logger.info(
|
||||
f"Acquired the OpenSearch migration lock. Took {time.monotonic() - task_start_time:.3f} seconds. "
|
||||
f"Token: {lock.local.token}"
|
||||
)
|
||||
|
||||
num_documents_migrated = 0
|
||||
num_chunks_migrated = 0
|
||||
num_documents_failed = 0
|
||||
try:
|
||||
# Double check that tenant info is correct.
|
||||
if tenant_id != get_current_tenant_id():
|
||||
@@ -264,98 +337,111 @@ def migrate_documents_from_vespa_to_opensearch_task(
|
||||
)
|
||||
task_logger.error(err_str)
|
||||
return False
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
records_needing_migration = (
|
||||
get_opensearch_migration_records_needing_migration(db_session)
|
||||
)
|
||||
if not records_needing_migration:
|
||||
task_logger.info(
|
||||
"No documents found that need to be migrated from Vespa to OpenSearch."
|
||||
)
|
||||
increment_num_times_observed_no_additional_docs_to_migrate_with_commit(
|
||||
db_session
|
||||
)
|
||||
# TODO(andrei): Once we've done this enough times and
|
||||
# document_migration_record_table_population_status is done, we
|
||||
# can be done with this task and update
|
||||
# overall_document_migration_status accordingly. Note that this
|
||||
# includes marking connectors as needing reindexing if some
|
||||
# migrations failed.
|
||||
return True
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
|
||||
|
||||
opensearch_document_index = OpenSearchDocumentIndex(
|
||||
index_name=search_settings.index_name, tenant_state=tenant_state
|
||||
)
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=False,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Trying to migrate {len(records_needing_migration)} documents from Vespa to OpenSearch."
|
||||
)
|
||||
|
||||
for record in records_needing_migration:
|
||||
try:
|
||||
# If the Document's chunk count is not known, it was
|
||||
# probably just indexed so fail here to give it a chance to
|
||||
# sync. If in the rare event this Document has not been
|
||||
# re-indexed in a very long time and is still under the
|
||||
# "old" embedding/indexing logic where chunk count was never
|
||||
# stored, we will eventually permanently fail and thus force
|
||||
# a re-index of this doc, which is a desireable outcome.
|
||||
if record.document.chunk_count is None:
|
||||
raise RuntimeError(
|
||||
f"Document {record.document_id} has no chunk count."
|
||||
)
|
||||
|
||||
chunks_migrated = _migrate_single_document(
|
||||
document_id=record.document_id,
|
||||
opensearch_document_index=opensearch_document_index,
|
||||
vespa_document_index=vespa_document_index,
|
||||
tenant_state=tenant_state,
|
||||
while (
|
||||
time.monotonic() - task_start_time < MIGRATION_TASK_SOFT_TIME_LIMIT_S
|
||||
and lock.owned()
|
||||
):
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# We'll do 5 documents per transaction/timeout check.
|
||||
records_needing_migration = (
|
||||
get_opensearch_migration_records_needing_migration(
|
||||
db_session, limit=5
|
||||
)
|
||||
|
||||
# If the number of chunks in Vespa is not in sync with the
|
||||
# Document table for this doc let's not consider this
|
||||
# completed and let's let a subsequent run take care of it.
|
||||
if chunks_migrated != record.document.chunk_count:
|
||||
raise RuntimeError(
|
||||
f"Number of chunks migrated ({chunks_migrated}) does not match number of expected chunks in Vespa "
|
||||
f"({record.document.chunk_count}) for document {record.document_id}."
|
||||
)
|
||||
|
||||
record.status = OpenSearchDocumentMigrationStatus.COMPLETED
|
||||
except Exception:
|
||||
record.status = OpenSearchDocumentMigrationStatus.FAILED
|
||||
record.error_message = f"Attempt {record.attempts_count + 1}:\n{traceback.format_exc()}"
|
||||
task_logger.exception(
|
||||
f"Error migrating document {record.document_id} from Vespa to OpenSearch."
|
||||
)
|
||||
if not records_needing_migration:
|
||||
task_logger.info(
|
||||
"No documents found that need to be migrated from Vespa to OpenSearch."
|
||||
)
|
||||
finally:
|
||||
record.attempts_count += 1
|
||||
record.last_attempt_at = datetime.now(timezone.utc)
|
||||
if should_document_migration_be_permanently_failed(record):
|
||||
record.status = (
|
||||
OpenSearchDocumentMigrationStatus.PERMANENTLY_FAILED
|
||||
)
|
||||
# TODO(andrei): Not necessarily here but if this happens
|
||||
# we'll need to mark the connector as needing reindex.
|
||||
increment_num_times_observed_no_additional_docs_to_migrate_with_commit(
|
||||
db_session
|
||||
)
|
||||
# TODO(andrei): Once we've done this enough times and
|
||||
# document_migration_record_table_population_status is done, we
|
||||
# can be done with this task and update
|
||||
# overall_document_migration_status accordingly. Note that this
|
||||
# includes marking connectors as needing reindexing if some
|
||||
# migrations failed.
|
||||
return True
|
||||
|
||||
db_session.commit()
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
tenant_state = TenantState(
|
||||
tenant_id=tenant_id, multitenant=MULTI_TENANT
|
||||
)
|
||||
opensearch_document_index = OpenSearchDocumentIndex(
|
||||
index_name=search_settings.index_name, tenant_state=tenant_state
|
||||
)
|
||||
vespa_document_index = VespaDocumentIndex(
|
||||
index_name=search_settings.index_name,
|
||||
tenant_state=tenant_state,
|
||||
large_chunks_enabled=False,
|
||||
)
|
||||
|
||||
for record in records_needing_migration:
|
||||
try:
|
||||
# If the Document's chunk count is not known, it was
|
||||
# probably just indexed so fail here to give it a chance to
|
||||
# sync. If in the rare event this Document has not been
|
||||
# re-indexed in a very long time and is still under the
|
||||
# "old" embedding/indexing logic where chunk count was never
|
||||
# stored, we will eventually permanently fail and thus force
|
||||
# a re-index of this doc, which is a desireable outcome.
|
||||
if record.document.chunk_count is None:
|
||||
raise RuntimeError(
|
||||
f"Document {record.document_id} has no chunk count."
|
||||
)
|
||||
|
||||
chunks_migrated = _migrate_single_document(
|
||||
document_id=record.document_id,
|
||||
opensearch_document_index=opensearch_document_index,
|
||||
vespa_document_index=vespa_document_index,
|
||||
tenant_state=tenant_state,
|
||||
)
|
||||
|
||||
# If the number of chunks in Vespa is not in sync with the
|
||||
# Document table for this doc let's not consider this
|
||||
# completed and let's let a subsequent run take care of it.
|
||||
if chunks_migrated != record.document.chunk_count:
|
||||
raise RuntimeError(
|
||||
f"Number of chunks migrated ({chunks_migrated}) does not match number of expected chunks "
|
||||
f"in Vespa ({record.document.chunk_count}) for document {record.document_id}."
|
||||
)
|
||||
|
||||
record.status = OpenSearchDocumentMigrationStatus.COMPLETED
|
||||
num_documents_migrated += 1
|
||||
num_chunks_migrated += chunks_migrated
|
||||
except Exception:
|
||||
record.status = OpenSearchDocumentMigrationStatus.FAILED
|
||||
record.error_message = f"Attempt {record.attempts_count + 1}:\n{traceback.format_exc()}"
|
||||
task_logger.exception(
|
||||
f"Error migrating document {record.document_id} from Vespa to OpenSearch."
|
||||
)
|
||||
num_documents_failed += 1
|
||||
finally:
|
||||
record.attempts_count += 1
|
||||
record.last_attempt_at = datetime.now(timezone.utc)
|
||||
if should_document_migration_be_permanently_failed(record):
|
||||
record.status = (
|
||||
OpenSearchDocumentMigrationStatus.PERMANENTLY_FAILED
|
||||
)
|
||||
# TODO(andrei): Not necessarily here but if this happens
|
||||
# we'll need to mark the connector as needing reindex.
|
||||
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
task_logger.exception("Error in the OpenSearch migration task.")
|
||||
return False
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
else:
|
||||
task_logger.warning(
|
||||
"The OpenSearch migration lock was not owned on completion of the migration task."
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Finished a migration batch from Vespa to OpenSearch. Migrated {num_chunks_migrated} chunks "
|
||||
f"from {num_documents_migrated} documents in {time.monotonic() - task_start_time:.3f} seconds. "
|
||||
f"Failed to migrate {num_documents_failed} documents."
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@@ -12,6 +12,7 @@ from retry import retry
|
||||
from sqlalchemy import select
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
@@ -19,12 +20,14 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
@@ -54,6 +57,17 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_queued_key(user_file_id: str | UUID) -> str:
|
||||
"""Key that exists while a process_single_user_file task is sitting in the queue.
|
||||
|
||||
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
before enqueuing and the worker deletes it as its first action. This prevents
|
||||
the beat from adding duplicate tasks for files that already have a live task
|
||||
in flight.
|
||||
"""
|
||||
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
@@ -117,7 +131,24 @@ def _get_document_chunk_count(
|
||||
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
|
||||
|
||||
Uses direct Redis locks to avoid overlapping runs.
|
||||
Three mechanisms prevent queue runaway:
|
||||
|
||||
1. **Queue depth backpressure** – if the broker queue already has more than
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
|
||||
entirely. Workers are clearly behind; adding more tasks would only make
|
||||
the backlog worse.
|
||||
|
||||
2. **Per-file queued guard** – before enqueuing a task we set a short-lived
|
||||
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
|
||||
already exists the file already has a live task in the queue, so we skip
|
||||
it. The worker deletes the key the moment it picks up the task so the
|
||||
next beat cycle can re-enqueue if the file is still PROCESSING.
|
||||
|
||||
3. **Task expiry** – every enqueued task carries an `expires` value equal to
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
|
||||
the queue after that deadline, Celery discards it without touching the DB.
|
||||
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
|
||||
Redis restart), stale tasks evict themselves rather than piling up forever.
|
||||
"""
|
||||
task_logger.info("check_user_file_processing - Starting")
|
||||
|
||||
@@ -132,7 +163,21 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
enqueued = 0
|
||||
skipped_guard = 0
|
||||
try:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
queue_len = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
|
||||
)
|
||||
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
f"check_user_file_processing - Queue depth {queue_len} exceeds "
|
||||
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file_ids = (
|
||||
db_session.execute(
|
||||
@@ -145,12 +190,35 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
)
|
||||
|
||||
for user_file_id in user_file_ids:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
# --- Protection 2: per-file queued guard ---
|
||||
queued_key = _user_file_queued_key(user_file_id)
|
||||
guard_set = redis_client.set(
|
||||
queued_key,
|
||||
1,
|
||||
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
nx=True,
|
||||
)
|
||||
if not guard_set:
|
||||
skipped_guard += 1
|
||||
continue
|
||||
|
||||
# --- Protection 3: task expiry ---
|
||||
# If task submission fails, clear the guard immediately so the
|
||||
# next beat cycle can retry enqueuing this file.
|
||||
try:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={
|
||||
"user_file_id": str(user_file_id),
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
)
|
||||
except Exception:
|
||||
redis_client.delete(queued_key)
|
||||
raise
|
||||
enqueued += 1
|
||||
|
||||
finally:
|
||||
@@ -158,7 +226,8 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
|
||||
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
|
||||
f"tasks for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -175,6 +244,12 @@ def process_single_user_file(
|
||||
start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Clear the "queued" guard set by the beat generator so that the next beat
|
||||
# cycle can re-enqueue this file if it is still in PROCESSING state after
|
||||
# this task completes or fails.
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
|
||||
@@ -22,11 +22,13 @@ from onyx.chat.prompt_utils import build_system_prompt
|
||||
from onyx.chat.prompt_utils import (
|
||||
get_default_base_system_prompt,
|
||||
)
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.context.search.models import SearchDocsResponse
|
||||
from onyx.db.models import Persona
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
@@ -36,6 +38,7 @@ from onyx.prompts.chat_prompts import OPEN_URL_REMINDER
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ToolCallDebug
|
||||
from onyx.server.query_and_chat.streaming_models import TopLevelBranching
|
||||
from onyx.tools.built_in_tools import CITEABLE_TOOLS_NAMES
|
||||
from onyx.tools.built_in_tools import STOPPING_TOOLS_NAMES
|
||||
@@ -57,6 +60,28 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _should_keep_bedrock_tool_definitions(
|
||||
llm: object, simple_chat_history: list[ChatMessageSimple]
|
||||
) -> bool:
|
||||
"""Bedrock requires tool config when history includes toolUse/toolResult blocks."""
|
||||
model_provider = getattr(getattr(llm, "config", None), "model_provider", None)
|
||||
if model_provider not in {
|
||||
LlmProviderNames.BEDROCK,
|
||||
LlmProviderNames.BEDROCK_CONVERSE,
|
||||
}:
|
||||
return False
|
||||
|
||||
return any(
|
||||
(
|
||||
msg.message_type == MessageType.ASSISTANT
|
||||
and msg.tool_calls
|
||||
and len(msg.tool_calls) > 0
|
||||
)
|
||||
or msg.message_type == MessageType.TOOL_CALL_RESPONSE
|
||||
for msg in simple_chat_history
|
||||
)
|
||||
|
||||
|
||||
def _try_fallback_tool_extraction(
|
||||
llm_step_result: LlmStepResult,
|
||||
tool_choice: ToolChoiceOptions,
|
||||
@@ -452,7 +477,12 @@ def run_llm_loop(
|
||||
elif out_of_cycles or ran_image_gen:
|
||||
# Last cycle, no tools allowed, just answer!
|
||||
tool_choice = ToolChoiceOptions.NONE
|
||||
final_tools = []
|
||||
# Bedrock requires tool config in requests that include toolUse/toolResult history.
|
||||
final_tools = (
|
||||
tools
|
||||
if _should_keep_bedrock_tool_definitions(llm, simple_chat_history)
|
||||
else []
|
||||
)
|
||||
else:
|
||||
tool_choice = ToolChoiceOptions.AUTO
|
||||
final_tools = tools
|
||||
@@ -601,6 +631,19 @@ def run_llm_loop(
|
||||
tool_responses: list[ToolResponse] = []
|
||||
tool_calls = llm_step_result.tool_calls or []
|
||||
|
||||
if INTEGRATION_TESTS_MODE and tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=tool_call.placement,
|
||||
obj=ToolCallDebug(
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
tool_args=tool_call.tool_args,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if len(tool_calls) > 1:
|
||||
emitter.emit(
|
||||
Packet(
|
||||
|
||||
@@ -7,6 +7,7 @@ import re
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from contextvars import Token
|
||||
from uuid import UUID
|
||||
|
||||
from redis.client import Redis
|
||||
@@ -43,6 +44,7 @@ from onyx.chat.prompt_utils import calculate_reserved_tokens
|
||||
from onyx.chat.save_chat import save_chat_turn
|
||||
from onyx.chat.stop_signal_checker import is_connected as check_stop_signal
|
||||
from onyx.chat.stop_signal_checker import reset_cancel_status
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
@@ -69,6 +71,8 @@ from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.llm.factory import get_llm_token_counter
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.request_context import reset_llm_mock_response
|
||||
from onyx.llm.request_context import set_llm_mock_response
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -90,10 +94,6 @@ from onyx.tools.tool_constructor import SearchToolConfig
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.timing import log_function_time
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from onyx.utils.variable_functionality import noop_fallback
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -318,6 +318,7 @@ def handle_stream_message_objects(
|
||||
) -> AnswerStream:
|
||||
tenant_id = get_current_tenant_id()
|
||||
processing_start_time = time.monotonic()
|
||||
mock_response_token: Token[str | None] | None = None
|
||||
|
||||
llm: LLM | None = None
|
||||
chat_session: ChatSession | None = None
|
||||
@@ -328,6 +329,14 @@ def handle_stream_message_objects(
|
||||
llm_user_identifier = "anonymous_user"
|
||||
else:
|
||||
llm_user_identifier = user.email or str(user_id)
|
||||
|
||||
if new_msg_req.mock_llm_response is not None:
|
||||
if not INTEGRATION_TESTS_MODE:
|
||||
raise ValueError(
|
||||
"mock_llm_response can only be used when INTEGRATION_TESTS_MODE=true"
|
||||
)
|
||||
mock_response_token = set_llm_mock_response(new_msg_req.mock_llm_response)
|
||||
|
||||
try:
|
||||
if not new_msg_req.chat_session_id:
|
||||
if not new_msg_req.chat_session_info:
|
||||
@@ -361,21 +370,16 @@ def handle_stream_message_objects(
|
||||
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
|
||||
)
|
||||
|
||||
# Track user message in PostHog for analytics
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
module="onyx.utils.telemetry",
|
||||
attribute="event_telemetry",
|
||||
fallback=noop_fallback,
|
||||
)(
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=user.email if not user.is_anonymous else tenant_id,
|
||||
event="user_message_sent",
|
||||
event=MilestoneRecordType.USER_MESSAGE_SENT,
|
||||
properties={
|
||||
"origin": new_msg_req.origin.value,
|
||||
"has_files": len(new_msg_req.file_descriptors) > 0,
|
||||
"has_project": chat_session.project_id is not None,
|
||||
"has_persona": persona is not None and persona.id != DEFAULT_PERSONA_ID,
|
||||
"deep_research": new_msg_req.deep_research,
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -723,6 +727,9 @@ def handle_stream_message_objects(
|
||||
|
||||
db_session.rollback()
|
||||
finally:
|
||||
if mock_response_token is not None:
|
||||
reset_llm_mock_response(mock_response_token)
|
||||
|
||||
try:
|
||||
if redis_client is not None and chat_session is not None:
|
||||
set_processing_status(
|
||||
@@ -839,6 +846,7 @@ def stream_chat_message_objects(
|
||||
translated_new_msg_req = SendMessageRequest(
|
||||
message=new_msg_req.message,
|
||||
llm_override=new_msg_req.llm_override,
|
||||
mock_llm_response=new_msg_req.mock_llm_response,
|
||||
allowed_tool_ids=new_msg_req.allowed_tool_ids,
|
||||
forced_tool_id=forced_tool_id,
|
||||
file_descriptors=new_msg_req.file_descriptors,
|
||||
|
||||
@@ -75,7 +75,7 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
|
||||
# Auth Configs
|
||||
#####
|
||||
# Upgrades users from disabled auth to basic auth and shows warning.
|
||||
_auth_type_str = (os.environ.get("AUTH_TYPE") or "").lower()
|
||||
_auth_type_str = (os.environ.get("AUTH_TYPE") or "basic").lower()
|
||||
if _auth_type_str == "disabled":
|
||||
logger.warning(
|
||||
"AUTH_TYPE='disabled' is no longer supported. "
|
||||
@@ -900,6 +900,9 @@ MANAGED_VESPA = os.environ.get("MANAGED_VESPA", "").lower() == "true"
|
||||
|
||||
ENABLE_EMAIL_INVITES = os.environ.get("ENABLE_EMAIL_INVITES", "").lower() == "true"
|
||||
|
||||
# Limit on number of users a free trial tenant can invite (cloud only)
|
||||
NUM_FREE_TRIAL_USER_INVITES = int(os.environ.get("NUM_FREE_TRIAL_USER_INVITES", "10"))
|
||||
|
||||
# Security and authentication
|
||||
DATA_PLANE_SECRET = os.environ.get(
|
||||
"DATA_PLANE_SECRET", ""
|
||||
|
||||
@@ -158,6 +158,17 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
|
||||
|
||||
# How long a queued user-file task is valid before workers discard it.
|
||||
# Should be longer than the beat interval (20 s) but short enough to prevent
|
||||
# indefinite queue growth. Workers drop tasks older than this without touching
|
||||
# the DB, so a shorter value = faster drain of stale duplicates.
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
|
||||
|
||||
# Maximum number of tasks allowed in the user-file-processing queue before the
|
||||
# beat generator stops adding more. Prevents unbounded queue growth when workers
|
||||
# fall behind.
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
|
||||
|
||||
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
|
||||
@@ -351,6 +362,7 @@ class MilestoneRecordType(str, Enum):
|
||||
CREATED_CONNECTOR = "created_connector"
|
||||
CONNECTOR_SUCCEEDED = "connector_succeeded"
|
||||
RAN_QUERY = "ran_query"
|
||||
USER_MESSAGE_SENT = "user_message_sent"
|
||||
MULTIPLE_ASSISTANTS = "multiple_assistants"
|
||||
CREATED_ASSISTANT = "created_assistant"
|
||||
CREATED_ONYX_BOT = "created_onyx_bot"
|
||||
@@ -434,6 +446,9 @@ class OnyxRedisLocks:
|
||||
# User file processing
|
||||
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
|
||||
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
|
||||
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
|
||||
# Prevents the beat from re-enqueuing the same file while a task is already queued.
|
||||
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
|
||||
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
|
||||
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
|
||||
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
|
||||
@@ -573,8 +588,8 @@ class OnyxCeleryTask:
|
||||
CHECK_FOR_DOCUMENTS_FOR_OPENSEARCH_MIGRATION_TASK = (
|
||||
"check_for_documents_for_opensearch_migration_task"
|
||||
)
|
||||
MIGRATE_DOCUMENT_FROM_VESPA_TO_OPENSEARCH_TASK = (
|
||||
"migrate_document_from_vespa_to_opensearch_task"
|
||||
MIGRATE_DOCUMENTS_FROM_VESPA_TO_OPENSEARCH_TASK = (
|
||||
"migrate_documents_from_vespa_to_opensearch_task"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -32,6 +32,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
BASE_URL = "https://api.gong.io"
|
||||
MAX_CALL_DETAILS_ATTEMPTS = 6
|
||||
CALL_DETAILS_DELAY = 30 # in seconds
|
||||
# Gong API limit is 3 calls/sec — stay safely under it
|
||||
MIN_REQUEST_INTERVAL = 0.5 # seconds between requests
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -45,9 +47,13 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
self.continue_on_fail = continue_on_fail
|
||||
self.auth_token_basic: str | None = None
|
||||
self.hide_user_info = hide_user_info
|
||||
self._last_request_time: float = 0.0
|
||||
|
||||
# urllib3 Retry already respects the Retry-After header by default
|
||||
# (respect_retry_after_header=True), so on 429 it will sleep for the
|
||||
# duration Gong specifies before retrying.
|
||||
retry_strategy = Retry(
|
||||
total=5,
|
||||
total=10,
|
||||
backoff_factor=2,
|
||||
status_forcelist=[429, 500, 502, 503, 504],
|
||||
)
|
||||
@@ -61,8 +67,24 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
url = f"{GongConnector.BASE_URL}{endpoint}"
|
||||
return url
|
||||
|
||||
def _throttled_request(
|
||||
self, method: str, url: str, **kwargs: Any
|
||||
) -> requests.Response:
|
||||
"""Rate-limited request wrapper. Enforces MIN_REQUEST_INTERVAL between
|
||||
calls to stay under Gong's 3 calls/sec limit and avoid triggering 429s."""
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_request_time
|
||||
if elapsed < self.MIN_REQUEST_INTERVAL:
|
||||
time.sleep(self.MIN_REQUEST_INTERVAL - elapsed)
|
||||
|
||||
response = self._session.request(method, url, **kwargs)
|
||||
self._last_request_time = time.monotonic()
|
||||
return response
|
||||
|
||||
def _get_workspace_id_map(self) -> dict[str, str]:
|
||||
response = self._session.get(GongConnector.make_url("/v2/workspaces"))
|
||||
response = self._throttled_request(
|
||||
"GET", GongConnector.make_url("/v2/workspaces")
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
workspaces_details = response.json().get("workspaces")
|
||||
@@ -106,8 +128,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
del body["filter"]["workspaceId"]
|
||||
|
||||
while True:
|
||||
response = self._session.post(
|
||||
GongConnector.make_url("/v2/calls/transcript"), json=body
|
||||
response = self._throttled_request(
|
||||
"POST", GongConnector.make_url("/v2/calls/transcript"), json=body
|
||||
)
|
||||
# If no calls in the range, just break out
|
||||
if response.status_code == 404:
|
||||
@@ -142,8 +164,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
"contentSelector": {"exposedFields": {"parties": True}},
|
||||
}
|
||||
|
||||
response = self._session.post(
|
||||
GongConnector.make_url("/v2/calls/extensive"), json=body
|
||||
response = self._throttled_request(
|
||||
"POST", GongConnector.make_url("/v2/calls/extensive"), json=body
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -194,7 +216,8 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
# There's a likely race condition in the API where a transcript will have a
|
||||
# call id but the call to v2/calls/extensive will not return all of the id's
|
||||
# retry with exponential backoff has been observed to mitigate this
|
||||
# in ~2 minutes
|
||||
# in ~2 minutes. After max attempts, proceed with whatever we have —
|
||||
# the per-call loop below will skip missing IDs gracefully.
|
||||
current_attempt = 0
|
||||
while True:
|
||||
current_attempt += 1
|
||||
@@ -213,11 +236,14 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
f"missing_call_ids={missing_call_ids}"
|
||||
)
|
||||
if current_attempt >= self.MAX_CALL_DETAILS_ATTEMPTS:
|
||||
raise RuntimeError(
|
||||
f"Attempt count exceeded for _get_call_details_by_ids: "
|
||||
f"missing_call_ids={missing_call_ids} "
|
||||
f"max_attempts={self.MAX_CALL_DETAILS_ATTEMPTS}"
|
||||
logger.error(
|
||||
f"Giving up on missing call id's after "
|
||||
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
|
||||
f"missing_call_ids={missing_call_ids} — "
|
||||
f"proceeding with {len(call_details_map)} of "
|
||||
f"{len(transcript_call_ids)} calls"
|
||||
)
|
||||
break
|
||||
|
||||
wait_seconds = self.CALL_DETAILS_DELAY * pow(2, current_attempt - 1)
|
||||
logger.warning(
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
from pydantic import model_validator
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
@@ -167,6 +168,14 @@ class DocumentBase(BaseModel):
|
||||
# list of strings.
|
||||
metadata: dict[str, str | list[str]]
|
||||
|
||||
@field_validator("metadata", mode="before")
|
||||
@classmethod
|
||||
def _coerce_metadata_values(cls, v: dict[str, Any]) -> dict[str, str | list[str]]:
|
||||
return {
|
||||
key: [str(item) for item in val] if isinstance(val, list) else str(val)
|
||||
for key, val in v.items()
|
||||
}
|
||||
|
||||
# UTC time
|
||||
doc_updated_at: datetime | None = None
|
||||
chunk_count: int | None = None
|
||||
|
||||
@@ -228,14 +228,13 @@ class BuildSessionStatus(str, PyEnum):
|
||||
class SandboxStatus(str, PyEnum):
|
||||
PROVISIONING = "provisioning"
|
||||
RUNNING = "running"
|
||||
IDLE = "idle"
|
||||
SLEEPING = "sleeping" # Pod terminated, snapshots saved to S3
|
||||
TERMINATED = "terminated"
|
||||
FAILED = "failed"
|
||||
|
||||
def is_active(self) -> bool:
|
||||
"""Check if sandbox is in an active state (running or idle)."""
|
||||
return self in (SandboxStatus.RUNNING, SandboxStatus.IDLE)
|
||||
"""Check if sandbox is in an active state (running)."""
|
||||
return self == SandboxStatus.RUNNING
|
||||
|
||||
def is_terminal(self) -> bool:
|
||||
"""Check if sandbox is in a terminal state."""
|
||||
|
||||
@@ -109,45 +109,38 @@ def can_user_access_llm_provider(
|
||||
is_admin: If True, bypass user group restrictions but still respect persona restrictions
|
||||
|
||||
Access logic:
|
||||
1. If is_public=True → everyone has access (public override)
|
||||
2. If is_public=False:
|
||||
- Both groups AND personas set → must satisfy BOTH (AND logic, admins bypass group check)
|
||||
- Only groups set → must be in one of the groups (OR across groups, admins bypass)
|
||||
- Only personas set → must use one of the personas (OR across personas, applies to admins)
|
||||
- Neither set → NOBODY has access unless admin (locked, admin-only)
|
||||
- is_public controls USER access (group bypass): when True, all users can access
|
||||
regardless of group membership. When False, user must be in a whitelisted group
|
||||
(or be admin).
|
||||
- Persona restrictions are ALWAYS enforced when set, regardless of is_public.
|
||||
This allows admins to make a provider available to all users while still
|
||||
restricting which personas (assistants) can use it.
|
||||
|
||||
Decision matrix:
|
||||
1. is_public=True, no personas set → everyone has access
|
||||
2. is_public=True, personas set → all users, but only whitelisted personas
|
||||
3. is_public=False, groups+personas set → must satisfy BOTH (admins bypass groups)
|
||||
4. is_public=False, only groups set → must be in group (admins bypass)
|
||||
5. is_public=False, only personas set → must use whitelisted persona
|
||||
6. is_public=False, neither set → admin-only (locked)
|
||||
"""
|
||||
# Public override - everyone has access
|
||||
if provider.is_public:
|
||||
return True
|
||||
|
||||
# Extract IDs once to avoid multiple iterations
|
||||
provider_group_ids = (
|
||||
{group.id for group in provider.groups} if provider.groups else set()
|
||||
)
|
||||
provider_persona_ids = (
|
||||
{p.id for p in provider.personas} if provider.personas else set()
|
||||
)
|
||||
|
||||
provider_group_ids = {g.id for g in (provider.groups or [])}
|
||||
provider_persona_ids = {p.id for p in (provider.personas or [])}
|
||||
has_groups = bool(provider_group_ids)
|
||||
has_personas = bool(provider_persona_ids)
|
||||
|
||||
# Both groups AND personas set → AND logic (must satisfy both)
|
||||
if has_groups and has_personas:
|
||||
# Admins bypass group check but still must satisfy persona restrictions
|
||||
user_in_group = is_admin or bool(user_group_ids & provider_group_ids)
|
||||
persona_allowed = persona.id in provider_persona_ids if persona else False
|
||||
return user_in_group and persona_allowed
|
||||
# Persona restrictions are always enforced when set, regardless of is_public
|
||||
if has_personas and not (persona and persona.id in provider_persona_ids):
|
||||
return False
|
||||
|
||||
if provider.is_public:
|
||||
return True
|
||||
|
||||
# Only groups set → user must be in one of the groups (admins bypass)
|
||||
if has_groups:
|
||||
return is_admin or bool(user_group_ids & provider_group_ids)
|
||||
|
||||
# Only personas set → persona must be in allowed list (applies to admins too)
|
||||
if has_personas:
|
||||
return persona.id in provider_persona_ids if persona else False
|
||||
|
||||
# Neither groups nor personas set, and not public → admins can access
|
||||
return is_admin
|
||||
# No groups: either persona-whitelisted (already passed) or admin-only if locked
|
||||
return has_personas or is_admin
|
||||
|
||||
|
||||
def validate_persona_ids_exist(
|
||||
@@ -428,7 +421,7 @@ def fetch_existing_models(
|
||||
|
||||
def fetch_existing_llm_providers(
|
||||
db_session: Session,
|
||||
flow_types: list[LLMModelFlowType],
|
||||
flow_type_filter: list[LLMModelFlowType],
|
||||
only_public: bool = False,
|
||||
exclude_image_generation_providers: bool = True,
|
||||
) -> list[LLMProviderModel]:
|
||||
@@ -436,30 +429,27 @@ def fetch_existing_llm_providers(
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
flow_types: List of flow types to filter by
|
||||
flow_type_filter: List of flow types to filter by, empty list for no filter
|
||||
only_public: If True, only return public providers
|
||||
exclude_image_generation_providers: If True, exclude providers that are
|
||||
used for image generation configs
|
||||
"""
|
||||
providers_with_flows = (
|
||||
select(ModelConfiguration.llm_provider_id)
|
||||
.join(LLMModelFlow)
|
||||
.where(LLMModelFlow.llm_model_flow_type.in_(flow_types))
|
||||
.distinct()
|
||||
)
|
||||
stmt = select(LLMProviderModel)
|
||||
|
||||
if flow_type_filter:
|
||||
providers_with_flows = (
|
||||
select(ModelConfiguration.llm_provider_id)
|
||||
.join(LLMModelFlow)
|
||||
.where(LLMModelFlow.llm_model_flow_type.in_(flow_type_filter))
|
||||
.distinct()
|
||||
)
|
||||
stmt = stmt.where(LLMProviderModel.id.in_(providers_with_flows))
|
||||
|
||||
if exclude_image_generation_providers:
|
||||
stmt = select(LLMProviderModel).where(
|
||||
LLMProviderModel.id.in_(providers_with_flows)
|
||||
)
|
||||
else:
|
||||
image_gen_provider_ids = select(ModelConfiguration.llm_provider_id).join(
|
||||
ImageGenerationConfig
|
||||
)
|
||||
stmt = select(LLMProviderModel).where(
|
||||
LLMProviderModel.id.in_(providers_with_flows)
|
||||
| LLMProviderModel.id.in_(image_gen_provider_ids)
|
||||
)
|
||||
stmt = stmt.where(~LLMProviderModel.id.in_(image_gen_provider_ids))
|
||||
|
||||
stmt = stmt.options(
|
||||
selectinload(LLMProviderModel.model_configurations),
|
||||
@@ -722,13 +712,15 @@ def sync_auto_mode_models(
|
||||
changes += 1
|
||||
else:
|
||||
# Add new model - all models from GitHub config are visible
|
||||
new_model = ModelConfiguration(
|
||||
insert_new_model_configuration__no_commit(
|
||||
db_session=db_session,
|
||||
llm_provider_id=provider.id,
|
||||
name=model_config.name,
|
||||
display_name=model_config.display_name,
|
||||
model_name=model_config.name,
|
||||
supported_flows=[LLMModelFlowType.CHAT],
|
||||
is_visible=True,
|
||||
max_input_tokens=None,
|
||||
display_name=model_config.display_name,
|
||||
)
|
||||
db_session.add(new_model)
|
||||
changes += 1
|
||||
|
||||
# In Auto mode, default model is always set from GitHub config
|
||||
|
||||
@@ -9,6 +9,9 @@ from sqlalchemy import text
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE,
|
||||
)
|
||||
from onyx.db.enums import OpenSearchDocumentMigrationStatus
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import OpenSearchDocumentMigrationRecord
|
||||
@@ -18,18 +21,21 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE = 15
|
||||
DEFAULT_BATCH_SIZE_OF_DOCUMENTS_TO_MIGRATE = 500
|
||||
DEFAULT_BATCH_SIZE_OF_DOCUMENTS_TO_CHECK_FOR_MIGRATION = 2000
|
||||
|
||||
|
||||
def get_paginated_document_batch(
|
||||
db_session: Session,
|
||||
limit: int = DEFAULT_BATCH_SIZE_OF_DOCUMENTS_TO_CHECK_FOR_MIGRATION,
|
||||
limit: int,
|
||||
prev_ending_document_id: str | None = None,
|
||||
) -> list[str]:
|
||||
"""Gets a paginated batch of document IDs from the Document table.
|
||||
|
||||
We need some deterministic ordering to ensure that we don't miss any
|
||||
documents when paginating. This function uses the document ID. It is
|
||||
possible a document is inserted above a spot this function has already
|
||||
passed. In that event we assume that the document will be indexed into
|
||||
OpenSearch anyway and we don't need to migrate.
|
||||
TODO(andrei): Consider ordering on last_modified in addition to ID to better
|
||||
match get_opensearch_migration_records_needing_migration.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session.
|
||||
limit: Number of document IDs to fetch.
|
||||
@@ -91,7 +97,7 @@ def create_opensearch_migration_records_with_commit(
|
||||
|
||||
def get_opensearch_migration_records_needing_migration(
|
||||
db_session: Session,
|
||||
limit: int = DEFAULT_BATCH_SIZE_OF_DOCUMENTS_TO_MIGRATE,
|
||||
limit: int,
|
||||
) -> list[OpenSearchDocumentMigrationRecord]:
|
||||
"""Gets records of documents that need to be migrated.
|
||||
|
||||
@@ -165,6 +171,20 @@ def get_total_document_count(db_session: Session) -> int:
|
||||
return db_session.query(Document).count()
|
||||
|
||||
|
||||
def try_insert_opensearch_tenant_migration_record_with_commit(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Tries to insert the singleton row on OpenSearchTenantMigrationRecord.
|
||||
|
||||
If the row already exists, does nothing.
|
||||
"""
|
||||
stmt = insert(OpenSearchTenantMigrationRecord).on_conflict_do_nothing(
|
||||
index_elements=[text("(true)")]
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def increment_num_times_observed_no_additional_docs_to_migrate_with_commit(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
|
||||
@@ -21,6 +21,7 @@ from onyx.llm.model_response import ModelResponseStream
|
||||
from onyx.llm.model_response import Usage
|
||||
from onyx.llm.models import ANTHROPIC_REASONING_EFFORT_BUDGET
|
||||
from onyx.llm.models import OPENAI_REASONING_EFFORT
|
||||
from onyx.llm.request_context import get_llm_mock_response
|
||||
from onyx.llm.utils import build_litellm_passthrough_kwargs
|
||||
from onyx.llm.utils import is_true_openai_model
|
||||
from onyx.llm.utils import model_is_reasoning_model
|
||||
@@ -378,7 +379,7 @@ class LitellmLLM(LLM):
|
||||
passthrough_kwargs["api_key"] = self._api_key or None
|
||||
|
||||
response = litellm.completion(
|
||||
mock_response=MOCK_LLM_RESPONSE,
|
||||
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
|
||||
model=model,
|
||||
base_url=self._api_base or None,
|
||||
api_version=self._api_version or None,
|
||||
|
||||
18
backend/onyx/llm/request_context.py
Normal file
18
backend/onyx/llm/request_context.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import contextvars
|
||||
|
||||
|
||||
_LLM_MOCK_RESPONSE_CONTEXTVAR: contextvars.ContextVar[str | None] = (
|
||||
contextvars.ContextVar("llm_mock_response", default=None)
|
||||
)
|
||||
|
||||
|
||||
def get_llm_mock_response() -> str | None:
|
||||
return _LLM_MOCK_RESPONSE_CONTEXTVAR.get()
|
||||
|
||||
|
||||
def set_llm_mock_response(mock_response: str | None) -> contextvars.Token[str | None]:
|
||||
return _LLM_MOCK_RESPONSE_CONTEXTVAR.set(mock_response)
|
||||
|
||||
|
||||
def reset_llm_mock_response(token: contextvars.Token[str | None]) -> None:
|
||||
_LLM_MOCK_RESPONSE_CONTEXTVAR.reset(token)
|
||||
@@ -592,11 +592,8 @@ def build_slack_response_blocks(
|
||||
)
|
||||
|
||||
citations_blocks = []
|
||||
document_blocks = []
|
||||
if answer.citation_info:
|
||||
citations_blocks = _build_citations_blocks(answer)
|
||||
else:
|
||||
document_blocks = _priority_ordered_documents_blocks(answer)
|
||||
|
||||
citations_divider = [DividerBlock()] if citations_blocks else []
|
||||
buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else []
|
||||
@@ -608,7 +605,6 @@ def build_slack_response_blocks(
|
||||
+ ai_feedback_block
|
||||
+ citations_divider
|
||||
+ citations_blocks
|
||||
+ document_blocks
|
||||
+ buttons_divider
|
||||
+ web_follow_up_block
|
||||
+ follow_up_block
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
import re
|
||||
from enum import Enum
|
||||
|
||||
# Matches Slack channel references like <#C097NBWMY8Y> or <#C097NBWMY8Y|channel-name>
|
||||
SLACK_CHANNEL_REF_PATTERN = re.compile(r"<#([A-Z0-9]+)(?:\|([^>]+))?>")
|
||||
|
||||
LIKE_BLOCK_ACTION_ID = "feedback-like"
|
||||
DISLIKE_BLOCK_ACTION_ID = "feedback-dislike"
|
||||
SHOW_EVERYONE_ACTION_ID = "show-everyone"
|
||||
|
||||
@@ -1,29 +1,163 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from mistune import create_markdown
|
||||
from mistune import HTMLRenderer
|
||||
|
||||
# Tags that should be replaced with a newline (line-break and block-level elements)
|
||||
_HTML_NEWLINE_TAG_PATTERN = re.compile(
|
||||
r"<br\s*/?>|</(?:p|div|li|h[1-6]|tr|blockquote|section|article)>",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Strips HTML tags but excludes autolinks like <https://...> and <mailto:...>
|
||||
_HTML_TAG_PATTERN = re.compile(
|
||||
r"<(?!https?://|mailto:)/?[a-zA-Z][^>]*>",
|
||||
)
|
||||
|
||||
# Matches fenced code blocks (``` ... ```) so we can skip sanitization inside them
|
||||
_FENCED_CODE_BLOCK_PATTERN = re.compile(r"```[\s\S]*?```")
|
||||
|
||||
# Matches the start of any markdown link: [text]( or [[n]](
|
||||
# The inner group handles nested brackets for citation links like [[1]](.
|
||||
_MARKDOWN_LINK_PATTERN = re.compile(r"\[(?:[^\[\]]|\[[^\]]*\])*\]\(")
|
||||
|
||||
# Matches Slack-style links <url|text> that LLMs sometimes output directly.
|
||||
# Mistune doesn't recognise this syntax, so text() would escape the angle
|
||||
# brackets and Slack would render them as literal text instead of links.
|
||||
_SLACK_LINK_PATTERN = re.compile(r"<(https?://[^|>]+)\|([^>]+)>")
|
||||
|
||||
|
||||
def _sanitize_html(text: str) -> str:
|
||||
"""Strip HTML tags from a text fragment.
|
||||
|
||||
Block-level closing tags and <br> are converted to newlines.
|
||||
All other HTML tags are removed. Autolinks (<https://...>) are preserved.
|
||||
"""
|
||||
text = _HTML_NEWLINE_TAG_PATTERN.sub("\n", text)
|
||||
text = _HTML_TAG_PATTERN.sub("", text)
|
||||
return text
|
||||
|
||||
|
||||
def _transform_outside_code_blocks(
|
||||
message: str, transform: Callable[[str], str]
|
||||
) -> str:
|
||||
"""Apply *transform* only to text outside fenced code blocks."""
|
||||
parts = _FENCED_CODE_BLOCK_PATTERN.split(message)
|
||||
code_blocks = _FENCED_CODE_BLOCK_PATTERN.findall(message)
|
||||
|
||||
result: list[str] = []
|
||||
for i, part in enumerate(parts):
|
||||
result.append(transform(part))
|
||||
if i < len(code_blocks):
|
||||
result.append(code_blocks[i])
|
||||
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int | None]:
|
||||
"""Extract markdown link destination, allowing nested parentheses in the URL."""
|
||||
depth = 0
|
||||
i = start_idx
|
||||
|
||||
while i < len(message):
|
||||
curr = message[i]
|
||||
if curr == "\\":
|
||||
i += 2
|
||||
continue
|
||||
|
||||
if curr == "(":
|
||||
depth += 1
|
||||
elif curr == ")":
|
||||
if depth == 0:
|
||||
return message[start_idx:i], i
|
||||
depth -= 1
|
||||
i += 1
|
||||
|
||||
return message[start_idx:], None
|
||||
|
||||
|
||||
def _normalize_link_destinations(message: str) -> str:
|
||||
"""Wrap markdown link URLs in angle brackets so the parser handles special chars safely.
|
||||
|
||||
Markdown link syntax [text](url) breaks when the URL contains unescaped
|
||||
parentheses, spaces, or other special characters. Wrapping the URL in angle
|
||||
brackets — [text](<url>) — tells the parser to treat everything inside as
|
||||
a literal URL. This applies to all links, not just citations.
|
||||
"""
|
||||
if "](" not in message:
|
||||
return message
|
||||
|
||||
normalized_parts: list[str] = []
|
||||
cursor = 0
|
||||
|
||||
while match := _MARKDOWN_LINK_PATTERN.search(message, cursor):
|
||||
normalized_parts.append(message[cursor : match.end()])
|
||||
destination_start = match.end()
|
||||
destination, end_idx = _extract_link_destination(message, destination_start)
|
||||
if end_idx is None:
|
||||
normalized_parts.append(message[destination_start:])
|
||||
return "".join(normalized_parts)
|
||||
|
||||
already_wrapped = destination.startswith("<") and destination.endswith(">")
|
||||
if destination and not already_wrapped:
|
||||
destination = f"<{destination}>"
|
||||
|
||||
normalized_parts.append(destination)
|
||||
normalized_parts.append(")")
|
||||
cursor = end_idx + 1
|
||||
|
||||
normalized_parts.append(message[cursor:])
|
||||
return "".join(normalized_parts)
|
||||
|
||||
|
||||
def _convert_slack_links_to_markdown(message: str) -> str:
|
||||
"""Convert Slack-style <url|text> links to standard markdown [text](url).
|
||||
|
||||
LLMs sometimes emit Slack mrkdwn link syntax directly. Mistune doesn't
|
||||
recognise it, so the angle brackets would be escaped by text() and Slack
|
||||
would render the link as literal text instead of a clickable link.
|
||||
"""
|
||||
return _transform_outside_code_blocks(
|
||||
message, lambda text: _SLACK_LINK_PATTERN.sub(r"[\2](\1)", text)
|
||||
)
|
||||
|
||||
|
||||
def format_slack_message(message: str | None) -> str:
|
||||
if message is None:
|
||||
return ""
|
||||
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
|
||||
result = md(message)
|
||||
message = _transform_outside_code_blocks(message, _sanitize_html)
|
||||
message = _convert_slack_links_to_markdown(message)
|
||||
normalized_message = _normalize_link_destinations(message)
|
||||
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough", "table"])
|
||||
result = md(normalized_message)
|
||||
# With HTMLRenderer, result is always str (not AST list)
|
||||
assert isinstance(result, str)
|
||||
return result
|
||||
return result.rstrip("\n")
|
||||
|
||||
|
||||
class SlackRenderer(HTMLRenderer):
|
||||
"""Renders markdown as Slack mrkdwn format instead of HTML.
|
||||
|
||||
Overrides all HTMLRenderer methods that produce HTML tags to ensure
|
||||
no raw HTML ever appears in Slack messages.
|
||||
"""
|
||||
|
||||
SPECIALS: dict[str, str] = {"&": "&", "<": "<", ">": ">"}
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._table_headers: list[str] = []
|
||||
self._current_row_cells: list[str] = []
|
||||
|
||||
def escape_special(self, text: str) -> str:
|
||||
for special, replacement in self.SPECIALS.items():
|
||||
text = text.replace(special, replacement)
|
||||
return text
|
||||
|
||||
def heading(self, text: str, level: int, **attrs: Any) -> str: # noqa: ARG002
|
||||
return f"*{text}*\n"
|
||||
return f"*{text}*\n\n"
|
||||
|
||||
def emphasis(self, text: str) -> str:
|
||||
return f"_{text}_"
|
||||
@@ -42,7 +176,7 @@ class SlackRenderer(HTMLRenderer):
|
||||
count += 1
|
||||
prefix = f"{count}. " if ordered else "• "
|
||||
lines[i] = f"{prefix}{line[4:]}"
|
||||
return "\n".join(lines)
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
def list_item(self, text: str) -> str:
|
||||
return f"li: {text}\n"
|
||||
@@ -64,7 +198,73 @@ class SlackRenderer(HTMLRenderer):
|
||||
return f"`{text}`"
|
||||
|
||||
def block_code(self, code: str, info: str | None = None) -> str: # noqa: ARG002
|
||||
return f"```\n{code}\n```\n"
|
||||
return f"```\n{code.rstrip(chr(10))}\n```\n\n"
|
||||
|
||||
def linebreak(self) -> str:
|
||||
return "\n"
|
||||
|
||||
def thematic_break(self) -> str:
|
||||
return "---\n\n"
|
||||
|
||||
def block_quote(self, text: str) -> str:
|
||||
lines = text.strip().split("\n")
|
||||
quoted = "\n".join(f">{line}" for line in lines)
|
||||
return quoted + "\n\n"
|
||||
|
||||
def block_html(self, html: str) -> str:
|
||||
return _sanitize_html(html) + "\n\n"
|
||||
|
||||
def block_error(self, text: str) -> str:
|
||||
return f"```\n{text}\n```\n\n"
|
||||
|
||||
def text(self, text: str) -> str:
|
||||
# Only escape the three entities Slack recognizes: & < >
|
||||
# HTMLRenderer.text() also escapes " to " which Slack renders
|
||||
# as literal " text since Slack doesn't recognize that entity.
|
||||
return self.escape_special(text)
|
||||
|
||||
# -- Table rendering (converts markdown tables to vertical cards) --
|
||||
|
||||
def table_cell(
|
||||
self, text: str, align: str | None = None, head: bool = False # noqa: ARG002
|
||||
) -> str:
|
||||
if head:
|
||||
self._table_headers.append(text.strip())
|
||||
else:
|
||||
self._current_row_cells.append(text.strip())
|
||||
return ""
|
||||
|
||||
def table_head(self, text: str) -> str: # noqa: ARG002
|
||||
self._current_row_cells = []
|
||||
return ""
|
||||
|
||||
def table_row(self, text: str) -> str: # noqa: ARG002
|
||||
cells = self._current_row_cells
|
||||
self._current_row_cells = []
|
||||
# First column becomes the bold title, remaining columns are bulleted fields
|
||||
lines: list[str] = []
|
||||
if cells:
|
||||
title = cells[0]
|
||||
if title:
|
||||
# Avoid double-wrapping if cell already contains bold markup
|
||||
if title.startswith("*") and title.endswith("*") and len(title) > 1:
|
||||
lines.append(title)
|
||||
else:
|
||||
lines.append(f"*{title}*")
|
||||
for i, cell in enumerate(cells[1:], start=1):
|
||||
if i < len(self._table_headers):
|
||||
lines.append(f" • {self._table_headers[i]}: {cell}")
|
||||
else:
|
||||
lines.append(f" • {cell}")
|
||||
return "\n".join(lines) + "\n\n"
|
||||
|
||||
def table_body(self, text: str) -> str:
|
||||
return text
|
||||
|
||||
def table(self, text: str) -> str:
|
||||
self._table_headers = []
|
||||
self._current_row_cells = []
|
||||
return text + "\n"
|
||||
|
||||
def paragraph(self, text: str) -> str:
|
||||
return f"{text}\n"
|
||||
return f"{text}\n\n"
|
||||
|
||||
@@ -18,15 +18,18 @@ from onyx.configs.onyxbot_configs import ONYX_BOT_DISPLAY_ERROR_MSGS
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_NUM_RETRIES
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import Tag
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.onyxbot.slack.blocks import build_slack_response_blocks
|
||||
from onyx.onyxbot.slack.constants import SLACK_CHANNEL_REF_PATTERN
|
||||
from onyx.onyxbot.slack.handlers.utils import send_team_member_message
|
||||
from onyx.onyxbot.slack.models import SlackMessageInfo
|
||||
from onyx.onyxbot.slack.models import ThreadMessage
|
||||
from onyx.onyxbot.slack.utils import get_channel_from_id
|
||||
from onyx.onyxbot.slack.utils import get_channel_name_from_id
|
||||
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
|
||||
from onyx.onyxbot.slack.utils import SlackRateLimiter
|
||||
@@ -41,6 +44,51 @@ srl = SlackRateLimiter()
|
||||
RT = TypeVar("RT") # return type
|
||||
|
||||
|
||||
def resolve_channel_references(
|
||||
message: str,
|
||||
client: WebClient,
|
||||
logger: OnyxLoggingAdapter,
|
||||
) -> tuple[str, list[Tag]]:
|
||||
"""Parse Slack channel references from a message, resolve IDs to names,
|
||||
replace the raw markup with readable #channel-name, and return channel tags
|
||||
for search filtering."""
|
||||
tags: list[Tag] = []
|
||||
channel_matches = SLACK_CHANNEL_REF_PATTERN.findall(message)
|
||||
seen_channel_ids: set[str] = set()
|
||||
|
||||
for channel_id, channel_name_from_markup in channel_matches:
|
||||
if channel_id in seen_channel_ids:
|
||||
continue
|
||||
seen_channel_ids.add(channel_id)
|
||||
|
||||
channel_name = channel_name_from_markup or None
|
||||
|
||||
if not channel_name:
|
||||
try:
|
||||
channel_info = get_channel_from_id(client=client, channel_id=channel_id)
|
||||
channel_name = channel_info.get("name") or None
|
||||
except Exception:
|
||||
logger.warning(f"Failed to resolve channel name for ID: {channel_id}")
|
||||
|
||||
if not channel_name:
|
||||
continue
|
||||
|
||||
# Replace raw Slack markup with readable channel name
|
||||
if channel_name_from_markup:
|
||||
message = message.replace(
|
||||
f"<#{channel_id}|{channel_name_from_markup}>",
|
||||
f"#{channel_name}",
|
||||
)
|
||||
else:
|
||||
message = message.replace(
|
||||
f"<#{channel_id}>",
|
||||
f"#{channel_name}",
|
||||
)
|
||||
tags.append(Tag(tag_key="Channel", tag_value=channel_name))
|
||||
|
||||
return message, tags
|
||||
|
||||
|
||||
def rate_limits(
|
||||
client: WebClient, channel: str, thread_ts: Optional[str]
|
||||
) -> Callable[[Callable[..., RT]], Callable[..., RT]]:
|
||||
@@ -157,6 +205,20 @@ def handle_regular_answer(
|
||||
user_message = messages[-1]
|
||||
history_messages = messages[:-1]
|
||||
|
||||
# Resolve any <#CHANNEL_ID> references in the user message to readable
|
||||
# channel names and extract channel tags for search filtering
|
||||
resolved_message, channel_tags = resolve_channel_references(
|
||||
message=user_message.message,
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
user_message = ThreadMessage(
|
||||
message=resolved_message,
|
||||
sender=user_message.sender,
|
||||
role=user_message.role,
|
||||
)
|
||||
|
||||
channel_name, _ = get_channel_name_from_id(
|
||||
client=client,
|
||||
channel_id=channel,
|
||||
@@ -207,6 +269,7 @@ def handle_regular_answer(
|
||||
source_type=None,
|
||||
document_set=document_set_names,
|
||||
time_cutoff=None,
|
||||
tags=channel_tags if channel_tags else None,
|
||||
)
|
||||
|
||||
new_message_request = SendMessageRequest(
|
||||
@@ -231,6 +294,16 @@ def handle_regular_answer(
|
||||
slack_context_str=slack_context_str,
|
||||
)
|
||||
|
||||
# If a channel filter was applied but no results were found, override
|
||||
# the LLM response to avoid hallucinated answers about unindexed channels
|
||||
if channel_tags and not answer.citation_info and not answer.top_documents:
|
||||
channel_names = ", ".join(f"#{tag.tag_value}" for tag in channel_tags)
|
||||
answer.answer = (
|
||||
f"No indexed data found for {channel_names}. "
|
||||
"This channel may not be indexed, or there may be no messages "
|
||||
"matching your query within it."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Unable to process message - did not successfully answer "
|
||||
@@ -285,6 +358,7 @@ def handle_regular_answer(
|
||||
only_respond_if_citations
|
||||
and not answer.citation_info
|
||||
and not message_info.bypass_filters
|
||||
and not channel_tags
|
||||
):
|
||||
logger.error(
|
||||
f"Unable to find citations to answer: '{answer.answer}' - not answering!"
|
||||
|
||||
@@ -109,6 +109,7 @@ class TenantRedis(redis.Redis):
|
||||
"unlock",
|
||||
"get",
|
||||
"set",
|
||||
"setex",
|
||||
"delete",
|
||||
"exists",
|
||||
"incrby",
|
||||
|
||||
@@ -92,6 +92,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pairs_for
|
||||
from onyx.db.connector_credential_pair import (
|
||||
get_connector_credential_pairs_for_user_parallel,
|
||||
)
|
||||
from onyx.db.connector_credential_pair import verify_user_has_access_to_cc_pair
|
||||
from onyx.db.credentials import cleanup_gmail_credentials
|
||||
from onyx.db.credentials import cleanup_google_drive_credentials
|
||||
from onyx.db.credentials import create_credential
|
||||
@@ -556,6 +557,43 @@ def _normalize_file_names_for_backwards_compatibility(
|
||||
return file_names + file_locations[len(file_names) :]
|
||||
|
||||
|
||||
def _fetch_and_check_file_connector_cc_pair_permissions(
|
||||
connector_id: int,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
require_editable: bool,
|
||||
) -> ConnectorCredentialPair:
|
||||
cc_pair = fetch_connector_credential_pair_for_connector(db_session, connector_id)
|
||||
if cc_pair is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No Connector-Credential Pair found for this connector",
|
||||
)
|
||||
|
||||
has_requested_access = verify_user_has_access_to_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=require_editable,
|
||||
)
|
||||
if has_requested_access:
|
||||
return cc_pair
|
||||
|
||||
# Special case: global curators should be able to manage files
|
||||
# for public file connectors even when they are not the creator.
|
||||
if (
|
||||
require_editable
|
||||
and user.role == UserRole.GLOBAL_CURATOR
|
||||
and cc_pair.access_type == AccessType.PUBLIC
|
||||
):
|
||||
return cc_pair
|
||||
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Access denied. User cannot manage files for this connector.",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/admin/connector/file/upload", tags=PUBLIC_API_TAGS)
|
||||
def upload_files_api(
|
||||
files: list[UploadFile],
|
||||
@@ -567,7 +605,7 @@ def upload_files_api(
|
||||
@router.get("/admin/connector/{connector_id}/files", tags=PUBLIC_API_TAGS)
|
||||
def list_connector_files(
|
||||
connector_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user), # noqa: ARG001
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ConnectorFilesResponse:
|
||||
"""List all files in a file connector."""
|
||||
@@ -580,6 +618,13 @@ def list_connector_files(
|
||||
status_code=400, detail="This endpoint only works with file connectors"
|
||||
)
|
||||
|
||||
_ = _fetch_and_check_file_connector_cc_pair_permissions(
|
||||
connector_id=connector_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
require_editable=False,
|
||||
)
|
||||
|
||||
file_locations = connector.connector_specific_config.get("file_locations", [])
|
||||
file_names = connector.connector_specific_config.get("file_names", [])
|
||||
|
||||
@@ -629,7 +674,7 @@ def update_connector_files(
|
||||
connector_id: int,
|
||||
files: list[UploadFile] | None = File(None),
|
||||
file_ids_to_remove: str = Form("[]"),
|
||||
user: User = Depends(current_curator_or_admin_user), # noqa: ARG001
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FileUploadResponse:
|
||||
"""
|
||||
@@ -647,12 +692,13 @@ def update_connector_files(
|
||||
)
|
||||
|
||||
# Get the connector-credential pair for indexing/pruning triggers
|
||||
cc_pair = fetch_connector_credential_pair_for_connector(db_session, connector_id)
|
||||
if cc_pair is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No Connector-Credential Pair found for this connector",
|
||||
)
|
||||
# and validate user permissions for file management.
|
||||
cc_pair = _fetch_and_check_file_connector_cc_pair_permissions(
|
||||
connector_id=connector_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
require_editable=True,
|
||||
)
|
||||
|
||||
# Parse file IDs to remove
|
||||
try:
|
||||
|
||||
@@ -243,6 +243,7 @@ class WebappInfo(BaseModel):
|
||||
has_webapp: bool # Whether a webapp exists in outputs/web
|
||||
webapp_url: str | None # URL to access the webapp (e.g., http://localhost:3015)
|
||||
status: str # Sandbox status (running, terminated, etc.)
|
||||
ready: bool # Whether the NextJS dev server is actually responding
|
||||
|
||||
|
||||
# ===== File Upload Models =====
|
||||
|
||||
@@ -13,6 +13,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import BuildSessionStatus
|
||||
from onyx.db.enums import SandboxStatus
|
||||
from onyx.db.models import BuildMessage
|
||||
from onyx.db.models import User
|
||||
@@ -32,6 +33,8 @@ from onyx.server.features.build.api.models import SuggestionBubble
|
||||
from onyx.server.features.build.api.models import SuggestionTheme
|
||||
from onyx.server.features.build.api.models import UploadResponse
|
||||
from onyx.server.features.build.api.models import WebappInfo
|
||||
from onyx.server.features.build.configs import SANDBOX_BACKEND
|
||||
from onyx.server.features.build.configs import SandboxBackend
|
||||
from onyx.server.features.build.db.build_session import allocate_nextjs_port
|
||||
from onyx.server.features.build.db.build_session import get_build_session
|
||||
from onyx.server.features.build.db.sandbox import get_latest_snapshot_for_session
|
||||
@@ -362,14 +365,13 @@ def restore_session(
|
||||
lock_key = f"sandbox_restore:{sandbox.id}"
|
||||
lock = redis_client.lock(lock_key, timeout=RESTORE_LOCK_TIMEOUT_SECONDS)
|
||||
|
||||
# blocking=True means wait if another restore is in progress
|
||||
acquired = lock.acquire(
|
||||
blocking=True, blocking_timeout=RESTORE_LOCK_TIMEOUT_SECONDS
|
||||
)
|
||||
# Non-blocking: if another restore is already running, return 409 immediately
|
||||
# instead of making the user wait. The frontend will retry.
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Restore operation timed out waiting for lock",
|
||||
status_code=409,
|
||||
detail="Restore already in progress",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -379,15 +381,11 @@ def restore_session(
|
||||
# Also re-check if session workspace exists (another request may have
|
||||
# restored it while we were waiting)
|
||||
if sandbox.status == SandboxStatus.RUNNING:
|
||||
# Verify pod is healthy before proceeding
|
||||
is_healthy = sandbox_manager.health_check(sandbox.id, timeout=10.0)
|
||||
if is_healthy and sandbox_manager.session_workspace_exists(
|
||||
sandbox.id, session_id
|
||||
):
|
||||
logger.info(
|
||||
f"Session {session_id} workspace was restored by another request"
|
||||
)
|
||||
# Update heartbeat to mark sandbox as active
|
||||
session.status = BuildSessionStatus.ACTIVE
|
||||
update_sandbox_heartbeat(db_session, sandbox.id)
|
||||
base_response = SessionResponse.from_model(session, sandbox)
|
||||
return DetailedSessionResponse.from_session_response(
|
||||
@@ -410,69 +408,82 @@ def restore_session(
|
||||
# Fall through to TERMINATED handling below
|
||||
|
||||
session_manager = SessionManager(db_session)
|
||||
llm_config = session_manager._get_llm_config(None, None)
|
||||
|
||||
if sandbox.status in (SandboxStatus.SLEEPING, SandboxStatus.TERMINATED):
|
||||
# 1. Re-provision the pod
|
||||
logger.info(f"Re-provisioning {sandbox.status.value} sandbox {sandbox.id}")
|
||||
llm_config = session_manager._get_llm_config(None, None)
|
||||
# Mark as PROVISIONING before the long-running provision() call
|
||||
# so other requests know work is in progress
|
||||
update_sandbox_status__no_commit(
|
||||
db_session, sandbox.id, SandboxStatus.PROVISIONING
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
sandbox_manager.provision(
|
||||
sandbox_id=sandbox.id,
|
||||
user_id=user.id,
|
||||
tenant_id=tenant_id,
|
||||
llm_config=llm_config,
|
||||
)
|
||||
|
||||
# Mark as RUNNING after successful provision
|
||||
update_sandbox_status__no_commit(
|
||||
db_session, sandbox.id, SandboxStatus.RUNNING
|
||||
)
|
||||
db_session.commit()
|
||||
db_session.refresh(sandbox)
|
||||
|
||||
# 2. Check if session workspace needs to be loaded
|
||||
if sandbox.status == SandboxStatus.RUNNING:
|
||||
if not sandbox_manager.session_workspace_exists(sandbox.id, session_id):
|
||||
# Get latest snapshot and restore it
|
||||
snapshot = get_latest_snapshot_for_session(db_session, session_id)
|
||||
if snapshot:
|
||||
# Allocate a new port for the restored session
|
||||
new_port = allocate_nextjs_port(db_session)
|
||||
session.nextjs_port = new_port
|
||||
workspace_exists = sandbox_manager.session_workspace_exists(
|
||||
sandbox.id, session_id
|
||||
)
|
||||
|
||||
if not workspace_exists:
|
||||
# Allocate port if not already set (needed for both snapshot restore and fresh setup)
|
||||
if not session.nextjs_port:
|
||||
session.nextjs_port = allocate_nextjs_port(db_session)
|
||||
# Commit port allocation before long-running operations
|
||||
db_session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Restoring snapshot for session {session_id} "
|
||||
f"from {snapshot.storage_path} with port {new_port}"
|
||||
)
|
||||
# Only Kubernetes backend supports snapshot restoration
|
||||
snapshot = None
|
||||
if SANDBOX_BACKEND == SandboxBackend.KUBERNETES:
|
||||
snapshot = get_latest_snapshot_for_session(db_session, session_id)
|
||||
|
||||
if snapshot:
|
||||
try:
|
||||
sandbox_manager.restore_snapshot(
|
||||
sandbox_id=sandbox.id,
|
||||
session_id=session_id,
|
||||
snapshot_storage_path=snapshot.storage_path,
|
||||
tenant_id=tenant_id,
|
||||
nextjs_port=new_port,
|
||||
nextjs_port=session.nextjs_port,
|
||||
llm_config=llm_config,
|
||||
use_demo_data=session.demo_data_enabled,
|
||||
)
|
||||
session.status = BuildSessionStatus.ACTIVE
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
# Clear the port allocation on failure so it can be reused
|
||||
logger.error(
|
||||
f"Failed to restore session {session_id}, "
|
||||
f"clearing port {new_port}: {e}"
|
||||
f"Snapshot restore failed for session {session_id}: {e}"
|
||||
)
|
||||
session.nextjs_port = None
|
||||
db_session.commit()
|
||||
raise
|
||||
else:
|
||||
# No snapshot - set up fresh workspace
|
||||
logger.info(
|
||||
f"No snapshot found for session {session_id}, "
|
||||
f"setting up fresh workspace"
|
||||
)
|
||||
llm_config = session_manager._get_llm_config(None, None)
|
||||
sandbox_manager.setup_session_workspace(
|
||||
sandbox_id=sandbox.id,
|
||||
session_id=session_id,
|
||||
llm_config=llm_config,
|
||||
nextjs_port=session.nextjs_port or 3010,
|
||||
nextjs_port=session.nextjs_port,
|
||||
)
|
||||
session.status = BuildSessionStatus.ACTIVE
|
||||
db_session.commit()
|
||||
else:
|
||||
logger.warning(
|
||||
f"Sandbox {sandbox.id} status is {sandbox.status} after "
|
||||
f"re-provision, expected RUNNING"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to restore session {session_id}: {e}", exc_info=True)
|
||||
|
||||
@@ -18,7 +18,6 @@ from onyx.db.models import BuildMessage
|
||||
from onyx.db.models import BuildSession
|
||||
from onyx.db.models import LLMProvider as LLMProviderModel
|
||||
from onyx.db.models import Sandbox
|
||||
from onyx.db.models import Snapshot
|
||||
from onyx.server.features.build.configs import SANDBOX_NEXTJS_PORT_END
|
||||
from onyx.server.features.build.configs import SANDBOX_NEXTJS_PORT_START
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
@@ -269,27 +268,6 @@ def update_artifact(
|
||||
logger.info(f"Updated artifact {artifact_id}")
|
||||
|
||||
|
||||
# Snapshot operations
|
||||
def create_snapshot(
|
||||
session_id: UUID,
|
||||
storage_path: str,
|
||||
size_bytes: int,
|
||||
db_session: Session,
|
||||
) -> Snapshot:
|
||||
"""Create a new snapshot record."""
|
||||
snapshot = Snapshot(
|
||||
session_id=session_id,
|
||||
storage_path=storage_path,
|
||||
size_bytes=size_bytes,
|
||||
)
|
||||
db_session.add(snapshot)
|
||||
db_session.commit()
|
||||
db_session.refresh(snapshot)
|
||||
|
||||
logger.info(f"Created snapshot {snapshot.id} for session {session_id}")
|
||||
return snapshot
|
||||
|
||||
|
||||
# Message operations
|
||||
def create_message(
|
||||
session_id: UUID,
|
||||
@@ -501,6 +479,32 @@ def allocate_nextjs_port(db_session: Session) -> int:
|
||||
)
|
||||
|
||||
|
||||
def mark_user_sessions_idle__no_commit(db_session: Session, user_id: UUID) -> int:
|
||||
"""Mark all ACTIVE sessions for a user as IDLE.
|
||||
|
||||
Called when a sandbox goes to sleep so the frontend knows these sessions
|
||||
need restoration before they can be used again.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
user_id: The user whose sessions should be marked idle
|
||||
|
||||
Returns:
|
||||
Number of sessions updated
|
||||
"""
|
||||
result = (
|
||||
db_session.query(BuildSession)
|
||||
.filter(
|
||||
BuildSession.user_id == user_id,
|
||||
BuildSession.status == BuildSessionStatus.ACTIVE,
|
||||
)
|
||||
.update({BuildSession.status: BuildSessionStatus.IDLE})
|
||||
)
|
||||
db_session.flush()
|
||||
logger.info(f"Marked {result} sessions as IDLE for user {user_id}")
|
||||
return result
|
||||
|
||||
|
||||
def clear_nextjs_ports_for_user(db_session: Session, user_id: UUID) -> int:
|
||||
"""Clear nextjs_port for all sessions belonging to a user.
|
||||
|
||||
|
||||
@@ -126,7 +126,7 @@ def get_idle_sandboxes(
|
||||
)
|
||||
|
||||
stmt = select(Sandbox).where(
|
||||
Sandbox.status.in_([SandboxStatus.RUNNING, SandboxStatus.IDLE]),
|
||||
Sandbox.status == SandboxStatus.RUNNING,
|
||||
or_(
|
||||
Sandbox.last_heartbeat < threshold_time,
|
||||
and_(
|
||||
@@ -147,27 +147,30 @@ def get_running_sandbox_count_by_tenant(
|
||||
since Sandbox model no longer has tenant_id. This function returns
|
||||
the count of all running sandboxes.
|
||||
"""
|
||||
stmt = select(func.count(Sandbox.id)).where(
|
||||
Sandbox.status.in_([SandboxStatus.RUNNING, SandboxStatus.IDLE])
|
||||
)
|
||||
stmt = select(func.count(Sandbox.id)).where(Sandbox.status == SandboxStatus.RUNNING)
|
||||
result = db_session.execute(stmt).scalar()
|
||||
return result or 0
|
||||
|
||||
|
||||
def create_snapshot(
|
||||
def create_snapshot__no_commit(
|
||||
db_session: Session,
|
||||
session_id: UUID,
|
||||
storage_path: str,
|
||||
size_bytes: int,
|
||||
) -> Snapshot:
|
||||
"""Create a snapshot record for a session."""
|
||||
"""Create a snapshot record for a session.
|
||||
|
||||
NOTE: Uses flush() instead of commit(). The caller (cleanup task) is
|
||||
responsible for committing after all snapshots + status updates are done,
|
||||
so the entire operation is atomic.
|
||||
"""
|
||||
snapshot = Snapshot(
|
||||
session_id=session_id,
|
||||
storage_path=storage_path,
|
||||
size_bytes=size_bytes,
|
||||
)
|
||||
db_session.add(snapshot)
|
||||
db_session.commit()
|
||||
db_session.flush()
|
||||
return snapshot
|
||||
|
||||
|
||||
|
||||
@@ -183,13 +183,14 @@ class SandboxManager(ABC):
|
||||
session_id: UUID,
|
||||
tenant_id: str,
|
||||
) -> SnapshotResult | None:
|
||||
"""Create a snapshot of a session's outputs directory.
|
||||
"""Create a snapshot of a session's outputs and attachments directories.
|
||||
|
||||
Captures only the session-specific outputs:
|
||||
sessions/$session_id/outputs/
|
||||
Captures session-specific user data:
|
||||
- sessions/$session_id/outputs/ (generated artifacts, web apps)
|
||||
- sessions/$session_id/attachments/ (user uploaded files)
|
||||
|
||||
Does NOT include: venv, skills, AGENTS.md, opencode.json, attachments
|
||||
Does NOT include: shared files/ directory
|
||||
Does NOT include: venv, skills, AGENTS.md, opencode.json, files symlink
|
||||
(these are regenerated during restore)
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
@@ -197,14 +198,45 @@ class SandboxManager(ABC):
|
||||
tenant_id: Tenant identifier for storage path
|
||||
|
||||
Returns:
|
||||
SnapshotResult with storage path and size, or None if
|
||||
snapshots are disabled for this backend
|
||||
SnapshotResult with storage path and size, or None if:
|
||||
- Snapshots are disabled for this backend
|
||||
- No outputs directory exists (nothing to snapshot)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If snapshot creation fails
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def restore_snapshot(
|
||||
self,
|
||||
sandbox_id: UUID,
|
||||
session_id: UUID,
|
||||
snapshot_storage_path: str,
|
||||
tenant_id: str,
|
||||
nextjs_port: int,
|
||||
llm_config: LLMProviderConfig,
|
||||
use_demo_data: bool = False,
|
||||
) -> None:
|
||||
"""Restore a session workspace from a snapshot.
|
||||
|
||||
For Kubernetes: Downloads and extracts the snapshot, regenerates config files.
|
||||
For Local: No-op since workspaces persist on disk (no snapshots).
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
session_id: The session ID to restore
|
||||
snapshot_storage_path: Path to the snapshot in storage
|
||||
tenant_id: Tenant identifier for storage access
|
||||
nextjs_port: Port number for the NextJS dev server
|
||||
llm_config: LLM provider configuration for opencode.json
|
||||
use_demo_data: If True, symlink files/ to demo data
|
||||
|
||||
Raises:
|
||||
RuntimeError: If snapshot restoration fails
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def session_workspace_exists(
|
||||
self,
|
||||
@@ -225,36 +257,6 @@ class SandboxManager(ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def restore_snapshot(
|
||||
self,
|
||||
sandbox_id: UUID,
|
||||
session_id: UUID,
|
||||
snapshot_storage_path: str,
|
||||
tenant_id: str,
|
||||
nextjs_port: int,
|
||||
) -> None:
|
||||
"""Restore a snapshot into a session's workspace directory.
|
||||
|
||||
Downloads the snapshot from storage, extracts it into
|
||||
sessions/$session_id/outputs/, and starts the NextJS server.
|
||||
|
||||
For Kubernetes backend, this downloads from S3 and streams
|
||||
into the pod via kubectl exec (since the pod has no S3 access).
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
session_id: The session ID to restore
|
||||
snapshot_storage_path: Path to the snapshot in storage
|
||||
tenant_id: Tenant identifier for storage access
|
||||
nextjs_port: Port number for the NextJS dev server
|
||||
|
||||
Raises:
|
||||
RuntimeError: If snapshot restoration fails
|
||||
FileNotFoundError: If snapshot does not exist
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def health_check(self, sandbox_id: UUID, timeout: float = 60.0) -> bool:
|
||||
"""Check if the sandbox is healthy.
|
||||
|
||||
@@ -1583,9 +1583,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@isaacs/brace-expansion": {
|
||||
"version": "5.0.0",
|
||||
"resolved": "https://registry.npmjs.org/@isaacs/brace-expansion/-/brace-expansion-5.0.0.tgz",
|
||||
"integrity": "sha512-ZT55BDLV0yv0RBm2czMiZ+SqCGO7AvmOM3G/w2xhVPH+te0aKgFjmBvGlL1dH+ql2tgGO3MVrbb3jCKyvpgnxA==",
|
||||
"version": "5.0.1",
|
||||
"resolved": "https://registry.npmjs.org/@isaacs/brace-expansion/-/brace-expansion-5.0.1.tgz",
|
||||
"integrity": "sha512-WMz71T1JS624nWj2n2fnYAuPovhv7EUhk69R6i9dsVyzxt5eM3bjwvgk9L+APE1TRscGysAVMANkB0jh0LQZrQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@isaacs/balanced-match": "^4.0.1"
|
||||
@@ -1640,9 +1640,9 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@modelcontextprotocol/sdk": {
|
||||
"version": "1.25.3",
|
||||
"resolved": "https://registry.npmjs.org/@modelcontextprotocol/sdk/-/sdk-1.25.3.tgz",
|
||||
"integrity": "sha512-vsAMBMERybvYgKbg/l4L1rhS7VXV1c0CtyJg72vwxONVX0l4ZfKVAnZEWTQixJGTzKnELjQ59e4NbdFDALRiAQ==",
|
||||
"version": "1.26.0",
|
||||
"resolved": "https://registry.npmjs.org/@modelcontextprotocol/sdk/-/sdk-1.26.0.tgz",
|
||||
"integrity": "sha512-Y5RmPncpiDtTXDbLKswIJzTqu2hyBKxTNsgKqKclDbhIgg1wgtf1fRuvxgTnRfcnxtvvgbIEcqUOzZrJ6iSReg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@hono/node-server": "^1.19.9",
|
||||
@@ -1653,14 +1653,15 @@
|
||||
"cross-spawn": "^7.0.5",
|
||||
"eventsource": "^3.0.2",
|
||||
"eventsource-parser": "^3.0.0",
|
||||
"express": "^5.0.1",
|
||||
"express-rate-limit": "^7.5.0",
|
||||
"jose": "^6.1.1",
|
||||
"express": "^5.2.1",
|
||||
"express-rate-limit": "^8.2.1",
|
||||
"hono": "^4.11.4",
|
||||
"jose": "^6.1.3",
|
||||
"json-schema-typed": "^8.0.2",
|
||||
"pkce-challenge": "^5.0.0",
|
||||
"raw-body": "^3.0.0",
|
||||
"zod": "^3.25 || ^4.0",
|
||||
"zod-to-json-schema": "^3.25.0"
|
||||
"zod-to-json-schema": "^3.25.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
@@ -6757,10 +6758,13 @@
|
||||
}
|
||||
},
|
||||
"node_modules/express-rate-limit": {
|
||||
"version": "7.5.1",
|
||||
"resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-7.5.1.tgz",
|
||||
"integrity": "sha512-7iN8iPMDzOMHPUYllBEsQdWVB6fPDMPqwjBaFrgr4Jgr/+okjvzAy+UHlYYL/Vs0OsOrMkwS6PJDkFlJwoxUnw==",
|
||||
"version": "8.2.1",
|
||||
"resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-8.2.1.tgz",
|
||||
"integrity": "sha512-PCZEIEIxqwhzw4KF0n7QF4QqruVTcF73O5kFKUnGOyjbCCgizBBiFaYpd/fnBLUMPw/BWw9OsiN7GgrNYr7j6g==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"ip-address": "10.0.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 16"
|
||||
},
|
||||
@@ -7424,7 +7428,6 @@
|
||||
"resolved": "https://registry.npmjs.org/hono/-/hono-4.11.7.tgz",
|
||||
"integrity": "sha512-l7qMiNee7t82bH3SeyUCt9UF15EVmaBvsppY2zQtrbIhl/yzBTny+YUxsVjSjQ6gaqaeVtZmGocom8TzBlA4Yw==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=16.9.0"
|
||||
}
|
||||
@@ -7552,6 +7555,15 @@
|
||||
"node": ">=12"
|
||||
}
|
||||
},
|
||||
"node_modules/ip-address": {
|
||||
"version": "10.0.1",
|
||||
"resolved": "https://registry.npmjs.org/ip-address/-/ip-address-10.0.1.tgz",
|
||||
"integrity": "sha512-NWv9YLW4PoW2B7xtzaS3NCot75m6nK7Icdv0o3lfMceJVRfSoQwqD4wEH5rLwoKJwUiZ/rfpiVBhnaF0FK4HoA==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 12"
|
||||
}
|
||||
},
|
||||
"node_modules/ipaddr.js": {
|
||||
"version": "1.9.1",
|
||||
"resolved": "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-1.9.1.tgz",
|
||||
|
||||
@@ -65,7 +65,6 @@ from onyx.server.features.build.configs import SANDBOX_NEXTJS_PORT_END
|
||||
from onyx.server.features.build.configs import SANDBOX_NEXTJS_PORT_START
|
||||
from onyx.server.features.build.configs import SANDBOX_S3_BUCKET
|
||||
from onyx.server.features.build.configs import SANDBOX_SERVICE_ACCOUNT_NAME
|
||||
from onyx.server.features.build.s3.s3_client import build_s3_client
|
||||
from onyx.server.features.build.sandbox.base import SandboxManager
|
||||
from onyx.server.features.build.sandbox.kubernetes.internal.acp_exec_client import (
|
||||
ACPEvent,
|
||||
@@ -409,6 +408,10 @@ done
|
||||
],
|
||||
volume_mounts=[
|
||||
client.V1VolumeMount(name="files", mount_path="/workspace/files"),
|
||||
# Mount sessions directory so file-sync can create snapshots
|
||||
client.V1VolumeMount(
|
||||
name="workspace", mount_path="/workspace/sessions"
|
||||
),
|
||||
],
|
||||
resources=client.V1ResourceRequirements(
|
||||
# Reduced resources since sidecar is mostly idle (sleeping)
|
||||
@@ -442,6 +445,10 @@ done
|
||||
client.V1VolumeMount(
|
||||
name="files", mount_path="/workspace/files", read_only=True
|
||||
),
|
||||
# Mount sessions directory (shared with file-sync for snapshots)
|
||||
client.V1VolumeMount(
|
||||
name="workspace", mount_path="/workspace/sessions"
|
||||
),
|
||||
],
|
||||
resources=client.V1ResourceRequirements(
|
||||
requests={"cpu": "500m", "memory": "1Gi"},
|
||||
@@ -583,6 +590,60 @@ done
|
||||
),
|
||||
)
|
||||
|
||||
def _ensure_service_exists(
|
||||
self,
|
||||
sandbox_id: UUID,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""Ensure a ClusterIP service exists for the sandbox pod.
|
||||
|
||||
Handles the case where a service is in Terminating state (has a
|
||||
deletion_timestamp) by waiting for deletion and recreating it.
|
||||
This prevents a race condition where provision reuses an existing pod
|
||||
but the old service is still being deleted.
|
||||
"""
|
||||
service_name = self._get_service_name(str(sandbox_id))
|
||||
|
||||
try:
|
||||
svc = self._core_api.read_namespaced_service(
|
||||
name=service_name,
|
||||
namespace=self._namespace,
|
||||
)
|
||||
# Service exists - check if it's being deleted
|
||||
if svc.metadata.deletion_timestamp:
|
||||
logger.info(
|
||||
f"Service {service_name} is terminating, waiting for deletion"
|
||||
)
|
||||
self._wait_for_resource_deletion("service", service_name)
|
||||
# Now create a fresh service
|
||||
service = self._create_sandbox_service(sandbox_id, tenant_id)
|
||||
self._core_api.create_namespaced_service(
|
||||
namespace=self._namespace,
|
||||
body=service,
|
||||
)
|
||||
logger.info(f"Recreated Service {service_name} after termination")
|
||||
else:
|
||||
logger.debug(f"Service {service_name} already exists and is active")
|
||||
|
||||
except ApiException as e:
|
||||
if e.status == 404:
|
||||
# Service doesn't exist, create it
|
||||
logger.info(f"Creating missing Service {service_name}")
|
||||
service = self._create_sandbox_service(sandbox_id, tenant_id)
|
||||
try:
|
||||
self._core_api.create_namespaced_service(
|
||||
namespace=self._namespace,
|
||||
body=service,
|
||||
)
|
||||
except ApiException as svc_e:
|
||||
if svc_e.status != 409: # Ignore AlreadyExists
|
||||
raise
|
||||
logger.debug(
|
||||
f"Service {service_name} was created by another request"
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
def _get_init_container_logs(self, pod_name: str, container_name: str) -> str:
|
||||
"""Get logs from an init container.
|
||||
|
||||
@@ -798,34 +859,14 @@ done
|
||||
)
|
||||
|
||||
pod_name = self._get_pod_name(str(sandbox_id))
|
||||
service_name = self._get_service_name(str(sandbox_id))
|
||||
|
||||
# Check if pod already exists and is healthy (idempotency check)
|
||||
if self._pod_exists_and_healthy(pod_name):
|
||||
logger.info(
|
||||
f"Pod {pod_name} already exists and is healthy, reusing existing pod"
|
||||
)
|
||||
# Ensure service exists too
|
||||
try:
|
||||
self._core_api.read_namespaced_service(
|
||||
name=service_name,
|
||||
namespace=self._namespace,
|
||||
)
|
||||
except ApiException as e:
|
||||
if e.status == 404:
|
||||
# Service doesn't exist, create it
|
||||
logger.debug(f"Creating missing Service {service_name}")
|
||||
service = self._create_sandbox_service(sandbox_id, tenant_id)
|
||||
try:
|
||||
self._core_api.create_namespaced_service(
|
||||
namespace=self._namespace,
|
||||
body=service,
|
||||
)
|
||||
except ApiException as svc_e:
|
||||
if svc_e.status != 409: # Ignore AlreadyExists
|
||||
raise
|
||||
else:
|
||||
raise
|
||||
# Ensure service exists and is not terminating
|
||||
self._ensure_service_exists(sandbox_id, tenant_id)
|
||||
|
||||
# Wait for pod to be ready if it's still pending
|
||||
logger.info(f"Waiting for existing pod {pod_name} to become ready...")
|
||||
@@ -880,20 +921,8 @@ done
|
||||
else:
|
||||
raise
|
||||
|
||||
# 2. Create Service (idempotent - ignore 409)
|
||||
logger.debug(f"Creating Service {service_name}")
|
||||
service = self._create_sandbox_service(sandbox_id, tenant_id)
|
||||
try:
|
||||
self._core_api.create_namespaced_service(
|
||||
namespace=self._namespace,
|
||||
body=service,
|
||||
)
|
||||
except ApiException as e:
|
||||
if e.status != 409: # Ignore AlreadyExists
|
||||
raise
|
||||
logger.warning(
|
||||
f"During provisioning, discovered that service {service_name} already exists. Reusing"
|
||||
)
|
||||
# 2. Create Service (handles terminating services)
|
||||
self._ensure_service_exists(sandbox_id, tenant_id)
|
||||
|
||||
# 3. Wait for pod to be ready
|
||||
logger.info(f"Waiting for pod {pod_name} to become ready...")
|
||||
@@ -1335,10 +1364,12 @@ echo "Session cleanup complete"
|
||||
session_id: UUID,
|
||||
tenant_id: str,
|
||||
) -> SnapshotResult | None:
|
||||
"""Create a snapshot of a session's outputs directory.
|
||||
"""Create a snapshot of a session's outputs and attachments directories.
|
||||
|
||||
For Kubernetes backend, we exec into the pod to create the snapshot.
|
||||
Only captures sessions/$session_id/outputs/
|
||||
For Kubernetes backend, we exec into the file-sync container to create
|
||||
the snapshot and upload to S3. Captures:
|
||||
- sessions/$session_id/outputs/ (generated artifacts, web apps)
|
||||
- sessions/$session_id/attachments/ (user uploaded files)
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
@@ -1346,7 +1377,7 @@ echo "Session cleanup complete"
|
||||
tenant_id: Tenant identifier for storage path
|
||||
|
||||
Returns:
|
||||
SnapshotResult with storage path and size
|
||||
SnapshotResult with storage path and size, or None if nothing to snapshot
|
||||
|
||||
Raises:
|
||||
RuntimeError: If snapshot creation fails
|
||||
@@ -1356,26 +1387,40 @@ echo "Session cleanup complete"
|
||||
pod_name = self._get_pod_name(sandbox_id_str)
|
||||
snapshot_id = str(uuid4())
|
||||
|
||||
session_path = f"/workspace/sessions/{session_id_str}"
|
||||
# Use shlex.quote for safety (UUIDs are safe but good practice)
|
||||
safe_session_path = shlex.quote(f"/workspace/sessions/{session_id_str}")
|
||||
s3_path = (
|
||||
f"s3://{self._s3_bucket}/{tenant_id}/snapshots/"
|
||||
f"{session_id_str}/{snapshot_id}.tar.gz"
|
||||
)
|
||||
|
||||
# Exec into pod to create and upload snapshot (session outputs only)
|
||||
# Exec into pod to create and upload snapshot (outputs + attachments)
|
||||
# Uses s5cmd pipe to stream tar.gz directly to S3
|
||||
# Only snapshot if outputs/ exists. Include attachments/ only if non-empty.
|
||||
exec_command = [
|
||||
"/bin/sh",
|
||||
"-c",
|
||||
f'tar -czf - -C {session_path} outputs | aws s3 cp - {s3_path} --tagging "Type=snapshot"',
|
||||
f"""
|
||||
set -eo pipefail
|
||||
cd {safe_session_path}
|
||||
if [ ! -d outputs ]; then
|
||||
echo "EMPTY_SNAPSHOT"
|
||||
exit 0
|
||||
fi
|
||||
dirs="outputs"
|
||||
[ -d attachments ] && [ "$(ls -A attachments 2>/dev/null)" ] && dirs="$dirs attachments"
|
||||
tar -czf - $dirs | /s5cmd pipe {s3_path}
|
||||
echo "SNAPSHOT_CREATED"
|
||||
""",
|
||||
]
|
||||
|
||||
try:
|
||||
# Use exec to run snapshot command in sandbox container
|
||||
# Use exec to run snapshot command in file-sync container (has s5cmd)
|
||||
resp = k8s_stream(
|
||||
self._stream_core_api.connect_get_namespaced_pod_exec,
|
||||
name=pod_name,
|
||||
namespace=self._namespace,
|
||||
container="sandbox",
|
||||
container="file-sync",
|
||||
command=exec_command,
|
||||
stderr=True,
|
||||
stdin=False,
|
||||
@@ -1385,6 +1430,17 @@ echo "Session cleanup complete"
|
||||
|
||||
logger.debug(f"Snapshot exec output: {resp}")
|
||||
|
||||
# Check if nothing was snapshotted
|
||||
if "EMPTY_SNAPSHOT" in resp:
|
||||
logger.info(
|
||||
f"No outputs or attachments to snapshot for session {session_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Verify upload succeeded
|
||||
if "SNAPSHOT_CREATED" not in resp:
|
||||
raise RuntimeError(f"Snapshot upload may have failed. Output: {resp}")
|
||||
|
||||
except ApiException as e:
|
||||
raise RuntimeError(f"Failed to create snapshot: {e}") from e
|
||||
|
||||
@@ -1392,9 +1448,8 @@ echo "Session cleanup complete"
|
||||
# In production, you might want to query S3 for the actual size
|
||||
size_bytes = 0
|
||||
|
||||
storage_path = (
|
||||
f"sandbox-snapshots/{tenant_id}/{session_id_str}/{snapshot_id}.tar.gz"
|
||||
)
|
||||
# Storage path must match the S3 upload path (without s3://bucket/ prefix)
|
||||
storage_path = f"{tenant_id}/snapshots/{session_id_str}/{snapshot_id}.tar.gz"
|
||||
|
||||
logger.info(f"Created snapshot for session {session_id}")
|
||||
|
||||
@@ -1426,7 +1481,7 @@ echo "Session cleanup complete"
|
||||
exec_command = [
|
||||
"/bin/sh",
|
||||
"-c",
|
||||
f'[ -d "{session_path}" ] && echo "EXISTS" || echo "NOT_EXISTS"',
|
||||
f'[ -d "{session_path}" ] && echo "WORKSPACE_FOUND" || echo "WORKSPACE_MISSING"',
|
||||
]
|
||||
|
||||
try:
|
||||
@@ -1442,7 +1497,12 @@ echo "Session cleanup complete"
|
||||
tty=False,
|
||||
)
|
||||
|
||||
return "EXISTS" in resp
|
||||
result = "WORKSPACE_FOUND" in resp
|
||||
logger.info(
|
||||
f"[WORKSPACE_CHECK] session={session_id}, "
|
||||
f"path={session_path}, raw_resp={resp!r}, result={result}"
|
||||
)
|
||||
return result
|
||||
|
||||
except ApiException as e:
|
||||
logger.warning(
|
||||
@@ -1457,14 +1517,21 @@ echo "Session cleanup complete"
|
||||
snapshot_storage_path: str,
|
||||
tenant_id: str, # noqa: ARG002
|
||||
nextjs_port: int,
|
||||
llm_config: LLMProviderConfig,
|
||||
use_demo_data: bool = False,
|
||||
) -> None:
|
||||
"""Download snapshot from S3, extract into session workspace, and start NextJS.
|
||||
"""Download snapshot from S3 via s5cmd, extract, regenerate config, and start NextJS.
|
||||
|
||||
Since the sandbox pod doesn't have S3 access, this method:
|
||||
1. Downloads snapshot from S3 (using boto3 directly)
|
||||
2. Creates the session directory structure in pod
|
||||
3. Streams the tar.gz into the pod via kubectl exec
|
||||
4. Starts the NextJS dev server
|
||||
Uses the file-sync sidecar container (which has s5cmd + S3 credentials
|
||||
via IRSA) to stream the snapshot directly from S3 into the session
|
||||
directory. This avoids downloading to the backend server and the
|
||||
base64 encoding overhead of piping through kubectl exec.
|
||||
|
||||
Steps:
|
||||
1. Exec s5cmd cat in file-sync container to stream snapshot from S3
|
||||
2. Pipe directly to tar for extraction in the shared workspace volume
|
||||
3. Regenerate configuration files (AGENTS.md, opencode.json, files symlink)
|
||||
4. Start the NextJS dev server
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
@@ -1472,87 +1539,56 @@ echo "Session cleanup complete"
|
||||
snapshot_storage_path: Path to the snapshot in S3 (relative path)
|
||||
tenant_id: Tenant identifier for storage access
|
||||
nextjs_port: Port number for the NextJS dev server
|
||||
llm_config: LLM provider configuration for opencode.json
|
||||
use_demo_data: If True, symlink files/ to demo data; else to user files
|
||||
|
||||
Raises:
|
||||
RuntimeError: If snapshot restoration fails
|
||||
FileNotFoundError: If snapshot does not exist
|
||||
"""
|
||||
import tempfile
|
||||
|
||||
pod_name = self._get_pod_name(str(sandbox_id))
|
||||
session_path = f"/workspace/sessions/{session_id}"
|
||||
safe_session_path = shlex.quote(session_path)
|
||||
|
||||
# Build full S3 path
|
||||
s3_key = snapshot_storage_path
|
||||
s3_path = f"s3://{self._s3_bucket}/{snapshot_storage_path}"
|
||||
|
||||
logger.info(f"Restoring snapshot for session {session_id} from {s3_key}")
|
||||
|
||||
# Download snapshot from S3 - uses IAM roles (IRSA)
|
||||
s3_client = build_s3_client()
|
||||
tmp_path: str | None = None
|
||||
# Stream snapshot directly from S3 via s5cmd in file-sync container.
|
||||
# Mirrors the upload pattern: upload uses `tar | s5cmd pipe`,
|
||||
# restore uses `s5cmd cat | tar`. Both run in file-sync container
|
||||
# which has s5cmd and S3 credentials (IRSA). The shared workspace
|
||||
# volume makes extracted files immediately visible to the sandbox
|
||||
# container.
|
||||
restore_script = f"""
|
||||
set -eo pipefail
|
||||
mkdir -p {safe_session_path}
|
||||
/s5cmd cat {s3_path} | tar -xzf - -C {safe_session_path}
|
||||
echo "SNAPSHOT_RESTORED"
|
||||
"""
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix=".tar.gz", delete=False
|
||||
) as tmp_file:
|
||||
tmp_path = tmp_file.name
|
||||
|
||||
try:
|
||||
s3_client.download_file(self._s3_bucket, s3_key, tmp_path)
|
||||
except s3_client.exceptions.NoSuchKey:
|
||||
raise FileNotFoundError(
|
||||
f"Snapshot not found: s3://{self._s3_bucket}/{s3_key}"
|
||||
)
|
||||
|
||||
# Create session directory structure in pod
|
||||
# Use shlex.quote to prevent shell injection
|
||||
safe_session_path = shlex.quote(session_path)
|
||||
setup_script = f"""
|
||||
set -e
|
||||
mkdir -p {safe_session_path}/outputs
|
||||
"""
|
||||
k8s_stream(
|
||||
self._stream_core_api.connect_get_namespaced_pod_exec,
|
||||
name=pod_name,
|
||||
namespace=self._namespace,
|
||||
container="sandbox",
|
||||
command=["/bin/sh", "-c", setup_script],
|
||||
stderr=True,
|
||||
stdin=False,
|
||||
stdout=True,
|
||||
tty=False,
|
||||
)
|
||||
|
||||
# Stream tar.gz into pod and extract
|
||||
# We use kubectl exec with stdin to pipe the tar file
|
||||
with open(tmp_path, "rb") as tar_file:
|
||||
tar_data = tar_file.read()
|
||||
|
||||
# Use base64 encoding to safely transfer binary data
|
||||
import base64
|
||||
|
||||
tar_b64 = base64.b64encode(tar_data).decode("ascii")
|
||||
|
||||
# Extract in the session directory (tar was created with outputs/ as root)
|
||||
extract_script = f"""
|
||||
set -e
|
||||
cd {safe_session_path}
|
||||
echo '{tar_b64}' | base64 -d | tar -xzf -
|
||||
"""
|
||||
resp = k8s_stream(
|
||||
self._stream_core_api.connect_get_namespaced_pod_exec,
|
||||
name=pod_name,
|
||||
namespace=self._namespace,
|
||||
container="sandbox",
|
||||
command=["/bin/sh", "-c", extract_script],
|
||||
container="file-sync",
|
||||
command=["/bin/sh", "-c", restore_script],
|
||||
stderr=True,
|
||||
stdin=False,
|
||||
stdout=True,
|
||||
tty=False,
|
||||
)
|
||||
|
||||
logger.debug(f"Snapshot restore output: {resp}")
|
||||
logger.info(f"Restored snapshot for session {session_id}")
|
||||
if "SNAPSHOT_RESTORED" not in resp:
|
||||
raise RuntimeError(f"Snapshot restore may have failed. Output: {resp}")
|
||||
|
||||
# Regenerate configuration files that aren't in the snapshot
|
||||
# These are regenerated to ensure they match the current system state
|
||||
self._regenerate_session_config(
|
||||
pod_name=pod_name,
|
||||
session_path=safe_session_path,
|
||||
llm_config=llm_config,
|
||||
nextjs_port=nextjs_port,
|
||||
use_demo_data=use_demo_data,
|
||||
)
|
||||
|
||||
# Start NextJS dev server (check node_modules since restoring from snapshot)
|
||||
start_script = _build_nextjs_start_script(
|
||||
@@ -1569,23 +1605,95 @@ echo '{tar_b64}' | base64 -d | tar -xzf -
|
||||
stdout=True,
|
||||
tty=False,
|
||||
)
|
||||
logger.info(
|
||||
f"Started NextJS server for session {session_id} on port {nextjs_port}"
|
||||
)
|
||||
|
||||
except ApiException as e:
|
||||
raise RuntimeError(f"Failed to restore snapshot: {e}") from e
|
||||
finally:
|
||||
# Cleanup temp file
|
||||
if tmp_path:
|
||||
try:
|
||||
import os
|
||||
|
||||
os.unlink(tmp_path)
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(
|
||||
f"Failed to cleanup temp file {tmp_path}: {cleanup_error}"
|
||||
)
|
||||
def _regenerate_session_config(
|
||||
self,
|
||||
pod_name: str,
|
||||
session_path: str,
|
||||
llm_config: LLMProviderConfig,
|
||||
nextjs_port: int,
|
||||
use_demo_data: bool,
|
||||
) -> None:
|
||||
"""Regenerate session configuration files after snapshot restore.
|
||||
|
||||
Creates:
|
||||
- AGENTS.md (agent instructions)
|
||||
- opencode.json (LLM configuration)
|
||||
- files symlink (to demo data or user files)
|
||||
|
||||
Args:
|
||||
pod_name: The pod name to exec into
|
||||
session_path: Path to the session directory (already shlex.quoted)
|
||||
llm_config: LLM provider configuration
|
||||
nextjs_port: Port for NextJS (used in AGENTS.md)
|
||||
use_demo_data: Whether to use demo data or user files
|
||||
"""
|
||||
# Generate AGENTS.md content
|
||||
agent_instructions = self._load_agent_instructions(
|
||||
files_path=None, # Container script handles this at runtime
|
||||
provider=llm_config.provider,
|
||||
model_name=llm_config.model_name,
|
||||
nextjs_port=nextjs_port,
|
||||
disabled_tools=OPENCODE_DISABLED_TOOLS,
|
||||
user_name=None, # Not stored, regenerate without personalization
|
||||
user_role=None,
|
||||
use_demo_data=use_demo_data,
|
||||
include_org_info=False, # Don't include org_info for restored sessions
|
||||
)
|
||||
|
||||
# Generate opencode.json
|
||||
opencode_config = build_opencode_config(
|
||||
provider=llm_config.provider,
|
||||
model_name=llm_config.model_name,
|
||||
api_key=llm_config.api_key if llm_config.api_key else None,
|
||||
api_base=llm_config.api_base,
|
||||
disabled_tools=OPENCODE_DISABLED_TOOLS,
|
||||
)
|
||||
opencode_json = json.dumps(opencode_config)
|
||||
|
||||
# Escape for shell (single quotes)
|
||||
opencode_json_escaped = opencode_json.replace("'", "'\\''")
|
||||
agent_instructions_escaped = agent_instructions.replace("'", "'\\''")
|
||||
|
||||
# Build files symlink setup
|
||||
if use_demo_data:
|
||||
symlink_target = "/workspace/demo_data"
|
||||
else:
|
||||
symlink_target = "/workspace/files"
|
||||
|
||||
config_script = f"""
|
||||
set -e
|
||||
|
||||
# Create files symlink
|
||||
echo "Creating files symlink to {symlink_target}"
|
||||
ln -sf {symlink_target} {session_path}/files
|
||||
|
||||
# Write agent instructions
|
||||
echo "Writing AGENTS.md"
|
||||
printf '%s' '{agent_instructions_escaped}' > {session_path}/AGENTS.md
|
||||
|
||||
# Write opencode config
|
||||
echo "Writing opencode.json"
|
||||
printf '%s' '{opencode_json_escaped}' > {session_path}/opencode.json
|
||||
|
||||
echo "Session config regeneration complete"
|
||||
"""
|
||||
|
||||
logger.info("Regenerating session configuration files")
|
||||
k8s_stream(
|
||||
self._stream_core_api.connect_get_namespaced_pod_exec,
|
||||
name=pod_name,
|
||||
namespace=self._namespace,
|
||||
container="sandbox",
|
||||
command=["/bin/sh", "-c", config_script],
|
||||
stderr=True,
|
||||
stdin=False,
|
||||
stdout=True,
|
||||
tty=False,
|
||||
)
|
||||
logger.info("Session configuration files regenerated")
|
||||
|
||||
def health_check(self, sandbox_id: UUID, timeout: float = 60.0) -> bool:
|
||||
"""Check if the sandbox pod is healthy (can exec into it).
|
||||
|
||||
@@ -608,34 +608,14 @@ class LocalSandboxManager(SandboxManager):
|
||||
session_id: UUID,
|
||||
tenant_id: str,
|
||||
) -> SnapshotResult | None:
|
||||
"""Create a snapshot of a session's outputs directory.
|
||||
"""Not implemented for local backend - workspaces persist on disk.
|
||||
|
||||
Returns None if snapshots are disabled (local backend).
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
session_id: The session ID to snapshot
|
||||
tenant_id: Tenant identifier for storage path
|
||||
|
||||
Returns:
|
||||
SnapshotResult with storage path and size, or None if
|
||||
snapshots are disabled for this backend
|
||||
Local sandboxes don't use snapshots since the filesystem persists.
|
||||
This should never be called for local backend.
|
||||
"""
|
||||
session_path = self._get_session_path(sandbox_id, session_id)
|
||||
# SnapshotManager expects string session_id for storage path
|
||||
_, storage_path, size_bytes = self._snapshot_manager.create_snapshot(
|
||||
session_path,
|
||||
str(session_id),
|
||||
tenant_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created snapshot for session {session_id}, size: {size_bytes} bytes"
|
||||
)
|
||||
|
||||
return SnapshotResult(
|
||||
storage_path=storage_path,
|
||||
size_bytes=size_bytes,
|
||||
raise NotImplementedError(
|
||||
"create_snapshot is not supported for local backend. "
|
||||
"Local sandboxes persist on disk and don't use snapshots."
|
||||
)
|
||||
|
||||
def session_workspace_exists(
|
||||
@@ -663,52 +643,23 @@ class LocalSandboxManager(SandboxManager):
|
||||
snapshot_storage_path: str,
|
||||
tenant_id: str, # noqa: ARG002
|
||||
nextjs_port: int,
|
||||
llm_config: LLMProviderConfig,
|
||||
use_demo_data: bool = False,
|
||||
) -> None:
|
||||
"""Restore a snapshot into a session's workspace directory and start NextJS.
|
||||
"""Not implemented for local backend - workspaces persist on disk.
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
session_id: The session ID to restore
|
||||
snapshot_storage_path: Path to the snapshot in storage
|
||||
tenant_id: Tenant identifier for storage access
|
||||
nextjs_port: Port number for the NextJS dev server
|
||||
|
||||
Raises:
|
||||
RuntimeError: If snapshot restoration fails
|
||||
FileNotFoundError: If snapshot does not exist
|
||||
Local sandboxes don't use snapshots since the filesystem persists.
|
||||
This should never be called for local backend.
|
||||
"""
|
||||
session_path = self._get_session_path(sandbox_id, session_id)
|
||||
|
||||
# Ensure session directory exists
|
||||
session_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Use SnapshotManager to restore
|
||||
self._snapshot_manager.restore_snapshot(
|
||||
storage_path=snapshot_storage_path,
|
||||
target_path=session_path,
|
||||
raise NotImplementedError(
|
||||
"restore_snapshot is not supported for local backend. "
|
||||
"Local sandboxes persist on disk and don't use snapshots."
|
||||
)
|
||||
|
||||
logger.info(f"Restored snapshot for session {session_id}")
|
||||
|
||||
# Start NextJS dev server
|
||||
web_dir = session_path / "outputs" / "web"
|
||||
if web_dir.exists():
|
||||
logger.info(f"Starting Next.js server at {web_dir} on port {nextjs_port}")
|
||||
nextjs_process = self._process_manager.start_nextjs_server(
|
||||
web_dir, nextjs_port
|
||||
)
|
||||
# Store process for clean shutdown on session delete
|
||||
self._nextjs_processes[(sandbox_id, session_id)] = nextjs_process
|
||||
logger.info(
|
||||
f"Started NextJS server for session {session_id} on port {nextjs_port}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Web directory not found at {web_dir}, skipping NextJS startup"
|
||||
)
|
||||
|
||||
def health_check(
|
||||
self, sandbox_id: UUID, timeout: float = 60.0 # noqa: ARG002
|
||||
self,
|
||||
sandbox_id: UUID,
|
||||
timeout: float = 60.0, # noqa: ARG002
|
||||
) -> bool:
|
||||
"""Check if the sandbox is healthy (folder exists).
|
||||
|
||||
|
||||
@@ -16,6 +16,9 @@ from onyx.server.features.build.configs import SANDBOX_BACKEND
|
||||
from onyx.server.features.build.configs import SANDBOX_IDLE_TIMEOUT_SECONDS
|
||||
from onyx.server.features.build.configs import SandboxBackend
|
||||
from onyx.server.features.build.db.build_session import clear_nextjs_ports_for_user
|
||||
from onyx.server.features.build.db.build_session import (
|
||||
mark_user_sessions_idle__no_commit,
|
||||
)
|
||||
from onyx.server.features.build.db.sandbox import get_sandbox_by_user_id
|
||||
from onyx.server.features.build.sandbox.base import get_sandbox_manager
|
||||
from onyx.server.features.build.sandbox.kubernetes.kubernetes_sandbox_manager import (
|
||||
@@ -75,12 +78,11 @@ def cleanup_idle_sandboxes_task(self: Task, *, tenant_id: str) -> None: # noqa:
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from onyx.db.enums import SandboxStatus
|
||||
from onyx.server.features.build.db.sandbox import create_snapshot
|
||||
from onyx.server.features.build.db.sandbox import create_snapshot__no_commit
|
||||
from onyx.server.features.build.db.sandbox import get_idle_sandboxes
|
||||
from onyx.server.features.build.db.sandbox import (
|
||||
update_sandbox_status__no_commit,
|
||||
)
|
||||
from onyx.server.features.build.sandbox import get_sandbox_manager
|
||||
|
||||
sandbox_manager = get_sandbox_manager()
|
||||
|
||||
@@ -128,7 +130,7 @@ def cleanup_idle_sandboxes_task(self: Task, *, tenant_id: str) -> None: # noqa:
|
||||
)
|
||||
if snapshot_result:
|
||||
# Create DB record for the snapshot
|
||||
create_snapshot(
|
||||
create_snapshot__no_commit(
|
||||
db_session,
|
||||
session_id,
|
||||
snapshot_result.storage_path,
|
||||
@@ -154,7 +156,15 @@ def cleanup_idle_sandboxes_task(self: Task, *, tenant_id: str) -> None: # noqa:
|
||||
f"{sandbox.user_id}"
|
||||
)
|
||||
|
||||
# Mark sandbox as SLEEPING (not TERMINATED)
|
||||
# Mark all active sessions as IDLE
|
||||
idled = mark_user_sessions_idle__no_commit(
|
||||
db_session, sandbox.user_id
|
||||
)
|
||||
task_logger.debug(
|
||||
f"Marked {idled} sessions as IDLE for user "
|
||||
f"{sandbox.user_id}"
|
||||
)
|
||||
|
||||
update_sandbox_status__no_commit(
|
||||
db_session, sandbox_id, SandboxStatus.SLEEPING
|
||||
)
|
||||
@@ -272,7 +282,7 @@ def sync_sandbox_files(
|
||||
task_logger.debug(f"No sandbox found for user {user_id}, skipping sync")
|
||||
return False
|
||||
|
||||
if sandbox.status not in [SandboxStatus.RUNNING, SandboxStatus.IDLE]:
|
||||
if sandbox.status != SandboxStatus.RUNNING:
|
||||
task_logger.debug(
|
||||
f"Sandbox {sandbox.id} not running (status={sandbox.status}), "
|
||||
f"skipping sync"
|
||||
|
||||
@@ -1675,7 +1675,8 @@ class SessionManager:
|
||||
user_id: The user ID to verify ownership
|
||||
|
||||
Returns:
|
||||
Dict with has_webapp, webapp_url, and status, or None if session not found
|
||||
Dict with has_webapp, webapp_url, status, and ready,
|
||||
or None if session not found
|
||||
"""
|
||||
# Verify session ownership
|
||||
session = get_build_session(session_id, user_id, self._db_session)
|
||||
@@ -1684,20 +1685,51 @@ class SessionManager:
|
||||
|
||||
sandbox = get_sandbox_by_user_id(self._db_session, user_id)
|
||||
if sandbox is None:
|
||||
return {"has_webapp": False, "webapp_url": None, "status": "no_sandbox"}
|
||||
return {
|
||||
"has_webapp": False,
|
||||
"webapp_url": None,
|
||||
"status": "no_sandbox",
|
||||
"ready": False,
|
||||
}
|
||||
|
||||
# Return the proxy URL - the proxy handles routing to the correct sandbox
|
||||
# for both local and Kubernetes environments
|
||||
webapp_url = None
|
||||
ready = False
|
||||
if session.nextjs_port:
|
||||
webapp_url = f"{WEB_DOMAIN}/api/build/sessions/{session_id}/webapp"
|
||||
|
||||
# Quick health check: can the API server reach the NextJS dev server?
|
||||
ready = self._check_nextjs_ready(sandbox.id, session.nextjs_port)
|
||||
|
||||
return {
|
||||
"has_webapp": session.nextjs_port is not None,
|
||||
"webapp_url": webapp_url,
|
||||
"status": sandbox.status.value,
|
||||
"ready": ready,
|
||||
}
|
||||
|
||||
def _check_nextjs_ready(self, sandbox_id: UUID, port: int) -> bool:
|
||||
"""Check if the NextJS dev server is responding.
|
||||
|
||||
Does a quick HTTP GET to the sandbox's internal URL with a short timeout.
|
||||
Returns True if the server responds with any status code, False on timeout
|
||||
or connection error.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
from onyx.server.features.build.sandbox.base import get_sandbox_manager
|
||||
|
||||
try:
|
||||
sandbox_manager = get_sandbox_manager()
|
||||
internal_url = sandbox_manager.get_webapp_url(sandbox_id, port)
|
||||
with httpx.Client(timeout=2.0) as client:
|
||||
resp = client.get(internal_url)
|
||||
# Any response (even 500) means the server is up
|
||||
return resp.status_code < 500
|
||||
except (httpx.TimeoutException, httpx.ConnectError, Exception):
|
||||
return False
|
||||
|
||||
def download_webapp_zip(
|
||||
self,
|
||||
session_id: UUID,
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.hierarchy_access import get_user_external_group_ids
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.document import get_accessible_documents_for_hierarchy_node_paginated
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
@@ -22,10 +25,25 @@ from onyx.server.features.hierarchy.models import HierarchyNodeDocumentsResponse
|
||||
from onyx.server.features.hierarchy.models import HierarchyNodesResponse
|
||||
from onyx.server.features.hierarchy.models import HierarchyNodeSummary
|
||||
|
||||
OPENSEARCH_NOT_ENABLED_MESSAGE = (
|
||||
"Per-source knowledge selection is coming soon in v3.0! "
|
||||
"OpenSearch indexing must be enabled to use this feature."
|
||||
)
|
||||
|
||||
router = APIRouter(prefix=HIERARCHY_NODES_PREFIX)
|
||||
|
||||
|
||||
def _require_opensearch() -> None:
|
||||
if (
|
||||
not ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
or not ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=OPENSEARCH_NOT_ENABLED_MESSAGE,
|
||||
)
|
||||
|
||||
|
||||
def _get_user_access_info(
|
||||
user: User | None, db_session: Session
|
||||
) -> tuple[str | None, list[str]]:
|
||||
@@ -40,6 +58,7 @@ def list_accessible_hierarchy_nodes(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HierarchyNodesResponse:
|
||||
_require_opensearch()
|
||||
user_email, external_group_ids = _get_user_access_info(user, db_session)
|
||||
nodes = get_accessible_hierarchy_nodes_for_source(
|
||||
db_session=db_session,
|
||||
@@ -66,6 +85,7 @@ def list_accessible_hierarchy_node_documents(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> HierarchyNodeDocumentsResponse:
|
||||
_require_opensearch()
|
||||
user_email, external_group_ids = _get_user_access_info(user, db_session)
|
||||
cursor = documents_request.cursor
|
||||
sort_field = documents_request.sort_field
|
||||
|
||||
@@ -255,7 +255,7 @@ def list_llm_providers(
|
||||
llm_provider_list: list[LLMProviderView] = []
|
||||
for llm_provider_model in fetch_existing_llm_providers(
|
||||
db_session=db_session,
|
||||
flow_types=[LLMModelFlowType.CHAT, LLMModelFlowType.VISION],
|
||||
flow_type_filter=[],
|
||||
exclude_image_generation_providers=not include_image_gen,
|
||||
):
|
||||
from_model_start = datetime.now(timezone.utc)
|
||||
@@ -503,9 +503,7 @@ def list_llm_provider_basics(
|
||||
start_time = datetime.now(timezone.utc)
|
||||
logger.debug("Starting to fetch user-accessible LLM providers")
|
||||
|
||||
all_providers = fetch_existing_llm_providers(
|
||||
db_session, [LLMModelFlowType.CHAT, LLMModelFlowType.VISION]
|
||||
)
|
||||
all_providers = fetch_existing_llm_providers(db_session, [])
|
||||
user_group_ids = fetch_user_group_ids(db_session, user)
|
||||
is_admin = user.role == UserRole.ADMIN
|
||||
|
||||
@@ -514,9 +512,9 @@ def list_llm_provider_basics(
|
||||
for provider in all_providers:
|
||||
# Use centralized access control logic with persona=None since we're
|
||||
# listing providers without a specific persona context. This correctly:
|
||||
# - Includes all public providers
|
||||
# - Includes public providers WITHOUT persona restrictions
|
||||
# - Includes providers user can access via group membership
|
||||
# - Excludes persona-only restricted providers (requires specific persona)
|
||||
# - Excludes providers with persona restrictions (requires specific persona)
|
||||
# - Excludes non-public providers with no restrictions (admin-only)
|
||||
if can_user_access_llm_provider(
|
||||
provider, user_group_ids, persona=None, is_admin=is_admin
|
||||
@@ -541,7 +539,7 @@ def get_valid_model_names_for_persona(
|
||||
|
||||
Returns a list of model names (e.g., ["gpt-4o", "claude-3-5-sonnet"]) that are
|
||||
available to the user when using this persona, respecting all RBAC restrictions.
|
||||
Public providers are always included.
|
||||
Public providers are included unless they have persona restrictions that exclude this persona.
|
||||
"""
|
||||
persona = fetch_persona_with_groups(db_session, persona_id)
|
||||
if not persona:
|
||||
@@ -555,7 +553,7 @@ def get_valid_model_names_for_persona(
|
||||
|
||||
valid_models = []
|
||||
for llm_provider_model in all_providers:
|
||||
# Public providers always included, restricted checked via RBAC
|
||||
# Check access with persona context — respects all RBAC restrictions
|
||||
if can_user_access_llm_provider(
|
||||
llm_provider_model, user_group_ids, persona, is_admin=is_admin
|
||||
):
|
||||
@@ -576,7 +574,7 @@ def list_llm_providers_for_persona(
|
||||
"""Get LLM providers for a specific persona.
|
||||
|
||||
Returns providers that the user can access when using this persona:
|
||||
- All public providers (is_public=True) - ALWAYS included
|
||||
- Public providers (respecting persona restrictions if set)
|
||||
- Restricted providers user can access via group/persona restrictions
|
||||
|
||||
This endpoint is used for background fetching of restricted providers
|
||||
@@ -605,7 +603,7 @@ def list_llm_providers_for_persona(
|
||||
llm_provider_list: list[LLMProviderDescriptor] = []
|
||||
|
||||
for llm_provider_model in all_providers:
|
||||
# Use simplified access check - public providers always included
|
||||
# Check access with persona context — respects persona restrictions
|
||||
if can_user_access_llm_provider(
|
||||
llm_provider_model, user_group_ids, persona, is_admin=is_admin
|
||||
):
|
||||
|
||||
@@ -30,12 +30,14 @@ from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import enforce_seat_limit
|
||||
from onyx.auth.users import optional_user
|
||||
from onyx.configs.app_configs import AUTH_BACKEND
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import AuthBackend
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import ENABLE_EMAIL_INVITES
|
||||
from onyx.configs.app_configs import NUM_FREE_TRIAL_USER_INVITES
|
||||
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
|
||||
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
@@ -90,6 +92,7 @@ from onyx.server.manage.models import UserSpecificAssistantPreferences
|
||||
from onyx.server.models import FullUserSnapshot
|
||||
from onyx.server.models import InvitedUserSnapshot
|
||||
from onyx.server.models import MinimalUserSnapshot
|
||||
from onyx.server.usage_limits import is_tenant_on_trial_fn
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
@@ -391,14 +394,20 @@ def bulk_invite_users(
|
||||
if e not in existing_users and e not in already_invited
|
||||
]
|
||||
|
||||
# Limit bulk invites for trial tenants to prevent email spam
|
||||
# Only count new invites, not re-invites of existing users
|
||||
if MULTI_TENANT and is_tenant_on_trial_fn(tenant_id):
|
||||
current_invited = len(already_invited)
|
||||
if current_invited + len(emails_needing_seats) > NUM_FREE_TRIAL_USER_INVITES:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="You have hit your invite limit. "
|
||||
"Please upgrade for unlimited invites.",
|
||||
)
|
||||
|
||||
# Check seat availability for new users
|
||||
# Only for self-hosted (non-multi-tenant) deployments
|
||||
if not MULTI_TENANT and emails_needing_seats:
|
||||
result = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license", "check_seat_availability", None
|
||||
)(db_session, seats_needed=len(emails_needing_seats))
|
||||
if result is not None and not result.available:
|
||||
raise HTTPException(status_code=402, detail=result.error_message)
|
||||
if emails_needing_seats:
|
||||
enforce_seat_limit(db_session, seats_needed=len(emails_needing_seats))
|
||||
|
||||
if MULTI_TENANT:
|
||||
try:
|
||||
@@ -414,10 +423,10 @@ def bulk_invite_users(
|
||||
all_emails = list(set(new_invited_emails) | set(initial_invited_users))
|
||||
number_of_invited_users = write_invited_users(all_emails)
|
||||
|
||||
# send out email invitations if enabled
|
||||
# send out email invitations only to new users (not already invited or existing)
|
||||
if ENABLE_EMAIL_INVITES:
|
||||
try:
|
||||
for email in new_invited_emails:
|
||||
for email in emails_needing_seats:
|
||||
send_user_email_invite(email, current_user, AUTH_TYPE)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending email invite to invited users: {e}")
|
||||
@@ -564,12 +573,7 @@ def activate_user_api(
|
||||
|
||||
# Check seat availability before activating
|
||||
# Only for self-hosted (non-multi-tenant) deployments
|
||||
if not MULTI_TENANT:
|
||||
result = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license", "check_seat_availability", None
|
||||
)(db_session, seats_needed=1)
|
||||
if result is not None and not result.available:
|
||||
raise HTTPException(status_code=402, detail=result.error_message)
|
||||
enforce_seat_limit(db_session)
|
||||
|
||||
activate_user(user_to_activate, db_session)
|
||||
|
||||
@@ -593,11 +597,17 @@ def get_valid_domains(
|
||||
|
||||
@router.get("/users", tags=PUBLIC_API_TAGS)
|
||||
def list_all_users_basic_info(
|
||||
include_api_keys: bool = False,
|
||||
_: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[MinimalUserSnapshot]:
|
||||
users = get_all_users(db_session)
|
||||
return [MinimalUserSnapshot(id=user.id, email=user.email) for user in users]
|
||||
return [
|
||||
MinimalUserSnapshot(id=user.id, email=user.email)
|
||||
for user in users
|
||||
if user.role != UserRole.SLACK_USER
|
||||
and (include_api_keys or not is_api_key_email_address(user.email))
|
||||
]
|
||||
|
||||
|
||||
@router.get("/get-user-role", tags=PUBLIC_API_TAGS)
|
||||
|
||||
@@ -87,6 +87,8 @@ class SendMessageRequest(BaseModel):
|
||||
message: str
|
||||
|
||||
llm_override: LLMOverride | None = None
|
||||
# Test-only override for deterministic LiteLLM mock responses.
|
||||
mock_llm_response: str | None = None
|
||||
|
||||
allowed_tool_ids: list[int] | None = None
|
||||
forced_tool_id: int | None = None
|
||||
@@ -191,6 +193,8 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
# allows the caller to override the Persona / Prompt
|
||||
# these do not persist in the chat thread details
|
||||
llm_override: LLMOverride | None = None
|
||||
# Test-only override for deterministic LiteLLM mock responses.
|
||||
mock_llm_response: str | None = None
|
||||
prompt_override: PromptOverride | None = None
|
||||
|
||||
# Allows the caller to override the temperature for the chat session
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from enum import Enum
|
||||
from typing import Annotated
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
from typing import Union
|
||||
|
||||
@@ -37,6 +38,7 @@ class StreamingType(Enum):
|
||||
REASONING_DELTA = "reasoning_delta"
|
||||
REASONING_DONE = "reasoning_done"
|
||||
CITATION_INFO = "citation_info"
|
||||
TOOL_CALL_DEBUG = "tool_call_debug"
|
||||
|
||||
DEEP_RESEARCH_PLAN_START = "deep_research_plan_start"
|
||||
DEEP_RESEARCH_PLAN_DELTA = "deep_research_plan_delta"
|
||||
@@ -127,6 +129,14 @@ class CitationInfo(BaseObj):
|
||||
document_id: str
|
||||
|
||||
|
||||
class ToolCallDebug(BaseObj):
|
||||
type: Literal["tool_call_debug"] = StreamingType.TOOL_CALL_DEBUG.value
|
||||
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
tool_args: dict[str, Any]
|
||||
|
||||
|
||||
################################################
|
||||
# Tool Packets
|
||||
################################################
|
||||
@@ -318,6 +328,7 @@ PacketObj = Union[
|
||||
ReasoningDone,
|
||||
# Citation Packets
|
||||
CitationInfo,
|
||||
ToolCallDebug,
|
||||
# Deep Research Packets
|
||||
DeepResearchPlanStart,
|
||||
DeepResearchPlanDelta,
|
||||
|
||||
@@ -57,9 +57,11 @@ class Settings(BaseModel):
|
||||
anonymous_user_enabled: bool | None = None
|
||||
deep_research_enabled: bool | None = None
|
||||
|
||||
# Enterprise features flag - set by license enforcement at runtime
|
||||
# When LICENSE_ENFORCEMENT_ENABLED=true, this reflects license status
|
||||
# When LICENSE_ENFORCEMENT_ENABLED=false, defaults to False
|
||||
# Whether EE features are unlocked for use.
|
||||
# Depends on license status: True when the user has a valid license
|
||||
# (ACTIVE, GRACE_PERIOD, PAYMENT_REMINDER), False when there's no license
|
||||
# or the license is expired (GATED_ACCESS).
|
||||
# This controls UI visibility of EE features (user groups, analytics, RBAC, etc.).
|
||||
ee_features_enabled: bool = False
|
||||
|
||||
temperature_override_enabled: bool | None = False
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from onyx.file_processing.html_utils import ParsedHTML
|
||||
from onyx.file_processing.html_utils import web_html_cleanup
|
||||
@@ -21,10 +22,22 @@ from onyx.utils.web_content import title_from_url
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
DEFAULT_TIMEOUT_SECONDS = 15
|
||||
DEFAULT_READ_TIMEOUT_SECONDS = 15
|
||||
DEFAULT_CONNECT_TIMEOUT_SECONDS = 5
|
||||
DEFAULT_USER_AGENT = "OnyxWebCrawler/1.0 (+https://www.onyx.app)"
|
||||
DEFAULT_MAX_PDF_SIZE_BYTES = 50 * 1024 * 1024 # 50 MB
|
||||
DEFAULT_MAX_HTML_SIZE_BYTES = 20 * 1024 * 1024 # 20 MB
|
||||
DEFAULT_MAX_WORKERS = 5
|
||||
|
||||
|
||||
def _failed_result(url: str) -> WebContent:
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
|
||||
|
||||
class OnyxWebCrawler(WebContentProvider):
|
||||
@@ -37,12 +50,14 @@ class OnyxWebCrawler(WebContentProvider):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
|
||||
timeout_seconds: int = DEFAULT_READ_TIMEOUT_SECONDS,
|
||||
connect_timeout_seconds: int = DEFAULT_CONNECT_TIMEOUT_SECONDS,
|
||||
user_agent: str = DEFAULT_USER_AGENT,
|
||||
max_pdf_size_bytes: int | None = None,
|
||||
max_html_size_bytes: int | None = None,
|
||||
) -> None:
|
||||
self._timeout_seconds = timeout_seconds
|
||||
self._read_timeout_seconds = timeout_seconds
|
||||
self._connect_timeout_seconds = connect_timeout_seconds
|
||||
self._max_pdf_size_bytes = max_pdf_size_bytes
|
||||
self._max_html_size_bytes = max_html_size_bytes
|
||||
self._headers = {
|
||||
@@ -51,75 +66,68 @@ class OnyxWebCrawler(WebContentProvider):
|
||||
}
|
||||
|
||||
def contents(self, urls: Sequence[str]) -> list[WebContent]:
|
||||
results: list[WebContent] = []
|
||||
for url in urls:
|
||||
results.append(self._fetch_url(url))
|
||||
return results
|
||||
if not urls:
|
||||
return []
|
||||
|
||||
max_workers = min(DEFAULT_MAX_WORKERS, len(urls))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
return list(executor.map(self._fetch_url_safe, urls))
|
||||
|
||||
def _fetch_url_safe(self, url: str) -> WebContent:
|
||||
"""Wrapper that catches all exceptions so one bad URL doesn't kill the batch."""
|
||||
try:
|
||||
return self._fetch_url(url)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Onyx crawler unexpected error for %s (%s)",
|
||||
url,
|
||||
exc.__class__.__name__,
|
||||
)
|
||||
return _failed_result(url)
|
||||
|
||||
def _fetch_url(self, url: str) -> WebContent:
|
||||
try:
|
||||
# Use SSRF-safe request to prevent DNS rebinding attacks
|
||||
response = ssrf_safe_get(
|
||||
url, headers=self._headers, timeout=self._timeout_seconds
|
||||
url,
|
||||
headers=self._headers,
|
||||
timeout=(self._connect_timeout_seconds, self._read_timeout_seconds),
|
||||
)
|
||||
except SSRFException as exc:
|
||||
logger.error(
|
||||
"SSRF protection blocked request to %s: %s",
|
||||
"SSRF protection blocked request to %s (%s)",
|
||||
url,
|
||||
str(exc),
|
||||
exc.__class__.__name__,
|
||||
)
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - network failures vary
|
||||
return _failed_result(url)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Onyx crawler failed to fetch %s (%s)",
|
||||
url,
|
||||
exc.__class__.__name__,
|
||||
)
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
return _failed_result(url)
|
||||
|
||||
if response.status_code >= 400:
|
||||
logger.warning("Onyx crawler received %s for %s", response.status_code, url)
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
return _failed_result(url)
|
||||
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
content_sniff = response.content[:1024] if response.content else None
|
||||
content = response.content
|
||||
|
||||
content_sniff = content[:1024] if content else None
|
||||
if is_pdf_resource(url, content_type, content_sniff):
|
||||
if (
|
||||
self._max_pdf_size_bytes is not None
|
||||
and len(response.content) > self._max_pdf_size_bytes
|
||||
and len(content) > self._max_pdf_size_bytes
|
||||
):
|
||||
logger.warning(
|
||||
"PDF content too large (%d bytes) for %s, max is %d",
|
||||
len(response.content),
|
||||
len(content),
|
||||
url,
|
||||
self._max_pdf_size_bytes,
|
||||
)
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
text_content, metadata = extract_pdf_text(response.content)
|
||||
return _failed_result(url)
|
||||
text_content, metadata = extract_pdf_text(content)
|
||||
title = title_from_pdf_metadata(metadata) or title_from_url(url)
|
||||
return WebContent(
|
||||
title=title,
|
||||
@@ -131,25 +139,19 @@ class OnyxWebCrawler(WebContentProvider):
|
||||
|
||||
if (
|
||||
self._max_html_size_bytes is not None
|
||||
and len(response.content) > self._max_html_size_bytes
|
||||
and len(content) > self._max_html_size_bytes
|
||||
):
|
||||
logger.warning(
|
||||
"HTML content too large (%d bytes) for %s, max is %d",
|
||||
len(response.content),
|
||||
len(content),
|
||||
url,
|
||||
self._max_html_size_bytes,
|
||||
)
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
return _failed_result(url)
|
||||
|
||||
try:
|
||||
decoded_html = decode_html_bytes(
|
||||
response.content,
|
||||
content,
|
||||
content_type=content_type,
|
||||
fallback_encoding=response.apparent_encoding or response.encoding,
|
||||
)
|
||||
|
||||
@@ -47,6 +47,7 @@ from onyx.tools.tool_implementations.web_search.utils import (
|
||||
from onyx.tools.tool_implementations.web_search.utils import MAX_CHARS_PER_URL
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.url import normalize_url as normalize_web_content_url
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
@@ -791,7 +792,9 @@ class OpenURLTool(Tool[OpenURLToolOverrideKwargs]):
|
||||
for url in all_urls:
|
||||
doc_id = url_to_doc_id.get(url)
|
||||
indexed_section = indexed_by_doc_id.get(doc_id) if doc_id else None
|
||||
crawled_section = crawled_by_url.get(url)
|
||||
# WebContent.link is normalized (query/fragment stripped). Match on the
|
||||
# same normalized form to avoid dropping successful crawl results.
|
||||
crawled_section = crawled_by_url.get(normalize_web_content_url(url))
|
||||
|
||||
if indexed_section and indexed_section.combined_content:
|
||||
# Prefer indexed
|
||||
|
||||
@@ -0,0 +1,260 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
from onyx.tools.tool_implementations.web_search.models import (
|
||||
WebSearchProvider,
|
||||
)
|
||||
from onyx.tools.tool_implementations.web_search.models import WebSearchResult
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
BRAVE_WEB_SEARCH_URL = "https://api.search.brave.com/res/v1/web/search"
|
||||
BRAVE_MAX_RESULTS_PER_REQUEST = 20
|
||||
BRAVE_SAFESEARCH_OPTIONS = {"off", "moderate", "strict"}
|
||||
BRAVE_FRESHNESS_OPTIONS = {"pd", "pw", "pm", "py"}
|
||||
|
||||
|
||||
class RetryableBraveSearchError(Exception):
|
||||
"""Error type used to trigger retry for transient Brave search failures."""
|
||||
|
||||
|
||||
class BraveClient(WebSearchProvider):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
*,
|
||||
num_results: int = 10,
|
||||
timeout_seconds: int = 10,
|
||||
country: str | None = None,
|
||||
search_lang: str | None = None,
|
||||
ui_lang: str | None = None,
|
||||
safesearch: str | None = None,
|
||||
freshness: str | None = None,
|
||||
) -> None:
|
||||
if timeout_seconds <= 0:
|
||||
raise ValueError("Brave provider config 'timeout_seconds' must be > 0.")
|
||||
|
||||
self._headers = {
|
||||
"Accept": "application/json",
|
||||
"X-Subscription-Token": api_key,
|
||||
}
|
||||
logger.debug(f"Count of results passed to BraveClient: {num_results}")
|
||||
self._num_results = max(1, min(num_results, BRAVE_MAX_RESULTS_PER_REQUEST))
|
||||
self._timeout_seconds = timeout_seconds
|
||||
self._country = _normalize_country(country)
|
||||
self._search_lang = _normalize_language_code(
|
||||
search_lang, field_name="search_lang"
|
||||
)
|
||||
self._ui_lang = _normalize_language_code(ui_lang, field_name="ui_lang")
|
||||
self._safesearch = _normalize_option(
|
||||
safesearch,
|
||||
field_name="safesearch",
|
||||
allowed_values=BRAVE_SAFESEARCH_OPTIONS,
|
||||
)
|
||||
self._freshness = _normalize_option(
|
||||
freshness,
|
||||
field_name="freshness",
|
||||
allowed_values=BRAVE_FRESHNESS_OPTIONS,
|
||||
)
|
||||
|
||||
def _build_search_params(self, query: str) -> dict[str, str]:
|
||||
params = {
|
||||
"q": query,
|
||||
"count": str(self._num_results),
|
||||
}
|
||||
if self._country:
|
||||
params["country"] = self._country
|
||||
if self._search_lang:
|
||||
params["search_lang"] = self._search_lang
|
||||
if self._ui_lang:
|
||||
params["ui_lang"] = self._ui_lang
|
||||
if self._safesearch:
|
||||
params["safesearch"] = self._safesearch
|
||||
if self._freshness:
|
||||
params["freshness"] = self._freshness
|
||||
return params
|
||||
|
||||
@retry_builder(
|
||||
tries=3,
|
||||
delay=1,
|
||||
backoff=2,
|
||||
exceptions=(RetryableBraveSearchError,),
|
||||
)
|
||||
def _search_with_retries(self, query: str) -> list[WebSearchResult]:
|
||||
params = self._build_search_params(query)
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
BRAVE_WEB_SEARCH_URL,
|
||||
headers=self._headers,
|
||||
params=params,
|
||||
timeout=self._timeout_seconds,
|
||||
)
|
||||
except requests.RequestException as exc:
|
||||
raise RetryableBraveSearchError(
|
||||
f"Brave search request failed: {exc}"
|
||||
) from exc
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as exc:
|
||||
error_msg = _build_error_message(response)
|
||||
if _is_retryable_status(response.status_code):
|
||||
raise RetryableBraveSearchError(error_msg) from exc
|
||||
raise ValueError(error_msg) from exc
|
||||
|
||||
data = response.json()
|
||||
web_results = (data.get("web") or {}).get("results") or []
|
||||
|
||||
results: list[WebSearchResult] = []
|
||||
for result in web_results:
|
||||
if not isinstance(result, dict):
|
||||
continue
|
||||
|
||||
link = _clean_string(result.get("url"))
|
||||
if not link:
|
||||
continue
|
||||
|
||||
title = _clean_string(result.get("title"))
|
||||
description = _clean_string(result.get("description"))
|
||||
|
||||
results.append(
|
||||
WebSearchResult(
|
||||
title=title,
|
||||
link=link,
|
||||
snippet=description,
|
||||
author=None,
|
||||
published_date=None,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def search(self, query: str) -> list[WebSearchResult]:
|
||||
try:
|
||||
return self._search_with_retries(query)
|
||||
except RetryableBraveSearchError as exc:
|
||||
raise ValueError(str(exc)) from exc
|
||||
|
||||
def test_connection(self) -> dict[str, str]:
|
||||
try:
|
||||
test_results = self.search("test")
|
||||
if not test_results or not any(result.link for result in test_results):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Brave API key validation failed: search returned no results.",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except (ValueError, requests.RequestException) as e:
|
||||
error_msg = str(e)
|
||||
lower = error_msg.lower()
|
||||
if (
|
||||
"status 401" in lower
|
||||
or "status 403" in lower
|
||||
or "api key" in lower
|
||||
or "auth" in lower
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid Brave API key: {error_msg}",
|
||||
) from e
|
||||
if "status 429" in lower or "rate limit" in lower:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Brave API rate limit exceeded: {error_msg}",
|
||||
) from e
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Brave API key validation failed: {error_msg}",
|
||||
) from e
|
||||
|
||||
logger.info("Web search provider test succeeded for Brave.")
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
def _build_error_message(response: requests.Response) -> str:
|
||||
return (
|
||||
"Brave search failed "
|
||||
f"(status {response.status_code}): {_extract_error_detail(response)}"
|
||||
)
|
||||
|
||||
|
||||
def _extract_error_detail(response: requests.Response) -> str:
|
||||
try:
|
||||
payload: Any = response.json()
|
||||
except Exception:
|
||||
text = response.text.strip()
|
||||
return text[:200] if text else "No error details"
|
||||
|
||||
if isinstance(payload, dict):
|
||||
error = payload.get("error")
|
||||
if isinstance(error, dict):
|
||||
detail = error.get("detail") or error.get("message")
|
||||
if isinstance(detail, str):
|
||||
return detail
|
||||
if isinstance(error, str):
|
||||
return error
|
||||
|
||||
message = payload.get("message")
|
||||
if isinstance(message, str):
|
||||
return message
|
||||
|
||||
return str(payload)[:200]
|
||||
|
||||
|
||||
def _is_retryable_status(status_code: int) -> bool:
|
||||
return status_code == 429 or status_code >= 500
|
||||
|
||||
|
||||
def _clean_string(value: Any) -> str:
|
||||
return value.strip() if isinstance(value, str) else ""
|
||||
|
||||
|
||||
def _normalize_country(country: str | None) -> str | None:
|
||||
if country is None:
|
||||
return None
|
||||
normalized = country.strip().upper()
|
||||
if not normalized:
|
||||
return None
|
||||
if len(normalized) != 2 or not normalized.isalpha():
|
||||
raise ValueError(
|
||||
"Brave provider config 'country' must be a 2-letter ISO country code."
|
||||
)
|
||||
return normalized
|
||||
|
||||
|
||||
def _normalize_language_code(value: str | None, *, field_name: str) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
normalized = value.strip()
|
||||
if not normalized:
|
||||
return None
|
||||
if len(normalized) > 20:
|
||||
raise ValueError(f"Brave provider config '{field_name}' is too long.")
|
||||
return normalized
|
||||
|
||||
|
||||
def _normalize_option(
|
||||
value: str | None,
|
||||
*,
|
||||
field_name: str,
|
||||
allowed_values: set[str],
|
||||
) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
normalized = value.strip().lower()
|
||||
if not normalized:
|
||||
return None
|
||||
if normalized not in allowed_values:
|
||||
allowed = ", ".join(sorted(allowed_values))
|
||||
raise ValueError(
|
||||
f"Brave provider config '{field_name}' must be one of: {allowed}."
|
||||
)
|
||||
return normalized
|
||||
@@ -13,6 +13,9 @@ from onyx.tools.tool_implementations.open_url.onyx_web_crawler import (
|
||||
DEFAULT_MAX_PDF_SIZE_BYTES,
|
||||
)
|
||||
from onyx.tools.tool_implementations.open_url.onyx_web_crawler import OnyxWebCrawler
|
||||
from onyx.tools.tool_implementations.web_search.clients.brave_client import (
|
||||
BraveClient,
|
||||
)
|
||||
from onyx.tools.tool_implementations.web_search.clients.exa_client import (
|
||||
ExaClient,
|
||||
)
|
||||
@@ -35,16 +38,76 @@ from shared_configs.enums import WebSearchProviderType
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _parse_positive_int_config(
|
||||
*,
|
||||
raw_value: str | None,
|
||||
default: int,
|
||||
provider_name: str,
|
||||
config_key: str,
|
||||
) -> int:
|
||||
if not raw_value:
|
||||
return default
|
||||
try:
|
||||
value = int(raw_value)
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
f"{provider_name} provider config '{config_key}' must be an integer."
|
||||
) from exc
|
||||
if value <= 0:
|
||||
raise ValueError(
|
||||
f"{provider_name} provider config '{config_key}' must be greater than 0."
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
def provider_requires_api_key(provider_type: WebSearchProviderType) -> bool:
|
||||
"""Return True if the given provider type requires an API key.
|
||||
This list is most likely just going to contain SEARXNG. The way it works is that it uses public search engines that do not
|
||||
require an API key. You can also set it up in a way which requires a key but SearXNG itself does not require a key.
|
||||
"""
|
||||
return provider_type != WebSearchProviderType.SEARXNG
|
||||
|
||||
|
||||
def build_search_provider_from_config(
|
||||
provider_type: WebSearchProviderType,
|
||||
api_key: str,
|
||||
api_key: str | None,
|
||||
config: dict[str, str] | None, # TODO use a typed object
|
||||
) -> WebSearchProvider:
|
||||
config = config or {}
|
||||
num_results = int(config.get("num_results") or DEFAULT_MAX_RESULTS)
|
||||
|
||||
# SearXNG does not require an API key
|
||||
if provider_type == WebSearchProviderType.SEARXNG:
|
||||
searxng_base_url = config.get("searxng_base_url")
|
||||
if not searxng_base_url:
|
||||
raise ValueError("Please provide a URL for your private SearXNG instance.")
|
||||
return SearXNGClient(
|
||||
searxng_base_url,
|
||||
num_results=num_results,
|
||||
)
|
||||
|
||||
# All other providers require an API key
|
||||
if not api_key:
|
||||
raise ValueError(f"API key is required for {provider_type.value} provider.")
|
||||
|
||||
if provider_type == WebSearchProviderType.EXA:
|
||||
return ExaClient(api_key=api_key, num_results=num_results)
|
||||
if provider_type == WebSearchProviderType.BRAVE:
|
||||
return BraveClient(
|
||||
api_key=api_key,
|
||||
num_results=num_results,
|
||||
timeout_seconds=_parse_positive_int_config(
|
||||
raw_value=config.get("timeout_seconds"),
|
||||
default=10,
|
||||
provider_name="Brave",
|
||||
config_key="timeout_seconds",
|
||||
),
|
||||
country=config.get("country"),
|
||||
search_lang=config.get("search_lang"),
|
||||
ui_lang=config.get("ui_lang"),
|
||||
safesearch=config.get("safesearch"),
|
||||
freshness=config.get("freshness"),
|
||||
)
|
||||
if provider_type == WebSearchProviderType.SERPER:
|
||||
return SerperClient(api_key=api_key, num_results=num_results)
|
||||
if provider_type == WebSearchProviderType.GOOGLE_PSE:
|
||||
@@ -64,20 +127,13 @@ def build_search_provider_from_config(
|
||||
num_results=num_results,
|
||||
timeout_seconds=int(config.get("timeout_seconds") or 10),
|
||||
)
|
||||
if provider_type == WebSearchProviderType.SEARXNG:
|
||||
searxng_base_url = config.get("searxng_base_url")
|
||||
if not searxng_base_url:
|
||||
raise ValueError("Please provide a URL for your private SearXNG instance.")
|
||||
return SearXNGClient(
|
||||
searxng_base_url,
|
||||
num_results=num_results,
|
||||
)
|
||||
raise ValueError(f"Unknown provider type: {provider_type.value}")
|
||||
|
||||
|
||||
def _build_search_provider(provider_model: InternetSearchProvider) -> WebSearchProvider:
|
||||
return build_search_provider_from_config(
|
||||
provider_type=WebSearchProviderType(provider_model.provider_type),
|
||||
api_key=provider_model.api_key or "",
|
||||
api_key=provider_model.api_key,
|
||||
config=provider_model.config or {},
|
||||
)
|
||||
|
||||
|
||||
@@ -146,7 +146,7 @@ MAX_REDIRECTS = 10
|
||||
def _make_ssrf_safe_request(
|
||||
url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: int = 15,
|
||||
timeout: float | tuple[float, float] = 15,
|
||||
**kwargs: Any,
|
||||
) -> requests.Response:
|
||||
"""
|
||||
@@ -204,7 +204,7 @@ def _make_ssrf_safe_request(
|
||||
def ssrf_safe_get(
|
||||
url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: int = 15,
|
||||
timeout: float | tuple[float, float] = 15,
|
||||
follow_redirects: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> requests.Response:
|
||||
|
||||
@@ -36,7 +36,7 @@ global_version = OnyxVersion()
|
||||
# Eventually, ENABLE_PAID_ENTERPRISE_EDITION_FEATURES will be removed
|
||||
# and license enforcement will be the only mechanism for EE features.
|
||||
_LICENSE_ENFORCEMENT_ENABLED = (
|
||||
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "").lower() == "true"
|
||||
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "true").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -265,7 +265,7 @@ fastapi==0.128.0
|
||||
# onyx
|
||||
fastapi-limiter==0.1.6
|
||||
# via onyx
|
||||
fastapi-users==15.0.2
|
||||
fastapi-users==15.0.4
|
||||
# via
|
||||
# fastapi-users-db-sqlalchemy
|
||||
# onyx
|
||||
@@ -362,23 +362,14 @@ greenlet==3.2.4
|
||||
# sqlalchemy
|
||||
grpc-google-iam-v1==0.14.3
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.67.1 ; python_full_version < '3.14'
|
||||
grpcio==1.76.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpcio==1.76.0 ; python_full_version >= '3.14'
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpcio-status==1.67.1 ; python_full_version < '3.14'
|
||||
# via google-api-core
|
||||
grpcio-status==1.76.0 ; python_full_version >= '3.14'
|
||||
grpcio-status==1.76.0
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via
|
||||
@@ -762,19 +753,7 @@ proto-plus==1.26.1
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==5.29.5 ; python_full_version < '3.14'
|
||||
# via
|
||||
# ddtrace
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# onnxruntime
|
||||
# opentelemetry-proto
|
||||
# proto-plus
|
||||
protobuf==6.33.4 ; python_full_version >= '3.14'
|
||||
protobuf==6.33.5
|
||||
# via
|
||||
# ddtrace
|
||||
# google-api-core
|
||||
@@ -850,7 +829,7 @@ pygithub==2.5.0
|
||||
# via onyx
|
||||
pygments==2.19.2
|
||||
# via rich
|
||||
pyjwt==2.10.1
|
||||
pyjwt==2.11.0
|
||||
# via
|
||||
# fastapi-users
|
||||
# mcp
|
||||
@@ -919,7 +898,7 @@ python-json-logger==4.0.0
|
||||
# via pydocket
|
||||
python-magic==0.4.27
|
||||
# via unstructured
|
||||
python-multipart==0.0.21
|
||||
python-multipart==0.0.22
|
||||
# via
|
||||
# fastapi-users
|
||||
# mcp
|
||||
|
||||
@@ -123,7 +123,7 @@ execnet==2.1.2
|
||||
# via pytest-xdist
|
||||
executing==2.2.1
|
||||
# via stack-data
|
||||
faker==37.1.0
|
||||
faker==40.1.2
|
||||
# via onyx
|
||||
fastapi==0.128.0
|
||||
# via
|
||||
@@ -195,23 +195,14 @@ greenlet==3.2.4 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or
|
||||
# via sqlalchemy
|
||||
grpc-google-iam-v1==0.14.3
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.67.1 ; python_full_version < '3.14'
|
||||
grpcio==1.76.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpcio==1.76.0 ; python_full_version >= '3.14'
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpcio-status==1.67.1 ; python_full_version < '3.14'
|
||||
# via google-api-core
|
||||
grpcio-status==1.76.0 ; python_full_version >= '3.14'
|
||||
grpcio-status==1.76.0
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via
|
||||
@@ -326,7 +317,7 @@ oauthlib==3.2.2
|
||||
# via
|
||||
# kubernetes
|
||||
# requests-oauthlib
|
||||
onyx-devtools==0.5.0
|
||||
onyx-devtools==0.6.2
|
||||
# via onyx
|
||||
openai==2.14.0
|
||||
# via
|
||||
@@ -388,16 +379,7 @@ proto-plus==1.26.1
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==5.29.5 ; python_full_version < '3.14'
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# proto-plus
|
||||
protobuf==6.33.4 ; python_full_version >= '3.14'
|
||||
protobuf==6.33.5
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
@@ -442,7 +424,7 @@ pygments==2.19.2
|
||||
# via
|
||||
# ipython
|
||||
# ipython-pygments-lexers
|
||||
pyjwt==2.10.1
|
||||
pyjwt==2.11.0
|
||||
# via mcp
|
||||
pyparsing==3.2.5
|
||||
# via matplotlib
|
||||
@@ -462,7 +444,7 @@ pytest-dotenv==0.5.2
|
||||
# via onyx
|
||||
pytest-repeat==0.9.4
|
||||
# via onyx
|
||||
pytest-xdist==3.6.1
|
||||
pytest-xdist==3.8.0
|
||||
# via onyx
|
||||
python-dateutil==2.8.2
|
||||
# via
|
||||
@@ -477,7 +459,7 @@ python-dotenv==1.1.1
|
||||
# litellm
|
||||
# pydantic-settings
|
||||
# pytest-dotenv
|
||||
python-multipart==0.0.21
|
||||
python-multipart==0.0.22
|
||||
# via mcp
|
||||
pywin32==311 ; sys_platform == 'win32'
|
||||
# via mcp
|
||||
@@ -640,7 +622,7 @@ typing-inspection==0.4.2
|
||||
# mcp
|
||||
# pydantic
|
||||
# pydantic-settings
|
||||
tzdata==2025.2
|
||||
tzdata==2025.2 ; sys_platform == 'win32'
|
||||
# via faker
|
||||
urllib3==2.6.3
|
||||
# via
|
||||
|
||||
@@ -152,23 +152,14 @@ googleapis-common-protos==1.72.0
|
||||
# grpcio-status
|
||||
grpc-google-iam-v1==0.14.3
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.67.1 ; python_full_version < '3.14'
|
||||
grpcio==1.76.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpcio==1.76.0 ; python_full_version >= '3.14'
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpcio-status==1.67.1 ; python_full_version < '3.14'
|
||||
# via google-api-core
|
||||
grpcio-status==1.76.0 ; python_full_version >= '3.14'
|
||||
grpcio-status==1.76.0
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via
|
||||
@@ -265,16 +256,7 @@ proto-plus==1.26.1
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==5.29.5 ; python_full_version < '3.14'
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# proto-plus
|
||||
protobuf==6.33.4 ; python_full_version >= '3.14'
|
||||
protobuf==6.33.5
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
@@ -309,7 +291,7 @@ pydantic-core==2.33.2
|
||||
# via pydantic
|
||||
pydantic-settings==2.12.0
|
||||
# via mcp
|
||||
pyjwt==2.10.1
|
||||
pyjwt==2.11.0
|
||||
# via mcp
|
||||
python-dateutil==2.8.2
|
||||
# via
|
||||
@@ -322,7 +304,7 @@ python-dotenv==1.1.1
|
||||
# via
|
||||
# litellm
|
||||
# pydantic-settings
|
||||
python-multipart==0.0.21
|
||||
python-multipart==0.0.22
|
||||
# via mcp
|
||||
pywin32==311 ; sys_platform == 'win32'
|
||||
# via mcp
|
||||
|
||||
@@ -177,23 +177,14 @@ googleapis-common-protos==1.72.0
|
||||
# grpcio-status
|
||||
grpc-google-iam-v1==0.14.3
|
||||
# via google-cloud-resource-manager
|
||||
grpcio==1.67.1 ; python_full_version < '3.14'
|
||||
grpcio==1.76.0
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpcio==1.76.0 ; python_full_version >= '3.14'
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
grpcio-status==1.67.1 ; python_full_version < '3.14'
|
||||
# via google-api-core
|
||||
grpcio-status==1.76.0 ; python_full_version >= '3.14'
|
||||
grpcio-status==1.76.0
|
||||
# via google-api-core
|
||||
h11==0.16.0
|
||||
# via
|
||||
@@ -351,16 +342,7 @@ proto-plus==1.26.1
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
protobuf==5.29.5 ; python_full_version < '3.14'
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
# google-cloud-resource-manager
|
||||
# googleapis-common-protos
|
||||
# grpc-google-iam-v1
|
||||
# grpcio-status
|
||||
# proto-plus
|
||||
protobuf==6.33.4 ; python_full_version >= '3.14'
|
||||
protobuf==6.33.5
|
||||
# via
|
||||
# google-api-core
|
||||
# google-cloud-aiplatform
|
||||
@@ -397,7 +379,7 @@ pydantic-core==2.33.2
|
||||
# via pydantic
|
||||
pydantic-settings==2.12.0
|
||||
# via mcp
|
||||
pyjwt==2.10.1
|
||||
pyjwt==2.11.0
|
||||
# via mcp
|
||||
python-dateutil==2.8.2
|
||||
# via
|
||||
@@ -410,7 +392,7 @@ python-dotenv==1.1.1
|
||||
# via
|
||||
# litellm
|
||||
# pydantic-settings
|
||||
python-multipart==0.0.21
|
||||
python-multipart==0.0.22
|
||||
# via mcp
|
||||
pywin32==311 ; sys_platform == 'win32'
|
||||
# via mcp
|
||||
|
||||
@@ -26,6 +26,7 @@ class WebSearchProviderType(str, Enum):
|
||||
SERPER = "serper"
|
||||
EXA = "exa"
|
||||
SEARXNG = "searxng"
|
||||
BRAVE = "brave"
|
||||
|
||||
|
||||
class WebContentProviderType(str, Enum):
|
||||
|
||||
@@ -0,0 +1,281 @@
|
||||
"""
|
||||
External dependency unit tests for user file processing queue protections.
|
||||
|
||||
Verifies that the three mechanisms added to check_user_file_processing work
|
||||
correctly:
|
||||
|
||||
1. Queue depth backpressure – when the broker queue exceeds
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH, no new tasks are enqueued.
|
||||
|
||||
2. Per-file Redis guard key – if the guard key for a file already exists in
|
||||
Redis, that file is skipped even though it is still in PROCESSING status.
|
||||
|
||||
3. Task expiry – every send_task call carries expires=
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES so that stale queued tasks are
|
||||
discarded by workers automatically.
|
||||
|
||||
Also verifies that process_single_user_file clears the guard key the moment
|
||||
it is picked up by a worker.
|
||||
|
||||
Uses real Redis (DB 0 via get_redis_client) and real PostgreSQL for UserFile
|
||||
rows. The Celery app is provided as a MagicMock injected via a PropertyMock
|
||||
on the task class so no real broker is needed.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import PropertyMock
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_user_file_lock_key,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_user_file_queued_key,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
check_user_file_processing,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
process_single_user_file,
|
||||
)
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PATCH_QUEUE_LEN = (
|
||||
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_queue_length"
|
||||
)
|
||||
|
||||
|
||||
def _create_processing_user_file(db_session: Session, user_id: object) -> UserFile:
|
||||
"""Insert a UserFile in PROCESSING status and return it."""
|
||||
uf = UserFile(
|
||||
id=uuid4(),
|
||||
user_id=user_id,
|
||||
file_id=f"test_file_{uuid4().hex[:8]}",
|
||||
name=f"test_{uuid4().hex[:8]}.txt",
|
||||
file_type="text/plain",
|
||||
status=UserFileStatus.PROCESSING,
|
||||
)
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
db_session.refresh(uf)
|
||||
return uf
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, None]:
|
||||
"""Patch the ``app`` property on *task*'s class so that ``self.app``
|
||||
inside the task function returns *mock_app*.
|
||||
|
||||
With ``bind=True``, ``task.run`` is a bound method whose ``__self__`` is
|
||||
the actual task instance. We patch ``app`` on that instance's class
|
||||
(a unique Celery-generated Task subclass) so the mock is scoped to this
|
||||
task only.
|
||||
"""
|
||||
task_instance = task.run.__self__
|
||||
with patch.object(
|
||||
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestQueueDepthBackpressure:
|
||||
"""Protection 1: skip all enqueuing when the broker queue is too deep."""
|
||||
|
||||
def test_no_tasks_enqueued_when_queue_over_limit(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""When the queue depth exceeds the limit the beat cycle is skipped."""
|
||||
user = create_test_user(db_session, "bp_user")
|
||||
_create_processing_user_file(db_session, user.id)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
with (
|
||||
_patch_task_app(check_user_file_processing, mock_app),
|
||||
patch(
|
||||
_PATCH_QUEUE_LEN, return_value=USER_FILE_PROCESSING_MAX_QUEUE_DEPTH + 1
|
||||
),
|
||||
):
|
||||
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
mock_app.send_task.assert_not_called()
|
||||
|
||||
|
||||
class TestPerFileGuardKey:
|
||||
"""Protection 2: per-file Redis guard key prevents duplicate enqueue."""
|
||||
|
||||
def test_guarded_file_not_re_enqueued(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""A file whose guard key is already set in Redis is skipped."""
|
||||
user = create_test_user(db_session, "guard_user")
|
||||
uf = _create_processing_user_file(db_session, user.id)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_queued_key(uf.id)
|
||||
redis_client.setex(guard_key, CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, 1)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
try:
|
||||
with (
|
||||
_patch_task_app(check_user_file_processing, mock_app),
|
||||
patch(_PATCH_QUEUE_LEN, return_value=0),
|
||||
):
|
||||
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
# send_task must not have been called with this specific file's ID
|
||||
for call in mock_app.send_task.call_args_list:
|
||||
kwargs = call.kwargs.get("kwargs", {})
|
||||
assert kwargs.get("user_file_id") != str(
|
||||
uf.id
|
||||
), f"File {uf.id} should have been skipped because its guard key exists"
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
def test_guard_key_exists_in_redis_after_enqueue(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""After a file is enqueued its guard key is present in Redis with a TTL."""
|
||||
user = create_test_user(db_session, "guard_set_user")
|
||||
uf = _create_processing_user_file(db_session, user.id)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_queued_key(uf.id)
|
||||
redis_client.delete(guard_key) # clean slate
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
try:
|
||||
with (
|
||||
_patch_task_app(check_user_file_processing, mock_app),
|
||||
patch(_PATCH_QUEUE_LEN, return_value=0),
|
||||
):
|
||||
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
assert redis_client.exists(
|
||||
guard_key
|
||||
), "Guard key should be set in Redis after enqueue"
|
||||
ttl = int(redis_client.ttl(guard_key)) # type: ignore[arg-type]
|
||||
assert 0 < ttl <= CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, (
|
||||
f"Guard key TTL {ttl}s is outside the expected range "
|
||||
f"(0, {CELERY_USER_FILE_PROCESSING_TASK_EXPIRES}]"
|
||||
)
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
|
||||
class TestTaskExpiry:
|
||||
"""Protection 3: every send_task call includes an expires value."""
|
||||
|
||||
def test_send_task_called_with_expires(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""send_task is called with the correct queue, task name, and expires."""
|
||||
user = create_test_user(db_session, "expires_user")
|
||||
uf = _create_processing_user_file(db_session, user.id)
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_queued_key(uf.id)
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
mock_app = MagicMock()
|
||||
|
||||
try:
|
||||
with (
|
||||
_patch_task_app(check_user_file_processing, mock_app),
|
||||
patch(_PATCH_QUEUE_LEN, return_value=0),
|
||||
):
|
||||
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
|
||||
|
||||
# At least one task should have been submitted (for our file)
|
||||
assert (
|
||||
mock_app.send_task.call_count >= 1
|
||||
), "Expected at least one task to be submitted"
|
||||
|
||||
# Every submitted task must carry expires
|
||||
for call in mock_app.send_task.call_args_list:
|
||||
assert call.args[0] == OnyxCeleryTask.PROCESS_SINGLE_USER_FILE
|
||||
assert call.kwargs.get("queue") == OnyxCeleryQueues.USER_FILE_PROCESSING
|
||||
assert (
|
||||
call.kwargs.get("expires")
|
||||
== CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
), (
|
||||
"Task must be submitted with the correct expires value to prevent "
|
||||
"stale task accumulation"
|
||||
)
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
|
||||
class TestWorkerClearsGuardKey:
|
||||
"""process_single_user_file removes the guard key when it picks up a task."""
|
||||
|
||||
def test_guard_key_deleted_on_pickup(
|
||||
self,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""The guard key is deleted before the worker does any real work.
|
||||
|
||||
We simulate an already-locked file so process_single_user_file returns
|
||||
early – but crucially, after the guard key deletion.
|
||||
"""
|
||||
user_file_id = str(uuid4())
|
||||
|
||||
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
|
||||
guard_key = _user_file_queued_key(user_file_id)
|
||||
|
||||
# Simulate the guard key set when the beat enqueued the task
|
||||
redis_client.setex(guard_key, CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, 1)
|
||||
assert redis_client.exists(guard_key), "Guard key must exist before pickup"
|
||||
|
||||
# Hold the per-file processing lock so the worker exits early without
|
||||
# touching the database or file store.
|
||||
lock_key = _user_file_lock_key(user_file_id)
|
||||
processing_lock = redis_client.lock(lock_key, timeout=10)
|
||||
acquired = processing_lock.acquire(blocking=False)
|
||||
assert acquired, "Should be able to acquire the processing lock for this test"
|
||||
|
||||
try:
|
||||
process_single_user_file.run(
|
||||
user_file_id=user_file_id,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
)
|
||||
finally:
|
||||
if processing_lock.owned():
|
||||
processing_lock.release()
|
||||
|
||||
assert not redis_client.exists(
|
||||
guard_key
|
||||
), "Guard key should be deleted when the worker picks up the task"
|
||||
@@ -553,7 +553,7 @@ class TestDefaultProviderEndpoint:
|
||||
|
||||
try:
|
||||
existing_providers = fetch_existing_llm_providers(
|
||||
db_session, flow_types=[LLMModelFlowType.CHAT]
|
||||
db_session, flow_type_filter=[LLMModelFlowType.CHAT]
|
||||
)
|
||||
provider_names_to_restore: list[str] = []
|
||||
|
||||
|
||||
@@ -14,9 +14,12 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import LLMModelFlowType
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import sync_auto_mode_models
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
@@ -606,3 +609,95 @@ class TestAutoModeSyncFeature:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_1_name)
|
||||
_cleanup_provider(db_session, provider_2_name)
|
||||
|
||||
|
||||
class TestAutoModeMissingFlows:
|
||||
"""Regression test: sync_auto_mode_models must create LLMModelFlow rows
|
||||
for every ModelConfiguration it inserts, otherwise the provider vanishes
|
||||
from listing queries that join through LLMModelFlow."""
|
||||
|
||||
def test_sync_auto_mode_creates_flow_rows(
|
||||
self,
|
||||
db_session: Session,
|
||||
provider_name: str,
|
||||
) -> None:
|
||||
"""
|
||||
Steps:
|
||||
1. Create a provider with no model configs (empty shell).
|
||||
2. Call sync_auto_mode_models to add models from a mock config.
|
||||
3. Assert every new ModelConfiguration has at least one LLMModelFlow.
|
||||
4. Assert fetch_existing_llm_providers (which joins through
|
||||
LLMModelFlow) returns the provider.
|
||||
"""
|
||||
mock_recommendations = _create_mock_llm_recommendations(
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
default_model_name="gpt-4o",
|
||||
additional_models=["gpt-4o-mini"],
|
||||
)
|
||||
|
||||
try:
|
||||
# Step 1: Create provider with no model configs
|
||||
put_llm_provider(
|
||||
llm_provider_upsert_request=LLMProviderUpsertRequest(
|
||||
name=provider_name,
|
||||
provider=LlmProviderNames.OPENAI,
|
||||
api_key="sk-test-key-00000000000000000000000000000000000",
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
default_model_name="gpt-4o",
|
||||
model_configurations=[],
|
||||
),
|
||||
is_creation=True,
|
||||
_=_create_mock_admin(),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Step 2: Run sync_auto_mode_models (simulating the periodic sync)
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
sync_auto_mode_models(
|
||||
db_session=db_session,
|
||||
provider=provider,
|
||||
llm_recommendations=mock_recommendations,
|
||||
)
|
||||
|
||||
# Step 3: Every ModelConfiguration must have at least one LLMModelFlow
|
||||
db_session.expire_all()
|
||||
provider = fetch_existing_llm_provider(
|
||||
name=provider_name, db_session=db_session
|
||||
)
|
||||
assert provider is not None
|
||||
|
||||
synced_model_names = {mc.name for mc in provider.model_configurations}
|
||||
assert "gpt-4o" in synced_model_names
|
||||
assert "gpt-4o-mini" in synced_model_names
|
||||
|
||||
for mc in provider.model_configurations:
|
||||
assert len(mc.llm_model_flows) > 0, (
|
||||
f"ModelConfiguration '{mc.name}' (id={mc.id}) has no "
|
||||
f"LLMModelFlow rows — it will be invisible to listing queries"
|
||||
)
|
||||
|
||||
flow_types = {f.llm_model_flow_type for f in mc.llm_model_flows}
|
||||
assert (
|
||||
LLMModelFlowType.CHAT in flow_types
|
||||
), f"ModelConfiguration '{mc.name}' is missing a CHAT flow"
|
||||
|
||||
# Step 4: The provider must appear in fetch_existing_llm_providers
|
||||
listed_providers = fetch_existing_llm_providers(
|
||||
db_session=db_session,
|
||||
flow_type_filter=[LLMModelFlowType.CHAT],
|
||||
)
|
||||
listed_provider_names = {p.name for p in listed_providers}
|
||||
assert provider_name in listed_provider_names, (
|
||||
f"Provider '{provider_name}' not returned by "
|
||||
f"fetch_existing_llm_providers — models are missing flow rows"
|
||||
)
|
||||
|
||||
finally:
|
||||
db_session.rollback()
|
||||
_cleanup_provider(db_session, provider_name)
|
||||
|
||||
@@ -14,6 +14,9 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE,
|
||||
)
|
||||
from onyx.background.celery.tasks.opensearch_migration.tasks import (
|
||||
check_for_documents_for_opensearch_migration_task,
|
||||
)
|
||||
@@ -25,14 +28,12 @@ from onyx.configs.constants import SOURCE_TYPE
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import OpenSearchDocumentMigrationStatus
|
||||
from onyx.db.enums import OpenSearchTenantMigrationStatus
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import OpenSearchDocumentMigrationRecord
|
||||
from onyx.db.models import OpenSearchTenantMigrationRecord
|
||||
from onyx.db.opensearch_migration import create_opensearch_migration_records_with_commit
|
||||
from onyx.db.opensearch_migration import get_last_opensearch_migration_document_id
|
||||
from onyx.db.opensearch_migration import (
|
||||
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE,
|
||||
)
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.client import OpenSearchClient
|
||||
@@ -525,6 +526,40 @@ class TestCheckForDocumentsForOpenSearchMigrationTask:
|
||||
>= 1
|
||||
)
|
||||
|
||||
def test_creates_singleton_migration_record(
|
||||
self,
|
||||
db_session: Session,
|
||||
clean_migration_tables: None, # noqa: ARG002
|
||||
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""Tests that singleton migration record is created."""
|
||||
# Under test.
|
||||
result = check_for_documents_for_opensearch_migration_task(
|
||||
tenant_id=get_current_tenant_id()
|
||||
)
|
||||
|
||||
# Postcondition.
|
||||
assert result is True
|
||||
# Expire the session cache to see the committed changes from the task.
|
||||
db_session.expire_all()
|
||||
# Verify the singleton migration record was created.
|
||||
tenant_record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
assert tenant_record is not None
|
||||
assert (
|
||||
tenant_record.document_migration_record_table_population_status
|
||||
== OpenSearchTenantMigrationStatus.PENDING
|
||||
)
|
||||
assert (
|
||||
tenant_record.num_times_observed_no_additional_docs_to_populate_migration_table
|
||||
== 1
|
||||
)
|
||||
assert (
|
||||
tenant_record.overall_document_migration_status
|
||||
== OpenSearchTenantMigrationStatus.PENDING
|
||||
)
|
||||
assert tenant_record.num_times_observed_no_additional_docs_to_migrate == 0
|
||||
assert tenant_record.last_updated_at is not None
|
||||
|
||||
|
||||
class TestMigrateDocumentsFromVespaToOpenSearchTask:
|
||||
"""Tests migrate_documents_from_vespa_to_opensearch_task."""
|
||||
@@ -665,7 +700,11 @@ class TestMigrateDocumentsFromVespaToOpenSearchTask:
|
||||
.first()
|
||||
)
|
||||
assert record is not None
|
||||
assert record.status == OpenSearchDocumentMigrationStatus.FAILED
|
||||
# In practice the task keeps trying docs until it either runs out of
|
||||
# time or the lock is lost, which will not happen during this test.
|
||||
# Because of this the migration record will just shift to permanently
|
||||
# failed. Let's just test for that here.
|
||||
assert record.status == OpenSearchDocumentMigrationStatus.PERMANENTLY_FAILED
|
||||
# Verify chunks were indexed in OpenSearch.
|
||||
for document_id in doc_ids_that_have_chunks:
|
||||
chunks = _get_document_chunks_from_opensearch(
|
||||
@@ -764,7 +803,11 @@ class TestMigrateDocumentsFromVespaToOpenSearchTask:
|
||||
.first()
|
||||
)
|
||||
assert record is not None
|
||||
assert record.status == OpenSearchDocumentMigrationStatus.FAILED
|
||||
# In practice the task keeps trying docs until it either runs out of
|
||||
# time or the lock is lost, which will not happen during this test.
|
||||
# Because of this the migration record will just shift to permanently
|
||||
# failed. Let's just test for that here.
|
||||
assert record.status == OpenSearchDocumentMigrationStatus.PERMANENTLY_FAILED
|
||||
assert record.error_message is not None
|
||||
assert "no chunk count" in record.error_message.lower()
|
||||
|
||||
|
||||
@@ -17,7 +17,8 @@ COPY ./tests/* /app/tests/
|
||||
|
||||
FROM base AS openapi-schema
|
||||
COPY ./scripts/onyx_openapi_schema.py /app/scripts/onyx_openapi_schema.py
|
||||
RUN python scripts/onyx_openapi_schema.py --filename openapi.json
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
RUN LICENSE_ENFORCEMENT_ENABLED=false python scripts/onyx_openapi_schema.py --filename openapi.json
|
||||
|
||||
FROM openapitools/openapi-generator-cli:latest AS openapi-client
|
||||
WORKDIR /local
|
||||
|
||||
@@ -24,6 +24,7 @@ from tests.integration.common_utils.test_models import DATestChatSession
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.test_models import ErrorResponse
|
||||
from tests.integration.common_utils.test_models import StreamedResponse
|
||||
from tests.integration.common_utils.test_models import ToolCallDebug
|
||||
from tests.integration.common_utils.test_models import ToolName
|
||||
from tests.integration.common_utils.test_models import ToolResult
|
||||
|
||||
@@ -40,6 +41,7 @@ class StreamPacketObj(TypedDict, total=False):
|
||||
"image_generation_start",
|
||||
"image_generation_heartbeat",
|
||||
"image_generation_final",
|
||||
"tool_call_debug",
|
||||
]
|
||||
content: str
|
||||
final_documents: list[dict[str, Any]]
|
||||
@@ -47,6 +49,9 @@ class StreamPacketObj(TypedDict, total=False):
|
||||
images: list[dict[str, Any]]
|
||||
queries: list[str]
|
||||
documents: list[dict[str, Any]]
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
tool_args: dict[str, Any]
|
||||
|
||||
|
||||
class PlacementData(TypedDict, total=False):
|
||||
@@ -109,6 +114,7 @@ class ChatSessionManager:
|
||||
use_existing_user_message: bool = False,
|
||||
forced_tool_ids: list[int] | None = None,
|
||||
chat_session: DATestChatSession | None = None,
|
||||
mock_llm_response: str | None = None,
|
||||
) -> StreamedResponse:
|
||||
chat_message_req = CreateChatMessageRequest(
|
||||
chat_session_id=chat_session_id,
|
||||
@@ -120,6 +126,7 @@ class ChatSessionManager:
|
||||
query_override=query_override,
|
||||
regenerate=regenerate,
|
||||
llm_override=llm_override,
|
||||
mock_llm_response=mock_llm_response,
|
||||
prompt_override=prompt_override,
|
||||
alternate_assistant_id=alternate_assistant_id,
|
||||
use_existing_user_message=use_existing_user_message,
|
||||
@@ -179,6 +186,7 @@ class ChatSessionManager:
|
||||
alternate_assistant_id: int | None = None,
|
||||
use_existing_user_message: bool = False,
|
||||
forced_tool_ids: list[int] | None = None,
|
||||
mock_llm_response: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Send a message and simulate client disconnect before stream completes.
|
||||
@@ -210,6 +218,7 @@ class ChatSessionManager:
|
||||
query_override=query_override,
|
||||
regenerate=regenerate,
|
||||
llm_override=llm_override,
|
||||
mock_llm_response=mock_llm_response,
|
||||
prompt_override=prompt_override,
|
||||
alternate_assistant_id=alternate_assistant_id,
|
||||
use_existing_user_message=use_existing_user_message,
|
||||
@@ -253,6 +262,7 @@ class ChatSessionManager:
|
||||
],
|
||||
)
|
||||
ind_to_tool_use: dict[int, ToolResult] = {}
|
||||
tool_call_debug: list[ToolCallDebug] = []
|
||||
top_documents: list[SearchDoc] = []
|
||||
heartbeat_packets: list[StreamPacketData] = []
|
||||
full_message = ""
|
||||
@@ -330,6 +340,16 @@ class ChatSessionManager:
|
||||
SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
|
||||
)
|
||||
ind_to_tool_use[ind].documents.extend(docs)
|
||||
elif packet_type_str == StreamingType.TOOL_CALL_DEBUG.value:
|
||||
tool_call_debug.append(
|
||||
ToolCallDebug(
|
||||
tool_call_id=str(data_obj.get("tool_call_id", "")),
|
||||
tool_name=str(data_obj.get("tool_name", "")),
|
||||
tool_args=cast(
|
||||
dict[str, Any], data_obj.get("tool_args") or {}
|
||||
),
|
||||
)
|
||||
)
|
||||
# If there's an error, assistant_message_id might not be present
|
||||
if not assistant_message_id and not error:
|
||||
raise ValueError("Assistant message id not found")
|
||||
@@ -338,6 +358,7 @@ class ChatSessionManager:
|
||||
assistant_message_id=assistant_message_id or -1, # Use -1 for error cases
|
||||
top_documents=top_documents,
|
||||
used_tools=list(ind_to_tool_use.values()),
|
||||
tool_call_debug=tool_call_debug,
|
||||
heartbeat_packets=[dict(packet) for packet in heartbeat_packets],
|
||||
error=error,
|
||||
)
|
||||
|
||||
@@ -202,6 +202,12 @@ class ToolResult(BaseModel):
|
||||
images: list[GeneratedImage] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ToolCallDebug(BaseModel):
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
tool_args: dict[str, Any]
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
error: str
|
||||
stack_trace: str
|
||||
@@ -212,6 +218,7 @@ class StreamedResponse(BaseModel):
|
||||
assistant_message_id: int
|
||||
top_documents: list[SearchDoc]
|
||||
used_tools: list[ToolResult]
|
||||
tool_call_debug: list[ToolCallDebug] = Field(default_factory=list)
|
||||
error: ErrorResponse | None = None
|
||||
|
||||
# Track heartbeat packets for image generation and other tools
|
||||
|
||||
@@ -34,9 +34,7 @@ def _schema_exists(schema_name: str) -> bool:
|
||||
class TestTenantProvisioningRollback:
|
||||
"""Integration tests for provisioning failure and rollback."""
|
||||
|
||||
def test_failed_provisioning_cleans_up_schema(
|
||||
self, reset_multitenant: None
|
||||
) -> None:
|
||||
def test_failed_provisioning_cleans_up_schema(self) -> None:
|
||||
"""
|
||||
When setup_tenant fails after schema creation, rollback should
|
||||
clean up the orphaned schema.
|
||||
@@ -79,9 +77,7 @@ class TestTenantProvisioningRollback:
|
||||
created_tenant_id
|
||||
), f"Schema {created_tenant_id} should have been rolled back"
|
||||
|
||||
def test_drop_schema_works_with_uuid_tenant_id(
|
||||
self, reset_multitenant: None
|
||||
) -> None:
|
||||
def test_drop_schema_works_with_uuid_tenant_id(self) -> None:
|
||||
"""
|
||||
drop_schema should work with UUID-format tenant IDs.
|
||||
|
||||
|
||||
@@ -240,6 +240,116 @@ def test_can_user_access_llm_provider_or_logic(
|
||||
)
|
||||
|
||||
|
||||
def test_public_provider_with_persona_restrictions(
|
||||
users: tuple[DATestUser, DATestUser],
|
||||
) -> None:
|
||||
"""Public providers should still enforce persona restrictions.
|
||||
|
||||
Regression test for the bug where is_public=True caused
|
||||
can_user_access_llm_provider() to return True immediately,
|
||||
bypassing persona whitelist checks entirely.
|
||||
"""
|
||||
admin_user, _basic_user = users
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Public provider with persona restrictions
|
||||
public_restricted = _create_llm_provider(
|
||||
db_session,
|
||||
name="public-persona-restricted",
|
||||
default_model_name="gpt-4o",
|
||||
is_public=True,
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
whitelisted_persona = _create_persona(
|
||||
db_session,
|
||||
name="whitelisted-persona",
|
||||
provider_name=public_restricted.name,
|
||||
)
|
||||
non_whitelisted_persona = _create_persona(
|
||||
db_session,
|
||||
name="non-whitelisted-persona",
|
||||
provider_name=public_restricted.name,
|
||||
)
|
||||
|
||||
# Only whitelist one persona
|
||||
db_session.add(
|
||||
LLMProvider__Persona(
|
||||
llm_provider_id=public_restricted.id,
|
||||
persona_id=whitelisted_persona.id,
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
db_session.refresh(public_restricted)
|
||||
|
||||
admin_model = db_session.get(User, admin_user.id)
|
||||
assert admin_model is not None
|
||||
admin_group_ids = fetch_user_group_ids(db_session, admin_model)
|
||||
|
||||
# Whitelisted persona — should be allowed
|
||||
assert can_user_access_llm_provider(
|
||||
public_restricted,
|
||||
admin_group_ids,
|
||||
whitelisted_persona,
|
||||
)
|
||||
|
||||
# Non-whitelisted persona — should be denied despite is_public=True
|
||||
assert not can_user_access_llm_provider(
|
||||
public_restricted,
|
||||
admin_group_ids,
|
||||
non_whitelisted_persona,
|
||||
)
|
||||
|
||||
# No persona context (e.g. global provider list) — should be denied
|
||||
# because provider has persona restrictions set
|
||||
assert not can_user_access_llm_provider(
|
||||
public_restricted,
|
||||
admin_group_ids,
|
||||
persona=None,
|
||||
)
|
||||
|
||||
|
||||
def test_public_provider_without_persona_restrictions(
|
||||
users: tuple[DATestUser, DATestUser],
|
||||
) -> None:
|
||||
"""Public providers with no persona restrictions remain accessible to all."""
|
||||
admin_user, basic_user = users
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
public_unrestricted = _create_llm_provider(
|
||||
db_session,
|
||||
name="public-unrestricted",
|
||||
default_model_name="gpt-4o",
|
||||
is_public=True,
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
any_persona = _create_persona(
|
||||
db_session,
|
||||
name="any-persona",
|
||||
provider_name=public_unrestricted.name,
|
||||
)
|
||||
|
||||
admin_model = db_session.get(User, admin_user.id)
|
||||
basic_model = db_session.get(User, basic_user.id)
|
||||
assert admin_model is not None
|
||||
assert basic_model is not None
|
||||
|
||||
admin_group_ids = fetch_user_group_ids(db_session, admin_model)
|
||||
basic_group_ids = fetch_user_group_ids(db_session, basic_model)
|
||||
|
||||
# Any user, any persona — all allowed
|
||||
assert can_user_access_llm_provider(
|
||||
public_unrestricted, admin_group_ids, any_persona
|
||||
)
|
||||
assert can_user_access_llm_provider(
|
||||
public_unrestricted, basic_group_ids, any_persona
|
||||
)
|
||||
assert can_user_access_llm_provider(
|
||||
public_unrestricted, admin_group_ids, persona=None
|
||||
)
|
||||
|
||||
|
||||
def test_get_llm_for_persona_falls_back_when_access_denied(
|
||||
users: tuple[DATestUser, DATestUser],
|
||||
) -> None:
|
||||
|
||||
@@ -0,0 +1,234 @@
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||
from tests.integration.common_utils.managers.user import DATestUser
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
|
||||
|
||||
def _upload_connector_file(
|
||||
*,
|
||||
user_performing_action: DATestUser,
|
||||
file_name: str,
|
||||
content: bytes,
|
||||
) -> tuple[str, str]:
|
||||
headers = user_performing_action.headers.copy()
|
||||
headers.pop("Content-Type", None)
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/connector/file/upload",
|
||||
files=[("files", (file_name, io.BytesIO(content), "text/plain"))],
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
return payload["file_paths"][0], payload["file_names"][0]
|
||||
|
||||
|
||||
def _update_connector_files(
|
||||
*,
|
||||
connector_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
file_ids_to_remove: list[str],
|
||||
new_file_name: str,
|
||||
new_file_content: bytes,
|
||||
) -> requests.Response:
|
||||
headers = user_performing_action.headers.copy()
|
||||
headers.pop("Content-Type", None)
|
||||
|
||||
return requests.post(
|
||||
f"{API_SERVER_URL}/manage/admin/connector/{connector_id}/files/update",
|
||||
data={"file_ids_to_remove": json.dumps(file_ids_to_remove)},
|
||||
files=[("files", (new_file_name, io.BytesIO(new_file_content), "text/plain"))],
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
def _list_connector_files(
|
||||
*,
|
||||
connector_id: int,
|
||||
user_performing_action: DATestUser,
|
||||
) -> requests.Response:
|
||||
return requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/connector/{connector_id}/files",
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Curator and user group tests are enterprise only",
|
||||
)
|
||||
@pytest.mark.usefixtures("reset")
|
||||
def test_only_global_curator_can_update_public_file_connector_files() -> None:
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
global_curator_creator = UserManager.create(name="global_curator_creator")
|
||||
global_curator_creator = UserManager.set_role(
|
||||
user_to_set=global_curator_creator,
|
||||
target_role=UserRole.GLOBAL_CURATOR,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
global_curator_editor = UserManager.create(name="global_curator_editor")
|
||||
global_curator_editor = UserManager.set_role(
|
||||
user_to_set=global_curator_editor,
|
||||
target_role=UserRole.GLOBAL_CURATOR,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
curator_user = UserManager.create(name="curator_user")
|
||||
curator_group = UserGroupManager.create(
|
||||
name="curator_group",
|
||||
user_ids=[curator_user.id],
|
||||
cc_pair_ids=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[curator_group],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.set_curator_status(
|
||||
test_user_group=curator_group,
|
||||
user_to_set_as_curator=curator_user,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
initial_file_id, initial_file_name = _upload_connector_file(
|
||||
user_performing_action=global_curator_creator,
|
||||
file_name="initial-file.txt",
|
||||
content=b"initial file content",
|
||||
)
|
||||
|
||||
connector = ConnectorManager.create(
|
||||
user_performing_action=global_curator_creator,
|
||||
name="public_file_connector",
|
||||
source=DocumentSource.FILE,
|
||||
connector_specific_config={
|
||||
"file_locations": [initial_file_id],
|
||||
"file_names": [initial_file_name],
|
||||
"zip_metadata_file_id": None,
|
||||
},
|
||||
access_type=AccessType.PUBLIC,
|
||||
groups=[],
|
||||
)
|
||||
credential = CredentialManager.create(
|
||||
user_performing_action=global_curator_creator,
|
||||
source=DocumentSource.FILE,
|
||||
curator_public=True,
|
||||
groups=[],
|
||||
name="public_file_connector_credential",
|
||||
)
|
||||
CCPairManager.create(
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
user_performing_action=global_curator_creator,
|
||||
access_type=AccessType.PUBLIC,
|
||||
groups=[],
|
||||
name="public_file_connector_cc_pair",
|
||||
)
|
||||
|
||||
curator_list_response = _list_connector_files(
|
||||
connector_id=connector.id,
|
||||
user_performing_action=curator_user,
|
||||
)
|
||||
curator_list_response.raise_for_status()
|
||||
curator_list_payload = curator_list_response.json()
|
||||
assert any(f["file_id"] == initial_file_id for f in curator_list_payload["files"])
|
||||
|
||||
global_curator_list_response = _list_connector_files(
|
||||
connector_id=connector.id,
|
||||
user_performing_action=global_curator_editor,
|
||||
)
|
||||
global_curator_list_response.raise_for_status()
|
||||
global_curator_list_payload = global_curator_list_response.json()
|
||||
assert any(
|
||||
f["file_id"] == initial_file_id for f in global_curator_list_payload["files"]
|
||||
)
|
||||
|
||||
denied_response = _update_connector_files(
|
||||
connector_id=connector.id,
|
||||
user_performing_action=curator_user,
|
||||
file_ids_to_remove=[initial_file_id],
|
||||
new_file_name="curator-file.txt",
|
||||
new_file_content=b"curator updated file",
|
||||
)
|
||||
assert denied_response.status_code == 403
|
||||
|
||||
allowed_response = _update_connector_files(
|
||||
connector_id=connector.id,
|
||||
user_performing_action=global_curator_editor,
|
||||
file_ids_to_remove=[initial_file_id],
|
||||
new_file_name="global-curator-file.txt",
|
||||
new_file_content=b"global curator updated file",
|
||||
)
|
||||
allowed_response.raise_for_status()
|
||||
|
||||
payload = allowed_response.json()
|
||||
assert initial_file_id not in payload["file_paths"]
|
||||
assert "global-curator-file.txt" in payload["file_names"]
|
||||
|
||||
creator_group = UserGroupManager.create(
|
||||
name="creator_group",
|
||||
user_ids=[global_curator_creator.id],
|
||||
cc_pair_ids=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[creator_group],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
private_file_id, private_file_name = _upload_connector_file(
|
||||
user_performing_action=global_curator_creator,
|
||||
file_name="private-initial-file.txt",
|
||||
content=b"private initial file content",
|
||||
)
|
||||
|
||||
private_connector = ConnectorManager.create(
|
||||
user_performing_action=global_curator_creator,
|
||||
name="private_file_connector",
|
||||
source=DocumentSource.FILE,
|
||||
connector_specific_config={
|
||||
"file_locations": [private_file_id],
|
||||
"file_names": [private_file_name],
|
||||
"zip_metadata_file_id": None,
|
||||
},
|
||||
access_type=AccessType.PRIVATE,
|
||||
groups=[creator_group.id],
|
||||
)
|
||||
private_credential = CredentialManager.create(
|
||||
user_performing_action=global_curator_creator,
|
||||
source=DocumentSource.FILE,
|
||||
curator_public=False,
|
||||
groups=[creator_group.id],
|
||||
name="private_file_connector_credential",
|
||||
)
|
||||
CCPairManager.create(
|
||||
connector_id=private_connector.id,
|
||||
credential_id=private_credential.id,
|
||||
user_performing_action=global_curator_creator,
|
||||
access_type=AccessType.PRIVATE,
|
||||
groups=[creator_group.id],
|
||||
name="private_file_connector_cc_pair",
|
||||
)
|
||||
|
||||
private_denied_response = _update_connector_files(
|
||||
connector_id=private_connector.id,
|
||||
user_performing_action=global_curator_editor,
|
||||
file_ids_to_remove=[private_file_id],
|
||||
new_file_name="global-curator-private-file.txt",
|
||||
new_file_content=b"global curator private update",
|
||||
)
|
||||
assert private_denied_response.status_code == 403
|
||||
155
backend/tests/integration/tests/users/test_seat_limit.py
Normal file
155
backend/tests/integration/tests/users/test_seat_limit.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""Integration tests for seat limit enforcement on user creation paths.
|
||||
|
||||
Verifies that when a license with a seat limit is active, new user
|
||||
creation (registration, invite, reactivation) is blocked with HTTP 402.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
|
||||
import redis
|
||||
import requests
|
||||
|
||||
from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from ee.onyx.server.license.models import PlanType
|
||||
from onyx.configs.app_configs import REDIS_DB_NUMBER
|
||||
from onyx.configs.app_configs import REDIS_HOST
|
||||
from onyx.configs.app_configs import REDIS_PORT
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
|
||||
# TenantRedis prefixes every key with "{tenant_id}:".
|
||||
# Single-tenant deployments use "public" as the tenant id.
|
||||
_LICENSE_REDIS_KEY = "public:license:metadata"
|
||||
|
||||
|
||||
def _seed_license(r: redis.Redis, seats: int) -> None:
|
||||
"""Write a LicenseMetadata entry into Redis with the given seat cap."""
|
||||
now = datetime.utcnow()
|
||||
metadata = LicenseMetadata(
|
||||
tenant_id="public",
|
||||
organization_name="Test Org",
|
||||
seats=seats,
|
||||
used_seats=0, # check_seat_availability recalculates from DB
|
||||
plan_type=PlanType.ANNUAL,
|
||||
issued_at=now,
|
||||
expires_at=now + timedelta(days=365),
|
||||
status=ApplicationStatus.ACTIVE,
|
||||
source=LicenseSource.MANUAL_UPLOAD,
|
||||
)
|
||||
r.set(_LICENSE_REDIS_KEY, metadata.model_dump_json(), ex=300)
|
||||
|
||||
|
||||
def _clear_license(r: redis.Redis) -> None:
|
||||
r.delete(_LICENSE_REDIS_KEY)
|
||||
|
||||
|
||||
def _redis() -> redis.Redis:
|
||||
return redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB_NUMBER)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Registration
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_registration_blocked_when_seats_full(reset: None) -> None: # noqa: ARG001
|
||||
"""POST /auth/register returns 402 when the seat limit is reached."""
|
||||
r = _redis()
|
||||
|
||||
# First user is admin — occupies 1 seat
|
||||
UserManager.create(name="admin_user")
|
||||
|
||||
# License allows exactly 1 seat → already full
|
||||
_seed_license(r, seats=1)
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
url=f"{API_SERVER_URL}/auth/register",
|
||||
json={
|
||||
"email": "blocked@example.com",
|
||||
"username": "blocked@example.com",
|
||||
"password": "TestPassword123!",
|
||||
},
|
||||
headers=GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 402
|
||||
finally:
|
||||
_clear_license(r)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Invitation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_invite_blocked_when_seats_full(reset: None) -> None: # noqa: ARG001
|
||||
"""PUT /manage/admin/users returns 402 when the seat limit is reached."""
|
||||
r = _redis()
|
||||
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
_seed_license(r, seats=1)
|
||||
|
||||
try:
|
||||
response = requests.put(
|
||||
url=f"{API_SERVER_URL}/manage/admin/users",
|
||||
json={"emails": ["newuser@example.com"]},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 402
|
||||
finally:
|
||||
_clear_license(r)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Reactivation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_reactivation_blocked_when_seats_full(reset: None) -> None: # noqa: ARG001
|
||||
"""PATCH /manage/admin/activate-user returns 402 when seats are full."""
|
||||
r = _redis()
|
||||
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
basic_user = UserManager.create(name="basic_user")
|
||||
|
||||
# Deactivate the basic user (frees a seat in the DB count)
|
||||
UserManager.set_status(
|
||||
basic_user, target_status=False, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
# Set license to 1 seat — only admin counts now
|
||||
_seed_license(r, seats=1)
|
||||
|
||||
try:
|
||||
response = requests.patch(
|
||||
url=f"{API_SERVER_URL}/manage/admin/activate-user",
|
||||
json={"user_email": basic_user.email},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
assert response.status_code == 402
|
||||
finally:
|
||||
_clear_license(r)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# No license → no enforcement
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_registration_allowed_without_license(reset: None) -> None: # noqa: ARG001
|
||||
"""Without a license in Redis, registration is unrestricted."""
|
||||
r = _redis()
|
||||
|
||||
# Make sure there is no cached license
|
||||
_clear_license(r)
|
||||
|
||||
UserManager.create(name="admin_user")
|
||||
|
||||
# Second user should register without issue
|
||||
second_user = UserManager.create(name="second_user")
|
||||
assert second_user is not None
|
||||
@@ -17,6 +17,7 @@ class TestOnyxWebCrawler:
|
||||
content from public websites correctly.
|
||||
"""
|
||||
|
||||
@pytest.mark.skip(reason="Temporarily disabled")
|
||||
def test_fetches_public_url_successfully(self, admin_user: DATestUser) -> None:
|
||||
"""Test that the crawler can fetch content from a public URL."""
|
||||
response = requests.post(
|
||||
@@ -40,6 +41,7 @@ class TestOnyxWebCrawler:
|
||||
assert "This domain is for use in" in content
|
||||
assert "documentation" in content or "illustrative" in content
|
||||
|
||||
@pytest.mark.skip(reason="Temporarily disabled")
|
||||
def test_fetches_multiple_urls(self, admin_user: DATestUser) -> None:
|
||||
"""Test that the crawler can fetch multiple URLs in one request."""
|
||||
response = requests.post(
|
||||
|
||||
@@ -101,6 +101,33 @@ class TestMakeBillingRequest:
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "Bad request" in exc_info.value.message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.billing.service._get_headers")
|
||||
@patch("ee.onyx.server.billing.service._get_base_url")
|
||||
async def test_follows_redirects(
|
||||
self,
|
||||
mock_base_url: MagicMock,
|
||||
mock_headers: MagicMock,
|
||||
) -> None:
|
||||
"""AsyncClient must be created with follow_redirects=True.
|
||||
|
||||
The target server (cloud data plane for self-hosted, control
|
||||
plane for cloud) may sit behind nginx that returns 308
|
||||
(HTTP→HTTPS). httpx does not follow redirects by default,
|
||||
so we must explicitly opt in.
|
||||
"""
|
||||
from ee.onyx.server.billing.service import _make_billing_request
|
||||
|
||||
mock_base_url.return_value = "http://api.example.com"
|
||||
mock_headers.return_value = {"Authorization": "Bearer token"}
|
||||
mock_response = make_mock_response({"ok": True})
|
||||
mock_client = make_mock_http_client("get", response=mock_response)
|
||||
|
||||
with patch("httpx.AsyncClient", mock_client):
|
||||
await _make_billing_request(method="GET", path="/test")
|
||||
|
||||
mock_client.assert_called_once_with(timeout=30.0, follow_redirects=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("ee.onyx.server.billing.service._get_headers")
|
||||
@patch("ee.onyx.server.billing.service._get_base_url")
|
||||
|
||||
@@ -51,7 +51,6 @@ class TestApplyLicenseStatusToSettings:
|
||||
@pytest.mark.parametrize(
|
||||
"license_status,expected_app_status,expected_ee_enabled",
|
||||
[
|
||||
(None, ApplicationStatus.ACTIVE, False),
|
||||
(ApplicationStatus.GATED_ACCESS, ApplicationStatus.GATED_ACCESS, False),
|
||||
(ApplicationStatus.ACTIVE, ApplicationStatus.ACTIVE, True),
|
||||
],
|
||||
@@ -84,6 +83,56 @@ class TestApplyLicenseStatusToSettings:
|
||||
assert result.application_status == expected_app_status
|
||||
assert result.ee_features_enabled is expected_ee_enabled
|
||||
|
||||
@patch("ee.onyx.server.settings.api.ENTERPRISE_EDITION_ENABLED", True)
|
||||
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
|
||||
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
|
||||
@patch("ee.onyx.server.settings.api.refresh_license_cache", return_value=None)
|
||||
@patch("ee.onyx.server.settings.api.get_session_with_current_tenant")
|
||||
@patch("ee.onyx.server.settings.api.get_current_tenant_id")
|
||||
@patch("ee.onyx.server.settings.api.get_cached_license_metadata")
|
||||
def test_no_license_with_ee_flag_gates_access(
|
||||
self,
|
||||
mock_get_metadata: MagicMock,
|
||||
mock_get_tenant: MagicMock,
|
||||
_mock_get_session: MagicMock,
|
||||
_mock_refresh: MagicMock,
|
||||
base_settings: Settings,
|
||||
) -> None:
|
||||
"""No license + ENTERPRISE_EDITION_ENABLED=true → GATED_ACCESS."""
|
||||
from ee.onyx.server.settings.api import apply_license_status_to_settings
|
||||
|
||||
mock_get_tenant.return_value = "test_tenant"
|
||||
mock_get_metadata.return_value = None
|
||||
|
||||
result = apply_license_status_to_settings(base_settings)
|
||||
assert result.application_status == ApplicationStatus.GATED_ACCESS
|
||||
assert result.ee_features_enabled is False
|
||||
|
||||
@patch("ee.onyx.server.settings.api.ENTERPRISE_EDITION_ENABLED", False)
|
||||
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
|
||||
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
|
||||
@patch("ee.onyx.server.settings.api.refresh_license_cache", return_value=None)
|
||||
@patch("ee.onyx.server.settings.api.get_session_with_current_tenant")
|
||||
@patch("ee.onyx.server.settings.api.get_current_tenant_id")
|
||||
@patch("ee.onyx.server.settings.api.get_cached_license_metadata")
|
||||
def test_no_license_without_ee_flag_allows_community(
|
||||
self,
|
||||
mock_get_metadata: MagicMock,
|
||||
mock_get_tenant: MagicMock,
|
||||
_mock_get_session: MagicMock,
|
||||
_mock_refresh: MagicMock,
|
||||
base_settings: Settings,
|
||||
) -> None:
|
||||
"""No license + ENTERPRISE_EDITION_ENABLED=false → community mode (no gating)."""
|
||||
from ee.onyx.server.settings.api import apply_license_status_to_settings
|
||||
|
||||
mock_get_tenant.return_value = "test_tenant"
|
||||
mock_get_metadata.return_value = None
|
||||
|
||||
result = apply_license_status_to_settings(base_settings)
|
||||
assert result.application_status == ApplicationStatus.ACTIVE
|
||||
assert result.ee_features_enabled is False
|
||||
|
||||
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
|
||||
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
|
||||
@patch("ee.onyx.server.settings.api.get_current_tenant_id")
|
||||
@@ -105,9 +154,10 @@ class TestApplyLicenseStatusToSettings:
|
||||
assert result.ee_features_enabled is False
|
||||
|
||||
|
||||
class TestSettingsDefaultEEDisabled:
|
||||
"""Verify the Settings model defaults ee_features_enabled to False."""
|
||||
class TestSettingsDefaults:
|
||||
"""Verify Settings model defaults for CE deployments."""
|
||||
|
||||
def test_default_ee_features_disabled(self) -> None:
|
||||
"""CE default: ee_features_enabled is False."""
|
||||
settings = Settings()
|
||||
assert settings.ee_features_enabled is False
|
||||
|
||||
@@ -427,6 +427,37 @@ class TestForwardToControlPlane:
|
||||
assert exc_info.value.status_code == 502
|
||||
assert "Failed to connect to control plane" in str(exc_info.value.detail)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_follows_redirects(self) -> None:
|
||||
"""Test that AsyncClient is created with follow_redirects=True.
|
||||
|
||||
The control plane may sit behind a reverse proxy that returns
|
||||
308 (HTTP→HTTPS). httpx does not follow redirects by default,
|
||||
so we must explicitly opt in.
|
||||
"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"ok": True}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"ee.onyx.server.tenants.proxy.generate_data_plane_token"
|
||||
) as mock_token,
|
||||
patch("ee.onyx.server.tenants.proxy.httpx.AsyncClient") as mock_client,
|
||||
patch(
|
||||
"ee.onyx.server.tenants.proxy.CONTROL_PLANE_API_BASE_URL",
|
||||
"http://control.example.com",
|
||||
),
|
||||
):
|
||||
mock_token.return_value = "cp_token"
|
||||
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
await forward_to_control_plane("GET", "/test")
|
||||
|
||||
mock_client.assert_called_once_with(timeout=30.0, follow_redirects=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsupported_method(self) -> None:
|
||||
"""Test that unsupported HTTP methods raise ValueError."""
|
||||
|
||||
@@ -384,6 +384,29 @@ class TestWhitelistBehavior:
|
||||
verify_email_is_invited("Allowed@Example.Com")
|
||||
|
||||
|
||||
class TestSeatLimitEnforcement:
|
||||
"""Seat limits block new user creation on self-hosted deployments."""
|
||||
|
||||
def test_adding_user_fails_when_seats_full(self) -> None:
|
||||
from onyx.auth.users import enforce_seat_limit
|
||||
|
||||
seat_result = MagicMock(available=False, error_message="Seat limit reached")
|
||||
with patch(
|
||||
"onyx.auth.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda *_a, **_kw: seat_result,
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
enforce_seat_limit(MagicMock())
|
||||
|
||||
assert exc.value.status_code == 402
|
||||
|
||||
def test_seat_limit_only_enforced_for_self_hosted(self) -> None:
|
||||
from onyx.auth.users import enforce_seat_limit
|
||||
|
||||
with patch("onyx.auth.users.MULTI_TENANT", True):
|
||||
enforce_seat_limit(MagicMock()) # should not raise
|
||||
|
||||
|
||||
class TestCaseInsensitiveEmailMatching:
|
||||
"""Test case-insensitive email matching for existing user checks."""
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.chat.llm_loop import _should_keep_bedrock_tool_definitions
|
||||
from onyx.chat.llm_loop import construct_message_history
|
||||
from onyx.chat.models import ChatLoadedFile
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
@@ -10,6 +11,17 @@ from onyx.chat.models import ProjectFileMetadata
|
||||
from onyx.chat.models import ToolCallSimple
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
|
||||
|
||||
class _StubConfig:
|
||||
def __init__(self, model_provider: str) -> None:
|
||||
self.model_provider = model_provider
|
||||
|
||||
|
||||
class _StubLLM:
|
||||
def __init__(self, model_provider: str) -> None:
|
||||
self.config = _StubConfig(model_provider=model_provider)
|
||||
|
||||
|
||||
def create_message(
|
||||
@@ -568,3 +580,34 @@ class TestConstructMessageHistory:
|
||||
assert '"contents"' in project_message.message
|
||||
assert "Project file 0 content" in project_message.message
|
||||
assert "Project file 1 content" in project_message.message
|
||||
|
||||
|
||||
class TestBedrockToolConfigGuard:
|
||||
def test_bedrock_with_tool_history_keeps_tool_definitions(self) -> None:
|
||||
llm = _StubLLM(LlmProviderNames.BEDROCK)
|
||||
history = [
|
||||
create_message("Question", MessageType.USER, 5),
|
||||
create_assistant_with_tool_call("tc_1", "search", 5),
|
||||
create_tool_response("tc_1", "Tool output", 5),
|
||||
]
|
||||
|
||||
assert _should_keep_bedrock_tool_definitions(llm, history) is True
|
||||
|
||||
def test_bedrock_without_tool_history_does_not_keep_tool_definitions(self) -> None:
|
||||
llm = _StubLLM(LlmProviderNames.BEDROCK)
|
||||
history = [
|
||||
create_message("Question", MessageType.USER, 5),
|
||||
create_message("Answer", MessageType.ASSISTANT, 5),
|
||||
]
|
||||
|
||||
assert _should_keep_bedrock_tool_definitions(llm, history) is False
|
||||
|
||||
def test_non_bedrock_with_tool_history_does_not_keep_tool_definitions(self) -> None:
|
||||
llm = _StubLLM(LlmProviderNames.OPENAI)
|
||||
history = [
|
||||
create_message("Question", MessageType.USER, 5),
|
||||
create_assistant_with_tool_call("tc_1", "search", 5),
|
||||
create_tool_response("tc_1", "Tool output", 5),
|
||||
]
|
||||
|
||||
assert _should_keep_bedrock_tool_definitions(llm, history) is False
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentBase
|
||||
from onyx.connectors.models import TextSection
|
||||
|
||||
|
||||
def _minimal_doc_kwargs(metadata: dict) -> dict:
|
||||
return {
|
||||
"id": "test-doc",
|
||||
"sections": [TextSection(text="hello", link="http://example.com")],
|
||||
"source": DocumentSource.NOT_APPLICABLE,
|
||||
"semantic_identifier": "Test Doc",
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
|
||||
def test_int_values_coerced_to_str() -> None:
|
||||
doc = Document(**_minimal_doc_kwargs({"count": 42}))
|
||||
assert doc.metadata == {"count": "42"}
|
||||
|
||||
|
||||
def test_float_values_coerced_to_str() -> None:
|
||||
doc = Document(**_minimal_doc_kwargs({"score": 3.14}))
|
||||
assert doc.metadata == {"score": "3.14"}
|
||||
|
||||
|
||||
def test_bool_values_coerced_to_str() -> None:
|
||||
doc = Document(**_minimal_doc_kwargs({"active": True}))
|
||||
assert doc.metadata == {"active": "True"}
|
||||
|
||||
|
||||
def test_list_of_ints_coerced_to_list_of_str() -> None:
|
||||
doc = Document(**_minimal_doc_kwargs({"ids": [1, 2, 3]}))
|
||||
assert doc.metadata == {"ids": ["1", "2", "3"]}
|
||||
|
||||
|
||||
def test_list_of_mixed_types_coerced_to_list_of_str() -> None:
|
||||
doc = Document(**_minimal_doc_kwargs({"tags": ["a", 1, True, 2.5]}))
|
||||
assert doc.metadata == {"tags": ["a", "1", "True", "2.5"]}
|
||||
|
||||
|
||||
def test_list_of_dicts_coerced_to_list_of_str() -> None:
|
||||
raw = {"nested": [{"key": "val"}, {"key2": "val2"}]}
|
||||
doc = Document(**_minimal_doc_kwargs(raw))
|
||||
assert doc.metadata == {"nested": ["{'key': 'val'}", "{'key2': 'val2'}"]}
|
||||
|
||||
|
||||
def test_dict_value_coerced_to_str() -> None:
|
||||
raw = {"info": {"inner_key": "inner_val"}}
|
||||
doc = Document(**_minimal_doc_kwargs(raw))
|
||||
assert doc.metadata == {"info": "{'inner_key': 'inner_val'}"}
|
||||
|
||||
|
||||
def test_none_value_coerced_to_str() -> None:
|
||||
doc = Document(**_minimal_doc_kwargs({"empty": None}))
|
||||
assert doc.metadata == {"empty": "None"}
|
||||
|
||||
|
||||
def test_already_valid_str_values_unchanged() -> None:
|
||||
doc = Document(**_minimal_doc_kwargs({"key": "value"}))
|
||||
assert doc.metadata == {"key": "value"}
|
||||
|
||||
|
||||
def test_already_valid_list_of_str_unchanged() -> None:
|
||||
doc = Document(**_minimal_doc_kwargs({"tags": ["a", "b", "c"]}))
|
||||
assert doc.metadata == {"tags": ["a", "b", "c"]}
|
||||
|
||||
|
||||
def test_empty_metadata_unchanged() -> None:
|
||||
doc = Document(**_minimal_doc_kwargs({}))
|
||||
assert doc.metadata == {}
|
||||
|
||||
|
||||
def test_mixed_metadata_values() -> None:
|
||||
raw = {
|
||||
"str_val": "hello",
|
||||
"int_val": 99,
|
||||
"list_val": [1, "two", 3.0],
|
||||
"dict_val": {"nested": True},
|
||||
}
|
||||
doc = Document(**_minimal_doc_kwargs(raw))
|
||||
assert doc.metadata == {
|
||||
"str_val": "hello",
|
||||
"int_val": "99",
|
||||
"list_val": ["1", "two", "3.0"],
|
||||
"dict_val": "{'nested': True}",
|
||||
}
|
||||
|
||||
|
||||
def test_coercion_works_on_base_class() -> None:
|
||||
kwargs = _minimal_doc_kwargs({"count": 42})
|
||||
kwargs.pop("source")
|
||||
kwargs.pop("id")
|
||||
doc = DocumentBase(**kwargs)
|
||||
assert doc.metadata == {"count": "42"}
|
||||
204
backend/tests/unit/onyx/onyxbot/test_handle_regular_answer.py
Normal file
204
backend/tests/unit/onyx/onyxbot/test_handle_regular_answer.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""Tests for Slack channel reference resolution and tag filtering
|
||||
in handle_regular_answer.py."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.context.search.models import Tag
|
||||
from onyx.onyxbot.slack.constants import SLACK_CHANNEL_REF_PATTERN
|
||||
from onyx.onyxbot.slack.handlers.handle_regular_answer import resolve_channel_references
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _mock_client_with_channels(
|
||||
channel_map: dict[str, str],
|
||||
) -> MagicMock:
|
||||
"""Return a mock WebClient where conversations_info resolves IDs to names."""
|
||||
client = MagicMock()
|
||||
|
||||
def _conversations_info(channel: str) -> MagicMock:
|
||||
if channel in channel_map:
|
||||
resp = MagicMock()
|
||||
resp.validate = MagicMock()
|
||||
resp.__getitem__ = lambda _self, key: {
|
||||
"channel": {
|
||||
"name": channel_map[channel],
|
||||
"is_im": False,
|
||||
"is_mpim": False,
|
||||
}
|
||||
}[key]
|
||||
return resp
|
||||
raise SlackApiError("channel_not_found", response=MagicMock())
|
||||
|
||||
client.conversations_info = _conversations_info
|
||||
return client
|
||||
|
||||
|
||||
def _mock_logger() -> MagicMock:
|
||||
return MagicMock()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SLACK_CHANNEL_REF_PATTERN regex tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSlackChannelRefPattern:
|
||||
def test_matches_bare_channel_id(self) -> None:
|
||||
matches = SLACK_CHANNEL_REF_PATTERN.findall("<#C097NBWMY8Y>")
|
||||
assert matches == [("C097NBWMY8Y", "")]
|
||||
|
||||
def test_matches_channel_id_with_name(self) -> None:
|
||||
matches = SLACK_CHANNEL_REF_PATTERN.findall("<#C097NBWMY8Y|eng-infra>")
|
||||
assert matches == [("C097NBWMY8Y", "eng-infra")]
|
||||
|
||||
def test_matches_multiple_channels(self) -> None:
|
||||
msg = "compare <#C111AAA> and <#C222BBB|general>"
|
||||
matches = SLACK_CHANNEL_REF_PATTERN.findall(msg)
|
||||
assert len(matches) == 2
|
||||
assert ("C111AAA", "") in matches
|
||||
assert ("C222BBB", "general") in matches
|
||||
|
||||
def test_no_match_on_plain_text(self) -> None:
|
||||
matches = SLACK_CHANNEL_REF_PATTERN.findall("no channels here")
|
||||
assert matches == []
|
||||
|
||||
def test_no_match_on_user_mention(self) -> None:
|
||||
matches = SLACK_CHANNEL_REF_PATTERN.findall("<@U12345>")
|
||||
assert matches == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_channel_references tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveChannelReferences:
|
||||
def test_resolves_bare_channel_id_via_api(self) -> None:
|
||||
client = _mock_client_with_channels({"C097NBWMY8Y": "eng-infra"})
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="summary of <#C097NBWMY8Y> this week",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert message == "summary of #eng-infra this week"
|
||||
assert len(tags) == 1
|
||||
assert tags[0] == Tag(tag_key="Channel", tag_value="eng-infra")
|
||||
|
||||
def test_uses_name_from_pipe_format_without_api_call(self) -> None:
|
||||
client = MagicMock()
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="check <#C097NBWMY8Y|eng-infra> for updates",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert message == "check #eng-infra for updates"
|
||||
assert tags == [Tag(tag_key="Channel", tag_value="eng-infra")]
|
||||
# Should NOT have called the API since name was in the markup
|
||||
client.conversations_info.assert_not_called()
|
||||
|
||||
def test_multiple_channels(self) -> None:
|
||||
client = _mock_client_with_channels(
|
||||
{
|
||||
"C111AAA": "eng-infra",
|
||||
"C222BBB": "eng-general",
|
||||
}
|
||||
)
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="compare <#C111AAA> and <#C222BBB>",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert "#eng-infra" in message
|
||||
assert "#eng-general" in message
|
||||
assert "<#" not in message
|
||||
assert len(tags) == 2
|
||||
tag_values = {t.tag_value for t in tags}
|
||||
assert tag_values == {"eng-infra", "eng-general"}
|
||||
|
||||
def test_no_channel_references_returns_unchanged(self) -> None:
|
||||
client = MagicMock()
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="just a normal message with no channels",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert message == "just a normal message with no channels"
|
||||
assert tags == []
|
||||
|
||||
def test_api_failure_skips_channel_gracefully(self) -> None:
|
||||
# Client that fails for all channel lookups
|
||||
client = _mock_client_with_channels({})
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="check <#CBADID123>",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
# Message should remain unchanged for the failed channel
|
||||
assert "<#CBADID123>" in message
|
||||
assert tags == []
|
||||
logger.warning.assert_called_once()
|
||||
|
||||
def test_partial_failure_resolves_what_it_can(self) -> None:
|
||||
# Only one of two channels resolves
|
||||
client = _mock_client_with_channels({"C111AAA": "eng-infra"})
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="compare <#C111AAA> and <#CBADID123>",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert "#eng-infra" in message
|
||||
assert "<#CBADID123>" in message # failed one stays raw
|
||||
assert len(tags) == 1
|
||||
assert tags[0].tag_value == "eng-infra"
|
||||
|
||||
def test_duplicate_channel_produces_single_tag(self) -> None:
|
||||
client = _mock_client_with_channels({"C111AAA": "eng-infra"})
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="summarize <#C111AAA> and compare with <#C111AAA>",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert message == "summarize #eng-infra and compare with #eng-infra"
|
||||
assert len(tags) == 1
|
||||
assert tags[0].tag_value == "eng-infra"
|
||||
|
||||
def test_mixed_pipe_and_bare_formats(self) -> None:
|
||||
client = _mock_client_with_channels({"C222BBB": "random"})
|
||||
logger = _mock_logger()
|
||||
|
||||
message, tags = resolve_channel_references(
|
||||
message="see <#C111AAA|eng-infra> and <#C222BBB>",
|
||||
client=client,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
assert "#eng-infra" in message
|
||||
assert "#random" in message
|
||||
assert len(tags) == 2
|
||||
205
backend/tests/unit/onyx/onyxbot/test_slack_formatting.py
Normal file
205
backend/tests/unit/onyx/onyxbot/test_slack_formatting.py
Normal file
@@ -0,0 +1,205 @@
|
||||
from onyx.onyxbot.slack.formatting import _convert_slack_links_to_markdown
|
||||
from onyx.onyxbot.slack.formatting import _normalize_link_destinations
|
||||
from onyx.onyxbot.slack.formatting import _sanitize_html
|
||||
from onyx.onyxbot.slack.formatting import _transform_outside_code_blocks
|
||||
from onyx.onyxbot.slack.formatting import format_slack_message
|
||||
from onyx.onyxbot.slack.utils import remove_slack_text_interactions
|
||||
from onyx.utils.text_processing import decode_escapes
|
||||
|
||||
|
||||
def test_normalize_citation_link_wraps_url_with_parentheses() -> None:
|
||||
message = (
|
||||
"See [[1]](https://example.com/Access%20ID%20Card(s)%20Guide.pdf) for details."
|
||||
)
|
||||
|
||||
normalized = _normalize_link_destinations(message)
|
||||
|
||||
assert (
|
||||
"See [[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>) for details."
|
||||
== normalized
|
||||
)
|
||||
|
||||
|
||||
def test_normalize_citation_link_keeps_existing_angle_brackets() -> None:
|
||||
message = "[[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>)"
|
||||
|
||||
normalized = _normalize_link_destinations(message)
|
||||
|
||||
assert message == normalized
|
||||
|
||||
|
||||
def test_normalize_citation_link_handles_multiple_links() -> None:
|
||||
message = (
|
||||
"[[1]](https://example.com/(USA)%20Guide.pdf) "
|
||||
"[[2]](https://example.com/Plan(s)%20Overview.pdf)"
|
||||
)
|
||||
|
||||
normalized = _normalize_link_destinations(message)
|
||||
|
||||
assert "[[1]](<https://example.com/(USA)%20Guide.pdf>)" in normalized
|
||||
assert "[[2]](<https://example.com/Plan(s)%20Overview.pdf>)" in normalized
|
||||
|
||||
|
||||
def test_format_slack_message_keeps_parenthesized_citation_links_intact() -> None:
|
||||
message = (
|
||||
"Download [[1]](https://example.com/(USA)%20Access%20ID%20Card(s)%20Guide.pdf)"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
rendered = decode_escapes(remove_slack_text_interactions(formatted))
|
||||
|
||||
assert (
|
||||
"<https://example.com/(USA)%20Access%20ID%20Card(s)%20Guide.pdf|[1]>"
|
||||
in rendered
|
||||
)
|
||||
assert "|[1]>%20Access%20ID%20Card" not in rendered
|
||||
|
||||
|
||||
def test_slack_style_links_converted_to_clickable_links() -> None:
|
||||
message = "Visit <https://example.com/page|Example Page> for details."
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "<https://example.com/page|Example Page>" in formatted
|
||||
assert "<" not in formatted
|
||||
|
||||
|
||||
def test_slack_style_links_preserved_inside_code_blocks() -> None:
|
||||
message = "```\n<https://example.com|click>\n```"
|
||||
|
||||
converted = _convert_slack_links_to_markdown(message)
|
||||
|
||||
assert "<https://example.com|click>" in converted
|
||||
|
||||
|
||||
def test_html_tags_stripped_outside_code_blocks() -> None:
|
||||
message = "Hello<br/>world ```<div>code</div>``` after"
|
||||
|
||||
sanitized = _transform_outside_code_blocks(message, _sanitize_html)
|
||||
|
||||
assert "<br" not in sanitized
|
||||
assert "<div>code</div>" in sanitized
|
||||
|
||||
|
||||
def test_format_slack_message_block_spacing() -> None:
|
||||
message = "Paragraph one.\n\nParagraph two."
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "Paragraph one.\n\nParagraph two." == formatted
|
||||
|
||||
|
||||
def test_format_slack_message_code_block_no_trailing_blank_line() -> None:
|
||||
message = "```python\nprint('hi')\n```"
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert formatted.endswith("print('hi')\n```")
|
||||
|
||||
|
||||
def test_format_slack_message_ampersand_not_double_escaped() -> None:
|
||||
message = 'She said "hello" & goodbye.'
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "&" in formatted
|
||||
assert """ not in formatted
|
||||
|
||||
|
||||
# -- Table rendering tests --
|
||||
|
||||
|
||||
def test_table_renders_as_vertical_cards() -> None:
|
||||
message = (
|
||||
"| Feature | Status | Owner |\n"
|
||||
"|---------|--------|-------|\n"
|
||||
"| Auth | Done | Alice |\n"
|
||||
"| Search | In Progress | Bob |\n"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "*Auth*\n • Status: Done\n • Owner: Alice" in formatted
|
||||
assert "*Search*\n • Status: In Progress\n • Owner: Bob" in formatted
|
||||
# Cards separated by blank line
|
||||
assert "Owner: Alice\n\n*Search*" in formatted
|
||||
# No raw pipe-and-dash table syntax
|
||||
assert "---|" not in formatted
|
||||
|
||||
|
||||
def test_table_single_column() -> None:
|
||||
message = "| Name |\n|------|\n| Alice |\n| Bob |\n"
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "*Alice*" in formatted
|
||||
assert "*Bob*" in formatted
|
||||
|
||||
|
||||
def test_table_embedded_in_text() -> None:
|
||||
message = (
|
||||
"Here are the results:\n\n"
|
||||
"| Item | Count |\n"
|
||||
"|------|-------|\n"
|
||||
"| Apples | 5 |\n"
|
||||
"\n"
|
||||
"That's all."
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "Here are the results:" in formatted
|
||||
assert "*Apples*\n • Count: 5" in formatted
|
||||
assert "That's all." in formatted
|
||||
|
||||
|
||||
def test_table_with_formatted_cells() -> None:
|
||||
message = (
|
||||
"| Name | Link |\n"
|
||||
"|------|------|\n"
|
||||
"| **Alice** | [profile](https://example.com) |\n"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
# Bold cell should not double-wrap: *Alice* not **Alice**
|
||||
assert "*Alice*" in formatted
|
||||
assert "**Alice**" not in formatted
|
||||
assert "<https://example.com|profile>" in formatted
|
||||
|
||||
|
||||
def test_table_with_alignment_specifiers() -> None:
|
||||
message = (
|
||||
"| Left | Center | Right |\n" "|:-----|:------:|------:|\n" "| a | b | c |\n"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "*a*\n • Center: b\n • Right: c" in formatted
|
||||
|
||||
|
||||
def test_two_tables_in_same_message_use_independent_headers() -> None:
|
||||
message = (
|
||||
"| A | B |\n"
|
||||
"|---|---|\n"
|
||||
"| 1 | 2 |\n"
|
||||
"\n"
|
||||
"| X | Y | Z |\n"
|
||||
"|---|---|---|\n"
|
||||
"| p | q | r |\n"
|
||||
)
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
assert "*1*\n • B: 2" in formatted
|
||||
assert "*p*\n • Y: q\n • Z: r" in formatted
|
||||
|
||||
|
||||
def test_table_empty_first_column_no_bare_asterisks() -> None:
|
||||
message = "| Name | Status |\n" "|------|--------|\n" "| | Done |\n"
|
||||
|
||||
formatted = format_slack_message(message)
|
||||
|
||||
# Empty title should not produce "**" (bare asterisks)
|
||||
assert "**" not in formatted
|
||||
assert " • Status: Done" in formatted
|
||||
@@ -0,0 +1,47 @@
|
||||
"""Test bulk invite limit for free trial tenants."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from onyx.server.manage.users import bulk_invite_users
|
||||
|
||||
|
||||
@patch("onyx.server.manage.users.MULTI_TENANT", True)
|
||||
@patch("onyx.server.manage.users.is_tenant_on_trial_fn", return_value=True)
|
||||
@patch("onyx.server.manage.users.get_current_tenant_id", return_value="test_tenant")
|
||||
@patch("onyx.server.manage.users.get_invited_users", return_value=[])
|
||||
@patch("onyx.server.manage.users.get_all_users", return_value=[])
|
||||
@patch("onyx.server.manage.users.NUM_FREE_TRIAL_USER_INVITES", 5)
|
||||
def test_trial_tenant_cannot_exceed_invite_limit(*_mocks: None) -> None:
|
||||
"""Trial tenants cannot invite more users than the configured limit."""
|
||||
emails = [f"user{i}@example.com" for i in range(6)]
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
bulk_invite_users(emails=emails)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "invite limit" in exc_info.value.detail.lower()
|
||||
|
||||
|
||||
@patch("onyx.server.manage.users.MULTI_TENANT", True)
|
||||
@patch("onyx.server.manage.users.DEV_MODE", True)
|
||||
@patch("onyx.server.manage.users.ENABLE_EMAIL_INVITES", False)
|
||||
@patch("onyx.server.manage.users.is_tenant_on_trial_fn", return_value=True)
|
||||
@patch("onyx.server.manage.users.get_current_tenant_id", return_value="test_tenant")
|
||||
@patch("onyx.server.manage.users.get_invited_users", return_value=[])
|
||||
@patch("onyx.server.manage.users.get_all_users", return_value=[])
|
||||
@patch("onyx.server.manage.users.write_invited_users", return_value=3)
|
||||
@patch("onyx.server.manage.users.NUM_FREE_TRIAL_USER_INVITES", 5)
|
||||
@patch(
|
||||
"onyx.server.manage.users.fetch_ee_implementation_or_noop",
|
||||
return_value=lambda *_args: None,
|
||||
)
|
||||
def test_trial_tenant_can_invite_within_limit(*_mocks: None) -> None:
|
||||
"""Trial tenants can invite users when under the limit."""
|
||||
emails = ["user1@example.com", "user2@example.com", "user3@example.com"]
|
||||
|
||||
result = bulk_invite_users(emails=emails)
|
||||
|
||||
assert result == 3
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user