Compare commits

...

85 Commits

Author SHA1 Message Date
Nikolas Garza
d676634f1c fix(slackbot): resolve channel references and filter search by channel tags (#9256) to release v2.12 (#9293) 2026-03-11 20:23:43 -07:00
Jamison Lahman
c49e9a93e8 fix(fe): fix API Key Role dropdown options (#9154) 2026-03-06 14:24:48 -08:00
Nikolas Garza
12b2ab2459 fix(billing): handle manual license users without Stripe subscription (#8787) 2026-03-05 15:44:34 -08:00
Nik
f8005eb90c fix(e2e): fix eeFeatures fixture login and ee_feature_redirect test for v2.12
- eeFeatures fixture: try fetching /api/settings before calling loginAs,
  avoiding timeout when user is already authenticated from a prior test
- ee_feature_redirect test: match v2.12 behavior (redirect to /app, no
  toast container assertion since v2.12 uses usePopup not useToast)
2026-03-05 15:01:27 -08:00
Nikolas Garza
152a950710 fix: EE route gating for upgrading CE users (#9026) 2026-03-05 15:01:27 -08:00
Justin Tahara
52eaf3d706 fix(permissions): Add file connector access control for global curators (#8990) 2026-03-05 11:20:14 -08:00
Jamison Lahman
96b7cc1711 chore(playwright): remove chromatic (#8339) 2026-03-05 11:00:36 -08:00
Justin Tahara
47d3c511d4 fix(ui): InputComboBox search for users/groups (#8928) 2026-03-04 12:43:49 -08:00
Justin Tahara
414944ac47 chore(ui): Update the Share Agent Modal (#8915) 2026-03-04 12:39:41 -08:00
Nikolas Garza
dd13f730da feat(slack): convert markdown tables to Slack-friendly format (#8999) 2026-03-04 11:54:44 -08:00
Nikolas Garza
bb71a91689 fix(llm): enforce persona restrictions on public LLM providers (#8846)
Co-authored-by: Dane <dane@onyx.app>
2026-03-02 15:29:45 -08:00
SubashMohan
af0721e063 feat: add mixed content handler for chat and image generation packets (#8494) 2026-03-02 10:12:07 +05:30
SubashMohan
567651a812 fix(popover): prevent viewport overflow with dynamic max-height and collision padding (#8675) 2026-03-02 10:11:55 +05:30
SubashMohan
7589767bb9 perf(open-url): parallelize URL fetching with split connect/read timeouts (#8580) 2026-03-02 10:11:42 +05:30
Justin Tahara
589d613f1e fix(celery): Guardrail for User File Processing (#8633) 2026-03-01 10:29:48 -08:00
Justin Tahara
b17d7e0033 fix(gong): Respecting Retry Timeout Header (#8866) 2026-02-27 13:57:10 -08:00
Nikolas Garza
131d418771 fix(slack): sanitize HTML tags and broken citation links in bot responses (#8767) 2026-02-26 16:50:45 -08:00
Jamison Lahman
0be04391b3 chore(devtools): upgrade ods: v0.6.1->v0.6.2 (#8773) 2026-02-26 16:20:47 -08:00
Jamison Lahman
20351d9998 chore(gha): update helm/chart-testing-action version (#8536) 2026-02-25 14:30:27 -08:00
Jamison Lahman
22152ad871 chore(ods): Automated Cherry-pick backport (#8642) to release v2.12 (#8770)
Co-authored-by: Justin Tahara <105671973+justin-tahara@users.noreply.github.com>
2026-02-25 22:10:32 +00:00
Jamison Lahman
7caf197f98 chore(fe): update human message size (#8547) 2026-02-25 14:08:09 -08:00
Jamison Lahman
140bc82b36 fix(fe): inline code-blocks respect header font-size (#8691) 2026-02-25 12:11:33 -08:00
Jamison Lahman
e7ecbfafd1 fix(fe): middle align human chat message text (#8756) 2026-02-25 11:21:05 -08:00
Evan Lohn
2c2af369f5 chore: coerce doc metadata (#8703) 2026-02-23 17:54:13 -08:00
justin-tahara
2032b76fbf chore(release): Fixing Release Branch 2026-02-20 14:45:30 -08:00
Jamison Lahman
055b30b00e chore(fe): fix drop-down overflow in API Key modal (#8574) 2026-02-20 14:26:31 -08:00
Jamison Lahman
360a4cf591 chore(fe): remove close button from image gen tooltip (#8585) 2026-02-20 14:13:16 -08:00
Jamison Lahman
3d3cab9f91 fix(fe): popover width can fit trigger element (#8624) 2026-02-20 14:13:16 -08:00
Justin Tahara
6120d012ba feat(web): FE Changes for Brave Web Search 3/3 (#8597) 2026-02-20 11:29:02 -08:00
Evan Lohn
3e7e2e93f2 fix: search tool enabled when nothing selected 2026-02-20 11:05:46 -08:00
Justin Tahara
ccf482fa3b hotfix/web 2026-02-20 11:03:32 -08:00
Justin Tahara
fd45a612da feat(web): Initial Framework for Brave Web Search 1/3 (#8594) 2026-02-20 10:58:41 -08:00
Danelegend
c444d8883b fix: /llm/provider route returns all providers (#8545) 2026-02-20 10:48:56 -08:00
SubashMohan
9947837f9f fix: update SourceTag component to use variant prop for sizing (#8582) 2026-02-20 11:54:18 +05:30
SubashMohan
bc324a8070 fix(ui): fix few common ui bugs (#8425) 2026-02-20 11:54:04 +05:30
SubashMohan
26f648c24a fix(chatpage): Improve agent message layout, sidebar nesting, and icon fixes (#8224) 2026-02-20 10:49:23 +05:30
SubashMohan
638f20f5f3 fix(timeline): reduce agent message re-renders with referential stability in usePacedTurnGroups (#8265) 2026-02-20 10:49:04 +05:30
Jamison Lahman
f6ee57f523 chore(gha): rm nightly license scan workflow (#8541) 2026-02-19 20:03:58 -08:00
Justin Tahara
aae6fc7aac fix(desktop): Link clicking within App (#8493) 2026-02-19 17:44:32 -08:00
Justin Tahara
5d7a664250 fix(bedrock): Fixing toolConfig call (#8342) 2026-02-19 17:44:11 -08:00
Wenxi
e7386490bf fix(manage-users): exclude slack users from /users list (#8602) 2026-02-19 17:09:47 -08:00
Wenxi
106e10a143 fix: open_url broken on non-normalized urls and enable web crawl tests (#8508) 2026-02-19 17:09:47 -08:00
Wenxi
513f430a1b refactor: connector config refresh elements/cleanup (#8428) 2026-02-19 17:09:47 -08:00
Wenxi
696d73822f fix: remove log error when authtype is not set (#8399) 2026-02-19 17:09:47 -08:00
Wenxi
bfcc5a20a2 chore: make chatbackgrounds local assets for air-gapped envs (#8381) 2026-02-19 17:09:47 -08:00
Wenxi
efe3613354 fix: allow basic users to share agents (#8269) 2026-02-19 17:09:47 -08:00
Nikolas Garza
62405bdc42 fix(ee): small ux fixes for licensing (#8498) 2026-02-19 14:32:28 -08:00
Yuhong Sun
8f505dc45f chore: License update (No change, just touchup) (#8460) 2026-02-19 14:32:28 -08:00
Jessica Singh
75f0db4fe5 chore(bulk invite): free trial limit (#8378) 2026-02-19 14:32:28 -08:00
Nikolas Garza
f0a5c579a3 feat(auth): enforce seat limits on all user creation paths (#8401) 2026-02-19 14:32:28 -08:00
Nikolas Garza
293bf30847 fix(billing): exclude inactive users from seat counts and allow users page when gated (#8397) 2026-02-19 14:32:28 -08:00
Nikolas Garza
8774ca3b0f feat(ee): gate access only when legacy EE flag is set and no license exists (#8368) 2026-02-19 14:32:28 -08:00
Nikolas Garza
016a73f85f fix(ee): follow HTTP→HTTPS redirects in forward_to_control_plane (#8360) 2026-02-19 14:32:28 -08:00
Wenxi
2eddb4e23e fix: upgrade plan page nits (#8346) 2026-02-19 14:32:28 -08:00
Nikolas Garza
0a61660a59 fix(ee): copy license public key into Docker image (#8322) 2026-02-19 14:32:28 -08:00
Danelegend
a10599e76e fix: model config not populating flow during sync (#8542) 2026-02-18 17:11:52 -08:00
Nikolas Garza
b3d3f7af76 feat(ee): Enable license enforcement by default (#8270) 2026-02-09 20:43:33 -08:00
Jamison Lahman
03d919c918 chore(devtools): upgrade ods: 0.5.0->0.5.1 (#8279) 2026-02-09 23:31:54 +00:00
Justin Tahara
71d2ae563a fix(posthog): Chat metrics for Cloud (#8278) 2026-02-09 22:58:37 +00:00
Jamison Lahman
19f9c7357c chore(devtools): ods logs, ods pull, ods compose --force-recreate (#8277) 2026-02-09 22:51:01 +00:00
acaprau
f8fa5b243c chore(opensearch): Try to create OpenSearchTenantMigrationRecord earlier in check_for_documents_for_opensearch_migration_task (#8260) 2026-02-09 22:13:43 +00:00
dependabot[bot]
5f845c208f chore(deps-dev): bump pytest-xdist from 3.6.1 to 3.8.0 in /backend (#8120)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-09 22:06:38 +00:00
Raunak Bhagat
d8595f8de0 refactor(opal): add new Button component built on Interactive.Base (#8263)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 21:52:13 +00:00
dependabot[bot]
5b00d1ef9c chore(deps-dev): bump faker from 37.1.0 to 40.1.2 in /backend (#8126)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-09 21:46:06 +00:00
dependabot[bot]
41b6ed92a9 chore(deps): bump docker/login-action from 3.6.0 to 3.7.0 (#8275)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-09 21:26:07 +00:00
dependabot[bot]
07f35336ad chore(deps): bump @modelcontextprotocol/sdk from 1.25.3 to 1.26.0 in /backend/onyx/server/features/build/sandbox/kubernetes/docker/templates/outputs/web (#8166)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-09 13:20:12 -08:00
dependabot[bot]
4728bb87c7 chore(deps): bump @isaacs/brace-expansion from 5.0.0 to 5.0.1 in /backend/onyx/server/features/build/sandbox/kubernetes/docker/templates/outputs/web (#8139)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-09 13:20:04 -08:00
dependabot[bot]
adfa2f30af chore(deps): bump actions/cache from 4.3.0 to 5.0.3 (#8273)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-09 13:19:41 -08:00
dependabot[bot]
9dac4165fb chore(deps): bump actions/setup-python from 6.1.0 to 6.2.0 (#8274)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-09 13:18:26 -08:00
dependabot[bot]
7d2ede5efc chore(deps): bump protobuf from 6.33.4 to 6.33.5 in /backend/requirements (#8182)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-09 21:04:34 +00:00
dependabot[bot]
4592f6885f chore(deps): bump python-multipart from 0.0.21 to 0.0.22 in /backend/requirements (#7831)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-09 20:44:54 +00:00
Evan Lohn
9dc14fad79 chore: disable hiernodes when opensearch not available (#8271) 2026-02-09 20:32:47 +00:00
dependabot[bot]
ff6e471cfb chore(deps): bump actions/setup-node from 4.4.0 to 6.2.0 (#8122)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-09 20:31:35 +00:00
dependabot[bot]
09b9443405 chore(deps): bump bytes from 1.11.0 to 1.11.1 in /desktop/src-tauri (#8138)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-09 12:34:07 -08:00
dependabot[bot]
14cd6d08e8 chore(deps): bump webpack from 5.102.1 to 5.105.0 in /web (#8199)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-09 12:18:29 -08:00
dependabot[bot]
5ee16697ce chore(deps): bump time from 0.3.44 to 0.3.47 in /desktop/src-tauri (#8187)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-09 12:17:40 -08:00
dependabot[bot]
b794f7e10d chore(deps): bump actions/upload-artifact from 4.6.2 to 6.0.0 (#8121)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-09 12:14:25 -08:00
dependabot[bot]
bb3275bb75 chore(deps): bump actions/checkout from 6.0.1 to 6.0.2 (#8123)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-09 12:13:37 -08:00
roshan
7644e225a5 fix(chrome extension): Simplify NRFPage ChatInputBar layout to use normal flex flow (#8267)
Co-authored-by: Claude <noreply@anthropic.com>
2026-02-09 18:12:54 +00:00
roshan
811600b84a fix(craft): snapshot restore (#8194) 2026-02-09 18:00:07 +00:00
Jamison Lahman
40ce8615ff fix(login): window undefined on login (#8266) 2026-02-09 17:55:05 +00:00
Justin Tahara
0cee3f6960 chore(llm): Introduce Scaffolding for Integration Tests (#8251) 2026-02-09 17:26:15 +00:00
acaprau
8883e5608f chore(chat frontend): Round up in formatDurationSeconds so we don't see "Thought for 0s" (#8259) 2026-02-09 07:54:39 +00:00
acaprau
7c2f3ded44 fix(opensearch): Tighten up task timing (#8256) 2026-02-09 07:53:44 +00:00
Raunak Bhagat
aa094ce1f0 refactor(opal): interactive base variant types + foreground color system (#8255)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 07:26:05 +00:00
260 changed files with 7365 additions and 4663 deletions

View File

@@ -8,5 +8,5 @@
## Additional Options
- [ ] [Required] I have considered whether this PR needs to be cherry-picked to the latest beta branch.
- [ ] [Optional] Please cherry-pick this PR to the latest release version.
- [ ] [Optional] Override Linear Check

View File

@@ -249,7 +249,7 @@ jobs:
xdg-utils
- name: setup node
uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v6.1.0
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v6.2.0
with:
node-version: 24
package-manager-cache: false
@@ -409,7 +409,7 @@ jobs:
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
@@ -482,7 +482,7 @@ jobs:
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
@@ -542,7 +542,7 @@ jobs:
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
@@ -620,7 +620,7 @@ jobs:
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
@@ -701,7 +701,7 @@ jobs:
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
@@ -769,7 +769,7 @@ jobs:
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
@@ -844,7 +844,7 @@ jobs:
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
@@ -916,7 +916,7 @@ jobs:
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
@@ -975,7 +975,7 @@ jobs:
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
@@ -1053,7 +1053,7 @@ jobs:
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
@@ -1126,7 +1126,7 @@ jobs:
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
@@ -1187,7 +1187,7 @@ jobs:
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
@@ -1267,7 +1267,7 @@ jobs:
buildkitd-flags: ${{ vars.DOCKER_DEBUG == 'true' && '--debug' || '' }}
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
@@ -1346,7 +1346,7 @@ jobs:
buildkitd-flags: ${{ vars.DOCKER_DEBUG == 'true' && '--debug' || '' }}
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
@@ -1409,7 +1409,7 @@ jobs:
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}

View File

@@ -24,7 +24,7 @@ jobs:
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}

View File

@@ -24,7 +24,7 @@ jobs:
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}

View File

@@ -1,151 +0,0 @@
# Scan for problematic software licenses
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
name: 'Nightly - Scan licenses'
on:
# schedule:
# - cron: '0 14 * * *' # Runs every day at 6 AM PST / 7 AM PDT / 2 PM UTC
workflow_dispatch: # Allows manual triggering
permissions:
actions: read
contents: read
jobs:
scan-licenses:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-licenses"]
timeout-minutes: 45
permissions:
actions: read
contents: read
security-events: write
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # ratchet:actions/setup-python@v6
with:
python-version: '3.11'
cache: 'pip'
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- name: Get explicit and transitive dependencies
run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
pip freeze > requirements-all.txt
- name: Check python
id: license_check_report
uses: pilosus/action-pip-license-checker@e909b0226ff49d3235c99c4585bc617f49fff16a # ratchet:pilosus/action-pip-license-checker@v3
with:
requirements: 'requirements-all.txt'
fail: 'Copyleft'
exclude: '(?i)^(pylint|aio[-_]*).*'
- name: Print report
if: always()
env:
REPORT: ${{ steps.license_check_report.outputs.report }}
run: echo "$REPORT"
- name: Install npm dependencies
working-directory: ./web
run: npm ci
# be careful enabling the sarif and upload as it may spam the security tab
# with a huge amount of items. Work out the issues before enabling upload.
# - name: Run Trivy vulnerability scanner in repo mode
# if: always()
# uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
# with:
# scan-type: fs
# scan-ref: .
# scanners: license
# format: table
# severity: HIGH,CRITICAL
# # format: sarif
# # output: trivy-results.sarif
#
# # - name: Upload Trivy scan results to GitHub Security tab
# # uses: github/codeql-action/upload-sarif@v3
# # with:
# # sarif_file: trivy-results.sarif
scan-trivy:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-trivy"]
timeout-minutes: 45
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# Backend
- name: Pull backend docker image
run: docker pull onyxdotapp/onyx-backend:latest
- name: Run Trivy vulnerability scanner on backend
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-backend:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0 # Set to 1 if we want a failed scan to fail the workflow
# Web server
- name: Pull web server docker image
run: docker pull onyxdotapp/onyx-web-server:latest
- name: Run Trivy vulnerability scanner on web server
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-web-server:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0
# Model server
- name: Pull model server docker image
run: docker pull onyxdotapp/onyx-model-server:latest
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-model-server:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0

View File

@@ -0,0 +1,79 @@
name: Post-Merge Beta Cherry-Pick
on:
push:
branches:
- main
permissions:
contents: write
pull-requests: write
jobs:
cherry-pick-to-latest-release:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Resolve merged PR and checkbox state
id: gate
env:
GH_TOKEN: ${{ github.token }}
run: |
# For the commit that triggered this workflow (HEAD on main), fetch all
# associated PRs and keep only the PR that was actually merged into main
# with this exact merge commit SHA.
pr_numbers="$(gh api "repos/${GITHUB_REPOSITORY}/commits/${GITHUB_SHA}/pulls" | jq -r --arg sha "${GITHUB_SHA}" '.[] | select(.merged_at != null and .base.ref == "main" and .merge_commit_sha == $sha) | .number')"
match_count="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | wc -l | tr -d ' ')"
pr_number="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | head -n 1)"
if [ "${match_count}" -gt 1 ]; then
echo "::warning::Multiple merged PRs matched commit ${GITHUB_SHA}. Using PR #${pr_number}."
fi
if [ -z "$pr_number" ]; then
echo "No merged PR associated with commit ${GITHUB_SHA}; skipping."
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
exit 0
fi
# Read the PR body and check whether the helper checkbox is checked.
pr_body="$(gh api "repos/${GITHUB_REPOSITORY}/pulls/${pr_number}" --jq '.body // ""')"
echo "pr_number=$pr_number" >> "$GITHUB_OUTPUT"
if echo "$pr_body" | grep -qiE "\\[x\\][[:space:]]*(\\[[^]]+\\][[:space:]]*)?Please cherry-pick this PR to the latest release version"; then
echo "should_cherrypick=true" >> "$GITHUB_OUTPUT"
echo "Cherry-pick checkbox checked for PR #${pr_number}."
exit 0
fi
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
echo "Cherry-pick checkbox not checked for PR #${pr_number}. Skipping."
- name: Checkout repository
if: steps.gate.outputs.should_cherrypick == 'true'
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
fetch-depth: 0
persist-credentials: true
ref: main
- name: Install the latest version of uv
if: steps.gate.outputs.should_cherrypick == 'true'
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"
- name: Configure git identity
if: steps.gate.outputs.should_cherrypick == 'true'
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
- name: Create cherry-pick PR to latest release
if: steps.gate.outputs.should_cherrypick == 'true'
env:
GH_TOKEN: ${{ github.token }}
GITHUB_TOKEN: ${{ github.token }}
run: |
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify

View File

@@ -1,28 +0,0 @@
name: Require beta cherry-pick consideration
concurrency:
group: Require-Beta-Cherrypick-Consideration-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
pull_request:
types: [opened, edited, reopened, synchronize]
permissions:
contents: read
jobs:
beta-cherrypick-check:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Check PR body for beta cherry-pick consideration
env:
PR_BODY: ${{ github.event.pull_request.body }}
run: |
if echo "$PR_BODY" | grep -qiE "\\[x\\][[:space:]]*\\[Required\\][[:space:]]*I have considered whether this PR needs to be cherry[- ]picked to the latest beta branch"; then
echo "Cherry-pick consideration box is checked. Check passed."
exit 0
fi
echo "::error::Please check the 'I have considered whether this PR needs to be cherry-picked to the latest beta branch' box in the PR description."
exit 1

View File

@@ -40,13 +40,16 @@ jobs:
- name: Generate OpenAPI schema and Python client
shell: bash
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
env:
LICENSE_ENFORCEMENT_ENABLED: "false"
run: |
ods openapi all
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}

View File

@@ -45,12 +45,12 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd
with:
persist-credentials: false
- name: Setup node
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238
with:
node-version: 24
cache: "npm" # zizmor: ignore[cache-poisoning]
@@ -63,7 +63,7 @@ jobs:
targets: ${{ matrix.target }}
- name: Cache Cargo registry and build
uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # zizmor: ignore[cache-poisoning]
uses: actions/cache@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # zizmor: ignore[cache-poisoning]
with:
path: |
~/.cargo/bin/
@@ -105,7 +105,7 @@ jobs:
- name: Upload build artifacts
if: always()
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: desktop-build-${{ matrix.platform }}-${{ github.run_id }}
path: |

View File

@@ -110,7 +110,7 @@ jobs:
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}

View File

@@ -41,8 +41,7 @@ jobs:
version: v3.19.0
- name: Set up chart-testing
# NOTE: This is Jamison's patch from https://github.com/helm/chart-testing-action/pull/194
uses: helm/chart-testing-action@8958a6ac472cbd8ee9a8fbb6f1acbc1b0e966e44 # zizmor: ignore[impostor-commit]
uses: helm/chart-testing-action@b5eebdd9998021f29756c53432f48dab66394810
with:
uv_version: "0.9.9"

View File

@@ -109,7 +109,7 @@ jobs:
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
@@ -169,7 +169,7 @@ jobs:
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
@@ -214,7 +214,7 @@ jobs:
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
@@ -287,7 +287,7 @@ jobs:
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
@@ -302,6 +302,8 @@ jobs:
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
LICENSE_ENFORCEMENT_ENABLED=false
AUTH_TYPE=basic
POSTGRES_POOL_PRE_PING=true
POSTGRES_USE_NULL_POOL=true
@@ -466,7 +468,7 @@ jobs:
persist-credentials: false
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
@@ -478,6 +480,7 @@ jobs:
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
LICENSE_ENFORCEMENT_ENABLED=false \
MULTI_TENANT=true \
AUTH_TYPE=cloud \
REQUIRE_EMAIL_VERIFICATION=false \

View File

@@ -28,7 +28,7 @@ jobs:
persist-credentials: false
- name: Setup node
uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v4
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
with:
node-version: 22
cache: "npm"

View File

@@ -101,7 +101,7 @@ jobs:
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
@@ -161,7 +161,7 @@ jobs:
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
@@ -220,7 +220,7 @@ jobs:
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
@@ -279,7 +279,7 @@ jobs:
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}

View File

@@ -22,6 +22,9 @@ env:
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
FIRECRAWL_API_KEY: ${{ secrets.FIRECRAWL_API_KEY }}
GOOGLE_PSE_API_KEY: ${{ secrets.GOOGLE_PSE_API_KEY }}
GOOGLE_PSE_SEARCH_ENGINE_ID: ${{ secrets.GOOGLE_PSE_SEARCH_ENGINE_ID }}
# for federated slack tests
SLACK_CLIENT_ID: ${{ secrets.SLACK_CLIENT_ID }}
@@ -90,7 +93,7 @@ jobs:
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
@@ -151,7 +154,7 @@ jobs:
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
@@ -212,7 +215,7 @@ jobs:
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
@@ -259,7 +262,7 @@ jobs:
persist-credentials: false
- name: Setup node
uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v4
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
with:
node-version: 22
cache: "npm"
@@ -291,6 +294,8 @@ jobs:
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
LICENSE_ENFORCEMENT_ENABLED=false
AUTH_TYPE=basic
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
EXA_API_KEY=${EXA_API_KEY_VALUE}
@@ -305,7 +310,7 @@ jobs:
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # ratchet:docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
@@ -465,48 +470,3 @@ jobs:
- name: Check job status
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
run: exit 1
# NOTE: Chromatic UI diff testing is currently disabled.
# We are using Playwright for local and CI testing without visual regression checks.
# Chromatic may be reintroduced in the future for UI diff testing if needed.
# chromatic-tests:
# name: Chromatic Tests
# needs: playwright-tests
# runs-on:
# [
# runs-on,
# runner=32cpu-linux-x64,
# disk=large,
# "run-id=${{ github.run_id }}",
# ]
# steps:
# - name: Checkout code
# uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
# with:
# fetch-depth: 0
# - name: Setup node
# uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v4
# with:
# node-version: 22
# - name: Install node dependencies
# working-directory: ./web
# run: npm ci
# - name: Download Playwright test results
# uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # ratchet:actions/download-artifact@v4
# with:
# name: test-results
# path: ./web/test-results
# - name: Run Chromatic
# uses: chromaui/action@latest
# with:
# playwright: true
# projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
# workingDir: ./web
# env:
# CHROMATIC_ARCHIVE_LOCATION: ./test-results

View File

@@ -42,6 +42,9 @@ jobs:
- name: Generate OpenAPI schema and Python client
shell: bash
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
env:
LICENSE_ENFORCEMENT_ENABLED: "false"
run: |
ods openapi all

View File

@@ -64,7 +64,7 @@ jobs:
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Login to Docker Hub
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}

View File

@@ -27,6 +27,8 @@ jobs:
PYTHONPATH: ./backend
REDIS_CLOUD_PYTEST_PASSWORD: ${{ secrets.REDIS_CLOUD_PYTEST_PASSWORD }}
DISABLE_TELEMETRY: "true"
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
LICENSE_ENFORCEMENT_ENABLED: "false"
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2

View File

@@ -24,13 +24,13 @@ jobs:
with:
fetch-depth: 0
persist-credentials: false
- uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # ratchet:actions/setup-python@v6
- uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # ratchet:actions/setup-python@v6
with:
python-version: "3.11"
- name: Setup Terraform
uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # ratchet:hashicorp/setup-terraform@v3
- name: Setup node
uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v6
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v6
with: # zizmor: ignore[cache-poisoning]
node-version: 22
cache: "npm"

View File

@@ -2,7 +2,10 @@ Copyright (c) 2023-present DanswerAI, Inc.
Portions of this software are licensed as follows:
- All content that resides under "ee" directories of this repository, if that directory exists, is licensed under the license defined in "backend/ee/LICENSE". Specifically all content under "backend/ee" and "web/src/app/ee" is licensed under the license defined in "backend/ee/LICENSE".
- All content that resides under "ee" directories of this repository is licensed under the Onyx Enterprise License. Each ee directory contains an identical copy of this license at its root:
- backend/ee/LICENSE
- web/src/app/ee/LICENSE
- web/src/ee/LICENSE
- All third party components incorporated into the Onyx Software are licensed under the original license provided by the owner of the applicable component.
- Content outside of the above mentioned directories or restrictions above is available under the "MIT Expat" license as defined below.

View File

@@ -134,6 +134,7 @@ COPY --chown=onyx:onyx ./alembic_tenants /app/alembic_tenants
COPY --chown=onyx:onyx ./alembic.ini /app/alembic.ini
COPY supervisord.conf /usr/etc/supervisord.conf
COPY --chown=onyx:onyx ./static /app/static
COPY --chown=onyx:onyx ./keys /app/keys
# Escape hatch scripts
COPY --chown=onyx:onyx ./scripts/debugging /app/scripts/debugging

View File

@@ -1,20 +1,20 @@
The DanswerAI Enterprise license (the Enterprise License)
The Onyx Enterprise License (the "Enterprise License")
Copyright (c) 2023-present DanswerAI, Inc.
With regard to the Onyx Software:
This software and associated documentation files (the "Software") may only be
used in production, if you (and any entity that you represent) have agreed to,
and are in compliance with, the DanswerAI Subscription Terms of Service, available
at https://onyx.app/terms (the Enterprise Terms), or other
and are in compliance with, the Onyx Subscription Terms of Service, available
at https://www.onyx.app/legal/self-host (the "Enterprise Terms"), or other
agreement governing the use of the Software, as agreed by you and DanswerAI,
and otherwise have a valid Onyx Enterprise license for the
and otherwise have a valid Onyx Enterprise License for the
correct number of user seats. Subject to the foregoing sentence, you are free to
modify this Software and publish patches to the Software. You agree that DanswerAI
and/or its licensors (as applicable) retain all right, title and interest in and
to all such modifications and/or patches, and all such modifications and/or
patches may only be used, copied, modified, displayed, distributed, or otherwise
exploited with a valid Onyx Enterprise license for the correct
exploited with a valid Onyx Enterprise License for the correct
number of user seats. Notwithstanding the foregoing, you may copy and modify
the Software for development and testing purposes, without requiring a
subscription. You agree that DanswerAI and/or its licensors (as applicable) retain

View File

@@ -134,7 +134,7 @@ GATED_TENANTS_KEY = "gated_tenants"
# License enforcement - when True, blocks API access for gated/expired licenses
LICENSE_ENFORCEMENT_ENABLED = (
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "").lower() == "true"
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "true").lower() == "true"
)
# Cloud data plane URL - self-hosted instances call this to reach cloud proxy endpoints

View File

@@ -263,9 +263,15 @@ def refresh_license_cache(
try:
payload = verify_license_signature(license_record.license_data)
# Derive source from payload: manual licenses lack stripe_customer_id
source: LicenseSource = (
LicenseSource.AUTO_FETCH
if payload.stripe_customer_id
else LicenseSource.MANUAL_UPLOAD
)
return update_license_cache(
payload,
source=LicenseSource.AUTO_FETCH,
source=source,
tenant_id=tenant_id,
)
except ValueError as e:

View File

@@ -4,7 +4,6 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI
from httpx_oauth.clients.google import GoogleOAuth2
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
from ee.onyx.server.analytics.api import router as analytics_router
from ee.onyx.server.auth_check import check_ee_router_auth
from ee.onyx.server.billing.api import router as billing_router
@@ -151,12 +150,9 @@ def get_application() -> FastAPI:
# License management
include_router_with_global_prefix_prepended(application, license_router)
# Unified billing API - available when license system is enabled
# Works for both self-hosted and cloud deployments
# TODO(ENG-3533): Once frontend migrates to /admin/billing/*, this becomes the
# primary billing API and /tenants/* billing endpoints can be removed
if LICENSE_ENFORCEMENT_ENABLED:
include_router_with_global_prefix_prepended(application, billing_router)
# Unified billing API - always registered in EE.
# Each endpoint is protected by the `current_admin_user` dependency (admin auth).
include_router_with_global_prefix_prepended(application, billing_router)
if MULTI_TENANT:
# Tenant management

View File

@@ -109,7 +109,9 @@ async def _make_billing_request(
headers = _get_headers(license_data)
try:
async with httpx.AsyncClient(timeout=_REQUEST_TIMEOUT) as client:
async with httpx.AsyncClient(
timeout=_REQUEST_TIMEOUT, follow_redirects=True
) as client:
if method == "GET":
response = await client.get(url, headers=headers, params=params)
else:

View File

@@ -1,9 +1,13 @@
"""EE Settings API - provides license-aware settings override."""
from redis.exceptions import RedisError
from sqlalchemy.exc import SQLAlchemyError
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
from ee.onyx.db.license import get_cached_license_metadata
from ee.onyx.db.license import refresh_license_cache
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.server.settings.models import ApplicationStatus
from onyx.server.settings.models import Settings
from onyx.utils.logger import setup_logger
@@ -40,6 +44,14 @@ def check_ee_features_enabled() -> bool:
tenant_id = get_current_tenant_id()
try:
metadata = get_cached_license_metadata(tenant_id)
if not metadata:
# Cache miss — warm from DB so cold-start doesn't block EE features
try:
with get_session_with_current_tenant() as db_session:
metadata = refresh_license_cache(db_session, tenant_id)
except SQLAlchemyError as db_error:
logger.warning(f"Failed to load license from DB: {db_error}")
if metadata and metadata.status != _BLOCKING_STATUS:
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
return True
@@ -81,6 +93,18 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
tenant_id = get_current_tenant_id()
try:
metadata = get_cached_license_metadata(tenant_id)
if not metadata:
# Cache miss (e.g. after TTL expiry). Fall back to DB so
# the /settings request doesn't falsely return GATED_ACCESS
# while the cache is cold.
try:
with get_session_with_current_tenant() as db_session:
metadata = refresh_license_cache(db_session, tenant_id)
except SQLAlchemyError as db_error:
logger.warning(
f"Failed to load license from DB for settings: {db_error}"
)
if metadata:
if metadata.status == _BLOCKING_STATUS:
settings.application_status = metadata.status
@@ -89,7 +113,11 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
settings.ee_features_enabled = True
else:
# No license = community edition, disable EE features
# No license found in cache or DB.
if ENTERPRISE_EDITION_ENABLED:
# Legacy EE flag is set → prior EE usage (e.g. permission
# syncing) means indexed data may need protection.
settings.application_status = _BLOCKING_STATUS
settings.ee_features_enabled = False
except RedisError as e:
logger.warning(f"Failed to check license metadata for settings: {e}")

View File

@@ -177,7 +177,7 @@ async def forward_to_control_plane(
url = f"{CONTROL_PLANE_API_BASE_URL}{path}"
try:
async with httpx.AsyncClient(timeout=30.0) as client:
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
if method == "GET":
response = await client.get(url, headers=headers, params=params)
elif method == "POST":

View File

@@ -12,12 +12,14 @@ from ee.onyx.db.user_group import prepare_user_group_for_deletion
from ee.onyx.db.user_group import update_user_curator_relationship
from ee.onyx.db.user_group import update_user_group
from ee.onyx.server.user_group.models import AddUsersToUserGroupRequest
from ee.onyx.server.user_group.models import MinimalUserGroupSnapshot
from ee.onyx.server.user_group.models import SetCuratorRequest
from ee.onyx.server.user_group.models import UserGroup
from ee.onyx.server.user_group.models import UserGroupCreate
from ee.onyx.server.user_group.models import UserGroupUpdate
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.configs.constants import PUBLIC_API_TAGS
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
@@ -45,6 +47,23 @@ def list_user_groups(
return [UserGroup.from_model(user_group) for user_group in user_groups]
@router.get("/user-groups/minimal")
def list_minimal_user_groups(
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[MinimalUserGroupSnapshot]:
if user.role == UserRole.ADMIN:
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
else:
user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=user.id,
)
return [
MinimalUserGroupSnapshot.from_model(user_group) for user_group in user_groups
]
@router.post("/admin/user-group")
def create_user_group(
user_group: UserGroupCreate,

View File

@@ -76,6 +76,18 @@ class UserGroup(BaseModel):
)
class MinimalUserGroupSnapshot(BaseModel):
id: int
name: str
@classmethod
def from_model(cls, user_group_model: UserGroupModel) -> "MinimalUserGroupSnapshot":
return cls(
id=user_group_model.id,
name=user_group_model.name,
)
class UserGroupCreate(BaseModel):
name: str
user_ids: list[UUID]

View File

@@ -60,6 +60,7 @@ from sqlalchemy import nulls_last
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from onyx.auth.api_key import get_hashed_api_key_from_request
from onyx.auth.disposable_email_validator import is_disposable_email
@@ -110,6 +111,7 @@ from onyx.db.auth import get_user_db
from onyx.db.auth import SQLAlchemyUserAdminDB
from onyx.db.engine.async_sql_engine import get_async_session
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
@@ -272,6 +274,22 @@ def verify_email_domain(email: str) -> None:
)
def enforce_seat_limit(db_session: Session, seats_needed: int = 1) -> None:
"""Raise HTTPException(402) if adding users would exceed the seat limit.
No-op for multi-tenant or CE deployments.
"""
if MULTI_TENANT:
return
result = fetch_ee_implementation_or_noop(
"onyx.db.license", "check_seat_availability", None
)(db_session, seats_needed=seats_needed)
if result is not None and not result.available:
raise HTTPException(status_code=402, detail=result.error_message)
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET
@@ -401,6 +419,12 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
):
user_create.role = UserRole.ADMIN
# Check seat availability for new users (single-tenant only)
with get_session_with_current_tenant() as sync_db:
existing = get_user_by_email(user_create.email, sync_db)
if existing is None:
enforce_seat_limit(sync_db)
user_created = False
try:
user = await super().create(user_create, safe=safe, request=request)
@@ -610,6 +634,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
raise exceptions.UserNotExists()
except exceptions.UserNotExists:
# Check seat availability before creating (single-tenant only)
with get_session_with_current_tenant() as sync_db:
enforce_seat_limit(sync_db)
password = self.password_helper.generate()
user_dict = {
"email": account_email,

View File

@@ -217,9 +217,11 @@ if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
{
"name": "check-for-documents-for-opensearch-migration",
"task": OnyxCeleryTask.CHECK_FOR_DOCUMENTS_FOR_OPENSEARCH_MIGRATION_TASK,
# Try to enqueue an invocation of this task with this frequency.
"schedule": timedelta(seconds=120), # 2 minutes
"options": {
"priority": OnyxCeleryPriority.LOW,
# If the task was not dequeued in this time, revoke it.
"expires": BEAT_EXPIRES_DEFAULT,
},
}
@@ -227,10 +229,18 @@ if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
beat_task_templates.append(
{
"name": "migrate-documents-from-vespa-to-opensearch",
"task": OnyxCeleryTask.MIGRATE_DOCUMENT_FROM_VESPA_TO_OPENSEARCH_TASK,
"task": OnyxCeleryTask.MIGRATE_DOCUMENTS_FROM_VESPA_TO_OPENSEARCH_TASK,
# Try to enqueue an invocation of this task with this frequency.
# NOTE: If MIGRATION_TASK_SOFT_TIME_LIMIT_S is greater than this
# value and the task is maximally busy, we can expect to see some
# enqueued tasks be revoked over time. This is ok; by erring on the
# side of "there will probably always be at least one task of this
# type in the queue", we are minimizing this task's idleness while
# still giving chances for other tasks to execute.
"schedule": timedelta(seconds=120), # 2 minutes
"options": {
"priority": OnyxCeleryPriority.LOW,
# If the task was not dequeued in this time, revoke it.
"expires": BEAT_EXPIRES_DEFAULT,
},
}

View File

@@ -0,0 +1,43 @@
# Tasks are expected to cease execution and do cleanup after the soft time
# limit. In principle they are also forceably terminated after the hard time
# limit, in practice this does not happen since we use threadpools for Celery
# task execution, and we simple hope that the total task time plus cleanup does
# not exceed this. Therefore tasks should regularly check their timeout and lock
# status. The lock timeout is the maximum time the lock manager (Redis in this
# case) will enforce the lock, independent of what is happening in the task. To
# reduce the chances that a task is still doing work while a lock has expired,
# make the lock timeout well above the task timeouts. In practice we should
# never see locks be held for this long anyway because a task should release the
# lock after its cleanup which happens at most after its soft timeout.
# Constants corresponding to migrate_documents_from_vespa_to_opensearch_task.
MIGRATION_TASK_SOFT_TIME_LIMIT_S = 60 * 5 # 5 minutes.
MIGRATION_TASK_TIME_LIMIT_S = 60 * 6 # 6 minutes.
# The maximum time the lock can be held for. Will automatically be released
# after this time.
MIGRATION_TASK_LOCK_TIMEOUT_S = 60 * 7 # 7 minutes.
assert (
MIGRATION_TASK_SOFT_TIME_LIMIT_S < MIGRATION_TASK_TIME_LIMIT_S
), "The soft time limit must be less than the time limit."
assert (
MIGRATION_TASK_TIME_LIMIT_S < MIGRATION_TASK_LOCK_TIMEOUT_S
), "The time limit must be less than the lock timeout."
# Time to wait to acquire the lock.
MIGRATION_TASK_LOCK_BLOCKING_TIMEOUT_S = 60 * 2 # 2 minutes.
# Constants corresponding to check_for_documents_for_opensearch_migration_task.
CHECK_FOR_DOCUMENTS_TASK_SOFT_TIME_LIMIT_S = 60 # 60 seconds / 1 minute.
CHECK_FOR_DOCUMENTS_TASK_TIME_LIMIT_S = 90 # 90 seconds.
# The maximum time the lock can be held for. Will automatically be released
# after this time.
CHECK_FOR_DOCUMENTS_TASK_LOCK_TIMEOUT_S = 120 # 120 seconds / 2 minutes.
assert (
CHECK_FOR_DOCUMENTS_TASK_SOFT_TIME_LIMIT_S < CHECK_FOR_DOCUMENTS_TASK_TIME_LIMIT_S
), "The soft time limit must be less than the time limit."
assert (
CHECK_FOR_DOCUMENTS_TASK_TIME_LIMIT_S < CHECK_FOR_DOCUMENTS_TASK_LOCK_TIMEOUT_S
), "The time limit must be less than the lock timeout."
# Time to wait to acquire the lock.
CHECK_FOR_DOCUMENTS_TASK_LOCK_BLOCKING_TIMEOUT_S = 30 # 30 seconds.
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE = 15

View File

@@ -1,5 +1,6 @@
"""Celery tasks for migrating documents from Vespa to OpenSearch."""
import time
import traceback
from datetime import datetime
from datetime import timezone
@@ -10,6 +11,30 @@ from celery import Task
from redis.lock import Lock as RedisLock
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.opensearch_migration.constants import (
CHECK_FOR_DOCUMENTS_TASK_LOCK_BLOCKING_TIMEOUT_S,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
CHECK_FOR_DOCUMENTS_TASK_LOCK_TIMEOUT_S,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
CHECK_FOR_DOCUMENTS_TASK_SOFT_TIME_LIMIT_S,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
CHECK_FOR_DOCUMENTS_TASK_TIME_LIMIT_S,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
MIGRATION_TASK_LOCK_BLOCKING_TIMEOUT_S,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
MIGRATION_TASK_LOCK_TIMEOUT_S,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
MIGRATION_TASK_SOFT_TIME_LIMIT_S,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
MIGRATION_TASK_TIME_LIMIT_S,
)
from onyx.background.celery.tasks.opensearch_migration.transformer import (
transform_vespa_chunks_to_opensearch_chunks,
)
@@ -31,6 +56,9 @@ from onyx.db.opensearch_migration import (
increment_num_times_observed_no_additional_docs_to_populate_migration_table_with_commit,
)
from onyx.db.opensearch_migration import should_document_migration_be_permanently_failed
from onyx.db.opensearch_migration import (
try_insert_opensearch_tenant_migration_record_with_commit,
)
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.opensearch_document_index import (
@@ -92,10 +120,14 @@ def _migrate_single_document(
name=OnyxCeleryTask.CHECK_FOR_DOCUMENTS_FOR_OPENSEARCH_MIGRATION_TASK,
# Does not store the task's return value in the result backend.
ignore_result=True,
# When exceeded celery will raise a SoftTimeLimitExceeded in the task.
soft_time_limit=60 * 5, # 5 minutes.
# When exceeded the task will be forcefully terminated.
time_limit=60 * 6, # 6 minutes.
# WARNING: This is here just for rigor but since we use threads for Celery
# this config is not respected and timeout logic must be implemented in the
# task.
soft_time_limit=CHECK_FOR_DOCUMENTS_TASK_SOFT_TIME_LIMIT_S,
# WARNING: This is here just for rigor but since we use threads for Celery
# this config is not respected and timeout logic must be implemented in the
# task.
time_limit=CHECK_FOR_DOCUMENTS_TASK_TIME_LIMIT_S,
# Passed in self to the task to get task metadata.
bind=True,
)
@@ -107,7 +139,11 @@ def check_for_documents_for_opensearch_migration_task(
table.
Should not execute meaningful logic at the same time as
migrate_document_from_vespa_to_opensearch_task.
migrate_documents_from_vespa_to_opensearch_task.
Effectively tries to populate as many migration records as possible within
CHECK_FOR_DOCUMENTS_TASK_SOFT_TIME_LIMIT_S seconds. Does so in batches of
1000 documents.
Returns:
None if OpenSearch migration is not enabled, or if the lock could not be
@@ -121,29 +157,33 @@ def check_for_documents_for_opensearch_migration_task(
return None
task_logger.info("Checking for documents for OpenSearch migration.")
task_start_time = time.monotonic()
r = get_redis_client()
# Use a lock to prevent overlapping tasks. Only this task or
# migrate_document_from_vespa_to_opensearch_task can interact with the
# migrate_documents_from_vespa_to_opensearch_task can interact with the
# OpenSearchMigration table at once.
lock_beat: RedisLock = r.lock(
lock: RedisLock = r.lock(
name=OnyxRedisLocks.OPENSEARCH_MIGRATION_BEAT_LOCK,
# The maximum time the lock can be held for. Will automatically be
# released after this time.
timeout=60 * 6, # 6 minutes, same as the time limit for this task.
timeout=CHECK_FOR_DOCUMENTS_TASK_LOCK_TIMEOUT_S,
# .acquire will block until the lock is acquired.
blocking=True,
# Wait for 2 minutes trying to acquire the lock.
blocking_timeout=60 * 2, # 2 minutes.
# Time to wait to acquire the lock.
blocking_timeout=CHECK_FOR_DOCUMENTS_TASK_LOCK_BLOCKING_TIMEOUT_S,
)
if not lock_beat.acquire():
if not lock.acquire():
task_logger.warning(
"The OpenSearch migration check task timed out waiting for the lock."
)
return None
else:
task_logger.info(
f"Acquired the OpenSearch migration check lock. Took {time.monotonic() - task_start_time:.3f} seconds. "
f"Token: {lock.local.token}"
)
num_documents_found_for_record_creation = 0
try:
# Double check that tenant info is correct.
if tenant_id != get_current_tenant_id():
@@ -153,60 +193,84 @@ def check_for_documents_for_opensearch_migration_task(
)
task_logger.error(err_str)
return False
with get_session_with_current_tenant() as db_session:
# For pagination, get the last ID we've inserted into
# OpenSearchMigration.
last_opensearch_migration_document_id = (
get_last_opensearch_migration_document_id(db_session)
)
# Now get the next batch of doc IDs starting after the last ID.
document_ids = get_paginated_document_batch(
db_session,
prev_ending_document_id=last_opensearch_migration_document_id,
)
if not document_ids:
task_logger.info(
"No more documents to insert for OpenSearch migration."
while (
time.monotonic() - task_start_time
< CHECK_FOR_DOCUMENTS_TASK_SOFT_TIME_LIMIT_S
and lock.owned()
):
with get_session_with_current_tenant() as db_session:
# For pagination, get the last ID we've inserted into
# OpenSearchMigration.
last_opensearch_migration_document_id = (
get_last_opensearch_migration_document_id(db_session)
)
increment_num_times_observed_no_additional_docs_to_populate_migration_table_with_commit(
db_session
# Now get the next batch of doc IDs starting after the last ID.
# We'll do 1000 documents per transaction/timeout check.
document_ids = get_paginated_document_batch(
db_session,
limit=1000,
prev_ending_document_id=last_opensearch_migration_document_id,
)
# TODO(andrei): Once we've done this enough times and the number
# of documents matches the number of migration records, we can
# be done with this task and update
# document_migration_record_table_population_status.
return True
# Create the migration records for the next batch of documents with
# status PENDING.
create_opensearch_migration_records_with_commit(db_session, document_ids)
task_logger.info(
f"Created {len(document_ids)} migration records for the next batch of documents."
)
if not document_ids:
task_logger.info(
"No more documents to insert for OpenSearch migration."
)
increment_num_times_observed_no_additional_docs_to_populate_migration_table_with_commit(
db_session
)
# TODO(andrei): Once we've done this enough times and the
# number of documents matches the number of migration
# records, we can be done with this task and update
# document_migration_record_table_population_status.
return True
# Create the migration records for the next batch of documents
# with status PENDING.
create_opensearch_migration_records_with_commit(
db_session, document_ids
)
num_documents_found_for_record_creation += len(document_ids)
# Try to create the singleton row in
# OpenSearchTenantMigrationRecord if it doesn't already exist.
# This is a reasonable place to put it because we already have a
# lock, a session, and error handling, at the cost of running
# this small set of logic for every batch.
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
except Exception:
task_logger.exception("Error in the OpenSearch migration check task.")
return False
finally:
if lock_beat.owned():
lock_beat.release()
if lock.owned():
lock.release()
else:
task_logger.warning(
"The OpenSearch migration lock was not owned on completion of the check task."
)
task_logger.info(
f"Finished checking for documents for OpenSearch migration. Found {num_documents_found_for_record_creation} documents "
f"to create migration records for in {time.monotonic() - task_start_time:.3f} seconds. However, this may include "
"documents for which there already exist records."
)
return True
# shared_task allows this task to be shared across celery app instances.
@shared_task(
name=OnyxCeleryTask.MIGRATE_DOCUMENT_FROM_VESPA_TO_OPENSEARCH_TASK,
name=OnyxCeleryTask.MIGRATE_DOCUMENTS_FROM_VESPA_TO_OPENSEARCH_TASK,
# Does not store the task's return value in the result backend.
ignore_result=True,
# When exceeded celery will raise a SoftTimeLimitExceeded in the task.
soft_time_limit=60 * 5, # 5 minutes.
# When exceeded the task will be forcefully terminated.
time_limit=60 * 6, # 6 minutes.
# WARNING: This is here just for rigor but since we use threads for Celery
# this config is not respected and timeout logic must be implemented in the
# task.
soft_time_limit=MIGRATION_TASK_SOFT_TIME_LIMIT_S,
# WARNING: This is here just for rigor but since we use threads for Celery
# this config is not respected and timeout logic must be implemented in the
# task.
time_limit=MIGRATION_TASK_TIME_LIMIT_S,
# Passed in self to the task to get task metadata.
bind=True,
)
@@ -220,10 +284,13 @@ def migrate_documents_from_vespa_to_opensearch_task(
Should not execute meaningful logic at the same time as
check_for_documents_for_opensearch_migration_task.
Effectively tries to migrate as many documents as possible within
MIGRATION_TASK_SOFT_TIME_LIMIT_S seconds. Does so in batches of 5 documents.
Returns:
None if OpenSearch migration is not enabled, or if the lock could not be
acquired; effectively a no-op. True if the task completed
successfully. False if the task failed.
successfully. False if the task errored.
"""
if not ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
task_logger.warning(
@@ -231,30 +298,36 @@ def migrate_documents_from_vespa_to_opensearch_task(
)
return None
task_logger.info("Trying to migrate documents from Vespa to OpenSearch.")
task_logger.info("Trying a migration batch from Vespa to OpenSearch.")
task_start_time = time.monotonic()
r = get_redis_client()
# Use a lock to prevent overlapping tasks. Only this task or
# check_for_documents_for_opensearch_migration_task can interact with the
# OpenSearchMigration table at once.
lock_beat: RedisLock = r.lock(
lock: RedisLock = r.lock(
name=OnyxRedisLocks.OPENSEARCH_MIGRATION_BEAT_LOCK,
# The maximum time the lock can be held for. Will automatically be
# released after this time.
timeout=60 * 6, # 6 minutes, same as the time limit for this task.
timeout=MIGRATION_TASK_LOCK_TIMEOUT_S,
# .acquire will block until the lock is acquired.
blocking=True,
# Wait for 2 minutes trying to acquire the lock.
blocking_timeout=60 * 2, # 2 minutes.
# Time to wait to acquire the lock.
blocking_timeout=MIGRATION_TASK_LOCK_BLOCKING_TIMEOUT_S,
)
if not lock_beat.acquire():
if not lock.acquire():
task_logger.warning(
"The OpenSearch migration task timed out waiting for the lock."
)
return None
else:
task_logger.info(
f"Acquired the OpenSearch migration lock. Took {time.monotonic() - task_start_time:.3f} seconds. "
f"Token: {lock.local.token}"
)
num_documents_migrated = 0
num_chunks_migrated = 0
num_documents_failed = 0
try:
# Double check that tenant info is correct.
if tenant_id != get_current_tenant_id():
@@ -264,98 +337,111 @@ def migrate_documents_from_vespa_to_opensearch_task(
)
task_logger.error(err_str)
return False
with get_session_with_current_tenant() as db_session:
records_needing_migration = (
get_opensearch_migration_records_needing_migration(db_session)
)
if not records_needing_migration:
task_logger.info(
"No documents found that need to be migrated from Vespa to OpenSearch."
)
increment_num_times_observed_no_additional_docs_to_migrate_with_commit(
db_session
)
# TODO(andrei): Once we've done this enough times and
# document_migration_record_table_population_status is done, we
# can be done with this task and update
# overall_document_migration_status accordingly. Note that this
# includes marking connectors as needing reindexing if some
# migrations failed.
return True
search_settings = get_current_search_settings(db_session)
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
opensearch_document_index = OpenSearchDocumentIndex(
index_name=search_settings.index_name, tenant_state=tenant_state
)
vespa_document_index = VespaDocumentIndex(
index_name=search_settings.index_name,
tenant_state=tenant_state,
large_chunks_enabled=False,
)
task_logger.info(
f"Trying to migrate {len(records_needing_migration)} documents from Vespa to OpenSearch."
)
for record in records_needing_migration:
try:
# If the Document's chunk count is not known, it was
# probably just indexed so fail here to give it a chance to
# sync. If in the rare event this Document has not been
# re-indexed in a very long time and is still under the
# "old" embedding/indexing logic where chunk count was never
# stored, we will eventually permanently fail and thus force
# a re-index of this doc, which is a desireable outcome.
if record.document.chunk_count is None:
raise RuntimeError(
f"Document {record.document_id} has no chunk count."
)
chunks_migrated = _migrate_single_document(
document_id=record.document_id,
opensearch_document_index=opensearch_document_index,
vespa_document_index=vespa_document_index,
tenant_state=tenant_state,
while (
time.monotonic() - task_start_time < MIGRATION_TASK_SOFT_TIME_LIMIT_S
and lock.owned()
):
with get_session_with_current_tenant() as db_session:
# We'll do 5 documents per transaction/timeout check.
records_needing_migration = (
get_opensearch_migration_records_needing_migration(
db_session, limit=5
)
# If the number of chunks in Vespa is not in sync with the
# Document table for this doc let's not consider this
# completed and let's let a subsequent run take care of it.
if chunks_migrated != record.document.chunk_count:
raise RuntimeError(
f"Number of chunks migrated ({chunks_migrated}) does not match number of expected chunks in Vespa "
f"({record.document.chunk_count}) for document {record.document_id}."
)
record.status = OpenSearchDocumentMigrationStatus.COMPLETED
except Exception:
record.status = OpenSearchDocumentMigrationStatus.FAILED
record.error_message = f"Attempt {record.attempts_count + 1}:\n{traceback.format_exc()}"
task_logger.exception(
f"Error migrating document {record.document_id} from Vespa to OpenSearch."
)
if not records_needing_migration:
task_logger.info(
"No documents found that need to be migrated from Vespa to OpenSearch."
)
finally:
record.attempts_count += 1
record.last_attempt_at = datetime.now(timezone.utc)
if should_document_migration_be_permanently_failed(record):
record.status = (
OpenSearchDocumentMigrationStatus.PERMANENTLY_FAILED
)
# TODO(andrei): Not necessarily here but if this happens
# we'll need to mark the connector as needing reindex.
increment_num_times_observed_no_additional_docs_to_migrate_with_commit(
db_session
)
# TODO(andrei): Once we've done this enough times and
# document_migration_record_table_population_status is done, we
# can be done with this task and update
# overall_document_migration_status accordingly. Note that this
# includes marking connectors as needing reindexing if some
# migrations failed.
return True
db_session.commit()
search_settings = get_current_search_settings(db_session)
tenant_state = TenantState(
tenant_id=tenant_id, multitenant=MULTI_TENANT
)
opensearch_document_index = OpenSearchDocumentIndex(
index_name=search_settings.index_name, tenant_state=tenant_state
)
vespa_document_index = VespaDocumentIndex(
index_name=search_settings.index_name,
tenant_state=tenant_state,
large_chunks_enabled=False,
)
for record in records_needing_migration:
try:
# If the Document's chunk count is not known, it was
# probably just indexed so fail here to give it a chance to
# sync. If in the rare event this Document has not been
# re-indexed in a very long time and is still under the
# "old" embedding/indexing logic where chunk count was never
# stored, we will eventually permanently fail and thus force
# a re-index of this doc, which is a desireable outcome.
if record.document.chunk_count is None:
raise RuntimeError(
f"Document {record.document_id} has no chunk count."
)
chunks_migrated = _migrate_single_document(
document_id=record.document_id,
opensearch_document_index=opensearch_document_index,
vespa_document_index=vespa_document_index,
tenant_state=tenant_state,
)
# If the number of chunks in Vespa is not in sync with the
# Document table for this doc let's not consider this
# completed and let's let a subsequent run take care of it.
if chunks_migrated != record.document.chunk_count:
raise RuntimeError(
f"Number of chunks migrated ({chunks_migrated}) does not match number of expected chunks "
f"in Vespa ({record.document.chunk_count}) for document {record.document_id}."
)
record.status = OpenSearchDocumentMigrationStatus.COMPLETED
num_documents_migrated += 1
num_chunks_migrated += chunks_migrated
except Exception:
record.status = OpenSearchDocumentMigrationStatus.FAILED
record.error_message = f"Attempt {record.attempts_count + 1}:\n{traceback.format_exc()}"
task_logger.exception(
f"Error migrating document {record.document_id} from Vespa to OpenSearch."
)
num_documents_failed += 1
finally:
record.attempts_count += 1
record.last_attempt_at = datetime.now(timezone.utc)
if should_document_migration_be_permanently_failed(record):
record.status = (
OpenSearchDocumentMigrationStatus.PERMANENTLY_FAILED
)
# TODO(andrei): Not necessarily here but if this happens
# we'll need to mark the connector as needing reindex.
db_session.commit()
except Exception:
task_logger.exception("Error in the OpenSearch migration task.")
return False
finally:
if lock_beat.owned():
lock_beat.release()
if lock.owned():
lock.release()
else:
task_logger.warning(
"The OpenSearch migration lock was not owned on completion of the migration task."
)
task_logger.info(
f"Finished a migration batch from Vespa to OpenSearch. Migrated {num_chunks_migrated} chunks "
f"from {num_documents_migrated} documents in {time.monotonic() - task_start_time:.3f} seconds. "
f"Failed to migrate {num_documents_failed} documents."
)
return True

View File

@@ -12,6 +12,7 @@ from retry import retry
from sqlalchemy import select
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.configs.app_configs import MANAGED_VESPA
@@ -19,12 +20,14 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
from onyx.connectors.file.connector import LocalFileConnector
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
@@ -54,6 +57,17 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
def _user_file_queued_key(user_file_id: str | UUID) -> str:
"""Key that exists while a process_single_user_file task is sitting in the queue.
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
before enqueuing and the worker deletes it as its first action. This prevents
the beat from adding duplicate tasks for files that already have a live task
in flight.
"""
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
@@ -117,7 +131,24 @@ def _get_document_chunk_count(
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
Uses direct Redis locks to avoid overlapping runs.
Three mechanisms prevent queue runaway:
1. **Queue depth backpressure** if the broker queue already has more than
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
entirely. Workers are clearly behind; adding more tasks would only make
the backlog worse.
2. **Per-file queued guard** before enqueuing a task we set a short-lived
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
already exists the file already has a live task in the queue, so we skip
it. The worker deletes the key the moment it picks up the task so the
next beat cycle can re-enqueue if the file is still PROCESSING.
3. **Task expiry** every enqueued task carries an `expires` value equal to
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
the queue after that deadline, Celery discards it without touching the DB.
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
Redis restart), stale tasks evict themselves rather than piling up forever.
"""
task_logger.info("check_user_file_processing - Starting")
@@ -132,7 +163,21 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
return None
enqueued = 0
skipped_guard = 0
try:
# --- Protection 1: queue depth backpressure ---
r_celery = self.app.broker_connection().channel().client # type: ignore
queue_len = celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
)
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
task_logger.warning(
f"check_user_file_processing - Queue depth {queue_len} exceeds "
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
f"tenant={tenant_id}"
)
return None
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
@@ -145,12 +190,35 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
)
for user_file_id in user_file_ids:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
# --- Protection 2: per-file queued guard ---
queued_key = _user_file_queued_key(user_file_id)
guard_set = redis_client.set(
queued_key,
1,
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
nx=True,
)
if not guard_set:
skipped_guard += 1
continue
# --- Protection 3: task expiry ---
# If task submission fails, clear the guard immediately so the
# next beat cycle can retry enqueuing this file.
try:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={
"user_file_id": str(user_file_id),
"tenant_id": tenant_id,
},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
)
except Exception:
redis_client.delete(queued_key)
raise
enqueued += 1
finally:
@@ -158,7 +226,8 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
lock.release()
task_logger.info(
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
f"tasks for tenant={tenant_id}"
)
return None
@@ -175,6 +244,12 @@ def process_single_user_file(
start = time.monotonic()
redis_client = get_redis_client(tenant_id=tenant_id)
# Clear the "queued" guard set by the beat generator so that the next beat
# cycle can re-enqueue this file if it is still in PROCESSING state after
# this task completes or fails.
redis_client.delete(_user_file_queued_key(user_file_id))
file_lock: RedisLock = redis_client.lock(
_user_file_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,

View File

@@ -22,11 +22,13 @@ from onyx.chat.prompt_utils import build_system_prompt
from onyx.chat.prompt_utils import (
get_default_base_system_prompt,
)
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MessageType
from onyx.context.search.models import SearchDoc
from onyx.context.search.models import SearchDocsResponse
from onyx.db.models import Persona
from onyx.llm.constants import LlmProviderNames
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMUserIdentity
from onyx.llm.interfaces import ToolChoiceOptions
@@ -36,6 +38,7 @@ from onyx.prompts.chat_prompts import OPEN_URL_REMINDER
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import ToolCallDebug
from onyx.server.query_and_chat.streaming_models import TopLevelBranching
from onyx.tools.built_in_tools import CITEABLE_TOOLS_NAMES
from onyx.tools.built_in_tools import STOPPING_TOOLS_NAMES
@@ -57,6 +60,28 @@ from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
def _should_keep_bedrock_tool_definitions(
llm: object, simple_chat_history: list[ChatMessageSimple]
) -> bool:
"""Bedrock requires tool config when history includes toolUse/toolResult blocks."""
model_provider = getattr(getattr(llm, "config", None), "model_provider", None)
if model_provider not in {
LlmProviderNames.BEDROCK,
LlmProviderNames.BEDROCK_CONVERSE,
}:
return False
return any(
(
msg.message_type == MessageType.ASSISTANT
and msg.tool_calls
and len(msg.tool_calls) > 0
)
or msg.message_type == MessageType.TOOL_CALL_RESPONSE
for msg in simple_chat_history
)
def _try_fallback_tool_extraction(
llm_step_result: LlmStepResult,
tool_choice: ToolChoiceOptions,
@@ -452,7 +477,12 @@ def run_llm_loop(
elif out_of_cycles or ran_image_gen:
# Last cycle, no tools allowed, just answer!
tool_choice = ToolChoiceOptions.NONE
final_tools = []
# Bedrock requires tool config in requests that include toolUse/toolResult history.
final_tools = (
tools
if _should_keep_bedrock_tool_definitions(llm, simple_chat_history)
else []
)
else:
tool_choice = ToolChoiceOptions.AUTO
final_tools = tools
@@ -601,6 +631,19 @@ def run_llm_loop(
tool_responses: list[ToolResponse] = []
tool_calls = llm_step_result.tool_calls or []
if INTEGRATION_TESTS_MODE and tool_calls:
for tool_call in tool_calls:
emitter.emit(
Packet(
placement=tool_call.placement,
obj=ToolCallDebug(
tool_call_id=tool_call.tool_call_id,
tool_name=tool_call.tool_name,
tool_args=tool_call.tool_args,
),
)
)
if len(tool_calls) > 1:
emitter.emit(
Packet(

View File

@@ -7,6 +7,7 @@ import re
import time
import traceback
from collections.abc import Callable
from contextvars import Token
from uuid import UUID
from redis.client import Redis
@@ -43,6 +44,7 @@ from onyx.chat.prompt_utils import calculate_reserved_tokens
from onyx.chat.save_chat import save_chat_turn
from onyx.chat.stop_signal_checker import is_connected as check_stop_signal
from onyx.chat.stop_signal_checker import reset_cancel_status
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MessageType
@@ -69,6 +71,8 @@ from onyx.llm.factory import get_llm_for_persona
from onyx.llm.factory import get_llm_token_counter
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMUserIdentity
from onyx.llm.request_context import reset_llm_mock_response
from onyx.llm.request_context import set_llm_mock_response
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.onyxbot.slack.models import SlackContext
from onyx.redis.redis_pool import get_redis_client
@@ -90,10 +94,6 @@ from onyx.tools.tool_constructor import SearchToolConfig
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.timing import log_function_time
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from onyx.utils.variable_functionality import noop_fallback
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -318,6 +318,7 @@ def handle_stream_message_objects(
) -> AnswerStream:
tenant_id = get_current_tenant_id()
processing_start_time = time.monotonic()
mock_response_token: Token[str | None] | None = None
llm: LLM | None = None
chat_session: ChatSession | None = None
@@ -328,6 +329,14 @@ def handle_stream_message_objects(
llm_user_identifier = "anonymous_user"
else:
llm_user_identifier = user.email or str(user_id)
if new_msg_req.mock_llm_response is not None:
if not INTEGRATION_TESTS_MODE:
raise ValueError(
"mock_llm_response can only be used when INTEGRATION_TESTS_MODE=true"
)
mock_response_token = set_llm_mock_response(new_msg_req.mock_llm_response)
try:
if not new_msg_req.chat_session_id:
if not new_msg_req.chat_session_info:
@@ -361,21 +370,16 @@ def handle_stream_message_objects(
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
)
# Track user message in PostHog for analytics
fetch_versioned_implementation_with_fallback(
module="onyx.utils.telemetry",
attribute="event_telemetry",
fallback=noop_fallback,
)(
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=user.email if not user.is_anonymous else tenant_id,
event="user_message_sent",
event=MilestoneRecordType.USER_MESSAGE_SENT,
properties={
"origin": new_msg_req.origin.value,
"has_files": len(new_msg_req.file_descriptors) > 0,
"has_project": chat_session.project_id is not None,
"has_persona": persona is not None and persona.id != DEFAULT_PERSONA_ID,
"deep_research": new_msg_req.deep_research,
"tenant_id": tenant_id,
},
)
@@ -723,6 +727,9 @@ def handle_stream_message_objects(
db_session.rollback()
finally:
if mock_response_token is not None:
reset_llm_mock_response(mock_response_token)
try:
if redis_client is not None and chat_session is not None:
set_processing_status(
@@ -839,6 +846,7 @@ def stream_chat_message_objects(
translated_new_msg_req = SendMessageRequest(
message=new_msg_req.message,
llm_override=new_msg_req.llm_override,
mock_llm_response=new_msg_req.mock_llm_response,
allowed_tool_ids=new_msg_req.allowed_tool_ids,
forced_tool_id=forced_tool_id,
file_descriptors=new_msg_req.file_descriptors,

View File

@@ -75,7 +75,7 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
# Auth Configs
#####
# Upgrades users from disabled auth to basic auth and shows warning.
_auth_type_str = (os.environ.get("AUTH_TYPE") or "").lower()
_auth_type_str = (os.environ.get("AUTH_TYPE") or "basic").lower()
if _auth_type_str == "disabled":
logger.warning(
"AUTH_TYPE='disabled' is no longer supported. "
@@ -900,6 +900,9 @@ MANAGED_VESPA = os.environ.get("MANAGED_VESPA", "").lower() == "true"
ENABLE_EMAIL_INVITES = os.environ.get("ENABLE_EMAIL_INVITES", "").lower() == "true"
# Limit on number of users a free trial tenant can invite (cloud only)
NUM_FREE_TRIAL_USER_INVITES = int(os.environ.get("NUM_FREE_TRIAL_USER_INVITES", "10"))
# Security and authentication
DATA_PLANE_SECRET = os.environ.get(
"DATA_PLANE_SECRET", ""

View File

@@ -158,6 +158,17 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
# How long a queued user-file task is valid before workers discard it.
# Should be longer than the beat interval (20 s) but short enough to prevent
# indefinite queue growth. Workers drop tasks older than this without touching
# the DB, so a shorter value = faster drain of stale duplicates.
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
# Maximum number of tasks allowed in the user-file-processing queue before the
# beat generator stops adding more. Prevents unbounded queue growth when workers
# fall behind.
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
@@ -351,6 +362,7 @@ class MilestoneRecordType(str, Enum):
CREATED_CONNECTOR = "created_connector"
CONNECTOR_SUCCEEDED = "connector_succeeded"
RAN_QUERY = "ran_query"
USER_MESSAGE_SENT = "user_message_sent"
MULTIPLE_ASSISTANTS = "multiple_assistants"
CREATED_ASSISTANT = "created_assistant"
CREATED_ONYX_BOT = "created_onyx_bot"
@@ -434,6 +446,9 @@ class OnyxRedisLocks:
# User file processing
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
# Prevents the beat from re-enqueuing the same file while a task is already queued.
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
@@ -573,8 +588,8 @@ class OnyxCeleryTask:
CHECK_FOR_DOCUMENTS_FOR_OPENSEARCH_MIGRATION_TASK = (
"check_for_documents_for_opensearch_migration_task"
)
MIGRATE_DOCUMENT_FROM_VESPA_TO_OPENSEARCH_TASK = (
"migrate_document_from_vespa_to_opensearch_task"
MIGRATE_DOCUMENTS_FROM_VESPA_TO_OPENSEARCH_TASK = (
"migrate_documents_from_vespa_to_opensearch_task"
)

View File

@@ -32,6 +32,8 @@ class GongConnector(LoadConnector, PollConnector):
BASE_URL = "https://api.gong.io"
MAX_CALL_DETAILS_ATTEMPTS = 6
CALL_DETAILS_DELAY = 30 # in seconds
# Gong API limit is 3 calls/sec — stay safely under it
MIN_REQUEST_INTERVAL = 0.5 # seconds between requests
def __init__(
self,
@@ -45,9 +47,13 @@ class GongConnector(LoadConnector, PollConnector):
self.continue_on_fail = continue_on_fail
self.auth_token_basic: str | None = None
self.hide_user_info = hide_user_info
self._last_request_time: float = 0.0
# urllib3 Retry already respects the Retry-After header by default
# (respect_retry_after_header=True), so on 429 it will sleep for the
# duration Gong specifies before retrying.
retry_strategy = Retry(
total=5,
total=10,
backoff_factor=2,
status_forcelist=[429, 500, 502, 503, 504],
)
@@ -61,8 +67,24 @@ class GongConnector(LoadConnector, PollConnector):
url = f"{GongConnector.BASE_URL}{endpoint}"
return url
def _throttled_request(
self, method: str, url: str, **kwargs: Any
) -> requests.Response:
"""Rate-limited request wrapper. Enforces MIN_REQUEST_INTERVAL between
calls to stay under Gong's 3 calls/sec limit and avoid triggering 429s."""
now = time.monotonic()
elapsed = now - self._last_request_time
if elapsed < self.MIN_REQUEST_INTERVAL:
time.sleep(self.MIN_REQUEST_INTERVAL - elapsed)
response = self._session.request(method, url, **kwargs)
self._last_request_time = time.monotonic()
return response
def _get_workspace_id_map(self) -> dict[str, str]:
response = self._session.get(GongConnector.make_url("/v2/workspaces"))
response = self._throttled_request(
"GET", GongConnector.make_url("/v2/workspaces")
)
response.raise_for_status()
workspaces_details = response.json().get("workspaces")
@@ -106,8 +128,8 @@ class GongConnector(LoadConnector, PollConnector):
del body["filter"]["workspaceId"]
while True:
response = self._session.post(
GongConnector.make_url("/v2/calls/transcript"), json=body
response = self._throttled_request(
"POST", GongConnector.make_url("/v2/calls/transcript"), json=body
)
# If no calls in the range, just break out
if response.status_code == 404:
@@ -142,8 +164,8 @@ class GongConnector(LoadConnector, PollConnector):
"contentSelector": {"exposedFields": {"parties": True}},
}
response = self._session.post(
GongConnector.make_url("/v2/calls/extensive"), json=body
response = self._throttled_request(
"POST", GongConnector.make_url("/v2/calls/extensive"), json=body
)
response.raise_for_status()
@@ -194,7 +216,8 @@ class GongConnector(LoadConnector, PollConnector):
# There's a likely race condition in the API where a transcript will have a
# call id but the call to v2/calls/extensive will not return all of the id's
# retry with exponential backoff has been observed to mitigate this
# in ~2 minutes
# in ~2 minutes. After max attempts, proceed with whatever we have —
# the per-call loop below will skip missing IDs gracefully.
current_attempt = 0
while True:
current_attempt += 1
@@ -213,11 +236,14 @@ class GongConnector(LoadConnector, PollConnector):
f"missing_call_ids={missing_call_ids}"
)
if current_attempt >= self.MAX_CALL_DETAILS_ATTEMPTS:
raise RuntimeError(
f"Attempt count exceeded for _get_call_details_by_ids: "
f"missing_call_ids={missing_call_ids} "
f"max_attempts={self.MAX_CALL_DETAILS_ATTEMPTS}"
logger.error(
f"Giving up on missing call id's after "
f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: "
f"missing_call_ids={missing_call_ids}"
f"proceeding with {len(call_details_map)} of "
f"{len(transcript_call_ids)} calls"
)
break
wait_seconds = self.CALL_DETAILS_DELAY * pow(2, current_attempt - 1)
logger.warning(

View File

@@ -6,6 +6,7 @@ from typing import cast
from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator
from pydantic import model_validator
from onyx.access.models import ExternalAccess
@@ -167,6 +168,14 @@ class DocumentBase(BaseModel):
# list of strings.
metadata: dict[str, str | list[str]]
@field_validator("metadata", mode="before")
@classmethod
def _coerce_metadata_values(cls, v: dict[str, Any]) -> dict[str, str | list[str]]:
return {
key: [str(item) for item in val] if isinstance(val, list) else str(val)
for key, val in v.items()
}
# UTC time
doc_updated_at: datetime | None = None
chunk_count: int | None = None

View File

@@ -228,14 +228,13 @@ class BuildSessionStatus(str, PyEnum):
class SandboxStatus(str, PyEnum):
PROVISIONING = "provisioning"
RUNNING = "running"
IDLE = "idle"
SLEEPING = "sleeping" # Pod terminated, snapshots saved to S3
TERMINATED = "terminated"
FAILED = "failed"
def is_active(self) -> bool:
"""Check if sandbox is in an active state (running or idle)."""
return self in (SandboxStatus.RUNNING, SandboxStatus.IDLE)
"""Check if sandbox is in an active state (running)."""
return self == SandboxStatus.RUNNING
def is_terminal(self) -> bool:
"""Check if sandbox is in a terminal state."""

View File

@@ -109,45 +109,38 @@ def can_user_access_llm_provider(
is_admin: If True, bypass user group restrictions but still respect persona restrictions
Access logic:
1. If is_public=True → everyone has access (public override)
2. If is_public=False:
- Both groups AND personas set → must satisfy BOTH (AND logic, admins bypass group check)
- Only groups set → must be in one of the groups (OR across groups, admins bypass)
- Only personas set → must use one of the personas (OR across personas, applies to admins)
- Neither set → NOBODY has access unless admin (locked, admin-only)
- is_public controls USER access (group bypass): when True, all users can access
regardless of group membership. When False, user must be in a whitelisted group
(or be admin).
- Persona restrictions are ALWAYS enforced when set, regardless of is_public.
This allows admins to make a provider available to all users while still
restricting which personas (assistants) can use it.
Decision matrix:
1. is_public=True, no personas set → everyone has access
2. is_public=True, personas set → all users, but only whitelisted personas
3. is_public=False, groups+personas set → must satisfy BOTH (admins bypass groups)
4. is_public=False, only groups set → must be in group (admins bypass)
5. is_public=False, only personas set → must use whitelisted persona
6. is_public=False, neither set → admin-only (locked)
"""
# Public override - everyone has access
if provider.is_public:
return True
# Extract IDs once to avoid multiple iterations
provider_group_ids = (
{group.id for group in provider.groups} if provider.groups else set()
)
provider_persona_ids = (
{p.id for p in provider.personas} if provider.personas else set()
)
provider_group_ids = {g.id for g in (provider.groups or [])}
provider_persona_ids = {p.id for p in (provider.personas or [])}
has_groups = bool(provider_group_ids)
has_personas = bool(provider_persona_ids)
# Both groups AND personas set → AND logic (must satisfy both)
if has_groups and has_personas:
# Admins bypass group check but still must satisfy persona restrictions
user_in_group = is_admin or bool(user_group_ids & provider_group_ids)
persona_allowed = persona.id in provider_persona_ids if persona else False
return user_in_group and persona_allowed
# Persona restrictions are always enforced when set, regardless of is_public
if has_personas and not (persona and persona.id in provider_persona_ids):
return False
if provider.is_public:
return True
# Only groups set → user must be in one of the groups (admins bypass)
if has_groups:
return is_admin or bool(user_group_ids & provider_group_ids)
# Only personas set → persona must be in allowed list (applies to admins too)
if has_personas:
return persona.id in provider_persona_ids if persona else False
# Neither groups nor personas set, and not public → admins can access
return is_admin
# No groups: either persona-whitelisted (already passed) or admin-only if locked
return has_personas or is_admin
def validate_persona_ids_exist(
@@ -428,7 +421,7 @@ def fetch_existing_models(
def fetch_existing_llm_providers(
db_session: Session,
flow_types: list[LLMModelFlowType],
flow_type_filter: list[LLMModelFlowType],
only_public: bool = False,
exclude_image_generation_providers: bool = True,
) -> list[LLMProviderModel]:
@@ -436,30 +429,27 @@ def fetch_existing_llm_providers(
Args:
db_session: Database session
flow_types: List of flow types to filter by
flow_type_filter: List of flow types to filter by, empty list for no filter
only_public: If True, only return public providers
exclude_image_generation_providers: If True, exclude providers that are
used for image generation configs
"""
providers_with_flows = (
select(ModelConfiguration.llm_provider_id)
.join(LLMModelFlow)
.where(LLMModelFlow.llm_model_flow_type.in_(flow_types))
.distinct()
)
stmt = select(LLMProviderModel)
if flow_type_filter:
providers_with_flows = (
select(ModelConfiguration.llm_provider_id)
.join(LLMModelFlow)
.where(LLMModelFlow.llm_model_flow_type.in_(flow_type_filter))
.distinct()
)
stmt = stmt.where(LLMProviderModel.id.in_(providers_with_flows))
if exclude_image_generation_providers:
stmt = select(LLMProviderModel).where(
LLMProviderModel.id.in_(providers_with_flows)
)
else:
image_gen_provider_ids = select(ModelConfiguration.llm_provider_id).join(
ImageGenerationConfig
)
stmt = select(LLMProviderModel).where(
LLMProviderModel.id.in_(providers_with_flows)
| LLMProviderModel.id.in_(image_gen_provider_ids)
)
stmt = stmt.where(~LLMProviderModel.id.in_(image_gen_provider_ids))
stmt = stmt.options(
selectinload(LLMProviderModel.model_configurations),
@@ -722,13 +712,15 @@ def sync_auto_mode_models(
changes += 1
else:
# Add new model - all models from GitHub config are visible
new_model = ModelConfiguration(
insert_new_model_configuration__no_commit(
db_session=db_session,
llm_provider_id=provider.id,
name=model_config.name,
display_name=model_config.display_name,
model_name=model_config.name,
supported_flows=[LLMModelFlowType.CHAT],
is_visible=True,
max_input_tokens=None,
display_name=model_config.display_name,
)
db_session.add(new_model)
changes += 1
# In Auto mode, default model is always set from GitHub config

View File

@@ -9,6 +9,9 @@ from sqlalchemy import text
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session
from onyx.background.celery.tasks.opensearch_migration.constants import (
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE,
)
from onyx.db.enums import OpenSearchDocumentMigrationStatus
from onyx.db.models import Document
from onyx.db.models import OpenSearchDocumentMigrationRecord
@@ -18,18 +21,21 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE = 15
DEFAULT_BATCH_SIZE_OF_DOCUMENTS_TO_MIGRATE = 500
DEFAULT_BATCH_SIZE_OF_DOCUMENTS_TO_CHECK_FOR_MIGRATION = 2000
def get_paginated_document_batch(
db_session: Session,
limit: int = DEFAULT_BATCH_SIZE_OF_DOCUMENTS_TO_CHECK_FOR_MIGRATION,
limit: int,
prev_ending_document_id: str | None = None,
) -> list[str]:
"""Gets a paginated batch of document IDs from the Document table.
We need some deterministic ordering to ensure that we don't miss any
documents when paginating. This function uses the document ID. It is
possible a document is inserted above a spot this function has already
passed. In that event we assume that the document will be indexed into
OpenSearch anyway and we don't need to migrate.
TODO(andrei): Consider ordering on last_modified in addition to ID to better
match get_opensearch_migration_records_needing_migration.
Args:
db_session: SQLAlchemy session.
limit: Number of document IDs to fetch.
@@ -91,7 +97,7 @@ def create_opensearch_migration_records_with_commit(
def get_opensearch_migration_records_needing_migration(
db_session: Session,
limit: int = DEFAULT_BATCH_SIZE_OF_DOCUMENTS_TO_MIGRATE,
limit: int,
) -> list[OpenSearchDocumentMigrationRecord]:
"""Gets records of documents that need to be migrated.
@@ -165,6 +171,20 @@ def get_total_document_count(db_session: Session) -> int:
return db_session.query(Document).count()
def try_insert_opensearch_tenant_migration_record_with_commit(
db_session: Session,
) -> None:
"""Tries to insert the singleton row on OpenSearchTenantMigrationRecord.
If the row already exists, does nothing.
"""
stmt = insert(OpenSearchTenantMigrationRecord).on_conflict_do_nothing(
index_elements=[text("(true)")]
)
db_session.execute(stmt)
db_session.commit()
def increment_num_times_observed_no_additional_docs_to_migrate_with_commit(
db_session: Session,
) -> None:

View File

@@ -21,6 +21,7 @@ from onyx.llm.model_response import ModelResponseStream
from onyx.llm.model_response import Usage
from onyx.llm.models import ANTHROPIC_REASONING_EFFORT_BUDGET
from onyx.llm.models import OPENAI_REASONING_EFFORT
from onyx.llm.request_context import get_llm_mock_response
from onyx.llm.utils import build_litellm_passthrough_kwargs
from onyx.llm.utils import is_true_openai_model
from onyx.llm.utils import model_is_reasoning_model
@@ -378,7 +379,7 @@ class LitellmLLM(LLM):
passthrough_kwargs["api_key"] = self._api_key or None
response = litellm.completion(
mock_response=MOCK_LLM_RESPONSE,
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
model=model,
base_url=self._api_base or None,
api_version=self._api_version or None,

View File

@@ -0,0 +1,18 @@
import contextvars
_LLM_MOCK_RESPONSE_CONTEXTVAR: contextvars.ContextVar[str | None] = (
contextvars.ContextVar("llm_mock_response", default=None)
)
def get_llm_mock_response() -> str | None:
return _LLM_MOCK_RESPONSE_CONTEXTVAR.get()
def set_llm_mock_response(mock_response: str | None) -> contextvars.Token[str | None]:
return _LLM_MOCK_RESPONSE_CONTEXTVAR.set(mock_response)
def reset_llm_mock_response(token: contextvars.Token[str | None]) -> None:
_LLM_MOCK_RESPONSE_CONTEXTVAR.reset(token)

View File

@@ -592,11 +592,8 @@ def build_slack_response_blocks(
)
citations_blocks = []
document_blocks = []
if answer.citation_info:
citations_blocks = _build_citations_blocks(answer)
else:
document_blocks = _priority_ordered_documents_blocks(answer)
citations_divider = [DividerBlock()] if citations_blocks else []
buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else []
@@ -608,7 +605,6 @@ def build_slack_response_blocks(
+ ai_feedback_block
+ citations_divider
+ citations_blocks
+ document_blocks
+ buttons_divider
+ web_follow_up_block
+ follow_up_block

View File

@@ -1,5 +1,9 @@
import re
from enum import Enum
# Matches Slack channel references like <#C097NBWMY8Y> or <#C097NBWMY8Y|channel-name>
SLACK_CHANNEL_REF_PATTERN = re.compile(r"<#([A-Z0-9]+)(?:\|([^>]+))?>")
LIKE_BLOCK_ACTION_ID = "feedback-like"
DISLIKE_BLOCK_ACTION_ID = "feedback-dislike"
SHOW_EVERYONE_ACTION_ID = "show-everyone"

View File

@@ -1,29 +1,163 @@
import re
from collections.abc import Callable
from typing import Any
from mistune import create_markdown
from mistune import HTMLRenderer
# Tags that should be replaced with a newline (line-break and block-level elements)
_HTML_NEWLINE_TAG_PATTERN = re.compile(
r"<br\s*/?>|</(?:p|div|li|h[1-6]|tr|blockquote|section|article)>",
re.IGNORECASE,
)
# Strips HTML tags but excludes autolinks like <https://...> and <mailto:...>
_HTML_TAG_PATTERN = re.compile(
r"<(?!https?://|mailto:)/?[a-zA-Z][^>]*>",
)
# Matches fenced code blocks (``` ... ```) so we can skip sanitization inside them
_FENCED_CODE_BLOCK_PATTERN = re.compile(r"```[\s\S]*?```")
# Matches the start of any markdown link: [text]( or [[n]](
# The inner group handles nested brackets for citation links like [[1]](.
_MARKDOWN_LINK_PATTERN = re.compile(r"\[(?:[^\[\]]|\[[^\]]*\])*\]\(")
# Matches Slack-style links <url|text> that LLMs sometimes output directly.
# Mistune doesn't recognise this syntax, so text() would escape the angle
# brackets and Slack would render them as literal text instead of links.
_SLACK_LINK_PATTERN = re.compile(r"<(https?://[^|>]+)\|([^>]+)>")
def _sanitize_html(text: str) -> str:
"""Strip HTML tags from a text fragment.
Block-level closing tags and <br> are converted to newlines.
All other HTML tags are removed. Autolinks (<https://...>) are preserved.
"""
text = _HTML_NEWLINE_TAG_PATTERN.sub("\n", text)
text = _HTML_TAG_PATTERN.sub("", text)
return text
def _transform_outside_code_blocks(
message: str, transform: Callable[[str], str]
) -> str:
"""Apply *transform* only to text outside fenced code blocks."""
parts = _FENCED_CODE_BLOCK_PATTERN.split(message)
code_blocks = _FENCED_CODE_BLOCK_PATTERN.findall(message)
result: list[str] = []
for i, part in enumerate(parts):
result.append(transform(part))
if i < len(code_blocks):
result.append(code_blocks[i])
return "".join(result)
def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int | None]:
"""Extract markdown link destination, allowing nested parentheses in the URL."""
depth = 0
i = start_idx
while i < len(message):
curr = message[i]
if curr == "\\":
i += 2
continue
if curr == "(":
depth += 1
elif curr == ")":
if depth == 0:
return message[start_idx:i], i
depth -= 1
i += 1
return message[start_idx:], None
def _normalize_link_destinations(message: str) -> str:
"""Wrap markdown link URLs in angle brackets so the parser handles special chars safely.
Markdown link syntax [text](url) breaks when the URL contains unescaped
parentheses, spaces, or other special characters. Wrapping the URL in angle
brackets — [text](<url>) — tells the parser to treat everything inside as
a literal URL. This applies to all links, not just citations.
"""
if "](" not in message:
return message
normalized_parts: list[str] = []
cursor = 0
while match := _MARKDOWN_LINK_PATTERN.search(message, cursor):
normalized_parts.append(message[cursor : match.end()])
destination_start = match.end()
destination, end_idx = _extract_link_destination(message, destination_start)
if end_idx is None:
normalized_parts.append(message[destination_start:])
return "".join(normalized_parts)
already_wrapped = destination.startswith("<") and destination.endswith(">")
if destination and not already_wrapped:
destination = f"<{destination}>"
normalized_parts.append(destination)
normalized_parts.append(")")
cursor = end_idx + 1
normalized_parts.append(message[cursor:])
return "".join(normalized_parts)
def _convert_slack_links_to_markdown(message: str) -> str:
"""Convert Slack-style <url|text> links to standard markdown [text](url).
LLMs sometimes emit Slack mrkdwn link syntax directly. Mistune doesn't
recognise it, so the angle brackets would be escaped by text() and Slack
would render the link as literal text instead of a clickable link.
"""
return _transform_outside_code_blocks(
message, lambda text: _SLACK_LINK_PATTERN.sub(r"[\2](\1)", text)
)
def format_slack_message(message: str | None) -> str:
if message is None:
return ""
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough"])
result = md(message)
message = _transform_outside_code_blocks(message, _sanitize_html)
message = _convert_slack_links_to_markdown(message)
normalized_message = _normalize_link_destinations(message)
md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough", "table"])
result = md(normalized_message)
# With HTMLRenderer, result is always str (not AST list)
assert isinstance(result, str)
return result
return result.rstrip("\n")
class SlackRenderer(HTMLRenderer):
"""Renders markdown as Slack mrkdwn format instead of HTML.
Overrides all HTMLRenderer methods that produce HTML tags to ensure
no raw HTML ever appears in Slack messages.
"""
SPECIALS: dict[str, str] = {"&": "&amp;", "<": "&lt;", ">": "&gt;"}
def __init__(self) -> None:
super().__init__()
self._table_headers: list[str] = []
self._current_row_cells: list[str] = []
def escape_special(self, text: str) -> str:
for special, replacement in self.SPECIALS.items():
text = text.replace(special, replacement)
return text
def heading(self, text: str, level: int, **attrs: Any) -> str: # noqa: ARG002
return f"*{text}*\n"
return f"*{text}*\n\n"
def emphasis(self, text: str) -> str:
return f"_{text}_"
@@ -42,7 +176,7 @@ class SlackRenderer(HTMLRenderer):
count += 1
prefix = f"{count}. " if ordered else ""
lines[i] = f"{prefix}{line[4:]}"
return "\n".join(lines)
return "\n".join(lines) + "\n"
def list_item(self, text: str) -> str:
return f"li: {text}\n"
@@ -64,7 +198,73 @@ class SlackRenderer(HTMLRenderer):
return f"`{text}`"
def block_code(self, code: str, info: str | None = None) -> str: # noqa: ARG002
return f"```\n{code}\n```\n"
return f"```\n{code.rstrip(chr(10))}\n```\n\n"
def linebreak(self) -> str:
return "\n"
def thematic_break(self) -> str:
return "---\n\n"
def block_quote(self, text: str) -> str:
lines = text.strip().split("\n")
quoted = "\n".join(f">{line}" for line in lines)
return quoted + "\n\n"
def block_html(self, html: str) -> str:
return _sanitize_html(html) + "\n\n"
def block_error(self, text: str) -> str:
return f"```\n{text}\n```\n\n"
def text(self, text: str) -> str:
# Only escape the three entities Slack recognizes: & < >
# HTMLRenderer.text() also escapes " to &quot; which Slack renders
# as literal &quot; text since Slack doesn't recognize that entity.
return self.escape_special(text)
# -- Table rendering (converts markdown tables to vertical cards) --
def table_cell(
self, text: str, align: str | None = None, head: bool = False # noqa: ARG002
) -> str:
if head:
self._table_headers.append(text.strip())
else:
self._current_row_cells.append(text.strip())
return ""
def table_head(self, text: str) -> str: # noqa: ARG002
self._current_row_cells = []
return ""
def table_row(self, text: str) -> str: # noqa: ARG002
cells = self._current_row_cells
self._current_row_cells = []
# First column becomes the bold title, remaining columns are bulleted fields
lines: list[str] = []
if cells:
title = cells[0]
if title:
# Avoid double-wrapping if cell already contains bold markup
if title.startswith("*") and title.endswith("*") and len(title) > 1:
lines.append(title)
else:
lines.append(f"*{title}*")
for i, cell in enumerate(cells[1:], start=1):
if i < len(self._table_headers):
lines.append(f"{self._table_headers[i]}: {cell}")
else:
lines.append(f"{cell}")
return "\n".join(lines) + "\n\n"
def table_body(self, text: str) -> str:
return text
def table(self, text: str) -> str:
self._table_headers = []
self._current_row_cells = []
return text + "\n"
def paragraph(self, text: str) -> str:
return f"{text}\n"
return f"{text}\n\n"

View File

@@ -18,15 +18,18 @@ from onyx.configs.onyxbot_configs import ONYX_BOT_DISPLAY_ERROR_MSGS
from onyx.configs.onyxbot_configs import ONYX_BOT_NUM_RETRIES
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import Tag
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import SlackChannelConfig
from onyx.db.models import User
from onyx.db.persona import get_persona_by_id
from onyx.db.users import get_user_by_email
from onyx.onyxbot.slack.blocks import build_slack_response_blocks
from onyx.onyxbot.slack.constants import SLACK_CHANNEL_REF_PATTERN
from onyx.onyxbot.slack.handlers.utils import send_team_member_message
from onyx.onyxbot.slack.models import SlackMessageInfo
from onyx.onyxbot.slack.models import ThreadMessage
from onyx.onyxbot.slack.utils import get_channel_from_id
from onyx.onyxbot.slack.utils import get_channel_name_from_id
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
from onyx.onyxbot.slack.utils import SlackRateLimiter
@@ -41,6 +44,51 @@ srl = SlackRateLimiter()
RT = TypeVar("RT") # return type
def resolve_channel_references(
message: str,
client: WebClient,
logger: OnyxLoggingAdapter,
) -> tuple[str, list[Tag]]:
"""Parse Slack channel references from a message, resolve IDs to names,
replace the raw markup with readable #channel-name, and return channel tags
for search filtering."""
tags: list[Tag] = []
channel_matches = SLACK_CHANNEL_REF_PATTERN.findall(message)
seen_channel_ids: set[str] = set()
for channel_id, channel_name_from_markup in channel_matches:
if channel_id in seen_channel_ids:
continue
seen_channel_ids.add(channel_id)
channel_name = channel_name_from_markup or None
if not channel_name:
try:
channel_info = get_channel_from_id(client=client, channel_id=channel_id)
channel_name = channel_info.get("name") or None
except Exception:
logger.warning(f"Failed to resolve channel name for ID: {channel_id}")
if not channel_name:
continue
# Replace raw Slack markup with readable channel name
if channel_name_from_markup:
message = message.replace(
f"<#{channel_id}|{channel_name_from_markup}>",
f"#{channel_name}",
)
else:
message = message.replace(
f"<#{channel_id}>",
f"#{channel_name}",
)
tags.append(Tag(tag_key="Channel", tag_value=channel_name))
return message, tags
def rate_limits(
client: WebClient, channel: str, thread_ts: Optional[str]
) -> Callable[[Callable[..., RT]], Callable[..., RT]]:
@@ -157,6 +205,20 @@ def handle_regular_answer(
user_message = messages[-1]
history_messages = messages[:-1]
# Resolve any <#CHANNEL_ID> references in the user message to readable
# channel names and extract channel tags for search filtering
resolved_message, channel_tags = resolve_channel_references(
message=user_message.message,
client=client,
logger=logger,
)
user_message = ThreadMessage(
message=resolved_message,
sender=user_message.sender,
role=user_message.role,
)
channel_name, _ = get_channel_name_from_id(
client=client,
channel_id=channel,
@@ -207,6 +269,7 @@ def handle_regular_answer(
source_type=None,
document_set=document_set_names,
time_cutoff=None,
tags=channel_tags if channel_tags else None,
)
new_message_request = SendMessageRequest(
@@ -231,6 +294,16 @@ def handle_regular_answer(
slack_context_str=slack_context_str,
)
# If a channel filter was applied but no results were found, override
# the LLM response to avoid hallucinated answers about unindexed channels
if channel_tags and not answer.citation_info and not answer.top_documents:
channel_names = ", ".join(f"#{tag.tag_value}" for tag in channel_tags)
answer.answer = (
f"No indexed data found for {channel_names}. "
"This channel may not be indexed, or there may be no messages "
"matching your query within it."
)
except Exception as e:
logger.exception(
f"Unable to process message - did not successfully answer "
@@ -285,6 +358,7 @@ def handle_regular_answer(
only_respond_if_citations
and not answer.citation_info
and not message_info.bypass_filters
and not channel_tags
):
logger.error(
f"Unable to find citations to answer: '{answer.answer}' - not answering!"

View File

@@ -109,6 +109,7 @@ class TenantRedis(redis.Redis):
"unlock",
"get",
"set",
"setex",
"delete",
"exists",
"incrby",

View File

@@ -92,6 +92,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pairs_for
from onyx.db.connector_credential_pair import (
get_connector_credential_pairs_for_user_parallel,
)
from onyx.db.connector_credential_pair import verify_user_has_access_to_cc_pair
from onyx.db.credentials import cleanup_gmail_credentials
from onyx.db.credentials import cleanup_google_drive_credentials
from onyx.db.credentials import create_credential
@@ -556,6 +557,43 @@ def _normalize_file_names_for_backwards_compatibility(
return file_names + file_locations[len(file_names) :]
def _fetch_and_check_file_connector_cc_pair_permissions(
connector_id: int,
user: User,
db_session: Session,
require_editable: bool,
) -> ConnectorCredentialPair:
cc_pair = fetch_connector_credential_pair_for_connector(db_session, connector_id)
if cc_pair is None:
raise HTTPException(
status_code=404,
detail="No Connector-Credential Pair found for this connector",
)
has_requested_access = verify_user_has_access_to_cc_pair(
cc_pair_id=cc_pair.id,
db_session=db_session,
user=user,
get_editable=require_editable,
)
if has_requested_access:
return cc_pair
# Special case: global curators should be able to manage files
# for public file connectors even when they are not the creator.
if (
require_editable
and user.role == UserRole.GLOBAL_CURATOR
and cc_pair.access_type == AccessType.PUBLIC
):
return cc_pair
raise HTTPException(
status_code=403,
detail="Access denied. User cannot manage files for this connector.",
)
@router.post("/admin/connector/file/upload", tags=PUBLIC_API_TAGS)
def upload_files_api(
files: list[UploadFile],
@@ -567,7 +605,7 @@ def upload_files_api(
@router.get("/admin/connector/{connector_id}/files", tags=PUBLIC_API_TAGS)
def list_connector_files(
connector_id: int,
user: User = Depends(current_curator_or_admin_user), # noqa: ARG001
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> ConnectorFilesResponse:
"""List all files in a file connector."""
@@ -580,6 +618,13 @@ def list_connector_files(
status_code=400, detail="This endpoint only works with file connectors"
)
_ = _fetch_and_check_file_connector_cc_pair_permissions(
connector_id=connector_id,
user=user,
db_session=db_session,
require_editable=False,
)
file_locations = connector.connector_specific_config.get("file_locations", [])
file_names = connector.connector_specific_config.get("file_names", [])
@@ -629,7 +674,7 @@ def update_connector_files(
connector_id: int,
files: list[UploadFile] | None = File(None),
file_ids_to_remove: str = Form("[]"),
user: User = Depends(current_curator_or_admin_user), # noqa: ARG001
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> FileUploadResponse:
"""
@@ -647,12 +692,13 @@ def update_connector_files(
)
# Get the connector-credential pair for indexing/pruning triggers
cc_pair = fetch_connector_credential_pair_for_connector(db_session, connector_id)
if cc_pair is None:
raise HTTPException(
status_code=404,
detail="No Connector-Credential Pair found for this connector",
)
# and validate user permissions for file management.
cc_pair = _fetch_and_check_file_connector_cc_pair_permissions(
connector_id=connector_id,
user=user,
db_session=db_session,
require_editable=True,
)
# Parse file IDs to remove
try:

View File

@@ -243,6 +243,7 @@ class WebappInfo(BaseModel):
has_webapp: bool # Whether a webapp exists in outputs/web
webapp_url: str | None # URL to access the webapp (e.g., http://localhost:3015)
status: str # Sandbox status (running, terminated, etc.)
ready: bool # Whether the NextJS dev server is actually responding
# ===== File Upload Models =====

View File

@@ -13,6 +13,7 @@ from sqlalchemy.orm import Session
from onyx.auth.users import current_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import BuildSessionStatus
from onyx.db.enums import SandboxStatus
from onyx.db.models import BuildMessage
from onyx.db.models import User
@@ -32,6 +33,8 @@ from onyx.server.features.build.api.models import SuggestionBubble
from onyx.server.features.build.api.models import SuggestionTheme
from onyx.server.features.build.api.models import UploadResponse
from onyx.server.features.build.api.models import WebappInfo
from onyx.server.features.build.configs import SANDBOX_BACKEND
from onyx.server.features.build.configs import SandboxBackend
from onyx.server.features.build.db.build_session import allocate_nextjs_port
from onyx.server.features.build.db.build_session import get_build_session
from onyx.server.features.build.db.sandbox import get_latest_snapshot_for_session
@@ -362,14 +365,13 @@ def restore_session(
lock_key = f"sandbox_restore:{sandbox.id}"
lock = redis_client.lock(lock_key, timeout=RESTORE_LOCK_TIMEOUT_SECONDS)
# blocking=True means wait if another restore is in progress
acquired = lock.acquire(
blocking=True, blocking_timeout=RESTORE_LOCK_TIMEOUT_SECONDS
)
# Non-blocking: if another restore is already running, return 409 immediately
# instead of making the user wait. The frontend will retry.
acquired = lock.acquire(blocking=False)
if not acquired:
raise HTTPException(
status_code=503,
detail="Restore operation timed out waiting for lock",
status_code=409,
detail="Restore already in progress",
)
try:
@@ -379,15 +381,11 @@ def restore_session(
# Also re-check if session workspace exists (another request may have
# restored it while we were waiting)
if sandbox.status == SandboxStatus.RUNNING:
# Verify pod is healthy before proceeding
is_healthy = sandbox_manager.health_check(sandbox.id, timeout=10.0)
if is_healthy and sandbox_manager.session_workspace_exists(
sandbox.id, session_id
):
logger.info(
f"Session {session_id} workspace was restored by another request"
)
# Update heartbeat to mark sandbox as active
session.status = BuildSessionStatus.ACTIVE
update_sandbox_heartbeat(db_session, sandbox.id)
base_response = SessionResponse.from_model(session, sandbox)
return DetailedSessionResponse.from_session_response(
@@ -410,69 +408,82 @@ def restore_session(
# Fall through to TERMINATED handling below
session_manager = SessionManager(db_session)
llm_config = session_manager._get_llm_config(None, None)
if sandbox.status in (SandboxStatus.SLEEPING, SandboxStatus.TERMINATED):
# 1. Re-provision the pod
logger.info(f"Re-provisioning {sandbox.status.value} sandbox {sandbox.id}")
llm_config = session_manager._get_llm_config(None, None)
# Mark as PROVISIONING before the long-running provision() call
# so other requests know work is in progress
update_sandbox_status__no_commit(
db_session, sandbox.id, SandboxStatus.PROVISIONING
)
db_session.commit()
sandbox_manager.provision(
sandbox_id=sandbox.id,
user_id=user.id,
tenant_id=tenant_id,
llm_config=llm_config,
)
# Mark as RUNNING after successful provision
update_sandbox_status__no_commit(
db_session, sandbox.id, SandboxStatus.RUNNING
)
db_session.commit()
db_session.refresh(sandbox)
# 2. Check if session workspace needs to be loaded
if sandbox.status == SandboxStatus.RUNNING:
if not sandbox_manager.session_workspace_exists(sandbox.id, session_id):
# Get latest snapshot and restore it
snapshot = get_latest_snapshot_for_session(db_session, session_id)
if snapshot:
# Allocate a new port for the restored session
new_port = allocate_nextjs_port(db_session)
session.nextjs_port = new_port
workspace_exists = sandbox_manager.session_workspace_exists(
sandbox.id, session_id
)
if not workspace_exists:
# Allocate port if not already set (needed for both snapshot restore and fresh setup)
if not session.nextjs_port:
session.nextjs_port = allocate_nextjs_port(db_session)
# Commit port allocation before long-running operations
db_session.commit()
logger.info(
f"Restoring snapshot for session {session_id} "
f"from {snapshot.storage_path} with port {new_port}"
)
# Only Kubernetes backend supports snapshot restoration
snapshot = None
if SANDBOX_BACKEND == SandboxBackend.KUBERNETES:
snapshot = get_latest_snapshot_for_session(db_session, session_id)
if snapshot:
try:
sandbox_manager.restore_snapshot(
sandbox_id=sandbox.id,
session_id=session_id,
snapshot_storage_path=snapshot.storage_path,
tenant_id=tenant_id,
nextjs_port=new_port,
nextjs_port=session.nextjs_port,
llm_config=llm_config,
use_demo_data=session.demo_data_enabled,
)
session.status = BuildSessionStatus.ACTIVE
db_session.commit()
except Exception as e:
# Clear the port allocation on failure so it can be reused
logger.error(
f"Failed to restore session {session_id}, "
f"clearing port {new_port}: {e}"
f"Snapshot restore failed for session {session_id}: {e}"
)
session.nextjs_port = None
db_session.commit()
raise
else:
# No snapshot - set up fresh workspace
logger.info(
f"No snapshot found for session {session_id}, "
f"setting up fresh workspace"
)
llm_config = session_manager._get_llm_config(None, None)
sandbox_manager.setup_session_workspace(
sandbox_id=sandbox.id,
session_id=session_id,
llm_config=llm_config,
nextjs_port=session.nextjs_port or 3010,
nextjs_port=session.nextjs_port,
)
session.status = BuildSessionStatus.ACTIVE
db_session.commit()
else:
logger.warning(
f"Sandbox {sandbox.id} status is {sandbox.status} after "
f"re-provision, expected RUNNING"
)
except Exception as e:
logger.error(f"Failed to restore session {session_id}: {e}", exc_info=True)

View File

@@ -18,7 +18,6 @@ from onyx.db.models import BuildMessage
from onyx.db.models import BuildSession
from onyx.db.models import LLMProvider as LLMProviderModel
from onyx.db.models import Sandbox
from onyx.db.models import Snapshot
from onyx.server.features.build.configs import SANDBOX_NEXTJS_PORT_END
from onyx.server.features.build.configs import SANDBOX_NEXTJS_PORT_START
from onyx.server.manage.llm.models import LLMProviderView
@@ -269,27 +268,6 @@ def update_artifact(
logger.info(f"Updated artifact {artifact_id}")
# Snapshot operations
def create_snapshot(
session_id: UUID,
storage_path: str,
size_bytes: int,
db_session: Session,
) -> Snapshot:
"""Create a new snapshot record."""
snapshot = Snapshot(
session_id=session_id,
storage_path=storage_path,
size_bytes=size_bytes,
)
db_session.add(snapshot)
db_session.commit()
db_session.refresh(snapshot)
logger.info(f"Created snapshot {snapshot.id} for session {session_id}")
return snapshot
# Message operations
def create_message(
session_id: UUID,
@@ -501,6 +479,32 @@ def allocate_nextjs_port(db_session: Session) -> int:
)
def mark_user_sessions_idle__no_commit(db_session: Session, user_id: UUID) -> int:
"""Mark all ACTIVE sessions for a user as IDLE.
Called when a sandbox goes to sleep so the frontend knows these sessions
need restoration before they can be used again.
Args:
db_session: Database session
user_id: The user whose sessions should be marked idle
Returns:
Number of sessions updated
"""
result = (
db_session.query(BuildSession)
.filter(
BuildSession.user_id == user_id,
BuildSession.status == BuildSessionStatus.ACTIVE,
)
.update({BuildSession.status: BuildSessionStatus.IDLE})
)
db_session.flush()
logger.info(f"Marked {result} sessions as IDLE for user {user_id}")
return result
def clear_nextjs_ports_for_user(db_session: Session, user_id: UUID) -> int:
"""Clear nextjs_port for all sessions belonging to a user.

View File

@@ -126,7 +126,7 @@ def get_idle_sandboxes(
)
stmt = select(Sandbox).where(
Sandbox.status.in_([SandboxStatus.RUNNING, SandboxStatus.IDLE]),
Sandbox.status == SandboxStatus.RUNNING,
or_(
Sandbox.last_heartbeat < threshold_time,
and_(
@@ -147,27 +147,30 @@ def get_running_sandbox_count_by_tenant(
since Sandbox model no longer has tenant_id. This function returns
the count of all running sandboxes.
"""
stmt = select(func.count(Sandbox.id)).where(
Sandbox.status.in_([SandboxStatus.RUNNING, SandboxStatus.IDLE])
)
stmt = select(func.count(Sandbox.id)).where(Sandbox.status == SandboxStatus.RUNNING)
result = db_session.execute(stmt).scalar()
return result or 0
def create_snapshot(
def create_snapshot__no_commit(
db_session: Session,
session_id: UUID,
storage_path: str,
size_bytes: int,
) -> Snapshot:
"""Create a snapshot record for a session."""
"""Create a snapshot record for a session.
NOTE: Uses flush() instead of commit(). The caller (cleanup task) is
responsible for committing after all snapshots + status updates are done,
so the entire operation is atomic.
"""
snapshot = Snapshot(
session_id=session_id,
storage_path=storage_path,
size_bytes=size_bytes,
)
db_session.add(snapshot)
db_session.commit()
db_session.flush()
return snapshot

View File

@@ -183,13 +183,14 @@ class SandboxManager(ABC):
session_id: UUID,
tenant_id: str,
) -> SnapshotResult | None:
"""Create a snapshot of a session's outputs directory.
"""Create a snapshot of a session's outputs and attachments directories.
Captures only the session-specific outputs:
sessions/$session_id/outputs/
Captures session-specific user data:
- sessions/$session_id/outputs/ (generated artifacts, web apps)
- sessions/$session_id/attachments/ (user uploaded files)
Does NOT include: venv, skills, AGENTS.md, opencode.json, attachments
Does NOT include: shared files/ directory
Does NOT include: venv, skills, AGENTS.md, opencode.json, files symlink
(these are regenerated during restore)
Args:
sandbox_id: The sandbox ID
@@ -197,14 +198,45 @@ class SandboxManager(ABC):
tenant_id: Tenant identifier for storage path
Returns:
SnapshotResult with storage path and size, or None if
snapshots are disabled for this backend
SnapshotResult with storage path and size, or None if:
- Snapshots are disabled for this backend
- No outputs directory exists (nothing to snapshot)
Raises:
RuntimeError: If snapshot creation fails
"""
...
@abstractmethod
def restore_snapshot(
self,
sandbox_id: UUID,
session_id: UUID,
snapshot_storage_path: str,
tenant_id: str,
nextjs_port: int,
llm_config: LLMProviderConfig,
use_demo_data: bool = False,
) -> None:
"""Restore a session workspace from a snapshot.
For Kubernetes: Downloads and extracts the snapshot, regenerates config files.
For Local: No-op since workspaces persist on disk (no snapshots).
Args:
sandbox_id: The sandbox ID
session_id: The session ID to restore
snapshot_storage_path: Path to the snapshot in storage
tenant_id: Tenant identifier for storage access
nextjs_port: Port number for the NextJS dev server
llm_config: LLM provider configuration for opencode.json
use_demo_data: If True, symlink files/ to demo data
Raises:
RuntimeError: If snapshot restoration fails
"""
...
@abstractmethod
def session_workspace_exists(
self,
@@ -225,36 +257,6 @@ class SandboxManager(ABC):
"""
...
@abstractmethod
def restore_snapshot(
self,
sandbox_id: UUID,
session_id: UUID,
snapshot_storage_path: str,
tenant_id: str,
nextjs_port: int,
) -> None:
"""Restore a snapshot into a session's workspace directory.
Downloads the snapshot from storage, extracts it into
sessions/$session_id/outputs/, and starts the NextJS server.
For Kubernetes backend, this downloads from S3 and streams
into the pod via kubectl exec (since the pod has no S3 access).
Args:
sandbox_id: The sandbox ID
session_id: The session ID to restore
snapshot_storage_path: Path to the snapshot in storage
tenant_id: Tenant identifier for storage access
nextjs_port: Port number for the NextJS dev server
Raises:
RuntimeError: If snapshot restoration fails
FileNotFoundError: If snapshot does not exist
"""
...
@abstractmethod
def health_check(self, sandbox_id: UUID, timeout: float = 60.0) -> bool:
"""Check if the sandbox is healthy.

View File

@@ -1583,9 +1583,9 @@
}
},
"node_modules/@isaacs/brace-expansion": {
"version": "5.0.0",
"resolved": "https://registry.npmjs.org/@isaacs/brace-expansion/-/brace-expansion-5.0.0.tgz",
"integrity": "sha512-ZT55BDLV0yv0RBm2czMiZ+SqCGO7AvmOM3G/w2xhVPH+te0aKgFjmBvGlL1dH+ql2tgGO3MVrbb3jCKyvpgnxA==",
"version": "5.0.1",
"resolved": "https://registry.npmjs.org/@isaacs/brace-expansion/-/brace-expansion-5.0.1.tgz",
"integrity": "sha512-WMz71T1JS624nWj2n2fnYAuPovhv7EUhk69R6i9dsVyzxt5eM3bjwvgk9L+APE1TRscGysAVMANkB0jh0LQZrQ==",
"license": "MIT",
"dependencies": {
"@isaacs/balanced-match": "^4.0.1"
@@ -1640,9 +1640,9 @@
}
},
"node_modules/@modelcontextprotocol/sdk": {
"version": "1.25.3",
"resolved": "https://registry.npmjs.org/@modelcontextprotocol/sdk/-/sdk-1.25.3.tgz",
"integrity": "sha512-vsAMBMERybvYgKbg/l4L1rhS7VXV1c0CtyJg72vwxONVX0l4ZfKVAnZEWTQixJGTzKnELjQ59e4NbdFDALRiAQ==",
"version": "1.26.0",
"resolved": "https://registry.npmjs.org/@modelcontextprotocol/sdk/-/sdk-1.26.0.tgz",
"integrity": "sha512-Y5RmPncpiDtTXDbLKswIJzTqu2hyBKxTNsgKqKclDbhIgg1wgtf1fRuvxgTnRfcnxtvvgbIEcqUOzZrJ6iSReg==",
"license": "MIT",
"dependencies": {
"@hono/node-server": "^1.19.9",
@@ -1653,14 +1653,15 @@
"cross-spawn": "^7.0.5",
"eventsource": "^3.0.2",
"eventsource-parser": "^3.0.0",
"express": "^5.0.1",
"express-rate-limit": "^7.5.0",
"jose": "^6.1.1",
"express": "^5.2.1",
"express-rate-limit": "^8.2.1",
"hono": "^4.11.4",
"jose": "^6.1.3",
"json-schema-typed": "^8.0.2",
"pkce-challenge": "^5.0.0",
"raw-body": "^3.0.0",
"zod": "^3.25 || ^4.0",
"zod-to-json-schema": "^3.25.0"
"zod-to-json-schema": "^3.25.1"
},
"engines": {
"node": ">=18"
@@ -6757,10 +6758,13 @@
}
},
"node_modules/express-rate-limit": {
"version": "7.5.1",
"resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-7.5.1.tgz",
"integrity": "sha512-7iN8iPMDzOMHPUYllBEsQdWVB6fPDMPqwjBaFrgr4Jgr/+okjvzAy+UHlYYL/Vs0OsOrMkwS6PJDkFlJwoxUnw==",
"version": "8.2.1",
"resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-8.2.1.tgz",
"integrity": "sha512-PCZEIEIxqwhzw4KF0n7QF4QqruVTcF73O5kFKUnGOyjbCCgizBBiFaYpd/fnBLUMPw/BWw9OsiN7GgrNYr7j6g==",
"license": "MIT",
"dependencies": {
"ip-address": "10.0.1"
},
"engines": {
"node": ">= 16"
},
@@ -7424,7 +7428,6 @@
"resolved": "https://registry.npmjs.org/hono/-/hono-4.11.7.tgz",
"integrity": "sha512-l7qMiNee7t82bH3SeyUCt9UF15EVmaBvsppY2zQtrbIhl/yzBTny+YUxsVjSjQ6gaqaeVtZmGocom8TzBlA4Yw==",
"license": "MIT",
"peer": true,
"engines": {
"node": ">=16.9.0"
}
@@ -7552,6 +7555,15 @@
"node": ">=12"
}
},
"node_modules/ip-address": {
"version": "10.0.1",
"resolved": "https://registry.npmjs.org/ip-address/-/ip-address-10.0.1.tgz",
"integrity": "sha512-NWv9YLW4PoW2B7xtzaS3NCot75m6nK7Icdv0o3lfMceJVRfSoQwqD4wEH5rLwoKJwUiZ/rfpiVBhnaF0FK4HoA==",
"license": "MIT",
"engines": {
"node": ">= 12"
}
},
"node_modules/ipaddr.js": {
"version": "1.9.1",
"resolved": "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-1.9.1.tgz",

View File

@@ -65,7 +65,6 @@ from onyx.server.features.build.configs import SANDBOX_NEXTJS_PORT_END
from onyx.server.features.build.configs import SANDBOX_NEXTJS_PORT_START
from onyx.server.features.build.configs import SANDBOX_S3_BUCKET
from onyx.server.features.build.configs import SANDBOX_SERVICE_ACCOUNT_NAME
from onyx.server.features.build.s3.s3_client import build_s3_client
from onyx.server.features.build.sandbox.base import SandboxManager
from onyx.server.features.build.sandbox.kubernetes.internal.acp_exec_client import (
ACPEvent,
@@ -409,6 +408,10 @@ done
],
volume_mounts=[
client.V1VolumeMount(name="files", mount_path="/workspace/files"),
# Mount sessions directory so file-sync can create snapshots
client.V1VolumeMount(
name="workspace", mount_path="/workspace/sessions"
),
],
resources=client.V1ResourceRequirements(
# Reduced resources since sidecar is mostly idle (sleeping)
@@ -442,6 +445,10 @@ done
client.V1VolumeMount(
name="files", mount_path="/workspace/files", read_only=True
),
# Mount sessions directory (shared with file-sync for snapshots)
client.V1VolumeMount(
name="workspace", mount_path="/workspace/sessions"
),
],
resources=client.V1ResourceRequirements(
requests={"cpu": "500m", "memory": "1Gi"},
@@ -583,6 +590,60 @@ done
),
)
def _ensure_service_exists(
self,
sandbox_id: UUID,
tenant_id: str,
) -> None:
"""Ensure a ClusterIP service exists for the sandbox pod.
Handles the case where a service is in Terminating state (has a
deletion_timestamp) by waiting for deletion and recreating it.
This prevents a race condition where provision reuses an existing pod
but the old service is still being deleted.
"""
service_name = self._get_service_name(str(sandbox_id))
try:
svc = self._core_api.read_namespaced_service(
name=service_name,
namespace=self._namespace,
)
# Service exists - check if it's being deleted
if svc.metadata.deletion_timestamp:
logger.info(
f"Service {service_name} is terminating, waiting for deletion"
)
self._wait_for_resource_deletion("service", service_name)
# Now create a fresh service
service = self._create_sandbox_service(sandbox_id, tenant_id)
self._core_api.create_namespaced_service(
namespace=self._namespace,
body=service,
)
logger.info(f"Recreated Service {service_name} after termination")
else:
logger.debug(f"Service {service_name} already exists and is active")
except ApiException as e:
if e.status == 404:
# Service doesn't exist, create it
logger.info(f"Creating missing Service {service_name}")
service = self._create_sandbox_service(sandbox_id, tenant_id)
try:
self._core_api.create_namespaced_service(
namespace=self._namespace,
body=service,
)
except ApiException as svc_e:
if svc_e.status != 409: # Ignore AlreadyExists
raise
logger.debug(
f"Service {service_name} was created by another request"
)
else:
raise
def _get_init_container_logs(self, pod_name: str, container_name: str) -> str:
"""Get logs from an init container.
@@ -798,34 +859,14 @@ done
)
pod_name = self._get_pod_name(str(sandbox_id))
service_name = self._get_service_name(str(sandbox_id))
# Check if pod already exists and is healthy (idempotency check)
if self._pod_exists_and_healthy(pod_name):
logger.info(
f"Pod {pod_name} already exists and is healthy, reusing existing pod"
)
# Ensure service exists too
try:
self._core_api.read_namespaced_service(
name=service_name,
namespace=self._namespace,
)
except ApiException as e:
if e.status == 404:
# Service doesn't exist, create it
logger.debug(f"Creating missing Service {service_name}")
service = self._create_sandbox_service(sandbox_id, tenant_id)
try:
self._core_api.create_namespaced_service(
namespace=self._namespace,
body=service,
)
except ApiException as svc_e:
if svc_e.status != 409: # Ignore AlreadyExists
raise
else:
raise
# Ensure service exists and is not terminating
self._ensure_service_exists(sandbox_id, tenant_id)
# Wait for pod to be ready if it's still pending
logger.info(f"Waiting for existing pod {pod_name} to become ready...")
@@ -880,20 +921,8 @@ done
else:
raise
# 2. Create Service (idempotent - ignore 409)
logger.debug(f"Creating Service {service_name}")
service = self._create_sandbox_service(sandbox_id, tenant_id)
try:
self._core_api.create_namespaced_service(
namespace=self._namespace,
body=service,
)
except ApiException as e:
if e.status != 409: # Ignore AlreadyExists
raise
logger.warning(
f"During provisioning, discovered that service {service_name} already exists. Reusing"
)
# 2. Create Service (handles terminating services)
self._ensure_service_exists(sandbox_id, tenant_id)
# 3. Wait for pod to be ready
logger.info(f"Waiting for pod {pod_name} to become ready...")
@@ -1335,10 +1364,12 @@ echo "Session cleanup complete"
session_id: UUID,
tenant_id: str,
) -> SnapshotResult | None:
"""Create a snapshot of a session's outputs directory.
"""Create a snapshot of a session's outputs and attachments directories.
For Kubernetes backend, we exec into the pod to create the snapshot.
Only captures sessions/$session_id/outputs/
For Kubernetes backend, we exec into the file-sync container to create
the snapshot and upload to S3. Captures:
- sessions/$session_id/outputs/ (generated artifacts, web apps)
- sessions/$session_id/attachments/ (user uploaded files)
Args:
sandbox_id: The sandbox ID
@@ -1346,7 +1377,7 @@ echo "Session cleanup complete"
tenant_id: Tenant identifier for storage path
Returns:
SnapshotResult with storage path and size
SnapshotResult with storage path and size, or None if nothing to snapshot
Raises:
RuntimeError: If snapshot creation fails
@@ -1356,26 +1387,40 @@ echo "Session cleanup complete"
pod_name = self._get_pod_name(sandbox_id_str)
snapshot_id = str(uuid4())
session_path = f"/workspace/sessions/{session_id_str}"
# Use shlex.quote for safety (UUIDs are safe but good practice)
safe_session_path = shlex.quote(f"/workspace/sessions/{session_id_str}")
s3_path = (
f"s3://{self._s3_bucket}/{tenant_id}/snapshots/"
f"{session_id_str}/{snapshot_id}.tar.gz"
)
# Exec into pod to create and upload snapshot (session outputs only)
# Exec into pod to create and upload snapshot (outputs + attachments)
# Uses s5cmd pipe to stream tar.gz directly to S3
# Only snapshot if outputs/ exists. Include attachments/ only if non-empty.
exec_command = [
"/bin/sh",
"-c",
f'tar -czf - -C {session_path} outputs | aws s3 cp - {s3_path} --tagging "Type=snapshot"',
f"""
set -eo pipefail
cd {safe_session_path}
if [ ! -d outputs ]; then
echo "EMPTY_SNAPSHOT"
exit 0
fi
dirs="outputs"
[ -d attachments ] && [ "$(ls -A attachments 2>/dev/null)" ] && dirs="$dirs attachments"
tar -czf - $dirs | /s5cmd pipe {s3_path}
echo "SNAPSHOT_CREATED"
""",
]
try:
# Use exec to run snapshot command in sandbox container
# Use exec to run snapshot command in file-sync container (has s5cmd)
resp = k8s_stream(
self._stream_core_api.connect_get_namespaced_pod_exec,
name=pod_name,
namespace=self._namespace,
container="sandbox",
container="file-sync",
command=exec_command,
stderr=True,
stdin=False,
@@ -1385,6 +1430,17 @@ echo "Session cleanup complete"
logger.debug(f"Snapshot exec output: {resp}")
# Check if nothing was snapshotted
if "EMPTY_SNAPSHOT" in resp:
logger.info(
f"No outputs or attachments to snapshot for session {session_id}"
)
return None
# Verify upload succeeded
if "SNAPSHOT_CREATED" not in resp:
raise RuntimeError(f"Snapshot upload may have failed. Output: {resp}")
except ApiException as e:
raise RuntimeError(f"Failed to create snapshot: {e}") from e
@@ -1392,9 +1448,8 @@ echo "Session cleanup complete"
# In production, you might want to query S3 for the actual size
size_bytes = 0
storage_path = (
f"sandbox-snapshots/{tenant_id}/{session_id_str}/{snapshot_id}.tar.gz"
)
# Storage path must match the S3 upload path (without s3://bucket/ prefix)
storage_path = f"{tenant_id}/snapshots/{session_id_str}/{snapshot_id}.tar.gz"
logger.info(f"Created snapshot for session {session_id}")
@@ -1426,7 +1481,7 @@ echo "Session cleanup complete"
exec_command = [
"/bin/sh",
"-c",
f'[ -d "{session_path}" ] && echo "EXISTS" || echo "NOT_EXISTS"',
f'[ -d "{session_path}" ] && echo "WORKSPACE_FOUND" || echo "WORKSPACE_MISSING"',
]
try:
@@ -1442,7 +1497,12 @@ echo "Session cleanup complete"
tty=False,
)
return "EXISTS" in resp
result = "WORKSPACE_FOUND" in resp
logger.info(
f"[WORKSPACE_CHECK] session={session_id}, "
f"path={session_path}, raw_resp={resp!r}, result={result}"
)
return result
except ApiException as e:
logger.warning(
@@ -1457,14 +1517,21 @@ echo "Session cleanup complete"
snapshot_storage_path: str,
tenant_id: str, # noqa: ARG002
nextjs_port: int,
llm_config: LLMProviderConfig,
use_demo_data: bool = False,
) -> None:
"""Download snapshot from S3, extract into session workspace, and start NextJS.
"""Download snapshot from S3 via s5cmd, extract, regenerate config, and start NextJS.
Since the sandbox pod doesn't have S3 access, this method:
1. Downloads snapshot from S3 (using boto3 directly)
2. Creates the session directory structure in pod
3. Streams the tar.gz into the pod via kubectl exec
4. Starts the NextJS dev server
Uses the file-sync sidecar container (which has s5cmd + S3 credentials
via IRSA) to stream the snapshot directly from S3 into the session
directory. This avoids downloading to the backend server and the
base64 encoding overhead of piping through kubectl exec.
Steps:
1. Exec s5cmd cat in file-sync container to stream snapshot from S3
2. Pipe directly to tar for extraction in the shared workspace volume
3. Regenerate configuration files (AGENTS.md, opencode.json, files symlink)
4. Start the NextJS dev server
Args:
sandbox_id: The sandbox ID
@@ -1472,87 +1539,56 @@ echo "Session cleanup complete"
snapshot_storage_path: Path to the snapshot in S3 (relative path)
tenant_id: Tenant identifier for storage access
nextjs_port: Port number for the NextJS dev server
llm_config: LLM provider configuration for opencode.json
use_demo_data: If True, symlink files/ to demo data; else to user files
Raises:
RuntimeError: If snapshot restoration fails
FileNotFoundError: If snapshot does not exist
"""
import tempfile
pod_name = self._get_pod_name(str(sandbox_id))
session_path = f"/workspace/sessions/{session_id}"
safe_session_path = shlex.quote(session_path)
# Build full S3 path
s3_key = snapshot_storage_path
s3_path = f"s3://{self._s3_bucket}/{snapshot_storage_path}"
logger.info(f"Restoring snapshot for session {session_id} from {s3_key}")
# Download snapshot from S3 - uses IAM roles (IRSA)
s3_client = build_s3_client()
tmp_path: str | None = None
# Stream snapshot directly from S3 via s5cmd in file-sync container.
# Mirrors the upload pattern: upload uses `tar | s5cmd pipe`,
# restore uses `s5cmd cat | tar`. Both run in file-sync container
# which has s5cmd and S3 credentials (IRSA). The shared workspace
# volume makes extracted files immediately visible to the sandbox
# container.
restore_script = f"""
set -eo pipefail
mkdir -p {safe_session_path}
/s5cmd cat {s3_path} | tar -xzf - -C {safe_session_path}
echo "SNAPSHOT_RESTORED"
"""
try:
with tempfile.NamedTemporaryFile(
suffix=".tar.gz", delete=False
) as tmp_file:
tmp_path = tmp_file.name
try:
s3_client.download_file(self._s3_bucket, s3_key, tmp_path)
except s3_client.exceptions.NoSuchKey:
raise FileNotFoundError(
f"Snapshot not found: s3://{self._s3_bucket}/{s3_key}"
)
# Create session directory structure in pod
# Use shlex.quote to prevent shell injection
safe_session_path = shlex.quote(session_path)
setup_script = f"""
set -e
mkdir -p {safe_session_path}/outputs
"""
k8s_stream(
self._stream_core_api.connect_get_namespaced_pod_exec,
name=pod_name,
namespace=self._namespace,
container="sandbox",
command=["/bin/sh", "-c", setup_script],
stderr=True,
stdin=False,
stdout=True,
tty=False,
)
# Stream tar.gz into pod and extract
# We use kubectl exec with stdin to pipe the tar file
with open(tmp_path, "rb") as tar_file:
tar_data = tar_file.read()
# Use base64 encoding to safely transfer binary data
import base64
tar_b64 = base64.b64encode(tar_data).decode("ascii")
# Extract in the session directory (tar was created with outputs/ as root)
extract_script = f"""
set -e
cd {safe_session_path}
echo '{tar_b64}' | base64 -d | tar -xzf -
"""
resp = k8s_stream(
self._stream_core_api.connect_get_namespaced_pod_exec,
name=pod_name,
namespace=self._namespace,
container="sandbox",
command=["/bin/sh", "-c", extract_script],
container="file-sync",
command=["/bin/sh", "-c", restore_script],
stderr=True,
stdin=False,
stdout=True,
tty=False,
)
logger.debug(f"Snapshot restore output: {resp}")
logger.info(f"Restored snapshot for session {session_id}")
if "SNAPSHOT_RESTORED" not in resp:
raise RuntimeError(f"Snapshot restore may have failed. Output: {resp}")
# Regenerate configuration files that aren't in the snapshot
# These are regenerated to ensure they match the current system state
self._regenerate_session_config(
pod_name=pod_name,
session_path=safe_session_path,
llm_config=llm_config,
nextjs_port=nextjs_port,
use_demo_data=use_demo_data,
)
# Start NextJS dev server (check node_modules since restoring from snapshot)
start_script = _build_nextjs_start_script(
@@ -1569,23 +1605,95 @@ echo '{tar_b64}' | base64 -d | tar -xzf -
stdout=True,
tty=False,
)
logger.info(
f"Started NextJS server for session {session_id} on port {nextjs_port}"
)
except ApiException as e:
raise RuntimeError(f"Failed to restore snapshot: {e}") from e
finally:
# Cleanup temp file
if tmp_path:
try:
import os
os.unlink(tmp_path)
except Exception as cleanup_error:
logger.warning(
f"Failed to cleanup temp file {tmp_path}: {cleanup_error}"
)
def _regenerate_session_config(
self,
pod_name: str,
session_path: str,
llm_config: LLMProviderConfig,
nextjs_port: int,
use_demo_data: bool,
) -> None:
"""Regenerate session configuration files after snapshot restore.
Creates:
- AGENTS.md (agent instructions)
- opencode.json (LLM configuration)
- files symlink (to demo data or user files)
Args:
pod_name: The pod name to exec into
session_path: Path to the session directory (already shlex.quoted)
llm_config: LLM provider configuration
nextjs_port: Port for NextJS (used in AGENTS.md)
use_demo_data: Whether to use demo data or user files
"""
# Generate AGENTS.md content
agent_instructions = self._load_agent_instructions(
files_path=None, # Container script handles this at runtime
provider=llm_config.provider,
model_name=llm_config.model_name,
nextjs_port=nextjs_port,
disabled_tools=OPENCODE_DISABLED_TOOLS,
user_name=None, # Not stored, regenerate without personalization
user_role=None,
use_demo_data=use_demo_data,
include_org_info=False, # Don't include org_info for restored sessions
)
# Generate opencode.json
opencode_config = build_opencode_config(
provider=llm_config.provider,
model_name=llm_config.model_name,
api_key=llm_config.api_key if llm_config.api_key else None,
api_base=llm_config.api_base,
disabled_tools=OPENCODE_DISABLED_TOOLS,
)
opencode_json = json.dumps(opencode_config)
# Escape for shell (single quotes)
opencode_json_escaped = opencode_json.replace("'", "'\\''")
agent_instructions_escaped = agent_instructions.replace("'", "'\\''")
# Build files symlink setup
if use_demo_data:
symlink_target = "/workspace/demo_data"
else:
symlink_target = "/workspace/files"
config_script = f"""
set -e
# Create files symlink
echo "Creating files symlink to {symlink_target}"
ln -sf {symlink_target} {session_path}/files
# Write agent instructions
echo "Writing AGENTS.md"
printf '%s' '{agent_instructions_escaped}' > {session_path}/AGENTS.md
# Write opencode config
echo "Writing opencode.json"
printf '%s' '{opencode_json_escaped}' > {session_path}/opencode.json
echo "Session config regeneration complete"
"""
logger.info("Regenerating session configuration files")
k8s_stream(
self._stream_core_api.connect_get_namespaced_pod_exec,
name=pod_name,
namespace=self._namespace,
container="sandbox",
command=["/bin/sh", "-c", config_script],
stderr=True,
stdin=False,
stdout=True,
tty=False,
)
logger.info("Session configuration files regenerated")
def health_check(self, sandbox_id: UUID, timeout: float = 60.0) -> bool:
"""Check if the sandbox pod is healthy (can exec into it).

View File

@@ -608,34 +608,14 @@ class LocalSandboxManager(SandboxManager):
session_id: UUID,
tenant_id: str,
) -> SnapshotResult | None:
"""Create a snapshot of a session's outputs directory.
"""Not implemented for local backend - workspaces persist on disk.
Returns None if snapshots are disabled (local backend).
Args:
sandbox_id: The sandbox ID
session_id: The session ID to snapshot
tenant_id: Tenant identifier for storage path
Returns:
SnapshotResult with storage path and size, or None if
snapshots are disabled for this backend
Local sandboxes don't use snapshots since the filesystem persists.
This should never be called for local backend.
"""
session_path = self._get_session_path(sandbox_id, session_id)
# SnapshotManager expects string session_id for storage path
_, storage_path, size_bytes = self._snapshot_manager.create_snapshot(
session_path,
str(session_id),
tenant_id,
)
logger.info(
f"Created snapshot for session {session_id}, size: {size_bytes} bytes"
)
return SnapshotResult(
storage_path=storage_path,
size_bytes=size_bytes,
raise NotImplementedError(
"create_snapshot is not supported for local backend. "
"Local sandboxes persist on disk and don't use snapshots."
)
def session_workspace_exists(
@@ -663,52 +643,23 @@ class LocalSandboxManager(SandboxManager):
snapshot_storage_path: str,
tenant_id: str, # noqa: ARG002
nextjs_port: int,
llm_config: LLMProviderConfig,
use_demo_data: bool = False,
) -> None:
"""Restore a snapshot into a session's workspace directory and start NextJS.
"""Not implemented for local backend - workspaces persist on disk.
Args:
sandbox_id: The sandbox ID
session_id: The session ID to restore
snapshot_storage_path: Path to the snapshot in storage
tenant_id: Tenant identifier for storage access
nextjs_port: Port number for the NextJS dev server
Raises:
RuntimeError: If snapshot restoration fails
FileNotFoundError: If snapshot does not exist
Local sandboxes don't use snapshots since the filesystem persists.
This should never be called for local backend.
"""
session_path = self._get_session_path(sandbox_id, session_id)
# Ensure session directory exists
session_path.mkdir(parents=True, exist_ok=True)
# Use SnapshotManager to restore
self._snapshot_manager.restore_snapshot(
storage_path=snapshot_storage_path,
target_path=session_path,
raise NotImplementedError(
"restore_snapshot is not supported for local backend. "
"Local sandboxes persist on disk and don't use snapshots."
)
logger.info(f"Restored snapshot for session {session_id}")
# Start NextJS dev server
web_dir = session_path / "outputs" / "web"
if web_dir.exists():
logger.info(f"Starting Next.js server at {web_dir} on port {nextjs_port}")
nextjs_process = self._process_manager.start_nextjs_server(
web_dir, nextjs_port
)
# Store process for clean shutdown on session delete
self._nextjs_processes[(sandbox_id, session_id)] = nextjs_process
logger.info(
f"Started NextJS server for session {session_id} on port {nextjs_port}"
)
else:
logger.warning(
f"Web directory not found at {web_dir}, skipping NextJS startup"
)
def health_check(
self, sandbox_id: UUID, timeout: float = 60.0 # noqa: ARG002
self,
sandbox_id: UUID,
timeout: float = 60.0, # noqa: ARG002
) -> bool:
"""Check if the sandbox is healthy (folder exists).

View File

@@ -16,6 +16,9 @@ from onyx.server.features.build.configs import SANDBOX_BACKEND
from onyx.server.features.build.configs import SANDBOX_IDLE_TIMEOUT_SECONDS
from onyx.server.features.build.configs import SandboxBackend
from onyx.server.features.build.db.build_session import clear_nextjs_ports_for_user
from onyx.server.features.build.db.build_session import (
mark_user_sessions_idle__no_commit,
)
from onyx.server.features.build.db.sandbox import get_sandbox_by_user_id
from onyx.server.features.build.sandbox.base import get_sandbox_manager
from onyx.server.features.build.sandbox.kubernetes.kubernetes_sandbox_manager import (
@@ -75,12 +78,11 @@ def cleanup_idle_sandboxes_task(self: Task, *, tenant_id: str) -> None: # noqa:
try:
# Import here to avoid circular imports
from onyx.db.enums import SandboxStatus
from onyx.server.features.build.db.sandbox import create_snapshot
from onyx.server.features.build.db.sandbox import create_snapshot__no_commit
from onyx.server.features.build.db.sandbox import get_idle_sandboxes
from onyx.server.features.build.db.sandbox import (
update_sandbox_status__no_commit,
)
from onyx.server.features.build.sandbox import get_sandbox_manager
sandbox_manager = get_sandbox_manager()
@@ -128,7 +130,7 @@ def cleanup_idle_sandboxes_task(self: Task, *, tenant_id: str) -> None: # noqa:
)
if snapshot_result:
# Create DB record for the snapshot
create_snapshot(
create_snapshot__no_commit(
db_session,
session_id,
snapshot_result.storage_path,
@@ -154,7 +156,15 @@ def cleanup_idle_sandboxes_task(self: Task, *, tenant_id: str) -> None: # noqa:
f"{sandbox.user_id}"
)
# Mark sandbox as SLEEPING (not TERMINATED)
# Mark all active sessions as IDLE
idled = mark_user_sessions_idle__no_commit(
db_session, sandbox.user_id
)
task_logger.debug(
f"Marked {idled} sessions as IDLE for user "
f"{sandbox.user_id}"
)
update_sandbox_status__no_commit(
db_session, sandbox_id, SandboxStatus.SLEEPING
)
@@ -272,7 +282,7 @@ def sync_sandbox_files(
task_logger.debug(f"No sandbox found for user {user_id}, skipping sync")
return False
if sandbox.status not in [SandboxStatus.RUNNING, SandboxStatus.IDLE]:
if sandbox.status != SandboxStatus.RUNNING:
task_logger.debug(
f"Sandbox {sandbox.id} not running (status={sandbox.status}), "
f"skipping sync"

View File

@@ -1675,7 +1675,8 @@ class SessionManager:
user_id: The user ID to verify ownership
Returns:
Dict with has_webapp, webapp_url, and status, or None if session not found
Dict with has_webapp, webapp_url, status, and ready,
or None if session not found
"""
# Verify session ownership
session = get_build_session(session_id, user_id, self._db_session)
@@ -1684,20 +1685,51 @@ class SessionManager:
sandbox = get_sandbox_by_user_id(self._db_session, user_id)
if sandbox is None:
return {"has_webapp": False, "webapp_url": None, "status": "no_sandbox"}
return {
"has_webapp": False,
"webapp_url": None,
"status": "no_sandbox",
"ready": False,
}
# Return the proxy URL - the proxy handles routing to the correct sandbox
# for both local and Kubernetes environments
webapp_url = None
ready = False
if session.nextjs_port:
webapp_url = f"{WEB_DOMAIN}/api/build/sessions/{session_id}/webapp"
# Quick health check: can the API server reach the NextJS dev server?
ready = self._check_nextjs_ready(sandbox.id, session.nextjs_port)
return {
"has_webapp": session.nextjs_port is not None,
"webapp_url": webapp_url,
"status": sandbox.status.value,
"ready": ready,
}
def _check_nextjs_ready(self, sandbox_id: UUID, port: int) -> bool:
"""Check if the NextJS dev server is responding.
Does a quick HTTP GET to the sandbox's internal URL with a short timeout.
Returns True if the server responds with any status code, False on timeout
or connection error.
"""
import httpx
from onyx.server.features.build.sandbox.base import get_sandbox_manager
try:
sandbox_manager = get_sandbox_manager()
internal_url = sandbox_manager.get_webapp_url(sandbox_id, port)
with httpx.Client(timeout=2.0) as client:
resp = client.get(internal_url)
# Any response (even 500) means the server is up
return resp.status_code < 500
except (httpx.TimeoutException, httpx.ConnectError, Exception):
return False
def download_webapp_zip(
self,
session_id: UUID,

View File

@@ -1,9 +1,12 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from onyx.access.hierarchy_access import get_user_external_group_ids
from onyx.auth.users import current_user
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX
from onyx.configs.constants import DocumentSource
from onyx.db.document import get_accessible_documents_for_hierarchy_node_paginated
from onyx.db.engine.sql_engine import get_session
@@ -22,10 +25,25 @@ from onyx.server.features.hierarchy.models import HierarchyNodeDocumentsResponse
from onyx.server.features.hierarchy.models import HierarchyNodesResponse
from onyx.server.features.hierarchy.models import HierarchyNodeSummary
OPENSEARCH_NOT_ENABLED_MESSAGE = (
"Per-source knowledge selection is coming soon in v3.0! "
"OpenSearch indexing must be enabled to use this feature."
)
router = APIRouter(prefix=HIERARCHY_NODES_PREFIX)
def _require_opensearch() -> None:
if (
not ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
or not ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX
):
raise HTTPException(
status_code=403,
detail=OPENSEARCH_NOT_ENABLED_MESSAGE,
)
def _get_user_access_info(
user: User | None, db_session: Session
) -> tuple[str | None, list[str]]:
@@ -40,6 +58,7 @@ def list_accessible_hierarchy_nodes(
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> HierarchyNodesResponse:
_require_opensearch()
user_email, external_group_ids = _get_user_access_info(user, db_session)
nodes = get_accessible_hierarchy_nodes_for_source(
db_session=db_session,
@@ -66,6 +85,7 @@ def list_accessible_hierarchy_node_documents(
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> HierarchyNodeDocumentsResponse:
_require_opensearch()
user_email, external_group_ids = _get_user_access_info(user, db_session)
cursor = documents_request.cursor
sort_field = documents_request.sort_field

View File

@@ -255,7 +255,7 @@ def list_llm_providers(
llm_provider_list: list[LLMProviderView] = []
for llm_provider_model in fetch_existing_llm_providers(
db_session=db_session,
flow_types=[LLMModelFlowType.CHAT, LLMModelFlowType.VISION],
flow_type_filter=[],
exclude_image_generation_providers=not include_image_gen,
):
from_model_start = datetime.now(timezone.utc)
@@ -503,9 +503,7 @@ def list_llm_provider_basics(
start_time = datetime.now(timezone.utc)
logger.debug("Starting to fetch user-accessible LLM providers")
all_providers = fetch_existing_llm_providers(
db_session, [LLMModelFlowType.CHAT, LLMModelFlowType.VISION]
)
all_providers = fetch_existing_llm_providers(db_session, [])
user_group_ids = fetch_user_group_ids(db_session, user)
is_admin = user.role == UserRole.ADMIN
@@ -514,9 +512,9 @@ def list_llm_provider_basics(
for provider in all_providers:
# Use centralized access control logic with persona=None since we're
# listing providers without a specific persona context. This correctly:
# - Includes all public providers
# - Includes public providers WITHOUT persona restrictions
# - Includes providers user can access via group membership
# - Excludes persona-only restricted providers (requires specific persona)
# - Excludes providers with persona restrictions (requires specific persona)
# - Excludes non-public providers with no restrictions (admin-only)
if can_user_access_llm_provider(
provider, user_group_ids, persona=None, is_admin=is_admin
@@ -541,7 +539,7 @@ def get_valid_model_names_for_persona(
Returns a list of model names (e.g., ["gpt-4o", "claude-3-5-sonnet"]) that are
available to the user when using this persona, respecting all RBAC restrictions.
Public providers are always included.
Public providers are included unless they have persona restrictions that exclude this persona.
"""
persona = fetch_persona_with_groups(db_session, persona_id)
if not persona:
@@ -555,7 +553,7 @@ def get_valid_model_names_for_persona(
valid_models = []
for llm_provider_model in all_providers:
# Public providers always included, restricted checked via RBAC
# Check access with persona context — respects all RBAC restrictions
if can_user_access_llm_provider(
llm_provider_model, user_group_ids, persona, is_admin=is_admin
):
@@ -576,7 +574,7 @@ def list_llm_providers_for_persona(
"""Get LLM providers for a specific persona.
Returns providers that the user can access when using this persona:
- All public providers (is_public=True) - ALWAYS included
- Public providers (respecting persona restrictions if set)
- Restricted providers user can access via group/persona restrictions
This endpoint is used for background fetching of restricted providers
@@ -605,7 +603,7 @@ def list_llm_providers_for_persona(
llm_provider_list: list[LLMProviderDescriptor] = []
for llm_provider_model in all_providers:
# Use simplified access check - public providers always included
# Check access with persona context — respects persona restrictions
if can_user_access_llm_provider(
llm_provider_model, user_group_ids, persona, is_admin=is_admin
):

View File

@@ -30,12 +30,14 @@ from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import enforce_seat_limit
from onyx.auth.users import optional_user
from onyx.configs.app_configs import AUTH_BACKEND
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import AuthBackend
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import ENABLE_EMAIL_INVITES
from onyx.configs.app_configs import NUM_FREE_TRIAL_USER_INVITES
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from onyx.configs.app_configs import USER_AUTH_SECRET
@@ -90,6 +92,7 @@ from onyx.server.manage.models import UserSpecificAssistantPreferences
from onyx.server.models import FullUserSnapshot
from onyx.server.models import InvitedUserSnapshot
from onyx.server.models import MinimalUserSnapshot
from onyx.server.usage_limits import is_tenant_on_trial_fn
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
@@ -391,14 +394,20 @@ def bulk_invite_users(
if e not in existing_users and e not in already_invited
]
# Limit bulk invites for trial tenants to prevent email spam
# Only count new invites, not re-invites of existing users
if MULTI_TENANT and is_tenant_on_trial_fn(tenant_id):
current_invited = len(already_invited)
if current_invited + len(emails_needing_seats) > NUM_FREE_TRIAL_USER_INVITES:
raise HTTPException(
status_code=403,
detail="You have hit your invite limit. "
"Please upgrade for unlimited invites.",
)
# Check seat availability for new users
# Only for self-hosted (non-multi-tenant) deployments
if not MULTI_TENANT and emails_needing_seats:
result = fetch_ee_implementation_or_noop(
"onyx.db.license", "check_seat_availability", None
)(db_session, seats_needed=len(emails_needing_seats))
if result is not None and not result.available:
raise HTTPException(status_code=402, detail=result.error_message)
if emails_needing_seats:
enforce_seat_limit(db_session, seats_needed=len(emails_needing_seats))
if MULTI_TENANT:
try:
@@ -414,10 +423,10 @@ def bulk_invite_users(
all_emails = list(set(new_invited_emails) | set(initial_invited_users))
number_of_invited_users = write_invited_users(all_emails)
# send out email invitations if enabled
# send out email invitations only to new users (not already invited or existing)
if ENABLE_EMAIL_INVITES:
try:
for email in new_invited_emails:
for email in emails_needing_seats:
send_user_email_invite(email, current_user, AUTH_TYPE)
except Exception as e:
logger.error(f"Error sending email invite to invited users: {e}")
@@ -564,12 +573,7 @@ def activate_user_api(
# Check seat availability before activating
# Only for self-hosted (non-multi-tenant) deployments
if not MULTI_TENANT:
result = fetch_ee_implementation_or_noop(
"onyx.db.license", "check_seat_availability", None
)(db_session, seats_needed=1)
if result is not None and not result.available:
raise HTTPException(status_code=402, detail=result.error_message)
enforce_seat_limit(db_session)
activate_user(user_to_activate, db_session)
@@ -593,11 +597,17 @@ def get_valid_domains(
@router.get("/users", tags=PUBLIC_API_TAGS)
def list_all_users_basic_info(
include_api_keys: bool = False,
_: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[MinimalUserSnapshot]:
users = get_all_users(db_session)
return [MinimalUserSnapshot(id=user.id, email=user.email) for user in users]
return [
MinimalUserSnapshot(id=user.id, email=user.email)
for user in users
if user.role != UserRole.SLACK_USER
and (include_api_keys or not is_api_key_email_address(user.email))
]
@router.get("/get-user-role", tags=PUBLIC_API_TAGS)

View File

@@ -87,6 +87,8 @@ class SendMessageRequest(BaseModel):
message: str
llm_override: LLMOverride | None = None
# Test-only override for deterministic LiteLLM mock responses.
mock_llm_response: str | None = None
allowed_tool_ids: list[int] | None = None
forced_tool_id: int | None = None
@@ -191,6 +193,8 @@ class CreateChatMessageRequest(ChunkContext):
# allows the caller to override the Persona / Prompt
# these do not persist in the chat thread details
llm_override: LLMOverride | None = None
# Test-only override for deterministic LiteLLM mock responses.
mock_llm_response: str | None = None
prompt_override: PromptOverride | None = None
# Allows the caller to override the temperature for the chat session

View File

@@ -1,5 +1,6 @@
from enum import Enum
from typing import Annotated
from typing import Any
from typing import Literal
from typing import Union
@@ -37,6 +38,7 @@ class StreamingType(Enum):
REASONING_DELTA = "reasoning_delta"
REASONING_DONE = "reasoning_done"
CITATION_INFO = "citation_info"
TOOL_CALL_DEBUG = "tool_call_debug"
DEEP_RESEARCH_PLAN_START = "deep_research_plan_start"
DEEP_RESEARCH_PLAN_DELTA = "deep_research_plan_delta"
@@ -127,6 +129,14 @@ class CitationInfo(BaseObj):
document_id: str
class ToolCallDebug(BaseObj):
type: Literal["tool_call_debug"] = StreamingType.TOOL_CALL_DEBUG.value
tool_call_id: str
tool_name: str
tool_args: dict[str, Any]
################################################
# Tool Packets
################################################
@@ -318,6 +328,7 @@ PacketObj = Union[
ReasoningDone,
# Citation Packets
CitationInfo,
ToolCallDebug,
# Deep Research Packets
DeepResearchPlanStart,
DeepResearchPlanDelta,

View File

@@ -57,9 +57,11 @@ class Settings(BaseModel):
anonymous_user_enabled: bool | None = None
deep_research_enabled: bool | None = None
# Enterprise features flag - set by license enforcement at runtime
# When LICENSE_ENFORCEMENT_ENABLED=true, this reflects license status
# When LICENSE_ENFORCEMENT_ENABLED=false, defaults to False
# Whether EE features are unlocked for use.
# Depends on license status: True when the user has a valid license
# (ACTIVE, GRACE_PERIOD, PAYMENT_REMINDER), False when there's no license
# or the license is expired (GATED_ACCESS).
# This controls UI visibility of EE features (user groups, analytics, RBAC, etc.).
ee_features_enabled: bool = False
temperature_override_enabled: bool | None = False

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from onyx.file_processing.html_utils import ParsedHTML
from onyx.file_processing.html_utils import web_html_cleanup
@@ -21,10 +22,22 @@ from onyx.utils.web_content import title_from_url
logger = setup_logger()
DEFAULT_TIMEOUT_SECONDS = 15
DEFAULT_READ_TIMEOUT_SECONDS = 15
DEFAULT_CONNECT_TIMEOUT_SECONDS = 5
DEFAULT_USER_AGENT = "OnyxWebCrawler/1.0 (+https://www.onyx.app)"
DEFAULT_MAX_PDF_SIZE_BYTES = 50 * 1024 * 1024 # 50 MB
DEFAULT_MAX_HTML_SIZE_BYTES = 20 * 1024 * 1024 # 20 MB
DEFAULT_MAX_WORKERS = 5
def _failed_result(url: str) -> WebContent:
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
class OnyxWebCrawler(WebContentProvider):
@@ -37,12 +50,14 @@ class OnyxWebCrawler(WebContentProvider):
def __init__(
self,
*,
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
timeout_seconds: int = DEFAULT_READ_TIMEOUT_SECONDS,
connect_timeout_seconds: int = DEFAULT_CONNECT_TIMEOUT_SECONDS,
user_agent: str = DEFAULT_USER_AGENT,
max_pdf_size_bytes: int | None = None,
max_html_size_bytes: int | None = None,
) -> None:
self._timeout_seconds = timeout_seconds
self._read_timeout_seconds = timeout_seconds
self._connect_timeout_seconds = connect_timeout_seconds
self._max_pdf_size_bytes = max_pdf_size_bytes
self._max_html_size_bytes = max_html_size_bytes
self._headers = {
@@ -51,75 +66,68 @@ class OnyxWebCrawler(WebContentProvider):
}
def contents(self, urls: Sequence[str]) -> list[WebContent]:
results: list[WebContent] = []
for url in urls:
results.append(self._fetch_url(url))
return results
if not urls:
return []
max_workers = min(DEFAULT_MAX_WORKERS, len(urls))
with ThreadPoolExecutor(max_workers=max_workers) as executor:
return list(executor.map(self._fetch_url_safe, urls))
def _fetch_url_safe(self, url: str) -> WebContent:
"""Wrapper that catches all exceptions so one bad URL doesn't kill the batch."""
try:
return self._fetch_url(url)
except Exception as exc:
logger.warning(
"Onyx crawler unexpected error for %s (%s)",
url,
exc.__class__.__name__,
)
return _failed_result(url)
def _fetch_url(self, url: str) -> WebContent:
try:
# Use SSRF-safe request to prevent DNS rebinding attacks
response = ssrf_safe_get(
url, headers=self._headers, timeout=self._timeout_seconds
url,
headers=self._headers,
timeout=(self._connect_timeout_seconds, self._read_timeout_seconds),
)
except SSRFException as exc:
logger.error(
"SSRF protection blocked request to %s: %s",
"SSRF protection blocked request to %s (%s)",
url,
str(exc),
exc.__class__.__name__,
)
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
except Exception as exc: # pragma: no cover - network failures vary
return _failed_result(url)
except Exception as exc:
logger.warning(
"Onyx crawler failed to fetch %s (%s)",
url,
exc.__class__.__name__,
)
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
return _failed_result(url)
if response.status_code >= 400:
logger.warning("Onyx crawler received %s for %s", response.status_code, url)
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
return _failed_result(url)
content_type = response.headers.get("Content-Type", "")
content_sniff = response.content[:1024] if response.content else None
content = response.content
content_sniff = content[:1024] if content else None
if is_pdf_resource(url, content_type, content_sniff):
if (
self._max_pdf_size_bytes is not None
and len(response.content) > self._max_pdf_size_bytes
and len(content) > self._max_pdf_size_bytes
):
logger.warning(
"PDF content too large (%d bytes) for %s, max is %d",
len(response.content),
len(content),
url,
self._max_pdf_size_bytes,
)
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
text_content, metadata = extract_pdf_text(response.content)
return _failed_result(url)
text_content, metadata = extract_pdf_text(content)
title = title_from_pdf_metadata(metadata) or title_from_url(url)
return WebContent(
title=title,
@@ -131,25 +139,19 @@ class OnyxWebCrawler(WebContentProvider):
if (
self._max_html_size_bytes is not None
and len(response.content) > self._max_html_size_bytes
and len(content) > self._max_html_size_bytes
):
logger.warning(
"HTML content too large (%d bytes) for %s, max is %d",
len(response.content),
len(content),
url,
self._max_html_size_bytes,
)
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
return _failed_result(url)
try:
decoded_html = decode_html_bytes(
response.content,
content,
content_type=content_type,
fallback_encoding=response.apparent_encoding or response.encoding,
)

View File

@@ -47,6 +47,7 @@ from onyx.tools.tool_implementations.web_search.utils import (
from onyx.tools.tool_implementations.web_search.utils import MAX_CHARS_PER_URL
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.url import normalize_url as normalize_web_content_url
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
@@ -791,7 +792,9 @@ class OpenURLTool(Tool[OpenURLToolOverrideKwargs]):
for url in all_urls:
doc_id = url_to_doc_id.get(url)
indexed_section = indexed_by_doc_id.get(doc_id) if doc_id else None
crawled_section = crawled_by_url.get(url)
# WebContent.link is normalized (query/fragment stripped). Match on the
# same normalized form to avoid dropping successful crawl results.
crawled_section = crawled_by_url.get(normalize_web_content_url(url))
if indexed_section and indexed_section.combined_content:
# Prefer indexed

View File

@@ -0,0 +1,260 @@
from __future__ import annotations
from typing import Any
import requests
from fastapi import HTTPException
from onyx.tools.tool_implementations.web_search.models import (
WebSearchProvider,
)
from onyx.tools.tool_implementations.web_search.models import WebSearchResult
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
logger = setup_logger()
BRAVE_WEB_SEARCH_URL = "https://api.search.brave.com/res/v1/web/search"
BRAVE_MAX_RESULTS_PER_REQUEST = 20
BRAVE_SAFESEARCH_OPTIONS = {"off", "moderate", "strict"}
BRAVE_FRESHNESS_OPTIONS = {"pd", "pw", "pm", "py"}
class RetryableBraveSearchError(Exception):
"""Error type used to trigger retry for transient Brave search failures."""
class BraveClient(WebSearchProvider):
def __init__(
self,
api_key: str,
*,
num_results: int = 10,
timeout_seconds: int = 10,
country: str | None = None,
search_lang: str | None = None,
ui_lang: str | None = None,
safesearch: str | None = None,
freshness: str | None = None,
) -> None:
if timeout_seconds <= 0:
raise ValueError("Brave provider config 'timeout_seconds' must be > 0.")
self._headers = {
"Accept": "application/json",
"X-Subscription-Token": api_key,
}
logger.debug(f"Count of results passed to BraveClient: {num_results}")
self._num_results = max(1, min(num_results, BRAVE_MAX_RESULTS_PER_REQUEST))
self._timeout_seconds = timeout_seconds
self._country = _normalize_country(country)
self._search_lang = _normalize_language_code(
search_lang, field_name="search_lang"
)
self._ui_lang = _normalize_language_code(ui_lang, field_name="ui_lang")
self._safesearch = _normalize_option(
safesearch,
field_name="safesearch",
allowed_values=BRAVE_SAFESEARCH_OPTIONS,
)
self._freshness = _normalize_option(
freshness,
field_name="freshness",
allowed_values=BRAVE_FRESHNESS_OPTIONS,
)
def _build_search_params(self, query: str) -> dict[str, str]:
params = {
"q": query,
"count": str(self._num_results),
}
if self._country:
params["country"] = self._country
if self._search_lang:
params["search_lang"] = self._search_lang
if self._ui_lang:
params["ui_lang"] = self._ui_lang
if self._safesearch:
params["safesearch"] = self._safesearch
if self._freshness:
params["freshness"] = self._freshness
return params
@retry_builder(
tries=3,
delay=1,
backoff=2,
exceptions=(RetryableBraveSearchError,),
)
def _search_with_retries(self, query: str) -> list[WebSearchResult]:
params = self._build_search_params(query)
try:
response = requests.get(
BRAVE_WEB_SEARCH_URL,
headers=self._headers,
params=params,
timeout=self._timeout_seconds,
)
except requests.RequestException as exc:
raise RetryableBraveSearchError(
f"Brave search request failed: {exc}"
) from exc
try:
response.raise_for_status()
except requests.HTTPError as exc:
error_msg = _build_error_message(response)
if _is_retryable_status(response.status_code):
raise RetryableBraveSearchError(error_msg) from exc
raise ValueError(error_msg) from exc
data = response.json()
web_results = (data.get("web") or {}).get("results") or []
results: list[WebSearchResult] = []
for result in web_results:
if not isinstance(result, dict):
continue
link = _clean_string(result.get("url"))
if not link:
continue
title = _clean_string(result.get("title"))
description = _clean_string(result.get("description"))
results.append(
WebSearchResult(
title=title,
link=link,
snippet=description,
author=None,
published_date=None,
)
)
return results
def search(self, query: str) -> list[WebSearchResult]:
try:
return self._search_with_retries(query)
except RetryableBraveSearchError as exc:
raise ValueError(str(exc)) from exc
def test_connection(self) -> dict[str, str]:
try:
test_results = self.search("test")
if not test_results or not any(result.link for result in test_results):
raise HTTPException(
status_code=400,
detail="Brave API key validation failed: search returned no results.",
)
except HTTPException:
raise
except (ValueError, requests.RequestException) as e:
error_msg = str(e)
lower = error_msg.lower()
if (
"status 401" in lower
or "status 403" in lower
or "api key" in lower
or "auth" in lower
):
raise HTTPException(
status_code=400,
detail=f"Invalid Brave API key: {error_msg}",
) from e
if "status 429" in lower or "rate limit" in lower:
raise HTTPException(
status_code=400,
detail=f"Brave API rate limit exceeded: {error_msg}",
) from e
raise HTTPException(
status_code=400,
detail=f"Brave API key validation failed: {error_msg}",
) from e
logger.info("Web search provider test succeeded for Brave.")
return {"status": "ok"}
def _build_error_message(response: requests.Response) -> str:
return (
"Brave search failed "
f"(status {response.status_code}): {_extract_error_detail(response)}"
)
def _extract_error_detail(response: requests.Response) -> str:
try:
payload: Any = response.json()
except Exception:
text = response.text.strip()
return text[:200] if text else "No error details"
if isinstance(payload, dict):
error = payload.get("error")
if isinstance(error, dict):
detail = error.get("detail") or error.get("message")
if isinstance(detail, str):
return detail
if isinstance(error, str):
return error
message = payload.get("message")
if isinstance(message, str):
return message
return str(payload)[:200]
def _is_retryable_status(status_code: int) -> bool:
return status_code == 429 or status_code >= 500
def _clean_string(value: Any) -> str:
return value.strip() if isinstance(value, str) else ""
def _normalize_country(country: str | None) -> str | None:
if country is None:
return None
normalized = country.strip().upper()
if not normalized:
return None
if len(normalized) != 2 or not normalized.isalpha():
raise ValueError(
"Brave provider config 'country' must be a 2-letter ISO country code."
)
return normalized
def _normalize_language_code(value: str | None, *, field_name: str) -> str | None:
if value is None:
return None
normalized = value.strip()
if not normalized:
return None
if len(normalized) > 20:
raise ValueError(f"Brave provider config '{field_name}' is too long.")
return normalized
def _normalize_option(
value: str | None,
*,
field_name: str,
allowed_values: set[str],
) -> str | None:
if value is None:
return None
normalized = value.strip().lower()
if not normalized:
return None
if normalized not in allowed_values:
allowed = ", ".join(sorted(allowed_values))
raise ValueError(
f"Brave provider config '{field_name}' must be one of: {allowed}."
)
return normalized

View File

@@ -13,6 +13,9 @@ from onyx.tools.tool_implementations.open_url.onyx_web_crawler import (
DEFAULT_MAX_PDF_SIZE_BYTES,
)
from onyx.tools.tool_implementations.open_url.onyx_web_crawler import OnyxWebCrawler
from onyx.tools.tool_implementations.web_search.clients.brave_client import (
BraveClient,
)
from onyx.tools.tool_implementations.web_search.clients.exa_client import (
ExaClient,
)
@@ -35,16 +38,76 @@ from shared_configs.enums import WebSearchProviderType
logger = setup_logger()
def _parse_positive_int_config(
*,
raw_value: str | None,
default: int,
provider_name: str,
config_key: str,
) -> int:
if not raw_value:
return default
try:
value = int(raw_value)
except ValueError as exc:
raise ValueError(
f"{provider_name} provider config '{config_key}' must be an integer."
) from exc
if value <= 0:
raise ValueError(
f"{provider_name} provider config '{config_key}' must be greater than 0."
)
return value
def provider_requires_api_key(provider_type: WebSearchProviderType) -> bool:
"""Return True if the given provider type requires an API key.
This list is most likely just going to contain SEARXNG. The way it works is that it uses public search engines that do not
require an API key. You can also set it up in a way which requires a key but SearXNG itself does not require a key.
"""
return provider_type != WebSearchProviderType.SEARXNG
def build_search_provider_from_config(
provider_type: WebSearchProviderType,
api_key: str,
api_key: str | None,
config: dict[str, str] | None, # TODO use a typed object
) -> WebSearchProvider:
config = config or {}
num_results = int(config.get("num_results") or DEFAULT_MAX_RESULTS)
# SearXNG does not require an API key
if provider_type == WebSearchProviderType.SEARXNG:
searxng_base_url = config.get("searxng_base_url")
if not searxng_base_url:
raise ValueError("Please provide a URL for your private SearXNG instance.")
return SearXNGClient(
searxng_base_url,
num_results=num_results,
)
# All other providers require an API key
if not api_key:
raise ValueError(f"API key is required for {provider_type.value} provider.")
if provider_type == WebSearchProviderType.EXA:
return ExaClient(api_key=api_key, num_results=num_results)
if provider_type == WebSearchProviderType.BRAVE:
return BraveClient(
api_key=api_key,
num_results=num_results,
timeout_seconds=_parse_positive_int_config(
raw_value=config.get("timeout_seconds"),
default=10,
provider_name="Brave",
config_key="timeout_seconds",
),
country=config.get("country"),
search_lang=config.get("search_lang"),
ui_lang=config.get("ui_lang"),
safesearch=config.get("safesearch"),
freshness=config.get("freshness"),
)
if provider_type == WebSearchProviderType.SERPER:
return SerperClient(api_key=api_key, num_results=num_results)
if provider_type == WebSearchProviderType.GOOGLE_PSE:
@@ -64,20 +127,13 @@ def build_search_provider_from_config(
num_results=num_results,
timeout_seconds=int(config.get("timeout_seconds") or 10),
)
if provider_type == WebSearchProviderType.SEARXNG:
searxng_base_url = config.get("searxng_base_url")
if not searxng_base_url:
raise ValueError("Please provide a URL for your private SearXNG instance.")
return SearXNGClient(
searxng_base_url,
num_results=num_results,
)
raise ValueError(f"Unknown provider type: {provider_type.value}")
def _build_search_provider(provider_model: InternetSearchProvider) -> WebSearchProvider:
return build_search_provider_from_config(
provider_type=WebSearchProviderType(provider_model.provider_type),
api_key=provider_model.api_key or "",
api_key=provider_model.api_key,
config=provider_model.config or {},
)

View File

@@ -146,7 +146,7 @@ MAX_REDIRECTS = 10
def _make_ssrf_safe_request(
url: str,
headers: dict[str, str] | None = None,
timeout: int = 15,
timeout: float | tuple[float, float] = 15,
**kwargs: Any,
) -> requests.Response:
"""
@@ -204,7 +204,7 @@ def _make_ssrf_safe_request(
def ssrf_safe_get(
url: str,
headers: dict[str, str] | None = None,
timeout: int = 15,
timeout: float | tuple[float, float] = 15,
follow_redirects: bool = True,
**kwargs: Any,
) -> requests.Response:

View File

@@ -36,7 +36,7 @@ global_version = OnyxVersion()
# Eventually, ENABLE_PAID_ENTERPRISE_EDITION_FEATURES will be removed
# and license enforcement will be the only mechanism for EE features.
_LICENSE_ENFORCEMENT_ENABLED = (
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "").lower() == "true"
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "true").lower() == "true"
)

View File

@@ -265,7 +265,7 @@ fastapi==0.128.0
# onyx
fastapi-limiter==0.1.6
# via onyx
fastapi-users==15.0.2
fastapi-users==15.0.4
# via
# fastapi-users-db-sqlalchemy
# onyx
@@ -362,23 +362,14 @@ greenlet==3.2.4
# sqlalchemy
grpc-google-iam-v1==0.14.3
# via google-cloud-resource-manager
grpcio==1.67.1 ; python_full_version < '3.14'
grpcio==1.76.0
# via
# google-api-core
# google-cloud-resource-manager
# googleapis-common-protos
# grpc-google-iam-v1
# grpcio-status
grpcio==1.76.0 ; python_full_version >= '3.14'
# via
# google-api-core
# google-cloud-resource-manager
# googleapis-common-protos
# grpc-google-iam-v1
# grpcio-status
grpcio-status==1.67.1 ; python_full_version < '3.14'
# via google-api-core
grpcio-status==1.76.0 ; python_full_version >= '3.14'
grpcio-status==1.76.0
# via google-api-core
h11==0.16.0
# via
@@ -762,19 +753,7 @@ proto-plus==1.26.1
# google-api-core
# google-cloud-aiplatform
# google-cloud-resource-manager
protobuf==5.29.5 ; python_full_version < '3.14'
# via
# ddtrace
# google-api-core
# google-cloud-aiplatform
# google-cloud-resource-manager
# googleapis-common-protos
# grpc-google-iam-v1
# grpcio-status
# onnxruntime
# opentelemetry-proto
# proto-plus
protobuf==6.33.4 ; python_full_version >= '3.14'
protobuf==6.33.5
# via
# ddtrace
# google-api-core
@@ -850,7 +829,7 @@ pygithub==2.5.0
# via onyx
pygments==2.19.2
# via rich
pyjwt==2.10.1
pyjwt==2.11.0
# via
# fastapi-users
# mcp
@@ -919,7 +898,7 @@ python-json-logger==4.0.0
# via pydocket
python-magic==0.4.27
# via unstructured
python-multipart==0.0.21
python-multipart==0.0.22
# via
# fastapi-users
# mcp

View File

@@ -123,7 +123,7 @@ execnet==2.1.2
# via pytest-xdist
executing==2.2.1
# via stack-data
faker==37.1.0
faker==40.1.2
# via onyx
fastapi==0.128.0
# via
@@ -195,23 +195,14 @@ greenlet==3.2.4 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or
# via sqlalchemy
grpc-google-iam-v1==0.14.3
# via google-cloud-resource-manager
grpcio==1.67.1 ; python_full_version < '3.14'
grpcio==1.76.0
# via
# google-api-core
# google-cloud-resource-manager
# googleapis-common-protos
# grpc-google-iam-v1
# grpcio-status
grpcio==1.76.0 ; python_full_version >= '3.14'
# via
# google-api-core
# google-cloud-resource-manager
# googleapis-common-protos
# grpc-google-iam-v1
# grpcio-status
grpcio-status==1.67.1 ; python_full_version < '3.14'
# via google-api-core
grpcio-status==1.76.0 ; python_full_version >= '3.14'
grpcio-status==1.76.0
# via google-api-core
h11==0.16.0
# via
@@ -326,7 +317,7 @@ oauthlib==3.2.2
# via
# kubernetes
# requests-oauthlib
onyx-devtools==0.5.0
onyx-devtools==0.6.2
# via onyx
openai==2.14.0
# via
@@ -388,16 +379,7 @@ proto-plus==1.26.1
# google-api-core
# google-cloud-aiplatform
# google-cloud-resource-manager
protobuf==5.29.5 ; python_full_version < '3.14'
# via
# google-api-core
# google-cloud-aiplatform
# google-cloud-resource-manager
# googleapis-common-protos
# grpc-google-iam-v1
# grpcio-status
# proto-plus
protobuf==6.33.4 ; python_full_version >= '3.14'
protobuf==6.33.5
# via
# google-api-core
# google-cloud-aiplatform
@@ -442,7 +424,7 @@ pygments==2.19.2
# via
# ipython
# ipython-pygments-lexers
pyjwt==2.10.1
pyjwt==2.11.0
# via mcp
pyparsing==3.2.5
# via matplotlib
@@ -462,7 +444,7 @@ pytest-dotenv==0.5.2
# via onyx
pytest-repeat==0.9.4
# via onyx
pytest-xdist==3.6.1
pytest-xdist==3.8.0
# via onyx
python-dateutil==2.8.2
# via
@@ -477,7 +459,7 @@ python-dotenv==1.1.1
# litellm
# pydantic-settings
# pytest-dotenv
python-multipart==0.0.21
python-multipart==0.0.22
# via mcp
pywin32==311 ; sys_platform == 'win32'
# via mcp
@@ -640,7 +622,7 @@ typing-inspection==0.4.2
# mcp
# pydantic
# pydantic-settings
tzdata==2025.2
tzdata==2025.2 ; sys_platform == 'win32'
# via faker
urllib3==2.6.3
# via

View File

@@ -152,23 +152,14 @@ googleapis-common-protos==1.72.0
# grpcio-status
grpc-google-iam-v1==0.14.3
# via google-cloud-resource-manager
grpcio==1.67.1 ; python_full_version < '3.14'
grpcio==1.76.0
# via
# google-api-core
# google-cloud-resource-manager
# googleapis-common-protos
# grpc-google-iam-v1
# grpcio-status
grpcio==1.76.0 ; python_full_version >= '3.14'
# via
# google-api-core
# google-cloud-resource-manager
# googleapis-common-protos
# grpc-google-iam-v1
# grpcio-status
grpcio-status==1.67.1 ; python_full_version < '3.14'
# via google-api-core
grpcio-status==1.76.0 ; python_full_version >= '3.14'
grpcio-status==1.76.0
# via google-api-core
h11==0.16.0
# via
@@ -265,16 +256,7 @@ proto-plus==1.26.1
# google-api-core
# google-cloud-aiplatform
# google-cloud-resource-manager
protobuf==5.29.5 ; python_full_version < '3.14'
# via
# google-api-core
# google-cloud-aiplatform
# google-cloud-resource-manager
# googleapis-common-protos
# grpc-google-iam-v1
# grpcio-status
# proto-plus
protobuf==6.33.4 ; python_full_version >= '3.14'
protobuf==6.33.5
# via
# google-api-core
# google-cloud-aiplatform
@@ -309,7 +291,7 @@ pydantic-core==2.33.2
# via pydantic
pydantic-settings==2.12.0
# via mcp
pyjwt==2.10.1
pyjwt==2.11.0
# via mcp
python-dateutil==2.8.2
# via
@@ -322,7 +304,7 @@ python-dotenv==1.1.1
# via
# litellm
# pydantic-settings
python-multipart==0.0.21
python-multipart==0.0.22
# via mcp
pywin32==311 ; sys_platform == 'win32'
# via mcp

View File

@@ -177,23 +177,14 @@ googleapis-common-protos==1.72.0
# grpcio-status
grpc-google-iam-v1==0.14.3
# via google-cloud-resource-manager
grpcio==1.67.1 ; python_full_version < '3.14'
grpcio==1.76.0
# via
# google-api-core
# google-cloud-resource-manager
# googleapis-common-protos
# grpc-google-iam-v1
# grpcio-status
grpcio==1.76.0 ; python_full_version >= '3.14'
# via
# google-api-core
# google-cloud-resource-manager
# googleapis-common-protos
# grpc-google-iam-v1
# grpcio-status
grpcio-status==1.67.1 ; python_full_version < '3.14'
# via google-api-core
grpcio-status==1.76.0 ; python_full_version >= '3.14'
grpcio-status==1.76.0
# via google-api-core
h11==0.16.0
# via
@@ -351,16 +342,7 @@ proto-plus==1.26.1
# google-api-core
# google-cloud-aiplatform
# google-cloud-resource-manager
protobuf==5.29.5 ; python_full_version < '3.14'
# via
# google-api-core
# google-cloud-aiplatform
# google-cloud-resource-manager
# googleapis-common-protos
# grpc-google-iam-v1
# grpcio-status
# proto-plus
protobuf==6.33.4 ; python_full_version >= '3.14'
protobuf==6.33.5
# via
# google-api-core
# google-cloud-aiplatform
@@ -397,7 +379,7 @@ pydantic-core==2.33.2
# via pydantic
pydantic-settings==2.12.0
# via mcp
pyjwt==2.10.1
pyjwt==2.11.0
# via mcp
python-dateutil==2.8.2
# via
@@ -410,7 +392,7 @@ python-dotenv==1.1.1
# via
# litellm
# pydantic-settings
python-multipart==0.0.21
python-multipart==0.0.22
# via mcp
pywin32==311 ; sys_platform == 'win32'
# via mcp

View File

@@ -26,6 +26,7 @@ class WebSearchProviderType(str, Enum):
SERPER = "serper"
EXA = "exa"
SEARXNG = "searxng"
BRAVE = "brave"
class WebContentProviderType(str, Enum):

View File

@@ -0,0 +1,281 @@
"""
External dependency unit tests for user file processing queue protections.
Verifies that the three mechanisms added to check_user_file_processing work
correctly:
1. Queue depth backpressure when the broker queue exceeds
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH, no new tasks are enqueued.
2. Per-file Redis guard key if the guard key for a file already exists in
Redis, that file is skipped even though it is still in PROCESSING status.
3. Task expiry every send_task call carries expires=
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES so that stale queued tasks are
discarded by workers automatically.
Also verifies that process_single_user_file clears the guard key the moment
it is picked up by a worker.
Uses real Redis (DB 0 via get_redis_client) and real PostgreSQL for UserFile
rows. The Celery app is provided as a MagicMock injected via a PropertyMock
on the task class so no real broker is needed.
"""
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from unittest.mock import PropertyMock
from uuid import uuid4
from sqlalchemy.orm import Session
from onyx.background.celery.tasks.user_file_processing.tasks import (
_user_file_lock_key,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
_user_file_queued_key,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
check_user_file_processing,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
process_single_user_file,
)
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
from onyx.db.enums import UserFileStatus
from onyx.db.models import UserFile
from onyx.redis.redis_pool import get_redis_client
from tests.external_dependency_unit.conftest import create_test_user
from tests.external_dependency_unit.constants import TEST_TENANT_ID
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_PATCH_QUEUE_LEN = (
"onyx.background.celery.tasks.user_file_processing.tasks.celery_get_queue_length"
)
def _create_processing_user_file(db_session: Session, user_id: object) -> UserFile:
"""Insert a UserFile in PROCESSING status and return it."""
uf = UserFile(
id=uuid4(),
user_id=user_id,
file_id=f"test_file_{uuid4().hex[:8]}",
name=f"test_{uuid4().hex[:8]}.txt",
file_type="text/plain",
status=UserFileStatus.PROCESSING,
)
db_session.add(uf)
db_session.commit()
db_session.refresh(uf)
return uf
@contextmanager
def _patch_task_app(task: Any, mock_app: MagicMock) -> Generator[None, None, None]:
"""Patch the ``app`` property on *task*'s class so that ``self.app``
inside the task function returns *mock_app*.
With ``bind=True``, ``task.run`` is a bound method whose ``__self__`` is
the actual task instance. We patch ``app`` on that instance's class
(a unique Celery-generated Task subclass) so the mock is scoped to this
task only.
"""
task_instance = task.run.__self__
with patch.object(
type(task_instance), "app", new_callable=PropertyMock, return_value=mock_app
):
yield
# ---------------------------------------------------------------------------
# Test classes
# ---------------------------------------------------------------------------
class TestQueueDepthBackpressure:
"""Protection 1: skip all enqueuing when the broker queue is too deep."""
def test_no_tasks_enqueued_when_queue_over_limit(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""When the queue depth exceeds the limit the beat cycle is skipped."""
user = create_test_user(db_session, "bp_user")
_create_processing_user_file(db_session, user.id)
mock_app = MagicMock()
with (
_patch_task_app(check_user_file_processing, mock_app),
patch(
_PATCH_QUEUE_LEN, return_value=USER_FILE_PROCESSING_MAX_QUEUE_DEPTH + 1
),
):
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
mock_app.send_task.assert_not_called()
class TestPerFileGuardKey:
"""Protection 2: per-file Redis guard key prevents duplicate enqueue."""
def test_guarded_file_not_re_enqueued(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""A file whose guard key is already set in Redis is skipped."""
user = create_test_user(db_session, "guard_user")
uf = _create_processing_user_file(db_session, user.id)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_queued_key(uf.id)
redis_client.setex(guard_key, CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, 1)
mock_app = MagicMock()
try:
with (
_patch_task_app(check_user_file_processing, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=0),
):
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
# send_task must not have been called with this specific file's ID
for call in mock_app.send_task.call_args_list:
kwargs = call.kwargs.get("kwargs", {})
assert kwargs.get("user_file_id") != str(
uf.id
), f"File {uf.id} should have been skipped because its guard key exists"
finally:
redis_client.delete(guard_key)
def test_guard_key_exists_in_redis_after_enqueue(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""After a file is enqueued its guard key is present in Redis with a TTL."""
user = create_test_user(db_session, "guard_set_user")
uf = _create_processing_user_file(db_session, user.id)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_queued_key(uf.id)
redis_client.delete(guard_key) # clean slate
mock_app = MagicMock()
try:
with (
_patch_task_app(check_user_file_processing, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=0),
):
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
assert redis_client.exists(
guard_key
), "Guard key should be set in Redis after enqueue"
ttl = int(redis_client.ttl(guard_key)) # type: ignore[arg-type]
assert 0 < ttl <= CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, (
f"Guard key TTL {ttl}s is outside the expected range "
f"(0, {CELERY_USER_FILE_PROCESSING_TASK_EXPIRES}]"
)
finally:
redis_client.delete(guard_key)
class TestTaskExpiry:
"""Protection 3: every send_task call includes an expires value."""
def test_send_task_called_with_expires(
self,
db_session: Session,
tenant_context: None, # noqa: ARG002
) -> None:
"""send_task is called with the correct queue, task name, and expires."""
user = create_test_user(db_session, "expires_user")
uf = _create_processing_user_file(db_session, user.id)
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_queued_key(uf.id)
redis_client.delete(guard_key)
mock_app = MagicMock()
try:
with (
_patch_task_app(check_user_file_processing, mock_app),
patch(_PATCH_QUEUE_LEN, return_value=0),
):
check_user_file_processing.run(tenant_id=TEST_TENANT_ID)
# At least one task should have been submitted (for our file)
assert (
mock_app.send_task.call_count >= 1
), "Expected at least one task to be submitted"
# Every submitted task must carry expires
for call in mock_app.send_task.call_args_list:
assert call.args[0] == OnyxCeleryTask.PROCESS_SINGLE_USER_FILE
assert call.kwargs.get("queue") == OnyxCeleryQueues.USER_FILE_PROCESSING
assert (
call.kwargs.get("expires")
== CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
), (
"Task must be submitted with the correct expires value to prevent "
"stale task accumulation"
)
finally:
redis_client.delete(guard_key)
class TestWorkerClearsGuardKey:
"""process_single_user_file removes the guard key when it picks up a task."""
def test_guard_key_deleted_on_pickup(
self,
tenant_context: None, # noqa: ARG002
) -> None:
"""The guard key is deleted before the worker does any real work.
We simulate an already-locked file so process_single_user_file returns
early but crucially, after the guard key deletion.
"""
user_file_id = str(uuid4())
redis_client = get_redis_client(tenant_id=TEST_TENANT_ID)
guard_key = _user_file_queued_key(user_file_id)
# Simulate the guard key set when the beat enqueued the task
redis_client.setex(guard_key, CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, 1)
assert redis_client.exists(guard_key), "Guard key must exist before pickup"
# Hold the per-file processing lock so the worker exits early without
# touching the database or file store.
lock_key = _user_file_lock_key(user_file_id)
processing_lock = redis_client.lock(lock_key, timeout=10)
acquired = processing_lock.acquire(blocking=False)
assert acquired, "Should be able to acquire the processing lock for this test"
try:
process_single_user_file.run(
user_file_id=user_file_id,
tenant_id=TEST_TENANT_ID,
)
finally:
if processing_lock.owned():
processing_lock.release()
assert not redis_client.exists(
guard_key
), "Guard key should be deleted when the worker picks up the task"

View File

@@ -553,7 +553,7 @@ class TestDefaultProviderEndpoint:
try:
existing_providers = fetch_existing_llm_providers(
db_session, flow_types=[LLMModelFlowType.CHAT]
db_session, flow_type_filter=[LLMModelFlowType.CHAT]
)
provider_names_to_restore: list[str] = []

View File

@@ -14,9 +14,12 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from onyx.db.enums import LLMModelFlowType
from onyx.db.llm import fetch_default_llm_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import sync_auto_mode_models
from onyx.db.llm import update_default_provider
from onyx.db.models import UserRole
from onyx.llm.constants import LlmProviderNames
@@ -606,3 +609,95 @@ class TestAutoModeSyncFeature:
db_session.rollback()
_cleanup_provider(db_session, provider_1_name)
_cleanup_provider(db_session, provider_2_name)
class TestAutoModeMissingFlows:
"""Regression test: sync_auto_mode_models must create LLMModelFlow rows
for every ModelConfiguration it inserts, otherwise the provider vanishes
from listing queries that join through LLMModelFlow."""
def test_sync_auto_mode_creates_flow_rows(
self,
db_session: Session,
provider_name: str,
) -> None:
"""
Steps:
1. Create a provider with no model configs (empty shell).
2. Call sync_auto_mode_models to add models from a mock config.
3. Assert every new ModelConfiguration has at least one LLMModelFlow.
4. Assert fetch_existing_llm_providers (which joins through
LLMModelFlow) returns the provider.
"""
mock_recommendations = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o",
additional_models=["gpt-4o-mini"],
)
try:
# Step 1: Create provider with no model configs
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
default_model_name="gpt-4o",
model_configurations=[],
),
is_creation=True,
_=_create_mock_admin(),
db_session=db_session,
)
# Step 2: Run sync_auto_mode_models (simulating the periodic sync)
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
sync_auto_mode_models(
db_session=db_session,
provider=provider,
llm_recommendations=mock_recommendations,
)
# Step 3: Every ModelConfiguration must have at least one LLMModelFlow
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
synced_model_names = {mc.name for mc in provider.model_configurations}
assert "gpt-4o" in synced_model_names
assert "gpt-4o-mini" in synced_model_names
for mc in provider.model_configurations:
assert len(mc.llm_model_flows) > 0, (
f"ModelConfiguration '{mc.name}' (id={mc.id}) has no "
f"LLMModelFlow rows — it will be invisible to listing queries"
)
flow_types = {f.llm_model_flow_type for f in mc.llm_model_flows}
assert (
LLMModelFlowType.CHAT in flow_types
), f"ModelConfiguration '{mc.name}' is missing a CHAT flow"
# Step 4: The provider must appear in fetch_existing_llm_providers
listed_providers = fetch_existing_llm_providers(
db_session=db_session,
flow_type_filter=[LLMModelFlowType.CHAT],
)
listed_provider_names = {p.name for p in listed_providers}
assert provider_name in listed_provider_names, (
f"Provider '{provider_name}' not returned by "
f"fetch_existing_llm_providers — models are missing flow rows"
)
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)

View File

@@ -14,6 +14,9 @@ from unittest.mock import patch
import pytest
from sqlalchemy.orm import Session
from onyx.background.celery.tasks.opensearch_migration.constants import (
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE,
)
from onyx.background.celery.tasks.opensearch_migration.tasks import (
check_for_documents_for_opensearch_migration_task,
)
@@ -25,14 +28,12 @@ from onyx.configs.constants import SOURCE_TYPE
from onyx.context.search.models import IndexFilters
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import OpenSearchDocumentMigrationStatus
from onyx.db.enums import OpenSearchTenantMigrationStatus
from onyx.db.models import Document
from onyx.db.models import OpenSearchDocumentMigrationRecord
from onyx.db.models import OpenSearchTenantMigrationRecord
from onyx.db.opensearch_migration import create_opensearch_migration_records_with_commit
from onyx.db.opensearch_migration import get_last_opensearch_migration_document_id
from onyx.db.opensearch_migration import (
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE,
)
from onyx.db.search_settings import get_active_search_settings
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.client import OpenSearchClient
@@ -525,6 +526,40 @@ class TestCheckForDocumentsForOpenSearchMigrationTask:
>= 1
)
def test_creates_singleton_migration_record(
self,
db_session: Session,
clean_migration_tables: None, # noqa: ARG002
enable_opensearch_indexing_for_onyx: None, # noqa: ARG002
) -> None:
"""Tests that singleton migration record is created."""
# Under test.
result = check_for_documents_for_opensearch_migration_task(
tenant_id=get_current_tenant_id()
)
# Postcondition.
assert result is True
# Expire the session cache to see the committed changes from the task.
db_session.expire_all()
# Verify the singleton migration record was created.
tenant_record = db_session.query(OpenSearchTenantMigrationRecord).first()
assert tenant_record is not None
assert (
tenant_record.document_migration_record_table_population_status
== OpenSearchTenantMigrationStatus.PENDING
)
assert (
tenant_record.num_times_observed_no_additional_docs_to_populate_migration_table
== 1
)
assert (
tenant_record.overall_document_migration_status
== OpenSearchTenantMigrationStatus.PENDING
)
assert tenant_record.num_times_observed_no_additional_docs_to_migrate == 0
assert tenant_record.last_updated_at is not None
class TestMigrateDocumentsFromVespaToOpenSearchTask:
"""Tests migrate_documents_from_vespa_to_opensearch_task."""
@@ -665,7 +700,11 @@ class TestMigrateDocumentsFromVespaToOpenSearchTask:
.first()
)
assert record is not None
assert record.status == OpenSearchDocumentMigrationStatus.FAILED
# In practice the task keeps trying docs until it either runs out of
# time or the lock is lost, which will not happen during this test.
# Because of this the migration record will just shift to permanently
# failed. Let's just test for that here.
assert record.status == OpenSearchDocumentMigrationStatus.PERMANENTLY_FAILED
# Verify chunks were indexed in OpenSearch.
for document_id in doc_ids_that_have_chunks:
chunks = _get_document_chunks_from_opensearch(
@@ -764,7 +803,11 @@ class TestMigrateDocumentsFromVespaToOpenSearchTask:
.first()
)
assert record is not None
assert record.status == OpenSearchDocumentMigrationStatus.FAILED
# In practice the task keeps trying docs until it either runs out of
# time or the lock is lost, which will not happen during this test.
# Because of this the migration record will just shift to permanently
# failed. Let's just test for that here.
assert record.status == OpenSearchDocumentMigrationStatus.PERMANENTLY_FAILED
assert record.error_message is not None
assert "no chunk count" in record.error_message.lower()

View File

@@ -17,7 +17,8 @@ COPY ./tests/* /app/tests/
FROM base AS openapi-schema
COPY ./scripts/onyx_openapi_schema.py /app/scripts/onyx_openapi_schema.py
RUN python scripts/onyx_openapi_schema.py --filename openapi.json
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
RUN LICENSE_ENFORCEMENT_ENABLED=false python scripts/onyx_openapi_schema.py --filename openapi.json
FROM openapitools/openapi-generator-cli:latest AS openapi-client
WORKDIR /local

View File

@@ -24,6 +24,7 @@ from tests.integration.common_utils.test_models import DATestChatSession
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.test_models import ErrorResponse
from tests.integration.common_utils.test_models import StreamedResponse
from tests.integration.common_utils.test_models import ToolCallDebug
from tests.integration.common_utils.test_models import ToolName
from tests.integration.common_utils.test_models import ToolResult
@@ -40,6 +41,7 @@ class StreamPacketObj(TypedDict, total=False):
"image_generation_start",
"image_generation_heartbeat",
"image_generation_final",
"tool_call_debug",
]
content: str
final_documents: list[dict[str, Any]]
@@ -47,6 +49,9 @@ class StreamPacketObj(TypedDict, total=False):
images: list[dict[str, Any]]
queries: list[str]
documents: list[dict[str, Any]]
tool_call_id: str
tool_name: str
tool_args: dict[str, Any]
class PlacementData(TypedDict, total=False):
@@ -109,6 +114,7 @@ class ChatSessionManager:
use_existing_user_message: bool = False,
forced_tool_ids: list[int] | None = None,
chat_session: DATestChatSession | None = None,
mock_llm_response: str | None = None,
) -> StreamedResponse:
chat_message_req = CreateChatMessageRequest(
chat_session_id=chat_session_id,
@@ -120,6 +126,7 @@ class ChatSessionManager:
query_override=query_override,
regenerate=regenerate,
llm_override=llm_override,
mock_llm_response=mock_llm_response,
prompt_override=prompt_override,
alternate_assistant_id=alternate_assistant_id,
use_existing_user_message=use_existing_user_message,
@@ -179,6 +186,7 @@ class ChatSessionManager:
alternate_assistant_id: int | None = None,
use_existing_user_message: bool = False,
forced_tool_ids: list[int] | None = None,
mock_llm_response: str | None = None,
) -> None:
"""
Send a message and simulate client disconnect before stream completes.
@@ -210,6 +218,7 @@ class ChatSessionManager:
query_override=query_override,
regenerate=regenerate,
llm_override=llm_override,
mock_llm_response=mock_llm_response,
prompt_override=prompt_override,
alternate_assistant_id=alternate_assistant_id,
use_existing_user_message=use_existing_user_message,
@@ -253,6 +262,7 @@ class ChatSessionManager:
],
)
ind_to_tool_use: dict[int, ToolResult] = {}
tool_call_debug: list[ToolCallDebug] = []
top_documents: list[SearchDoc] = []
heartbeat_packets: list[StreamPacketData] = []
full_message = ""
@@ -330,6 +340,16 @@ class ChatSessionManager:
SavedSearchDoc.from_search_doc(search_doc, db_doc_id=0)
)
ind_to_tool_use[ind].documents.extend(docs)
elif packet_type_str == StreamingType.TOOL_CALL_DEBUG.value:
tool_call_debug.append(
ToolCallDebug(
tool_call_id=str(data_obj.get("tool_call_id", "")),
tool_name=str(data_obj.get("tool_name", "")),
tool_args=cast(
dict[str, Any], data_obj.get("tool_args") or {}
),
)
)
# If there's an error, assistant_message_id might not be present
if not assistant_message_id and not error:
raise ValueError("Assistant message id not found")
@@ -338,6 +358,7 @@ class ChatSessionManager:
assistant_message_id=assistant_message_id or -1, # Use -1 for error cases
top_documents=top_documents,
used_tools=list(ind_to_tool_use.values()),
tool_call_debug=tool_call_debug,
heartbeat_packets=[dict(packet) for packet in heartbeat_packets],
error=error,
)

View File

@@ -202,6 +202,12 @@ class ToolResult(BaseModel):
images: list[GeneratedImage] = Field(default_factory=list)
class ToolCallDebug(BaseModel):
tool_call_id: str
tool_name: str
tool_args: dict[str, Any]
class ErrorResponse(BaseModel):
error: str
stack_trace: str
@@ -212,6 +218,7 @@ class StreamedResponse(BaseModel):
assistant_message_id: int
top_documents: list[SearchDoc]
used_tools: list[ToolResult]
tool_call_debug: list[ToolCallDebug] = Field(default_factory=list)
error: ErrorResponse | None = None
# Track heartbeat packets for image generation and other tools

View File

@@ -34,9 +34,7 @@ def _schema_exists(schema_name: str) -> bool:
class TestTenantProvisioningRollback:
"""Integration tests for provisioning failure and rollback."""
def test_failed_provisioning_cleans_up_schema(
self, reset_multitenant: None
) -> None:
def test_failed_provisioning_cleans_up_schema(self) -> None:
"""
When setup_tenant fails after schema creation, rollback should
clean up the orphaned schema.
@@ -79,9 +77,7 @@ class TestTenantProvisioningRollback:
created_tenant_id
), f"Schema {created_tenant_id} should have been rolled back"
def test_drop_schema_works_with_uuid_tenant_id(
self, reset_multitenant: None
) -> None:
def test_drop_schema_works_with_uuid_tenant_id(self) -> None:
"""
drop_schema should work with UUID-format tenant IDs.

View File

@@ -240,6 +240,116 @@ def test_can_user_access_llm_provider_or_logic(
)
def test_public_provider_with_persona_restrictions(
users: tuple[DATestUser, DATestUser],
) -> None:
"""Public providers should still enforce persona restrictions.
Regression test for the bug where is_public=True caused
can_user_access_llm_provider() to return True immediately,
bypassing persona whitelist checks entirely.
"""
admin_user, _basic_user = users
with get_session_with_current_tenant() as db_session:
# Public provider with persona restrictions
public_restricted = _create_llm_provider(
db_session,
name="public-persona-restricted",
default_model_name="gpt-4o",
is_public=True,
is_default=True,
)
whitelisted_persona = _create_persona(
db_session,
name="whitelisted-persona",
provider_name=public_restricted.name,
)
non_whitelisted_persona = _create_persona(
db_session,
name="non-whitelisted-persona",
provider_name=public_restricted.name,
)
# Only whitelist one persona
db_session.add(
LLMProvider__Persona(
llm_provider_id=public_restricted.id,
persona_id=whitelisted_persona.id,
)
)
db_session.flush()
db_session.refresh(public_restricted)
admin_model = db_session.get(User, admin_user.id)
assert admin_model is not None
admin_group_ids = fetch_user_group_ids(db_session, admin_model)
# Whitelisted persona — should be allowed
assert can_user_access_llm_provider(
public_restricted,
admin_group_ids,
whitelisted_persona,
)
# Non-whitelisted persona — should be denied despite is_public=True
assert not can_user_access_llm_provider(
public_restricted,
admin_group_ids,
non_whitelisted_persona,
)
# No persona context (e.g. global provider list) — should be denied
# because provider has persona restrictions set
assert not can_user_access_llm_provider(
public_restricted,
admin_group_ids,
persona=None,
)
def test_public_provider_without_persona_restrictions(
users: tuple[DATestUser, DATestUser],
) -> None:
"""Public providers with no persona restrictions remain accessible to all."""
admin_user, basic_user = users
with get_session_with_current_tenant() as db_session:
public_unrestricted = _create_llm_provider(
db_session,
name="public-unrestricted",
default_model_name="gpt-4o",
is_public=True,
is_default=True,
)
any_persona = _create_persona(
db_session,
name="any-persona",
provider_name=public_unrestricted.name,
)
admin_model = db_session.get(User, admin_user.id)
basic_model = db_session.get(User, basic_user.id)
assert admin_model is not None
assert basic_model is not None
admin_group_ids = fetch_user_group_ids(db_session, admin_model)
basic_group_ids = fetch_user_group_ids(db_session, basic_model)
# Any user, any persona — all allowed
assert can_user_access_llm_provider(
public_unrestricted, admin_group_ids, any_persona
)
assert can_user_access_llm_provider(
public_unrestricted, basic_group_ids, any_persona
)
assert can_user_access_llm_provider(
public_unrestricted, admin_group_ids, persona=None
)
def test_get_llm_for_persona_falls_back_when_access_denied(
users: tuple[DATestUser, DATestUser],
) -> None:

View File

@@ -0,0 +1,234 @@
import io
import json
import os
import pytest
import requests
from onyx.db.enums import AccessType
from onyx.db.models import UserRole
from onyx.server.documents.models import DocumentSource
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.connector import ConnectorManager
from tests.integration.common_utils.managers.credential import CredentialManager
from tests.integration.common_utils.managers.user import DATestUser
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
def _upload_connector_file(
*,
user_performing_action: DATestUser,
file_name: str,
content: bytes,
) -> tuple[str, str]:
headers = user_performing_action.headers.copy()
headers.pop("Content-Type", None)
response = requests.post(
f"{API_SERVER_URL}/manage/admin/connector/file/upload",
files=[("files", (file_name, io.BytesIO(content), "text/plain"))],
headers=headers,
)
response.raise_for_status()
payload = response.json()
return payload["file_paths"][0], payload["file_names"][0]
def _update_connector_files(
*,
connector_id: int,
user_performing_action: DATestUser,
file_ids_to_remove: list[str],
new_file_name: str,
new_file_content: bytes,
) -> requests.Response:
headers = user_performing_action.headers.copy()
headers.pop("Content-Type", None)
return requests.post(
f"{API_SERVER_URL}/manage/admin/connector/{connector_id}/files/update",
data={"file_ids_to_remove": json.dumps(file_ids_to_remove)},
files=[("files", (new_file_name, io.BytesIO(new_file_content), "text/plain"))],
headers=headers,
)
def _list_connector_files(
*,
connector_id: int,
user_performing_action: DATestUser,
) -> requests.Response:
return requests.get(
f"{API_SERVER_URL}/manage/admin/connector/{connector_id}/files",
headers=user_performing_action.headers,
)
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Curator and user group tests are enterprise only",
)
@pytest.mark.usefixtures("reset")
def test_only_global_curator_can_update_public_file_connector_files() -> None:
admin_user = UserManager.create(name="admin_user")
global_curator_creator = UserManager.create(name="global_curator_creator")
global_curator_creator = UserManager.set_role(
user_to_set=global_curator_creator,
target_role=UserRole.GLOBAL_CURATOR,
user_performing_action=admin_user,
)
global_curator_editor = UserManager.create(name="global_curator_editor")
global_curator_editor = UserManager.set_role(
user_to_set=global_curator_editor,
target_role=UserRole.GLOBAL_CURATOR,
user_performing_action=admin_user,
)
curator_user = UserManager.create(name="curator_user")
curator_group = UserGroupManager.create(
name="curator_group",
user_ids=[curator_user.id],
cc_pair_ids=[],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[curator_group],
user_performing_action=admin_user,
)
UserGroupManager.set_curator_status(
test_user_group=curator_group,
user_to_set_as_curator=curator_user,
user_performing_action=admin_user,
)
initial_file_id, initial_file_name = _upload_connector_file(
user_performing_action=global_curator_creator,
file_name="initial-file.txt",
content=b"initial file content",
)
connector = ConnectorManager.create(
user_performing_action=global_curator_creator,
name="public_file_connector",
source=DocumentSource.FILE,
connector_specific_config={
"file_locations": [initial_file_id],
"file_names": [initial_file_name],
"zip_metadata_file_id": None,
},
access_type=AccessType.PUBLIC,
groups=[],
)
credential = CredentialManager.create(
user_performing_action=global_curator_creator,
source=DocumentSource.FILE,
curator_public=True,
groups=[],
name="public_file_connector_credential",
)
CCPairManager.create(
connector_id=connector.id,
credential_id=credential.id,
user_performing_action=global_curator_creator,
access_type=AccessType.PUBLIC,
groups=[],
name="public_file_connector_cc_pair",
)
curator_list_response = _list_connector_files(
connector_id=connector.id,
user_performing_action=curator_user,
)
curator_list_response.raise_for_status()
curator_list_payload = curator_list_response.json()
assert any(f["file_id"] == initial_file_id for f in curator_list_payload["files"])
global_curator_list_response = _list_connector_files(
connector_id=connector.id,
user_performing_action=global_curator_editor,
)
global_curator_list_response.raise_for_status()
global_curator_list_payload = global_curator_list_response.json()
assert any(
f["file_id"] == initial_file_id for f in global_curator_list_payload["files"]
)
denied_response = _update_connector_files(
connector_id=connector.id,
user_performing_action=curator_user,
file_ids_to_remove=[initial_file_id],
new_file_name="curator-file.txt",
new_file_content=b"curator updated file",
)
assert denied_response.status_code == 403
allowed_response = _update_connector_files(
connector_id=connector.id,
user_performing_action=global_curator_editor,
file_ids_to_remove=[initial_file_id],
new_file_name="global-curator-file.txt",
new_file_content=b"global curator updated file",
)
allowed_response.raise_for_status()
payload = allowed_response.json()
assert initial_file_id not in payload["file_paths"]
assert "global-curator-file.txt" in payload["file_names"]
creator_group = UserGroupManager.create(
name="creator_group",
user_ids=[global_curator_creator.id],
cc_pair_ids=[],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[creator_group],
user_performing_action=admin_user,
)
private_file_id, private_file_name = _upload_connector_file(
user_performing_action=global_curator_creator,
file_name="private-initial-file.txt",
content=b"private initial file content",
)
private_connector = ConnectorManager.create(
user_performing_action=global_curator_creator,
name="private_file_connector",
source=DocumentSource.FILE,
connector_specific_config={
"file_locations": [private_file_id],
"file_names": [private_file_name],
"zip_metadata_file_id": None,
},
access_type=AccessType.PRIVATE,
groups=[creator_group.id],
)
private_credential = CredentialManager.create(
user_performing_action=global_curator_creator,
source=DocumentSource.FILE,
curator_public=False,
groups=[creator_group.id],
name="private_file_connector_credential",
)
CCPairManager.create(
connector_id=private_connector.id,
credential_id=private_credential.id,
user_performing_action=global_curator_creator,
access_type=AccessType.PRIVATE,
groups=[creator_group.id],
name="private_file_connector_cc_pair",
)
private_denied_response = _update_connector_files(
connector_id=private_connector.id,
user_performing_action=global_curator_editor,
file_ids_to_remove=[private_file_id],
new_file_name="global-curator-private-file.txt",
new_file_content=b"global curator private update",
)
assert private_denied_response.status_code == 403

View File

@@ -0,0 +1,155 @@
"""Integration tests for seat limit enforcement on user creation paths.
Verifies that when a license with a seat limit is active, new user
creation (registration, invite, reactivation) is blocked with HTTP 402.
"""
from datetime import datetime
from datetime import timedelta
import redis
import requests
from ee.onyx.server.license.models import LicenseMetadata
from ee.onyx.server.license.models import LicenseSource
from ee.onyx.server.license.models import PlanType
from onyx.configs.app_configs import REDIS_DB_NUMBER
from onyx.configs.app_configs import REDIS_HOST
from onyx.configs.app_configs import REDIS_PORT
from onyx.server.settings.models import ApplicationStatus
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.managers.user import UserManager
# TenantRedis prefixes every key with "{tenant_id}:".
# Single-tenant deployments use "public" as the tenant id.
_LICENSE_REDIS_KEY = "public:license:metadata"
def _seed_license(r: redis.Redis, seats: int) -> None:
"""Write a LicenseMetadata entry into Redis with the given seat cap."""
now = datetime.utcnow()
metadata = LicenseMetadata(
tenant_id="public",
organization_name="Test Org",
seats=seats,
used_seats=0, # check_seat_availability recalculates from DB
plan_type=PlanType.ANNUAL,
issued_at=now,
expires_at=now + timedelta(days=365),
status=ApplicationStatus.ACTIVE,
source=LicenseSource.MANUAL_UPLOAD,
)
r.set(_LICENSE_REDIS_KEY, metadata.model_dump_json(), ex=300)
def _clear_license(r: redis.Redis) -> None:
r.delete(_LICENSE_REDIS_KEY)
def _redis() -> redis.Redis:
return redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB_NUMBER)
# ------------------------------------------------------------------
# Registration
# ------------------------------------------------------------------
def test_registration_blocked_when_seats_full(reset: None) -> None: # noqa: ARG001
"""POST /auth/register returns 402 when the seat limit is reached."""
r = _redis()
# First user is admin — occupies 1 seat
UserManager.create(name="admin_user")
# License allows exactly 1 seat → already full
_seed_license(r, seats=1)
try:
response = requests.post(
url=f"{API_SERVER_URL}/auth/register",
json={
"email": "blocked@example.com",
"username": "blocked@example.com",
"password": "TestPassword123!",
},
headers=GENERAL_HEADERS,
)
assert response.status_code == 402
finally:
_clear_license(r)
# ------------------------------------------------------------------
# Invitation
# ------------------------------------------------------------------
def test_invite_blocked_when_seats_full(reset: None) -> None: # noqa: ARG001
"""PUT /manage/admin/users returns 402 when the seat limit is reached."""
r = _redis()
admin_user = UserManager.create(name="admin_user")
_seed_license(r, seats=1)
try:
response = requests.put(
url=f"{API_SERVER_URL}/manage/admin/users",
json={"emails": ["newuser@example.com"]},
headers=admin_user.headers,
)
assert response.status_code == 402
finally:
_clear_license(r)
# ------------------------------------------------------------------
# Reactivation
# ------------------------------------------------------------------
def test_reactivation_blocked_when_seats_full(reset: None) -> None: # noqa: ARG001
"""PATCH /manage/admin/activate-user returns 402 when seats are full."""
r = _redis()
admin_user = UserManager.create(name="admin_user")
basic_user = UserManager.create(name="basic_user")
# Deactivate the basic user (frees a seat in the DB count)
UserManager.set_status(
basic_user, target_status=False, user_performing_action=admin_user
)
# Set license to 1 seat — only admin counts now
_seed_license(r, seats=1)
try:
response = requests.patch(
url=f"{API_SERVER_URL}/manage/admin/activate-user",
json={"user_email": basic_user.email},
headers=admin_user.headers,
)
assert response.status_code == 402
finally:
_clear_license(r)
# ------------------------------------------------------------------
# No license → no enforcement
# ------------------------------------------------------------------
def test_registration_allowed_without_license(reset: None) -> None: # noqa: ARG001
"""Without a license in Redis, registration is unrestricted."""
r = _redis()
# Make sure there is no cached license
_clear_license(r)
UserManager.create(name="admin_user")
# Second user should register without issue
second_user = UserManager.create(name="second_user")
assert second_user is not None

View File

@@ -17,6 +17,7 @@ class TestOnyxWebCrawler:
content from public websites correctly.
"""
@pytest.mark.skip(reason="Temporarily disabled")
def test_fetches_public_url_successfully(self, admin_user: DATestUser) -> None:
"""Test that the crawler can fetch content from a public URL."""
response = requests.post(
@@ -40,6 +41,7 @@ class TestOnyxWebCrawler:
assert "This domain is for use in" in content
assert "documentation" in content or "illustrative" in content
@pytest.mark.skip(reason="Temporarily disabled")
def test_fetches_multiple_urls(self, admin_user: DATestUser) -> None:
"""Test that the crawler can fetch multiple URLs in one request."""
response = requests.post(

View File

@@ -101,6 +101,33 @@ class TestMakeBillingRequest:
assert exc_info.value.status_code == 400
assert "Bad request" in exc_info.value.message
@pytest.mark.asyncio
@patch("ee.onyx.server.billing.service._get_headers")
@patch("ee.onyx.server.billing.service._get_base_url")
async def test_follows_redirects(
self,
mock_base_url: MagicMock,
mock_headers: MagicMock,
) -> None:
"""AsyncClient must be created with follow_redirects=True.
The target server (cloud data plane for self-hosted, control
plane for cloud) may sit behind nginx that returns 308
(HTTP→HTTPS). httpx does not follow redirects by default,
so we must explicitly opt in.
"""
from ee.onyx.server.billing.service import _make_billing_request
mock_base_url.return_value = "http://api.example.com"
mock_headers.return_value = {"Authorization": "Bearer token"}
mock_response = make_mock_response({"ok": True})
mock_client = make_mock_http_client("get", response=mock_response)
with patch("httpx.AsyncClient", mock_client):
await _make_billing_request(method="GET", path="/test")
mock_client.assert_called_once_with(timeout=30.0, follow_redirects=True)
@pytest.mark.asyncio
@patch("ee.onyx.server.billing.service._get_headers")
@patch("ee.onyx.server.billing.service._get_base_url")

View File

@@ -51,7 +51,6 @@ class TestApplyLicenseStatusToSettings:
@pytest.mark.parametrize(
"license_status,expected_app_status,expected_ee_enabled",
[
(None, ApplicationStatus.ACTIVE, False),
(ApplicationStatus.GATED_ACCESS, ApplicationStatus.GATED_ACCESS, False),
(ApplicationStatus.ACTIVE, ApplicationStatus.ACTIVE, True),
],
@@ -84,6 +83,56 @@ class TestApplyLicenseStatusToSettings:
assert result.application_status == expected_app_status
assert result.ee_features_enabled is expected_ee_enabled
@patch("ee.onyx.server.settings.api.ENTERPRISE_EDITION_ENABLED", True)
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
@patch("ee.onyx.server.settings.api.refresh_license_cache", return_value=None)
@patch("ee.onyx.server.settings.api.get_session_with_current_tenant")
@patch("ee.onyx.server.settings.api.get_current_tenant_id")
@patch("ee.onyx.server.settings.api.get_cached_license_metadata")
def test_no_license_with_ee_flag_gates_access(
self,
mock_get_metadata: MagicMock,
mock_get_tenant: MagicMock,
_mock_get_session: MagicMock,
_mock_refresh: MagicMock,
base_settings: Settings,
) -> None:
"""No license + ENTERPRISE_EDITION_ENABLED=true → GATED_ACCESS."""
from ee.onyx.server.settings.api import apply_license_status_to_settings
mock_get_tenant.return_value = "test_tenant"
mock_get_metadata.return_value = None
result = apply_license_status_to_settings(base_settings)
assert result.application_status == ApplicationStatus.GATED_ACCESS
assert result.ee_features_enabled is False
@patch("ee.onyx.server.settings.api.ENTERPRISE_EDITION_ENABLED", False)
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
@patch("ee.onyx.server.settings.api.refresh_license_cache", return_value=None)
@patch("ee.onyx.server.settings.api.get_session_with_current_tenant")
@patch("ee.onyx.server.settings.api.get_current_tenant_id")
@patch("ee.onyx.server.settings.api.get_cached_license_metadata")
def test_no_license_without_ee_flag_allows_community(
self,
mock_get_metadata: MagicMock,
mock_get_tenant: MagicMock,
_mock_get_session: MagicMock,
_mock_refresh: MagicMock,
base_settings: Settings,
) -> None:
"""No license + ENTERPRISE_EDITION_ENABLED=false → community mode (no gating)."""
from ee.onyx.server.settings.api import apply_license_status_to_settings
mock_get_tenant.return_value = "test_tenant"
mock_get_metadata.return_value = None
result = apply_license_status_to_settings(base_settings)
assert result.application_status == ApplicationStatus.ACTIVE
assert result.ee_features_enabled is False
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
@patch("ee.onyx.server.settings.api.get_current_tenant_id")
@@ -105,9 +154,10 @@ class TestApplyLicenseStatusToSettings:
assert result.ee_features_enabled is False
class TestSettingsDefaultEEDisabled:
"""Verify the Settings model defaults ee_features_enabled to False."""
class TestSettingsDefaults:
"""Verify Settings model defaults for CE deployments."""
def test_default_ee_features_disabled(self) -> None:
"""CE default: ee_features_enabled is False."""
settings = Settings()
assert settings.ee_features_enabled is False

View File

@@ -427,6 +427,37 @@ class TestForwardToControlPlane:
assert exc_info.value.status_code == 502
assert "Failed to connect to control plane" in str(exc_info.value.detail)
@pytest.mark.asyncio
async def test_follows_redirects(self) -> None:
"""Test that AsyncClient is created with follow_redirects=True.
The control plane may sit behind a reverse proxy that returns
308 (HTTP→HTTPS). httpx does not follow redirects by default,
so we must explicitly opt in.
"""
mock_response = MagicMock()
mock_response.json.return_value = {"ok": True}
mock_response.raise_for_status = MagicMock()
with (
patch(
"ee.onyx.server.tenants.proxy.generate_data_plane_token"
) as mock_token,
patch("ee.onyx.server.tenants.proxy.httpx.AsyncClient") as mock_client,
patch(
"ee.onyx.server.tenants.proxy.CONTROL_PLANE_API_BASE_URL",
"http://control.example.com",
),
):
mock_token.return_value = "cp_token"
mock_client.return_value.__aenter__.return_value.get = AsyncMock(
return_value=mock_response
)
await forward_to_control_plane("GET", "/test")
mock_client.assert_called_once_with(timeout=30.0, follow_redirects=True)
@pytest.mark.asyncio
async def test_unsupported_method(self) -> None:
"""Test that unsupported HTTP methods raise ValueError."""

View File

@@ -384,6 +384,29 @@ class TestWhitelistBehavior:
verify_email_is_invited("Allowed@Example.Com")
class TestSeatLimitEnforcement:
"""Seat limits block new user creation on self-hosted deployments."""
def test_adding_user_fails_when_seats_full(self) -> None:
from onyx.auth.users import enforce_seat_limit
seat_result = MagicMock(available=False, error_message="Seat limit reached")
with patch(
"onyx.auth.users.fetch_ee_implementation_or_noop",
return_value=lambda *_a, **_kw: seat_result,
):
with pytest.raises(HTTPException) as exc:
enforce_seat_limit(MagicMock())
assert exc.value.status_code == 402
def test_seat_limit_only_enforced_for_self_hosted(self) -> None:
from onyx.auth.users import enforce_seat_limit
with patch("onyx.auth.users.MULTI_TENANT", True):
enforce_seat_limit(MagicMock()) # should not raise
class TestCaseInsensitiveEmailMatching:
"""Test case-insensitive email matching for existing user checks."""

View File

@@ -2,6 +2,7 @@
import pytest
from onyx.chat.llm_loop import _should_keep_bedrock_tool_definitions
from onyx.chat.llm_loop import construct_message_history
from onyx.chat.models import ChatLoadedFile
from onyx.chat.models import ChatMessageSimple
@@ -10,6 +11,17 @@ from onyx.chat.models import ProjectFileMetadata
from onyx.chat.models import ToolCallSimple
from onyx.configs.constants import MessageType
from onyx.file_store.models import ChatFileType
from onyx.llm.constants import LlmProviderNames
class _StubConfig:
def __init__(self, model_provider: str) -> None:
self.model_provider = model_provider
class _StubLLM:
def __init__(self, model_provider: str) -> None:
self.config = _StubConfig(model_provider=model_provider)
def create_message(
@@ -568,3 +580,34 @@ class TestConstructMessageHistory:
assert '"contents"' in project_message.message
assert "Project file 0 content" in project_message.message
assert "Project file 1 content" in project_message.message
class TestBedrockToolConfigGuard:
def test_bedrock_with_tool_history_keeps_tool_definitions(self) -> None:
llm = _StubLLM(LlmProviderNames.BEDROCK)
history = [
create_message("Question", MessageType.USER, 5),
create_assistant_with_tool_call("tc_1", "search", 5),
create_tool_response("tc_1", "Tool output", 5),
]
assert _should_keep_bedrock_tool_definitions(llm, history) is True
def test_bedrock_without_tool_history_does_not_keep_tool_definitions(self) -> None:
llm = _StubLLM(LlmProviderNames.BEDROCK)
history = [
create_message("Question", MessageType.USER, 5),
create_message("Answer", MessageType.ASSISTANT, 5),
]
assert _should_keep_bedrock_tool_definitions(llm, history) is False
def test_non_bedrock_with_tool_history_does_not_keep_tool_definitions(self) -> None:
llm = _StubLLM(LlmProviderNames.OPENAI)
history = [
create_message("Question", MessageType.USER, 5),
create_assistant_with_tool_call("tc_1", "search", 5),
create_tool_response("tc_1", "Tool output", 5),
]
assert _should_keep_bedrock_tool_definitions(llm, history) is False

View File

@@ -0,0 +1,95 @@
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentBase
from onyx.connectors.models import TextSection
def _minimal_doc_kwargs(metadata: dict) -> dict:
return {
"id": "test-doc",
"sections": [TextSection(text="hello", link="http://example.com")],
"source": DocumentSource.NOT_APPLICABLE,
"semantic_identifier": "Test Doc",
"metadata": metadata,
}
def test_int_values_coerced_to_str() -> None:
doc = Document(**_minimal_doc_kwargs({"count": 42}))
assert doc.metadata == {"count": "42"}
def test_float_values_coerced_to_str() -> None:
doc = Document(**_minimal_doc_kwargs({"score": 3.14}))
assert doc.metadata == {"score": "3.14"}
def test_bool_values_coerced_to_str() -> None:
doc = Document(**_minimal_doc_kwargs({"active": True}))
assert doc.metadata == {"active": "True"}
def test_list_of_ints_coerced_to_list_of_str() -> None:
doc = Document(**_minimal_doc_kwargs({"ids": [1, 2, 3]}))
assert doc.metadata == {"ids": ["1", "2", "3"]}
def test_list_of_mixed_types_coerced_to_list_of_str() -> None:
doc = Document(**_minimal_doc_kwargs({"tags": ["a", 1, True, 2.5]}))
assert doc.metadata == {"tags": ["a", "1", "True", "2.5"]}
def test_list_of_dicts_coerced_to_list_of_str() -> None:
raw = {"nested": [{"key": "val"}, {"key2": "val2"}]}
doc = Document(**_minimal_doc_kwargs(raw))
assert doc.metadata == {"nested": ["{'key': 'val'}", "{'key2': 'val2'}"]}
def test_dict_value_coerced_to_str() -> None:
raw = {"info": {"inner_key": "inner_val"}}
doc = Document(**_minimal_doc_kwargs(raw))
assert doc.metadata == {"info": "{'inner_key': 'inner_val'}"}
def test_none_value_coerced_to_str() -> None:
doc = Document(**_minimal_doc_kwargs({"empty": None}))
assert doc.metadata == {"empty": "None"}
def test_already_valid_str_values_unchanged() -> None:
doc = Document(**_minimal_doc_kwargs({"key": "value"}))
assert doc.metadata == {"key": "value"}
def test_already_valid_list_of_str_unchanged() -> None:
doc = Document(**_minimal_doc_kwargs({"tags": ["a", "b", "c"]}))
assert doc.metadata == {"tags": ["a", "b", "c"]}
def test_empty_metadata_unchanged() -> None:
doc = Document(**_minimal_doc_kwargs({}))
assert doc.metadata == {}
def test_mixed_metadata_values() -> None:
raw = {
"str_val": "hello",
"int_val": 99,
"list_val": [1, "two", 3.0],
"dict_val": {"nested": True},
}
doc = Document(**_minimal_doc_kwargs(raw))
assert doc.metadata == {
"str_val": "hello",
"int_val": "99",
"list_val": ["1", "two", "3.0"],
"dict_val": "{'nested': True}",
}
def test_coercion_works_on_base_class() -> None:
kwargs = _minimal_doc_kwargs({"count": 42})
kwargs.pop("source")
kwargs.pop("id")
doc = DocumentBase(**kwargs)
assert doc.metadata == {"count": "42"}

View File

@@ -0,0 +1,204 @@
"""Tests for Slack channel reference resolution and tag filtering
in handle_regular_answer.py."""
from unittest.mock import MagicMock
from slack_sdk.errors import SlackApiError
from onyx.context.search.models import Tag
from onyx.onyxbot.slack.constants import SLACK_CHANNEL_REF_PATTERN
from onyx.onyxbot.slack.handlers.handle_regular_answer import resolve_channel_references
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _mock_client_with_channels(
channel_map: dict[str, str],
) -> MagicMock:
"""Return a mock WebClient where conversations_info resolves IDs to names."""
client = MagicMock()
def _conversations_info(channel: str) -> MagicMock:
if channel in channel_map:
resp = MagicMock()
resp.validate = MagicMock()
resp.__getitem__ = lambda _self, key: {
"channel": {
"name": channel_map[channel],
"is_im": False,
"is_mpim": False,
}
}[key]
return resp
raise SlackApiError("channel_not_found", response=MagicMock())
client.conversations_info = _conversations_info
return client
def _mock_logger() -> MagicMock:
return MagicMock()
# ---------------------------------------------------------------------------
# SLACK_CHANNEL_REF_PATTERN regex tests
# ---------------------------------------------------------------------------
class TestSlackChannelRefPattern:
def test_matches_bare_channel_id(self) -> None:
matches = SLACK_CHANNEL_REF_PATTERN.findall("<#C097NBWMY8Y>")
assert matches == [("C097NBWMY8Y", "")]
def test_matches_channel_id_with_name(self) -> None:
matches = SLACK_CHANNEL_REF_PATTERN.findall("<#C097NBWMY8Y|eng-infra>")
assert matches == [("C097NBWMY8Y", "eng-infra")]
def test_matches_multiple_channels(self) -> None:
msg = "compare <#C111AAA> and <#C222BBB|general>"
matches = SLACK_CHANNEL_REF_PATTERN.findall(msg)
assert len(matches) == 2
assert ("C111AAA", "") in matches
assert ("C222BBB", "general") in matches
def test_no_match_on_plain_text(self) -> None:
matches = SLACK_CHANNEL_REF_PATTERN.findall("no channels here")
assert matches == []
def test_no_match_on_user_mention(self) -> None:
matches = SLACK_CHANNEL_REF_PATTERN.findall("<@U12345>")
assert matches == []
# ---------------------------------------------------------------------------
# resolve_channel_references tests
# ---------------------------------------------------------------------------
class TestResolveChannelReferences:
def test_resolves_bare_channel_id_via_api(self) -> None:
client = _mock_client_with_channels({"C097NBWMY8Y": "eng-infra"})
logger = _mock_logger()
message, tags = resolve_channel_references(
message="summary of <#C097NBWMY8Y> this week",
client=client,
logger=logger,
)
assert message == "summary of #eng-infra this week"
assert len(tags) == 1
assert tags[0] == Tag(tag_key="Channel", tag_value="eng-infra")
def test_uses_name_from_pipe_format_without_api_call(self) -> None:
client = MagicMock()
logger = _mock_logger()
message, tags = resolve_channel_references(
message="check <#C097NBWMY8Y|eng-infra> for updates",
client=client,
logger=logger,
)
assert message == "check #eng-infra for updates"
assert tags == [Tag(tag_key="Channel", tag_value="eng-infra")]
# Should NOT have called the API since name was in the markup
client.conversations_info.assert_not_called()
def test_multiple_channels(self) -> None:
client = _mock_client_with_channels(
{
"C111AAA": "eng-infra",
"C222BBB": "eng-general",
}
)
logger = _mock_logger()
message, tags = resolve_channel_references(
message="compare <#C111AAA> and <#C222BBB>",
client=client,
logger=logger,
)
assert "#eng-infra" in message
assert "#eng-general" in message
assert "<#" not in message
assert len(tags) == 2
tag_values = {t.tag_value for t in tags}
assert tag_values == {"eng-infra", "eng-general"}
def test_no_channel_references_returns_unchanged(self) -> None:
client = MagicMock()
logger = _mock_logger()
message, tags = resolve_channel_references(
message="just a normal message with no channels",
client=client,
logger=logger,
)
assert message == "just a normal message with no channels"
assert tags == []
def test_api_failure_skips_channel_gracefully(self) -> None:
# Client that fails for all channel lookups
client = _mock_client_with_channels({})
logger = _mock_logger()
message, tags = resolve_channel_references(
message="check <#CBADID123>",
client=client,
logger=logger,
)
# Message should remain unchanged for the failed channel
assert "<#CBADID123>" in message
assert tags == []
logger.warning.assert_called_once()
def test_partial_failure_resolves_what_it_can(self) -> None:
# Only one of two channels resolves
client = _mock_client_with_channels({"C111AAA": "eng-infra"})
logger = _mock_logger()
message, tags = resolve_channel_references(
message="compare <#C111AAA> and <#CBADID123>",
client=client,
logger=logger,
)
assert "#eng-infra" in message
assert "<#CBADID123>" in message # failed one stays raw
assert len(tags) == 1
assert tags[0].tag_value == "eng-infra"
def test_duplicate_channel_produces_single_tag(self) -> None:
client = _mock_client_with_channels({"C111AAA": "eng-infra"})
logger = _mock_logger()
message, tags = resolve_channel_references(
message="summarize <#C111AAA> and compare with <#C111AAA>",
client=client,
logger=logger,
)
assert message == "summarize #eng-infra and compare with #eng-infra"
assert len(tags) == 1
assert tags[0].tag_value == "eng-infra"
def test_mixed_pipe_and_bare_formats(self) -> None:
client = _mock_client_with_channels({"C222BBB": "random"})
logger = _mock_logger()
message, tags = resolve_channel_references(
message="see <#C111AAA|eng-infra> and <#C222BBB>",
client=client,
logger=logger,
)
assert "#eng-infra" in message
assert "#random" in message
assert len(tags) == 2

View File

@@ -0,0 +1,205 @@
from onyx.onyxbot.slack.formatting import _convert_slack_links_to_markdown
from onyx.onyxbot.slack.formatting import _normalize_link_destinations
from onyx.onyxbot.slack.formatting import _sanitize_html
from onyx.onyxbot.slack.formatting import _transform_outside_code_blocks
from onyx.onyxbot.slack.formatting import format_slack_message
from onyx.onyxbot.slack.utils import remove_slack_text_interactions
from onyx.utils.text_processing import decode_escapes
def test_normalize_citation_link_wraps_url_with_parentheses() -> None:
message = (
"See [[1]](https://example.com/Access%20ID%20Card(s)%20Guide.pdf) for details."
)
normalized = _normalize_link_destinations(message)
assert (
"See [[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>) for details."
== normalized
)
def test_normalize_citation_link_keeps_existing_angle_brackets() -> None:
message = "[[1]](<https://example.com/Access%20ID%20Card(s)%20Guide.pdf>)"
normalized = _normalize_link_destinations(message)
assert message == normalized
def test_normalize_citation_link_handles_multiple_links() -> None:
message = (
"[[1]](https://example.com/(USA)%20Guide.pdf) "
"[[2]](https://example.com/Plan(s)%20Overview.pdf)"
)
normalized = _normalize_link_destinations(message)
assert "[[1]](<https://example.com/(USA)%20Guide.pdf>)" in normalized
assert "[[2]](<https://example.com/Plan(s)%20Overview.pdf>)" in normalized
def test_format_slack_message_keeps_parenthesized_citation_links_intact() -> None:
message = (
"Download [[1]](https://example.com/(USA)%20Access%20ID%20Card(s)%20Guide.pdf)"
)
formatted = format_slack_message(message)
rendered = decode_escapes(remove_slack_text_interactions(formatted))
assert (
"<https://example.com/(USA)%20Access%20ID%20Card(s)%20Guide.pdf|[1]>"
in rendered
)
assert "|[1]>%20Access%20ID%20Card" not in rendered
def test_slack_style_links_converted_to_clickable_links() -> None:
message = "Visit <https://example.com/page|Example Page> for details."
formatted = format_slack_message(message)
assert "<https://example.com/page|Example Page>" in formatted
assert "&lt;" not in formatted
def test_slack_style_links_preserved_inside_code_blocks() -> None:
message = "```\n<https://example.com|click>\n```"
converted = _convert_slack_links_to_markdown(message)
assert "<https://example.com|click>" in converted
def test_html_tags_stripped_outside_code_blocks() -> None:
message = "Hello<br/>world ```<div>code</div>``` after"
sanitized = _transform_outside_code_blocks(message, _sanitize_html)
assert "<br" not in sanitized
assert "<div>code</div>" in sanitized
def test_format_slack_message_block_spacing() -> None:
message = "Paragraph one.\n\nParagraph two."
formatted = format_slack_message(message)
assert "Paragraph one.\n\nParagraph two." == formatted
def test_format_slack_message_code_block_no_trailing_blank_line() -> None:
message = "```python\nprint('hi')\n```"
formatted = format_slack_message(message)
assert formatted.endswith("print('hi')\n```")
def test_format_slack_message_ampersand_not_double_escaped() -> None:
message = 'She said "hello" & goodbye.'
formatted = format_slack_message(message)
assert "&amp;" in formatted
assert "&quot;" not in formatted
# -- Table rendering tests --
def test_table_renders_as_vertical_cards() -> None:
message = (
"| Feature | Status | Owner |\n"
"|---------|--------|-------|\n"
"| Auth | Done | Alice |\n"
"| Search | In Progress | Bob |\n"
)
formatted = format_slack_message(message)
assert "*Auth*\n • Status: Done\n • Owner: Alice" in formatted
assert "*Search*\n • Status: In Progress\n • Owner: Bob" in formatted
# Cards separated by blank line
assert "Owner: Alice\n\n*Search*" in formatted
# No raw pipe-and-dash table syntax
assert "---|" not in formatted
def test_table_single_column() -> None:
message = "| Name |\n|------|\n| Alice |\n| Bob |\n"
formatted = format_slack_message(message)
assert "*Alice*" in formatted
assert "*Bob*" in formatted
def test_table_embedded_in_text() -> None:
message = (
"Here are the results:\n\n"
"| Item | Count |\n"
"|------|-------|\n"
"| Apples | 5 |\n"
"\n"
"That's all."
)
formatted = format_slack_message(message)
assert "Here are the results:" in formatted
assert "*Apples*\n • Count: 5" in formatted
assert "That's all." in formatted
def test_table_with_formatted_cells() -> None:
message = (
"| Name | Link |\n"
"|------|------|\n"
"| **Alice** | [profile](https://example.com) |\n"
)
formatted = format_slack_message(message)
# Bold cell should not double-wrap: *Alice* not **Alice**
assert "*Alice*" in formatted
assert "**Alice**" not in formatted
assert "<https://example.com|profile>" in formatted
def test_table_with_alignment_specifiers() -> None:
message = (
"| Left | Center | Right |\n" "|:-----|:------:|------:|\n" "| a | b | c |\n"
)
formatted = format_slack_message(message)
assert "*a*\n • Center: b\n • Right: c" in formatted
def test_two_tables_in_same_message_use_independent_headers() -> None:
message = (
"| A | B |\n"
"|---|---|\n"
"| 1 | 2 |\n"
"\n"
"| X | Y | Z |\n"
"|---|---|---|\n"
"| p | q | r |\n"
)
formatted = format_slack_message(message)
assert "*1*\n • B: 2" in formatted
assert "*p*\n • Y: q\n • Z: r" in formatted
def test_table_empty_first_column_no_bare_asterisks() -> None:
message = "| Name | Status |\n" "|------|--------|\n" "| | Done |\n"
formatted = format_slack_message(message)
# Empty title should not produce "**" (bare asterisks)
assert "**" not in formatted
assert " • Status: Done" in formatted

View File

@@ -0,0 +1,47 @@
"""Test bulk invite limit for free trial tenants."""
from unittest.mock import patch
import pytest
from fastapi import HTTPException
from onyx.server.manage.users import bulk_invite_users
@patch("onyx.server.manage.users.MULTI_TENANT", True)
@patch("onyx.server.manage.users.is_tenant_on_trial_fn", return_value=True)
@patch("onyx.server.manage.users.get_current_tenant_id", return_value="test_tenant")
@patch("onyx.server.manage.users.get_invited_users", return_value=[])
@patch("onyx.server.manage.users.get_all_users", return_value=[])
@patch("onyx.server.manage.users.NUM_FREE_TRIAL_USER_INVITES", 5)
def test_trial_tenant_cannot_exceed_invite_limit(*_mocks: None) -> None:
"""Trial tenants cannot invite more users than the configured limit."""
emails = [f"user{i}@example.com" for i in range(6)]
with pytest.raises(HTTPException) as exc_info:
bulk_invite_users(emails=emails)
assert exc_info.value.status_code == 403
assert "invite limit" in exc_info.value.detail.lower()
@patch("onyx.server.manage.users.MULTI_TENANT", True)
@patch("onyx.server.manage.users.DEV_MODE", True)
@patch("onyx.server.manage.users.ENABLE_EMAIL_INVITES", False)
@patch("onyx.server.manage.users.is_tenant_on_trial_fn", return_value=True)
@patch("onyx.server.manage.users.get_current_tenant_id", return_value="test_tenant")
@patch("onyx.server.manage.users.get_invited_users", return_value=[])
@patch("onyx.server.manage.users.get_all_users", return_value=[])
@patch("onyx.server.manage.users.write_invited_users", return_value=3)
@patch("onyx.server.manage.users.NUM_FREE_TRIAL_USER_INVITES", 5)
@patch(
"onyx.server.manage.users.fetch_ee_implementation_or_noop",
return_value=lambda *_args: None,
)
def test_trial_tenant_can_invite_within_limit(*_mocks: None) -> None:
"""Trial tenants can invite users when under the limit."""
emails = ["user1@example.com", "user2@example.com", "user3@example.com"]
result = bulk_invite_users(emails=emails)
assert result == 3

Some files were not shown because too many files have changed in this diff Show More