Compare commits

...

277 Commits

Author SHA1 Message Date
rohoswagger
21aa89badc feat(craft): stop generation 2026-02-26 15:03:34 -08:00
Justin Tahara
be13aa1310 chore(llm): add Vertex AI nightly tests (#8813) 2026-02-26 22:38:05 +00:00
Nikolas Garza
45d38c4906 feat(metrics): add per-tenant Prometheus metrics (#8822) 2026-02-26 22:37:35 +00:00
Danelegend
8aab518532 fix: Admin page modal centering excludes sidebar (#8823) 2026-02-26 22:27:58 +00:00
Nikolas Garza
da6ce10e86 test(scim): add integration tests for SCIM token management (#8819) 2026-02-26 22:22:16 +00:00
Nikolas Garza
aaf8253520 fix(ee): show subscription text on expired access page for cloud users (#8804) 2026-02-26 22:15:44 +00:00
Jamison Lahman
7c7f81b164 chore(fe): add feature agent to editor page (#8814) 2026-02-26 22:12:20 +00:00
Justin Tahara
2d4a3c72e9 chore(llm): Nightly Bedrock Tests (#8812) 2026-02-26 22:10:31 +00:00
acaprau
7c51712018 fix(db ssl): Remove import-time side effect of creating SSL context if IAM enabled (#8811) 2026-02-26 21:37:13 +00:00
Evan Lohn
aa5614695d feat: sharepoint tenant avoid org get (#8802) 2026-02-26 21:28:56 +00:00
Jamison Lahman
8d7255d3c4 chore(fe): support featured agents w/o being public (#8809) 2026-02-26 21:16:23 +00:00
Evan Lohn
d403498f48 feat: context injection unification (#8687) 2026-02-26 21:11:19 +00:00
dependabot[bot]
9ef3095c17 chore(deps): bump pypdf from 6.6.2 to 6.7.3 (#8808)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-26 20:42:01 +00:00
Justin Tahara
a39e93a0cb chore(llm): LLM Integration Tests Generic Setup (#8803) 2026-02-26 19:59:19 +00:00
Jamison Lahman
46d73cdfee fix(docker): prefer user runtime docker socket (#8799) 2026-02-26 10:55:44 -08:00
Raunak Bhagat
1e04ce78e0 feat(opal): add Hoverable compound component (#8798) 2026-02-26 17:08:53 +00:00
Jamison Lahman
f9b81c1725 feat(agents): share agents with labels or featured (#8742)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-02-26 16:21:05 +00:00
SubashMohan
3bc1b89fee fix(memory): timeline UI alignment issues and highlighting issue (#8753) 2026-02-26 08:46:43 +00:00
Nikolas Garza
01743d99d4 fix(billing): handle manual license users without Stripe subscription (#8787) 2026-02-26 08:07:14 +00:00
acaprau
092c1db7e0 chore(opensearch): Allow programatic schema updates (#8794) 2026-02-26 07:49:56 +00:00
acaprau
40ac0d859a chore(opensearch): OpenSearchClient implements context manager, also closes on del (#8781)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-02-26 07:38:16 +00:00
SubashMohan
929e58361f fix: resolve OAuth token manager using masked secrets (#8673) 2026-02-26 07:06:51 +00:00
SubashMohan
6d472df7c5 fix(timeline): Fix double-collapse and improve tool status messages (#8751) 2026-02-26 07:05:48 +00:00
acaprau
cfa7acd904 chore(opensearch): MT cloud should verify index on document index init, and do cluster setup once at start (#8776) 2026-02-26 06:42:06 +00:00
Danelegend
5c5a6f943b chore: deprecate llm provider fields (#8783) 2026-02-26 05:27:28 +00:00
Evan Lohn
d04128b8b1 fix: sharepoint unquote (#8786) 2026-02-26 03:38:46 +00:00
Nikolas Garza
bbebdf8f78 feat(scim): Entra ID enterprise extension support [3/3] (#8747) 2026-02-26 02:32:04 +00:00
Nikolas Garza
161279a2d5 feat(scim): field round-tripping for IdP attribute preservation [2/3] (#8746) 2026-02-26 02:01:13 +00:00
Jamison Lahman
e5ebb45a20 chore(devtools): upgrade ods: v0.6.1->v0.6.2 (#8773) 2026-02-26 01:57:25 +00:00
Evan Lohn
320ba9cb1b refactor: filter by persona id during search (#8683) 2026-02-26 01:51:00 +00:00
Nikolas Garza
f2e8cb3114 fix(slack): sanitize HTML tags and broken citation links in bot responses (#8767) 2026-02-26 01:47:44 +00:00
Nikolas Garza
43054a28ec feat(scim): SCIM 2.0 protocol compliance fixes [1/3] (#8745) 2026-02-26 01:33:08 +00:00
Justin Tahara
dc74aa7b1f chore(llm): Add OpenAI Integration Tests (#8711) 2026-02-26 00:58:28 +00:00
Raunak Bhagat
bd773191c2 feat(opal): add more icons (#8778) 2026-02-26 00:38:54 +00:00
Evan Lohn
66dbff41e6 refactor: extend sync mechanism to persona files (#8682) 2026-02-26 00:32:30 +00:00
roshan
1dcffe38bc fix: Invoke generate_agents_md.py in K8s to populate knowledge sources (#8768)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-26 00:04:10 +00:00
Evan Lohn
c35e883564 refactor: persona id in vector db by indexing (#8681) 2026-02-25 22:51:57 +00:00
Jamison Lahman
fefcd58481 chore(devtools): ods web to run web/package.json scripts (#8766)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-02-25 14:05:29 -08:00
Jamison Lahman
bdc89d9e3f chore(fe): opal button implements responsiveHideText (#8764) 2026-02-25 21:05:08 +00:00
Evan Lohn
f4d777b80d refactor: persona id in vector db (#8680) 2026-02-25 20:42:38 +00:00
acaprau
da4d57b5e3 chore(devtools): Make AGENTS.md reference contributing_guides/best_practices.md (#8760)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-02-25 20:27:12 +00:00
Evan Lohn
dcdcd067bd fix: drive 403 rate limits (#8762) 2026-02-25 20:12:36 +00:00
Evan Lohn
8b15a29723 feat: slab connector validation (#8758) 2026-02-25 20:00:42 +00:00
Danelegend
763853674f feat(ci): Add preview modal for data types (#8752) 2026-02-25 19:52:19 +00:00
Jamison Lahman
429b6f3465 fix(fe): modal aligning with detached element after navigation (#8676) 2026-02-25 19:33:07 +00:00
Danelegend
37d5be1b40 feat: python tool not added when no code interpretter server (#8749) 2026-02-25 19:17:42 +00:00
Jamison Lahman
8ab99dbb06 chore(fe): add hover style to AgentCard (#8689) 2026-02-25 19:08:00 +00:00
Jamison Lahman
52799e9c7a fix(fe): middle align human chat message text (#8756) 2026-02-25 19:00:01 +00:00
Jamison Lahman
aef009cc97 chore(fe): foldable buttons display text via tooltip when disabled (#8735) 2026-02-25 18:39:53 +00:00
Evan Lohn
18d1ea1770 fix: sharepoint driveItem perm sync (#8698) 2026-02-25 18:29:26 +00:00
Bo-Onyx
f336ad00f4 fix(user invitation): failed but no warning. (#8731)
Co-authored-by: Bo Yang <boyang@Bos-MacBook-Pro.local>
2026-02-25 17:23:39 +00:00
SubashMohan
0558e687d9 fix: persist onboarding dismissal in localStorage with user-specific keys (#8674) 2026-02-25 06:22:17 +00:00
roshan
784a99e24a updated demo data (#8748) 2026-02-24 19:59:46 -08:00
Justin Tahara
da1f5a11f4 chore(cherry-pick): Alerting on Failed Cherry-Picks (#8744) 2026-02-25 02:09:19 +00:00
Justin Tahara
5633805890 chore(devtools): Upgrade ods from 0.6.0 -> 0.6.1 (#8743) 2026-02-25 02:01:20 +00:00
Danelegend
0817b45ae1 feat: Get code interpreter config route (#8739) 2026-02-25 01:49:30 +00:00
Justin Tahara
af0e4bdebc fix(slack): Cleaning up URL Links (#8569) 2026-02-25 01:42:12 +00:00
Justin Tahara
4cd2320732 chore(cherry-pick): Add Github Label for PRs (#8736) 2026-02-25 00:46:12 +00:00
Danelegend
90a361f0e1 feat: code interpreter routes (#8670) 2026-02-24 16:27:10 -08:00
Justin Tahara
194efde97b chore(llm): Scaffolding for Nightly LLM Tests (#8704) 2026-02-25 00:06:24 +00:00
Danelegend
d922a42262 feat: code interpreter docker default deploy (#8672) 2026-02-24 23:51:19 +00:00
Danelegend
f00c3a486e feat: default deploy code interpreter - helm & bump version 0.3.0 (#8685) 2026-02-24 23:40:46 +00:00
Danelegend
192080c9e4 feat: default deploy code interpreter - restart_script (#8686) 2026-02-24 23:40:36 +00:00
Justin Tahara
c5787dc073 chore(image): Update test to be for Dall E 3 instead of 2 (#8732) 2026-02-24 22:53:31 +00:00
Justin Tahara
d424d6462c fix(sanitization): Centralizing DB Filters (#8730) 2026-02-24 22:28:25 +00:00
Jamison Lahman
ecea86deb6 chore(fe): only left input items flex (#8734) 2026-02-24 22:25:04 +00:00
Jamison Lahman
a5c1f50a8a chore(fe): update disabled "select" button color (#8733) 2026-02-24 22:03:52 +00:00
roshan
4a04cfd486 feat(craft): make output/ files downloadable from Artifacts tab (#8721)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-02-24 21:49:59 +00:00
Nikolas Garza
f22e9628db feat(scim): add additional entra id fields to ScimUserMapping (#8728) 2026-02-24 20:23:21 +00:00
Jamison Lahman
255ba10af6 chore(chat): consolidate chat message whitespacing style (#8696) 2026-02-24 20:02:28 +00:00
Justin Tahara
563202a080 feat(image): support Azure historical image context edits (#8726) 2026-02-24 19:21:30 +00:00
Evan Lohn
1062dc0743 fix: graph client env (#8727) 2026-02-24 18:46:49 +00:00
Justin Tahara
0826348568 feat(image): support OpenAI historical image context edits (#8725) 2026-02-24 18:45:56 +00:00
Justin Tahara
375079136d chore(cherry-pick): Assign merged-by user on beta cherry-pick PR (#8723) 2026-02-24 18:27:48 +00:00
Jamison Lahman
82aad5e253 fix(welcome): add back agent description (#8716) 2026-02-24 17:27:23 +00:00
Jamison Lahman
beb1c49c69 fix(fe): inline code-blocks respect header font-size (#8691) 2026-02-24 17:03:21 +00:00
Jamison Lahman
c4556515be fix(fe): rm non-admin-confirmation max-width (#8693) 2026-02-24 17:03:05 +00:00
SubashMohan
a4387f230b fix(popover): prevent viewport overflow with dynamic max-height and collision padding (#8675) 2026-02-24 10:27:36 +00:00
Evan Lohn
d91e452658 chore: version bumps for client libs (#8720) 2026-02-24 08:13:37 +00:00
Danelegend
dd274f8667 feat: code interpreter supports streaming (#8663) 2026-02-24 06:07:36 +00:00
roshan
2c82f0da16 fix(craft): delete S3 snapshot files when deleting a craft (#8718)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-24 05:58:29 +00:00
Raunak Bhagat
26101636f2 refactor: add new ContentAction component (#8695) 2026-02-24 05:13:18 +00:00
roshan
5e2c0c6cf4 fix(nrf): hide search toggle when search mode is unavailable (#8717)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 20:43:19 -08:00
roshan
33b64db498 fix(extensions): fix base url for chrome extension to (#8714) 2026-02-23 20:18:05 -08:00
roshan
b925cc1a56 feat(chrome-extension): add tab reading to side panel (#8571)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-02-24 01:17:57 +00:00
Danelegend
bac4b7c945 fix: preview markdown formatting (#8667) 2026-02-24 01:13:52 +00:00
Evan Lohn
6f6ef1e657 chore: coerce doc metadata (#8703) 2026-02-24 01:12:11 +00:00
Danelegend
885c69f460 feat: Improve csv preview modal (#8702) 2026-02-24 01:00:20 +00:00
Danelegend
4b837303ff feat(code-interpreter): Seed code interpreter server row (#8701) 2026-02-24 00:59:49 +00:00
Justin Tahara
d856a9befb fix(projects): Guardrails for Project User Files (#8644) 2026-02-24 00:21:57 +00:00
Justin Tahara
adade353c5 fix(api): Improving the API handling of threads (#8573) 2026-02-24 00:04:21 +00:00
Nikolas Garza
3cb6ec2f85 fix: patch prometheus metrics in daily test fixture (#8699) 2026-02-24 00:02:56 +00:00
Wenxi
691eebf00a fix: remove user info requirement for craft onboarding modal (#8697) 2026-02-23 23:52:17 +00:00
Danelegend
905b6633e6 chore: preview modal (#8665) 2026-02-23 23:40:55 +00:00
Justin Tahara
fd088196ff fix(search): Improve Speed (#8430) 2026-02-23 22:45:18 +00:00
Jamison Lahman
cafbf5b8be chore(playwright): warn user if setup takes longer than usual (#8690) 2026-02-23 22:23:58 +00:00
roshan
1235181559 fix(ui): Clean up NRF settings button styling (#8678)
Co-authored-by: Claude <noreply@anthropic.com>
2026-02-23 21:25:43 +00:00
Justin Tahara
caa2e45632 fix(db): Multitenant Schema migration update (#8679) 2026-02-23 21:25:26 +00:00
Justin Tahara
9c62e03120 chore(ods): Automated Cherry-pick backport (#8642) 2026-02-23 21:15:09 +00:00
Nikolas Garza
0937305064 feat(scim): Okta compatibility + provider abstraction (#8568) 2026-02-23 21:09:18 +00:00
Wenxi
e4c06570e3 fix: domain rules for signup on cloud (#8671) 2026-02-23 20:27:37 +00:00
roshan
78fc7c86d7 fix: Handle unauthenticated state gracefully on NRF page (#8491)
Co-authored-by: Claude <noreply@anthropic.com>
2026-02-23 19:26:38 +00:00
Raunak Bhagat
84d3aea847 refactor: migrate Web Search page to SettingsLayouts + Content (#8662) 2026-02-23 13:38:37 +00:00
Danelegend
00a404d3cd feat: Add code interpreter server db model (#8669) 2026-02-23 05:09:59 +00:00
Wenxi
787cf90d96 chore: set trial api usage to 0 and show ui (#8664) 2026-02-23 01:41:23 +00:00
SubashMohan
15fe47adc5 fix: improve connector status display, agent search tool detection, and prompt formatting (#8579) 2026-02-22 07:22:06 +00:00
SubashMohan
29958f1a52 perf(open-url): parallelize URL fetching with split connect/read timeouts (#8580) 2026-02-22 06:45:26 +00:00
SubashMohan
ac7f9838bc feat: add mixed content handler for chat and image generation packets (#8494) 2026-02-22 06:12:48 +00:00
Raunak Bhagat
d0fa4b3319 fix: limit connectors section to 3 items in Chat Preferences (#8661) 2026-02-22 01:01:42 +00:00
Raunak Bhagat
3fb4fb422e feat: refreshed admin Chat Preferences page (#8488)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-21 04:52:04 +00:00
acaprau
ba5da22ea1 chore(documentation): Add comment in contributing_guides/best_practices.md about accountability in TODOs (#8659) 2026-02-21 04:47:48 +00:00
Evan Lohn
9909049047 fix: improve eager loading behavior (#8576) 2026-02-21 04:02:34 +00:00
Nikolas Garza
c516aa3e3c fix(search): eliminate DB pool exhaustion from SearchTool parallel queries (#8651) 2026-02-21 03:56:00 +00:00
Evan Lohn
5cc6220417 feat: checkpointing mid drive (#8606) 2026-02-21 03:47:28 +00:00
acaprau
15da1e0a88 feat(opensearch): Add OpenSearch in docker compose (#8611) 2026-02-21 02:57:31 +00:00
Jamison Lahman
e9ff00890b chore(fe): fix Pinned icon hover background when card is hovered (#8656) 2026-02-21 02:37:18 +00:00
Wenxi
67747a9d93 fix: unify provider retrieval into swr hook (#8609) 2026-02-21 02:15:28 +00:00
Jamison Lahman
edfc51b439 fix(fe): Truncated tooltip is disabled when not needed to be shown (#8629) 2026-02-21 02:07:18 +00:00
Evan Lohn
ac4fba947e feat: sharepoint scalability5 (#8631) 2026-02-21 01:56:11 +00:00
Nikolas Garza
c142b2db02 chore: special sauce (#8646) 2026-02-21 01:30:36 +00:00
Jamison Lahman
fb7e7e4395 fix(fe): keep focus on input on empty (#8627) 2026-02-21 00:13:26 +00:00
Justin Tahara
113f23398e fix(celery): Guardrail for User File Processing (#8633) 2026-02-20 22:43:01 +00:00
Danelegend
5a8716026a feat: Python tool call replay packets (#8649) 2026-02-20 22:28:06 +00:00
Jamison Lahman
3389140bfd chore(deps): @radix-ui/react-tooltip: v1.1.3->v1.2.8 (#8647) 2026-02-20 22:27:49 +00:00
Justin Tahara
13109e7b81 chore(llm): Removal of Retired Models + Cleanup (#8645) 2026-02-20 21:47:53 +00:00
Jessica Singh
56ad457168 chore(auth): tests cleanup (#8559) 2026-02-20 20:51:29 +00:00
Nikolas Garza
a81aea2afc feat(scim): add scim_username column to ScimUserMapping (#8635) 2026-02-20 12:44:16 -08:00
Jamison Lahman
7cb5c9c4a6 chore(fe): replace BlinkingDot with a BlinkingBar (#8640) 2026-02-20 20:02:19 +00:00
Evan Lohn
3520c58a22 feat: sharepoint scalability4 (#8551) 2026-02-20 19:35:36 +00:00
Nikolas Garza
bd9d1bfa27 chore(helm): update code-interpreter chart repo URL to python-sandbox (#8625) 2026-02-20 11:39:08 -08:00
Evan Lohn
14416cc3db fix: search tool enabled when nothing selected (#8637) 2026-02-20 19:22:28 +00:00
Jamison Lahman
d7fce14d26 chore(devtools): upgrade ods: 0.5.7->0.6.0 (#8628) 2026-02-20 10:18:27 -08:00
Wenxi
39a8d8ed05 fix: boolean form field click on text, open url tool checkbox in default assistant, and simple tooltip rendering (#8480) 2026-02-20 17:38:29 +00:00
Jamison Lahman
82f735a434 chore(fe): rm settings page top padding (#8621) 2026-02-20 17:37:19 +00:00
Jamison Lahman
aadb58518b fix(fe): popover width can fit trigger element (#8624) 2026-02-20 07:29:21 +00:00
Jamison Lahman
0755499e0f chore(fe): whitelabel logo layout followup (#8626) 2026-02-19 23:27:37 -08:00
Danelegend
27aaf977a2 feat(code interpreter): improve produced code artifacts (#8507)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-20 07:10:09 +00:00
Evan Lohn
9f707f195e feat: sharepoint scalability3 (#8537) 2026-02-20 05:34:50 +00:00
Justin Tahara
3e35570f70 feat(vertex ai): Image Gen Historical Context (#8603) 2026-02-20 05:31:41 +00:00
Jamison Lahman
53b1bf3b2c chore(fe): fix nested button hydration error (#8599) 2026-02-20 05:14:24 +00:00
Jamison Lahman
5a3fa6b648 chore(fe): fix whitelabel logo moving on sidebar close (#8577) 2026-02-20 04:30:01 +00:00
Evan Lohn
fc6a37850b feat: delta sync sharepoint (#8532)
Co-authored-by: CE11-Kishan <CE11-Kishan@users.noreply.github.com>
2026-02-20 03:26:54 +00:00
Raunak Bhagat
aa6fec3d58 Fix/pat UI (#8617) 2026-02-19 19:23:26 -08:00
Raunak Bhagat
efa6005e36 feat(opal): Add widthVariant to Interactive.Container (#8610) 2026-02-20 03:14:30 +00:00
Nikolas Garza
921bfc72f4 fix(helm): route /openapi.json to api_server in nginx config (#8612) 2026-02-19 19:05:14 -08:00
Justin Tahara
812603152d feat(web): FE Changes for Brave Web Search 3/3 (#8597) 2026-02-20 02:43:30 +00:00
Raunak Bhagat
6779d8fbd7 feat(opal): Add Content layout component (#8534)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-20 02:35:27 +00:00
Justin Tahara
2c9826e4a9 feat(web): Logical Hardening for Brave Web Search 2/3 (#8595) 2026-02-20 02:09:49 +00:00
Justin Tahara
5b54687077 feat(web): Initial Framework for Brave Web Search 1/3 (#8594) 2026-02-20 01:38:19 +00:00
Raunak Bhagat
0f7e2ee674 feat(opal): Add Tag component and resync colors with Figma (#8533)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-20 01:14:52 +00:00
Danelegend
ea466648d9 feat: file preview from llm (#8604) 2026-02-20 00:47:03 +00:00
Evan Lohn
a402911ee6 feat: sharepoint scalability 1 (#8531) 2026-02-20 00:41:48 +00:00
Wenxi
7ae9ba807d fix(manage-users): exclude slack users from /users list (#8602) 2026-02-19 23:41:25 +00:00
Nikolas Garza
1f79223c42 feat(ods): add --continue flag and cp alias to cherry-pick (#8601) 2026-02-19 22:31:06 +00:00
Danelegend
c0c2247d5a feat: Modal Header no icon (#8596) 2026-02-19 21:59:18 +00:00
Danelegend
2989ceda41 feat: add file formatting reminder (#8524) 2026-02-19 21:51:06 +00:00
Wenxi
c825f5eca6 feat: whitelist invite setting, allow users to register and invite, new accoount blocked page (#8527) 2026-02-19 20:16:50 +00:00
Jamison Lahman
a8965def79 chore(playwright): fix chat scrolling non-determinism (#8584) 2026-02-19 19:26:09 +00:00
Jamison Lahman
59e1ad51ba chore(fe): fix drop-down overflow in API Key modal (#8574) 2026-02-19 19:15:10 +00:00
Jamison Lahman
0e70a8f826 chore(fe): remove close button from image gen tooltip (#8585) 2026-02-19 18:31:02 +00:00
SubashMohan
0891737dfd fix: update SourceTag component to use variant prop for sizing (#8582) 2026-02-19 17:27:04 +00:00
victoria reese
5a20112670 fix: define fallback only on custom metrics (#8566) 2026-02-19 09:19:44 -08:00
SubashMohan
584f2e2638 fix(ui): Fix admin UI layout, sticky headers, LLM filtering, and project view issues (#8496) 2026-02-19 14:02:22 +00:00
Yuhong Sun
aa24b16ec1 chore: Opensearch tuning (#8518) 2026-02-19 05:43:45 +00:00
acaprau
50aa9d7df6 feat(opensearch): Display percentage progress in the migration page (#8575) 2026-02-19 05:33:29 +00:00
acaprau
bfda586054 chore(opensearch): Make the migration visit from 4 independent Vespa slices concurrently (#8570) 2026-02-19 04:26:59 +00:00
roshan
e04392fbb1 feat(craft): shareable webapp URLs with two-tier access control (#8510)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-02-19 04:04:21 +00:00
Wenxi
e46c6c5175 fix: init all celery tasks for autodiscovery (#8539) 2026-02-19 03:45:40 +00:00
Nikolas Garza
f59792b4ac feat(metrics): add SQLAlchemy connection pool Prometheus metrics (#8543) 2026-02-19 02:32:42 +00:00
Justin Tahara
973b9456e9 fix(vertex ai): Sort Model Names (#8572) 2026-02-19 02:24:47 +00:00
Justin Tahara
aa8d126513 fix(ollama): Ollama Chat fixes (#8522) 2026-02-19 01:37:07 +00:00
Jamison Lahman
a6da5add49 chore(fe): header text content wraps when responsive (#8565) 2026-02-19 01:25:12 +00:00
Jamison Lahman
3356f90437 chore(fe): Button handles empty string as text, use in header (#8563) 2026-02-19 01:24:10 +00:00
Danelegend
27c254ecf9 fix: /llm/provider route returns all providers (#8545) 2026-02-19 00:13:27 +00:00
Jamison Lahman
09678b3c8e chore(fe): whitelabeling header nits (#8561) 2026-02-18 23:39:54 +00:00
Justin Tahara
ecdb962e24 chore(llm): Cleaning up arg parsing (#8555) 2026-02-18 23:06:19 +00:00
Justin Tahara
63b9b91565 chore(llm): Extract _close_reasoning (#8550) 2026-02-18 22:59:07 +00:00
Justin Tahara
14770e6e90 chore(llm): Extract citation flush helper (#8554) 2026-02-18 22:31:51 +00:00
Justin Tahara
14807d986a chore(llm): Using bool for has_reasoned (#8549) 2026-02-18 22:18:09 +00:00
Justin Tahara
290eb98020 chore(llm): Extract _make_placement (#8552) 2026-02-18 21:58:58 +00:00
Nikolas Garza
fe6fa3d034 feat(scim): add SCIM Group CRUD endpoints (#8456) 2026-02-18 21:52:56 +00:00
Nikolas Garza
2a60a02e0e feat(scim): add admin SCIM token management API (#8538) 2026-02-18 21:41:09 +00:00
Justin Tahara
3bcd666e90 chore(llm): Cleaning up _extract_tool_call_kickoffs (#8548) 2026-02-18 21:39:36 +00:00
Jamison Lahman
684013732c chore(fe): update human message size (#8547) 2026-02-18 21:27:02 +00:00
Nikolas Garza
367dcb8f8b feat(scim): add SCIM User CRUD endpoints (#8455) 2026-02-18 21:26:52 +00:00
Justin Tahara
59dfed0bc8 chore(llm): Cleaning up Docstring (#8546) 2026-02-18 21:25:32 +00:00
Jamison Lahman
7a719b54bb chore(playwright): worker-aware users for test isolation (#8544) 2026-02-18 21:13:07 +00:00
Jamison Lahman
25ef5ff010 chore(fe): update SourceTag tag size (#8540) 2026-02-18 20:49:11 +00:00
Danelegend
53f9f042a1 fix: model config not populating flow during sync (#8542) 2026-02-18 20:08:59 +00:00
Jamison Lahman
3469f0c979 chore(gha): rm nightly license scan workflow (#8541) 2026-02-18 19:55:23 +00:00
Jamison Lahman
f688efbcd6 chore(gha): update helm/chart-testing-action version (#8536) 2026-02-18 11:21:17 -08:00
Jamison Lahman
250658a8b2 chore(playwright): screenshots wait for animations (#8530) 2026-02-18 18:33:33 +00:00
SubashMohan
5150ffc3e0 feat(chat): new chat sharing UI (#8471) 2026-02-18 15:15:34 +00:00
Jamison Lahman
858c1dbe4a chore(playwright): chat rendering tests (#8526) 2026-02-18 05:41:54 +00:00
Evan Lohn
a8e7353227 fix: shared drive node names (#8520) 2026-02-18 05:34:00 +00:00
Nikolas Garza
343cda35cb feat(observability): add production Prometheus instrumentation module (#8503) 2026-02-18 05:09:32 +00:00
Wenxi
1cbe47d85e chore: increase firecrawl test timeout to 60s for e2e test (#8529) 2026-02-18 05:02:45 +00:00
Jamison Lahman
221658132a chore(playwright): skip visual regression report on cancelled (#8528)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-02-17 20:56:22 -08:00
Jamison Lahman
fe8fb9eb75 fix(fe): center-align modals relative to chat container (#8517) 2026-02-18 04:27:38 +00:00
Wenxi
f7925584b8 fix: create release notifications on cloud (#8525) 2026-02-18 03:47:23 +00:00
Nikolas Garza
00b0e15ed7 feat(scim): add SCIM 2.0 service discovery endpoints (#8454) 2026-02-18 02:07:49 +00:00
Jamison Lahman
c2968e3bfe chore(playwright): update skill re: user isolation best-practices (#8521) 2026-02-17 17:52:03 -08:00
Justin Tahara
978f0a9d35 fix(ollama): Cleaning up DeepSeek (#8519) 2026-02-18 01:44:34 +00:00
Danelegend
410340fe37 chore: add linguist-language (#8515)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-18 00:57:27 +00:00
Wenxi
a0545c7eb3 fix: open_url broken on non-normalized urls and enable web crawl tests (#8508) 2026-02-18 00:41:21 +00:00
Jamison Lahman
aa46a8bba2 chore(gha): only run zizmor when .github/ changes (#8516) 2026-02-17 16:25:13 -08:00
Jamison Lahman
cd5aaa0302 chore(playwright): prefer absolute imports (#8511) 2026-02-17 16:22:51 -08:00
roshan
db33efaeaa fix: Enable search UI in Chrome extension side panel (#8486)
Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-02-17 23:59:05 +00:00
Justin Tahara
c28d37dff8 chore(playwright): Cleanup MCP Tests (#8512) 2026-02-17 23:57:22 +00:00
Jamison Lahman
696e72bcbb chore(playwright): organize tests into directories (#8509) 2026-02-17 15:37:35 -08:00
Nikolas Garza
cfdac8083a feat(scim): add SCIM bearer token authentication (#8423) 2026-02-17 23:28:00 +00:00
Jamison Lahman
781aab67fa chore(cursor): playwright SKILL.md (#8506) 2026-02-17 15:12:01 -08:00
Justin Tahara
b14d357d55 chore(mcp): Adding more Playwright Tests (#8452) 2026-02-17 22:55:35 +00:00
Nikolas Garza
60bc1ce8a1 feat(scim): add SCIM database CRUD operations (DAL) (#8424) 2026-02-17 22:43:54 +00:00
roshan
c89e82ee58 feat(craft): ACP session persistence across sandbox sleep/wake (#8466)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 22:32:49 +00:00
Evan Lohn
529ab8179f chore: undeprecate doc sets (#8505) 2026-02-17 21:56:49 +00:00
Nikolas Garza
19716874b2 fix(ee): small ux fixes for licensing (#8498) 2026-02-17 21:55:06 +00:00
Justin Tahara
1ac3b8515d chore(playwright): Cleanup for LLM Tests (#8504) 2026-02-17 21:38:24 +00:00
Justin Tahara
0d0c8580ca fix(helm): Log Level Passthrough for Celery (#8495) 2026-02-17 21:15:46 +00:00
roshan
96a38dcc06 fix(craft): ephemeral ACP clients to prevent multi-replica session corruption (#8465)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 21:10:21 +00:00
Jamison Lahman
6fe72e5524 chore(playwright): use user for llm runtime tests (#8502) 2026-02-17 20:21:28 +00:00
Jamison Lahman
1f4ee4d550 chore(playwright): hide toast elements in screenshots (#8501) 2026-02-17 19:42:54 +00:00
Danelegend
534db103fc fix: Edit messages w/ files (#8461) 2026-02-17 18:27:40 +00:00
Justin Tahara
f6f2e2d7c0 fix(desktop): Link clicking within App (#8493) 2026-02-17 18:17:16 +00:00
Raunak Bhagat
b2836c5a13 fix: Interactive.Base href rendering and container self-stretch (#8492)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 18:16:40 +00:00
Justin Tahara
47177c536d fix(vertex): Additional fix for Opus 4.6 (#8497) 2026-02-17 18:12:19 +00:00
victoria reese
3920a1554e Chore/mcp server host (#8472) 2026-02-17 07:19:29 -08:00
SubashMohan
8b12bf34ed fix(frontend): simplify onboarding logic and fix UI issues (#8470) 2026-02-17 08:42:30 +00:00
roshan
dd5b7edc6b fix(ui): remove data-state from Switch (#8394) 2026-02-17 05:22:38 +00:00
Danelegend
7977017aab fix: Code interpreter takes files (#8489)
Co-authored-by: dat <dat@purplefish.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Dat Bac Do <dat.b.do@gmail.com>
2026-02-16 18:53:20 -08:00
Evan Lohn
efd8388239 feat: slim confluence hiernodes (#8490) 2026-02-17 02:37:10 +00:00
Evan Lohn
ac6668309f fix: custom MCP auth Headers in API [mirror of #8345] (#8487)
Co-authored-by: Clément Rigo <rigoclement@mydnic.be>
2026-02-17 01:21:35 +00:00
Evan Lohn
a8bf63315c chore: hiernode passive deletes (#8484) 2026-02-17 00:30:18 +00:00
Evan Lohn
d905eced87 feat: pruning picks up hierarchynodes (#8481) 2026-02-16 23:56:20 +00:00
Evan Lohn
1e6f8fa6bf fix: allow public oauth clients via DCR (#8482) 2026-02-16 23:13:27 +00:00
Raunak Bhagat
faeac13693 feat: agent view only modal (#7989)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 18:23:44 +00:00
SubashMohan
ddb14ec762 fix(ui): fix few common ui bugs (#8425) 2026-02-16 11:22:43 +00:00
Justin Tahara
f31f589860 chore(llm): Adding more FE Unit Tests (#8457) 2026-02-16 03:02:19 +00:00
Evan Lohn
63b9a869af fix: CheckpointedConnector pruning only processes first checkpoint step (mirror of #8464) (#8468)
Co-authored-by: Yves Grolet <yves@grolet.com>
2026-02-16 00:45:43 +00:00
Yuhong Sun
6aea36b573 chore: Context summarization update (#8467) 2026-02-15 23:39:47 +00:00
Wenxi
3d8e8d0846 refactor: connector config refresh elements/cleanup (#8428) 2026-02-15 20:12:51 +00:00
Yuhong Sun
dea5be2185 chore: License update (No change, just touchup) (#8460) 2026-02-14 02:44:38 +00:00
Wenxi
d083973d4f chore: disable auto craft animation with feature flag (#8459) 2026-02-14 02:29:37 +00:00
Wenxi
df956888bf fix: bake public recaptcha key in cloud image (#8458) 2026-02-14 02:12:43 +00:00
dependabot[bot]
7c6062e7d5 chore(deps): bump qs from 6.14.1 to 6.14.2 in /web (#8451)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-14 02:04:30 +00:00
Yuhong Sun
89d2759021 chore: Remove end of lived backend routes (#8453) 2026-02-14 01:57:06 +00:00
Justin Tahara
d9feaf43a7 chore(playwright): Adding new LLM Runtime tests (#8447) 2026-02-14 01:38:23 +00:00
Nikolas Garza
5bfffefa2f feat(scim): add SCIM filter expression parser with unit tests (#8421) 2026-02-14 01:17:48 +00:00
Nikolas Garza
4d0b7e14d4 feat(scim): add SCIM PATCH operation handler with unit tests (#8422) 2026-02-14 01:12:46 +00:00
Jamison Lahman
36c55d9e59 chore(gha): de-duplicate integration test logic (#8450) 2026-02-14 00:31:31 +00:00
Wenxi
9f652108f9 fix: don't pass captcha token to db (#8449) 2026-02-14 00:20:36 +00:00
victoria reese
d4e4c6b40e feat: add setting to configure mcp host (#8439) 2026-02-13 23:49:18 +00:00
Jamison Lahman
9c8deb5d0c chore(playwright): mask non-deterministic email element (#8448) 2026-02-13 23:37:24 +00:00
Danelegend
58f57c43aa feat(contextual-llm): Populate and set w/ llm flow (#8398) 2026-02-13 23:32:26 +00:00
Evan Lohn
62106df753 fix: sharepoint cred refresh2 (#8445) 2026-02-13 23:05:15 +00:00
Jamison Lahman
45b3a5e945 chore(playwright): include option to hide element in screenshots (#8446) 2026-02-13 22:45:46 +00:00
Jamison Lahman
e19a6b6789 chore(playwright): create new user tests (#8429) 2026-02-13 22:17:18 +00:00
Jamison Lahman
2de7df4839 chore(playwright): login page screenshots (#8427) 2026-02-13 22:01:32 +00:00
victoria reese
bd054bbad9 fix: remove default idleReplicaCount (#8434) 2026-02-13 13:37:19 -08:00
Justin Tahara
313e709d41 fix(celery): Respecting Limits for Celery Heavy Tasks (#8407) 2026-02-13 21:27:04 +00:00
Nikolas Garza
aeb1d6edac feat(scim): add SCIM 2.0 Pydantic schemas (#8420) 2026-02-13 21:21:05 +00:00
Wenxi
49a35f8aaa fix: remove user file indexing from launch, add init imports for all celery tasks, bump sandbox memory limits (#8443) 2026-02-13 21:15:30 +00:00
Danelegend
049e8ef0e2 feat(llm): Populate env w/ custom config (#8328) 2026-02-13 21:11:49 +00:00
Jamison Lahman
3b61b495a3 chore(playwright): tag appearance_theme tests exclusive (#8441) 2026-02-13 21:07:57 +00:00
Wenxi
5c5c9f0e1d feat(airtable): index all and heirarchy for craft (#8414) 2026-02-13 21:03:53 +00:00
Nikolas Garza
f20d5c33b7 feat(scim): add SCIM database models and migration (#8419) 2026-02-13 20:54:56 +00:00
Jamison Lahman
e898407f7b chore(tests): skip yet another test_web_search_api test (#8442) 2026-02-13 12:50:04 -08:00
Jamison Lahman
f802ff09a7 chore(tests): skip additional web_search test (#8440) 2026-02-13 12:29:36 -08:00
Jamison Lahman
69ad712e09 chore(tests): temporarily disable exa tests (#8431) 2026-02-13 11:06:25 -08:00
Jamison Lahman
98b69c0f2c chore(playwright): welcome_page tests & per-element screenshots (#8426) 2026-02-13 10:07:27 -08:00
Raunak Bhagat
1e5c87896f refactor(web): migrate from usePopup/setPopup to global toast system (#8411)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 17:21:14 +00:00
Raunak Bhagat
b6cc97a8c3 fix(web): icon button and timeline header UI fixes (#8416) 2026-02-13 17:20:37 +00:00
Yuhong Sun
032fbf1058 chore: reminder prompt to be moveable (#8417) 2026-02-13 07:39:12 +00:00
SubashMohan
fc32a9f92a fix(memory): memory tool UI and prompt injection issues (#8377) 2026-02-13 04:29:51 +00:00
Jamison Lahman
9be13bbf63 chore(playwright): make screenshots deterministic (#8412) 2026-02-12 19:53:11 -08:00
Yuhong Sun
9e7176eb82 chore: Tiny intro message change (#8415) 2026-02-12 19:44:34 -08:00
Yuhong Sun
c7faf8ce52 chore: Project instructions would get ignored (#8409) 2026-02-13 02:51:13 +00:00
885 changed files with 48928 additions and 16933 deletions

1
.claude/skills Symbolic link
View File

@@ -0,0 +1 @@
../.cursor/skills

View File

@@ -0,0 +1,248 @@
---
name: playwright-e2e-tests
description: Write and maintain Playwright end-to-end tests for the Onyx application. Use when creating new E2E tests, debugging test failures, adding test coverage, or when the user mentions Playwright, E2E tests, or browser testing.
---
# Playwright E2E Tests
## Project Layout
- **Tests**: `web/tests/e2e/` — organized by feature (`auth/`, `admin/`, `chat/`, `assistants/`, `connectors/`, `mcp/`)
- **Config**: `web/playwright.config.ts`
- **Utilities**: `web/tests/e2e/utils/`
- **Constants**: `web/tests/e2e/constants.ts`
- **Global setup**: `web/tests/e2e/global-setup.ts`
- **Output**: `web/output/playwright/`
## Imports
Always use absolute imports with the `@tests/e2e/` prefix — never relative paths (`../`, `../../`). The alias is defined in `web/tsconfig.json` and resolves to `web/tests/`.
```typescript
import { loginAs } from "@tests/e2e/utils/auth";
import { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
import { TEST_ADMIN_CREDENTIALS } from "@tests/e2e/constants";
```
All new files should be `.ts`, not `.js`.
## Running Tests
```bash
# Run a specific test file
npx playwright test web/tests/e2e/chat/default_assistant.spec.ts
# Run a specific project
npx playwright test --project admin
npx playwright test --project exclusive
```
## Test Projects
| Project | Description | Parallelism |
|---------|-------------|-------------|
| `admin` | Standard tests (excludes `@exclusive`) | Parallel |
| `exclusive` | Serial, slower tests (tagged `@exclusive`) | 1 worker |
All tests use `admin_auth.json` storage state by default (pre-authenticated admin session).
## Authentication
Global setup (`global-setup.ts`) runs automatically before all tests and handles:
- Server readiness check (polls health endpoint, 60s timeout)
- Provisioning test users: admin, admin2, and a **pool of worker users** (`worker0@example.com` through `worker7@example.com`) (idempotent)
- API login + saving storage states: `admin_auth.json`, `admin2_auth.json`, and `worker{N}_auth.json` for each worker user
- Setting display name to `"worker"` for each worker user
- Promoting admin2 to admin role
- Ensuring a public LLM provider exists
Both test projects set `storageState: "admin_auth.json"`, so **every test starts pre-authenticated as admin with no login code needed**.
When a test needs a different user, use API-based login — never drive the login UI:
```typescript
import { loginAs } from "@tests/e2e/utils/auth";
await page.context().clearCookies();
await loginAs(page, "admin2");
// Log in as the worker-specific user (preferred for test isolation):
import { loginAsWorkerUser } from "@tests/e2e/utils/auth";
await page.context().clearCookies();
await loginAsWorkerUser(page, testInfo.workerIndex);
```
## Test Structure
Tests start pre-authenticated as admin — navigate and test directly:
```typescript
import { test, expect } from "@playwright/test";
test.describe("Feature Name", () => {
test("should describe expected behavior clearly", async ({ page }) => {
await page.goto("/app");
await page.waitForLoadState("networkidle");
// Already authenticated as admin — go straight to testing
});
});
```
**User isolation** — tests that modify visible app state (creating assistants, sending chat messages, pinning items) should run as a **worker-specific user** and clean up resources in `afterAll`. Global setup provisions a pool of worker users (`worker0@example.com` through `worker7@example.com`). `loginAsWorkerUser` maps `testInfo.workerIndex` to a pool slot via modulo, so retry workers (which get incrementing indices beyond the pool size) safely reuse existing users. This ensures parallel workers never share user state, keeps usernames deterministic for screenshots, and avoids cross-contamination:
```typescript
import { test } from "@playwright/test";
import { loginAsWorkerUser } from "@tests/e2e/utils/auth";
test.beforeEach(async ({ page }, testInfo) => {
await page.context().clearCookies();
await loginAsWorkerUser(page, testInfo.workerIndex);
});
```
If the test requires admin privileges *and* modifies visible state, use `"admin2"` instead — it's a pre-provisioned admin account that keeps the primary `"admin"` clean for other parallel tests. Switch to `"admin"` only for privileged setup (creating providers, configuring tools), then back to the worker user for the actual test. See `chat/default_assistant.spec.ts` for a full example.
`loginAsRandomUser` exists for the rare case where the test requires a brand-new user (e.g. onboarding flows). Avoid it elsewhere — it produces non-deterministic usernames that complicate screenshots.
**API resource setup** — only when tests need to create backend resources (image gen configs, web search providers, MCP servers). Use `beforeAll`/`afterAll` with `OnyxApiClient` to create and clean up. See `chat/default_assistant.spec.ts` or `mcp/mcp_oauth_flow.spec.ts` for examples. This is uncommon (~4 of 37 test files).
## Key Utilities
### `OnyxApiClient` (`@tests/e2e/utils/onyxApiClient`)
Backend API client for test setup/teardown. Key methods:
- **Connectors**: `createFileConnector()`, `deleteCCPair()`, `pauseConnector()`
- **LLM Providers**: `ensurePublicProvider()`, `createRestrictedProvider()`, `setProviderAsDefault()`
- **Assistants**: `createAssistant()`, `deleteAssistant()`, `findAssistantByName()`
- **User Groups**: `createUserGroup()`, `deleteUserGroup()`, `setUserRole()`
- **Tools**: `createWebSearchProvider()`, `createImageGenerationConfig()`
- **Chat**: `createChatSession()`, `deleteChatSession()`
### `chatActions` (`@tests/e2e/utils/chatActions`)
- `sendMessage(page, message)` — sends a message and waits for AI response
- `startNewChat(page)` — clicks new-chat button and waits for intro
- `verifyDefaultAssistantIsChosen(page)` — checks Onyx logo is visible
- `verifyAssistantIsChosen(page, name)` — checks assistant name display
- `switchModel(page, modelName)` — switches LLM model via popover
### `visualRegression` (`@tests/e2e/utils/visualRegression`)
- `expectScreenshot(page, { name, mask?, hide?, fullPage? })`
- `expectElementScreenshot(locator, { name, mask?, hide? })`
- Controlled by `VISUAL_REGRESSION=true` env var
### `theme` (`@tests/e2e/utils/theme`)
- `THEMES``["light", "dark"] as const` array for iterating over both themes
- `setThemeBeforeNavigation(page, theme)` — sets `next-themes` theme via `localStorage` before navigation
When tests need light/dark screenshots, loop over `THEMES` at the `test.describe` level and call `setThemeBeforeNavigation` in `beforeEach` **before** any `page.goto()`. Include the theme in screenshot names. See `admin/admin_pages.spec.ts` or `chat/chat_message_rendering.spec.ts` for examples:
```typescript
import { THEMES, setThemeBeforeNavigation } from "@tests/e2e/utils/theme";
for (const theme of THEMES) {
test.describe(`Feature (${theme} mode)`, () => {
test.beforeEach(async ({ page }) => {
await setThemeBeforeNavigation(page, theme);
});
test("renders correctly", async ({ page }) => {
await page.goto("/app");
await expectScreenshot(page, { name: `feature-${theme}` });
});
});
}
```
### `tools` (`@tests/e2e/utils/tools`)
- `TOOL_IDS` — centralized `data-testid` selectors for tool options
- `openActionManagement(page)` — opens the tool management popover
## Locator Strategy
Use locators in this priority order:
1. **`data-testid` / `aria-label`** — preferred for Onyx components
```typescript
page.getByTestId("AppSidebar/new-session")
page.getByLabel("admin-page-title")
```
2. **Role-based** — for standard HTML elements
```typescript
page.getByRole("button", { name: "Create" })
page.getByRole("dialog")
```
3. **Text/Label** — for visible text content
```typescript
page.getByText("Custom Assistant")
page.getByLabel("Email")
```
4. **CSS selectors** — last resort, only when above won't work
```typescript
page.locator('input[name="name"]')
page.locator("#onyx-chat-input-textarea")
```
**Never use** `page.locator` with complex CSS/XPath when a built-in locator works.
## Assertions
Use web-first assertions — they auto-retry until the condition is met:
```typescript
// Visibility
await expect(page.getByTestId("onyx-logo")).toBeVisible({ timeout: 5000 });
// Text content
await expect(page.getByTestId("assistant-name-display")).toHaveText("My Assistant");
// Count
await expect(page.locator('[data-testid="onyx-ai-message"]')).toHaveCount(2, { timeout: 30000 });
// URL
await expect(page).toHaveURL(/chatId=/);
// Element state
await expect(toggle).toBeChecked();
await expect(button).toBeEnabled();
```
**Never use** `assert` statements or hardcoded `page.waitForTimeout()`.
## Waiting Strategy
```typescript
// Wait for load state after navigation
await page.goto("/app");
await page.waitForLoadState("networkidle");
// Wait for specific element
await page.getByTestId("chat-intro").waitFor({ state: "visible", timeout: 10000 });
// Wait for URL change
await page.waitForFunction(() => window.location.href.includes("chatId="), null, { timeout: 10000 });
// Wait for network response
await page.waitForResponse(resp => resp.url().includes("/api/chat") && resp.status() === 200);
```
## Best Practices
1. **Descriptive test names** — clearly state expected behavior: `"should display greeting message when opening new chat"`
2. **API-first setup** — use `OnyxApiClient` for backend state; reserve UI interactions for the behavior under test
3. **User isolation** — tests that modify visible app state (sidebar, chat history) should run as the worker-specific user via `loginAsWorkerUser(page, testInfo.workerIndex)` (not admin) and clean up resources in `afterAll`. Each parallel worker gets its own user, preventing cross-contamination. Reserve `loginAsRandomUser` for flows that require a brand-new user (e.g. onboarding)
4. **DRY helpers** — extract reusable logic into `utils/` with JSDoc comments
5. **No hardcoded waits** — use `waitFor`, `waitForLoadState`, or web-first assertions
6. **Parallel-safe** — no shared mutable state between tests. Prefer static, human-readable names (e.g. `"E2E-CMD Chat 1"`) and clean up resources by ID in `afterAll`. This keeps screenshots deterministic and avoids needing to mask/hide dynamic text. Only fall back to timestamps (`\`test-${Date.now()}\``) when resources cannot be reliably cleaned up or when name collisions across parallel workers would cause functional failures
7. **Error context** — catch and re-throw with useful debug info (page text, URL, etc.)
8. **Tag slow tests** — mark serial/slow tests with `@exclusive` in the test title
9. **Visual regression** — use `expectScreenshot()` for UI consistency checks
10. **Minimal comments** — only comment to clarify non-obvious intent; never restate what the next line of code does

View File

@@ -0,0 +1,73 @@
name: "Build Backend Image"
description: "Builds and pushes the backend Docker image with cache reuse"
inputs:
runs-on-ecr-cache:
description: "ECR cache registry from runs-on/action"
required: true
ref-name:
description: "Git ref name used for cache suffix fallback"
required: true
pr-number:
description: "Optional PR number for cache suffix"
required: false
default: ""
github-sha:
description: "Commit SHA used for cache keys"
required: true
run-id:
description: "GitHub run ID used in output image tag"
required: true
docker-username:
description: "Docker Hub username"
required: true
docker-token:
description: "Docker Hub token"
required: true
docker-no-cache:
description: "Set to 'true' to disable docker build cache"
required: false
default: "false"
runs:
using: "composite"
steps:
- name: Format branch name for cache
id: format-branch
shell: bash
env:
PR_NUMBER: ${{ inputs.pr-number }}
REF_NAME: ${{ inputs.ref-name }}
run: |
if [ -n "${PR_NUMBER}" ]; then
CACHE_SUFFIX="${PR_NUMBER}"
else
# shellcheck disable=SC2001
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
fi
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ inputs.docker-username }}
password: ${{ inputs.docker-token }}
- name: Build and push Backend Docker image
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile
push: true
tags: ${{ inputs.runs-on-ecr-cache }}:nightly-llm-it-backend-${{ inputs.run-id }}
cache-from: |
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ inputs.github-sha }}
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }}
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache
type=registry,ref=onyxdotapp/onyx-backend:latest
cache-to: |
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ inputs.github-sha }},mode=max
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache,mode=max
no-cache: ${{ inputs.docker-no-cache == 'true' }}

View File

@@ -0,0 +1,75 @@
name: "Build Integration Image"
description: "Builds and pushes the integration test image with docker bake"
inputs:
runs-on-ecr-cache:
description: "ECR cache registry from runs-on/action"
required: true
ref-name:
description: "Git ref name used for cache suffix fallback"
required: true
pr-number:
description: "Optional PR number for cache suffix"
required: false
default: ""
github-sha:
description: "Commit SHA used for cache keys"
required: true
run-id:
description: "GitHub run ID used in output image tag"
required: true
docker-username:
description: "Docker Hub username"
required: true
docker-token:
description: "Docker Hub token"
required: true
runs:
using: "composite"
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ inputs.docker-username }}
password: ${{ inputs.docker-token }}
- name: Format branch name for cache
id: format-branch
shell: bash
env:
PR_NUMBER: ${{ inputs.pr-number }}
REF_NAME: ${{ inputs.ref-name }}
run: |
if [ -n "${PR_NUMBER}" ]; then
CACHE_SUFFIX="${PR_NUMBER}"
else
# shellcheck disable=SC2001
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
fi
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
- name: Build and push integration test image with Docker Bake
shell: bash
env:
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
TAG: nightly-llm-it-${{ inputs.run-id }}
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
HEAD_SHA: ${{ inputs.github-sha }}
run: |
docker buildx bake --push \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA},mode=max \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX},mode=max \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA} \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX} \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA},mode=max \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX},mode=max \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
integration

View File

@@ -0,0 +1,68 @@
name: "Build Model Server Image"
description: "Builds and pushes the model server Docker image with cache reuse"
inputs:
runs-on-ecr-cache:
description: "ECR cache registry from runs-on/action"
required: true
ref-name:
description: "Git ref name used for cache suffix fallback"
required: true
pr-number:
description: "Optional PR number for cache suffix"
required: false
default: ""
github-sha:
description: "Commit SHA used for cache keys"
required: true
run-id:
description: "GitHub run ID used in output image tag"
required: true
docker-username:
description: "Docker Hub username"
required: true
docker-token:
description: "Docker Hub token"
required: true
runs:
using: "composite"
steps:
- name: Format branch name for cache
id: format-branch
shell: bash
env:
PR_NUMBER: ${{ inputs.pr-number }}
REF_NAME: ${{ inputs.ref-name }}
run: |
if [ -n "${PR_NUMBER}" ]; then
CACHE_SUFFIX="${PR_NUMBER}"
else
# shellcheck disable=SC2001
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
fi
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ inputs.docker-username }}
password: ${{ inputs.docker-token }}
- name: Build and push Model Server Docker image
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile.model_server
push: true
tags: ${{ inputs.runs-on-ecr-cache }}:nightly-llm-it-model-server-${{ inputs.run-id }}
cache-from: |
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ inputs.github-sha }}
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }}
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache
type=registry,ref=onyxdotapp/onyx-model-server:latest
cache-to: |
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ inputs.github-sha }},mode=max
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache,mode=max

View File

@@ -0,0 +1,118 @@
name: "Run Nightly Provider Chat Test"
description: "Starts required compose services and runs nightly provider integration test"
inputs:
provider:
description: "Provider slug for NIGHTLY_LLM_PROVIDER"
required: true
models:
description: "Comma-separated model list for NIGHTLY_LLM_MODELS"
required: true
provider-api-key:
description: "API key for NIGHTLY_LLM_API_KEY"
required: false
default: ""
strict:
description: "String true/false for NIGHTLY_LLM_STRICT"
required: true
api-base:
description: "Optional NIGHTLY_LLM_API_BASE"
required: false
default: ""
custom-config-json:
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON"
required: false
default: ""
runs-on-ecr-cache:
description: "ECR cache registry from runs-on/action"
required: true
run-id:
description: "GitHub run ID used in image tags"
required: true
docker-username:
description: "Docker Hub username"
required: true
docker-token:
description: "Docker Hub token"
required: true
runs:
using: "composite"
steps:
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ inputs.docker-username }}
password: ${{ inputs.docker-token }}
- name: Create .env file for Docker Compose
shell: bash
env:
ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
RUN_ID: ${{ inputs.run-id }}
run: |
cat <<EOF2 > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
LICENSE_ENFORCEMENT_ENABLED=false
AUTH_TYPE=basic
POSTGRES_POOL_PRE_PING=true
POSTGRES_USE_NULL_POOL=true
REQUIRE_EMAIL_VERIFICATION=false
DISABLE_TELEMETRY=true
INTEGRATION_TESTS_MODE=true
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
AWS_REGION_NAME=us-west-2
ONYX_BACKEND_IMAGE=${ECR_CACHE}:nightly-llm-it-backend-${RUN_ID}
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:nightly-llm-it-model-server-${RUN_ID}
EOF2
- name: Start Docker containers
shell: bash
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d --wait \
relational_db \
index \
cache \
minio \
api_server \
inference_model_server
- name: Run nightly provider integration test
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
env:
MODELS: ${{ inputs.models }}
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
NIGHTLY_LLM_API_KEY: ${{ inputs.provider-api-key }}
NIGHTLY_LLM_API_BASE: ${{ inputs.api-base }}
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom-config-json }}
NIGHTLY_LLM_STRICT: ${{ inputs.strict }}
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
RUN_ID: ${{ inputs.run-id }}
with:
timeout_minutes: 20
max_attempts: 2
retry_wait_seconds: 10
command: |
docker run --rm --network onyx_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e DB_READONLY_USER=db_readonly_user \
-e DB_READONLY_PASSWORD=password \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e TEST_WEB_HOSTNAME=test-runner \
-e AWS_REGION_NAME=us-west-2 \
-e NIGHTLY_LLM_PROVIDER="${NIGHTLY_LLM_PROVIDER}" \
-e NIGHTLY_LLM_MODELS="${MODELS}" \
-e NIGHTLY_LLM_API_KEY="${NIGHTLY_LLM_API_KEY}" \
-e NIGHTLY_LLM_API_BASE="${NIGHTLY_LLM_API_BASE}" \
-e NIGHTLY_LLM_CUSTOM_CONFIG_JSON="${NIGHTLY_LLM_CUSTOM_CONFIG_JSON}" \
-e NIGHTLY_LLM_STRICT="${NIGHTLY_LLM_STRICT}" \
${RUNS_ON_ECR_CACHE}:nightly-llm-it-${RUN_ID} \
/app/tests/integration/tests/llm_workflows/test_nightly_provider_chat_workflow.py

View File

@@ -8,5 +8,5 @@
## Additional Options
- [ ] [Required] I have considered whether this PR needs to be cherry-picked to the latest beta branch.
- [ ] [Optional] Please cherry-pick this PR to the latest release version.
- [ ] [Optional] Override Linear Check

View File

@@ -640,6 +640,7 @@ jobs:
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${{ vars.NEXT_PUBLIC_RECAPTCHA_SITE_KEY }}
NEXT_PUBLIC_GTM_ENABLED=true
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
@@ -721,6 +722,7 @@ jobs:
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${{ vars.NEXT_PUBLIC_RECAPTCHA_SITE_KEY }}
NEXT_PUBLIC_GTM_ENABLED=true
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true

View File

@@ -33,7 +33,7 @@ jobs:
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
helm repo update
- name: Build chart dependencies

View File

@@ -0,0 +1,49 @@
name: Nightly LLM Provider Chat Tests
concurrency:
group: Nightly-LLM-Provider-Chat-${{ github.workflow }}-${{ github.ref_name }}
cancel-in-progress: true
on:
schedule:
# Runs daily at 10:30 UTC (2:30 AM PST / 3:30 AM PDT)
- cron: "30 10 * * *"
workflow_dispatch:
permissions:
contents: read
jobs:
provider-chat-test:
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
with:
openai_models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
anthropic_models: ${{ vars.NIGHTLY_LLM_ANTHROPIC_MODELS }}
bedrock_models: ${{ vars.NIGHTLY_LLM_BEDROCK_MODELS }}
vertex_ai_models: ${{ vars.NIGHTLY_LLM_VERTEX_AI_MODELS }}
strict: true
secrets:
openai_api_key: ${{ secrets.OPENAI_API_KEY }}
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
bedrock_api_key: ${{ secrets.BEDROCK_API_KEY }}
vertex_ai_custom_config_json: ${{ secrets.NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON }}
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
notify-slack-on-failure:
needs: [provider-chat-test]
if: failure() && github.event_name == 'schedule'
runs-on: ubuntu-slim
timeout-minutes: 5
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Send Slack notification
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
failed-jobs: provider-chat-test
title: "🚨 Scheduled LLM Provider Chat Tests failed!"
ref-name: ${{ github.ref_name }}

View File

@@ -1,151 +0,0 @@
# Scan for problematic software licenses
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
name: 'Nightly - Scan licenses'
on:
# schedule:
# - cron: '0 14 * * *' # Runs every day at 6 AM PST / 7 AM PDT / 2 PM UTC
workflow_dispatch: # Allows manual triggering
permissions:
actions: read
contents: read
jobs:
scan-licenses:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-licenses"]
timeout-minutes: 45
permissions:
actions: read
contents: read
security-events: write
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Set up Python
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # ratchet:actions/setup-python@v6
with:
python-version: '3.11'
cache: 'pip'
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- name: Get explicit and transitive dependencies
run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
pip freeze > requirements-all.txt
- name: Check python
id: license_check_report
uses: pilosus/action-pip-license-checker@e909b0226ff49d3235c99c4585bc617f49fff16a # ratchet:pilosus/action-pip-license-checker@v3
with:
requirements: 'requirements-all.txt'
fail: 'Copyleft'
exclude: '(?i)^(pylint|aio[-_]*).*'
- name: Print report
if: always()
env:
REPORT: ${{ steps.license_check_report.outputs.report }}
run: echo "$REPORT"
- name: Install npm dependencies
working-directory: ./web
run: npm ci
# be careful enabling the sarif and upload as it may spam the security tab
# with a huge amount of items. Work out the issues before enabling upload.
# - name: Run Trivy vulnerability scanner in repo mode
# if: always()
# uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
# with:
# scan-type: fs
# scan-ref: .
# scanners: license
# format: table
# severity: HIGH,CRITICAL
# # format: sarif
# # output: trivy-results.sarif
#
# # - name: Upload Trivy scan results to GitHub Security tab
# # uses: github/codeql-action/upload-sarif@v3
# # with:
# # sarif_file: trivy-results.sarif
scan-trivy:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}-scan-trivy"]
timeout-minutes: 45
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# Backend
- name: Pull backend docker image
run: docker pull onyxdotapp/onyx-backend:latest
- name: Run Trivy vulnerability scanner on backend
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-backend:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0 # Set to 1 if we want a failed scan to fail the workflow
# Web server
- name: Pull web server docker image
run: docker pull onyxdotapp/onyx-web-server:latest
- name: Run Trivy vulnerability scanner on web server
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-web-server:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0
# Model server
- name: Pull model server docker image
run: docker pull onyxdotapp/onyx-model-server:latest
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@b6643a29fecd7f34b3597bc6acb0a98b03d33ff8 # ratchet:aquasecurity/trivy-action@0.33.1
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-model-server:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0

View File

@@ -0,0 +1,161 @@
name: Post-Merge Beta Cherry-Pick
on:
push:
branches:
- main
permissions:
contents: write
pull-requests: write
jobs:
cherry-pick-to-latest-release:
outputs:
should_cherrypick: ${{ steps.gate.outputs.should_cherrypick }}
pr_number: ${{ steps.gate.outputs.pr_number }}
cherry_pick_reason: ${{ steps.run_cherry_pick.outputs.reason }}
cherry_pick_details: ${{ steps.run_cherry_pick.outputs.details }}
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Resolve merged PR and checkbox state
id: gate
env:
GH_TOKEN: ${{ github.token }}
run: |
# For the commit that triggered this workflow (HEAD on main), fetch all
# associated PRs and keep only the PR that was actually merged into main
# with this exact merge commit SHA.
pr_numbers="$(gh api "repos/${GITHUB_REPOSITORY}/commits/${GITHUB_SHA}/pulls" | jq -r --arg sha "${GITHUB_SHA}" '.[] | select(.merged_at != null and .base.ref == "main" and .merge_commit_sha == $sha) | .number')"
match_count="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | wc -l | tr -d ' ')"
pr_number="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | head -n 1)"
if [ "${match_count}" -gt 1 ]; then
echo "::warning::Multiple merged PRs matched commit ${GITHUB_SHA}. Using PR #${pr_number}."
fi
if [ -z "$pr_number" ]; then
echo "No merged PR associated with commit ${GITHUB_SHA}; skipping."
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
exit 0
fi
# Read the PR once so we can gate behavior and infer preferred actor.
pr_json="$(gh api "repos/${GITHUB_REPOSITORY}/pulls/${pr_number}")"
pr_body="$(printf '%s' "$pr_json" | jq -r '.body // ""')"
merged_by="$(printf '%s' "$pr_json" | jq -r '.merged_by.login // ""')"
echo "pr_number=$pr_number" >> "$GITHUB_OUTPUT"
echo "merged_by=$merged_by" >> "$GITHUB_OUTPUT"
if echo "$pr_body" | grep -qiE "\\[x\\][[:space:]]*(\\[[^]]+\\][[:space:]]*)?Please cherry-pick this PR to the latest release version"; then
echo "should_cherrypick=true" >> "$GITHUB_OUTPUT"
echo "Cherry-pick checkbox checked for PR #${pr_number}."
exit 0
fi
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
echo "Cherry-pick checkbox not checked for PR #${pr_number}. Skipping."
- name: Checkout repository
if: steps.gate.outputs.should_cherrypick == 'true'
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
fetch-depth: 0
persist-credentials: true
ref: main
- name: Install the latest version of uv
if: steps.gate.outputs.should_cherrypick == 'true'
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"
- name: Configure git identity
if: steps.gate.outputs.should_cherrypick == 'true'
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
- name: Create cherry-pick PR to latest release
id: run_cherry_pick
if: steps.gate.outputs.should_cherrypick == 'true'
continue-on-error: true
env:
GH_TOKEN: ${{ github.token }}
GITHUB_TOKEN: ${{ github.token }}
CHERRY_PICK_ASSIGNEE: ${{ steps.gate.outputs.merged_by }}
run: |
set -o pipefail
output_file="$(mktemp)"
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify 2>&1 | tee "$output_file"
exit_code="${PIPESTATUS[0]}"
if [ "${exit_code}" -eq 0 ]; then
echo "status=success" >> "$GITHUB_OUTPUT"
exit 0
fi
echo "status=failure" >> "$GITHUB_OUTPUT"
reason="command-failed"
if grep -qiE "merge conflict during cherry-pick|CONFLICT|could not apply|cherry-pick in progress with staged changes" "$output_file"; then
reason="merge-conflict"
fi
echo "reason=${reason}" >> "$GITHUB_OUTPUT"
{
echo "details<<EOF"
tail -n 40 "$output_file"
echo "EOF"
} >> "$GITHUB_OUTPUT"
- name: Mark workflow as failed if cherry-pick failed
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
run: |
echo "::error::Automated cherry-pick failed (${{ steps.run_cherry_pick.outputs.reason }})."
exit 1
notify-slack-on-cherry-pick-failure:
needs:
- cherry-pick-to-latest-release
if: always() && needs.cherry-pick-to-latest-release.outputs.should_cherrypick == 'true' && needs.cherry-pick-to-latest-release.result != 'success'
runs-on: ubuntu-slim
timeout-minutes: 10
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Build cherry-pick failure summary
id: failure-summary
env:
SOURCE_PR_NUMBER: ${{ needs.cherry-pick-to-latest-release.outputs.pr_number }}
CHERRY_PICK_REASON: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_reason }}
CHERRY_PICK_DETAILS: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_details }}
run: |
source_pr_url="https://github.com/${GITHUB_REPOSITORY}/pull/${SOURCE_PR_NUMBER}"
reason_text="cherry-pick command failed"
if [ "${CHERRY_PICK_REASON}" = "merge-conflict" ]; then
reason_text="merge conflict during cherry-pick"
fi
details_excerpt="$(printf '%s' "${CHERRY_PICK_DETAILS}" | tail -n 8 | tr '\n' ' ' | sed "s/[[:space:]]\\+/ /g" | sed "s/\"/'/g" | cut -c1-350)"
failed_jobs="• cherry-pick-to-latest-release\\n• source PR: ${source_pr_url}\\n• reason: ${reason_text}"
if [ -n "${details_excerpt}" ]; then
failed_jobs="${failed_jobs}\\n• excerpt: ${details_excerpt}"
fi
echo "jobs=${failed_jobs}" >> "$GITHUB_OUTPUT"
- name: Notify #cherry-pick-prs about cherry-pick failure
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
failed-jobs: ${{ steps.failure-summary.outputs.jobs }}
title: "🚨 Automated Cherry-Pick Failed"
ref-name: ${{ github.ref_name }}

View File

@@ -1,28 +0,0 @@
name: Require beta cherry-pick consideration
concurrency:
group: Require-Beta-Cherrypick-Consideration-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
pull_request:
types: [opened, edited, reopened, synchronize]
permissions:
contents: read
jobs:
beta-cherrypick-check:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Check PR body for beta cherry-pick consideration
env:
PR_BODY: ${{ github.event.pull_request.body }}
run: |
if echo "$PR_BODY" | grep -qiE "\\[x\\][[:space:]]*\\[Required\\][[:space:]]*I have considered whether this PR needs to be cherry[- ]picked to the latest beta branch"; then
echo "Cherry-pick consideration box is checked. Check passed."
exit 0
fi
echo "::error::Please check the 'I have considered whether this PR needs to be cherry-picked to the latest beta branch' box in the PR description."
exit 1

View File

@@ -45,9 +45,6 @@ env:
# TODO: debug why this is failing and enable
CODE_INTERPRETER_BASE_URL: http://localhost:8000
# OpenSearch
OPENSEARCH_ADMIN_PASSWORD: "StrongPassword123!"
jobs:
discover-test-dirs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
@@ -118,9 +115,9 @@ jobs:
- name: Create .env file for Docker Compose
run: |
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore
CODE_INTERPRETER_BETA_ENABLED=true
COMPOSE_PROFILES=s3-filestore,opensearch-enabled
DISABLE_TELEMETRY=true
OPENSEARCH_FOR_ONYX_ENABLED=true
EOF
- name: Set up Standard Dependencies
@@ -129,7 +126,6 @@ jobs:
docker compose \
-f docker-compose.yml \
-f docker-compose.dev.yml \
-f docker-compose.opensearch.yml \
up -d \
minio \
relational_db \

View File

@@ -41,8 +41,7 @@ jobs:
version: v3.19.0
- name: Set up chart-testing
# NOTE: This is Jamison's patch from https://github.com/helm/chart-testing-action/pull/194
uses: helm/chart-testing-action@8958a6ac472cbd8ee9a8fbb6f1acbc1b0e966e44 # zizmor: ignore[impostor-commit]
uses: helm/chart-testing-action@b5eebdd9998021f29756c53432f48dab66394810
with:
uv_version: "0.9.9"
@@ -92,7 +91,7 @@ jobs:
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
helm repo update
- name: Install Redis operator

View File

@@ -20,6 +20,7 @@ env:
# Test Environment Variables
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
SLACK_BOT_TOKEN_TEST_SPACE: ${{ secrets.SLACK_BOT_TOKEN_TEST_SPACE }}
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
@@ -46,6 +47,7 @@ jobs:
timeout-minutes: 45
outputs:
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
editions: ${{ steps.set-editions.outputs.editions }}
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
@@ -72,6 +74,16 @@ jobs:
all_dirs="[${all_dirs%,}]"
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
- name: Determine editions to test
id: set-editions
run: |
# On PRs, only run EE tests. On merge_group and tags, run both EE and MIT.
if [ "${{ github.event_name }}" = "pull_request" ]; then
echo 'editions=["ee"]' >> $GITHUB_OUTPUT
else
echo 'editions=["ee","mit"]' >> $GITHUB_OUTPUT
fi
build-backend-image:
runs-on:
[
@@ -267,7 +279,7 @@ jobs:
runs-on:
- runs-on
- runner=4cpu-linux-arm64
- ${{ format('run-id={0}-integration-tests-job-{1}', github.run_id, strategy['job-index']) }}
- ${{ format('run-id={0}-integration-tests-{1}-job-{2}', github.run_id, matrix.edition, strategy['job-index']) }}
- extras=ecr-cache
timeout-minutes: 45
@@ -275,6 +287,7 @@ jobs:
fail-fast: false
matrix:
test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }}
edition: ${{ fromJson(needs.discover-test-dirs.outputs.editions) }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
@@ -298,12 +311,11 @@ jobs:
env:
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
EDITION: ${{ matrix.edition }}
run: |
# Base config shared by both editions
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
LICENSE_ENFORCEMENT_ENABLED=false
AUTH_TYPE=basic
POSTGRES_POOL_PRE_PING=true
POSTGRES_USE_NULL_POOL=true
@@ -312,11 +324,20 @@ jobs:
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
INTEGRATION_TESTS_MODE=true
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
MCP_SERVER_ENABLED=true
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
EOF
# EE-only config
if [ "$EDITION" = "ee" ]; then
cat <<EOF >> deployment/docker_compose/.env
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
LICENSE_ENFORCEMENT_ENABLED=false
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
EOF
fi
- name: Start Docker containers
run: |
@@ -379,14 +400,14 @@ jobs:
docker compose -f docker-compose.mock-it-services.yml \
-p mock-it-services-stack up -d
- name: Run Integration Tests for ${{ matrix.test-dir.name }}
- name: Run Integration Tests (${{ matrix.edition }}) for ${{ matrix.test-dir.name }}
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
with:
timeout_minutes: 20
max_attempts: 3
retry_wait_seconds: 10
command: |
echo "Running integration tests for ${{ matrix.test-dir.path }}..."
echo "Running ${{ matrix.edition }} integration tests for ${{ matrix.test-dir.path }}..."
docker run --rm --network onyx_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
@@ -403,6 +424,7 @@ jobs:
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e EXA_API_KEY=${EXA_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e SLACK_BOT_TOKEN_TEST_SPACE=${SLACK_BOT_TOKEN_TEST_SPACE} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
@@ -423,6 +445,7 @@ jobs:
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
-e ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${{ matrix.edition == 'ee' && 'true' || 'false' }} \
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
/app/tests/integration/${{ matrix.test-dir.path }}
@@ -444,7 +467,7 @@ jobs:
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-all-logs-${{ matrix.test-dir.name }}
name: docker-all-logs-${{ matrix.edition }}-${{ matrix.test-dir.name }}
path: ${{ github.workspace }}/docker-compose.log
# ------------------------------------------------------------
@@ -681,6 +704,7 @@ jobs:
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e EXA_API_KEY=${EXA_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e SLACK_BOT_TOKEN_TEST_SPACE=${SLACK_BOT_TOKEN_TEST_SPACE} \
-e TEST_WEB_HOSTNAME=test-runner \
-e AUTH_TYPE=cloud \
-e MULTI_TENANT=true \

View File

@@ -1,443 +0,0 @@
name: Run MIT Integration Tests v2
concurrency:
group: Run-MIT-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
merge_group:
types: [checks_requested]
push:
tags:
- "v*.*.*"
permissions:
contents: read
env:
# Test Environment Variables
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }}
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }}
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
jobs:
discover-test-dirs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
outputs:
test-dirs: ${{ steps.set-matrix.outputs.test-dirs }}
steps:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Discover test directories
id: set-matrix
run: |
# Find all leaf-level directories in both test directories
tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" ! -name "mcp" -exec basename {} \; | sort)
connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort)
# Create JSON array with directory info
all_dirs=""
for dir in $tests_dirs; do
all_dirs="$all_dirs{\"path\":\"tests/$dir\",\"name\":\"tests-$dir\"},"
done
for dir in $connector_dirs; do
all_dirs="$all_dirs{\"path\":\"connector_job_tests/$dir\",\"name\":\"connector-$dir\"},"
done
# Remove trailing comma and wrap in array
all_dirs="[${all_dirs%,}]"
echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT
build-backend-image:
runs-on:
[
runs-on,
runner=1cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-backend-image",
"extras=ecr-cache",
]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Format branch name for cache
id: format-branch
env:
PR_NUMBER: ${{ github.event.pull_request.number }}
REF_NAME: ${{ github.ref_name }}
run: |
if [ -n "${PR_NUMBER}" ]; then
CACHE_SUFFIX="${PR_NUMBER}"
else
# shellcheck disable=SC2001
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
fi
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push Backend Docker image
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile
push: true
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ github.event.pull_request.head.sha || github.sha }}
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }}
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache
type=registry,ref=onyxdotapp/onyx-backend:latest
cache-to: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max
no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }}
build-model-server-image:
runs-on:
[
runs-on,
runner=1cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-model-server-image",
"extras=ecr-cache",
]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Format branch name for cache
id: format-branch
env:
PR_NUMBER: ${{ github.event.pull_request.number }}
REF_NAME: ${{ github.ref_name }}
run: |
if [ -n "${PR_NUMBER}" ]; then
CACHE_SUFFIX="${PR_NUMBER}"
else
# shellcheck disable=SC2001
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
fi
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push Model Server Docker image
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile.model_server
push: true
tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }}
cache-from: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }}
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }}
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache
type=registry,ref=onyxdotapp/onyx-model-server:latest
cache-to: |
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max
build-integration-image:
runs-on:
[
runs-on,
runner=2cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-integration-image",
"extras=ecr-cache",
]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Format branch name for cache
id: format-branch
env:
PR_NUMBER: ${{ github.event.pull_request.number }}
REF_NAME: ${{ github.ref_name }}
run: |
if [ -n "${PR_NUMBER}" ]; then
CACHE_SUFFIX="${PR_NUMBER}"
else
# shellcheck disable=SC2001
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
fi
echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
# needed for pulling openapitools/openapi-generator-cli
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push integration test image with Docker Bake
env:
INTEGRATION_REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }}
TAG: integration-test-${{ github.run_id }}
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
HEAD_SHA: ${{ github.event.pull_request.head.sha || github.sha }}
run: |
docker buildx bake --push \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA},mode=max \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX},mode=max \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA} \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX} \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA},mode=max \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX},mode=max \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
integration
integration-tests-mit:
needs:
[
discover-test-dirs,
build-backend-image,
build-model-server-image,
build-integration-image,
]
runs-on:
- runs-on
- runner=4cpu-linux-arm64
- ${{ format('run-id={0}-integration-tests-mit-job-{1}', github.run_id, strategy['job-index']) }}
- extras=ecr-cache
timeout-minutes: 45
strategy:
fail-fast: false
matrix:
test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }}
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
# needed for pulling Vespa, Redis, Postgres, and Minio images
# otherwise, we hit the "Unauthenticated users" limit
# https://docs.docker.com/docker-hub/usage/
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
# NOTE: don't need web server for integration tests
- name: Create .env file for Docker Compose
env:
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
RUN_ID: ${{ github.run_id }}
run: |
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore
AUTH_TYPE=basic
POSTGRES_POOL_PRE_PING=true
POSTGRES_USE_NULL_POOL=true
REQUIRE_EMAIL_VERIFICATION=false
DISABLE_TELEMETRY=true
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID}
INTEGRATION_TESTS_MODE=true
MCP_SERVER_ENABLED=true
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
EOF
- name: Start Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.dev.yml up \
relational_db \
index \
cache \
minio \
api_server \
inference_model_server \
indexing_model_server \
background \
-d
id: start_docker
- name: Wait for services to be ready
run: |
echo "Starting wait-for-service script..."
wait_for_service() {
local url=$1
local label=$2
local timeout=${3:-300} # default 5 minutes
local start_time
start_time=$(date +%s)
while true; do
local current_time
current_time=$(date +%s)
local elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. ${label} did not become ready in $timeout seconds."
exit 1
fi
local response
response=$(curl -s -o /dev/null -w "%{http_code}" "$url" || echo "curl_error")
if [ "$response" = "200" ]; then
echo "${label} is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error while checking ${label}. Retrying in 5 seconds..."
else
echo "${label} not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
sleep 5
done
}
wait_for_service "http://localhost:8080/health" "API server"
echo "Finished waiting for services."
- name: Start Mock Services
run: |
cd backend/tests/integration/mock_services
docker compose -f docker-compose.mock-it-services.yml \
-p mock-it-services-stack up -d
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
- name: Run Integration Tests for ${{ matrix.test-dir.name }}
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
with:
timeout_minutes: 20
max_attempts: 3
retry_wait_seconds: 10
command: |
echo "Running integration tests for ${{ matrix.test-dir.path }}..."
docker run --rm --network onyx_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e DB_READONLY_USER=db_readonly_user \
-e DB_READONLY_PASSWORD=password \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e EXA_API_KEY=${EXA_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e CONFLUENCE_ACCESS_TOKEN_SCOPED=${CONFLUENCE_ACCESS_TOKEN_SCOPED} \
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
-e JIRA_API_TOKEN_SCOPED=${JIRA_API_TOKEN_SCOPED} \
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
/app/tests/integration/${{ matrix.test-dir.path }}
# ------------------------------------------------------------
# Always gather logs BEFORE "down":
- name: Dump API server logs
if: always()
run: |
cd deployment/docker_compose
docker compose logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-all-logs-${{ matrix.test-dir.name }}
path: ${{ github.workspace }}/docker-compose.log
# ------------------------------------------------------------
required:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 45
needs: [integration-tests-mit]
if: ${{ always() }}
steps:
- name: Check job status
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }}
run: exit 1

View File

@@ -22,6 +22,9 @@ env:
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
FIRECRAWL_API_KEY: ${{ secrets.FIRECRAWL_API_KEY }}
GOOGLE_PSE_API_KEY: ${{ secrets.GOOGLE_PSE_API_KEY }}
GOOGLE_PSE_SEARCH_ENGINE_ID: ${{ secrets.GOOGLE_PSE_SEARCH_ENGINE_ID }}
# for federated slack tests
SLACK_CLIENT_ID: ${{ secrets.SLACK_CLIENT_ID }}
@@ -300,6 +303,7 @@ jobs:
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
LICENSE_ENFORCEMENT_ENABLED=false
AUTH_TYPE=basic
INTEGRATION_TESTS_MODE=true
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
EXA_API_KEY=${EXA_API_KEY_VALUE}
REQUIRE_EMAIL_VERIFICATION=false
@@ -589,7 +593,10 @@ jobs:
# Post a single combined visual regression comment after all matrix jobs finish
visual-regression-comment:
needs: [playwright-tests]
if: always() && github.event_name == 'pull_request'
if: >-
always() &&
github.event_name == 'pull_request' &&
needs.playwright-tests.result != 'cancelled'
runs-on: ubuntu-slim
timeout-minutes: 5
permissions:

View File

@@ -89,6 +89,10 @@ env:
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ vars.SHAREPOINT_CLIENT_DIRECTORY_ID }}
SHAREPOINT_SITE: ${{ vars.SHAREPOINT_SITE }}
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
# Github
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}

View File

@@ -0,0 +1,217 @@
name: Reusable Nightly LLM Provider Chat Tests
on:
workflow_call:
inputs:
openai_models:
description: "Comma-separated models for openai"
required: false
default: ""
type: string
anthropic_models:
description: "Comma-separated models for anthropic"
required: false
default: ""
type: string
bedrock_models:
description: "Comma-separated models for bedrock"
required: false
default: ""
type: string
vertex_ai_models:
description: "Comma-separated models for vertex_ai"
required: false
default: ""
type: string
strict:
description: "Default NIGHTLY_LLM_STRICT passed to tests"
required: false
default: true
type: boolean
secrets:
openai_api_key:
required: false
anthropic_api_key:
required: false
bedrock_api_key:
required: false
vertex_ai_custom_config_json:
required: false
DOCKER_USERNAME:
required: true
DOCKER_TOKEN:
required: true
permissions:
contents: read
jobs:
build-backend-image:
runs-on:
[
runs-on,
runner=1cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-backend-image",
"extras=ecr-cache",
]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Build backend image
uses: ./.github/actions/build-backend-image
with:
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
ref-name: ${{ github.ref_name }}
pr-number: ${{ github.event.pull_request.number }}
github-sha: ${{ github.sha }}
run-id: ${{ github.run_id }}
docker-username: ${{ secrets.DOCKER_USERNAME }}
docker-token: ${{ secrets.DOCKER_TOKEN }}
docker-no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' && 'true' || 'false' }}
build-model-server-image:
runs-on:
[
runs-on,
runner=1cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-model-server-image",
"extras=ecr-cache",
]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Build model server image
uses: ./.github/actions/build-model-server-image
with:
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
ref-name: ${{ github.ref_name }}
pr-number: ${{ github.event.pull_request.number }}
github-sha: ${{ github.sha }}
run-id: ${{ github.run_id }}
docker-username: ${{ secrets.DOCKER_USERNAME }}
docker-token: ${{ secrets.DOCKER_TOKEN }}
build-integration-image:
runs-on:
[
runs-on,
runner=2cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-integration-image",
"extras=ecr-cache",
]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Build integration image
uses: ./.github/actions/build-integration-image
with:
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
ref-name: ${{ github.ref_name }}
pr-number: ${{ github.event.pull_request.number }}
github-sha: ${{ github.sha }}
run-id: ${{ github.run_id }}
docker-username: ${{ secrets.DOCKER_USERNAME }}
docker-token: ${{ secrets.DOCKER_TOKEN }}
provider-chat-test:
needs:
[
build-backend-image,
build-model-server-image,
build-integration-image,
]
strategy:
fail-fast: false
matrix:
include:
- provider: openai
models: ${{ inputs.openai_models }}
api_key_secret: openai_api_key
custom_config_secret: ""
required: true
- provider: anthropic
models: ${{ inputs.anthropic_models }}
api_key_secret: anthropic_api_key
custom_config_secret: ""
required: true
- provider: bedrock
models: ${{ inputs.bedrock_models }}
api_key_secret: bedrock_api_key
custom_config_secret: ""
required: false
- provider: vertex_ai
models: ${{ inputs.vertex_ai_models }}
api_key_secret: ""
custom_config_secret: vertex_ai_custom_config_json
required: false
runs-on:
- runs-on
- runner=4cpu-linux-arm64
- "run-id=${{ github.run_id }}-nightly-${{ matrix.provider }}-provider-chat-test"
- extras=ecr-cache
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Run nightly provider chat test
uses: ./.github/actions/run-nightly-provider-chat-test
with:
provider: ${{ matrix.provider }}
models: ${{ matrix.models }}
provider-api-key: ${{ matrix.api_key_secret && secrets[matrix.api_key_secret] || '' }}
strict: ${{ inputs.strict && 'true' || 'false' }}
custom-config-json: ${{ matrix.custom_config_secret && secrets[matrix.custom_config_secret] || '' }}
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
run-id: ${{ github.run_id }}
docker-username: ${{ secrets.DOCKER_USERNAME }}
docker-token: ${{ secrets.DOCKER_TOKEN }}
- name: Dump API server logs
if: always()
run: |
cd deployment/docker_compose
docker compose logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
- name: Dump all-container logs
if: always()
run: |
cd deployment/docker_compose
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-all-logs-nightly-${{ matrix.provider }}-llm-provider
path: |
${{ github.workspace }}/api_server.log
${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
if: always()
run: |
cd deployment/docker_compose
docker compose down -v

View File

@@ -5,6 +5,8 @@ on:
branches: ["main"]
pull_request:
branches: ["**"]
paths:
- ".github/**"
permissions: {}
@@ -21,29 +23,18 @@ jobs:
with:
persist-credentials: false
- name: Detect changes
id: filter
uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # ratchet:dorny/paths-filter@v3
with:
filters: |
zizmor:
- '.github/**'
- name: Install the latest version of uv
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"
- name: Run zizmor
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
run: uv run --no-sync --with zizmor zizmor --format=sarif . > results.sarif
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Upload SARIF file
if: steps.filter.outputs.zizmor == 'true' || github.ref_name == 'main'
uses: github/codeql-action/upload-sarif@ba454b8ab46733eb6145342877cd148270bb77ab # ratchet:github/codeql-action/upload-sarif@codeql-bundle-v2.23.5
with:
sarif_file: results.sarif

1
.gitignore vendored
View File

@@ -7,6 +7,7 @@
.zed
.cursor
!/.cursor/mcp.json
!/.cursor/skills/
# macos
.DS_store

4
.vscode/launch.json vendored
View File

@@ -275,7 +275,7 @@
"--loglevel=INFO",
"--hostname=background@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,user_files_indexing,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration"
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration"
],
"presentation": {
"group": "2"
@@ -419,7 +419,7 @@
"--loglevel=INFO",
"--hostname=docfetching@%n",
"-Q",
"connector_doc_fetching,user_files_indexing"
"connector_doc_fetching"
],
"presentation": {
"group": "2"

View File

@@ -548,7 +548,7 @@ class in the utils over directly calling the APIs with a library like `requests`
calling the utilities directly (e.g. do NOT create admin users with
`admin_user = UserManager.create(name="admin_user")`, instead use the `admin_user` fixture).
A great example of this type of test is `backend/tests/integration/dev_apis/test_simple_chat_api.py`.
A great example of this type of test is `backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py`.
To run them:
@@ -616,3 +616,9 @@ This is a minimal list - feel free to include more. Do NOT write code as part of
Keep it high level. You can reference certain files or functions though.
Before writing your plan, make sure to do research. Explore the relevant sections in the codebase.
## Best Practices
In addition to the other content in this file, best practices for contributing
to the codebase can be found at `contributing_guides/best_practices.md`.
Understand its contents and follow them.

View File

@@ -2,7 +2,10 @@ Copyright (c) 2023-present DanswerAI, Inc.
Portions of this software are licensed as follows:
- All content that resides under "ee" directories of this repository, if that directory exists, is licensed under the license defined in "backend/ee/LICENSE". Specifically all content under "backend/ee" and "web/src/app/ee" is licensed under the license defined in "backend/ee/LICENSE".
- All content that resides under "ee" directories of this repository is licensed under the Onyx Enterprise License. Each ee directory contains an identical copy of this license at its root:
- backend/ee/LICENSE
- web/src/app/ee/LICENSE
- web/src/ee/LICENSE
- All third party components incorporated into the Onyx Software are licensed under the original license provided by the owner of the applicable component.
- Content outside of the above mentioned directories or restrictions above is available under the "MIT Expat" license as defined below.

View File

@@ -21,15 +21,14 @@ import sys
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, NamedTuple
from typing import NamedTuple
from alembic.config import Config
from alembic.script import ScriptDirectory
from sqlalchemy import text
from onyx.db.engine.sql_engine import is_valid_schema_name
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.db.engine.tenant_utils import get_schemas_needing_migration
from shared_configs.configs import TENANT_ID_PREFIX
@@ -105,56 +104,6 @@ def get_head_revision() -> str | None:
return script.get_current_head()
def get_schemas_needing_migration(
tenant_schemas: List[str], head_rev: str
) -> List[str]:
"""Return only schemas whose current alembic version is not at head."""
if not tenant_schemas:
return []
engine = SqlEngine.get_engine()
with engine.connect() as conn:
# Find which schemas actually have an alembic_version table
rows = conn.execute(
text(
"SELECT table_schema FROM information_schema.tables "
"WHERE table_name = 'alembic_version' "
"AND table_schema = ANY(:schemas)"
),
{"schemas": tenant_schemas},
)
schemas_with_table = set(row[0] for row in rows)
# Schemas without the table definitely need migration
needs_migration = [s for s in tenant_schemas if s not in schemas_with_table]
if not schemas_with_table:
return needs_migration
# Validate schema names before interpolating into SQL
for schema in schemas_with_table:
if not is_valid_schema_name(schema):
raise ValueError(f"Invalid schema name: {schema}")
# Single query to get every schema's current revision at once.
# Use integer tags instead of interpolating schema names into
# string literals to avoid quoting issues.
schema_list = list(schemas_with_table)
union_parts = [
f'SELECT {i} AS idx, version_num FROM "{schema}".alembic_version'
for i, schema in enumerate(schema_list)
]
rows = conn.execute(text(" UNION ALL ".join(union_parts)))
version_by_schema = {schema_list[row[0]]: row[1] for row in rows}
needs_migration.extend(
s for s in schemas_with_table if version_by_schema.get(s) != head_rev
)
return needs_migration
def run_migrations_parallel(
schemas: list[str],
max_workers: int,

View File

@@ -0,0 +1,29 @@
"""code interpreter seed
Revision ID: 07b98176f1de
Revises: 7cb492013621
Create Date: 2026-02-23 15:55:07.606784
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "07b98176f1de"
down_revision = "7cb492013621"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Seed the single instance of code_interpreter_server
# NOTE: There should only exist at most and at minimum 1 code_interpreter_server row
op.execute(
sa.text("INSERT INTO code_interpreter_server (server_enabled) VALUES (true)")
)
def downgrade() -> None:
op.execute(sa.text("DELETE FROM code_interpreter_server"))

View File

@@ -0,0 +1,28 @@
"""add scim_username to scim_user_mapping
Revision ID: 0bb4558f35df
Revises: 631fd2504136
Create Date: 2026-02-20 10:45:30.340188
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "0bb4558f35df"
down_revision = "631fd2504136"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"scim_user_mapping",
sa.Column("scim_username", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("scim_user_mapping", "scim_username")

View File

@@ -0,0 +1,71 @@
"""Migrate to contextual rag model
Revision ID: 19c0ccb01687
Revises: 9c54986124c6
Create Date: 2026-02-12 11:21:41.798037
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "19c0ccb01687"
down_revision = "9c54986124c6"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Widen the column to fit 'CONTEXTUAL_RAG' (15 chars); was varchar(10)
# when the table was created with only CHAT/VISION values.
op.alter_column(
"llm_model_flow",
"llm_model_flow_type",
type_=sa.String(length=20),
existing_type=sa.String(length=10),
existing_nullable=False,
)
# For every search_settings row that has contextual rag configured,
# create an llm_model_flow entry. is_default is TRUE if the row
# belongs to the PRESENT search settings, FALSE otherwise.
op.execute(
"""
INSERT INTO llm_model_flow (llm_model_flow_type, model_configuration_id, is_default)
SELECT DISTINCT
'CONTEXTUAL_RAG',
mc.id,
(ss.status = 'PRESENT')
FROM search_settings ss
JOIN llm_provider lp
ON lp.name = ss.contextual_rag_llm_provider
JOIN model_configuration mc
ON mc.llm_provider_id = lp.id
AND mc.name = ss.contextual_rag_llm_name
WHERE ss.enable_contextual_rag = TRUE
AND ss.contextual_rag_llm_name IS NOT NULL
AND ss.contextual_rag_llm_provider IS NOT NULL
ON CONFLICT (llm_model_flow_type, model_configuration_id)
DO UPDATE SET is_default = EXCLUDED.is_default
WHERE EXCLUDED.is_default = TRUE
"""
)
def downgrade() -> None:
op.execute(
"""
DELETE FROM llm_model_flow
WHERE llm_model_flow_type = 'CONTEXTUAL_RAG'
"""
)
op.alter_column(
"llm_model_flow",
"llm_model_flow_type",
type_=sa.String(length=10),
existing_type=sa.String(length=20),
existing_nullable=False,
)

View File

@@ -0,0 +1,32 @@
"""add approx_chunk_count_in_vespa to opensearch tenant migration
Revision ID: 631fd2504136
Revises: c7f2e1b4a9d3
Create Date: 2026-02-18 21:07:52.831215
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "631fd2504136"
down_revision = "c7f2e1b4a9d3"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"opensearch_tenant_migration_record",
sa.Column(
"approx_chunk_count_in_vespa",
sa.Integer(),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("opensearch_tenant_migration_record", "approx_chunk_count_in_vespa")

View File

@@ -0,0 +1,48 @@
"""add enterprise and name fields to scim_user_mapping
Revision ID: 7616121f6e97
Revises: 07b98176f1de
Create Date: 2026-02-23 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7616121f6e97"
down_revision = "07b98176f1de"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"scim_user_mapping",
sa.Column("department", sa.String(), nullable=True),
)
op.add_column(
"scim_user_mapping",
sa.Column("manager", sa.String(), nullable=True),
)
op.add_column(
"scim_user_mapping",
sa.Column("given_name", sa.String(), nullable=True),
)
op.add_column(
"scim_user_mapping",
sa.Column("family_name", sa.String(), nullable=True),
)
op.add_column(
"scim_user_mapping",
sa.Column("scim_emails_json", sa.Text(), nullable=True),
)
def downgrade() -> None:
op.drop_column("scim_user_mapping", "scim_emails_json")
op.drop_column("scim_user_mapping", "family_name")
op.drop_column("scim_user_mapping", "given_name")
op.drop_column("scim_user_mapping", "manager")
op.drop_column("scim_user_mapping", "department")

View File

@@ -0,0 +1,31 @@
"""code interpreter server model
Revision ID: 7cb492013621
Revises: 0bb4558f35df
Create Date: 2026-02-22 18:54:54.007265
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7cb492013621"
down_revision = "0bb4558f35df"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"code_interpreter_server",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column(
"server_enabled", sa.Boolean, nullable=False, server_default=sa.true()
),
)
def downgrade() -> None:
op.drop_table("code_interpreter_server")

View File

@@ -0,0 +1,33 @@
"""add needs_persona_sync to user_file
Revision ID: 8ffcc2bcfc11
Revises: 7616121f6e97
Create Date: 2026-02-23 10:48:48.343826
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "8ffcc2bcfc11"
down_revision = "7616121f6e97"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user_file",
sa.Column(
"needs_persona_sync",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
)
def downgrade() -> None:
op.drop_column("user_file", "needs_persona_sync")

View File

@@ -0,0 +1,124 @@
"""add_scim_tables
Revision ID: 9c54986124c6
Revises: b51c6844d1df
Create Date: 2026-02-12 20:29:47.448614
"""
from alembic import op
import fastapi_users_db_sqlalchemy
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "9c54986124c6"
down_revision = "b51c6844d1df"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"scim_token",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("hashed_token", sa.String(length=64), nullable=False),
sa.Column("token_display", sa.String(), nullable=False),
sa.Column(
"created_by_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=False,
),
sa.Column(
"is_active",
sa.Boolean(),
server_default=sa.text("true"),
nullable=False,
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(["created_by_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("hashed_token"),
)
op.create_table(
"scim_group_mapping",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("external_id", sa.String(), nullable=False),
sa.Column("user_group_id", sa.Integer(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["user_group_id"], ["user_group.id"], ondelete="CASCADE"
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("user_group_id"),
)
op.create_index(
op.f("ix_scim_group_mapping_external_id"),
"scim_group_mapping",
["external_id"],
unique=True,
)
op.create_table(
"scim_user_mapping",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("external_id", sa.String(), nullable=False),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=False,
),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
onupdate=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("user_id"),
)
op.create_index(
op.f("ix_scim_user_mapping_external_id"),
"scim_user_mapping",
["external_id"],
unique=True,
)
def downgrade() -> None:
op.drop_index(
op.f("ix_scim_user_mapping_external_id"),
table_name="scim_user_mapping",
)
op.drop_table("scim_user_mapping")
op.drop_index(
op.f("ix_scim_group_mapping_external_id"),
table_name="scim_group_mapping",
)
op.drop_table("scim_group_mapping")
op.drop_table("scim_token")

View File

@@ -0,0 +1,70 @@
"""llm provider deprecate fields
Revision ID: c0c937d5c9e5
Revises: 8ffcc2bcfc11
Create Date: 2026-02-25 17:35:46.125102
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c0c937d5c9e5"
down_revision = "8ffcc2bcfc11"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Make default_model_name nullable (was NOT NULL)
op.alter_column(
"llm_provider",
"default_model_name",
existing_type=sa.String(),
nullable=True,
)
# Drop unique constraint on is_default_provider (defaults now tracked via LLMModelFlow)
op.drop_constraint(
"llm_provider_is_default_provider_key",
"llm_provider",
type_="unique",
)
# Remove server_default from is_default_vision_provider (was server_default=false())
op.alter_column(
"llm_provider",
"is_default_vision_provider",
existing_type=sa.Boolean(),
server_default=None,
)
def downgrade() -> None:
# Restore default_model_name to NOT NULL (set empty string for any NULLs first)
op.execute(
"UPDATE llm_provider SET default_model_name = '' WHERE default_model_name IS NULL"
)
op.alter_column(
"llm_provider",
"default_model_name",
existing_type=sa.String(),
nullable=False,
)
# Restore unique constraint on is_default_provider
op.create_unique_constraint(
"llm_provider_is_default_provider_key",
"llm_provider",
["is_default_provider"],
)
# Restore server_default for is_default_vision_provider
op.alter_column(
"llm_provider",
"is_default_vision_provider",
existing_type=sa.Boolean(),
server_default=sa.false(),
)

View File

@@ -0,0 +1,31 @@
"""add sharing_scope to build_session
Revision ID: c7f2e1b4a9d3
Revises: 19c0ccb01687
Create Date: 2026-02-17 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "c7f2e1b4a9d3"
down_revision = "19c0ccb01687"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"build_session",
sa.Column(
"sharing_scope",
sa.String(),
nullable=False,
server_default="private",
),
)
def downgrade() -> None:
op.drop_column("build_session", "sharing_scope")

View File

@@ -1,20 +1,20 @@
The DanswerAI Enterprise license (the Enterprise License)
The Onyx Enterprise License (the "Enterprise License")
Copyright (c) 2023-present DanswerAI, Inc.
With regard to the Onyx Software:
This software and associated documentation files (the "Software") may only be
used in production, if you (and any entity that you represent) have agreed to,
and are in compliance with, the DanswerAI Subscription Terms of Service, available
at https://onyx.app/terms (the Enterprise Terms), or other
and are in compliance with, the Onyx Subscription Terms of Service, available
at https://www.onyx.app/legal/self-host (the "Enterprise Terms"), or other
agreement governing the use of the Software, as agreed by you and DanswerAI,
and otherwise have a valid Onyx Enterprise license for the
and otherwise have a valid Onyx Enterprise License for the
correct number of user seats. Subject to the foregoing sentence, you are free to
modify this Software and publish patches to the Software. You agree that DanswerAI
and/or its licensors (as applicable) retain all right, title and interest in and
to all such modifications and/or patches, and all such modifications and/or
patches may only be used, copied, modified, displayed, distributed, or otherwise
exploited with a valid Onyx Enterprise license for the correct
exploited with a valid Onyx Enterprise License for the correct
number of user seats. Notwithstanding the foregoing, you may copy and modify
the Software for development and testing purposes, without requiring a
subscription. You agree that DanswerAI and/or its licensors (as applicable) retain

View File

@@ -536,7 +536,9 @@ def connector_permission_sync_generator_task(
)
redis_connector.permissions.set_fence(new_payload)
callback = PermissionSyncCallback(redis_connector, lock, r)
callback = PermissionSyncCallback(
redis_connector, lock, r, timeout_seconds=JOB_TIMEOUT
)
# pass in the capability to fetch all existing docs for the cc_pair
# this is can be used to determine documents that are "missing" and thus
@@ -576,6 +578,13 @@ def connector_permission_sync_generator_task(
tasks_generated = 0
docs_with_errors = 0
for doc_external_access in document_external_accesses:
if callback.should_stop():
raise RuntimeError(
f"Permission sync task timed out or stop signal detected: "
f"cc_pair={cc_pair_id} "
f"tasks_generated={tasks_generated}"
)
result = redis_connector.permissions.update_db(
lock=lock,
new_permissions=[doc_external_access],
@@ -932,6 +941,7 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
redis_connector: RedisConnector,
redis_lock: RedisLock,
redis_client: Redis,
timeout_seconds: int | None = None,
):
super().__init__()
self.redis_connector: RedisConnector = redis_connector
@@ -944,11 +954,26 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
self.last_tag: str = "PermissionSyncCallback.__init__"
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
self.last_lock_monotonic = time.monotonic()
self.start_monotonic = time.monotonic()
self.timeout_seconds = timeout_seconds
def should_stop(self) -> bool:
if self.redis_connector.stop.fenced:
return True
# Check if the task has exceeded its timeout
# NOTE: Celery's soft_time_limit does not work with thread pools,
# so we must enforce timeouts internally.
if self.timeout_seconds is not None:
elapsed = time.monotonic() - self.start_monotonic
if elapsed > self.timeout_seconds:
logger.warning(
f"PermissionSyncCallback - task timeout exceeded: "
f"elapsed={elapsed:.0f}s timeout={self.timeout_seconds}s "
f"cc_pair={self.redis_connector.cc_pair_id}"
)
return True
return False
def progress(self, tag: str, amount: int) -> None: # noqa: ARG002

View File

@@ -466,6 +466,7 @@ def connector_external_group_sync_generator_task(
def _perform_external_group_sync(
cc_pair_id: int,
tenant_id: str,
timeout_seconds: int = JOB_TIMEOUT,
) -> None:
# Create attempt record at the start
with get_session_with_current_tenant() as db_session:
@@ -518,9 +519,23 @@ def _perform_external_group_sync(
seen_users: set[str] = set() # Track unique users across all groups
total_groups_processed = 0
total_group_memberships_synced = 0
start_time = time.monotonic()
try:
external_user_group_generator = ext_group_sync_func(tenant_id, cc_pair)
for external_user_group in external_user_group_generator:
# Check if the task has exceeded its timeout
# NOTE: Celery's soft_time_limit does not work with thread pools,
# so we must enforce timeouts internally.
elapsed = time.monotonic() - start_time
if elapsed > timeout_seconds:
raise RuntimeError(
f"External group sync task timed out: "
f"cc_pair={cc_pair_id} "
f"elapsed={elapsed:.0f}s "
f"timeout={timeout_seconds}s "
f"groups_processed={total_groups_processed}"
)
external_user_group_batch.append(external_user_group)
# Track progress

View File

@@ -263,9 +263,15 @@ def refresh_license_cache(
try:
payload = verify_license_signature(license_record.license_data)
# Derive source from payload: manual licenses lack stripe_customer_id
source: LicenseSource = (
LicenseSource.AUTO_FETCH
if payload.stripe_customer_id
else LicenseSource.MANUAL_UPLOAD
)
return update_license_cache(
payload,
source=LicenseSource.AUTO_FETCH,
source=source,
tenant_id=tenant_id,
)
except ValueError as e:

709
backend/ee/onyx/db/scim.py Normal file
View File

@@ -0,0 +1,709 @@
"""SCIM Data Access Layer.
All database operations for SCIM provisioning — token management, user
mappings, and group mappings. Extends the base DAL (see ``onyx.db.dal``).
Usage from FastAPI::
def get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
return ScimDAL(db_session)
@router.post("/tokens")
def create_token(dal: ScimDAL = Depends(get_scim_dal)) -> ...:
token = dal.create_token(name=..., hashed_token=..., ...)
dal.commit()
return token
Usage from background tasks::
with ScimDAL.from_tenant("tenant_abc") as dal:
mapping = dal.create_user_mapping(external_id="idp-123", user_id=uid)
dal.commit()
"""
from __future__ import annotations
from uuid import UUID
from sqlalchemy import delete as sa_delete
from sqlalchemy import func
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import SQLColumnExpression
from sqlalchemy.dialects.postgresql import insert as pg_insert
from ee.onyx.server.scim.filtering import ScimFilter
from ee.onyx.server.scim.filtering import ScimFilterOperator
from ee.onyx.server.scim.models import ScimMappingFields
from onyx.db.dal import DAL
from onyx.db.models import ScimGroupMapping
from onyx.db.models import ScimToken
from onyx.db.models import ScimUserMapping
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.utils.logger import setup_logger
logger = setup_logger()
class ScimDAL(DAL):
"""Data Access Layer for SCIM provisioning operations.
Methods mutate but do NOT commit — call ``dal.commit()`` explicitly
when you want to persist changes. This follows the existing ``_no_commit``
convention and lets callers batch multiple operations into one transaction.
"""
# ------------------------------------------------------------------
# Token operations
# ------------------------------------------------------------------
def create_token(
self,
name: str,
hashed_token: str,
token_display: str,
created_by_id: UUID,
) -> ScimToken:
"""Create a new SCIM bearer token.
Only one token is active at a time — this method automatically revokes
all existing active tokens before creating the new one.
"""
# Revoke any currently active tokens
active_tokens = list(
self._session.scalars(
select(ScimToken).where(ScimToken.is_active.is_(True))
).all()
)
for t in active_tokens:
t.is_active = False
token = ScimToken(
name=name,
hashed_token=hashed_token,
token_display=token_display,
created_by_id=created_by_id,
)
self._session.add(token)
self._session.flush()
return token
def get_active_token(self) -> ScimToken | None:
"""Return the single currently active token, or None."""
return self._session.scalar(
select(ScimToken).where(ScimToken.is_active.is_(True))
)
def get_token_by_hash(self, hashed_token: str) -> ScimToken | None:
"""Look up a token by its SHA-256 hash."""
return self._session.scalar(
select(ScimToken).where(ScimToken.hashed_token == hashed_token)
)
def revoke_token(self, token_id: int) -> None:
"""Deactivate a token by ID.
Raises:
ValueError: If the token does not exist.
"""
token = self._session.get(ScimToken, token_id)
if not token:
raise ValueError(f"SCIM token with id {token_id} not found")
token.is_active = False
def update_token_last_used(self, token_id: int) -> None:
"""Update the last_used_at timestamp for a token."""
token = self._session.get(ScimToken, token_id)
if token:
token.last_used_at = func.now() # type: ignore[assignment]
# ------------------------------------------------------------------
# User mapping operations
# ------------------------------------------------------------------
def create_user_mapping(
self,
external_id: str,
user_id: UUID,
scim_username: str | None = None,
fields: ScimMappingFields | None = None,
) -> ScimUserMapping:
"""Create a mapping between a SCIM externalId and an Onyx user."""
f = fields or ScimMappingFields()
mapping = ScimUserMapping(
external_id=external_id,
user_id=user_id,
scim_username=scim_username,
department=f.department,
manager=f.manager,
given_name=f.given_name,
family_name=f.family_name,
scim_emails_json=f.scim_emails_json,
)
self._session.add(mapping)
self._session.flush()
return mapping
def get_user_mapping_by_external_id(
self, external_id: str
) -> ScimUserMapping | None:
"""Look up a user mapping by the IdP's external identifier."""
return self._session.scalar(
select(ScimUserMapping).where(ScimUserMapping.external_id == external_id)
)
def get_user_mapping_by_user_id(self, user_id: UUID) -> ScimUserMapping | None:
"""Look up a user mapping by the Onyx user ID."""
return self._session.scalar(
select(ScimUserMapping).where(ScimUserMapping.user_id == user_id)
)
def list_user_mappings(
self,
start_index: int = 1,
count: int = 100,
) -> tuple[list[ScimUserMapping], int]:
"""List user mappings with SCIM-style pagination.
Args:
start_index: 1-based start index (SCIM convention).
count: Maximum number of results to return.
Returns:
A tuple of (mappings, total_count).
"""
total = (
self._session.scalar(select(func.count()).select_from(ScimUserMapping)) or 0
)
offset = max(start_index - 1, 0)
mappings = list(
self._session.scalars(
select(ScimUserMapping)
.order_by(ScimUserMapping.id)
.offset(offset)
.limit(count)
).all()
)
return mappings, total
def update_user_mapping_external_id(
self,
mapping_id: int,
external_id: str,
) -> ScimUserMapping:
"""Update the external ID on a user mapping.
Raises:
ValueError: If the mapping does not exist.
"""
mapping = self._session.get(ScimUserMapping, mapping_id)
if not mapping:
raise ValueError(f"SCIM user mapping with id {mapping_id} not found")
mapping.external_id = external_id
return mapping
def delete_user_mapping(self, mapping_id: int) -> None:
"""Delete a user mapping by ID. No-op if already deleted."""
mapping = self._session.get(ScimUserMapping, mapping_id)
if not mapping:
logger.warning("SCIM user mapping %d not found during delete", mapping_id)
return
self._session.delete(mapping)
# ------------------------------------------------------------------
# User query operations
# ------------------------------------------------------------------
def get_user(self, user_id: UUID) -> User | None:
"""Fetch a user by ID."""
return self._session.scalar(
select(User).where(User.id == user_id) # type: ignore[arg-type]
)
def get_user_by_email(self, email: str) -> User | None:
"""Fetch a user by email (case-insensitive)."""
return self._session.scalar(
select(User).where(func.lower(User.email) == func.lower(email))
)
def add_user(self, user: User) -> None:
"""Add a new user to the session and flush to assign an ID."""
self._session.add(user)
self._session.flush()
def update_user(
self,
user: User,
*,
email: str | None = None,
is_active: bool | None = None,
personal_name: str | None = None,
) -> None:
"""Update user attributes. Only sets fields that are provided."""
if email is not None:
user.email = email
if is_active is not None:
user.is_active = is_active
if personal_name is not None:
user.personal_name = personal_name
def deactivate_user(self, user: User) -> None:
"""Mark a user as inactive."""
user.is_active = False
def list_users(
self,
scim_filter: ScimFilter | None,
start_index: int = 1,
count: int = 100,
) -> tuple[list[tuple[User, ScimUserMapping | None]], int]:
"""Query users with optional SCIM filter and pagination.
Returns:
A tuple of (list of (user, mapping) pairs, total_count).
Raises:
ValueError: If the filter uses an unsupported attribute.
"""
query = select(User).where(
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER])
)
if scim_filter:
attr = scim_filter.attribute.lower()
if attr == "username":
# arg-type: fastapi-users types User.email as str, not a column expression
# assignment: union return type widens but query is still Select[tuple[User]]
query = _apply_scim_string_op(query, User.email, scim_filter) # type: ignore[arg-type, assignment]
elif attr == "active":
query = query.where(
User.is_active.is_(scim_filter.value.lower() == "true") # type: ignore[attr-defined]
)
elif attr == "externalid":
mapping = self.get_user_mapping_by_external_id(scim_filter.value)
if not mapping:
return [], 0
query = query.where(User.id == mapping.user_id) # type: ignore[arg-type]
else:
raise ValueError(
f"Unsupported filter attribute: {scim_filter.attribute}"
)
# Count total matching rows first, then paginate. SCIM uses 1-based
# indexing (RFC 7644 §3.4.2), so we convert to a 0-based offset.
total = (
self._session.scalar(select(func.count()).select_from(query.subquery()))
or 0
)
offset = max(start_index - 1, 0)
users = list(
self._session.scalars(
query.order_by(User.id).offset(offset).limit(count) # type: ignore[arg-type]
)
.unique()
.all()
)
# Batch-fetch SCIM mappings to avoid N+1 queries
mapping_map = self._get_user_mappings_batch([u.id for u in users])
return [(u, mapping_map.get(u.id)) for u in users], total
def sync_user_external_id(
self,
user_id: UUID,
new_external_id: str | None,
scim_username: str | None = None,
fields: ScimMappingFields | None = None,
) -> None:
"""Create, update, or delete the external ID mapping for a user.
When *fields* is provided, all mapping fields are written
unconditionally — including ``None`` values — so that a caller can
clear a previously-set field (e.g. removing a department).
"""
mapping = self.get_user_mapping_by_user_id(user_id)
if new_external_id:
if mapping:
if mapping.external_id != new_external_id:
mapping.external_id = new_external_id
if scim_username is not None:
mapping.scim_username = scim_username
if fields is not None:
mapping.department = fields.department
mapping.manager = fields.manager
mapping.given_name = fields.given_name
mapping.family_name = fields.family_name
mapping.scim_emails_json = fields.scim_emails_json
else:
self.create_user_mapping(
external_id=new_external_id,
user_id=user_id,
scim_username=scim_username,
fields=fields,
)
elif mapping:
self.delete_user_mapping(mapping.id)
def _get_user_mappings_batch(
self, user_ids: list[UUID]
) -> dict[UUID, ScimUserMapping]:
"""Batch-fetch SCIM user mappings keyed by user ID."""
if not user_ids:
return {}
mappings = self._session.scalars(
select(ScimUserMapping).where(ScimUserMapping.user_id.in_(user_ids))
).all()
return {m.user_id: m for m in mappings}
def get_user_groups(self, user_id: UUID) -> list[tuple[int, str]]:
"""Get groups a user belongs to as ``(group_id, group_name)`` pairs.
Excludes groups marked for deletion.
"""
rels = self._session.scalars(
select(User__UserGroup).where(User__UserGroup.user_id == user_id)
).all()
group_ids = [r.user_group_id for r in rels]
if not group_ids:
return []
groups = self._session.scalars(
select(UserGroup).where(
UserGroup.id.in_(group_ids),
UserGroup.is_up_for_deletion.is_(False),
)
).all()
return [(g.id, g.name) for g in groups]
def get_users_groups_batch(
self, user_ids: list[UUID]
) -> dict[UUID, list[tuple[int, str]]]:
"""Batch-fetch group memberships for multiple users.
Returns a mapping of ``user_id → [(group_id, group_name), ...]``.
Avoids N+1 queries when building user list responses.
"""
if not user_ids:
return {}
rels = self._session.scalars(
select(User__UserGroup).where(User__UserGroup.user_id.in_(user_ids))
).all()
group_ids = list({r.user_group_id for r in rels})
if not group_ids:
return {}
groups = self._session.scalars(
select(UserGroup).where(
UserGroup.id.in_(group_ids),
UserGroup.is_up_for_deletion.is_(False),
)
).all()
groups_by_id = {g.id: g.name for g in groups}
result: dict[UUID, list[tuple[int, str]]] = {}
for r in rels:
if r.user_id and r.user_group_id in groups_by_id:
result.setdefault(r.user_id, []).append(
(r.user_group_id, groups_by_id[r.user_group_id])
)
return result
# ------------------------------------------------------------------
# Group mapping operations
# ------------------------------------------------------------------
def create_group_mapping(
self,
external_id: str,
user_group_id: int,
) -> ScimGroupMapping:
"""Create a mapping between a SCIM externalId and an Onyx user group."""
mapping = ScimGroupMapping(external_id=external_id, user_group_id=user_group_id)
self._session.add(mapping)
self._session.flush()
return mapping
def get_group_mapping_by_external_id(
self, external_id: str
) -> ScimGroupMapping | None:
"""Look up a group mapping by the IdP's external identifier."""
return self._session.scalar(
select(ScimGroupMapping).where(ScimGroupMapping.external_id == external_id)
)
def get_group_mapping_by_group_id(
self, user_group_id: int
) -> ScimGroupMapping | None:
"""Look up a group mapping by the Onyx user group ID."""
return self._session.scalar(
select(ScimGroupMapping).where(
ScimGroupMapping.user_group_id == user_group_id
)
)
def list_group_mappings(
self,
start_index: int = 1,
count: int = 100,
) -> tuple[list[ScimGroupMapping], int]:
"""List group mappings with SCIM-style pagination.
Args:
start_index: 1-based start index (SCIM convention).
count: Maximum number of results to return.
Returns:
A tuple of (mappings, total_count).
"""
total = (
self._session.scalar(select(func.count()).select_from(ScimGroupMapping))
or 0
)
offset = max(start_index - 1, 0)
mappings = list(
self._session.scalars(
select(ScimGroupMapping)
.order_by(ScimGroupMapping.id)
.offset(offset)
.limit(count)
).all()
)
return mappings, total
def delete_group_mapping(self, mapping_id: int) -> None:
"""Delete a group mapping by ID. No-op if already deleted."""
mapping = self._session.get(ScimGroupMapping, mapping_id)
if not mapping:
logger.warning("SCIM group mapping %d not found during delete", mapping_id)
return
self._session.delete(mapping)
# ------------------------------------------------------------------
# Group query operations
# ------------------------------------------------------------------
def get_group(self, group_id: int) -> UserGroup | None:
"""Fetch a group by ID, returning None if deleted or missing."""
group = self._session.get(UserGroup, group_id)
if group and group.is_up_for_deletion:
return None
return group
def get_group_by_name(self, name: str) -> UserGroup | None:
"""Fetch a group by exact name."""
return self._session.scalar(select(UserGroup).where(UserGroup.name == name))
def add_group(self, group: UserGroup) -> None:
"""Add a new group to the session and flush to assign an ID."""
self._session.add(group)
self._session.flush()
def update_group(
self,
group: UserGroup,
*,
name: str | None = None,
) -> None:
"""Update group attributes and set the modification timestamp."""
if name is not None:
group.name = name
group.time_last_modified_by_user = func.now()
def delete_group(self, group: UserGroup) -> None:
"""Delete a group from the session."""
self._session.delete(group)
def list_groups(
self,
scim_filter: ScimFilter | None,
start_index: int = 1,
count: int = 100,
) -> tuple[list[tuple[UserGroup, str | None]], int]:
"""Query groups with optional SCIM filter and pagination.
Returns:
A tuple of (list of (group, external_id) pairs, total_count).
Raises:
ValueError: If the filter uses an unsupported attribute.
"""
query = select(UserGroup).where(UserGroup.is_up_for_deletion.is_(False))
if scim_filter:
attr = scim_filter.attribute.lower()
if attr == "displayname":
# assignment: union return type widens but query is still Select[tuple[UserGroup]]
query = _apply_scim_string_op(query, UserGroup.name, scim_filter) # type: ignore[assignment]
elif attr == "externalid":
mapping = self.get_group_mapping_by_external_id(scim_filter.value)
if not mapping:
return [], 0
query = query.where(UserGroup.id == mapping.user_group_id)
else:
raise ValueError(
f"Unsupported filter attribute: {scim_filter.attribute}"
)
total = (
self._session.scalar(select(func.count()).select_from(query.subquery()))
or 0
)
offset = max(start_index - 1, 0)
groups = list(
self._session.scalars(
query.order_by(UserGroup.id).offset(offset).limit(count)
).all()
)
ext_id_map = self._get_group_external_ids([g.id for g in groups])
return [(g, ext_id_map.get(g.id)) for g in groups], total
def get_group_members(self, group_id: int) -> list[tuple[UUID, str | None]]:
"""Get group members as (user_id, email) pairs."""
rels = self._session.scalars(
select(User__UserGroup).where(User__UserGroup.user_group_id == group_id)
).all()
user_ids = [r.user_id for r in rels if r.user_id]
if not user_ids:
return []
users = (
self._session.scalars(
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
)
.unique()
.all()
)
users_by_id = {u.id: u for u in users}
return [
(
r.user_id,
users_by_id[r.user_id].email if r.user_id in users_by_id else None,
)
for r in rels
if r.user_id
]
def validate_member_ids(self, uuids: list[UUID]) -> list[UUID]:
"""Return the subset of UUIDs that don't exist as users.
Returns an empty list if all IDs are valid.
"""
if not uuids:
return []
existing_users = (
self._session.scalars(
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
)
.unique()
.all()
)
existing_ids = {u.id for u in existing_users}
return [uid for uid in uuids if uid not in existing_ids]
def upsert_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
"""Add user-group relationships, ignoring duplicates."""
if not user_ids:
return
self._session.execute(
pg_insert(User__UserGroup)
.values([{"user_id": uid, "user_group_id": group_id} for uid in user_ids])
.on_conflict_do_nothing(
index_elements=[
User__UserGroup.user_group_id,
User__UserGroup.user_id,
]
)
)
def replace_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
"""Replace all members of a group."""
self._session.execute(
sa_delete(User__UserGroup).where(User__UserGroup.user_group_id == group_id)
)
self.upsert_group_members(group_id, user_ids)
def remove_group_members(self, group_id: int, user_ids: list[UUID]) -> None:
"""Remove specific members from a group."""
if not user_ids:
return
self._session.execute(
sa_delete(User__UserGroup).where(
User__UserGroup.user_group_id == group_id,
User__UserGroup.user_id.in_(user_ids),
)
)
def delete_group_with_members(self, group: UserGroup) -> None:
"""Remove all member relationships and delete the group."""
self._session.execute(
sa_delete(User__UserGroup).where(User__UserGroup.user_group_id == group.id)
)
self._session.delete(group)
def sync_group_external_id(
self, group_id: int, new_external_id: str | None
) -> None:
"""Create, update, or delete the external ID mapping for a group."""
mapping = self.get_group_mapping_by_group_id(group_id)
if new_external_id:
if mapping:
if mapping.external_id != new_external_id:
mapping.external_id = new_external_id
else:
self.create_group_mapping(
external_id=new_external_id, user_group_id=group_id
)
elif mapping:
self.delete_group_mapping(mapping.id)
def _get_group_external_ids(self, group_ids: list[int]) -> dict[int, str]:
"""Batch-fetch external IDs for a list of group IDs."""
if not group_ids:
return {}
mappings = self._session.scalars(
select(ScimGroupMapping).where(
ScimGroupMapping.user_group_id.in_(group_ids)
)
).all()
return {m.user_group_id: m.external_id for m in mappings}
# ---------------------------------------------------------------------------
# Module-level helpers (used by DAL methods above)
# ---------------------------------------------------------------------------
def _apply_scim_string_op(
query: Select[tuple[User]] | Select[tuple[UserGroup]],
column: SQLColumnExpression[str],
scim_filter: ScimFilter,
) -> Select[tuple[User]] | Select[tuple[UserGroup]]:
"""Apply a SCIM string filter operator using SQLAlchemy column operators.
Handles eq (case-insensitive exact), co (contains), and sw (starts with).
SQLAlchemy's operators handle LIKE-pattern escaping internally.
"""
val = scim_filter.value
if scim_filter.operator == ScimFilterOperator.EQUAL:
return query.where(func.lower(column) == val.lower())
elif scim_filter.operator == ScimFilterOperator.CONTAINS:
return query.where(column.icontains(val, autoescape=True))
elif scim_filter.operator == ScimFilterOperator.STARTS_WITH:
return query.where(column.istartswith(val, autoescape=True))
else:
raise ValueError(f"Unsupported string filter operator: {scim_filter.operator}")

View File

@@ -9,6 +9,7 @@ from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from ee.onyx.server.user_group.models import SetCuratorRequest
@@ -18,11 +19,15 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import Credential__UserGroup
from onyx.db.models import Document
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.db.models import DocumentSet
from onyx.db.models import DocumentSet__UserGroup
from onyx.db.models import FederatedConnector__DocumentSet
from onyx.db.models import LLMProvider__UserGroup
from onyx.db.models import Persona
from onyx.db.models import Persona__UserGroup
from onyx.db.models import TokenRateLimit__UserGroup
from onyx.db.models import User
@@ -195,8 +200,60 @@ def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | Non
return db_session.scalar(stmt)
def _add_user_group_snapshot_eager_loads(
stmt: Select,
) -> Select:
"""Add eager loading options needed by UserGroup.from_model snapshot creation."""
return stmt.options(
selectinload(UserGroup.users),
selectinload(UserGroup.user_group_relationships),
selectinload(UserGroup.cc_pair_relationships)
.selectinload(UserGroup__ConnectorCredentialPair.cc_pair)
.options(
selectinload(ConnectorCredentialPair.connector),
selectinload(ConnectorCredentialPair.credential).selectinload(
Credential.user
),
),
selectinload(UserGroup.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(UserGroup.personas).options(
selectinload(Persona.tools),
selectinload(Persona.hierarchy_nodes),
selectinload(Persona.attached_documents).selectinload(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.user),
selectinload(Persona.user_files),
selectinload(Persona.users),
selectinload(Persona.groups),
),
)
def fetch_user_groups(
db_session: Session, only_up_to_date: bool = True
db_session: Session,
only_up_to_date: bool = True,
eager_load_for_snapshot: bool = False,
) -> Sequence[UserGroup]:
"""
Fetches user groups from the database.
@@ -209,6 +266,8 @@ def fetch_user_groups(
db_session (Session): The SQLAlchemy session used to query the database.
only_up_to_date (bool, optional): Flag to determine whether to filter the results
to include only up to date user groups. Defaults to `True`.
eager_load_for_snapshot: If True, adds eager loading for all relationships
needed by UserGroup.from_model snapshot creation.
Returns:
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
@@ -216,11 +275,16 @@ def fetch_user_groups(
stmt = select(UserGroup)
if only_up_to_date:
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
return db_session.scalars(stmt).all()
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
def fetch_user_groups_for_user(
db_session: Session, user_id: UUID, only_curator_groups: bool = False
db_session: Session,
user_id: UUID,
only_curator_groups: bool = False,
eager_load_for_snapshot: bool = False,
) -> Sequence[UserGroup]:
stmt = (
select(UserGroup)
@@ -230,7 +294,9 @@ def fetch_user_groups_for_user(
)
if only_curator_groups:
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
return db_session.scalars(stmt).all()
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
def construct_document_id_select_by_usergroup(

View File

@@ -6,6 +6,7 @@ from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.sharepoint.permission_utils import (
get_sharepoint_external_groups,
)
from onyx.configs.app_configs import SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
from onyx.connectors.sharepoint.connector import acquire_token_for_rest
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.db.models import ConnectorCredentialPair
@@ -46,19 +47,27 @@ def sharepoint_group_sync(
logger.info(f"Processing {len(site_descriptors)} sites for group sync")
enumerate_all = connector_config.get(
"exhaustive_ad_enumeration", SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
)
msal_app = connector.msal_app
sp_tenant_domain = connector.sp_tenant_domain
# Process each site
sp_domain_suffix = connector.sharepoint_domain_suffix
for site_descriptor in site_descriptors:
logger.debug(f"Processing site: {site_descriptor.url}")
# Create client context for the site using connector's MSAL app
ctx = ClientContext(site_descriptor.url).with_access_token(
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain, sp_domain_suffix)
)
# Get external groups for this site
external_groups = get_sharepoint_external_groups(ctx, connector.graph_client)
external_groups = get_sharepoint_external_groups(
ctx,
connector.graph_client,
graph_api_base=connector.graph_api_base,
get_access_token=connector._get_graph_access_token,
enumerate_all_ad_groups=enumerate_all,
)
# Yield each group
for group in external_groups:

View File

@@ -1,9 +1,12 @@
import re
import time
from collections import deque
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from urllib.parse import unquote
from urllib.parse import urlparse
import requests as _requests
from office365.graph_client import GraphClient # type: ignore[import-untyped]
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
from office365.runtime.client_request import ClientRequestException # type: ignore
@@ -14,7 +17,10 @@ from pydantic import BaseModel
from ee.onyx.db.external_perm import ExternalUserGroup
from onyx.access.models import ExternalAccess
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
from onyx.configs.constants import DocumentSource
from onyx.connectors.sharepoint.connector import GRAPH_API_MAX_RETRIES
from onyx.connectors.sharepoint.connector import GRAPH_API_RETRYABLE_STATUSES
from onyx.connectors.sharepoint.connector import SHARED_DOCUMENTS_MAP_REVERSE
from onyx.connectors.sharepoint.connector import sleep_and_retry
from onyx.utils.logger import setup_logger
@@ -33,6 +39,70 @@ LIMITED_ACCESS_ROLE_TYPES = [1, 9]
LIMITED_ACCESS_ROLE_NAMES = ["Limited Access", "Web-Only Limited Access"]
AD_GROUP_ENUMERATION_THRESHOLD = 100_000
def _graph_api_get(
url: str,
get_access_token: Callable[[], str],
params: dict[str, str] | None = None,
) -> dict[str, Any]:
"""Authenticated Graph API GET with retry on transient errors."""
for attempt in range(GRAPH_API_MAX_RETRIES + 1):
access_token = get_access_token()
headers = {"Authorization": f"Bearer {access_token}"}
try:
resp = _requests.get(
url, headers=headers, params=params, timeout=REQUEST_TIMEOUT_SECONDS
)
if (
resp.status_code in GRAPH_API_RETRYABLE_STATUSES
and attempt < GRAPH_API_MAX_RETRIES
):
wait = min(int(resp.headers.get("Retry-After", str(2**attempt))), 60)
logger.warning(
f"Graph API {resp.status_code} on attempt {attempt + 1}, "
f"retrying in {wait}s: {url}"
)
time.sleep(wait)
continue
resp.raise_for_status()
return resp.json()
except (_requests.ConnectionError, _requests.Timeout, _requests.HTTPError):
if attempt < GRAPH_API_MAX_RETRIES:
wait = min(2**attempt, 60)
logger.warning(
f"Graph API connection error on attempt {attempt + 1}, "
f"retrying in {wait}s: {url}"
)
time.sleep(wait)
continue
raise
raise RuntimeError(
f"Graph API request failed after {GRAPH_API_MAX_RETRIES + 1} attempts: {url}"
)
def _iter_graph_collection(
initial_url: str,
get_access_token: Callable[[], str],
params: dict[str, str] | None = None,
) -> Generator[dict[str, Any], None, None]:
"""Paginate through a Graph API collection, yielding items one at a time."""
url: str | None = initial_url
while url:
data = _graph_api_get(url, get_access_token, params)
params = None
yield from data.get("value", [])
url = data.get("@odata.nextLink")
def _normalize_email(email: str) -> str:
if MICROSOFT_DOMAIN in email:
return email.replace(MICROSOFT_DOMAIN, "")
return email
class SharepointGroup(BaseModel):
model_config = {"frozen": True}
@@ -527,8 +597,12 @@ def get_external_access_from_sharepoint(
)
elif site_page:
site_url = site_page.get("webUrl")
# Prefer server-relative URL to avoid OData filters that break on apostrophes
server_relative_url = unquote(urlparse(site_url).path)
# Keep percent-encoding intact so the path matches the encoding
# used by the Office365 library's SPResPath.create_relative(),
# which compares against urlparse(context.base_url).path.
# Decoding (e.g. %27 → ') causes a mismatch that duplicates
# the site prefix in the constructed URL.
server_relative_url = urlparse(site_url).path
file_obj = client_context.web.get_file_by_server_relative_url(
server_relative_url
)
@@ -572,8 +646,65 @@ def get_external_access_from_sharepoint(
)
def _enumerate_ad_groups_paginated(
get_access_token: Callable[[], str],
already_resolved: set[str],
graph_api_base: str,
) -> Generator[ExternalUserGroup, None, None]:
"""Paginate through all Azure AD groups and yield ExternalUserGroup for each.
Skips groups whose suffixed name is already in *already_resolved*.
Stops early if the number of groups exceeds AD_GROUP_ENUMERATION_THRESHOLD.
"""
groups_url = f"{graph_api_base}/groups"
groups_params: dict[str, str] = {"$select": "id,displayName", "$top": "999"}
total_groups = 0
for group_json in _iter_graph_collection(
groups_url, get_access_token, groups_params
):
group_id: str = group_json.get("id", "")
display_name: str = group_json.get("displayName", "")
if not group_id or not display_name:
continue
total_groups += 1
if total_groups > AD_GROUP_ENUMERATION_THRESHOLD:
logger.warning(
f"Azure AD group enumeration exceeded {AD_GROUP_ENUMERATION_THRESHOLD} "
"groups — stopping to avoid excessive memory/API usage. "
"Remaining groups will be resolved from role assignments only."
)
return
name = f"{display_name}_{group_id}"
if name in already_resolved:
continue
member_emails: list[str] = []
members_url = f"{graph_api_base}/groups/{group_id}/members"
members_params: dict[str, str] = {
"$select": "userPrincipalName,mail",
"$top": "999",
}
for member_json in _iter_graph_collection(
members_url, get_access_token, members_params
):
email = member_json.get("userPrincipalName") or member_json.get("mail")
if email:
member_emails.append(_normalize_email(email))
yield ExternalUserGroup(id=name, user_emails=member_emails)
logger.info(f"Enumerated {total_groups} Azure AD groups via paginated Graph API")
def get_sharepoint_external_groups(
client_context: ClientContext, graph_client: GraphClient
client_context: ClientContext,
graph_client: GraphClient,
graph_api_base: str,
get_access_token: Callable[[], str] | None = None,
enumerate_all_ad_groups: bool = False,
) -> list[ExternalUserGroup]:
groups: set[SharepointGroup] = set()
@@ -629,57 +760,22 @@ def get_sharepoint_external_groups(
client_context, graph_client, groups, is_group_sync=True
)
# get all Azure AD groups because if any group is assigned to the drive item, we don't want to miss them
# We can't assign sharepoint groups to drive items or drives, so we don't need to get all sharepoint groups
azure_ad_groups = sleep_and_retry(
graph_client.groups.get_all(page_loaded=lambda _: None),
"get_sharepoint_external_groups:get_azure_ad_groups",
)
logger.info(f"Azure AD Groups: {len(azure_ad_groups)}")
identified_groups: set[str] = set(groups_and_members.groups_to_emails.keys())
ad_groups_to_emails: dict[str, set[str]] = {}
for group in azure_ad_groups:
# If the group is already identified, we don't need to get the members
if group.display_name in identified_groups:
continue
# AD groups allows same display name for multiple groups, so we need to add the GUID to the name
name = group.display_name
name = _get_group_name_with_suffix(group.id, name, graph_client)
external_user_groups: list[ExternalUserGroup] = [
ExternalUserGroup(id=group_name, user_emails=list(emails))
for group_name, emails in groups_and_members.groups_to_emails.items()
]
members = sleep_and_retry(
group.members.get_all(page_loaded=lambda _: None),
"get_sharepoint_external_groups:get_azure_ad_groups:get_members",
if not enumerate_all_ad_groups or get_access_token is None:
logger.info(
"Skipping exhaustive Azure AD group enumeration. "
"Only groups found in site role assignments are included."
)
for member in members:
member_data = member.to_json()
user_principal_name = member_data.get("userPrincipalName")
mail = member_data.get("mail")
if not ad_groups_to_emails.get(name):
ad_groups_to_emails[name] = set()
if user_principal_name:
if MICROSOFT_DOMAIN in user_principal_name:
user_principal_name = user_principal_name.replace(
MICROSOFT_DOMAIN, ""
)
ad_groups_to_emails[name].add(user_principal_name)
elif mail:
if MICROSOFT_DOMAIN in mail:
mail = mail.replace(MICROSOFT_DOMAIN, "")
ad_groups_to_emails[name].add(mail)
return external_user_groups
external_user_groups: list[ExternalUserGroup] = []
for group_name, emails in groups_and_members.groups_to_emails.items():
external_user_group = ExternalUserGroup(
id=group_name,
user_emails=list(emails),
)
external_user_groups.append(external_user_group)
for group_name, emails in ad_groups_to_emails.items():
external_user_group = ExternalUserGroup(
id=group_name,
user_emails=list(emails),
)
external_user_groups.append(external_user_group)
already_resolved = set(groups_and_members.groups_to_emails.keys())
for group in _enumerate_ad_groups_paginated(
get_access_token, already_resolved, graph_api_base
):
external_user_groups.append(group)
return external_user_groups

View File

@@ -31,6 +31,7 @@ from ee.onyx.server.query_and_chat.query_backend import (
from ee.onyx.server.query_and_chat.search_backend import router as search_router
from ee.onyx.server.query_history.api import router as query_history_router
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
from ee.onyx.server.scim.api import scim_router
from ee.onyx.server.seeding import seed_db
from ee.onyx.server.tenants.api import router as tenants_router
from ee.onyx.server.token_rate_limits.api import (
@@ -162,6 +163,11 @@ def get_application() -> FastAPI:
# Tenant management
include_router_with_global_prefix_prepended(application, tenants_router)
# SCIM 2.0 — protocol endpoints (unauthenticated by Onyx session auth;
# they use their own SCIM bearer token auth).
# Not behind APP_API_PREFIX because IdPs expect /scim/v2/... directly.
application.include_router(scim_router)
# Ensure all routes have auth enabled or are explicitly marked as public
check_ee_router_auth(application)

View File

@@ -5,6 +5,11 @@ from onyx.server.auth_check import PUBLIC_ENDPOINT_SPECS
EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [
# SCIM 2.0 service discovery — unauthenticated so IdPs can probe
# before bearer token configuration is complete
("/scim/v2/ServiceProviderConfig", {"GET"}),
("/scim/v2/ResourceTypes", {"GET"}),
("/scim/v2/Schemas", {"GET"}),
# needs to be accessible prior to user login
("/enterprise-settings", {"GET"}),
("/enterprise-settings/logo", {"GET"}),

View File

@@ -13,6 +13,7 @@ from pydantic import BaseModel
from pydantic import Field
from sqlalchemy.orm import Session
from ee.onyx.db.scim import ScimDAL
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
from ee.onyx.server.enterprise_settings.store import get_logo_filename
@@ -22,6 +23,10 @@ from ee.onyx.server.enterprise_settings.store import load_settings
from ee.onyx.server.enterprise_settings.store import store_analytics_script
from ee.onyx.server.enterprise_settings.store import store_settings
from ee.onyx.server.enterprise_settings.store import upload_logo
from ee.onyx.server.scim.auth import generate_scim_token
from ee.onyx.server.scim.models import ScimTokenCreate
from ee.onyx.server.scim.models import ScimTokenCreatedResponse
from ee.onyx.server.scim.models import ScimTokenResponse
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user_with_expired_token
from onyx.auth.users import get_user_manager
@@ -198,3 +203,63 @@ def upload_custom_analytics_script(
@basic_router.get("/custom-analytics-script")
def fetch_custom_analytics_script() -> str | None:
return load_analytics_script()
# ---------------------------------------------------------------------------
# SCIM token management
# ---------------------------------------------------------------------------
def _get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
return ScimDAL(db_session)
@admin_router.get("/scim/token")
def get_active_scim_token(
_: User = Depends(current_admin_user),
dal: ScimDAL = Depends(_get_scim_dal),
) -> ScimTokenResponse:
"""Return the currently active SCIM token's metadata, or 404 if none."""
token = dal.get_active_token()
if not token:
raise HTTPException(status_code=404, detail="No active SCIM token")
return ScimTokenResponse(
id=token.id,
name=token.name,
token_display=token.token_display,
is_active=token.is_active,
created_at=token.created_at,
last_used_at=token.last_used_at,
)
@admin_router.post("/scim/token", status_code=201)
def create_scim_token(
body: ScimTokenCreate,
user: User = Depends(current_admin_user),
dal: ScimDAL = Depends(_get_scim_dal),
) -> ScimTokenCreatedResponse:
"""Create a new SCIM bearer token.
Only one token is active at a time — creating a new token automatically
revokes all previous tokens. The raw token value is returned exactly once
in the response; it cannot be retrieved again.
"""
raw_token, hashed_token, token_display = generate_scim_token()
token = dal.create_token(
name=body.name,
hashed_token=hashed_token,
token_display=token_display,
created_by_id=user.id,
)
dal.commit()
return ScimTokenCreatedResponse(
id=token.id,
name=token.name,
token_display=token.token_display,
is_active=token.is_active,
created_at=token.created_at,
last_used_at=token.last_used_at,
raw_token=raw_token,
)

View File

@@ -27,12 +27,14 @@ class SearchFlowClassificationResponse(BaseModel):
is_search_flow: bool
# NOTE: This model is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
# experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
class SendSearchQueryRequest(BaseModel):
search_query: str
filters: BaseFilters | None = None
num_docs_fed_to_llm_selection: int | None = None
run_query_expansion: bool = False
num_hits: int = 50
num_hits: int = 30
include_content: bool = False
stream: bool = False

View File

@@ -67,6 +67,8 @@ def search_flow_classification(
return SearchFlowClassificationResponse(is_search_flow=is_search_flow)
# NOTE: This endpoint is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an
# experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions.
@router.post(
"/send-search-message",
response_model=None,

View File

View File

@@ -0,0 +1,957 @@
"""SCIM 2.0 API endpoints (RFC 7644).
This module provides the FastAPI router for SCIM service discovery,
User CRUD, and Group CRUD. Identity providers (Okta, Azure AD) call
these endpoints to provision and manage users and groups.
Service discovery endpoints are unauthenticated — IdPs may probe them
before bearer token configuration is complete. All other endpoints
require a valid SCIM bearer token.
"""
from __future__ import annotations
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from fastapi import Query
from fastapi import Response
from fastapi.responses import JSONResponse
from fastapi_users.password import PasswordHelper
from sqlalchemy import func
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.db.scim import ScimDAL
from ee.onyx.server.scim.auth import verify_scim_token
from ee.onyx.server.scim.filtering import parse_scim_filter
from ee.onyx.server.scim.models import SCIM_LIST_RESPONSE_SCHEMA
from ee.onyx.server.scim.models import ScimError
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimListResponse
from ee.onyx.server.scim.models import ScimMappingFields
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimPatchRequest
from ee.onyx.server.scim.models import ScimServiceProviderConfig
from ee.onyx.server.scim.models import ScimUserResource
from ee.onyx.server.scim.patch import apply_group_patch
from ee.onyx.server.scim.patch import apply_user_patch
from ee.onyx.server.scim.patch import ScimPatchError
from ee.onyx.server.scim.providers.base import get_default_provider
from ee.onyx.server.scim.providers.base import ScimProvider
from ee.onyx.server.scim.providers.base import serialize_emails
from ee.onyx.server.scim.schema_definitions import ENTERPRISE_USER_SCHEMA_DEF
from ee.onyx.server.scim.schema_definitions import GROUP_RESOURCE_TYPE
from ee.onyx.server.scim.schema_definitions import GROUP_SCHEMA_DEF
from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import ScimToken
from onyx.db.models import ScimUserMapping
from onyx.db.models import User
from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
class ScimJSONResponse(JSONResponse):
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
media_type = "application/scim+json"
# NOTE: All URL paths in this router (/ServiceProviderConfig, /ResourceTypes,
# /Schemas, /Users, /Groups) are mandated by the SCIM spec (RFC 7643/7644).
# IdPs like Okta and Azure AD hardcode these exact paths, so they cannot be
# changed to kebab-case.
scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
_pw_helper = PasswordHelper()
def _get_provider(
_token: ScimToken = Depends(verify_scim_token),
) -> ScimProvider:
"""Resolve the SCIM provider for the current request.
Currently returns OktaProvider for all requests. When multi-provider
support is added (ENG-3652), this will resolve based on token metadata
or tenant configuration — no endpoint changes required.
"""
return get_default_provider()
# ---------------------------------------------------------------------------
# Service Discovery Endpoints (unauthenticated)
# ---------------------------------------------------------------------------
@scim_router.get("/ServiceProviderConfig")
def get_service_provider_config() -> ScimServiceProviderConfig:
"""Advertise supported SCIM features (RFC 7643 §5)."""
return SERVICE_PROVIDER_CONFIG
@scim_router.get("/ResourceTypes")
def get_resource_types() -> ScimJSONResponse:
"""List available SCIM resource types (RFC 7643 §6).
Wrapped in a ListResponse envelope (RFC 7644 §3.4.2) because IdPs
like Entra ID expect a JSON object, not a bare array.
"""
resources = [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
return ScimJSONResponse(
content={
"schemas": [SCIM_LIST_RESPONSE_SCHEMA],
"totalResults": len(resources),
"Resources": [
r.model_dump(exclude_none=True, by_alias=True) for r in resources
],
}
)
@scim_router.get("/Schemas")
def get_schemas() -> ScimJSONResponse:
"""Return SCIM schema definitions (RFC 7643 §7).
Wrapped in a ListResponse envelope (RFC 7644 §3.4.2) because IdPs
like Entra ID expect a JSON object, not a bare array.
"""
schemas = [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF, ENTERPRISE_USER_SCHEMA_DEF]
return ScimJSONResponse(
content={
"schemas": [SCIM_LIST_RESPONSE_SCHEMA],
"totalResults": len(schemas),
"Resources": [s.model_dump(exclude_none=True) for s in schemas],
}
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _scim_error_response(status: int, detail: str) -> ScimJSONResponse:
"""Build a SCIM-compliant error response (RFC 7644 §3.12)."""
logger.warning("SCIM error response: status=%s detail=%s", status, detail)
body = ScimError(status=str(status), detail=detail)
return ScimJSONResponse(
status_code=status,
content=body.model_dump(exclude_none=True),
)
def _parse_excluded_attributes(raw: str | None) -> set[str]:
"""Parse the ``excludedAttributes`` query parameter (RFC 7644 §3.4.2.5).
Returns a set of lowercased attribute names to omit from responses.
"""
if not raw:
return set()
return {attr.strip().lower() for attr in raw.split(",") if attr.strip()}
def _apply_exclusions(
resource: ScimUserResource | ScimGroupResource,
excluded: set[str],
) -> dict:
"""Serialize a SCIM resource, omitting attributes the IdP excluded.
RFC 7644 §3.4.2.5 lets the IdP pass ``?excludedAttributes=groups,emails``
to reduce response payload size. We strip those fields after serialization
so the rest of the pipeline doesn't need to know about them.
"""
data = resource.model_dump(exclude_none=True, by_alias=True)
for attr in excluded:
# Match case-insensitively against the camelCase field names
keys_to_remove = [k for k in data if k.lower() == attr]
for k in keys_to_remove:
del data[k]
return data
def _check_seat_availability(dal: ScimDAL) -> str | None:
"""Return an error message if seat limit is reached, else None."""
check_fn = fetch_ee_implementation_or_noop(
"onyx.db.license", "check_seat_availability", None
)
if check_fn is None:
return None
result = check_fn(dal.session, seats_needed=1)
if not result.available:
return result.error_message or "Seat limit reached"
return None
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | ScimJSONResponse:
"""Parse *user_id* as UUID, look up the user, or return a 404 error."""
try:
uid = UUID(user_id)
except ValueError:
return _scim_error_response(404, f"User {user_id} not found")
user = dal.get_user(uid)
if not user:
return _scim_error_response(404, f"User {user_id} not found")
return user
def _scim_name_to_str(name: ScimName | None) -> str | None:
"""Extract a display name string from a SCIM name object.
Returns None if no name is provided, so the caller can decide
whether to update the user's personal_name.
"""
if not name:
return None
# If the client explicitly provides ``formatted``, prefer it — the client
# knows what display string it wants. Otherwise build from components.
if name.formatted:
return name.formatted
parts = " ".join(part for part in [name.givenName, name.familyName] if part)
return parts or None
def _scim_resource_response(
resource: ScimUserResource | ScimGroupResource | ScimListResponse,
status_code: int = 200,
) -> ScimJSONResponse:
"""Serialize a SCIM resource as ``application/scim+json``."""
content = resource.model_dump(exclude_none=True, by_alias=True)
return ScimJSONResponse(
status_code=status_code,
content=content,
)
def _build_list_response(
resources: list[ScimUserResource | ScimGroupResource],
total: int,
start_index: int,
count: int,
excluded: set[str] | None = None,
) -> ScimListResponse | ScimJSONResponse:
"""Build a SCIM list response, optionally applying attribute exclusions.
RFC 7644 §3.4.2.5 — IdPs may request certain attributes be omitted via
the ``excludedAttributes`` query parameter.
"""
if excluded:
envelope = ScimListResponse(
totalResults=total,
startIndex=start_index,
itemsPerPage=count,
)
data = envelope.model_dump(exclude_none=True)
data["Resources"] = [_apply_exclusions(r, excluded) for r in resources]
return ScimJSONResponse(content=data)
return _scim_resource_response(
ScimListResponse(
totalResults=total,
startIndex=start_index,
itemsPerPage=count,
Resources=resources,
)
)
def _extract_enterprise_fields(
resource: ScimUserResource,
) -> tuple[str | None, str | None]:
"""Extract department and manager from enterprise extension."""
ext = resource.enterprise_extension
if not ext:
return None, None
department = ext.department
manager = ext.manager.value if ext.manager else None
return department, manager
def _mapping_to_fields(
mapping: ScimUserMapping | None,
) -> ScimMappingFields | None:
"""Extract round-trip fields from a SCIM user mapping."""
if not mapping:
return None
return ScimMappingFields(
department=mapping.department,
manager=mapping.manager,
given_name=mapping.given_name,
family_name=mapping.family_name,
scim_emails_json=mapping.scim_emails_json,
)
def _fields_from_resource(resource: ScimUserResource) -> ScimMappingFields:
"""Build mapping fields from an incoming SCIM user resource."""
department, manager = _extract_enterprise_fields(resource)
return ScimMappingFields(
department=department,
manager=manager,
given_name=resource.name.givenName if resource.name else None,
family_name=resource.name.familyName if resource.name else None,
scim_emails_json=serialize_emails(resource.emails),
)
# ---------------------------------------------------------------------------
# User CRUD (RFC 7644 §3)
# ---------------------------------------------------------------------------
@scim_router.get("/Users", response_model=None)
def list_users(
filter: str | None = Query(None),
excludedAttributes: str | None = None,
startIndex: int = Query(1, ge=1),
count: int = Query(100, ge=0, le=500),
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimListResponse | ScimJSONResponse:
"""List users with optional SCIM filter and pagination."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
dal.commit()
try:
scim_filter = parse_scim_filter(filter)
except ValueError as e:
return _scim_error_response(400, str(e))
try:
users_with_mappings, total = dal.list_users(scim_filter, startIndex, count)
except ValueError as e:
return _scim_error_response(400, str(e))
user_groups_map = dal.get_users_groups_batch([u.id for u, _ in users_with_mappings])
resources: list[ScimUserResource | ScimGroupResource] = [
provider.build_user_resource(
user,
mapping.external_id if mapping else None,
groups=user_groups_map.get(user.id, []),
scim_username=mapping.scim_username if mapping else None,
fields=_mapping_to_fields(mapping),
)
for user, mapping in users_with_mappings
]
return _build_list_response(
resources,
total,
startIndex,
count,
excluded=_parse_excluded_attributes(excludedAttributes),
)
@scim_router.get("/Users/{user_id}", response_model=None)
def get_user(
user_id: str,
excludedAttributes: str | None = None,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | ScimJSONResponse:
"""Get a single user by ID."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
dal.commit()
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, ScimJSONResponse):
return result
user = result
mapping = dal.get_user_mapping_by_user_id(user.id)
resource = provider.build_user_resource(
user,
mapping.external_id if mapping else None,
groups=dal.get_user_groups(user.id),
scim_username=mapping.scim_username if mapping else None,
fields=_mapping_to_fields(mapping),
)
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
excluded = _parse_excluded_attributes(excludedAttributes)
if excluded:
return ScimJSONResponse(content=_apply_exclusions(resource, excluded))
return _scim_resource_response(resource)
@scim_router.post("/Users", status_code=201, response_model=None)
def create_user(
user_resource: ScimUserResource,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | ScimJSONResponse:
"""Create a new user from a SCIM provisioning request."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
email = user_resource.userName.strip()
# externalId is how the IdP correlates this user on subsequent requests.
# Without it, the IdP can't find the user and will try to re-create,
# hitting a 409 conflict — so we require it up front.
if not user_resource.externalId:
return _scim_error_response(400, "externalId is required")
# Enforce seat limit
seat_error = _check_seat_availability(dal)
if seat_error:
return _scim_error_response(403, seat_error)
# Check for existing user
if dal.get_user_by_email(email):
return _scim_error_response(409, f"User with email {email} already exists")
# Create user with a random password (SCIM users authenticate via IdP)
personal_name = _scim_name_to_str(user_resource.name)
user = User(
email=email,
hashed_password=_pw_helper.hash(_pw_helper.generate()),
role=UserRole.BASIC,
is_active=user_resource.active,
is_verified=True,
personal_name=personal_name,
)
try:
dal.add_user(user)
except IntegrityError:
dal.rollback()
return _scim_error_response(409, f"User with email {email} already exists")
# Create SCIM mapping (externalId is validated above, always present)
external_id = user_resource.externalId
scim_username = user_resource.userName.strip()
fields = _fields_from_resource(user_resource)
dal.create_user_mapping(
external_id=external_id,
user_id=user.id,
scim_username=scim_username,
fields=fields,
)
dal.commit()
return _scim_resource_response(
provider.build_user_resource(
user,
external_id,
scim_username=scim_username,
fields=fields,
),
status_code=201,
)
@scim_router.put("/Users/{user_id}", response_model=None)
def replace_user(
user_id: str,
user_resource: ScimUserResource,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | ScimJSONResponse:
"""Replace a user entirely (RFC 7644 §3.5.1)."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, ScimJSONResponse):
return result
user = result
# Handle activation (need seat check) / deactivation
if user_resource.active and not user.is_active:
seat_error = _check_seat_availability(dal)
if seat_error:
return _scim_error_response(403, seat_error)
personal_name = _scim_name_to_str(user_resource.name)
dal.update_user(
user,
email=user_resource.userName.strip(),
is_active=user_resource.active,
personal_name=personal_name,
)
new_external_id = user_resource.externalId
scim_username = user_resource.userName.strip()
fields = _fields_from_resource(user_resource)
dal.sync_user_external_id(
user.id,
new_external_id,
scim_username=scim_username,
fields=fields,
)
dal.commit()
return _scim_resource_response(
provider.build_user_resource(
user,
new_external_id,
groups=dal.get_user_groups(user.id),
scim_username=scim_username,
fields=fields,
)
)
@scim_router.patch("/Users/{user_id}", response_model=None)
def patch_user(
user_id: str,
patch_request: ScimPatchRequest,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | ScimJSONResponse:
"""Partially update a user (RFC 7644 §3.5.2).
This is the primary endpoint for user deprovisioning — Okta sends
``PATCH {"active": false}`` rather than DELETE.
"""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, ScimJSONResponse):
return result
user = result
mapping = dal.get_user_mapping_by_user_id(user.id)
external_id = mapping.external_id if mapping else None
current_scim_username = mapping.scim_username if mapping else None
current_fields = _mapping_to_fields(mapping)
current = provider.build_user_resource(
user,
external_id,
groups=dal.get_user_groups(user.id),
scim_username=current_scim_username,
fields=current_fields,
)
try:
patched, ent_data = apply_user_patch(
patch_request.Operations, current, provider.ignored_patch_paths
)
except ScimPatchError as e:
return _scim_error_response(e.status, e.detail)
# Apply changes back to the DB model
if patched.active != user.is_active:
if patched.active:
seat_error = _check_seat_availability(dal)
if seat_error:
return _scim_error_response(403, seat_error)
# Track the scim_username — if userName was patched, update it
new_scim_username = patched.userName.strip() if patched.userName else None
# If displayName was explicitly patched (different from the original), use
# it as personal_name directly. Otherwise, derive from name components.
personal_name: str | None
if patched.displayName and patched.displayName != current.displayName:
personal_name = patched.displayName
else:
personal_name = _scim_name_to_str(patched.name)
dal.update_user(
user,
email=(
patched.userName.strip()
if patched.userName.strip().lower() != user.email.lower()
else None
),
is_active=patched.active if patched.active != user.is_active else None,
personal_name=personal_name,
)
# Build updated fields by merging PATCH enterprise data with current values
cf = current_fields or ScimMappingFields()
fields = ScimMappingFields(
department=ent_data.get("department", cf.department),
manager=ent_data.get("manager", cf.manager),
given_name=patched.name.givenName if patched.name else cf.given_name,
family_name=patched.name.familyName if patched.name else cf.family_name,
scim_emails_json=(
serialize_emails(patched.emails)
if patched.emails is not None
else cf.scim_emails_json
),
)
dal.sync_user_external_id(
user.id,
patched.externalId,
scim_username=new_scim_username,
fields=fields,
)
dal.commit()
return _scim_resource_response(
provider.build_user_resource(
user,
patched.externalId,
groups=dal.get_user_groups(user.id),
scim_username=new_scim_username,
fields=fields,
)
)
@scim_router.delete("/Users/{user_id}", status_code=204, response_model=None)
def delete_user(
user_id: str,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> Response | ScimJSONResponse:
"""Delete a user (RFC 7644 §3.6).
Deactivates the user and removes the SCIM mapping. Note that Okta
typically uses PATCH active=false instead of DELETE.
A second DELETE returns 404 per RFC 7644 §3.6.
"""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, ScimJSONResponse):
return result
user = result
# If no SCIM mapping exists, the user was already deleted from
# SCIM's perspective — return 404 per RFC 7644 §3.6.
mapping = dal.get_user_mapping_by_user_id(user.id)
if not mapping:
return _scim_error_response(404, f"User {user_id} not found")
dal.deactivate_user(user)
dal.delete_user_mapping(mapping.id)
dal.commit()
return Response(status_code=204)
# ---------------------------------------------------------------------------
# Group helpers
# ---------------------------------------------------------------------------
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | ScimJSONResponse:
"""Parse *group_id* as int, look up the group, or return a 404 error."""
try:
gid = int(group_id)
except ValueError:
return _scim_error_response(404, f"Group {group_id} not found")
group = dal.get_group(gid)
if not group:
return _scim_error_response(404, f"Group {group_id} not found")
return group
def _parse_member_uuids(
members: list[ScimGroupMember],
) -> tuple[list[UUID], str | None]:
"""Parse member value strings to UUIDs.
Returns (uuid_list, error_message). error_message is None on success.
"""
uuids: list[UUID] = []
for m in members:
try:
uuids.append(UUID(m.value))
except ValueError:
return [], f"Invalid member ID: {m.value}"
return uuids, None
def _validate_and_parse_members(
members: list[ScimGroupMember], dal: ScimDAL
) -> tuple[list[UUID], str | None]:
"""Parse and validate member UUIDs exist in the database.
Returns (uuid_list, error_message). error_message is None on success.
"""
uuids, err = _parse_member_uuids(members)
if err:
return [], err
if uuids:
missing = dal.validate_member_ids(uuids)
if missing:
return [], f"Member(s) not found: {', '.join(str(u) for u in missing)}"
return uuids, None
# ---------------------------------------------------------------------------
# Group CRUD (RFC 7644 §3)
# ---------------------------------------------------------------------------
@scim_router.get("/Groups", response_model=None)
def list_groups(
filter: str | None = Query(None),
excludedAttributes: str | None = None,
startIndex: int = Query(1, ge=1),
count: int = Query(100, ge=0, le=500),
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimListResponse | ScimJSONResponse:
"""List groups with optional SCIM filter and pagination."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
dal.commit()
try:
scim_filter = parse_scim_filter(filter)
except ValueError as e:
return _scim_error_response(400, str(e))
try:
groups_with_ext_ids, total = dal.list_groups(scim_filter, startIndex, count)
except ValueError as e:
return _scim_error_response(400, str(e))
resources: list[ScimUserResource | ScimGroupResource] = [
provider.build_group_resource(group, dal.get_group_members(group.id), ext_id)
for group, ext_id in groups_with_ext_ids
]
return _build_list_response(
resources,
total,
startIndex,
count,
excluded=_parse_excluded_attributes(excludedAttributes),
)
@scim_router.get("/Groups/{group_id}", response_model=None)
def get_group(
group_id: str,
excludedAttributes: str | None = None,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | ScimJSONResponse:
"""Get a single group by ID."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
dal.commit()
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, ScimJSONResponse):
return result
group = result
mapping = dal.get_group_mapping_by_group_id(group.id)
members = dal.get_group_members(group.id)
resource = provider.build_group_resource(
group, members, mapping.external_id if mapping else None
)
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
excluded = _parse_excluded_attributes(excludedAttributes)
if excluded:
return ScimJSONResponse(content=_apply_exclusions(resource, excluded))
return _scim_resource_response(resource)
@scim_router.post("/Groups", status_code=201, response_model=None)
def create_group(
group_resource: ScimGroupResource,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | ScimJSONResponse:
"""Create a new group from a SCIM provisioning request."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
if dal.get_group_by_name(group_resource.displayName):
return _scim_error_response(
409, f"Group with name '{group_resource.displayName}' already exists"
)
member_uuids, err = _validate_and_parse_members(group_resource.members, dal)
if err:
return _scim_error_response(400, err)
db_group = UserGroup(
name=group_resource.displayName,
is_up_to_date=True,
time_last_modified_by_user=func.now(),
)
try:
dal.add_group(db_group)
except IntegrityError:
dal.rollback()
return _scim_error_response(
409, f"Group with name '{group_resource.displayName}' already exists"
)
dal.upsert_group_members(db_group.id, member_uuids)
external_id = group_resource.externalId
if external_id:
dal.create_group_mapping(external_id=external_id, user_group_id=db_group.id)
dal.commit()
members = dal.get_group_members(db_group.id)
return _scim_resource_response(
provider.build_group_resource(db_group, members, external_id),
status_code=201,
)
@scim_router.put("/Groups/{group_id}", response_model=None)
def replace_group(
group_id: str,
group_resource: ScimGroupResource,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | ScimJSONResponse:
"""Replace a group entirely (RFC 7644 §3.5.1)."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, ScimJSONResponse):
return result
group = result
member_uuids, err = _validate_and_parse_members(group_resource.members, dal)
if err:
return _scim_error_response(400, err)
dal.update_group(group, name=group_resource.displayName)
dal.replace_group_members(group.id, member_uuids)
dal.sync_group_external_id(group.id, group_resource.externalId)
dal.commit()
members = dal.get_group_members(group.id)
return _scim_resource_response(
provider.build_group_resource(group, members, group_resource.externalId)
)
@scim_router.patch("/Groups/{group_id}", response_model=None)
def patch_group(
group_id: str,
patch_request: ScimPatchRequest,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | ScimJSONResponse:
"""Partially update a group (RFC 7644 §3.5.2).
Handles member add/remove operations from Okta and Azure AD.
"""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, ScimJSONResponse):
return result
group = result
mapping = dal.get_group_mapping_by_group_id(group.id)
external_id = mapping.external_id if mapping else None
current_members = dal.get_group_members(group.id)
current = provider.build_group_resource(group, current_members, external_id)
try:
patched, added_ids, removed_ids = apply_group_patch(
patch_request.Operations, current, provider.ignored_patch_paths
)
except ScimPatchError as e:
return _scim_error_response(e.status, e.detail)
new_name = patched.displayName if patched.displayName != group.name else None
dal.update_group(group, name=new_name)
if added_ids:
add_uuids = [UUID(mid) for mid in added_ids if _is_valid_uuid(mid)]
if add_uuids:
missing = dal.validate_member_ids(add_uuids)
if missing:
return _scim_error_response(
400,
f"Member(s) not found: {', '.join(str(u) for u in missing)}",
)
dal.upsert_group_members(group.id, add_uuids)
if removed_ids:
remove_uuids = [UUID(mid) for mid in removed_ids if _is_valid_uuid(mid)]
dal.remove_group_members(group.id, remove_uuids)
dal.sync_group_external_id(group.id, patched.externalId)
dal.commit()
members = dal.get_group_members(group.id)
return _scim_resource_response(
provider.build_group_resource(group, members, patched.externalId)
)
@scim_router.delete("/Groups/{group_id}", status_code=204, response_model=None)
def delete_group(
group_id: str,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> Response | ScimJSONResponse:
"""Delete a group (RFC 7644 §3.6)."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, ScimJSONResponse):
return result
group = result
mapping = dal.get_group_mapping_by_group_id(group.id)
if mapping:
dal.delete_group_mapping(mapping.id)
dal.delete_group_with_members(group)
dal.commit()
return Response(status_code=204)
def _is_valid_uuid(value: str) -> bool:
"""Check if a string is a valid UUID."""
try:
UUID(value)
return True
except ValueError:
return False

View File

@@ -0,0 +1,104 @@
"""SCIM bearer token authentication.
SCIM endpoints are authenticated via bearer tokens that admins create in the
Onyx UI. This module provides:
- ``verify_scim_token``: FastAPI dependency that extracts, hashes, and
validates the token from the Authorization header.
- ``generate_scim_token``: Creates a new cryptographically random token
and returns the raw value, its SHA-256 hash, and a display suffix.
Token format: ``onyx_scim_<random>`` where ``<random>`` is 48 bytes of
URL-safe base64 from ``secrets.token_urlsafe``.
The hash is stored in the ``scim_token`` table; the raw value is shown to
the admin exactly once at creation time.
"""
import hashlib
import secrets
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy.orm import Session
from ee.onyx.db.scim import ScimDAL
from onyx.auth.utils import get_hashed_bearer_token_from_request
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import ScimToken
SCIM_TOKEN_PREFIX = "onyx_scim_"
SCIM_TOKEN_LENGTH = 48
def _hash_scim_token(token: str) -> str:
"""SHA-256 hash a SCIM token. No salt needed — tokens are random."""
return hashlib.sha256(token.encode("utf-8")).hexdigest()
def generate_scim_token() -> tuple[str, str, str]:
"""Generate a new SCIM bearer token.
Returns:
A tuple of ``(raw_token, hashed_token, token_display)`` where
``token_display`` is a masked version showing only the last 4 chars.
"""
raw_token = SCIM_TOKEN_PREFIX + secrets.token_urlsafe(SCIM_TOKEN_LENGTH)
hashed_token = _hash_scim_token(raw_token)
token_display = SCIM_TOKEN_PREFIX + "****" + raw_token[-4:]
return raw_token, hashed_token, token_display
def _get_hashed_scim_token_from_request(request: Request) -> str | None:
"""Extract and hash a SCIM token from the request Authorization header."""
return get_hashed_bearer_token_from_request(
request,
valid_prefixes=[SCIM_TOKEN_PREFIX],
hash_fn=_hash_scim_token,
)
def _get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL:
return ScimDAL(db_session)
def verify_scim_token(
request: Request,
dal: ScimDAL = Depends(_get_scim_dal),
) -> ScimToken:
"""FastAPI dependency that authenticates SCIM requests.
Extracts the bearer token from the Authorization header, hashes it,
looks it up in the database, and verifies it is active.
Note:
This dependency does NOT update ``last_used_at`` — the endpoint
should do that via ``ScimDAL.update_token_last_used()`` so the
timestamp write is part of the endpoint's transaction.
Raises:
HTTPException(401): If the token is missing, invalid, or inactive.
"""
hashed = _get_hashed_scim_token_from_request(request)
if not hashed:
raise HTTPException(
status_code=401,
detail="Missing or invalid SCIM bearer token",
)
token = dal.get_token_by_hash(hashed)
if not token:
raise HTTPException(
status_code=401,
detail="Invalid SCIM bearer token",
)
if not token.is_active:
raise HTTPException(
status_code=401,
detail="SCIM token has been revoked",
)
return token

View File

@@ -0,0 +1,96 @@
"""SCIM filter expression parser (RFC 7644 §3.4.2.2).
Identity providers (Okta, Azure AD, OneLogin, etc.) use filters to look up
resources before deciding whether to create or update them. For example, when
an admin assigns a user to the Onyx app, the IdP first checks whether that
user already exists::
GET /scim/v2/Users?filter=userName eq "john@example.com"
If zero results come back the IdP creates the user (``POST``); if a match is
found it links to the existing record and uses ``PUT``/``PATCH`` going forward.
The same pattern applies to groups (``displayName eq "Engineering"``).
This module parses the subset of the SCIM filter grammar that identity
providers actually send in practice:
attribute SP operator SP value
Supported operators: ``eq``, ``co`` (contains), ``sw`` (starts with).
Compound filters (``and`` / ``or``) are not supported; if an IdP sends one
the parser returns ``None`` and the caller falls back to an unfiltered list.
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from enum import Enum
class ScimFilterOperator(str, Enum):
"""Supported SCIM filter operators."""
EQUAL = "eq"
CONTAINS = "co"
STARTS_WITH = "sw"
@dataclass(frozen=True, slots=True)
class ScimFilter:
"""Parsed SCIM filter expression."""
attribute: str
operator: ScimFilterOperator
value: str
# Matches: attribute operator "value" (with or without quotes around value)
# Groups: (attribute) (operator) ("quoted value" | unquoted_value)
_FILTER_RE = re.compile(
r"^(\S+)\s+(eq|co|sw)\s+" # attribute + operator
r'(?:"([^"]*)"' # quoted value
r"|'([^']*)')" # or single-quoted value
r"$",
re.IGNORECASE,
)
def parse_scim_filter(filter_string: str | None) -> ScimFilter | None:
"""Parse a simple SCIM filter expression.
Args:
filter_string: Raw filter query parameter value, e.g.
``'userName eq "john@example.com"'``
Returns:
A ``ScimFilter`` if the expression is valid and uses a supported
operator, or ``None`` if the input is empty / missing.
Raises:
ValueError: If the filter string is present but malformed or uses
an unsupported operator.
"""
if not filter_string or not filter_string.strip():
return None
match = _FILTER_RE.match(filter_string.strip())
if not match:
raise ValueError(f"Unsupported or malformed SCIM filter: {filter_string}")
return _build_filter(match, filter_string)
def _build_filter(match: re.Match[str], raw: str) -> ScimFilter:
"""Extract fields from a regex match and construct a ScimFilter."""
attribute = match.group(1)
op_str = match.group(2).lower()
# Value is in group 3 (double-quoted) or group 4 (single-quoted)
value = match.group(3) if match.group(3) is not None else match.group(4)
if value is None:
raise ValueError(f"Unsupported or malformed SCIM filter: {raw}")
operator = ScimFilterOperator(op_str)
return ScimFilter(attribute=attribute, operator=operator, value=value)

View File

@@ -0,0 +1,376 @@
"""Pydantic schemas for SCIM 2.0 provisioning (RFC 7643 / RFC 7644).
SCIM protocol schemas follow the wire format defined in:
- Core Schema: https://datatracker.ietf.org/doc/html/rfc7643
- Protocol: https://datatracker.ietf.org/doc/html/rfc7644
Admin API schemas are internal to Onyx and used for SCIM token management.
"""
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import field_validator
# ---------------------------------------------------------------------------
# SCIM Schema URIs (RFC 7643 §8)
# Every SCIM JSON payload includes a "schemas" array identifying its type.
# IdPs like Okta/Azure AD use these URIs to determine how to parse responses.
# ---------------------------------------------------------------------------
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group"
SCIM_LIST_RESPONSE_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:ListResponse"
SCIM_PATCH_OP_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:PatchOp"
SCIM_ERROR_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:Error"
SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA = (
"urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"
)
SCIM_RESOURCE_TYPE_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:ResourceType"
SCIM_SCHEMA_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Schema"
SCIM_ENTERPRISE_USER_SCHEMA = (
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
)
# ---------------------------------------------------------------------------
# SCIM Protocol Schemas
# ---------------------------------------------------------------------------
class ScimName(BaseModel):
"""User name components (RFC 7643 §4.1.1)."""
givenName: str | None = None
familyName: str | None = None
formatted: str | None = None
class ScimEmail(BaseModel):
"""Email sub-attribute (RFC 7643 §4.1.2)."""
value: str
type: str | None = None
primary: bool = False
class ScimMeta(BaseModel):
"""Resource metadata (RFC 7643 §3.1)."""
resourceType: str | None = None
created: datetime | None = None
lastModified: datetime | None = None
location: str | None = None
class ScimUserGroupRef(BaseModel):
"""Group reference within a User resource (RFC 7643 §4.1.2, read-only)."""
value: str
display: str | None = None
class ScimManagerRef(BaseModel):
"""Manager sub-attribute for the enterprise extension (RFC 7643 §4.3)."""
value: str | None = None
class ScimEnterpriseExtension(BaseModel):
"""Enterprise User extension attributes (RFC 7643 §4.3)."""
department: str | None = None
manager: ScimManagerRef | None = None
@dataclass
class ScimMappingFields:
"""Stored SCIM mapping fields that need to round-trip through the IdP.
Entra ID sends structured name components, email metadata, and enterprise
extension attributes that must be returned verbatim in subsequent GET
responses. These fields are persisted on ScimUserMapping and threaded
through the DAL, provider, and endpoint layers.
"""
department: str | None = None
manager: str | None = None
given_name: str | None = None
family_name: str | None = None
scim_emails_json: str | None = None
class ScimUserResource(BaseModel):
"""SCIM User resource representation (RFC 7643 §4.1).
This is the JSON shape that IdPs send when creating/updating a user via
SCIM, and the shape we return in GET responses. Field names use camelCase
to match the SCIM wire format (not Python convention).
"""
model_config = ConfigDict(populate_by_name=True)
schemas: list[str] = Field(default_factory=lambda: [SCIM_USER_SCHEMA])
id: str | None = None # Onyx's internal user ID, set on responses
externalId: str | None = None # IdP's identifier for this user
userName: str # Typically the user's email address
name: ScimName | None = None
displayName: str | None = None
emails: list[ScimEmail] = Field(default_factory=list)
active: bool = True
groups: list[ScimUserGroupRef] = Field(default_factory=list)
meta: ScimMeta | None = None
enterprise_extension: ScimEnterpriseExtension | None = Field(
default=None,
alias="urn:ietf:params:scim:schemas:extension:enterprise:2.0:User",
)
class ScimGroupMember(BaseModel):
"""Group member reference (RFC 7643 §4.2).
Represents a user within a SCIM group. The IdP sends these when adding
or removing users from groups. ``value`` is the Onyx user ID.
"""
value: str # User ID of the group member
display: str | None = None
class ScimGroupResource(BaseModel):
"""SCIM Group resource representation (RFC 7643 §4.2)."""
schemas: list[str] = Field(default_factory=lambda: [SCIM_GROUP_SCHEMA])
id: str | None = None
externalId: str | None = None
displayName: str
members: list[ScimGroupMember] = Field(default_factory=list)
meta: ScimMeta | None = None
class ScimListResponse(BaseModel):
"""Paginated list response (RFC 7644 §3.4.2)."""
schemas: list[str] = Field(default_factory=lambda: [SCIM_LIST_RESPONSE_SCHEMA])
totalResults: int
startIndex: int = 1
itemsPerPage: int = 100
Resources: list[ScimUserResource | ScimGroupResource] = Field(default_factory=list)
class ScimPatchOperationType(str, Enum):
"""Supported PATCH operations (RFC 7644 §3.5.2)."""
ADD = "add"
REPLACE = "replace"
REMOVE = "remove"
class ScimPatchResourceValue(BaseModel):
"""Partial resource dict for path-less PATCH replace operations.
When an IdP sends a PATCH without a ``path``, the ``value`` is a dict
of resource attributes to set. IdPs may include read-only fields
(``id``, ``schemas``, ``meta``) alongside actual changes — these are
stripped by the provider's ``ignored_patch_paths`` before processing.
``extra="allow"`` lets unknown attributes pass through so the patch
handler can decide what to do with them (ignore or reject).
"""
model_config = ConfigDict(extra="allow")
active: bool | None = None
userName: str | None = None
displayName: str | None = None
externalId: str | None = None
name: ScimName | None = None
members: list[ScimGroupMember] | None = None
id: str | None = None
schemas: list[str] | None = None
meta: ScimMeta | None = None
ScimPatchValue = str | bool | list[ScimGroupMember] | ScimPatchResourceValue | None
class ScimPatchOperation(BaseModel):
"""Single PATCH operation (RFC 7644 §3.5.2)."""
op: ScimPatchOperationType
path: str | None = None
value: ScimPatchValue = None
@field_validator("op", mode="before")
@classmethod
def normalize_operation(cls, v: object) -> object:
"""Normalize op to lowercase for case-insensitive matching.
Some IdPs (e.g. Entra ID) send capitalized ops like ``"Replace"``
instead of ``"replace"``. This is safe for all providers since the
enum values are lowercase. If a future provider requires other
pre-processing quirks, move patch deserialization into the provider
subclass instead of adding more special cases here.
"""
return v.lower() if isinstance(v, str) else v
class ScimPatchRequest(BaseModel):
"""PATCH request body (RFC 7644 §3.5.2).
IdPs use PATCH to make incremental changes — e.g. deactivating a user
(replace active=false) or adding/removing group members — instead of
replacing the entire resource with PUT.
"""
schemas: list[str] = Field(default_factory=lambda: [SCIM_PATCH_OP_SCHEMA])
Operations: list[ScimPatchOperation]
class ScimError(BaseModel):
"""SCIM error response (RFC 7644 §3.12)."""
schemas: list[str] = Field(default_factory=lambda: [SCIM_ERROR_SCHEMA])
status: str
detail: str | None = None
scimType: str | None = None
# ---------------------------------------------------------------------------
# Service Provider Configuration (RFC 7643 §5)
# ---------------------------------------------------------------------------
class ScimSupported(BaseModel):
"""Generic supported/not-supported flag used in ServiceProviderConfig."""
supported: bool
class ScimFilterConfig(BaseModel):
"""Filter configuration within ServiceProviderConfig (RFC 7643 §5)."""
supported: bool
maxResults: int = 100
class ScimServiceProviderConfig(BaseModel):
"""SCIM ServiceProviderConfig resource (RFC 7643 §5).
Served at GET /scim/v2/ServiceProviderConfig. IdPs fetch this during
initial setup to discover which SCIM features our server supports
(e.g. PATCH yes, bulk no, filtering yes).
"""
schemas: list[str] = Field(
default_factory=lambda: [SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA]
)
patch: ScimSupported = ScimSupported(supported=True)
bulk: ScimSupported = ScimSupported(supported=False)
filter: ScimFilterConfig = ScimFilterConfig(supported=True)
changePassword: ScimSupported = ScimSupported(supported=False)
sort: ScimSupported = ScimSupported(supported=False)
etag: ScimSupported = ScimSupported(supported=False)
authenticationSchemes: list[dict[str, str]] = Field(
default_factory=lambda: [
{
"type": "oauthbearertoken",
"name": "OAuth Bearer Token",
"description": "Authentication scheme using a SCIM bearer token",
}
]
)
class ScimSchemaAttribute(BaseModel):
"""Attribute definition within a SCIM Schema (RFC 7643 §7)."""
name: str
type: str
multiValued: bool = False
required: bool = False
description: str = ""
caseExact: bool = False
mutability: str = "readWrite"
returned: str = "default"
uniqueness: str = "none"
subAttributes: list["ScimSchemaAttribute"] = Field(default_factory=list)
class ScimSchemaDefinition(BaseModel):
"""SCIM Schema definition (RFC 7643 §7).
Served at GET /scim/v2/Schemas. Describes the attributes available
on each resource type so IdPs know which fields they can provision.
"""
schemas: list[str] = Field(default_factory=lambda: [SCIM_SCHEMA_SCHEMA])
id: str
name: str
description: str
attributes: list[ScimSchemaAttribute] = Field(default_factory=list)
class ScimSchemaExtension(BaseModel):
"""Schema extension reference within ResourceType (RFC 7643 §6)."""
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
schema_: str = Field(alias="schema")
required: bool
class ScimResourceType(BaseModel):
"""SCIM ResourceType resource (RFC 7643 §6).
Served at GET /scim/v2/ResourceTypes. Tells the IdP which resource
types are available (Users, Groups) and their respective endpoints.
"""
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
schemas: list[str] = Field(default_factory=lambda: [SCIM_RESOURCE_TYPE_SCHEMA])
id: str
name: str
endpoint: str
description: str | None = None
schema_: str = Field(alias="schema")
schemaExtensions: list[ScimSchemaExtension] = Field(default_factory=list)
# ---------------------------------------------------------------------------
# Admin API Schemas (Onyx-internal, for SCIM token management)
# These are NOT part of the SCIM protocol. They power the Onyx admin UI
# where admins create/revoke the bearer tokens that IdPs use to authenticate.
# ---------------------------------------------------------------------------
class ScimTokenCreate(BaseModel):
"""Request to create a new SCIM bearer token."""
name: str
class ScimTokenResponse(BaseModel):
"""SCIM token metadata returned in list/get responses."""
id: int
name: str
token_display: str
is_active: bool
created_at: datetime
last_used_at: datetime | None = None
class ScimTokenCreatedResponse(ScimTokenResponse):
"""Response returned when a new SCIM token is created.
Includes the raw token value which is only available at creation time.
"""
raw_token: str

View File

@@ -0,0 +1,461 @@
"""SCIM PATCH operation handler (RFC 7644 §3.5.2).
Identity providers use PATCH to make incremental changes to SCIM resources
instead of replacing the entire resource with PUT. Common operations include:
- Deactivating a user: ``replace`` ``active`` with ``false``
- Adding group members: ``add`` to ``members``
- Removing group members: ``remove`` from ``members[value eq "..."]``
This module applies PATCH operations to Pydantic SCIM resource objects and
returns the modified result. It does NOT touch the database — the caller is
responsible for persisting changes.
"""
from __future__ import annotations
import logging
import re
from dataclasses import dataclass
from dataclasses import field
from typing import Any
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimPatchOperation
from ee.onyx.server.scim.models import ScimPatchOperationType
from ee.onyx.server.scim.models import ScimPatchResourceValue
from ee.onyx.server.scim.models import ScimPatchValue
from ee.onyx.server.scim.models import ScimUserResource
logger = logging.getLogger(__name__)
# Lowercased enterprise extension URN for case-insensitive matching
_ENTERPRISE_URN_LOWER = SCIM_ENTERPRISE_USER_SCHEMA.lower()
# Pattern for email filter paths, e.g.:
# emails[primary eq true].value (Okta)
# emails[type eq "work"].value (Azure AD / Entra ID)
_EMAIL_FILTER_RE = re.compile(
r"^emails\[.+\]\.value$",
re.IGNORECASE,
)
# Pattern for member removal path: members[value eq "user-id"]
_MEMBER_FILTER_RE = re.compile(
r'^members\[value\s+eq\s+"([^"]+)"\]$',
re.IGNORECASE,
)
# ---------------------------------------------------------------------------
# Dispatch tables for user PATCH paths
#
# Maps lowercased SCIM path → (camelCase key, target dict name).
# "data" writes to the top-level resource dict, "name" writes to the
# name sub-object dict. This replaces the elif chains for simple fields.
# ---------------------------------------------------------------------------
_USER_REPLACE_PATHS: dict[str, tuple[str, str]] = {
"active": ("active", "data"),
"username": ("userName", "data"),
"externalid": ("externalId", "data"),
"name.givenname": ("givenName", "name"),
"name.familyname": ("familyName", "name"),
"name.formatted": ("formatted", "name"),
}
_USER_REMOVE_PATHS: dict[str, tuple[str, str]] = {
"externalid": ("externalId", "data"),
"name.givenname": ("givenName", "name"),
"name.familyname": ("familyName", "name"),
"name.formatted": ("formatted", "name"),
"displayname": ("displayName", "data"),
}
_GROUP_REPLACE_PATHS: dict[str, tuple[str, str]] = {
"displayname": ("displayName", "data"),
"externalid": ("externalId", "data"),
}
class ScimPatchError(Exception):
"""Raised when a PATCH operation cannot be applied."""
def __init__(self, detail: str, status: int = 400) -> None:
self.detail = detail
self.status = status
super().__init__(detail)
@dataclass
class _UserPatchCtx:
"""Bundles the mutable state for user PATCH operations."""
data: dict[str, Any]
name_data: dict[str, Any]
ent_data: dict[str, str | None] = field(default_factory=dict)
# ---------------------------------------------------------------------------
# User PATCH
# ---------------------------------------------------------------------------
def apply_user_patch(
operations: list[ScimPatchOperation],
current: ScimUserResource,
ignored_paths: frozenset[str] = frozenset(),
) -> tuple[ScimUserResource, dict[str, str | None]]:
"""Apply SCIM PATCH operations to a user resource.
Args:
operations: The PATCH operations to apply.
current: The current user resource state.
ignored_paths: SCIM attribute paths to silently skip (from provider).
Returns:
A tuple of (modified user resource, enterprise extension data dict).
The enterprise dict has keys ``"department"`` and ``"manager"``
with values set only when a PATCH operation touched them.
Raises:
ScimPatchError: If an operation targets an unsupported path.
"""
data = current.model_dump()
ctx = _UserPatchCtx(data=data, name_data=data.get("name") or {})
for op in operations:
if op.op in (ScimPatchOperationType.REPLACE, ScimPatchOperationType.ADD):
_apply_user_replace(op, ctx, ignored_paths)
elif op.op == ScimPatchOperationType.REMOVE:
_apply_user_remove(op, ctx, ignored_paths)
else:
raise ScimPatchError(
f"Unsupported operation '{op.op.value}' on User resource"
)
ctx.data["name"] = ctx.name_data
return ScimUserResource.model_validate(ctx.data), ctx.ent_data
def _apply_user_replace(
op: ScimPatchOperation,
ctx: _UserPatchCtx,
ignored_paths: frozenset[str],
) -> None:
"""Apply a replace/add operation to user data."""
path = (op.path or "").lower()
if not path:
# No path — value is a resource dict of top-level attributes to set.
if isinstance(op.value, ScimPatchResourceValue):
for key, val in op.value.model_dump(exclude_unset=True).items():
_set_user_field(key.lower(), val, ctx, ignored_paths, strict=False)
else:
raise ScimPatchError("Replace without path requires a dict value")
return
_set_user_field(path, op.value, ctx, ignored_paths)
def _apply_user_remove(
op: ScimPatchOperation,
ctx: _UserPatchCtx,
ignored_paths: frozenset[str],
) -> None:
"""Apply a remove operation to user data — clears the target field."""
path = (op.path or "").lower()
if not path:
raise ScimPatchError("Remove operation requires a path")
if path in ignored_paths:
return
entry = _USER_REMOVE_PATHS.get(path)
if entry:
key, target = entry
target_dict = ctx.data if target == "data" else ctx.name_data
target_dict[key] = None
return
raise ScimPatchError(f"Unsupported remove path '{path}' for User PATCH")
def _set_user_field(
path: str,
value: ScimPatchValue,
ctx: _UserPatchCtx,
ignored_paths: frozenset[str],
*,
strict: bool = True,
) -> None:
"""Set a single field on user data by SCIM path.
Args:
strict: When ``False`` (path-less replace), unknown attributes are
silently skipped. When ``True`` (explicit path), they raise.
"""
if path in ignored_paths:
return
# Simple field writes handled by the dispatch table
entry = _USER_REPLACE_PATHS.get(path)
if entry:
key, target = entry
target_dict = ctx.data if target == "data" else ctx.name_data
target_dict[key] = value
return
# displayName sets both the top-level field and the name.formatted sub-field
if path == "displayname":
ctx.data["displayName"] = value
ctx.name_data["formatted"] = value
elif path == "name":
if isinstance(value, dict):
for k, v in value.items():
ctx.name_data[k] = v
elif path == "emails":
if isinstance(value, list):
ctx.data["emails"] = value
elif _EMAIL_FILTER_RE.match(path):
_update_primary_email(ctx.data, value)
elif path.startswith(_ENTERPRISE_URN_LOWER):
_set_enterprise_field(path, value, ctx.ent_data)
elif not strict:
return
else:
raise ScimPatchError(f"Unsupported path '{path}' for User PATCH")
def _update_primary_email(data: dict[str, Any], value: ScimPatchValue) -> None:
"""Update the primary email entry via an email filter path."""
emails: list[dict] = data.get("emails") or []
for email_entry in emails:
if email_entry.get("primary"):
email_entry["value"] = value
break
else:
emails.append({"value": value, "type": "work", "primary": True})
data["emails"] = emails
def _to_dict(value: ScimPatchValue) -> dict | None:
"""Coerce a SCIM patch value to a plain dict if possible.
Pydantic may parse raw dicts as ``ScimPatchResourceValue`` (which uses
``extra="allow"``), so we also dump those back to a dict.
"""
if isinstance(value, dict):
return value
if isinstance(value, ScimPatchResourceValue):
return value.model_dump(exclude_unset=True)
return None
def _set_enterprise_field(
path: str,
value: ScimPatchValue,
ent_data: dict[str, str | None],
) -> None:
"""Handle enterprise extension URN paths or value dicts."""
# Full URN as key with dict value (path-less PATCH)
# e.g. key="urn:...:user", value={"department": "Eng", "manager": {...}}
if path == _ENTERPRISE_URN_LOWER:
d = _to_dict(value)
if d is not None:
if "department" in d:
ent_data["department"] = d["department"]
if "manager" in d:
mgr = d["manager"]
if isinstance(mgr, dict):
ent_data["manager"] = mgr.get("value")
return
# Dotted URN path, e.g. "urn:...:user:department"
suffix = path[len(_ENTERPRISE_URN_LOWER) :].lstrip(":").lower()
if suffix == "department":
ent_data["department"] = str(value) if value is not None else None
elif suffix == "manager":
d = _to_dict(value)
if d is not None:
ent_data["manager"] = d.get("value")
elif isinstance(value, str):
ent_data["manager"] = value
else:
# Unknown enterprise attributes are silently ignored rather than
# rejected — IdPs may send attributes we don't model yet.
logger.warning("Ignoring unknown enterprise extension attribute '%s'", suffix)
# ---------------------------------------------------------------------------
# Group PATCH
# ---------------------------------------------------------------------------
def apply_group_patch(
operations: list[ScimPatchOperation],
current: ScimGroupResource,
ignored_paths: frozenset[str] = frozenset(),
) -> tuple[ScimGroupResource, list[str], list[str]]:
"""Apply SCIM PATCH operations to a group resource.
Args:
operations: The PATCH operations to apply.
current: The current group resource state.
ignored_paths: SCIM attribute paths to silently skip (from provider).
Returns:
A tuple of (modified group, added member IDs, removed member IDs).
The caller uses the member ID lists to update the database.
Raises:
ScimPatchError: If an operation targets an unsupported path.
"""
data = current.model_dump()
current_members: list[dict] = list(data.get("members") or [])
added_ids: list[str] = []
removed_ids: list[str] = []
for op in operations:
if op.op == ScimPatchOperationType.REPLACE:
_apply_group_replace(
op, data, current_members, added_ids, removed_ids, ignored_paths
)
elif op.op == ScimPatchOperationType.ADD:
_apply_group_add(op, current_members, added_ids)
elif op.op == ScimPatchOperationType.REMOVE:
_apply_group_remove(op, current_members, removed_ids)
else:
raise ScimPatchError(
f"Unsupported operation '{op.op.value}' on Group resource"
)
data["members"] = current_members
group = ScimGroupResource.model_validate(data)
return group, added_ids, removed_ids
def _apply_group_replace(
op: ScimPatchOperation,
data: dict,
current_members: list[dict],
added_ids: list[str],
removed_ids: list[str],
ignored_paths: frozenset[str],
) -> None:
"""Apply a replace operation to group data."""
path = (op.path or "").lower()
if not path:
if isinstance(op.value, ScimPatchResourceValue):
dumped = op.value.model_dump(exclude_unset=True)
for key, val in dumped.items():
if key.lower() == "members":
_replace_members(val, current_members, added_ids, removed_ids)
else:
_set_group_field(key.lower(), val, data, ignored_paths)
else:
raise ScimPatchError("Replace without path requires a dict value")
return
if path == "members":
_replace_members(
_members_to_dicts(op.value), current_members, added_ids, removed_ids
)
return
_set_group_field(path, op.value, data, ignored_paths)
def _members_to_dicts(
value: str | bool | list[ScimGroupMember] | ScimPatchResourceValue | None,
) -> list[dict]:
"""Convert a member list value to a list of dicts for internal processing."""
if not isinstance(value, list):
raise ScimPatchError("Replace members requires a list value")
return [m.model_dump(exclude_none=True) for m in value]
def _replace_members(
value: list[dict],
current_members: list[dict],
added_ids: list[str],
removed_ids: list[str],
) -> None:
"""Replace the entire group member list."""
old_ids = {m["value"] for m in current_members}
new_ids = {m.get("value", "") for m in value}
removed_ids.extend(old_ids - new_ids)
added_ids.extend(new_ids - old_ids)
current_members[:] = value
def _set_group_field(
path: str,
value: ScimPatchValue,
data: dict,
ignored_paths: frozenset[str],
) -> None:
"""Set a single field on group data by SCIM path."""
if path in ignored_paths:
return
entry = _GROUP_REPLACE_PATHS.get(path)
if entry:
key, _ = entry
data[key] = value
return
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
def _apply_group_add(
op: ScimPatchOperation,
members: list[dict],
added_ids: list[str],
) -> None:
"""Add members to a group."""
path = (op.path or "").lower()
if path and path != "members":
raise ScimPatchError(f"Unsupported add path '{op.path}' for Group")
if not isinstance(op.value, list):
raise ScimPatchError("Add members requires a list value")
member_dicts = [m.model_dump(exclude_none=True) for m in op.value]
existing_ids = {m["value"] for m in members}
for member_data in member_dicts:
member_id = member_data.get("value", "")
if member_id and member_id not in existing_ids:
members.append(member_data)
added_ids.append(member_id)
existing_ids.add(member_id)
def _apply_group_remove(
op: ScimPatchOperation,
members: list[dict],
removed_ids: list[str],
) -> None:
"""Remove members from a group."""
if not op.path:
raise ScimPatchError("Remove operation requires a path")
match = _MEMBER_FILTER_RE.match(op.path)
if not match:
raise ScimPatchError(
f"Unsupported remove path '{op.path}'. "
'Expected: members[value eq "user-id"]'
)
target_id = match.group(1)
original_len = len(members)
members[:] = [m for m in members if m.get("value") != target_id]
if len(members) < original_len:
removed_ids.append(target_id)

View File

@@ -0,0 +1,210 @@
"""Base SCIM provider abstraction."""
from __future__ import annotations
import json
import logging
from abc import ABC
from abc import abstractmethod
from uuid import UUID
from pydantic import ValidationError
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
from ee.onyx.server.scim.models import ScimEmail
from ee.onyx.server.scim.models import ScimEnterpriseExtension
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimManagerRef
from ee.onyx.server.scim.models import ScimMappingFields
from ee.onyx.server.scim.models import ScimMeta
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimUserGroupRef
from ee.onyx.server.scim.models import ScimUserResource
from onyx.db.models import User
from onyx.db.models import UserGroup
logger = logging.getLogger(__name__)
COMMON_IGNORED_PATCH_PATHS: frozenset[str] = frozenset(
{
"id",
"schemas",
"meta",
}
)
class ScimProvider(ABC):
"""Base class for provider-specific SCIM behavior.
Subclass this to handle IdP-specific quirks. The base class provides
RFC 7643-compliant response builders that populate all standard fields.
"""
@property
@abstractmethod
def name(self) -> str:
"""Short identifier for this provider (e.g. ``"okta"``)."""
...
@property
@abstractmethod
def ignored_patch_paths(self) -> frozenset[str]:
"""SCIM attribute paths to silently skip in PATCH value-object dicts.
IdPs may include read-only or meta fields alongside actual changes
(e.g. Okta sends ``{"id": "...", "active": false}``). Paths listed
here are silently dropped instead of raising an error.
"""
...
@property
def user_schemas(self) -> list[str]:
"""Schema URIs to include in User resource responses.
Override in subclasses to advertise additional schemas (e.g. the
enterprise extension for Entra ID).
"""
return [SCIM_USER_SCHEMA]
def build_user_resource(
self,
user: User,
external_id: str | None = None,
groups: list[tuple[int, str]] | None = None,
scim_username: str | None = None,
fields: ScimMappingFields | None = None,
) -> ScimUserResource:
"""Build a SCIM User response from an Onyx User.
Args:
user: The Onyx user model.
external_id: The IdP's external identifier for this user.
groups: List of ``(group_id, group_name)`` tuples for the
``groups`` read-only attribute. Pass ``None`` or ``[]``
for newly-created users.
scim_username: The original-case userName from the IdP. Falls
back to ``user.email`` (lowercase) when not available.
fields: Stored mapping fields that the IdP expects round-tripped.
"""
f = fields or ScimMappingFields()
group_refs = [
ScimUserGroupRef(value=str(gid), display=gname)
for gid, gname in (groups or [])
]
username = scim_username or user.email
# Build enterprise extension when at least one value is present.
# Dynamically add the enterprise URN to schemas per RFC 7643 §3.0.
enterprise_ext: ScimEnterpriseExtension | None = None
schemas = list(self.user_schemas)
if f.department is not None or f.manager is not None:
manager_ref = (
ScimManagerRef(value=f.manager) if f.manager is not None else None
)
enterprise_ext = ScimEnterpriseExtension(
department=f.department,
manager=manager_ref,
)
if SCIM_ENTERPRISE_USER_SCHEMA not in schemas:
schemas.append(SCIM_ENTERPRISE_USER_SCHEMA)
name = self.build_scim_name(user, f)
emails = _deserialize_emails(f.scim_emails_json, username)
resource = ScimUserResource(
schemas=schemas,
id=str(user.id),
externalId=external_id,
userName=username,
name=name,
displayName=user.personal_name,
emails=emails,
active=user.is_active,
groups=group_refs,
meta=ScimMeta(resourceType="User"),
)
resource.enterprise_extension = enterprise_ext
return resource
def build_group_resource(
self,
group: UserGroup,
members: list[tuple[UUID, str | None]],
external_id: str | None = None,
) -> ScimGroupResource:
"""Build a SCIM Group response from an Onyx UserGroup."""
scim_members = [
ScimGroupMember(value=str(uid), display=email) for uid, email in members
]
return ScimGroupResource(
id=str(group.id),
externalId=external_id,
displayName=group.name,
members=scim_members,
meta=ScimMeta(resourceType="Group"),
)
def build_scim_name(
self,
user: User,
fields: ScimMappingFields,
) -> ScimName | None:
"""Build SCIM name components for the response.
Round-trips stored ``given_name``/``family_name`` when available (so
the IdP gets back what it sent). Falls back to splitting
``personal_name`` for users provisioned before we stored components.
Providers may override for custom behavior.
"""
if fields.given_name is not None or fields.family_name is not None:
return ScimName(
givenName=fields.given_name,
familyName=fields.family_name,
formatted=user.personal_name,
)
if not user.personal_name:
return None
parts = user.personal_name.split(" ", 1)
return ScimName(
givenName=parts[0],
familyName=parts[1] if len(parts) > 1 else None,
formatted=user.personal_name,
)
def _deserialize_emails(stored_json: str | None, username: str) -> list[ScimEmail]:
"""Deserialize stored email entries or build a default work email."""
if stored_json:
try:
entries = json.loads(stored_json)
if isinstance(entries, list) and entries:
return [ScimEmail(**e) for e in entries]
except (json.JSONDecodeError, TypeError, ValidationError):
logger.warning(
"Corrupt scim_emails_json, falling back to default: %s", stored_json
)
return [ScimEmail(value=username, type="work", primary=True)]
def serialize_emails(emails: list[ScimEmail]) -> str | None:
"""Serialize SCIM email entries to JSON for storage."""
if not emails:
return None
return json.dumps([e.model_dump(exclude_none=True) for e in emails])
def get_default_provider() -> ScimProvider:
"""Return the default SCIM provider.
Currently returns ``OktaProvider`` since Okta is the primary supported
IdP. When provider detection is added (via token metadata or tenant
config), this can be replaced with dynamic resolution.
"""
from ee.onyx.server.scim.providers.okta import OktaProvider
return OktaProvider()

View File

@@ -0,0 +1,36 @@
"""Entra ID (Azure AD) SCIM provider."""
from __future__ import annotations
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
from ee.onyx.server.scim.providers.base import ScimProvider
_ENTRA_IGNORED_PATCH_PATHS = COMMON_IGNORED_PATCH_PATHS
class EntraProvider(ScimProvider):
"""Entra ID (Azure AD) SCIM provider.
Entra behavioral notes:
- Sends capitalized PATCH ops (``"Add"``, ``"Replace"``, ``"Remove"``)
— handled by ``ScimPatchOperation.normalize_op`` validator.
- Sends the enterprise extension URN as a key in path-less PATCH value
dicts — handled by ``_set_enterprise_field`` in ``patch.py`` to
store department/manager values.
- Expects the enterprise extension schema in ``schemas`` arrays and
``/Schemas`` + ``/ResourceTypes`` discovery endpoints.
"""
@property
def name(self) -> str:
return "entra"
@property
def ignored_patch_paths(self) -> frozenset[str]:
return _ENTRA_IGNORED_PATCH_PATHS
@property
def user_schemas(self) -> list[str]:
return [SCIM_USER_SCHEMA, SCIM_ENTERPRISE_USER_SCHEMA]

View File

@@ -0,0 +1,26 @@
"""Okta SCIM provider."""
from __future__ import annotations
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
from ee.onyx.server.scim.providers.base import ScimProvider
class OktaProvider(ScimProvider):
"""Okta SCIM provider.
Okta behavioral notes:
- Uses ``PATCH {"active": false}`` for deprovisioning (not DELETE)
- Sends path-less PATCH with value dicts containing extra fields
(``id``, ``schemas``)
- Expects ``displayName`` and ``groups`` in user responses
- Only uses ``eq`` operator for ``userName`` filter
"""
@property
def name(self) -> str:
return "okta"
@property
def ignored_patch_paths(self) -> frozenset[str]:
return COMMON_IGNORED_PATCH_PATHS

View File

@@ -0,0 +1,173 @@
"""Static SCIM service discovery responses (RFC 7643 §5, §6, §7).
Pre-built at import time — these never change at runtime. Separated from
api.py to keep the endpoint module focused on request handling.
"""
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
from ee.onyx.server.scim.models import SCIM_GROUP_SCHEMA
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
from ee.onyx.server.scim.models import ScimResourceType
from ee.onyx.server.scim.models import ScimSchemaAttribute
from ee.onyx.server.scim.models import ScimSchemaDefinition
from ee.onyx.server.scim.models import ScimServiceProviderConfig
SERVICE_PROVIDER_CONFIG = ScimServiceProviderConfig()
USER_RESOURCE_TYPE = ScimResourceType.model_validate(
{
"id": "User",
"name": "User",
"endpoint": "/scim/v2/Users",
"description": "SCIM User resource",
"schema": SCIM_USER_SCHEMA,
"schemaExtensions": [
{"schema": SCIM_ENTERPRISE_USER_SCHEMA, "required": False}
],
}
)
GROUP_RESOURCE_TYPE = ScimResourceType.model_validate(
{
"id": "Group",
"name": "Group",
"endpoint": "/scim/v2/Groups",
"description": "SCIM Group resource",
"schema": SCIM_GROUP_SCHEMA,
}
)
USER_SCHEMA_DEF = ScimSchemaDefinition(
id=SCIM_USER_SCHEMA,
name="User",
description="SCIM core User schema",
attributes=[
ScimSchemaAttribute(
name="userName",
type="string",
required=True,
uniqueness="server",
description="Unique identifier for the user, typically an email address.",
),
ScimSchemaAttribute(
name="name",
type="complex",
description="The components of the user's name.",
subAttributes=[
ScimSchemaAttribute(
name="givenName",
type="string",
description="The user's first name.",
),
ScimSchemaAttribute(
name="familyName",
type="string",
description="The user's last name.",
),
ScimSchemaAttribute(
name="formatted",
type="string",
description="The full name, including all middle names and titles.",
),
],
),
ScimSchemaAttribute(
name="emails",
type="complex",
multiValued=True,
description="Email addresses for the user.",
subAttributes=[
ScimSchemaAttribute(
name="value",
type="string",
description="Email address value.",
),
ScimSchemaAttribute(
name="type",
type="string",
description="Label for this email (e.g. 'work').",
),
ScimSchemaAttribute(
name="primary",
type="boolean",
description="Whether this is the primary email.",
),
],
),
ScimSchemaAttribute(
name="active",
type="boolean",
description="Whether the user account is active.",
),
ScimSchemaAttribute(
name="externalId",
type="string",
description="Identifier from the provisioning client (IdP).",
caseExact=True,
),
],
)
ENTERPRISE_USER_SCHEMA_DEF = ScimSchemaDefinition(
id=SCIM_ENTERPRISE_USER_SCHEMA,
name="EnterpriseUser",
description="Enterprise User extension (RFC 7643 §4.3)",
attributes=[
ScimSchemaAttribute(
name="department",
type="string",
description="Department.",
),
ScimSchemaAttribute(
name="manager",
type="complex",
description="The user's manager.",
subAttributes=[
ScimSchemaAttribute(
name="value",
type="string",
description="Manager user ID.",
),
],
),
],
)
GROUP_SCHEMA_DEF = ScimSchemaDefinition(
id=SCIM_GROUP_SCHEMA,
name="Group",
description="SCIM core Group schema",
attributes=[
ScimSchemaAttribute(
name="displayName",
type="string",
required=True,
description="Human-readable name for the group.",
),
ScimSchemaAttribute(
name="members",
type="complex",
multiValued=True,
description="Members of the group.",
subAttributes=[
ScimSchemaAttribute(
name="value",
type="string",
description="User ID of the group member.",
),
ScimSchemaAttribute(
name="display",
type="string",
mutability="readOnly",
description="Display name of the group member.",
),
],
),
ScimSchemaAttribute(
name="externalId",
type="string",
description="Identifier from the provisioning client (IdP).",
caseExact=True,
),
],
)

View File

@@ -1,10 +1,13 @@
"""EE Settings API - provides license-aware settings override."""
from redis.exceptions import RedisError
from sqlalchemy.exc import SQLAlchemyError
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
from ee.onyx.db.license import get_cached_license_metadata
from ee.onyx.db.license import refresh_license_cache
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.server.settings.models import ApplicationStatus
from onyx.server.settings.models import Settings
from onyx.utils.logger import setup_logger
@@ -41,6 +44,14 @@ def check_ee_features_enabled() -> bool:
tenant_id = get_current_tenant_id()
try:
metadata = get_cached_license_metadata(tenant_id)
if not metadata:
# Cache miss — warm from DB so cold-start doesn't block EE features
try:
with get_session_with_current_tenant() as db_session:
metadata = refresh_license_cache(db_session, tenant_id)
except SQLAlchemyError as db_error:
logger.warning(f"Failed to load license from DB: {db_error}")
if metadata and metadata.status != _BLOCKING_STATUS:
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
return True
@@ -82,6 +93,18 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
tenant_id = get_current_tenant_id()
try:
metadata = get_cached_license_metadata(tenant_id)
if not metadata:
# Cache miss (e.g. after TTL expiry). Fall back to DB so
# the /settings request doesn't falsely return GATED_ACCESS
# while the cache is cold.
try:
with get_session_with_current_tenant() as db_session:
metadata = refresh_license_cache(db_session, tenant_id)
except SQLAlchemyError as db_error:
logger.warning(
f"Failed to load license from DB for settings: {db_error}"
)
if metadata:
if metadata.status == _BLOCKING_STATUS:
settings.application_status = metadata.status
@@ -90,7 +113,7 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
settings.ee_features_enabled = True
else:
# No license found.
# No license found in cache or DB.
if ENTERPRISE_EDITION_ENABLED:
# Legacy EE flag is set → prior EE usage (e.g. permission
# syncing) means indexed data may need protection.

View File

@@ -37,12 +37,15 @@ def list_user_groups(
db_session: Session = Depends(get_session),
) -> list[UserGroup]:
if user.role == UserRole.ADMIN:
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
user_groups = fetch_user_groups(
db_session, only_up_to_date=False, eager_load_for_snapshot=True
)
else:
user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=user.id,
only_curator_groups=user.role == UserRole.CURATOR,
eager_load_for_snapshot=True,
)
return [UserGroup.from_model(user_group) for user_group in user_groups]

View File

@@ -53,7 +53,8 @@ class UserGroup(BaseModel):
id=cc_pair_relationship.cc_pair.id,
name=cc_pair_relationship.cc_pair.name,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair_relationship.cc_pair.connector
cc_pair_relationship.cc_pair.connector,
credential_ids=[cc_pair_relationship.cc_pair.credential_id],
),
credential=CredentialSnapshot.from_credential_db_model(
cc_pair_relationship.cc_pair.credential

View File

@@ -58,16 +58,27 @@ class OAuthTokenManager:
if not user_token.token_data:
raise ValueError("No token data available for refresh")
if (
self.oauth_config.client_id is None
or self.oauth_config.client_secret is None
):
raise ValueError(
"OAuth client_id and client_secret are required for token refresh"
)
token_data = self._unwrap_token_data(user_token.token_data)
data: dict[str, str] = {
"grant_type": "refresh_token",
"refresh_token": token_data["refresh_token"],
"client_id": self._unwrap_sensitive_str(self.oauth_config.client_id),
"client_secret": self._unwrap_sensitive_str(
self.oauth_config.client_secret
),
}
response = requests.post(
self.oauth_config.token_url,
data={
"grant_type": "refresh_token",
"refresh_token": token_data["refresh_token"],
"client_id": self.oauth_config.client_id,
"client_secret": self.oauth_config.client_secret,
},
data=data,
headers={"Accept": "application/json"},
)
response.raise_for_status()
@@ -115,15 +126,26 @@ class OAuthTokenManager:
def exchange_code_for_token(self, code: str, redirect_uri: str) -> dict[str, Any]:
"""Exchange authorization code for access token"""
if (
self.oauth_config.client_id is None
or self.oauth_config.client_secret is None
):
raise ValueError(
"OAuth client_id and client_secret are required for code exchange"
)
data: dict[str, str] = {
"grant_type": "authorization_code",
"code": code,
"client_id": self._unwrap_sensitive_str(self.oauth_config.client_id),
"client_secret": self._unwrap_sensitive_str(
self.oauth_config.client_secret
),
"redirect_uri": redirect_uri,
}
response = requests.post(
self.oauth_config.token_url,
data={
"grant_type": "authorization_code",
"code": code,
"client_id": self.oauth_config.client_id,
"client_secret": self.oauth_config.client_secret,
"redirect_uri": redirect_uri,
},
data=data,
headers={"Accept": "application/json"},
)
response.raise_for_status()
@@ -141,8 +163,13 @@ class OAuthTokenManager:
oauth_config: OAuthConfig, redirect_uri: str, state: str
) -> str:
"""Build OAuth authorization URL"""
if oauth_config.client_id is None:
raise ValueError("OAuth client_id is required to build authorization URL")
params: dict[str, Any] = {
"client_id": oauth_config.client_id,
"client_id": OAuthTokenManager._unwrap_sensitive_str(
oauth_config.client_id
),
"redirect_uri": redirect_uri,
"response_type": "code",
"state": state,
@@ -161,6 +188,12 @@ class OAuthTokenManager:
return f"{oauth_config.authorization_url}{separator}{urlencode(params)}"
@staticmethod
def _unwrap_sensitive_str(value: SensitiveValue[str] | str) -> str:
if isinstance(value, SensitiveValue):
return value.get_value(apply_mask=False)
return value
@staticmethod
def _unwrap_token_data(
token_data: SensitiveValue[dict[str, Any]] | dict[str, Any],

View File

@@ -1,7 +1,9 @@
import uuid
from enum import Enum
from typing import Any
from fastapi_users import schemas
from typing_extensions import override
class UserRole(str, Enum):
@@ -41,8 +43,21 @@ class UserCreate(schemas.BaseUserCreate):
role: UserRole = UserRole.BASIC
tenant_id: str | None = None
# Captcha token for cloud signup protection (optional, only used when captcha is enabled)
# Excluded from create_update_dict so it never reaches the DB layer
captcha_token: str | None = None
@override
def create_update_dict(self) -> dict[str, Any]:
d = super().create_update_dict()
d.pop("captcha_token", None)
return d
@override
def create_update_dict_superuser(self) -> dict[str, Any]:
d = super().create_update_dict_superuser()
d.pop("captcha_token", None)
return d
class UserUpdateWithRole(schemas.BaseUserUpdate):
role: UserRole

View File

@@ -121,6 +121,7 @@ from onyx.db.pat import fetch_user_for_pat
from onyx.db.users import get_user_by_email
from onyx.redis.redis_pool import get_async_redis_connection
from onyx.redis.redis_pool import get_redis_client
from onyx.server.settings.store import load_settings
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import mt_cloud_telemetry
@@ -137,6 +138,8 @@ from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
REGISTER_INVITE_ONLY_CODE = "REGISTER_INVITE_ONLY"
def is_user_admin(user: User) -> bool:
return user.role == UserRole.ADMIN
@@ -208,22 +211,34 @@ def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool:
return int(value.decode("utf-8")) == 1
def workspace_invite_only_enabled() -> bool:
settings = load_settings()
return settings.invite_only_enabled
def verify_email_is_invited(email: str) -> None:
if AUTH_TYPE in {AuthType.SAML, AuthType.OIDC}:
# SSO providers manage membership; allow JIT provisioning regardless of invites
return
whitelist = get_invited_users()
if not whitelist:
if not workspace_invite_only_enabled():
return
whitelist = get_invited_users()
if not email:
raise PermissionError("Email must be specified")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "Email must be specified"},
)
try:
email_info = validate_email(email, check_deliverability=False)
except EmailUndeliverableError:
raise PermissionError("Email is not valid")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "Email is not valid"},
)
for email_whitelist in whitelist:
try:
@@ -240,7 +255,13 @@ def verify_email_is_invited(email: str) -> None:
if email_info.normalized.lower() == email_info_whitelist.normalized.lower():
return
raise PermissionError("User not on allowed user whitelist")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"code": REGISTER_INVITE_ONLY_CODE,
"reason": "This workspace is invite-only. Please ask your admin to invite you.",
},
)
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
@@ -256,13 +277,32 @@ def verify_email_domain(email: str) -> None:
detail="Email is not valid",
)
domain = email.split("@")[-1].lower()
local_part, domain = email.split("@")
domain = domain.lower()
if AUTH_TYPE == AuthType.CLOUD:
# Normalize googlemail.com to gmail.com (they deliver to the same inbox)
if domain == "googlemail.com":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "Please use @gmail.com instead of @googlemail.com."},
)
if "+" in local_part and domain != "onyx.app":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"reason": "Email addresses with '+' are not allowed. Please use your base email address."
},
)
# Check if email uses a disposable/temporary domain
if is_disposable_email(email):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Disposable email addresses are not allowed. Please use a permanent email address.",
detail={
"reason": "Disposable email addresses are not allowed. Please use a permanent email address."
},
)
# Check domain whitelist if configured
@@ -1650,7 +1690,10 @@ def get_oauth_router(
if redirect_url is not None:
authorize_redirect_url = redirect_url
else:
authorize_redirect_url = str(request.url_for(callback_route_name))
# Use WEB_DOMAIN instead of request.url_for() to prevent host
# header poisoning — request.url_for() trusts the Host header.
callback_path = request.app.url_path_for(callback_route_name)
authorize_redirect_url = f"{WEB_DOMAIN}{callback_path}"
next_url = request.query_params.get("next", "/")

View File

@@ -1,25 +1,30 @@
from collections.abc import Generator
from collections.abc import Iterator
from collections.abc import Sequence
from datetime import datetime
from datetime import timezone
from pathlib import Path
from typing import Any
from typing import cast
from typing import TypeVar
import httpx
from pydantic import BaseModel
from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
from onyx.connectors.connector_runner import batched_doc_ids
from onyx.connectors.connector_runner import CheckpointOutputWrapper
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import ConnectorCheckpoint
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import SlimDocument
@@ -29,63 +34,129 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
PRUNING_CHECKPOINTED_BATCH_SIZE = 32
CT = TypeVar("CT", bound=ConnectorCheckpoint)
def document_batch_to_ids(
doc_batch: (
Iterator[list[Document | HierarchyNode]]
| Iterator[list[SlimDocument | HierarchyNode]]
),
) -> Generator[set[str], None, None]:
for doc_list in doc_batch:
yield {
doc.raw_node_id if isinstance(doc, HierarchyNode) else doc.id
for doc in doc_list
}
class SlimConnectorExtractionResult(BaseModel):
"""Result of extracting document IDs and hierarchy nodes from a connector."""
doc_ids: set[str]
hierarchy_nodes: list[HierarchyNode]
def _checkpointed_batched_items(
connector: CheckpointedConnector[CT],
start: float,
end: float,
) -> Generator[list[Document | HierarchyNode | ConnectorFailure], None, None]:
"""Loop through all checkpoint steps and yield batched items.
Some checkpointed connectors (e.g. IMAP) are multi-step: the first
checkpoint call may only initialize internal state without yielding
any documents. This function loops until checkpoint.has_more is False
to ensure all items are collected across every step.
"""
checkpoint = connector.build_dummy_checkpoint()
while True:
checkpoint_output = connector.load_from_checkpoint(
start=start, end=end, checkpoint=checkpoint
)
wrapper: CheckpointOutputWrapper[CT] = CheckpointOutputWrapper()
batch: list[Document | HierarchyNode | ConnectorFailure] = []
for document, hierarchy_node, failure, next_checkpoint in wrapper(
checkpoint_output
):
if document is not None:
batch.append(document)
elif hierarchy_node is not None:
batch.append(hierarchy_node)
elif failure is not None:
batch.append(failure)
if next_checkpoint is not None:
checkpoint = next_checkpoint
if batch:
yield batch
if not checkpoint.has_more:
break
def _get_failure_id(failure: ConnectorFailure) -> str | None:
"""Extract the document/entity ID from a ConnectorFailure."""
if failure.failed_document:
return failure.failed_document.document_id
if failure.failed_entity:
return failure.failed_entity.entity_id
return None
def _extract_from_batch(
doc_list: Sequence[Document | SlimDocument | HierarchyNode | ConnectorFailure],
) -> tuple[set[str], list[HierarchyNode]]:
"""Separate a batch into document IDs and hierarchy nodes.
ConnectorFailure items have their failed document/entity IDs added to the
ID set so that failed-to-retrieve documents are not accidentally pruned.
"""
ids: set[str] = set()
hierarchy_nodes: list[HierarchyNode] = []
for item in doc_list:
if isinstance(item, HierarchyNode):
hierarchy_nodes.append(item)
ids.add(item.raw_node_id)
elif isinstance(item, ConnectorFailure):
failed_id = _get_failure_id(item)
if failed_id:
ids.add(failed_id)
logger.warning(
f"Failed to retrieve document {failed_id}: " f"{item.failure_message}"
)
else:
ids.add(item.id)
return ids, hierarchy_nodes
def extract_ids_from_runnable_connector(
runnable_connector: BaseConnector,
callback: IndexingHeartbeatInterface | None = None,
) -> set[str]:
) -> SlimConnectorExtractionResult:
"""
If the given connector is neither a SlimConnector nor a SlimConnectorWithPermSync, just pull
all docs using the load_from_state and grab out the IDs.
Extract document IDs and hierarchy nodes from a runnable connector.
Hierarchy nodes yielded alongside documents/slim docs are collected and
returned in the result. ConnectorFailure items have their IDs preserved
so that failed-to-retrieve documents are not accidentally pruned.
Optionally, a callback can be passed to handle the length of each document batch.
"""
all_connector_doc_ids: set[str] = set()
all_hierarchy_nodes: list[HierarchyNode] = []
# Sequence (covariant) lets all the specific list[...] iterator types unify here
raw_batch_generator: (
Iterator[Sequence[Document | SlimDocument | HierarchyNode | ConnectorFailure]]
| None
) = None
doc_batch_id_generator = None
if isinstance(runnable_connector, SlimConnector):
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.retrieve_all_slim_docs()
)
raw_batch_generator = runnable_connector.retrieve_all_slim_docs()
elif isinstance(runnable_connector, SlimConnectorWithPermSync):
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.retrieve_all_slim_docs_perm_sync()
)
raw_batch_generator = runnable_connector.retrieve_all_slim_docs_perm_sync()
# If the connector isn't slim, fall back to running it normally to get ids
elif isinstance(runnable_connector, LoadConnector):
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.load_from_state()
)
raw_batch_generator = runnable_connector.load_from_state()
elif isinstance(runnable_connector, PollConnector):
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
end = datetime.now(timezone.utc).timestamp()
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.poll_source(start=start, end=end)
)
raw_batch_generator = runnable_connector.poll_source(start=start, end=end)
elif isinstance(runnable_connector, CheckpointedConnector):
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
end = datetime.now(timezone.utc).timestamp()
checkpoint = runnable_connector.build_dummy_checkpoint()
checkpoint_generator = runnable_connector.load_from_checkpoint(
start=start, end=end, checkpoint=checkpoint
)
doc_batch_id_generator = batched_doc_ids(
checkpoint_generator, batch_size=PRUNING_CHECKPOINTED_BATCH_SIZE
raw_batch_generator = _checkpointed_batched_items(
runnable_connector, start, end
)
else:
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
@@ -99,19 +170,24 @@ def extract_ids_from_runnable_connector(
else lambda x: x
)
for doc_batch_ids in doc_batch_id_generator:
if callback:
if callback.should_stop():
raise RuntimeError(
"extract_ids_from_runnable_connector: Stop signal detected"
)
# process raw batches to extract both IDs and hierarchy nodes
for doc_list in raw_batch_generator:
if callback and callback.should_stop():
raise RuntimeError(
"extract_ids_from_runnable_connector: Stop signal detected"
)
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch_ids))
batch_ids, batch_nodes = _extract_from_batch(doc_list)
all_connector_doc_ids.update(doc_batch_processing_func(batch_ids))
all_hierarchy_nodes.extend(batch_nodes)
if callback:
callback.progress("extract_ids_from_runnable_connector", len(doc_batch_ids))
callback.progress("extract_ids_from_runnable_connector", len(batch_ids))
return all_connector_doc_ids
return SlimConnectorExtractionResult(
doc_ids=all_connector_doc_ids,
hierarchy_nodes=all_hierarchy_nodes,
)
def celery_is_listening_to_queue(worker: Any, name: str) -> bool:

View File

@@ -37,6 +37,7 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
redis_connector: RedisConnector,
redis_lock: RedisLock,
redis_client: Redis,
timeout_seconds: int | None = None,
):
super().__init__()
self.parent_pid = parent_pid
@@ -51,11 +52,29 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
self.last_lock_monotonic = time.monotonic()
self.last_parent_check = time.monotonic()
self.start_monotonic = time.monotonic()
self.timeout_seconds = timeout_seconds
def should_stop(self) -> bool:
# Check if the associated indexing attempt has been cancelled
# TODO: Pass index_attempt_id to the callback and check cancellation using the db
return bool(self.redis_connector.stop.fenced)
if bool(self.redis_connector.stop.fenced):
return True
# Check if the task has exceeded its timeout
# NOTE: Celery's soft_time_limit does not work with thread pools,
# so we must enforce timeouts internally.
if self.timeout_seconds is not None:
elapsed = time.monotonic() - self.start_monotonic
if elapsed > self.timeout_seconds:
logger.warning(
f"IndexingCallback Docprocessing - task timeout exceeded: "
f"elapsed={elapsed:.0f}s timeout={self.timeout_seconds}s "
f"cc_pair={self.redis_connector.cc_pair_id}"
)
return True
return False
def progress(self, tag: str, amount: int) -> None: # noqa: ARG002
"""Amount isn't used yet."""

View File

@@ -146,14 +146,26 @@ def _collect_queue_metrics(redis_celery: Redis) -> list[Metric]:
"""Collect metrics about queue lengths for different Celery queues"""
metrics = []
queue_mappings = {
"celery_queue_length": "celery",
"docprocessing_queue_length": "docprocessing",
"sync_queue_length": "sync",
"deletion_queue_length": "deletion",
"pruning_queue_length": "pruning",
"celery_queue_length": OnyxCeleryQueues.PRIMARY,
"docprocessing_queue_length": OnyxCeleryQueues.DOCPROCESSING,
"docfetching_queue_length": OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
"sync_queue_length": OnyxCeleryQueues.VESPA_METADATA_SYNC,
"deletion_queue_length": OnyxCeleryQueues.CONNECTOR_DELETION,
"pruning_queue_length": OnyxCeleryQueues.CONNECTOR_PRUNING,
"permissions_sync_queue_length": OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
"external_group_sync_queue_length": OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
"permissions_upsert_queue_length": OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT,
"hierarchy_fetching_queue_length": OnyxCeleryQueues.CONNECTOR_HIERARCHY_FETCHING,
"llm_model_update_queue_length": OnyxCeleryQueues.LLM_MODEL_UPDATE,
"checkpoint_cleanup_queue_length": OnyxCeleryQueues.CHECKPOINT_CLEANUP,
"index_attempt_cleanup_queue_length": OnyxCeleryQueues.INDEX_ATTEMPT_CLEANUP,
"csv_generation_queue_length": OnyxCeleryQueues.CSV_GENERATION,
"user_file_processing_queue_length": OnyxCeleryQueues.USER_FILE_PROCESSING,
"user_file_project_sync_queue_length": OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
"user_file_delete_queue_length": OnyxCeleryQueues.USER_FILE_DELETE,
"monitoring_queue_length": OnyxCeleryQueues.MONITORING,
"sandbox_queue_length": OnyxCeleryQueues.SANDBOX,
"opensearch_migration_queue_length": OnyxCeleryQueues.OPENSEARCH_MIGRATION,
}
for name, queue in queue_mappings.items():
@@ -881,7 +893,7 @@ def monitor_celery_queues_helper(
"""A task to monitor all celery queue lengths."""
r_celery = task.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r_celery)
n_celery = celery_get_queue_length(OnyxCeleryQueues.PRIMARY, r_celery)
n_docfetching = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
)
@@ -908,6 +920,26 @@ def monitor_celery_queues_helper(
n_permissions_upsert = celery_get_queue_length(
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
)
n_hierarchy_fetching = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_HIERARCHY_FETCHING, r_celery
)
n_llm_model_update = celery_get_queue_length(
OnyxCeleryQueues.LLM_MODEL_UPDATE, r_celery
)
n_checkpoint_cleanup = celery_get_queue_length(
OnyxCeleryQueues.CHECKPOINT_CLEANUP, r_celery
)
n_index_attempt_cleanup = celery_get_queue_length(
OnyxCeleryQueues.INDEX_ATTEMPT_CLEANUP, r_celery
)
n_csv_generation = celery_get_queue_length(
OnyxCeleryQueues.CSV_GENERATION, r_celery
)
n_monitoring = celery_get_queue_length(OnyxCeleryQueues.MONITORING, r_celery)
n_sandbox = celery_get_queue_length(OnyxCeleryQueues.SANDBOX, r_celery)
n_opensearch_migration = celery_get_queue_length(
OnyxCeleryQueues.OPENSEARCH_MIGRATION, r_celery
)
n_docfetching_prefetched = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery
@@ -931,6 +963,14 @@ def monitor_celery_queues_helper(
f"permissions_sync={n_permissions_sync} "
f"external_group_sync={n_external_group_sync} "
f"permissions_upsert={n_permissions_upsert} "
f"hierarchy_fetching={n_hierarchy_fetching} "
f"llm_model_update={n_llm_model_update} "
f"checkpoint_cleanup={n_checkpoint_cleanup} "
f"index_attempt_cleanup={n_index_attempt_cleanup} "
f"csv_generation={n_csv_generation} "
f"monitoring={n_monitoring} "
f"sandbox={n_sandbox} "
f"opensearch_migration={n_opensearch_migration} "
)

View File

@@ -41,3 +41,14 @@ assert (
CHECK_FOR_DOCUMENTS_TASK_LOCK_BLOCKING_TIMEOUT_S = 30 # 30 seconds.
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE = 15
# WARNING: Do not change these values without knowing what changes also need to
# be made to OpenSearchTenantMigrationRecord.
GET_VESPA_CHUNKS_PAGE_SIZE = 500
GET_VESPA_CHUNKS_SLICE_COUNT = 4
# String used to indicate in the vespa_visit_continuation_token mapping that the
# slice has finished and there is nothing left to visit.
FINISHED_VISITING_SLICE_CONTINUATION_TOKEN = (
"FINISHED_VISITING_SLICE_CONTINUATION_TOKEN"
)

View File

@@ -8,6 +8,12 @@ from celery import Task
from redis.lock import Lock as RedisLock
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.opensearch_migration.constants import (
FINISHED_VISITING_SLICE_CONTINUATION_TOKEN,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
GET_VESPA_CHUNKS_PAGE_SIZE,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
MIGRATION_TASK_LOCK_BLOCKING_TIMEOUT_S,
)
@@ -42,12 +48,19 @@ from onyx.document_index.opensearch.opensearch_document_index import (
OpenSearchDocumentIndex,
)
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
from onyx.indexing.models import IndexingSetting
from onyx.redis.redis_pool import get_redis_client
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
GET_VESPA_CHUNKS_PAGE_SIZE = 1000
def is_continuation_token_done_for_all_slices(
continuation_token_map: dict[int, str | None],
) -> bool:
return all(
continuation_token == FINISHED_VISITING_SLICE_CONTINUATION_TOKEN
for continuation_token in continuation_token_map.values()
)
# shared_task allows this task to be shared across celery app instances.
@@ -76,11 +89,15 @@ def migrate_chunks_from_vespa_to_opensearch_task(
Uses Vespa's Visit API to iterate through ALL chunks in bulk (not
per-document), transform them, and index them into OpenSearch. Progress is
tracked via a continuation token stored in the
tracked via a continuation token map stored in the
OpenSearchTenantMigrationRecord.
The first time we see no continuation token and non-zero chunks migrated, we
consider the migration complete and all subsequent invocations are no-ops.
The first time we see no continuation token map and non-zero chunks
migrated, we consider the migration complete and all subsequent invocations
are no-ops.
We divide the index into GET_VESPA_CHUNKS_SLICE_COUNT independent slices
where progress is tracked for each slice.
Returns:
None if OpenSearch migration is not enabled, or if the lock could not be
@@ -133,8 +150,12 @@ def migrate_chunks_from_vespa_to_opensearch_task(
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
search_settings = get_current_search_settings(db_session)
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
indexing_setting = IndexingSetting.from_db_model(search_settings)
opensearch_document_index = OpenSearchDocumentIndex(
index_name=search_settings.index_name, tenant_state=tenant_state
tenant_state=tenant_state,
index_name=search_settings.index_name,
embedding_dim=indexing_setting.final_embedding_dim,
embedding_precision=indexing_setting.embedding_precision,
)
vespa_document_index = VespaDocumentIndex(
index_name=search_settings.index_name,
@@ -153,15 +174,28 @@ def migrate_chunks_from_vespa_to_opensearch_task(
f"in {time.monotonic() - sanitized_doc_start_time:.3f} seconds."
)
approx_chunk_count_in_vespa: int | None = None
get_chunk_count_start_time = time.monotonic()
try:
approx_chunk_count_in_vespa = vespa_document_index.get_chunk_count()
except Exception:
task_logger.exception(
"Error getting approximate chunk count in Vespa. Moving on..."
)
task_logger.debug(
f"Took {time.monotonic() - get_chunk_count_start_time:.3f} seconds to attempt to get "
f"approximate chunk count in Vespa. Got {approx_chunk_count_in_vespa}."
)
while (
time.monotonic() - task_start_time < MIGRATION_TASK_SOFT_TIME_LIMIT_S
and lock.owned()
):
(
continuation_token,
continuation_token_map,
total_chunks_migrated,
) = get_vespa_visit_state(db_session)
if continuation_token is None and total_chunks_migrated > 0:
if is_continuation_token_done_for_all_slices(continuation_token_map):
task_logger.info(
f"OpenSearch migration COMPLETED for tenant {tenant_id}. "
f"Total chunks migrated: {total_chunks_migrated}."
@@ -170,19 +204,19 @@ def migrate_chunks_from_vespa_to_opensearch_task(
break
task_logger.debug(
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
f"Continuation token: {continuation_token}"
f"Continuation token map: {continuation_token_map}"
)
get_vespa_chunks_start_time = time.monotonic()
raw_vespa_chunks, next_continuation_token = (
raw_vespa_chunks, next_continuation_token_map = (
vespa_document_index.get_all_raw_document_chunks_paginated(
continuation_token=continuation_token,
continuation_token_map=continuation_token_map,
page_size=GET_VESPA_CHUNKS_PAGE_SIZE,
)
)
task_logger.debug(
f"Read {len(raw_vespa_chunks)} chunks from Vespa in {time.monotonic() - get_vespa_chunks_start_time:.3f} "
f"seconds. Next continuation token: {next_continuation_token}"
f"seconds. Next continuation token map: {next_continuation_token_map}"
)
opensearch_document_chunks, errored_chunks = (
@@ -212,14 +246,11 @@ def migrate_chunks_from_vespa_to_opensearch_task(
total_chunks_errored_this_task += len(errored_chunks)
update_vespa_visit_progress_with_commit(
db_session,
continuation_token=next_continuation_token,
continuation_token_map=next_continuation_token_map,
chunks_processed=len(opensearch_document_chunks),
chunks_errored=len(errored_chunks),
approx_chunk_count_in_vespa=approx_chunk_count_in_vespa,
)
if next_continuation_token is None and len(raw_vespa_chunks) == 0:
task_logger.info("Vespa reported no more chunks to migrate.")
break
except Exception:
traceback.print_exc()
task_logger.exception("Error in the OpenSearch migration task.")

View File

@@ -22,6 +22,7 @@ from onyx.document_index.vespa_constants import HIDDEN
from onyx.document_index.vespa_constants import IMAGE_FILE_NAME
from onyx.document_index.vespa_constants import METADATA_LIST
from onyx.document_index.vespa_constants import METADATA_SUFFIX
from onyx.document_index.vespa_constants import PERSONAS
from onyx.document_index.vespa_constants import PRIMARY_OWNERS
from onyx.document_index.vespa_constants import SECONDARY_OWNERS
from onyx.document_index.vespa_constants import SEMANTIC_IDENTIFIER
@@ -37,6 +38,36 @@ from shared_configs.configs import MULTI_TENANT
logger = setup_logger(__name__)
FIELDS_NEEDED_FOR_TRANSFORMATION: list[str] = [
DOCUMENT_ID,
CHUNK_ID,
TITLE,
TITLE_EMBEDDING,
CONTENT,
EMBEDDINGS,
SOURCE_TYPE,
METADATA_LIST,
DOC_UPDATED_AT,
HIDDEN,
BOOST,
SEMANTIC_IDENTIFIER,
IMAGE_FILE_NAME,
SOURCE_LINKS,
BLURB,
DOC_SUMMARY,
CHUNK_CONTEXT,
METADATA_SUFFIX,
DOCUMENT_SETS,
USER_PROJECT,
PERSONAS,
PRIMARY_OWNERS,
SECONDARY_OWNERS,
ACCESS_CONTROL_LIST,
]
if MULTI_TENANT:
FIELDS_NEEDED_FOR_TRANSFORMATION.append(TENANT_ID)
def _extract_content_vector(embeddings: Any) -> list[float]:
"""Extracts the full chunk embedding vector from Vespa's embeddings tensor.
@@ -247,6 +278,7 @@ def transform_vespa_chunks_to_opensearch_chunks(
)
)
user_projects: list[int] | None = vespa_chunk.get(USER_PROJECT)
personas: list[int] | None = vespa_chunk.get(PERSONAS)
primary_owners: list[str] | None = vespa_chunk.get(PRIMARY_OWNERS)
secondary_owners: list[str] | None = vespa_chunk.get(SECONDARY_OWNERS)
@@ -296,6 +328,7 @@ def transform_vespa_chunks_to_opensearch_chunks(
metadata_suffix=metadata_suffix,
document_sets=document_sets,
user_projects=user_projects,
personas=personas,
primary_owners=primary_owners,
secondary_owners=secondary_owners,
tenant_id=tenant_state,

View File

@@ -43,9 +43,11 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.document import get_documents_for_connector_credential_pair
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.hierarchy import upsert_hierarchy_nodes_batch
from onyx.db.models import ConnectorCredentialPair
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
@@ -53,6 +55,9 @@ from onyx.db.tag import delete_orphan_tags__no_commit
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
from onyx.redis.redis_hierarchy import ensure_source_node_exists
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.runtime.onyx_runtime import OnyxRuntime
@@ -523,12 +528,47 @@ def connector_pruning_generator_task(
redis_connector,
lock,
r,
timeout_seconds=JOB_TIMEOUT,
)
# a list of docs in the source
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
# Extract docs and hierarchy nodes from the source
extraction_result = extract_ids_from_runnable_connector(
runnable_connector, callback
)
all_connector_doc_ids = extraction_result.doc_ids
# Process hierarchy nodes (same as docfetching):
# upsert to Postgres and cache in Redis
if extraction_result.hierarchy_nodes:
is_connector_public = cc_pair.access_type == AccessType.PUBLIC
redis_client = get_redis_client(tenant_id=tenant_id)
ensure_source_node_exists(
redis_client, db_session, cc_pair.connector.source
)
upserted_nodes = upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=extraction_result.hierarchy_nodes,
source=cc_pair.connector.source,
commit=True,
is_connector_public=is_connector_public,
)
cache_entries = [
HierarchyNodeCacheEntry.from_db_model(node)
for node in upserted_nodes
]
cache_hierarchy_nodes_batch(
redis_client=redis_client,
source=cc_pair.connector.source,
entries=cache_entries,
)
task_logger.info(
f"Pruning: persisted and cached {len(extraction_result.hierarchy_nodes)} "
f"hierarchy nodes for cc_pair={cc_pair_id}"
)
# a list of docs in our local index
all_indexed_document_ids = {

View File

@@ -5,14 +5,18 @@ from uuid import UUID
import httpx
import sqlalchemy as sa
from celery import Celery
from celery import shared_task
from celery import Task
from redis import Redis
from redis.lock import Lock as RedisLock
from retry import retry
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.configs.app_configs import DISABLE_VECTOR_DB
@@ -21,12 +25,16 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
from onyx.configs.constants import USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH
from onyx.connectors.file.connector import LocalFileConnector
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
@@ -57,14 +65,73 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
def _user_file_queued_key(user_file_id: str | UUID) -> str:
"""Key that exists while a process_single_user_file task is sitting in the queue.
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
before enqueuing and the worker deletes it as its first action. This prevents
the beat from adding duplicate tasks for files that already have a live task
in flight.
"""
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
def user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
def _user_file_project_sync_queued_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_QUEUED_PREFIX}:{user_file_id}"
def _user_file_delete_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_DELETE_LOCK_PREFIX}:{user_file_id}"
def get_user_file_project_sync_queue_depth(celery_app: Celery) -> int:
redis_celery: Redis = celery_app.broker_connection().channel().client # type: ignore
return celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, redis_celery
)
def enqueue_user_file_project_sync_task(
*,
celery_app: Celery,
redis_client: Redis,
user_file_id: str | UUID,
tenant_id: str,
priority: OnyxCeleryPriority = OnyxCeleryPriority.HIGH,
) -> bool:
"""Enqueue a project-sync task if no matching queued task already exists."""
queued_key = _user_file_project_sync_queued_key(user_file_id)
# NX+EX gives us atomic dedupe and a self-healing TTL.
queued_guard_set = redis_client.set(
queued_key,
1,
nx=True,
ex=CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES,
)
if not queued_guard_set:
return False
try:
celery_app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
priority=priority,
expires=CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES,
)
except Exception:
# Roll back the queued guard if task publish fails.
redis_client.delete(queued_key)
raise
return True
@retry(tries=3, delay=1, backoff=2, jitter=(0.0, 1.0))
def _visit_chunks(
*,
@@ -120,7 +187,24 @@ def _get_document_chunk_count(
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
Uses direct Redis locks to avoid overlapping runs.
Three mechanisms prevent queue runaway:
1. **Queue depth backpressure** if the broker queue already has more than
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
entirely. Workers are clearly behind; adding more tasks would only make
the backlog worse.
2. **Per-file queued guard** before enqueuing a task we set a short-lived
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
already exists the file already has a live task in the queue, so we skip
it. The worker deletes the key the moment it picks up the task so the
next beat cycle can re-enqueue if the file is still PROCESSING.
3. **Task expiry** every enqueued task carries an `expires` value equal to
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
the queue after that deadline, Celery discards it without touching the DB.
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
Redis restart), stale tasks evict themselves rather than piling up forever.
"""
task_logger.info("check_user_file_processing - Starting")
@@ -135,7 +219,21 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
return None
enqueued = 0
skipped_guard = 0
try:
# --- Protection 1: queue depth backpressure ---
r_celery = self.app.broker_connection().channel().client # type: ignore
queue_len = celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
)
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
task_logger.warning(
f"check_user_file_processing - Queue depth {queue_len} exceeds "
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
f"tenant={tenant_id}"
)
return None
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
@@ -148,12 +246,35 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
)
for user_file_id in user_file_ids:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
# --- Protection 2: per-file queued guard ---
queued_key = _user_file_queued_key(user_file_id)
guard_set = redis_client.set(
queued_key,
1,
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
nx=True,
)
if not guard_set:
skipped_guard += 1
continue
# --- Protection 3: task expiry ---
# If task submission fails, clear the guard immediately so the
# next beat cycle can retry enqueuing this file.
try:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={
"user_file_id": str(user_file_id),
"tenant_id": tenant_id,
},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
)
except Exception:
redis_client.delete(queued_key)
raise
enqueued += 1
finally:
@@ -161,7 +282,8 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
lock.release()
task_logger.info(
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
f"tasks for tenant={tenant_id}"
)
return None
@@ -304,6 +426,12 @@ def process_single_user_file(
start = time.monotonic()
redis_client = get_redis_client(tenant_id=tenant_id)
# Clear the "queued" guard set by the beat generator so that the next beat
# cycle can re-enqueue this file if it is still in PROCESSING state after
# this task completes or fails.
redis_client.delete(_user_file_queued_key(user_file_id))
file_lock: RedisLock = redis_client.lock(
_user_file_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
@@ -557,8 +685,8 @@ def process_single_user_file_delete(
ignore_result=True,
)
def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
"""Scan for user files with PROJECT_SYNC status and enqueue per-file tasks."""
task_logger.info("check_for_user_file_project_sync - Starting")
"""Scan for user files needing project sync and enqueue per-file tasks."""
task_logger.info("Starting")
redis_client = get_redis_client(tenant_id=tenant_id)
lock: RedisLock = redis_client.lock(
@@ -570,13 +698,25 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
return None
enqueued = 0
skipped_guard = 0
try:
queue_depth = get_user_file_project_sync_queue_depth(self.app)
if queue_depth > USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH:
task_logger.warning(
f"Queue depth {queue_depth} exceeds "
f"{USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH}, skipping enqueue for tenant={tenant_id}"
)
return None
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
select(UserFile.id).where(
sa.and_(
UserFile.needs_project_sync.is_(True),
sa.or_(
UserFile.needs_project_sync.is_(True),
UserFile.needs_persona_sync.is_(True),
),
UserFile.status == UserFileStatus.COMPLETED,
)
)
@@ -586,19 +726,23 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
)
for user_file_id in user_file_ids:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
if not enqueue_user_file_project_sync_task(
celery_app=self.app,
redis_client=redis_client,
user_file_id=user_file_id,
tenant_id=tenant_id,
priority=OnyxCeleryPriority.HIGH,
)
):
skipped_guard += 1
continue
enqueued += 1
finally:
if lock.owned():
lock.release()
task_logger.info(
f"check_for_user_file_project_sync - Enqueued {enqueued} tasks for tenant={tenant_id}"
f"Enqueued {enqueued} "
f"Skipped guard {skipped_guard} tasks for tenant={tenant_id}"
)
return None
@@ -617,8 +761,10 @@ def process_single_user_file_project_sync(
)
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
file_lock: RedisLock = redis_client.lock(
_user_file_project_sync_lock_key(user_file_id),
user_file_project_sync_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
)
@@ -630,7 +776,11 @@ def process_single_user_file_project_sync(
try:
with get_session_with_current_tenant() as db_session:
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
user_file = db_session.execute(
select(UserFile)
.where(UserFile.id == _as_uuid(user_file_id))
.options(selectinload(UserFile.assistants))
).scalar_one_or_none()
if not user_file:
task_logger.info(
f"process_single_user_file_project_sync - User file not found id={user_file_id}"
@@ -658,13 +808,17 @@ def process_single_user_file_project_sync(
]
project_ids = [project.id for project in user_file.projects]
persona_ids = [p.id for p in user_file.assistants if not p.deleted]
for retry_document_index in retry_document_indices:
retry_document_index.update_single(
doc_id=str(user_file.id),
tenant_id=tenant_id,
chunk_count=user_file.chunk_count,
fields=None,
user_fields=VespaDocumentUserFields(user_projects=project_ids),
user_fields=VespaDocumentUserFields(
user_projects=project_ids,
personas=persona_ids,
),
)
task_logger.info(
@@ -672,6 +826,7 @@ def process_single_user_file_project_sync(
)
user_file.needs_project_sync = False
user_file.needs_persona_sync = False
user_file.last_project_sync_at = datetime.datetime.now(
datetime.timezone.utc
)

View File

@@ -58,6 +58,8 @@ from onyx.file_store.document_batch_storage import DocumentBatchStorage
from onyx.file_store.document_batch_storage import get_document_batch_storage
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
from onyx.indexing.postgres_sanitization import sanitize_document_for_postgres
from onyx.indexing.postgres_sanitization import sanitize_hierarchy_nodes_for_postgres
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
from onyx.redis.redis_hierarchy import ensure_source_node_exists
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
@@ -156,36 +158,7 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
logger.warning(
f"doc {doc.id} too large, Document size: {sys.getsizeof(doc)}"
)
cleaned_doc = doc.model_copy()
# Postgres cannot handle NUL characters in text fields
if "\x00" in cleaned_doc.id:
logger.warning(f"NUL characters found in document ID: {cleaned_doc.id}")
cleaned_doc.id = cleaned_doc.id.replace("\x00", "")
if cleaned_doc.title and "\x00" in cleaned_doc.title:
logger.warning(
f"NUL characters found in document title: {cleaned_doc.title}"
)
cleaned_doc.title = cleaned_doc.title.replace("\x00", "")
if "\x00" in cleaned_doc.semantic_identifier:
logger.warning(
f"NUL characters found in document semantic identifier: {cleaned_doc.semantic_identifier}"
)
cleaned_doc.semantic_identifier = cleaned_doc.semantic_identifier.replace(
"\x00", ""
)
for section in cleaned_doc.sections:
if section.link is not None:
section.link = section.link.replace("\x00", "")
# since text can be longer, just replace to avoid double scan
if isinstance(section, TextSection) and section.text is not None:
section.text = section.text.replace("\x00", "")
cleaned_batch.append(cleaned_doc)
cleaned_batch.append(sanitize_document_for_postgres(doc))
return cleaned_batch
@@ -602,10 +575,13 @@ def connector_document_extraction(
# Process hierarchy nodes batch - upsert to Postgres and cache in Redis
if hierarchy_node_batch:
hierarchy_node_batch_cleaned = (
sanitize_hierarchy_nodes_for_postgres(hierarchy_node_batch)
)
with get_session_with_current_tenant() as db_session:
upserted_nodes = upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=hierarchy_node_batch,
nodes=hierarchy_node_batch_cleaned,
source=db_connector.source,
commit=True,
is_connector_public=is_connector_public,
@@ -624,7 +600,7 @@ def connector_document_extraction(
)
logger.debug(
f"Persisted and cached {len(hierarchy_node_batch)} hierarchy nodes "
f"Persisted and cached {len(hierarchy_node_batch_cleaned)} hierarchy nodes "
f"for attempt={index_attempt_id}"
)

View File

@@ -9,10 +9,8 @@ Summaries are stored as `ChatMessage` records with two key fields:
- `parent_message_id` → last message when compression triggered (places summary in the tree)
- `last_summarized_message_id` → pointer to an older message up the chain (the cutoff). Messages after this are kept verbatim.
**Why store summary as a separate message?** If we embedded the summary in the `last_summarized_message_id` message itself, that message would contain context from messages that came after it—context that doesn't exist in other branches. By creating the summary as a new message attached to the branch tip, it only applies to the specific branch where compression occurred.
### Timestamp-Based Ordering
Messages are filtered by `time_sent` (not ID) so the logic remains intact if IDs are changed to UUIDs in the future.
**Why store summary as a separate message?** If we embedded the summary in the `last_summarized_message_id` message itself, that message would contain context from messages that came after it—context that doesn't exist in other branches. By creating the summary as a new message attached to the branch tip, it only applies to the specific branch where compression occurred. It's only back-pointed to by the
branch which it applies to. All of this is necessary because we keep the last few messages verbatim and also to support branching logic.
### Progressive Summarization
Subsequent compressions incorporate the existing summary text + new messages, preventing information loss in very long conversations.
@@ -26,10 +24,11 @@ Context window breakdown:
- `max_context_tokens` — LLM's total context window
- `reserved_tokens` — space for system prompt, tools, files, etc.
- Available for chat history = `max_context_tokens - reserved_tokens`
Note: If there is a lot of reserved tokens, chat compression may happen fairly frequently which is costly, slow, and leads to a bad user experience. Possible area of future improvement.
Configurable ratios:
- `COMPRESSION_TRIGGER_RATIO` (default 0.75) — compress when chat history exceeds this ratio of available space
- `RECENT_MESSAGES_RATIO` (default 0.25) — portion of chat history to keep verbatim when compressing
- `RECENT_MESSAGES_RATIO` (default 0.2) — portion of chat history to keep verbatim when compressing
## Flow

View File

@@ -3,7 +3,6 @@ import time
from collections.abc import Callable
from collections.abc import Generator
from queue import Empty
from typing import Any
from onyx.chat.citation_processor import CitationMapping
from onyx.chat.emitter import Emitter
@@ -163,13 +162,11 @@ class ChatStateContainer:
def run_chat_loop_with_state_containers(
func: Callable[..., None],
chat_loop_func: Callable[[Emitter, ChatStateContainer], None],
completion_callback: Callable[[ChatStateContainer], None],
is_connected: Callable[[], bool],
emitter: Emitter,
state_container: ChatStateContainer,
*args: Any,
**kwargs: Any,
) -> Generator[Packet, None]:
"""
Explicit wrapper function that runs a function in a background thread
@@ -180,19 +177,18 @@ def run_chat_loop_with_state_containers(
Args:
func: The function to wrap (should accept emitter and state_container as first and second args)
completion_callback: Callback function to call when the function completes
emitter: Emitter instance for sending packets
state_container: ChatStateContainer instance for accumulating state
is_connected: Callable that returns False when stop signal is set
*args: Additional positional arguments for func
**kwargs: Additional keyword arguments for func
Usage:
packets = run_chat_loop_with_state_containers(
my_func,
completion_callback=completion_callback,
emitter=emitter,
state_container=state_container,
is_connected=check_func,
arg1, arg2, kwarg1=value1
)
for packet in packets:
# Process packets
@@ -201,9 +197,7 @@ def run_chat_loop_with_state_containers(
def run_with_exception_capture() -> None:
try:
# Ensure state_container is passed explicitly, removing it from kwargs if present
kwargs_with_state = {**kwargs, "state_container": state_container}
func(emitter, *args, **kwargs_with_state)
chat_loop_func(emitter, state_container)
except Exception as e:
# If execution fails, emit an exception packet
emitter.emit(

View File

@@ -1,36 +1,29 @@
import json
import re
from collections.abc import Callable
from typing import cast
from uuid import UUID
from fastapi import HTTPException
from fastapi.datastructures import Headers
from sqlalchemy.orm import Session
from onyx.auth.users import is_user_admin
from onyx.chat.models import ChatHistoryResult
from onyx.chat.models import ChatLoadedFile
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import FileToolMetadata
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import ToolCallSimple
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.configs.constants import MessageType
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.chat import create_chat_session
from onyx.db.chat import get_chat_messages_by_session
from onyx.db.chat import get_or_create_root_message
from onyx.db.kg_config import get_kg_config_settings
from onyx.db.kg_config import is_kg_config_settings_enabled_valid
from onyx.db.llm import fetch_existing_doc_sets
from onyx.db.llm import fetch_existing_tools
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
from onyx.db.models import Persona
from onyx.db.models import SearchDoc as DbSearchDoc
from onyx.db.models import Tool
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.db.projects import check_project_ownership
from onyx.file_processing.extract_file_text import extract_file_text
@@ -47,15 +40,13 @@ from onyx.prompts.tool_prompts import TOOL_CALL_FAILURE_PROMPT
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.tools.models import ToolCallKickoff
from onyx.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.timing import log_function_time
logger = setup_logger()
IMAGE_GENERATION_TOOL_NAME = "generate_image"
def create_chat_session_from_request(
@@ -278,70 +269,6 @@ def extract_headers(
return extracted_headers
def create_temporary_persona(
persona_config: PersonaOverrideConfig, db_session: Session, user: User
) -> Persona:
if not is_user_admin(user):
raise HTTPException(
status_code=403,
detail="User is not authorized to create a persona in one shot queries",
)
"""Create a temporary Persona object from the provided configuration."""
persona = Persona(
name=persona_config.name,
description=persona_config.description,
num_chunks=persona_config.num_chunks,
llm_relevance_filter=persona_config.llm_relevance_filter,
llm_filter_extraction=persona_config.llm_filter_extraction,
recency_bias=RecencyBiasSetting.BASE_DECAY,
llm_model_provider_override=persona_config.llm_model_provider_override,
llm_model_version_override=persona_config.llm_model_version_override,
)
if persona_config.prompts:
# Use the first prompt from the override config for embedded prompt fields
first_prompt = persona_config.prompts[0]
persona.system_prompt = first_prompt.system_prompt
persona.task_prompt = first_prompt.task_prompt
persona.datetime_aware = first_prompt.datetime_aware
persona.tools = []
if persona_config.custom_tools_openapi:
from onyx.chat.emitter import get_default_emitter
for schema in persona_config.custom_tools_openapi:
tools = cast(
list[Tool],
build_custom_tools_from_openapi_schema_and_headers(
tool_id=0, # dummy tool id
openapi_schema=schema,
emitter=get_default_emitter(),
),
)
persona.tools.extend(tools)
if persona_config.tools:
tool_ids = [tool.id for tool in persona_config.tools]
persona.tools.extend(
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
)
if persona_config.tool_ids:
persona.tools.extend(
fetch_existing_tools(
db_session=db_session, tool_ids=persona_config.tool_ids
)
)
fetched_docs = fetch_existing_doc_sets(
db_session=db_session, doc_ids=persona_config.document_set_ids
)
persona.document_sets = fetched_docs
return persona
def process_kg_commands(
message: str, persona_name: str, tenant_id: str, db_session: Session # noqa: ARG001
) -> None:
@@ -497,10 +424,44 @@ def convert_chat_history_basic(
return list(reversed(trimmed_reversed))
def _build_tool_call_response_history_message(
tool_name: str,
generated_images: list[dict] | None,
tool_call_response: str | None,
) -> str:
if tool_name != IMAGE_GENERATION_TOOL_NAME:
return TOOL_CALL_RESPONSE_CROSS_MESSAGE
if generated_images:
llm_image_context: list[dict[str, str]] = []
for image in generated_images:
file_id = image.get("file_id")
revised_prompt = image.get("revised_prompt")
if not isinstance(file_id, str):
continue
llm_image_context.append(
{
"file_id": file_id,
"revised_prompt": (
revised_prompt if isinstance(revised_prompt, str) else ""
),
}
)
if llm_image_context:
return json.dumps(llm_image_context)
if tool_call_response:
return tool_call_response
return TOOL_CALL_RESPONSE_CROSS_MESSAGE
def convert_chat_history(
chat_history: list[ChatMessage],
files: list[ChatLoadedFile],
project_image_files: list[ChatLoadedFile],
context_image_files: list[ChatLoadedFile],
additional_context: str | None,
token_counter: Callable[[str], int],
tool_id_to_name_map: dict[int, str],
@@ -580,11 +541,11 @@ def convert_chat_history(
)
# Add the user message with image files attached
# If this is the last USER message, also include project_image_files
# Note: project image file tokens are NOT counted in the token count
# If this is the last USER message, also include context_image_files
# Note: context image file tokens are NOT counted in the token count
if idx == last_user_message_idx:
if project_image_files:
image_files.extend(project_image_files)
if context_image_files:
image_files.extend(context_image_files)
if additional_context:
simple_messages.append(
@@ -657,10 +618,24 @@ def convert_chat_history(
# Add TOOL_CALL_RESPONSE messages for each tool call in this turn
for tool_call in turn_tool_calls:
tool_name = tool_id_to_name_map.get(
tool_call.tool_id, "unknown"
)
tool_response_message = (
_build_tool_call_response_history_message(
tool_name=tool_name,
generated_images=tool_call.generated_images,
tool_call_response=tool_call.tool_call_response,
)
)
simple_messages.append(
ChatMessageSimple(
message=TOOL_CALL_RESPONSE_CROSS_MESSAGE,
token_count=20, # Tiny overestimate
message=tool_response_message,
token_count=(
token_counter(tool_response_message)
if tool_name == IMAGE_GENERATION_TOOL_NAME
else 20
),
message_type=MessageType.TOOL_CALL_RESPONSE,
tool_call_id=tool_call.tool_call_id,
image_files=None,
@@ -688,28 +663,34 @@ def convert_chat_history(
def get_custom_agent_prompt(persona: Persona, chat_session: ChatSession) -> str | None:
"""Get the custom agent prompt from persona or project instructions.
"""Get the custom agent prompt from persona or project instructions. If it's replacing the base system prompt,
it does not count as a custom agent prompt (logic exists later also to drop it in this case).
Chat Sessions in Projects that are using a custom agent will retain the custom agent prompt.
Priority: persona.system_prompt > chat_session.project.instructions > None
Priority: persona.system_prompt (if not default Agent) > chat_session.project.instructions
# NOTE: Logic elsewhere allows saving empty strings for potentially other purposes but for constructing the prompts
# we never want to return an empty string for a prompt so it's translated into an explicit None.
Args:
persona: The Persona object
chat_session: The ChatSession object
Returns:
The custom agent prompt string, or None if neither persona nor project has one
The prompt to use for the custom Agent part of the prompt.
"""
# Not considered a custom agent if it's the default behavior persona
if persona.id == DEFAULT_PERSONA_ID:
return None
# If using a custom Agent, always respect its prompt, even if in a Project, and even if it's an empty custom prompt.
if persona.id != DEFAULT_PERSONA_ID:
# Logic exists later also to drop it in this case but this is strictly correct anyhow.
if persona.replace_base_system_prompt:
return None
return persona.system_prompt or None
if persona.system_prompt:
return persona.system_prompt
elif chat_session.project and chat_session.project.instructions:
# If in a project and using the default Agent, respect the project instructions.
if chat_session.project and chat_session.project.instructions:
return chat_session.project.instructions
else:
return None
return None
def is_last_assistant_message_clarification(chat_history: list[ChatMessage]) -> bool:

View File

@@ -17,20 +17,26 @@ from onyx.configs.chat_configs import COMPRESSION_TRIGGER_RATIO
from onyx.configs.constants import MessageType
from onyx.db.models import ChatMessage
from onyx.llm.interfaces import LLM
from onyx.llm.models import AssistantMessage
from onyx.llm.models import ChatCompletionMessage
from onyx.llm.models import SystemMessage
from onyx.llm.models import UserMessage
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.prompts.compression_prompts import PROGRESSIVE_SUMMARY_PROMPT
from onyx.prompts.compression_prompts import PROGRESSIVE_SUMMARY_SYSTEM_PROMPT_BLOCK
from onyx.prompts.compression_prompts import PROGRESSIVE_USER_REMINDER
from onyx.prompts.compression_prompts import SUMMARIZATION_CUTOFF_MARKER
from onyx.prompts.compression_prompts import SUMMARIZATION_PROMPT
from onyx.prompts.compression_prompts import USER_FINAL_REMINDER
from onyx.prompts.compression_prompts import USER_REMINDER
from onyx.tracing.framework.create import ensure_trace
from onyx.tracing.llm_utils import llm_generation_span
from onyx.tracing.llm_utils import record_llm_response
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
# Ratio of available context to allocate for recent messages after compression
RECENT_MESSAGES_RATIO = 0.25
RECENT_MESSAGES_RATIO = 0.2
class CompressionResult(BaseModel):
@@ -187,6 +193,11 @@ def get_messages_to_summarize(
recent_messages.insert(0, msg)
tokens_used += msg_tokens
# Ensure cutoff is right before a user message by moving any leading
# non-user messages from recent_messages to older_messages
while recent_messages and recent_messages[0].message_type != MessageType.USER:
recent_messages.pop(0)
# Everything else gets summarized
recent_ids = {m.id for m in recent_messages}
older_messages = [m for m in messages if m.id not in recent_ids]
@@ -196,31 +207,47 @@ def get_messages_to_summarize(
)
def format_messages_for_summary(
def _build_llm_messages_for_summarization(
messages: list[ChatMessage],
tool_id_to_name: dict[int, str],
) -> str:
"""Format messages into a string for the summarization prompt.
) -> list[UserMessage | AssistantMessage]:
"""Convert ChatMessage objects to LLM message format for summarization.
Tool call messages are formatted compactly to save tokens.
This is intentionally different from translate_history_to_llm_format in llm_step.py:
- Compacts tool calls to "[Used tools: tool1, tool2]" to save tokens in summaries
- Skips TOOL_CALL_RESPONSE messages entirely (tool usage captured in assistant message)
- No image/multimodal handling (summaries are text-only)
- No caching or LLMConfig-specific behavior needed
"""
formatted = []
result: list[UserMessage | AssistantMessage] = []
for msg in messages:
# Format assistant messages with tool calls compactly
if msg.message_type == MessageType.ASSISTANT and msg.tool_calls:
tool_names = [
tool_id_to_name.get(tc.tool_id, "unknown") for tc in msg.tool_calls
]
formatted.append(f"[assistant used tools: {', '.join(tool_names)}]")
# Skip empty messages
if not msg.message:
continue
# Handle assistant messages with tool calls compactly
if msg.message_type == MessageType.ASSISTANT:
if msg.tool_calls:
tool_names = [
tool_id_to_name.get(tc.tool_id, "unknown") for tc in msg.tool_calls
]
result.append(
AssistantMessage(content=f"[Used tools: {', '.join(tool_names)}]")
)
else:
result.append(AssistantMessage(content=msg.message))
continue
# Skip tool call response messages - tool calls are captured above via assistant messages
if msg.message_type == MessageType.TOOL_CALL_RESPONSE:
continue
role = msg.message_type.value
formatted.append(f"[{role}]: {msg.message}")
return "\n\n".join(formatted)
# Handle user messages
if msg.message_type == MessageType.USER:
result.append(UserMessage(content=msg.message))
return result
def generate_summary(
@@ -236,6 +263,9 @@ def generate_summary(
The cutoff marker tells the LLM to summarize only older messages,
while using recent messages as context to inform what's important.
Messages are sent as separate UserMessage/AssistantMessage objects rather
than being concatenated into a single message.
Args:
older_messages: Messages to compress into summary (before cutoff)
recent_messages: Messages kept verbatim (after cutoff, for context only)
@@ -246,37 +276,54 @@ def generate_summary(
Returns:
Summary text
"""
older_messages_str = format_messages_for_summary(older_messages, tool_id_to_name)
recent_messages_str = format_messages_for_summary(recent_messages, tool_id_to_name)
# Build user prompt with cutoff marker
# Build system prompt
system_content = SUMMARIZATION_PROMPT
if existing_summary:
# Progressive summarization: include existing summary
user_prompt = PROGRESSIVE_SUMMARY_PROMPT.format(
existing_summary=existing_summary
# Progressive summarization: append existing summary to system prompt
system_content += PROGRESSIVE_SUMMARY_SYSTEM_PROMPT_BLOCK.format(
previous_summary=existing_summary
)
user_prompt += f"\n\n{older_messages_str}"
final_reminder = PROGRESSIVE_USER_REMINDER
else:
# Initial summarization
user_prompt = older_messages_str
final_reminder = USER_FINAL_REMINDER
final_reminder = USER_REMINDER
# Add cutoff marker and recent messages as context
user_prompt += f"\n\n{SUMMARIZATION_CUTOFF_MARKER}"
if recent_messages_str:
user_prompt += f"\n\n{recent_messages_str}"
# Convert messages to LLM format (using compression-specific conversion)
older_llm_messages = _build_llm_messages_for_summarization(
older_messages, tool_id_to_name
)
recent_llm_messages = _build_llm_messages_for_summarization(
recent_messages, tool_id_to_name
)
# Build message list with separate messages
input_messages: list[ChatCompletionMessage] = [
SystemMessage(content=system_content),
]
# Add older messages (to be summarized)
input_messages.extend(older_llm_messages)
# Add cutoff marker as a user message
input_messages.append(UserMessage(content=SUMMARIZATION_CUTOFF_MARKER))
# Add recent messages (for context only)
input_messages.extend(recent_llm_messages)
# Add final reminder
user_prompt += f"\n\n{final_reminder}"
input_messages.append(UserMessage(content=final_reminder))
response = llm.invoke(
[
SystemMessage(content=SUMMARIZATION_PROMPT),
UserMessage(content=user_prompt),
]
)
return response.choice.message.content or ""
with llm_generation_span(
llm=llm,
flow="chat_history_summarization",
input_messages=input_messages,
) as span_generation:
response = llm.invoke(input_messages)
record_llm_response(span_generation, response)
content = response.choice.message.content
if not (content and content.strip()):
raise ValueError("LLM returned empty summary")
return content.strip()
def compress_chat_history(
@@ -292,6 +339,19 @@ def compress_chat_history(
The summary message's parent_message_id points to the last message in
chat_history, making it branch-aware via the tree structure.
Note: This takes the entire chat history as input, splits it into older
messages (to summarize) and recent messages (kept verbatim within the
token budget), generates a summary of the older part, and persists the
new summary message with its parent set to the last message in history.
Past summary is taken into context (progressive summarization): we find
at most one existing summary for this branch. If present, only messages
after that summary's last_summarized_message_id are considered; the
existing summary text is passed into the LLM so the new summary
incorporates it instead of summarizing from scratch.
For more details, see the COMPRESSION.md file.
Args:
db_session: Database session
chat_history: Branch-aware list of messages
@@ -305,74 +365,84 @@ def compress_chat_history(
if not chat_history:
return CompressionResult(summary_created=False, messages_summarized=0)
chat_session_id = chat_history[0].chat_session_id
logger.info(
f"Starting compression for session {chat_history[0].chat_session_id}, "
f"Starting compression for session {chat_session_id}, "
f"history_len={len(chat_history)}, tokens_for_recent={compression_params.tokens_for_recent}"
)
try:
# Find existing summary for this branch
existing_summary = find_summary_for_branch(db_session, chat_history)
with ensure_trace(
"chat_history_compression",
group_id=str(chat_session_id),
metadata={
"tenant_id": get_current_tenant_id(),
"chat_session_id": str(chat_session_id),
},
):
try:
# Find existing summary for this branch
existing_summary = find_summary_for_branch(db_session, chat_history)
# Get messages to summarize
summary_content = get_messages_to_summarize(
chat_history,
existing_summary,
tokens_for_recent=compression_params.tokens_for_recent,
)
# Get messages to summarize
summary_content = get_messages_to_summarize(
chat_history,
existing_summary,
tokens_for_recent=compression_params.tokens_for_recent,
)
if not summary_content.older_messages:
logger.debug("No messages to summarize, skipping compression")
return CompressionResult(summary_created=False, messages_summarized=0)
if not summary_content.older_messages:
logger.debug("No messages to summarize, skipping compression")
return CompressionResult(summary_created=False, messages_summarized=0)
# Generate summary (incorporate existing summary if present)
existing_summary_text = existing_summary.message if existing_summary else None
summary_text = generate_summary(
older_messages=summary_content.older_messages,
recent_messages=summary_content.recent_messages,
llm=llm,
tool_id_to_name=tool_id_to_name,
existing_summary=existing_summary_text,
)
# Generate summary (incorporate existing summary if present)
existing_summary_text = (
existing_summary.message if existing_summary else None
)
summary_text = generate_summary(
older_messages=summary_content.older_messages,
recent_messages=summary_content.recent_messages,
llm=llm,
tool_id_to_name=tool_id_to_name,
existing_summary=existing_summary_text,
)
# Calculate token count for the summary
tokenizer = get_tokenizer(None, None)
summary_token_count = len(tokenizer.encode(summary_text))
logger.debug(
f"Generated summary ({summary_token_count} tokens): {summary_text[:200]}..."
)
# Calculate token count for the summary
tokenizer = get_tokenizer(None, None)
summary_token_count = len(tokenizer.encode(summary_text))
logger.debug(
f"Generated summary ({summary_token_count} tokens): {summary_text[:200]}..."
)
# Create new summary as a ChatMessage
# Parent is the last message in history - this makes the summary branch-aware
summary_message = ChatMessage(
chat_session_id=chat_history[0].chat_session_id,
message_type=MessageType.ASSISTANT,
message=summary_text,
token_count=summary_token_count,
parent_message_id=chat_history[-1].id,
last_summarized_message_id=summary_content.older_messages[-1].id,
)
db_session.add(summary_message)
db_session.commit()
# Create new summary as a ChatMessage
# Parent is the last message in history - this makes the summary branch-aware
summary_message = ChatMessage(
chat_session_id=chat_session_id,
message_type=MessageType.ASSISTANT,
message=summary_text,
token_count=summary_token_count,
parent_message_id=chat_history[-1].id,
last_summarized_message_id=summary_content.older_messages[-1].id,
)
db_session.add(summary_message)
db_session.commit()
logger.info(
f"Compressed {len(summary_content.older_messages)} messages into summary "
f"(session_id={chat_history[0].chat_session_id}, "
f"summary_tokens={summary_token_count})"
)
logger.info(
f"Compressed {len(summary_content.older_messages)} messages into summary "
f"(session_id={chat_session_id}, "
f"summary_tokens={summary_token_count})"
)
return CompressionResult(
summary_created=True,
messages_summarized=len(summary_content.older_messages),
)
return CompressionResult(
summary_created=True,
messages_summarized=len(summary_content.older_messages),
)
except Exception as e:
logger.exception(
f"Compression failed for session {chat_history[0].chat_session_id}: {e}"
)
db_session.rollback()
return CompressionResult(
summary_created=False,
messages_summarized=0,
error=str(e),
)
except Exception as e:
logger.exception(f"Compression failed for session {chat_session_id}: {e}")
db_session.rollback()
return CompressionResult(
summary_created=False,
messages_summarized=0,
error=str(e),
)

View File

@@ -15,10 +15,10 @@ from onyx.chat.emitter import Emitter
from onyx.chat.llm_step import extract_tool_calls_from_response_text
from onyx.chat.llm_step import run_llm_step
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import ExtractedProjectFiles
from onyx.chat.models import ContextFileMetadata
from onyx.chat.models import ExtractedContextFiles
from onyx.chat.models import FileToolMetadata
from onyx.chat.models import LlmStepResult
from onyx.chat.models import ProjectFileMetadata
from onyx.chat.models import ToolCallSimple
from onyx.chat.prompt_utils import build_reminder_message
from onyx.chat.prompt_utils import build_system_prompt
@@ -30,6 +30,7 @@ from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MessageType
from onyx.context.search.models import SearchDoc
from onyx.context.search.models import SearchDocsResponse
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.memory import add_memory
from onyx.db.memory import update_memory_at_index
from onyx.db.memory import UserMemoryContext
@@ -38,7 +39,6 @@ from onyx.llm.constants import LlmProviderNames
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMUserIdentity
from onyx.llm.interfaces import ToolChoiceOptions
from onyx.llm.utils import model_needs_formatting_reenabled
from onyx.prompts.chat_prompts import IMAGE_GEN_REMINDER
from onyx.prompts.chat_prompts import OPEN_URL_REMINDER
from onyx.server.query_and_chat.placement import Placement
@@ -49,6 +49,7 @@ from onyx.server.query_and_chat.streaming_models import TopLevelBranching
from onyx.tools.built_in_tools import CITEABLE_TOOLS_NAMES
from onyx.tools.built_in_tools import STOPPING_TOOLS_NAMES
from onyx.tools.interface import Tool
from onyx.tools.models import ChatFile
from onyx.tools.models import MemoryToolResponseSnapshot
from onyx.tools.models import ToolCallInfo
from onyx.tools.models import ToolCallKickoff
@@ -57,6 +58,7 @@ from onyx.tools.tool_implementations.images.models import (
FinalImageGenerationResponse,
)
from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
from onyx.tools.tool_implementations.python.python_tool import PythonTool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.web_search.utils import extract_url_snippet_map
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
@@ -68,6 +70,18 @@ from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
def _looks_like_xml_tool_call_payload(text: str | None) -> bool:
"""Detect XML-style marshaled tool calls emitted as plain text."""
if not text:
return False
lowered = text.lower()
return (
"<function_calls" in lowered
and "<invoke" in lowered
and "<parameter" in lowered
)
def _should_keep_bedrock_tool_definitions(
llm: object, simple_chat_history: list[ChatMessageSimple]
) -> bool:
@@ -122,38 +136,56 @@ def _try_fallback_tool_extraction(
reasoning_but_no_answer_or_tools = (
llm_step_result.reasoning and not llm_step_result.answer and no_tool_calls
)
xml_tool_call_text_detected = no_tool_calls and (
_looks_like_xml_tool_call_payload(llm_step_result.answer)
or _looks_like_xml_tool_call_payload(llm_step_result.raw_answer)
or _looks_like_xml_tool_call_payload(llm_step_result.reasoning)
)
should_try_fallback = (
tool_choice == ToolChoiceOptions.REQUIRED and no_tool_calls
) or reasoning_but_no_answer_or_tools
(tool_choice == ToolChoiceOptions.REQUIRED and no_tool_calls)
or reasoning_but_no_answer_or_tools
or xml_tool_call_text_detected
)
if not should_try_fallback:
return llm_step_result, False
# Try to extract from answer first, then fall back to reasoning
extracted_tool_calls: list[ToolCallKickoff] = []
if llm_step_result.answer:
extracted_tool_calls = extract_tool_calls_from_response_text(
response_text=llm_step_result.answer,
tool_definitions=tool_defs,
placement=Placement(turn_index=turn_index),
)
if (
not extracted_tool_calls
and llm_step_result.raw_answer
and llm_step_result.raw_answer != llm_step_result.answer
):
extracted_tool_calls = extract_tool_calls_from_response_text(
response_text=llm_step_result.raw_answer,
tool_definitions=tool_defs,
placement=Placement(turn_index=turn_index),
)
if not extracted_tool_calls and llm_step_result.reasoning:
extracted_tool_calls = extract_tool_calls_from_response_text(
response_text=llm_step_result.reasoning,
tool_definitions=tool_defs,
placement=Placement(turn_index=turn_index),
)
if extracted_tool_calls:
logger.info(
f"Extracted {len(extracted_tool_calls)} tool call(s) from response text "
f"as fallback (tool_choice was REQUIRED but no tool calls returned)"
"as fallback"
)
return (
LlmStepResult(
reasoning=llm_step_result.reasoning,
answer=llm_step_result.answer,
tool_calls=extracted_tool_calls,
raw_answer=llm_step_result.raw_answer,
),
True,
)
@@ -171,17 +203,17 @@ def _try_fallback_tool_extraction(
MAX_LLM_CYCLES = 6
def _build_project_file_citation_mapping(
project_file_metadata: list[ProjectFileMetadata],
def _build_context_file_citation_mapping(
file_metadata: list[ContextFileMetadata],
starting_citation_num: int = 1,
) -> CitationMapping:
"""Build citation mapping for project files.
"""Build citation mapping for context files.
Converts project file metadata into SearchDoc objects that can be cited.
Converts context file metadata into SearchDoc objects that can be cited.
Citation numbers start from the provided starting number.
Args:
project_file_metadata: List of project file metadata
file_metadata: List of context file metadata
starting_citation_num: Starting citation number (default: 1)
Returns:
@@ -189,8 +221,7 @@ def _build_project_file_citation_mapping(
"""
citation_mapping: CitationMapping = {}
for idx, file_meta in enumerate(project_file_metadata, start=starting_citation_num):
# Create a SearchDoc for each project file
for idx, file_meta in enumerate(file_metadata, start=starting_citation_num):
search_doc = SearchDoc(
document_id=file_meta.file_id,
chunk_ind=0,
@@ -210,29 +241,28 @@ def _build_project_file_citation_mapping(
def _build_project_message(
project_files: ExtractedProjectFiles | None,
context_files: ExtractedContextFiles | None,
token_counter: Callable[[str], int] | None,
) -> list[ChatMessageSimple]:
"""Build messages for project / tool-backed files.
"""Build messages for context-injected / tool-backed files.
Returns up to two messages:
1. The full-text project files message (if project_file_texts is populated).
1. The full-text files message (if file_texts is populated).
2. A lightweight metadata message for files the LLM should access via the
FileReaderTool (e.g. oversized chat-attached files or project files that
don't fit in context).
FileReaderTool (e.g. oversized files that don't fit in context).
"""
if not project_files:
if not context_files:
return []
messages: list[ChatMessageSimple] = []
if project_files.project_file_texts:
if context_files.file_texts:
messages.append(
_create_project_files_message(project_files, token_counter=None)
_create_context_files_message(context_files, token_counter=None)
)
if project_files.file_metadata_for_tool and token_counter:
if context_files.file_metadata_for_tool and token_counter:
messages.append(
_create_file_tool_metadata_message(
project_files.file_metadata_for_tool, token_counter
context_files.file_metadata_for_tool, token_counter
)
)
return messages
@@ -243,7 +273,7 @@ def construct_message_history(
custom_agent_prompt: ChatMessageSimple | None,
simple_chat_history: list[ChatMessageSimple],
reminder_message: ChatMessageSimple | None,
project_files: ExtractedProjectFiles | None,
context_files: ExtractedContextFiles | None,
available_tokens: int,
last_n_user_messages: int | None = None,
token_counter: Callable[[str], int] | None = None,
@@ -257,7 +287,7 @@ def construct_message_history(
# Build the project / file-metadata messages up front so we can use their
# actual token counts for the budget.
project_messages = _build_project_message(project_files, token_counter)
project_messages = _build_project_message(context_files, token_counter)
project_messages_tokens = sum(m.token_count for m in project_messages)
history_token_budget = available_tokens
@@ -413,17 +443,17 @@ def construct_message_history(
)
# Attach project images to the last user message
if project_files and project_files.project_image_files:
if context_files and context_files.image_files:
existing_images = last_user_message.image_files or []
last_user_message = ChatMessageSimple(
message=last_user_message.message,
token_count=last_user_message.token_count,
message_type=last_user_message.message_type,
image_files=existing_images + project_files.project_image_files,
image_files=existing_images + context_files.image_files,
)
# Build the final message list according to README ordering:
# [system], [history_before_last_user], [custom_agent], [project_files],
# [system], [history_before_last_user], [custom_agent], [context_files],
# [forgotten_files], [last_user_message], [messages_after_last_user], [reminder]
result = [system_prompt] if system_prompt else []
@@ -434,14 +464,14 @@ def construct_message_history(
if custom_agent_prompt:
result.append(custom_agent_prompt)
# 3. Add project files / file-metadata messages (inserted before last user message)
# 3. Add context files / file-metadata messages (inserted before last user message)
result.extend(project_messages)
# 4. Add forgotten-files metadata (right before the user's question)
if forgotten_files_message:
result.append(forgotten_files_message)
# 5. Add last user message (with project images attached)
# 5. Add last user message (with context images attached)
result.append(last_user_message)
# 6. Add messages after last user message (tool calls, responses, etc.)
@@ -451,7 +481,42 @@ def construct_message_history(
if reminder_message:
result.append(reminder_message)
return result
return _drop_orphaned_tool_call_responses(result)
def _drop_orphaned_tool_call_responses(
messages: list[ChatMessageSimple],
) -> list[ChatMessageSimple]:
"""Drop tool response messages whose tool_call_id is not in prior assistant tool calls.
This can happen when history truncation drops an ASSISTANT tool-call message but
leaves a later TOOL_CALL_RESPONSE message in context. Some providers (e.g. Ollama)
reject such history with an "unexpected tool call id" error.
"""
known_tool_call_ids: set[str] = set()
sanitized: list[ChatMessageSimple] = []
for msg in messages:
if msg.message_type == MessageType.ASSISTANT and msg.tool_calls:
for tool_call in msg.tool_calls:
known_tool_call_ids.add(tool_call.tool_call_id)
sanitized.append(msg)
continue
if msg.message_type == MessageType.TOOL_CALL_RESPONSE:
if msg.tool_call_id and msg.tool_call_id in known_tool_call_ids:
sanitized.append(msg)
else:
logger.debug(
"Dropping orphaned tool response with tool_call_id=%s while "
"constructing message history",
msg.tool_call_id,
)
continue
sanitized.append(msg)
return sanitized
def _create_file_tool_metadata_message(
@@ -480,11 +545,11 @@ def _create_file_tool_metadata_message(
)
def _create_project_files_message(
project_files: ExtractedProjectFiles,
def _create_context_files_message(
context_files: ExtractedContextFiles,
token_counter: Callable[[str], int] | None, # noqa: ARG001
) -> ChatMessageSimple:
"""Convert project files to a ChatMessageSimple message.
"""Convert context files to a ChatMessageSimple message.
Format follows the README specification for document representation.
"""
@@ -492,7 +557,7 @@ def _create_project_files_message(
# Format as documents JSON as described in README
documents_list = []
for idx, file_text in enumerate(project_files.project_file_texts, start=1):
for idx, file_text in enumerate(context_files.file_texts, start=1):
documents_list.append(
{
"document": idx,
@@ -503,10 +568,10 @@ def _create_project_files_message(
documents_json = json.dumps({"documents": documents_list}, indent=2)
message_content = f"Here are some documents provided for context, they may not all be relevant:\n{documents_json}"
# Use pre-calculated token count from project_files
# Use pre-calculated token count from context_files
return ChatMessageSimple(
message=message_content,
token_count=project_files.total_token_count,
token_count=context_files.total_token_count,
message_type=MessageType.USER,
)
@@ -517,7 +582,7 @@ def run_llm_loop(
simple_chat_history: list[ChatMessageSimple],
tools: list[Tool],
custom_agent_prompt: str | None,
project_files: ExtractedProjectFiles,
context_files: ExtractedContextFiles,
persona: Persona | None,
user_memory_context: UserMemoryContext | None,
llm: LLM,
@@ -526,6 +591,7 @@ def run_llm_loop(
forced_tool_id: int | None = None,
user_identity: LLMUserIdentity | None = None,
chat_session_id: str | None = None,
chat_files: list[ChatFile] | None = None,
include_citations: bool = True,
all_injected_file_metadata: dict[str, FileToolMetadata] | None = None,
inject_memories_in_prompt: bool = True,
@@ -559,9 +625,9 @@ def run_llm_loop(
# Add project file citation mappings if project files are present
project_citation_mapping: CitationMapping = {}
if project_files.project_file_metadata:
project_citation_mapping = _build_project_file_citation_mapping(
project_files.project_file_metadata
if context_files.file_metadata:
project_citation_mapping = _build_context_file_citation_mapping(
context_files.file_metadata
)
citation_processor.update_citation_mapping(project_citation_mapping)
@@ -579,21 +645,28 @@ def run_llm_loop(
# TODO allow citing of images in Projects. Since attached to the last user message, it has no text associated with it.
# One future workaround is to include the images as separate user messages with citation information and process those.
always_cite_documents: bool = bool(
project_files.project_as_filter or project_files.project_file_texts
context_files.use_as_search_filter or context_files.file_texts
)
should_cite_documents: bool = False
ran_image_gen: bool = False
just_ran_web_search: bool = False
has_called_search_tool: bool = False
code_interpreter_file_generated: bool = False
fallback_extraction_attempted: bool = False
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
default_base_system_prompt: str = get_default_base_system_prompt(db_session)
# Fetch this in a short-lived session so the long-running stream loop does
# not pin a connection just to keep read state alive.
with get_session_with_current_tenant() as prompt_db_session:
default_base_system_prompt: str = get_default_base_system_prompt(
prompt_db_session
)
system_prompt = None
custom_agent_prompt_msg = None
reasoning_cycles = 0
for llm_cycle_count in range(MAX_LLM_CYCLES):
# Handling tool calls based on cycle count and past cycle conditions
out_of_cycles = llm_cycle_count == MAX_LLM_CYCLES - 1
if forced_tool_id:
# Needs to be just the single one because the "required" currently doesn't have a specified tool, just a binary
@@ -615,6 +688,7 @@ def run_llm_loop(
tool_choice = ToolChoiceOptions.AUTO
final_tools = tools
# Handling the system prompt and custom agent prompt
# The section below calculates the available tokens for history a bit more accurately
# now that project files are loaded in.
if persona and persona.replace_base_system_prompt:
@@ -632,12 +706,14 @@ def run_llm_loop(
else:
# If it's an empty string, we assume the user does not want to include it as an empty System message
if default_base_system_prompt:
open_ai_formatting_enabled = model_needs_formatting_reenabled(
llm.config.model_name
)
prompt_memory_context = (
user_memory_context if inject_memories_in_prompt else None
user_memory_context
if inject_memories_in_prompt
else (
user_memory_context.without_memories()
if user_memory_context
else None
)
)
system_prompt_str = build_system_prompt(
base_system_prompt=default_base_system_prompt,
@@ -646,7 +722,6 @@ def run_llm_loop(
tools=tools,
should_cite_documents=should_cite_documents
or always_cite_documents,
open_ai_formatting_enabled=open_ai_formatting_enabled,
)
system_prompt = ChatMessageSimple(
message=system_prompt_str,
@@ -692,6 +767,7 @@ def run_llm_loop(
),
include_citation_reminder=should_cite_documents
or always_cite_documents,
include_file_reminder=code_interpreter_file_generated,
is_last_cycle=out_of_cycles,
)
@@ -710,7 +786,7 @@ def run_llm_loop(
custom_agent_prompt=custom_agent_prompt_msg,
simple_chat_history=simple_chat_history,
reminder_message=reminder_msg,
project_files=project_files,
context_files=context_files,
available_tokens=available_tokens,
token_counter=token_counter,
all_injected_file_metadata=all_injected_file_metadata,
@@ -804,6 +880,7 @@ def run_llm_loop(
next_citation_num=citation_processor.get_next_citation_number(),
max_concurrent_tools=None,
skip_search_query_expansion=has_called_search_tool,
chat_files=chat_files,
url_snippet_map=extract_url_snippet_map(gathered_documents or []),
inject_memories_in_prompt=inject_memories_in_prompt,
)
@@ -830,6 +907,18 @@ def run_llm_loop(
if tool_call.tool_name == SearchTool.NAME:
has_called_search_tool = True
# Track if code interpreter generated files with download links
if (
tool_call.tool_name == PythonTool.NAME
and not code_interpreter_file_generated
):
try:
parsed = json.loads(tool_response.llm_facing_response)
if parsed.get("generated_files"):
code_interpreter_file_generated = True
except (json.JSONDecodeError, AttributeError):
pass
# Build a mapping of tool names to tool objects for getting tool_id
tools_by_name = {tool.name: tool for tool in final_tools}

Some files were not shown because too many files have changed in this diff Show More