Compare commits

..

167 Commits

Author SHA1 Message Date
Dane Urban
2036847a2d Lookup by name 2026-02-27 12:40:18 -08:00
Dane Urban
9ff1ac862e . 2026-02-27 12:35:38 -08:00
Dane Urban
e4179629ae Test 2026-02-27 10:29:10 -08:00
Dane Urban
4077c20def Add python tool to default persona 2026-02-27 10:14:53 -08:00
Yuhong Sun
4d256c5666 chore: remove instance of Assistant from frontend (#8848)
Co-authored-by: Nik <nikolas.garza5@gmail.com>
2026-02-27 04:22:28 +00:00
Danelegend
2e53496f46 feat: Code interpreter admin page visuals (#8729) 2026-02-27 04:01:02 +00:00
acaprau
63a206706a docs(best practices): Add comment about import-time side effects and main.py files (#8820)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-02-27 01:29:56 +00:00
Nikolas Garza
28427b3e5f fix(metrics): restore default HTTP request counter and histogram metrics (#8842) 2026-02-27 00:53:22 +00:00
Justin Tahara
3cafcd8a5e chore(llm): add OpenRouter nightly tests (#8818) 2026-02-26 23:54:25 +00:00
Justin Tahara
f2c50b7bb5 chore(llm): add Ollama nightly tests (#8817) 2026-02-26 23:28:40 +00:00
Jamison Lahman
6b28c6bbfc fix(fe): Search Actions popover has consistent hover states (#8826) 2026-02-26 23:16:09 +00:00
Justin Tahara
226e801665 chore(llm): add Azure nightly tests (#8816) 2026-02-26 23:05:03 +00:00
Justin Tahara
be13aa1310 chore(llm): add Vertex AI nightly tests (#8813) 2026-02-26 22:38:05 +00:00
Nikolas Garza
45d38c4906 feat(metrics): add per-tenant Prometheus metrics (#8822) 2026-02-26 22:37:35 +00:00
Danelegend
8aab518532 fix: Admin page modal centering excludes sidebar (#8823) 2026-02-26 22:27:58 +00:00
Nikolas Garza
da6ce10e86 test(scim): add integration tests for SCIM token management (#8819) 2026-02-26 22:22:16 +00:00
Nikolas Garza
aaf8253520 fix(ee): show subscription text on expired access page for cloud users (#8804) 2026-02-26 22:15:44 +00:00
Jamison Lahman
7c7f81b164 chore(fe): add feature agent to editor page (#8814) 2026-02-26 22:12:20 +00:00
Justin Tahara
2d4a3c72e9 chore(llm): Nightly Bedrock Tests (#8812) 2026-02-26 22:10:31 +00:00
acaprau
7c51712018 fix(db ssl): Remove import-time side effect of creating SSL context if IAM enabled (#8811) 2026-02-26 21:37:13 +00:00
Evan Lohn
aa5614695d feat: sharepoint tenant avoid org get (#8802) 2026-02-26 21:28:56 +00:00
Jamison Lahman
8d7255d3c4 chore(fe): support featured agents w/o being public (#8809) 2026-02-26 21:16:23 +00:00
Evan Lohn
d403498f48 feat: context injection unification (#8687) 2026-02-26 21:11:19 +00:00
dependabot[bot]
9ef3095c17 chore(deps): bump pypdf from 6.6.2 to 6.7.3 (#8808)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-26 20:42:01 +00:00
Justin Tahara
a39e93a0cb chore(llm): LLM Integration Tests Generic Setup (#8803) 2026-02-26 19:59:19 +00:00
Jamison Lahman
46d73cdfee fix(docker): prefer user runtime docker socket (#8799) 2026-02-26 10:55:44 -08:00
Raunak Bhagat
1e04ce78e0 feat(opal): add Hoverable compound component (#8798) 2026-02-26 17:08:53 +00:00
Jamison Lahman
f9b81c1725 feat(agents): share agents with labels or featured (#8742)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-02-26 16:21:05 +00:00
SubashMohan
3bc1b89fee fix(memory): timeline UI alignment issues and highlighting issue (#8753) 2026-02-26 08:46:43 +00:00
Nikolas Garza
01743d99d4 fix(billing): handle manual license users without Stripe subscription (#8787) 2026-02-26 08:07:14 +00:00
acaprau
092c1db7e0 chore(opensearch): Allow programatic schema updates (#8794) 2026-02-26 07:49:56 +00:00
acaprau
40ac0d859a chore(opensearch): OpenSearchClient implements context manager, also closes on del (#8781)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-02-26 07:38:16 +00:00
SubashMohan
929e58361f fix: resolve OAuth token manager using masked secrets (#8673) 2026-02-26 07:06:51 +00:00
SubashMohan
6d472df7c5 fix(timeline): Fix double-collapse and improve tool status messages (#8751) 2026-02-26 07:05:48 +00:00
acaprau
cfa7acd904 chore(opensearch): MT cloud should verify index on document index init, and do cluster setup once at start (#8776) 2026-02-26 06:42:06 +00:00
Danelegend
5c5a6f943b chore: deprecate llm provider fields (#8783) 2026-02-26 05:27:28 +00:00
Evan Lohn
d04128b8b1 fix: sharepoint unquote (#8786) 2026-02-26 03:38:46 +00:00
Nikolas Garza
bbebdf8f78 feat(scim): Entra ID enterprise extension support [3/3] (#8747) 2026-02-26 02:32:04 +00:00
Nikolas Garza
161279a2d5 feat(scim): field round-tripping for IdP attribute preservation [2/3] (#8746) 2026-02-26 02:01:13 +00:00
Jamison Lahman
e5ebb45a20 chore(devtools): upgrade ods: v0.6.1->v0.6.2 (#8773) 2026-02-26 01:57:25 +00:00
Evan Lohn
320ba9cb1b refactor: filter by persona id during search (#8683) 2026-02-26 01:51:00 +00:00
Nikolas Garza
f2e8cb3114 fix(slack): sanitize HTML tags and broken citation links in bot responses (#8767) 2026-02-26 01:47:44 +00:00
Nikolas Garza
43054a28ec feat(scim): SCIM 2.0 protocol compliance fixes [1/3] (#8745) 2026-02-26 01:33:08 +00:00
Justin Tahara
dc74aa7b1f chore(llm): Add OpenAI Integration Tests (#8711) 2026-02-26 00:58:28 +00:00
Raunak Bhagat
bd773191c2 feat(opal): add more icons (#8778) 2026-02-26 00:38:54 +00:00
Evan Lohn
66dbff41e6 refactor: extend sync mechanism to persona files (#8682) 2026-02-26 00:32:30 +00:00
roshan
1dcffe38bc fix: Invoke generate_agents_md.py in K8s to populate knowledge sources (#8768)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-26 00:04:10 +00:00
Evan Lohn
c35e883564 refactor: persona id in vector db by indexing (#8681) 2026-02-25 22:51:57 +00:00
Jamison Lahman
fefcd58481 chore(devtools): ods web to run web/package.json scripts (#8766)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-02-25 14:05:29 -08:00
Jamison Lahman
bdc89d9e3f chore(fe): opal button implements responsiveHideText (#8764) 2026-02-25 21:05:08 +00:00
Evan Lohn
f4d777b80d refactor: persona id in vector db (#8680) 2026-02-25 20:42:38 +00:00
acaprau
da4d57b5e3 chore(devtools): Make AGENTS.md reference contributing_guides/best_practices.md (#8760)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-02-25 20:27:12 +00:00
Evan Lohn
dcdcd067bd fix: drive 403 rate limits (#8762) 2026-02-25 20:12:36 +00:00
Evan Lohn
8b15a29723 feat: slab connector validation (#8758) 2026-02-25 20:00:42 +00:00
Danelegend
763853674f feat(ci): Add preview modal for data types (#8752) 2026-02-25 19:52:19 +00:00
Jamison Lahman
429b6f3465 fix(fe): modal aligning with detached element after navigation (#8676) 2026-02-25 19:33:07 +00:00
Danelegend
37d5be1b40 feat: python tool not added when no code interpretter server (#8749) 2026-02-25 19:17:42 +00:00
Jamison Lahman
8ab99dbb06 chore(fe): add hover style to AgentCard (#8689) 2026-02-25 19:08:00 +00:00
Jamison Lahman
52799e9c7a fix(fe): middle align human chat message text (#8756) 2026-02-25 19:00:01 +00:00
Jamison Lahman
aef009cc97 chore(fe): foldable buttons display text via tooltip when disabled (#8735) 2026-02-25 18:39:53 +00:00
Evan Lohn
18d1ea1770 fix: sharepoint driveItem perm sync (#8698) 2026-02-25 18:29:26 +00:00
Bo-Onyx
f336ad00f4 fix(user invitation): failed but no warning. (#8731)
Co-authored-by: Bo Yang <boyang@Bos-MacBook-Pro.local>
2026-02-25 17:23:39 +00:00
SubashMohan
0558e687d9 fix: persist onboarding dismissal in localStorage with user-specific keys (#8674) 2026-02-25 06:22:17 +00:00
roshan
784a99e24a updated demo data (#8748) 2026-02-24 19:59:46 -08:00
Justin Tahara
da1f5a11f4 chore(cherry-pick): Alerting on Failed Cherry-Picks (#8744) 2026-02-25 02:09:19 +00:00
Justin Tahara
5633805890 chore(devtools): Upgrade ods from 0.6.0 -> 0.6.1 (#8743) 2026-02-25 02:01:20 +00:00
Danelegend
0817b45ae1 feat: Get code interpreter config route (#8739) 2026-02-25 01:49:30 +00:00
Justin Tahara
af0e4bdebc fix(slack): Cleaning up URL Links (#8569) 2026-02-25 01:42:12 +00:00
Justin Tahara
4cd2320732 chore(cherry-pick): Add Github Label for PRs (#8736) 2026-02-25 00:46:12 +00:00
Danelegend
90a361f0e1 feat: code interpreter routes (#8670) 2026-02-24 16:27:10 -08:00
Justin Tahara
194efde97b chore(llm): Scaffolding for Nightly LLM Tests (#8704) 2026-02-25 00:06:24 +00:00
Danelegend
d922a42262 feat: code interpreter docker default deploy (#8672) 2026-02-24 23:51:19 +00:00
Danelegend
f00c3a486e feat: default deploy code interpreter - helm & bump version 0.3.0 (#8685) 2026-02-24 23:40:46 +00:00
Danelegend
192080c9e4 feat: default deploy code interpreter - restart_script (#8686) 2026-02-24 23:40:36 +00:00
Justin Tahara
c5787dc073 chore(image): Update test to be for Dall E 3 instead of 2 (#8732) 2026-02-24 22:53:31 +00:00
Justin Tahara
d424d6462c fix(sanitization): Centralizing DB Filters (#8730) 2026-02-24 22:28:25 +00:00
Jamison Lahman
ecea86deb6 chore(fe): only left input items flex (#8734) 2026-02-24 22:25:04 +00:00
Jamison Lahman
a5c1f50a8a chore(fe): update disabled "select" button color (#8733) 2026-02-24 22:03:52 +00:00
roshan
4a04cfd486 feat(craft): make output/ files downloadable from Artifacts tab (#8721)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2026-02-24 21:49:59 +00:00
Nikolas Garza
f22e9628db feat(scim): add additional entra id fields to ScimUserMapping (#8728) 2026-02-24 20:23:21 +00:00
Jamison Lahman
255ba10af6 chore(chat): consolidate chat message whitespacing style (#8696) 2026-02-24 20:02:28 +00:00
Justin Tahara
563202a080 feat(image): support Azure historical image context edits (#8726) 2026-02-24 19:21:30 +00:00
Evan Lohn
1062dc0743 fix: graph client env (#8727) 2026-02-24 18:46:49 +00:00
Justin Tahara
0826348568 feat(image): support OpenAI historical image context edits (#8725) 2026-02-24 18:45:56 +00:00
Justin Tahara
375079136d chore(cherry-pick): Assign merged-by user on beta cherry-pick PR (#8723) 2026-02-24 18:27:48 +00:00
Jamison Lahman
82aad5e253 fix(welcome): add back agent description (#8716) 2026-02-24 17:27:23 +00:00
Jamison Lahman
beb1c49c69 fix(fe): inline code-blocks respect header font-size (#8691) 2026-02-24 17:03:21 +00:00
Jamison Lahman
c4556515be fix(fe): rm non-admin-confirmation max-width (#8693) 2026-02-24 17:03:05 +00:00
SubashMohan
a4387f230b fix(popover): prevent viewport overflow with dynamic max-height and collision padding (#8675) 2026-02-24 10:27:36 +00:00
Evan Lohn
d91e452658 chore: version bumps for client libs (#8720) 2026-02-24 08:13:37 +00:00
Danelegend
dd274f8667 feat: code interpreter supports streaming (#8663) 2026-02-24 06:07:36 +00:00
roshan
2c82f0da16 fix(craft): delete S3 snapshot files when deleting a craft (#8718)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-24 05:58:29 +00:00
Raunak Bhagat
26101636f2 refactor: add new ContentAction component (#8695) 2026-02-24 05:13:18 +00:00
roshan
5e2c0c6cf4 fix(nrf): hide search toggle when search mode is unavailable (#8717)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 20:43:19 -08:00
roshan
33b64db498 fix(extensions): fix base url for chrome extension to (#8714) 2026-02-23 20:18:05 -08:00
roshan
b925cc1a56 feat(chrome-extension): add tab reading to side panel (#8571)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-02-24 01:17:57 +00:00
Danelegend
bac4b7c945 fix: preview markdown formatting (#8667) 2026-02-24 01:13:52 +00:00
Evan Lohn
6f6ef1e657 chore: coerce doc metadata (#8703) 2026-02-24 01:12:11 +00:00
Danelegend
885c69f460 feat: Improve csv preview modal (#8702) 2026-02-24 01:00:20 +00:00
Danelegend
4b837303ff feat(code-interpreter): Seed code interpreter server row (#8701) 2026-02-24 00:59:49 +00:00
Justin Tahara
d856a9befb fix(projects): Guardrails for Project User Files (#8644) 2026-02-24 00:21:57 +00:00
Justin Tahara
adade353c5 fix(api): Improving the API handling of threads (#8573) 2026-02-24 00:04:21 +00:00
Nikolas Garza
3cb6ec2f85 fix: patch prometheus metrics in daily test fixture (#8699) 2026-02-24 00:02:56 +00:00
Wenxi
691eebf00a fix: remove user info requirement for craft onboarding modal (#8697) 2026-02-23 23:52:17 +00:00
Danelegend
905b6633e6 chore: preview modal (#8665) 2026-02-23 23:40:55 +00:00
Justin Tahara
fd088196ff fix(search): Improve Speed (#8430) 2026-02-23 22:45:18 +00:00
Jamison Lahman
cafbf5b8be chore(playwright): warn user if setup takes longer than usual (#8690) 2026-02-23 22:23:58 +00:00
roshan
1235181559 fix(ui): Clean up NRF settings button styling (#8678)
Co-authored-by: Claude <noreply@anthropic.com>
2026-02-23 21:25:43 +00:00
Justin Tahara
caa2e45632 fix(db): Multitenant Schema migration update (#8679) 2026-02-23 21:25:26 +00:00
Justin Tahara
9c62e03120 chore(ods): Automated Cherry-pick backport (#8642) 2026-02-23 21:15:09 +00:00
Nikolas Garza
0937305064 feat(scim): Okta compatibility + provider abstraction (#8568) 2026-02-23 21:09:18 +00:00
Wenxi
e4c06570e3 fix: domain rules for signup on cloud (#8671) 2026-02-23 20:27:37 +00:00
roshan
78fc7c86d7 fix: Handle unauthenticated state gracefully on NRF page (#8491)
Co-authored-by: Claude <noreply@anthropic.com>
2026-02-23 19:26:38 +00:00
Raunak Bhagat
84d3aea847 refactor: migrate Web Search page to SettingsLayouts + Content (#8662) 2026-02-23 13:38:37 +00:00
Danelegend
00a404d3cd feat: Add code interpreter server db model (#8669) 2026-02-23 05:09:59 +00:00
Wenxi
787cf90d96 chore: set trial api usage to 0 and show ui (#8664) 2026-02-23 01:41:23 +00:00
SubashMohan
15fe47adc5 fix: improve connector status display, agent search tool detection, and prompt formatting (#8579) 2026-02-22 07:22:06 +00:00
SubashMohan
29958f1a52 perf(open-url): parallelize URL fetching with split connect/read timeouts (#8580) 2026-02-22 06:45:26 +00:00
SubashMohan
ac7f9838bc feat: add mixed content handler for chat and image generation packets (#8494) 2026-02-22 06:12:48 +00:00
Raunak Bhagat
d0fa4b3319 fix: limit connectors section to 3 items in Chat Preferences (#8661) 2026-02-22 01:01:42 +00:00
Raunak Bhagat
3fb4fb422e feat: refreshed admin Chat Preferences page (#8488)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-21 04:52:04 +00:00
acaprau
ba5da22ea1 chore(documentation): Add comment in contributing_guides/best_practices.md about accountability in TODOs (#8659) 2026-02-21 04:47:48 +00:00
Evan Lohn
9909049047 fix: improve eager loading behavior (#8576) 2026-02-21 04:02:34 +00:00
Nikolas Garza
c516aa3e3c fix(search): eliminate DB pool exhaustion from SearchTool parallel queries (#8651) 2026-02-21 03:56:00 +00:00
Evan Lohn
5cc6220417 feat: checkpointing mid drive (#8606) 2026-02-21 03:47:28 +00:00
acaprau
15da1e0a88 feat(opensearch): Add OpenSearch in docker compose (#8611) 2026-02-21 02:57:31 +00:00
Jamison Lahman
e9ff00890b chore(fe): fix Pinned icon hover background when card is hovered (#8656) 2026-02-21 02:37:18 +00:00
Wenxi
67747a9d93 fix: unify provider retrieval into swr hook (#8609) 2026-02-21 02:15:28 +00:00
Jamison Lahman
edfc51b439 fix(fe): Truncated tooltip is disabled when not needed to be shown (#8629) 2026-02-21 02:07:18 +00:00
Evan Lohn
ac4fba947e feat: sharepoint scalability5 (#8631) 2026-02-21 01:56:11 +00:00
Nikolas Garza
c142b2db02 chore: special sauce (#8646) 2026-02-21 01:30:36 +00:00
Jamison Lahman
fb7e7e4395 fix(fe): keep focus on input on empty (#8627) 2026-02-21 00:13:26 +00:00
Justin Tahara
113f23398e fix(celery): Guardrail for User File Processing (#8633) 2026-02-20 22:43:01 +00:00
Danelegend
5a8716026a feat: Python tool call replay packets (#8649) 2026-02-20 22:28:06 +00:00
Jamison Lahman
3389140bfd chore(deps): @radix-ui/react-tooltip: v1.1.3->v1.2.8 (#8647) 2026-02-20 22:27:49 +00:00
Justin Tahara
13109e7b81 chore(llm): Removal of Retired Models + Cleanup (#8645) 2026-02-20 21:47:53 +00:00
Jessica Singh
56ad457168 chore(auth): tests cleanup (#8559) 2026-02-20 20:51:29 +00:00
Nikolas Garza
a81aea2afc feat(scim): add scim_username column to ScimUserMapping (#8635) 2026-02-20 12:44:16 -08:00
Jamison Lahman
7cb5c9c4a6 chore(fe): replace BlinkingDot with a BlinkingBar (#8640) 2026-02-20 20:02:19 +00:00
Evan Lohn
3520c58a22 feat: sharepoint scalability4 (#8551) 2026-02-20 19:35:36 +00:00
Nikolas Garza
bd9d1bfa27 chore(helm): update code-interpreter chart repo URL to python-sandbox (#8625) 2026-02-20 11:39:08 -08:00
Evan Lohn
14416cc3db fix: search tool enabled when nothing selected (#8637) 2026-02-20 19:22:28 +00:00
Jamison Lahman
d7fce14d26 chore(devtools): upgrade ods: 0.5.7->0.6.0 (#8628) 2026-02-20 10:18:27 -08:00
Wenxi
39a8d8ed05 fix: boolean form field click on text, open url tool checkbox in default assistant, and simple tooltip rendering (#8480) 2026-02-20 17:38:29 +00:00
Jamison Lahman
82f735a434 chore(fe): rm settings page top padding (#8621) 2026-02-20 17:37:19 +00:00
Jamison Lahman
aadb58518b fix(fe): popover width can fit trigger element (#8624) 2026-02-20 07:29:21 +00:00
Jamison Lahman
0755499e0f chore(fe): whitelabel logo layout followup (#8626) 2026-02-19 23:27:37 -08:00
Danelegend
27aaf977a2 feat(code interpreter): improve produced code artifacts (#8507)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2026-02-20 07:10:09 +00:00
Evan Lohn
9f707f195e feat: sharepoint scalability3 (#8537) 2026-02-20 05:34:50 +00:00
Justin Tahara
3e35570f70 feat(vertex ai): Image Gen Historical Context (#8603) 2026-02-20 05:31:41 +00:00
Jamison Lahman
53b1bf3b2c chore(fe): fix nested button hydration error (#8599) 2026-02-20 05:14:24 +00:00
Jamison Lahman
5a3fa6b648 chore(fe): fix whitelabel logo moving on sidebar close (#8577) 2026-02-20 04:30:01 +00:00
Evan Lohn
fc6a37850b feat: delta sync sharepoint (#8532)
Co-authored-by: CE11-Kishan <CE11-Kishan@users.noreply.github.com>
2026-02-20 03:26:54 +00:00
Raunak Bhagat
aa6fec3d58 Fix/pat UI (#8617) 2026-02-19 19:23:26 -08:00
Raunak Bhagat
efa6005e36 feat(opal): Add widthVariant to Interactive.Container (#8610) 2026-02-20 03:14:30 +00:00
Nikolas Garza
921bfc72f4 fix(helm): route /openapi.json to api_server in nginx config (#8612) 2026-02-19 19:05:14 -08:00
Justin Tahara
812603152d feat(web): FE Changes for Brave Web Search 3/3 (#8597) 2026-02-20 02:43:30 +00:00
Raunak Bhagat
6779d8fbd7 feat(opal): Add Content layout component (#8534)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-20 02:35:27 +00:00
Justin Tahara
2c9826e4a9 feat(web): Logical Hardening for Brave Web Search 2/3 (#8595) 2026-02-20 02:09:49 +00:00
Justin Tahara
5b54687077 feat(web): Initial Framework for Brave Web Search 1/3 (#8594) 2026-02-20 01:38:19 +00:00
Raunak Bhagat
0f7e2ee674 feat(opal): Add Tag component and resync colors with Figma (#8533)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-20 01:14:52 +00:00
Danelegend
ea466648d9 feat: file preview from llm (#8604) 2026-02-20 00:47:03 +00:00
Evan Lohn
a402911ee6 feat: sharepoint scalability 1 (#8531) 2026-02-20 00:41:48 +00:00
Wenxi
7ae9ba807d fix(manage-users): exclude slack users from /users list (#8602) 2026-02-19 23:41:25 +00:00
Nikolas Garza
1f79223c42 feat(ods): add --continue flag and cp alias to cherry-pick (#8601) 2026-02-19 22:31:06 +00:00
Danelegend
c0c2247d5a feat: Modal Header no icon (#8596) 2026-02-19 21:59:18 +00:00
Danelegend
2989ceda41 feat: add file formatting reminder (#8524) 2026-02-19 21:51:06 +00:00
519 changed files with 26519 additions and 9535 deletions

View File

@@ -0,0 +1,73 @@
name: "Build Backend Image"
description: "Builds and pushes the backend Docker image with cache reuse"
inputs:
runs-on-ecr-cache:
description: "ECR cache registry from runs-on/action"
required: true
ref-name:
description: "Git ref name used for cache suffix fallback"
required: true
pr-number:
description: "Optional PR number for cache suffix"
required: false
default: ""
github-sha:
description: "Commit SHA used for cache keys"
required: true
run-id:
description: "GitHub run ID used in output image tag"
required: true
docker-username:
description: "Docker Hub username"
required: true
docker-token:
description: "Docker Hub token"
required: true
docker-no-cache:
description: "Set to 'true' to disable docker build cache"
required: false
default: "false"
runs:
using: "composite"
steps:
- name: Format branch name for cache
id: format-branch
shell: bash
env:
PR_NUMBER: ${{ inputs.pr-number }}
REF_NAME: ${{ inputs.ref-name }}
run: |
if [ -n "${PR_NUMBER}" ]; then
CACHE_SUFFIX="${PR_NUMBER}"
else
# shellcheck disable=SC2001
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
fi
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ inputs.docker-username }}
password: ${{ inputs.docker-token }}
- name: Build and push Backend Docker image
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile
push: true
tags: ${{ inputs.runs-on-ecr-cache }}:nightly-llm-it-backend-${{ inputs.run-id }}
cache-from: |
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ inputs.github-sha }}
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }}
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache
type=registry,ref=onyxdotapp/onyx-backend:latest
cache-to: |
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ inputs.github-sha }},mode=max
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache,mode=max
no-cache: ${{ inputs.docker-no-cache == 'true' }}

View File

@@ -0,0 +1,75 @@
name: "Build Integration Image"
description: "Builds and pushes the integration test image with docker bake"
inputs:
runs-on-ecr-cache:
description: "ECR cache registry from runs-on/action"
required: true
ref-name:
description: "Git ref name used for cache suffix fallback"
required: true
pr-number:
description: "Optional PR number for cache suffix"
required: false
default: ""
github-sha:
description: "Commit SHA used for cache keys"
required: true
run-id:
description: "GitHub run ID used in output image tag"
required: true
docker-username:
description: "Docker Hub username"
required: true
docker-token:
description: "Docker Hub token"
required: true
runs:
using: "composite"
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ inputs.docker-username }}
password: ${{ inputs.docker-token }}
- name: Format branch name for cache
id: format-branch
shell: bash
env:
PR_NUMBER: ${{ inputs.pr-number }}
REF_NAME: ${{ inputs.ref-name }}
run: |
if [ -n "${PR_NUMBER}" ]; then
CACHE_SUFFIX="${PR_NUMBER}"
else
# shellcheck disable=SC2001
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
fi
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
- name: Build and push integration test image with Docker Bake
shell: bash
env:
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
TAG: nightly-llm-it-${{ inputs.run-id }}
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
HEAD_SHA: ${{ inputs.github-sha }}
run: |
docker buildx bake --push \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA},mode=max \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX},mode=max \
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA} \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX} \
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA},mode=max \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX},mode=max \
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
integration

View File

@@ -0,0 +1,68 @@
name: "Build Model Server Image"
description: "Builds and pushes the model server Docker image with cache reuse"
inputs:
runs-on-ecr-cache:
description: "ECR cache registry from runs-on/action"
required: true
ref-name:
description: "Git ref name used for cache suffix fallback"
required: true
pr-number:
description: "Optional PR number for cache suffix"
required: false
default: ""
github-sha:
description: "Commit SHA used for cache keys"
required: true
run-id:
description: "GitHub run ID used in output image tag"
required: true
docker-username:
description: "Docker Hub username"
required: true
docker-token:
description: "Docker Hub token"
required: true
runs:
using: "composite"
steps:
- name: Format branch name for cache
id: format-branch
shell: bash
env:
PR_NUMBER: ${{ inputs.pr-number }}
REF_NAME: ${{ inputs.ref-name }}
run: |
if [ -n "${PR_NUMBER}" ]; then
CACHE_SUFFIX="${PR_NUMBER}"
else
# shellcheck disable=SC2001
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
fi
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ inputs.docker-username }}
password: ${{ inputs.docker-token }}
- name: Build and push Model Server Docker image
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile.model_server
push: true
tags: ${{ inputs.runs-on-ecr-cache }}:nightly-llm-it-model-server-${{ inputs.run-id }}
cache-from: |
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ inputs.github-sha }}
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }}
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache
type=registry,ref=onyxdotapp/onyx-model-server:latest
cache-to: |
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ inputs.github-sha }},mode=max
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache,mode=max

View File

@@ -0,0 +1,130 @@
name: "Run Nightly Provider Chat Test"
description: "Starts required compose services and runs nightly provider integration test"
inputs:
provider:
description: "Provider slug for NIGHTLY_LLM_PROVIDER"
required: true
models:
description: "Comma-separated model list for NIGHTLY_LLM_MODELS"
required: true
provider-api-key:
description: "API key for NIGHTLY_LLM_API_KEY"
required: false
default: ""
strict:
description: "String true/false for NIGHTLY_LLM_STRICT"
required: true
api-base:
description: "Optional NIGHTLY_LLM_API_BASE"
required: false
default: ""
api-version:
description: "Optional NIGHTLY_LLM_API_VERSION"
required: false
default: ""
deployment-name:
description: "Optional NIGHTLY_LLM_DEPLOYMENT_NAME"
required: false
default: ""
custom-config-json:
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON"
required: false
default: ""
runs-on-ecr-cache:
description: "ECR cache registry from runs-on/action"
required: true
run-id:
description: "GitHub run ID used in image tags"
required: true
docker-username:
description: "Docker Hub username"
required: true
docker-token:
description: "Docker Hub token"
required: true
runs:
using: "composite"
steps:
- name: Login to Docker Hub
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
with:
username: ${{ inputs.docker-username }}
password: ${{ inputs.docker-token }}
- name: Create .env file for Docker Compose
shell: bash
env:
ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
RUN_ID: ${{ inputs.run-id }}
run: |
cat <<EOF2 > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
LICENSE_ENFORCEMENT_ENABLED=false
AUTH_TYPE=basic
POSTGRES_POOL_PRE_PING=true
POSTGRES_USE_NULL_POOL=true
REQUIRE_EMAIL_VERIFICATION=false
DISABLE_TELEMETRY=true
INTEGRATION_TESTS_MODE=true
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
AWS_REGION_NAME=us-west-2
ONYX_BACKEND_IMAGE=${ECR_CACHE}:nightly-llm-it-backend-${RUN_ID}
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:nightly-llm-it-model-server-${RUN_ID}
EOF2
- name: Start Docker containers
shell: bash
run: |
cd deployment/docker_compose
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d --wait \
relational_db \
index \
cache \
minio \
api_server \
inference_model_server
- name: Run nightly provider integration test
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
env:
MODELS: ${{ inputs.models }}
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
NIGHTLY_LLM_API_KEY: ${{ inputs.provider-api-key }}
NIGHTLY_LLM_API_BASE: ${{ inputs.api-base }}
NIGHTLY_LLM_API_VERSION: ${{ inputs.api-version }}
NIGHTLY_LLM_DEPLOYMENT_NAME: ${{ inputs.deployment-name }}
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom-config-json }}
NIGHTLY_LLM_STRICT: ${{ inputs.strict }}
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
RUN_ID: ${{ inputs.run-id }}
with:
timeout_minutes: 20
max_attempts: 2
retry_wait_seconds: 10
command: |
docker run --rm --network onyx_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e DB_READONLY_USER=db_readonly_user \
-e DB_READONLY_PASSWORD=password \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e TEST_WEB_HOSTNAME=test-runner \
-e AWS_REGION_NAME=us-west-2 \
-e NIGHTLY_LLM_PROVIDER="${NIGHTLY_LLM_PROVIDER}" \
-e NIGHTLY_LLM_MODELS="${MODELS}" \
-e NIGHTLY_LLM_API_KEY="${NIGHTLY_LLM_API_KEY}" \
-e NIGHTLY_LLM_API_BASE="${NIGHTLY_LLM_API_BASE}" \
-e NIGHTLY_LLM_API_VERSION="${NIGHTLY_LLM_API_VERSION}" \
-e NIGHTLY_LLM_DEPLOYMENT_NAME="${NIGHTLY_LLM_DEPLOYMENT_NAME}" \
-e NIGHTLY_LLM_CUSTOM_CONFIG_JSON="${NIGHTLY_LLM_CUSTOM_CONFIG_JSON}" \
-e NIGHTLY_LLM_STRICT="${NIGHTLY_LLM_STRICT}" \
${RUNS_ON_ECR_CACHE}:nightly-llm-it-${RUN_ID} \
/app/tests/integration/tests/llm_workflows/test_nightly_provider_chat_workflow.py

View File

@@ -8,5 +8,5 @@
## Additional Options
- [ ] [Required] I have considered whether this PR needs to be cherry-picked to the latest beta branch.
- [ ] [Optional] Please cherry-pick this PR to the latest release version.
- [ ] [Optional] Override Linear Check

View File

@@ -33,7 +33,7 @@ jobs:
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
helm repo update
- name: Build chart dependencies

View File

@@ -0,0 +1,56 @@
name: Nightly LLM Provider Chat Tests
concurrency:
group: Nightly-LLM-Provider-Chat-${{ github.workflow }}-${{ github.ref_name }}
cancel-in-progress: true
on:
schedule:
# Runs daily at 10:30 UTC (2:30 AM PST / 3:30 AM PDT)
- cron: "30 10 * * *"
workflow_dispatch:
permissions:
contents: read
jobs:
provider-chat-test:
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
with:
openai_models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
anthropic_models: ${{ vars.NIGHTLY_LLM_ANTHROPIC_MODELS }}
bedrock_models: ${{ vars.NIGHTLY_LLM_BEDROCK_MODELS }}
vertex_ai_models: ${{ vars.NIGHTLY_LLM_VERTEX_AI_MODELS }}
azure_models: ${{ vars.NIGHTLY_LLM_AZURE_MODELS }}
azure_api_base: ${{ vars.NIGHTLY_LLM_AZURE_API_BASE }}
ollama_models: ${{ vars.NIGHTLY_LLM_OLLAMA_MODELS }}
openrouter_models: ${{ vars.NIGHTLY_LLM_OPENROUTER_MODELS }}
strict: true
secrets:
openai_api_key: ${{ secrets.OPENAI_API_KEY }}
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
bedrock_api_key: ${{ secrets.BEDROCK_API_KEY }}
vertex_ai_custom_config_json: ${{ secrets.NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON }}
azure_api_key: ${{ secrets.AZURE_API_KEY }}
ollama_api_key: ${{ secrets.OLLAMA_API_KEY }}
openrouter_api_key: ${{ secrets.OPENROUTER_API_KEY }}
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
notify-slack-on-failure:
needs: [provider-chat-test]
if: failure() && github.event_name == 'schedule'
runs-on: ubuntu-slim
timeout-minutes: 5
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Send Slack notification
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
failed-jobs: provider-chat-test
title: "🚨 Scheduled LLM Provider Chat Tests failed!"
ref-name: ${{ github.ref_name }}

View File

@@ -0,0 +1,161 @@
name: Post-Merge Beta Cherry-Pick
on:
push:
branches:
- main
permissions:
contents: write
pull-requests: write
jobs:
cherry-pick-to-latest-release:
outputs:
should_cherrypick: ${{ steps.gate.outputs.should_cherrypick }}
pr_number: ${{ steps.gate.outputs.pr_number }}
cherry_pick_reason: ${{ steps.run_cherry_pick.outputs.reason }}
cherry_pick_details: ${{ steps.run_cherry_pick.outputs.details }}
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Resolve merged PR and checkbox state
id: gate
env:
GH_TOKEN: ${{ github.token }}
run: |
# For the commit that triggered this workflow (HEAD on main), fetch all
# associated PRs and keep only the PR that was actually merged into main
# with this exact merge commit SHA.
pr_numbers="$(gh api "repos/${GITHUB_REPOSITORY}/commits/${GITHUB_SHA}/pulls" | jq -r --arg sha "${GITHUB_SHA}" '.[] | select(.merged_at != null and .base.ref == "main" and .merge_commit_sha == $sha) | .number')"
match_count="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | wc -l | tr -d ' ')"
pr_number="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | head -n 1)"
if [ "${match_count}" -gt 1 ]; then
echo "::warning::Multiple merged PRs matched commit ${GITHUB_SHA}. Using PR #${pr_number}."
fi
if [ -z "$pr_number" ]; then
echo "No merged PR associated with commit ${GITHUB_SHA}; skipping."
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
exit 0
fi
# Read the PR once so we can gate behavior and infer preferred actor.
pr_json="$(gh api "repos/${GITHUB_REPOSITORY}/pulls/${pr_number}")"
pr_body="$(printf '%s' "$pr_json" | jq -r '.body // ""')"
merged_by="$(printf '%s' "$pr_json" | jq -r '.merged_by.login // ""')"
echo "pr_number=$pr_number" >> "$GITHUB_OUTPUT"
echo "merged_by=$merged_by" >> "$GITHUB_OUTPUT"
if echo "$pr_body" | grep -qiE "\\[x\\][[:space:]]*(\\[[^]]+\\][[:space:]]*)?Please cherry-pick this PR to the latest release version"; then
echo "should_cherrypick=true" >> "$GITHUB_OUTPUT"
echo "Cherry-pick checkbox checked for PR #${pr_number}."
exit 0
fi
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
echo "Cherry-pick checkbox not checked for PR #${pr_number}. Skipping."
- name: Checkout repository
if: steps.gate.outputs.should_cherrypick == 'true'
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
fetch-depth: 0
persist-credentials: true
ref: main
- name: Install the latest version of uv
if: steps.gate.outputs.should_cherrypick == 'true'
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"
- name: Configure git identity
if: steps.gate.outputs.should_cherrypick == 'true'
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
- name: Create cherry-pick PR to latest release
id: run_cherry_pick
if: steps.gate.outputs.should_cherrypick == 'true'
continue-on-error: true
env:
GH_TOKEN: ${{ github.token }}
GITHUB_TOKEN: ${{ github.token }}
CHERRY_PICK_ASSIGNEE: ${{ steps.gate.outputs.merged_by }}
run: |
set -o pipefail
output_file="$(mktemp)"
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify 2>&1 | tee "$output_file"
exit_code="${PIPESTATUS[0]}"
if [ "${exit_code}" -eq 0 ]; then
echo "status=success" >> "$GITHUB_OUTPUT"
exit 0
fi
echo "status=failure" >> "$GITHUB_OUTPUT"
reason="command-failed"
if grep -qiE "merge conflict during cherry-pick|CONFLICT|could not apply|cherry-pick in progress with staged changes" "$output_file"; then
reason="merge-conflict"
fi
echo "reason=${reason}" >> "$GITHUB_OUTPUT"
{
echo "details<<EOF"
tail -n 40 "$output_file"
echo "EOF"
} >> "$GITHUB_OUTPUT"
- name: Mark workflow as failed if cherry-pick failed
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
run: |
echo "::error::Automated cherry-pick failed (${{ steps.run_cherry_pick.outputs.reason }})."
exit 1
notify-slack-on-cherry-pick-failure:
needs:
- cherry-pick-to-latest-release
if: always() && needs.cherry-pick-to-latest-release.outputs.should_cherrypick == 'true' && needs.cherry-pick-to-latest-release.result != 'success'
runs-on: ubuntu-slim
timeout-minutes: 10
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Build cherry-pick failure summary
id: failure-summary
env:
SOURCE_PR_NUMBER: ${{ needs.cherry-pick-to-latest-release.outputs.pr_number }}
CHERRY_PICK_REASON: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_reason }}
CHERRY_PICK_DETAILS: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_details }}
run: |
source_pr_url="https://github.com/${GITHUB_REPOSITORY}/pull/${SOURCE_PR_NUMBER}"
reason_text="cherry-pick command failed"
if [ "${CHERRY_PICK_REASON}" = "merge-conflict" ]; then
reason_text="merge conflict during cherry-pick"
fi
details_excerpt="$(printf '%s' "${CHERRY_PICK_DETAILS}" | tail -n 8 | tr '\n' ' ' | sed "s/[[:space:]]\\+/ /g" | sed "s/\"/'/g" | cut -c1-350)"
failed_jobs="• cherry-pick-to-latest-release\\n• source PR: ${source_pr_url}\\n• reason: ${reason_text}"
if [ -n "${details_excerpt}" ]; then
failed_jobs="${failed_jobs}\\n• excerpt: ${details_excerpt}"
fi
echo "jobs=${failed_jobs}" >> "$GITHUB_OUTPUT"
- name: Notify #cherry-pick-prs about cherry-pick failure
uses: ./.github/actions/slack-notify
with:
webhook-url: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
failed-jobs: ${{ steps.failure-summary.outputs.jobs }}
title: "🚨 Automated Cherry-Pick Failed"
ref-name: ${{ github.ref_name }}

View File

@@ -1,28 +0,0 @@
name: Require beta cherry-pick consideration
concurrency:
group: Require-Beta-Cherrypick-Consideration-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
pull_request:
types: [opened, edited, reopened, synchronize]
permissions:
contents: read
jobs:
beta-cherrypick-check:
runs-on: ubuntu-latest
timeout-minutes: 45
steps:
- name: Check PR body for beta cherry-pick consideration
env:
PR_BODY: ${{ github.event.pull_request.body }}
run: |
if echo "$PR_BODY" | grep -qiE "\\[x\\][[:space:]]*\\[Required\\][[:space:]]*I have considered whether this PR needs to be cherry[- ]picked to the latest beta branch"; then
echo "Cherry-pick consideration box is checked. Check passed."
exit 0
fi
echo "::error::Please check the 'I have considered whether this PR needs to be cherry-picked to the latest beta branch' box in the PR description."
exit 1

View File

@@ -45,9 +45,6 @@ env:
# TODO: debug why this is failing and enable
CODE_INTERPRETER_BASE_URL: http://localhost:8000
# OpenSearch
OPENSEARCH_ADMIN_PASSWORD: "StrongPassword123!"
jobs:
discover-test-dirs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
@@ -118,9 +115,9 @@ jobs:
- name: Create .env file for Docker Compose
run: |
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore
CODE_INTERPRETER_BETA_ENABLED=true
COMPOSE_PROFILES=s3-filestore,opensearch-enabled
DISABLE_TELEMETRY=true
OPENSEARCH_FOR_ONYX_ENABLED=true
EOF
- name: Set up Standard Dependencies
@@ -129,7 +126,6 @@ jobs:
docker compose \
-f docker-compose.yml \
-f docker-compose.dev.yml \
-f docker-compose.opensearch.yml \
up -d \
minio \
relational_db \

View File

@@ -91,7 +91,7 @@ jobs:
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
helm repo update
- name: Install Redis operator

View File

@@ -20,6 +20,7 @@ env:
# Test Environment Variables
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
SLACK_BOT_TOKEN_TEST_SPACE: ${{ secrets.SLACK_BOT_TOKEN_TEST_SPACE }}
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
@@ -423,6 +424,7 @@ jobs:
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e EXA_API_KEY=${EXA_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e SLACK_BOT_TOKEN_TEST_SPACE=${SLACK_BOT_TOKEN_TEST_SPACE} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
@@ -443,6 +445,7 @@ jobs:
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
-e ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${{ matrix.edition == 'ee' && 'true' || 'false' }} \
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
/app/tests/integration/${{ matrix.test-dir.path }}
@@ -701,6 +704,7 @@ jobs:
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e EXA_API_KEY=${EXA_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e SLACK_BOT_TOKEN_TEST_SPACE=${SLACK_BOT_TOKEN_TEST_SPACE} \
-e TEST_WEB_HOSTNAME=test-runner \
-e AUTH_TYPE=cloud \
-e MULTI_TENANT=true \

View File

@@ -89,6 +89,10 @@ env:
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ vars.SHAREPOINT_CLIENT_DIRECTORY_ID }}
SHAREPOINT_SITE: ${{ vars.SHAREPOINT_SITE }}
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
# Github
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}

View File

@@ -0,0 +1,282 @@
name: Reusable Nightly LLM Provider Chat Tests
on:
workflow_call:
inputs:
openai_models:
description: "Comma-separated models for openai"
required: false
default: ""
type: string
anthropic_models:
description: "Comma-separated models for anthropic"
required: false
default: ""
type: string
bedrock_models:
description: "Comma-separated models for bedrock"
required: false
default: ""
type: string
vertex_ai_models:
description: "Comma-separated models for vertex_ai"
required: false
default: ""
type: string
azure_models:
description: "Comma-separated models for azure"
required: false
default: ""
type: string
ollama_models:
description: "Comma-separated models for ollama_chat"
required: false
default: ""
type: string
openrouter_models:
description: "Comma-separated models for openrouter"
required: false
default: ""
type: string
azure_api_base:
description: "API base for azure provider"
required: false
default: ""
type: string
strict:
description: "Default NIGHTLY_LLM_STRICT passed to tests"
required: false
default: true
type: boolean
secrets:
openai_api_key:
required: false
anthropic_api_key:
required: false
bedrock_api_key:
required: false
vertex_ai_custom_config_json:
required: false
azure_api_key:
required: false
ollama_api_key:
required: false
openrouter_api_key:
required: false
DOCKER_USERNAME:
required: true
DOCKER_TOKEN:
required: true
permissions:
contents: read
jobs:
build-backend-image:
runs-on:
[
runs-on,
runner=1cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-backend-image",
"extras=ecr-cache",
]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Build backend image
uses: ./.github/actions/build-backend-image
with:
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
ref-name: ${{ github.ref_name }}
pr-number: ${{ github.event.pull_request.number }}
github-sha: ${{ github.sha }}
run-id: ${{ github.run_id }}
docker-username: ${{ secrets.DOCKER_USERNAME }}
docker-token: ${{ secrets.DOCKER_TOKEN }}
docker-no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' && 'true' || 'false' }}
build-model-server-image:
runs-on:
[
runs-on,
runner=1cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-model-server-image",
"extras=ecr-cache",
]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Build model server image
uses: ./.github/actions/build-model-server-image
with:
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
ref-name: ${{ github.ref_name }}
pr-number: ${{ github.event.pull_request.number }}
github-sha: ${{ github.sha }}
run-id: ${{ github.run_id }}
docker-username: ${{ secrets.DOCKER_USERNAME }}
docker-token: ${{ secrets.DOCKER_TOKEN }}
build-integration-image:
runs-on:
[
runs-on,
runner=2cpu-linux-arm64,
"run-id=${{ github.run_id }}-build-integration-image",
"extras=ecr-cache",
]
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Build integration image
uses: ./.github/actions/build-integration-image
with:
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
ref-name: ${{ github.ref_name }}
pr-number: ${{ github.event.pull_request.number }}
github-sha: ${{ github.sha }}
run-id: ${{ github.run_id }}
docker-username: ${{ secrets.DOCKER_USERNAME }}
docker-token: ${{ secrets.DOCKER_TOKEN }}
provider-chat-test:
needs:
[
build-backend-image,
build-model-server-image,
build-integration-image,
]
strategy:
fail-fast: false
matrix:
include:
- provider: openai
models: ${{ inputs.openai_models }}
api_key_secret: openai_api_key
custom_config_secret: ""
api_base: ""
api_version: ""
deployment_name: ""
required: true
- provider: anthropic
models: ${{ inputs.anthropic_models }}
api_key_secret: anthropic_api_key
custom_config_secret: ""
api_base: ""
api_version: ""
deployment_name: ""
required: true
- provider: bedrock
models: ${{ inputs.bedrock_models }}
api_key_secret: bedrock_api_key
custom_config_secret: ""
api_base: ""
api_version: ""
deployment_name: ""
required: false
- provider: vertex_ai
models: ${{ inputs.vertex_ai_models }}
api_key_secret: ""
custom_config_secret: vertex_ai_custom_config_json
api_base: ""
api_version: ""
deployment_name: ""
required: false
- provider: azure
models: ${{ inputs.azure_models }}
api_key_secret: azure_api_key
custom_config_secret: ""
api_base: ${{ inputs.azure_api_base }}
api_version: "2025-04-01-preview"
deployment_name: ""
required: false
- provider: ollama_chat
models: ${{ inputs.ollama_models }}
api_key_secret: ollama_api_key
custom_config_secret: ""
api_base: "https://ollama.com"
api_version: ""
deployment_name: ""
required: false
- provider: openrouter
models: ${{ inputs.openrouter_models }}
api_key_secret: openrouter_api_key
custom_config_secret: ""
api_base: "https://openrouter.ai/api/v1"
api_version: ""
deployment_name: ""
required: false
runs-on:
- runs-on
- runner=4cpu-linux-arm64
- "run-id=${{ github.run_id }}-nightly-${{ matrix.provider }}-provider-chat-test"
- extras=ecr-cache
timeout-minutes: 45
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Run nightly provider chat test
uses: ./.github/actions/run-nightly-provider-chat-test
with:
provider: ${{ matrix.provider }}
models: ${{ matrix.models }}
provider-api-key: ${{ matrix.api_key_secret && secrets[matrix.api_key_secret] || '' }}
strict: ${{ inputs.strict && 'true' || 'false' }}
api-base: ${{ matrix.api_base }}
api-version: ${{ matrix.api_version }}
deployment-name: ${{ matrix.deployment_name }}
custom-config-json: ${{ matrix.custom_config_secret && secrets[matrix.custom_config_secret] || '' }}
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
run-id: ${{ github.run_id }}
docker-username: ${{ secrets.DOCKER_USERNAME }}
docker-token: ${{ secrets.DOCKER_TOKEN }}
- name: Dump API server logs
if: always()
run: |
cd deployment/docker_compose
docker compose logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
- name: Dump all-container logs
if: always()
run: |
cd deployment/docker_compose
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
with:
name: docker-all-logs-nightly-${{ matrix.provider }}-llm-provider
path: |
${{ github.workspace }}/api_server.log
${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
if: always()
run: |
cd deployment/docker_compose
docker compose down -v

View File

@@ -548,7 +548,7 @@ class in the utils over directly calling the APIs with a library like `requests`
calling the utilities directly (e.g. do NOT create admin users with
`admin_user = UserManager.create(name="admin_user")`, instead use the `admin_user` fixture).
A great example of this type of test is `backend/tests/integration/dev_apis/test_simple_chat_api.py`.
A great example of this type of test is `backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py`.
To run them:
@@ -616,3 +616,9 @@ This is a minimal list - feel free to include more. Do NOT write code as part of
Keep it high level. You can reference certain files or functions though.
Before writing your plan, make sure to do research. Explore the relevant sections in the codebase.
## Best Practices
In addition to the other content in this file, best practices for contributing
to the codebase can be found at `contributing_guides/best_practices.md`.
Understand its contents and follow them.

View File

@@ -21,15 +21,14 @@ import sys
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, NamedTuple
from typing import NamedTuple
from alembic.config import Config
from alembic.script import ScriptDirectory
from sqlalchemy import text
from onyx.db.engine.sql_engine import is_valid_schema_name
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.db.engine.tenant_utils import get_schemas_needing_migration
from shared_configs.configs import TENANT_ID_PREFIX
@@ -105,56 +104,6 @@ def get_head_revision() -> str | None:
return script.get_current_head()
def get_schemas_needing_migration(
tenant_schemas: List[str], head_rev: str
) -> List[str]:
"""Return only schemas whose current alembic version is not at head."""
if not tenant_schemas:
return []
engine = SqlEngine.get_engine()
with engine.connect() as conn:
# Find which schemas actually have an alembic_version table
rows = conn.execute(
text(
"SELECT table_schema FROM information_schema.tables "
"WHERE table_name = 'alembic_version' "
"AND table_schema = ANY(:schemas)"
),
{"schemas": tenant_schemas},
)
schemas_with_table = set(row[0] for row in rows)
# Schemas without the table definitely need migration
needs_migration = [s for s in tenant_schemas if s not in schemas_with_table]
if not schemas_with_table:
return needs_migration
# Validate schema names before interpolating into SQL
for schema in schemas_with_table:
if not is_valid_schema_name(schema):
raise ValueError(f"Invalid schema name: {schema}")
# Single query to get every schema's current revision at once.
# Use integer tags instead of interpolating schema names into
# string literals to avoid quoting issues.
schema_list = list(schemas_with_table)
union_parts = [
f'SELECT {i} AS idx, version_num FROM "{schema}".alembic_version'
for i, schema in enumerate(schema_list)
]
rows = conn.execute(text(" UNION ALL ".join(union_parts)))
version_by_schema = {schema_list[row[0]]: row[1] for row in rows}
needs_migration.extend(
s for s in schemas_with_table if version_by_schema.get(s) != head_rev
)
return needs_migration
def run_migrations_parallel(
schemas: list[str],
max_workers: int,

View File

@@ -0,0 +1,29 @@
"""code interpreter seed
Revision ID: 07b98176f1de
Revises: 7cb492013621
Create Date: 2026-02-23 15:55:07.606784
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "07b98176f1de"
down_revision = "7cb492013621"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Seed the single instance of code_interpreter_server
# NOTE: There should only exist at most and at minimum 1 code_interpreter_server row
op.execute(
sa.text("INSERT INTO code_interpreter_server (server_enabled) VALUES (true)")
)
def downgrade() -> None:
op.execute(sa.text("DELETE FROM code_interpreter_server"))

View File

@@ -0,0 +1,28 @@
"""add scim_username to scim_user_mapping
Revision ID: 0bb4558f35df
Revises: 631fd2504136
Create Date: 2026-02-20 10:45:30.340188
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "0bb4558f35df"
down_revision = "631fd2504136"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"scim_user_mapping",
sa.Column("scim_username", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("scim_user_mapping", "scim_username")

View File

@@ -0,0 +1,69 @@
"""add python tool on default
Revision ID: 57122d037335
Revises: c0c937d5c9e5
Create Date: 2026-02-27 10:10:40.124925
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "57122d037335"
down_revision = "c0c937d5c9e5"
branch_labels = None
depends_on = None
PYTHON_TOOL_NAME = "python"
def upgrade() -> None:
conn = op.get_bind()
# Look up the PythonTool id
result = conn.execute(
sa.text("SELECT id FROM tool WHERE name = :name"),
{"name": PYTHON_TOOL_NAME},
).fetchone()
if not result:
return
tool_id = result[0]
# Attach to the default persona (id=0) if not already attached
conn.execute(
sa.text(
"""
INSERT INTO persona__tool (persona_id, tool_id)
VALUES (0, :tool_id)
ON CONFLICT DO NOTHING
"""
),
{"tool_id": tool_id},
)
def downgrade() -> None:
conn = op.get_bind()
result = conn.execute(
sa.text("SELECT id FROM tool WHERE name = :name"),
{"name": PYTHON_TOOL_NAME},
).fetchone()
if not result:
return
conn.execute(
sa.text(
"""
DELETE FROM persona__tool
WHERE persona_id = 0 AND tool_id = :tool_id
"""
),
{"tool_id": result[0]},
)

View File

@@ -0,0 +1,48 @@
"""add enterprise and name fields to scim_user_mapping
Revision ID: 7616121f6e97
Revises: 07b98176f1de
Create Date: 2026-02-23 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7616121f6e97"
down_revision = "07b98176f1de"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"scim_user_mapping",
sa.Column("department", sa.String(), nullable=True),
)
op.add_column(
"scim_user_mapping",
sa.Column("manager", sa.String(), nullable=True),
)
op.add_column(
"scim_user_mapping",
sa.Column("given_name", sa.String(), nullable=True),
)
op.add_column(
"scim_user_mapping",
sa.Column("family_name", sa.String(), nullable=True),
)
op.add_column(
"scim_user_mapping",
sa.Column("scim_emails_json", sa.Text(), nullable=True),
)
def downgrade() -> None:
op.drop_column("scim_user_mapping", "scim_emails_json")
op.drop_column("scim_user_mapping", "family_name")
op.drop_column("scim_user_mapping", "given_name")
op.drop_column("scim_user_mapping", "manager")
op.drop_column("scim_user_mapping", "department")

View File

@@ -0,0 +1,31 @@
"""code interpreter server model
Revision ID: 7cb492013621
Revises: 0bb4558f35df
Create Date: 2026-02-22 18:54:54.007265
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7cb492013621"
down_revision = "0bb4558f35df"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"code_interpreter_server",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column(
"server_enabled", sa.Boolean, nullable=False, server_default=sa.true()
),
)
def downgrade() -> None:
op.drop_table("code_interpreter_server")

View File

@@ -0,0 +1,33 @@
"""add needs_persona_sync to user_file
Revision ID: 8ffcc2bcfc11
Revises: 7616121f6e97
Create Date: 2026-02-23 10:48:48.343826
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "8ffcc2bcfc11"
down_revision = "7616121f6e97"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user_file",
sa.Column(
"needs_persona_sync",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
)
def downgrade() -> None:
op.drop_column("user_file", "needs_persona_sync")

View File

@@ -0,0 +1,70 @@
"""llm provider deprecate fields
Revision ID: c0c937d5c9e5
Revises: 8ffcc2bcfc11
Create Date: 2026-02-25 17:35:46.125102
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c0c937d5c9e5"
down_revision = "8ffcc2bcfc11"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Make default_model_name nullable (was NOT NULL)
op.alter_column(
"llm_provider",
"default_model_name",
existing_type=sa.String(),
nullable=True,
)
# Drop unique constraint on is_default_provider (defaults now tracked via LLMModelFlow)
op.drop_constraint(
"llm_provider_is_default_provider_key",
"llm_provider",
type_="unique",
)
# Remove server_default from is_default_vision_provider (was server_default=false())
op.alter_column(
"llm_provider",
"is_default_vision_provider",
existing_type=sa.Boolean(),
server_default=None,
)
def downgrade() -> None:
# Restore default_model_name to NOT NULL (set empty string for any NULLs first)
op.execute(
"UPDATE llm_provider SET default_model_name = '' WHERE default_model_name IS NULL"
)
op.alter_column(
"llm_provider",
"default_model_name",
existing_type=sa.String(),
nullable=False,
)
# Restore unique constraint on is_default_provider
op.create_unique_constraint(
"llm_provider_is_default_provider_key",
"llm_provider",
["is_default_provider"],
)
# Restore server_default for is_default_vision_provider
op.alter_column(
"llm_provider",
"is_default_vision_provider",
existing_type=sa.Boolean(),
server_default=sa.false(),
)

View File

@@ -34,6 +34,7 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert
from ee.onyx.server.scim.filtering import ScimFilter
from ee.onyx.server.scim.filtering import ScimFilterOperator
from ee.onyx.server.scim.models import ScimMappingFields
from onyx.db.dal import DAL
from onyx.db.models import ScimGroupMapping
from onyx.db.models import ScimToken
@@ -127,9 +128,21 @@ class ScimDAL(DAL):
self,
external_id: str,
user_id: UUID,
scim_username: str | None = None,
fields: ScimMappingFields | None = None,
) -> ScimUserMapping:
"""Create a mapping between a SCIM externalId and an Onyx user."""
mapping = ScimUserMapping(external_id=external_id, user_id=user_id)
f = fields or ScimMappingFields()
mapping = ScimUserMapping(
external_id=external_id,
user_id=user_id,
scim_username=scim_username,
department=f.department,
manager=f.manager,
given_name=f.given_name,
family_name=f.family_name,
scim_emails_json=f.scim_emails_json,
)
self._session.add(mapping)
self._session.flush()
return mapping
@@ -248,11 +261,11 @@ class ScimDAL(DAL):
scim_filter: ScimFilter | None,
start_index: int = 1,
count: int = 100,
) -> tuple[list[tuple[User, str | None]], int]:
) -> tuple[list[tuple[User, ScimUserMapping | None]], int]:
"""Query users with optional SCIM filter and pagination.
Returns:
A tuple of (list of (user, external_id) pairs, total_count).
A tuple of (list of (user, mapping) pairs, total_count).
Raises:
ValueError: If the filter uses an unsupported attribute.
@@ -292,33 +305,117 @@ class ScimDAL(DAL):
users = list(
self._session.scalars(
query.order_by(User.id).offset(offset).limit(count) # type: ignore[arg-type]
).all()
)
.unique()
.all()
)
# Batch-fetch external IDs to avoid N+1 queries
ext_id_map = self._get_user_external_ids([u.id for u in users])
return [(u, ext_id_map.get(u.id)) for u in users], total
# Batch-fetch SCIM mappings to avoid N+1 queries
mapping_map = self._get_user_mappings_batch([u.id for u in users])
return [(u, mapping_map.get(u.id)) for u in users], total
def sync_user_external_id(self, user_id: UUID, new_external_id: str | None) -> None:
"""Create, update, or delete the external ID mapping for a user."""
def sync_user_external_id(
self,
user_id: UUID,
new_external_id: str | None,
scim_username: str | None = None,
fields: ScimMappingFields | None = None,
) -> None:
"""Create, update, or delete the external ID mapping for a user.
When *fields* is provided, all mapping fields are written
unconditionally — including ``None`` values — so that a caller can
clear a previously-set field (e.g. removing a department).
"""
mapping = self.get_user_mapping_by_user_id(user_id)
if new_external_id:
if mapping:
if mapping.external_id != new_external_id:
mapping.external_id = new_external_id
if scim_username is not None:
mapping.scim_username = scim_username
if fields is not None:
mapping.department = fields.department
mapping.manager = fields.manager
mapping.given_name = fields.given_name
mapping.family_name = fields.family_name
mapping.scim_emails_json = fields.scim_emails_json
else:
self.create_user_mapping(external_id=new_external_id, user_id=user_id)
self.create_user_mapping(
external_id=new_external_id,
user_id=user_id,
scim_username=scim_username,
fields=fields,
)
elif mapping:
self.delete_user_mapping(mapping.id)
def _get_user_external_ids(self, user_ids: list[UUID]) -> dict[UUID, str]:
"""Batch-fetch external IDs for a list of user IDs."""
def _get_user_mappings_batch(
self, user_ids: list[UUID]
) -> dict[UUID, ScimUserMapping]:
"""Batch-fetch SCIM user mappings keyed by user ID."""
if not user_ids:
return {}
mappings = self._session.scalars(
select(ScimUserMapping).where(ScimUserMapping.user_id.in_(user_ids))
).all()
return {m.user_id: m.external_id for m in mappings}
return {m.user_id: m for m in mappings}
def get_user_groups(self, user_id: UUID) -> list[tuple[int, str]]:
"""Get groups a user belongs to as ``(group_id, group_name)`` pairs.
Excludes groups marked for deletion.
"""
rels = self._session.scalars(
select(User__UserGroup).where(User__UserGroup.user_id == user_id)
).all()
group_ids = [r.user_group_id for r in rels]
if not group_ids:
return []
groups = self._session.scalars(
select(UserGroup).where(
UserGroup.id.in_(group_ids),
UserGroup.is_up_for_deletion.is_(False),
)
).all()
return [(g.id, g.name) for g in groups]
def get_users_groups_batch(
self, user_ids: list[UUID]
) -> dict[UUID, list[tuple[int, str]]]:
"""Batch-fetch group memberships for multiple users.
Returns a mapping of ``user_id → [(group_id, group_name), ...]``.
Avoids N+1 queries when building user list responses.
"""
if not user_ids:
return {}
rels = self._session.scalars(
select(User__UserGroup).where(User__UserGroup.user_id.in_(user_ids))
).all()
group_ids = list({r.user_group_id for r in rels})
if not group_ids:
return {}
groups = self._session.scalars(
select(UserGroup).where(
UserGroup.id.in_(group_ids),
UserGroup.is_up_for_deletion.is_(False),
)
).all()
groups_by_id = {g.id: g.name for g in groups}
result: dict[UUID, list[tuple[int, str]]] = {}
for r in rels:
if r.user_id and r.user_group_id in groups_by_id:
result.setdefault(r.user_id, []).append(
(r.user_group_id, groups_by_id[r.user_group_id])
)
return result
# ------------------------------------------------------------------
# Group mapping operations
@@ -483,9 +580,13 @@ class ScimDAL(DAL):
if not user_ids:
return []
users = self._session.scalars(
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
).all()
users = (
self._session.scalars(
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
)
.unique()
.all()
)
users_by_id = {u.id: u for u in users}
return [
@@ -504,9 +605,13 @@ class ScimDAL(DAL):
"""
if not uuids:
return []
existing_users = self._session.scalars(
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
).all()
existing_users = (
self._session.scalars(
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
)
.unique()
.all()
)
existing_ids = {u.id for u in existing_users}
return [uid for uid in uuids if uid not in existing_ids]

View File

@@ -9,6 +9,7 @@ from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from ee.onyx.server.user_group.models import SetCuratorRequest
@@ -18,11 +19,15 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import Credential__UserGroup
from onyx.db.models import Document
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.db.models import DocumentSet
from onyx.db.models import DocumentSet__UserGroup
from onyx.db.models import FederatedConnector__DocumentSet
from onyx.db.models import LLMProvider__UserGroup
from onyx.db.models import Persona
from onyx.db.models import Persona__UserGroup
from onyx.db.models import TokenRateLimit__UserGroup
from onyx.db.models import User
@@ -195,8 +200,60 @@ def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | Non
return db_session.scalar(stmt)
def _add_user_group_snapshot_eager_loads(
stmt: Select,
) -> Select:
"""Add eager loading options needed by UserGroup.from_model snapshot creation."""
return stmt.options(
selectinload(UserGroup.users),
selectinload(UserGroup.user_group_relationships),
selectinload(UserGroup.cc_pair_relationships)
.selectinload(UserGroup__ConnectorCredentialPair.cc_pair)
.options(
selectinload(ConnectorCredentialPair.connector),
selectinload(ConnectorCredentialPair.credential).selectinload(
Credential.user
),
),
selectinload(UserGroup.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(UserGroup.personas).options(
selectinload(Persona.tools),
selectinload(Persona.hierarchy_nodes),
selectinload(Persona.attached_documents).selectinload(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.user),
selectinload(Persona.user_files),
selectinload(Persona.users),
selectinload(Persona.groups),
),
)
def fetch_user_groups(
db_session: Session, only_up_to_date: bool = True
db_session: Session,
only_up_to_date: bool = True,
eager_load_for_snapshot: bool = False,
) -> Sequence[UserGroup]:
"""
Fetches user groups from the database.
@@ -209,6 +266,8 @@ def fetch_user_groups(
db_session (Session): The SQLAlchemy session used to query the database.
only_up_to_date (bool, optional): Flag to determine whether to filter the results
to include only up to date user groups. Defaults to `True`.
eager_load_for_snapshot: If True, adds eager loading for all relationships
needed by UserGroup.from_model snapshot creation.
Returns:
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
@@ -216,11 +275,16 @@ def fetch_user_groups(
stmt = select(UserGroup)
if only_up_to_date:
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
return db_session.scalars(stmt).all()
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
def fetch_user_groups_for_user(
db_session: Session, user_id: UUID, only_curator_groups: bool = False
db_session: Session,
user_id: UUID,
only_curator_groups: bool = False,
eager_load_for_snapshot: bool = False,
) -> Sequence[UserGroup]:
stmt = (
select(UserGroup)
@@ -230,7 +294,9 @@ def fetch_user_groups_for_user(
)
if only_curator_groups:
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
return db_session.scalars(stmt).all()
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
def construct_document_id_select_by_usergroup(

View File

@@ -1,9 +1,13 @@
from collections.abc import Generator
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.sharepoint.permission_utils import (
get_sharepoint_external_groups,
)
from onyx.configs.app_configs import SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
from onyx.connectors.sharepoint.connector import acquire_token_for_rest
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
@@ -43,14 +47,27 @@ def sharepoint_group_sync(
logger.info(f"Processing {len(site_descriptors)} sites for group sync")
# Process each site
enumerate_all = connector_config.get(
"exhaustive_ad_enumeration", SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION
)
msal_app = connector.msal_app
sp_tenant_domain = connector.sp_tenant_domain
sp_domain_suffix = connector.sharepoint_domain_suffix
for site_descriptor in site_descriptors:
logger.debug(f"Processing site: {site_descriptor.url}")
ctx = connector._create_rest_client_context(site_descriptor.url)
ctx = ClientContext(site_descriptor.url).with_access_token(
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain, sp_domain_suffix)
)
# Get external groups for this site
external_groups = get_sharepoint_external_groups(ctx, connector.graph_client)
external_groups = get_sharepoint_external_groups(
ctx,
connector.graph_client,
graph_api_base=connector.graph_api_base,
get_access_token=connector._get_graph_access_token,
enumerate_all_ad_groups=enumerate_all,
)
# Yield each group
for group in external_groups:

View File

@@ -1,9 +1,12 @@
import re
import time
from collections import deque
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from urllib.parse import unquote
from urllib.parse import urlparse
import requests as _requests
from office365.graph_client import GraphClient # type: ignore[import-untyped]
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
from office365.runtime.client_request import ClientRequestException # type: ignore
@@ -14,7 +17,10 @@ from pydantic import BaseModel
from ee.onyx.db.external_perm import ExternalUserGroup
from onyx.access.models import ExternalAccess
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS
from onyx.configs.constants import DocumentSource
from onyx.connectors.sharepoint.connector import GRAPH_API_MAX_RETRIES
from onyx.connectors.sharepoint.connector import GRAPH_API_RETRYABLE_STATUSES
from onyx.connectors.sharepoint.connector import SHARED_DOCUMENTS_MAP_REVERSE
from onyx.connectors.sharepoint.connector import sleep_and_retry
from onyx.utils.logger import setup_logger
@@ -33,6 +39,70 @@ LIMITED_ACCESS_ROLE_TYPES = [1, 9]
LIMITED_ACCESS_ROLE_NAMES = ["Limited Access", "Web-Only Limited Access"]
AD_GROUP_ENUMERATION_THRESHOLD = 100_000
def _graph_api_get(
url: str,
get_access_token: Callable[[], str],
params: dict[str, str] | None = None,
) -> dict[str, Any]:
"""Authenticated Graph API GET with retry on transient errors."""
for attempt in range(GRAPH_API_MAX_RETRIES + 1):
access_token = get_access_token()
headers = {"Authorization": f"Bearer {access_token}"}
try:
resp = _requests.get(
url, headers=headers, params=params, timeout=REQUEST_TIMEOUT_SECONDS
)
if (
resp.status_code in GRAPH_API_RETRYABLE_STATUSES
and attempt < GRAPH_API_MAX_RETRIES
):
wait = min(int(resp.headers.get("Retry-After", str(2**attempt))), 60)
logger.warning(
f"Graph API {resp.status_code} on attempt {attempt + 1}, "
f"retrying in {wait}s: {url}"
)
time.sleep(wait)
continue
resp.raise_for_status()
return resp.json()
except (_requests.ConnectionError, _requests.Timeout, _requests.HTTPError):
if attempt < GRAPH_API_MAX_RETRIES:
wait = min(2**attempt, 60)
logger.warning(
f"Graph API connection error on attempt {attempt + 1}, "
f"retrying in {wait}s: {url}"
)
time.sleep(wait)
continue
raise
raise RuntimeError(
f"Graph API request failed after {GRAPH_API_MAX_RETRIES + 1} attempts: {url}"
)
def _iter_graph_collection(
initial_url: str,
get_access_token: Callable[[], str],
params: dict[str, str] | None = None,
) -> Generator[dict[str, Any], None, None]:
"""Paginate through a Graph API collection, yielding items one at a time."""
url: str | None = initial_url
while url:
data = _graph_api_get(url, get_access_token, params)
params = None
yield from data.get("value", [])
url = data.get("@odata.nextLink")
def _normalize_email(email: str) -> str:
if MICROSOFT_DOMAIN in email:
return email.replace(MICROSOFT_DOMAIN, "")
return email
class SharepointGroup(BaseModel):
model_config = {"frozen": True}
@@ -527,8 +597,12 @@ def get_external_access_from_sharepoint(
)
elif site_page:
site_url = site_page.get("webUrl")
# Prefer server-relative URL to avoid OData filters that break on apostrophes
server_relative_url = unquote(urlparse(site_url).path)
# Keep percent-encoding intact so the path matches the encoding
# used by the Office365 library's SPResPath.create_relative(),
# which compares against urlparse(context.base_url).path.
# Decoding (e.g. %27 → ') causes a mismatch that duplicates
# the site prefix in the constructed URL.
server_relative_url = urlparse(site_url).path
file_obj = client_context.web.get_file_by_server_relative_url(
server_relative_url
)
@@ -572,8 +646,65 @@ def get_external_access_from_sharepoint(
)
def _enumerate_ad_groups_paginated(
get_access_token: Callable[[], str],
already_resolved: set[str],
graph_api_base: str,
) -> Generator[ExternalUserGroup, None, None]:
"""Paginate through all Azure AD groups and yield ExternalUserGroup for each.
Skips groups whose suffixed name is already in *already_resolved*.
Stops early if the number of groups exceeds AD_GROUP_ENUMERATION_THRESHOLD.
"""
groups_url = f"{graph_api_base}/groups"
groups_params: dict[str, str] = {"$select": "id,displayName", "$top": "999"}
total_groups = 0
for group_json in _iter_graph_collection(
groups_url, get_access_token, groups_params
):
group_id: str = group_json.get("id", "")
display_name: str = group_json.get("displayName", "")
if not group_id or not display_name:
continue
total_groups += 1
if total_groups > AD_GROUP_ENUMERATION_THRESHOLD:
logger.warning(
f"Azure AD group enumeration exceeded {AD_GROUP_ENUMERATION_THRESHOLD} "
"groups — stopping to avoid excessive memory/API usage. "
"Remaining groups will be resolved from role assignments only."
)
return
name = f"{display_name}_{group_id}"
if name in already_resolved:
continue
member_emails: list[str] = []
members_url = f"{graph_api_base}/groups/{group_id}/members"
members_params: dict[str, str] = {
"$select": "userPrincipalName,mail",
"$top": "999",
}
for member_json in _iter_graph_collection(
members_url, get_access_token, members_params
):
email = member_json.get("userPrincipalName") or member_json.get("mail")
if email:
member_emails.append(_normalize_email(email))
yield ExternalUserGroup(id=name, user_emails=member_emails)
logger.info(f"Enumerated {total_groups} Azure AD groups via paginated Graph API")
def get_sharepoint_external_groups(
client_context: ClientContext, graph_client: GraphClient
client_context: ClientContext,
graph_client: GraphClient,
graph_api_base: str,
get_access_token: Callable[[], str] | None = None,
enumerate_all_ad_groups: bool = False,
) -> list[ExternalUserGroup]:
groups: set[SharepointGroup] = set()
@@ -629,57 +760,22 @@ def get_sharepoint_external_groups(
client_context, graph_client, groups, is_group_sync=True
)
# get all Azure AD groups because if any group is assigned to the drive item, we don't want to miss them
# We can't assign sharepoint groups to drive items or drives, so we don't need to get all sharepoint groups
azure_ad_groups = sleep_and_retry(
graph_client.groups.get_all(page_loaded=lambda _: None),
"get_sharepoint_external_groups:get_azure_ad_groups",
)
logger.info(f"Azure AD Groups: {len(azure_ad_groups)}")
identified_groups: set[str] = set(groups_and_members.groups_to_emails.keys())
ad_groups_to_emails: dict[str, set[str]] = {}
for group in azure_ad_groups:
# If the group is already identified, we don't need to get the members
if group.display_name in identified_groups:
continue
# AD groups allows same display name for multiple groups, so we need to add the GUID to the name
name = group.display_name
name = _get_group_name_with_suffix(group.id, name, graph_client)
external_user_groups: list[ExternalUserGroup] = [
ExternalUserGroup(id=group_name, user_emails=list(emails))
for group_name, emails in groups_and_members.groups_to_emails.items()
]
members = sleep_and_retry(
group.members.get_all(page_loaded=lambda _: None),
"get_sharepoint_external_groups:get_azure_ad_groups:get_members",
if not enumerate_all_ad_groups or get_access_token is None:
logger.info(
"Skipping exhaustive Azure AD group enumeration. "
"Only groups found in site role assignments are included."
)
for member in members:
member_data = member.to_json()
user_principal_name = member_data.get("userPrincipalName")
mail = member_data.get("mail")
if not ad_groups_to_emails.get(name):
ad_groups_to_emails[name] = set()
if user_principal_name:
if MICROSOFT_DOMAIN in user_principal_name:
user_principal_name = user_principal_name.replace(
MICROSOFT_DOMAIN, ""
)
ad_groups_to_emails[name].add(user_principal_name)
elif mail:
if MICROSOFT_DOMAIN in mail:
mail = mail.replace(MICROSOFT_DOMAIN, "")
ad_groups_to_emails[name].add(mail)
return external_user_groups
external_user_groups: list[ExternalUserGroup] = []
for group_name, emails in groups_and_members.groups_to_emails.items():
external_user_group = ExternalUserGroup(
id=group_name,
user_emails=list(emails),
)
external_user_groups.append(external_user_group)
for group_name, emails in ad_groups_to_emails.items():
external_user_group = ExternalUserGroup(
id=group_name,
user_emails=list(emails),
)
external_user_groups.append(external_user_group)
already_resolved = set(groups_and_members.groups_to_emails.keys())
for group in _enumerate_ad_groups_paginated(
get_access_token, already_resolved, graph_api_base
):
external_user_groups.append(group)
return external_user_groups

View File

@@ -34,7 +34,7 @@ class SendSearchQueryRequest(BaseModel):
filters: BaseFilters | None = None
num_docs_fed_to_llm_selection: int | None = None
run_query_expansion: bool = False
num_hits: int = 50
num_hits: int = 30
include_content: bool = False
stream: bool = False

View File

@@ -24,7 +24,6 @@ from onyx.db.engine.sql_engine import get_session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import User
from onyx.llm.factory import get_default_llm
from onyx.llm.factory import get_llm_for_contextual_rag
from onyx.server.usage_limits import check_llm_cost_limit_for_provider
from onyx.server.utils import get_json_line
from onyx.server.utils_vector_db import require_vector_db
@@ -35,10 +34,6 @@ logger = setup_logger()
router = APIRouter(prefix="/search")
# Temporary hardcode for latency testing of search-flow classification.
CLASSIFICATION_LLM_PROVIDER_NAME = "OpenAI"
CLASSIFICATION_LLM_MODEL_NAME = "gpt-5-nano"
@router.post("/search-flow-classification")
def search_flow_classification(
@@ -52,18 +47,7 @@ def search_flow_classification(
if len(query) > 200:
return SearchFlowClassificationResponse(is_search_flow=False)
try:
llm = get_llm_for_contextual_rag(
model_name=CLASSIFICATION_LLM_MODEL_NAME,
model_provider=CLASSIFICATION_LLM_PROVIDER_NAME,
)
except Exception:
logger.exception(
"Failed to initialize hardcoded classification LLM (%s/%s); falling back to default LLM",
CLASSIFICATION_LLM_PROVIDER_NAME,
CLASSIFICATION_LLM_MODEL_NAME,
)
llm = get_default_llm()
llm = get_default_llm()
check_llm_cost_limit_for_provider(
db_session=db_session,

View File

@@ -26,21 +26,23 @@ from sqlalchemy.orm import Session
from ee.onyx.db.scim import ScimDAL
from ee.onyx.server.scim.auth import verify_scim_token
from ee.onyx.server.scim.filtering import parse_scim_filter
from ee.onyx.server.scim.models import ScimEmail
from ee.onyx.server.scim.models import SCIM_LIST_RESPONSE_SCHEMA
from ee.onyx.server.scim.models import ScimError
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimListResponse
from ee.onyx.server.scim.models import ScimMeta
from ee.onyx.server.scim.models import ScimMappingFields
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimPatchRequest
from ee.onyx.server.scim.models import ScimResourceType
from ee.onyx.server.scim.models import ScimSchemaDefinition
from ee.onyx.server.scim.models import ScimServiceProviderConfig
from ee.onyx.server.scim.models import ScimUserResource
from ee.onyx.server.scim.patch import apply_group_patch
from ee.onyx.server.scim.patch import apply_user_patch
from ee.onyx.server.scim.patch import ScimPatchError
from ee.onyx.server.scim.providers.base import get_default_provider
from ee.onyx.server.scim.providers.base import ScimProvider
from ee.onyx.server.scim.providers.base import serialize_emails
from ee.onyx.server.scim.schema_definitions import ENTERPRISE_USER_SCHEMA_DEF
from ee.onyx.server.scim.schema_definitions import GROUP_RESOURCE_TYPE
from ee.onyx.server.scim.schema_definitions import GROUP_SCHEMA_DEF
from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
@@ -48,21 +50,45 @@ from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import ScimToken
from onyx.db.models import ScimUserMapping
from onyx.db.models import User
from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
class ScimJSONResponse(JSONResponse):
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
media_type = "application/scim+json"
# NOTE: All URL paths in this router (/ServiceProviderConfig, /ResourceTypes,
# /Schemas, /Users, /Groups) are mandated by the SCIM spec (RFC 7643/7644).
# IdPs like Okta and Azure AD hardcode these exact paths, so they cannot be
# changed to kebab-case.
scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
_pw_helper = PasswordHelper()
def _get_provider(
_token: ScimToken = Depends(verify_scim_token),
) -> ScimProvider:
"""Resolve the SCIM provider for the current request.
Currently returns OktaProvider for all requests. When multi-provider
support is added (ENG-3652), this will resolve based on token metadata
or tenant configuration — no endpoint changes required.
"""
return get_default_provider()
# ---------------------------------------------------------------------------
# Service Discovery Endpoints (unauthenticated)
# ---------------------------------------------------------------------------
@@ -75,15 +101,39 @@ def get_service_provider_config() -> ScimServiceProviderConfig:
@scim_router.get("/ResourceTypes")
def get_resource_types() -> list[ScimResourceType]:
"""List available SCIM resource types (RFC 7643 §6)."""
return [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
def get_resource_types() -> ScimJSONResponse:
"""List available SCIM resource types (RFC 7643 §6).
Wrapped in a ListResponse envelope (RFC 7644 §3.4.2) because IdPs
like Entra ID expect a JSON object, not a bare array.
"""
resources = [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
return ScimJSONResponse(
content={
"schemas": [SCIM_LIST_RESPONSE_SCHEMA],
"totalResults": len(resources),
"Resources": [
r.model_dump(exclude_none=True, by_alias=True) for r in resources
],
}
)
@scim_router.get("/Schemas")
def get_schemas() -> list[ScimSchemaDefinition]:
"""Return SCIM schema definitions (RFC 7643 §7)."""
return [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF]
def get_schemas() -> ScimJSONResponse:
"""Return SCIM schema definitions (RFC 7643 §7).
Wrapped in a ListResponse envelope (RFC 7644 §3.4.2) because IdPs
like Entra ID expect a JSON object, not a bare array.
"""
schemas = [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF, ENTERPRISE_USER_SCHEMA_DEF]
return ScimJSONResponse(
content={
"schemas": [SCIM_LIST_RESPONSE_SCHEMA],
"totalResults": len(schemas),
"Resources": [s.model_dump(exclude_none=True) for s in schemas],
}
)
# ---------------------------------------------------------------------------
@@ -91,35 +141,43 @@ def get_schemas() -> list[ScimSchemaDefinition]:
# ---------------------------------------------------------------------------
def _scim_error_response(status: int, detail: str) -> JSONResponse:
def _scim_error_response(status: int, detail: str) -> ScimJSONResponse:
"""Build a SCIM-compliant error response (RFC 7644 §3.12)."""
logger.warning("SCIM error response: status=%s detail=%s", status, detail)
body = ScimError(status=str(status), detail=detail)
return JSONResponse(
return ScimJSONResponse(
status_code=status,
content=body.model_dump(exclude_none=True),
)
def _user_to_scim(user: User, external_id: str | None = None) -> ScimUserResource:
"""Convert an Onyx User to a SCIM User resource representation."""
name = None
if user.personal_name:
parts = user.personal_name.split(" ", 1)
name = ScimName(
givenName=parts[0],
familyName=parts[1] if len(parts) > 1 else None,
formatted=user.personal_name,
)
def _parse_excluded_attributes(raw: str | None) -> set[str]:
"""Parse the ``excludedAttributes`` query parameter (RFC 7644 §3.4.2.5).
return ScimUserResource(
id=str(user.id),
externalId=external_id,
userName=user.email,
name=name,
emails=[ScimEmail(value=user.email, type="work", primary=True)],
active=user.is_active,
meta=ScimMeta(resourceType="User"),
)
Returns a set of lowercased attribute names to omit from responses.
"""
if not raw:
return set()
return {attr.strip().lower() for attr in raw.split(",") if attr.strip()}
def _apply_exclusions(
resource: ScimUserResource | ScimGroupResource,
excluded: set[str],
) -> dict:
"""Serialize a SCIM resource, omitting attributes the IdP excluded.
RFC 7644 §3.4.2.5 lets the IdP pass ``?excludedAttributes=groups,emails``
to reduce response payload size. We strip those fields after serialization
so the rest of the pipeline doesn't need to know about them.
"""
data = resource.model_dump(exclude_none=True, by_alias=True)
for attr in excluded:
# Match case-insensitively against the camelCase field names
keys_to_remove = [k for k in data if k.lower() == attr]
for k in keys_to_remove:
del data[k]
return data
def _check_seat_availability(dal: ScimDAL) -> str | None:
@@ -135,7 +193,7 @@ def _check_seat_availability(dal: ScimDAL) -> str | None:
return None
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | JSONResponse:
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | ScimJSONResponse:
"""Parse *user_id* as UUID, look up the user, or return a 404 error."""
try:
uid = UUID(user_id)
@@ -155,8 +213,94 @@ def _scim_name_to_str(name: ScimName | None) -> str | None:
"""
if not name:
return None
return name.formatted or " ".join(
part for part in [name.givenName, name.familyName] if part
# If the client explicitly provides ``formatted``, prefer it — the client
# knows what display string it wants. Otherwise build from components.
if name.formatted:
return name.formatted
parts = " ".join(part for part in [name.givenName, name.familyName] if part)
return parts or None
def _scim_resource_response(
resource: ScimUserResource | ScimGroupResource | ScimListResponse,
status_code: int = 200,
) -> ScimJSONResponse:
"""Serialize a SCIM resource as ``application/scim+json``."""
content = resource.model_dump(exclude_none=True, by_alias=True)
return ScimJSONResponse(
status_code=status_code,
content=content,
)
def _build_list_response(
resources: list[ScimUserResource | ScimGroupResource],
total: int,
start_index: int,
count: int,
excluded: set[str] | None = None,
) -> ScimListResponse | ScimJSONResponse:
"""Build a SCIM list response, optionally applying attribute exclusions.
RFC 7644 §3.4.2.5 — IdPs may request certain attributes be omitted via
the ``excludedAttributes`` query parameter.
"""
if excluded:
envelope = ScimListResponse(
totalResults=total,
startIndex=start_index,
itemsPerPage=count,
)
data = envelope.model_dump(exclude_none=True)
data["Resources"] = [_apply_exclusions(r, excluded) for r in resources]
return ScimJSONResponse(content=data)
return _scim_resource_response(
ScimListResponse(
totalResults=total,
startIndex=start_index,
itemsPerPage=count,
Resources=resources,
)
)
def _extract_enterprise_fields(
resource: ScimUserResource,
) -> tuple[str | None, str | None]:
"""Extract department and manager from enterprise extension."""
ext = resource.enterprise_extension
if not ext:
return None, None
department = ext.department
manager = ext.manager.value if ext.manager else None
return department, manager
def _mapping_to_fields(
mapping: ScimUserMapping | None,
) -> ScimMappingFields | None:
"""Extract round-trip fields from a SCIM user mapping."""
if not mapping:
return None
return ScimMappingFields(
department=mapping.department,
manager=mapping.manager,
given_name=mapping.given_name,
family_name=mapping.family_name,
scim_emails_json=mapping.scim_emails_json,
)
def _fields_from_resource(resource: ScimUserResource) -> ScimMappingFields:
"""Build mapping fields from an incoming SCIM user resource."""
department, manager = _extract_enterprise_fields(resource)
return ScimMappingFields(
department=department,
manager=manager,
given_name=resource.name.givenName if resource.name else None,
family_name=resource.name.familyName if resource.name else None,
scim_emails_json=serialize_emails(resource.emails),
)
@@ -168,14 +312,17 @@ def _scim_name_to_str(name: ScimName | None) -> str | None:
@scim_router.get("/Users", response_model=None)
def list_users(
filter: str | None = Query(None),
excludedAttributes: str | None = None,
startIndex: int = Query(1, ge=1),
count: int = Query(100, ge=0, le=500),
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimListResponse | JSONResponse:
) -> ScimListResponse | ScimJSONResponse:
"""List users with optional SCIM filter and pagination."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
dal.commit()
try:
scim_filter = parse_scim_filter(filter)
@@ -183,52 +330,79 @@ def list_users(
return _scim_error_response(400, str(e))
try:
users_with_ext_ids, total = dal.list_users(scim_filter, startIndex, count)
users_with_mappings, total = dal.list_users(scim_filter, startIndex, count)
except ValueError as e:
return _scim_error_response(400, str(e))
user_groups_map = dal.get_users_groups_batch([u.id for u, _ in users_with_mappings])
resources: list[ScimUserResource | ScimGroupResource] = [
_user_to_scim(user, ext_id) for user, ext_id in users_with_ext_ids
provider.build_user_resource(
user,
mapping.external_id if mapping else None,
groups=user_groups_map.get(user.id, []),
scim_username=mapping.scim_username if mapping else None,
fields=_mapping_to_fields(mapping),
)
for user, mapping in users_with_mappings
]
return ScimListResponse(
totalResults=total,
startIndex=startIndex,
itemsPerPage=count,
Resources=resources,
return _build_list_response(
resources,
total,
startIndex,
count,
excluded=_parse_excluded_attributes(excludedAttributes),
)
@scim_router.get("/Users/{user_id}", response_model=None)
def get_user(
user_id: str,
excludedAttributes: str | None = None,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
) -> ScimUserResource | ScimJSONResponse:
"""Get a single user by ID."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
dal.commit()
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, JSONResponse):
if isinstance(result, ScimJSONResponse):
return result
user = result
mapping = dal.get_user_mapping_by_user_id(user.id)
return _user_to_scim(user, mapping.external_id if mapping else None)
resource = provider.build_user_resource(
user,
mapping.external_id if mapping else None,
groups=dal.get_user_groups(user.id),
scim_username=mapping.scim_username if mapping else None,
fields=_mapping_to_fields(mapping),
)
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
excluded = _parse_excluded_attributes(excludedAttributes)
if excluded:
return ScimJSONResponse(content=_apply_exclusions(resource, excluded))
return _scim_resource_response(resource)
@scim_router.post("/Users", status_code=201, response_model=None)
def create_user(
user_resource: ScimUserResource,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
) -> ScimUserResource | ScimJSONResponse:
"""Create a new user from a SCIM provisioning request."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
email = user_resource.userName.strip().lower()
email = user_resource.userName.strip()
# externalId is how the IdP correlates this user on subsequent requests.
# Without it, the IdP can't find the user and will try to re-create,
@@ -264,11 +438,26 @@ def create_user(
# Create SCIM mapping (externalId is validated above, always present)
external_id = user_resource.externalId
dal.create_user_mapping(external_id=external_id, user_id=user.id)
scim_username = user_resource.userName.strip()
fields = _fields_from_resource(user_resource)
dal.create_user_mapping(
external_id=external_id,
user_id=user.id,
scim_username=scim_username,
fields=fields,
)
dal.commit()
return _user_to_scim(user, external_id)
return _scim_resource_response(
provider.build_user_resource(
user,
external_id,
scim_username=scim_username,
fields=fields,
),
status_code=201,
)
@scim_router.put("/Users/{user_id}", response_model=None)
@@ -276,14 +465,15 @@ def replace_user(
user_id: str,
user_resource: ScimUserResource,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
) -> ScimUserResource | ScimJSONResponse:
"""Replace a user entirely (RFC 7644 §3.5.1)."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, JSONResponse):
if isinstance(result, ScimJSONResponse):
return result
user = result
@@ -293,19 +483,36 @@ def replace_user(
if seat_error:
return _scim_error_response(403, seat_error)
personal_name = _scim_name_to_str(user_resource.name)
dal.update_user(
user,
email=user_resource.userName.strip().lower(),
email=user_resource.userName.strip(),
is_active=user_resource.active,
personal_name=_scim_name_to_str(user_resource.name),
personal_name=personal_name,
)
new_external_id = user_resource.externalId
dal.sync_user_external_id(user.id, new_external_id)
scim_username = user_resource.userName.strip()
fields = _fields_from_resource(user_resource)
dal.sync_user_external_id(
user.id,
new_external_id,
scim_username=scim_username,
fields=fields,
)
dal.commit()
return _user_to_scim(user, new_external_id)
return _scim_resource_response(
provider.build_user_resource(
user,
new_external_id,
groups=dal.get_user_groups(user.id),
scim_username=scim_username,
fields=fields,
)
)
@scim_router.patch("/Users/{user_id}", response_model=None)
@@ -313,8 +520,9 @@ def patch_user(
user_id: str,
patch_request: ScimPatchRequest,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimUserResource | JSONResponse:
) -> ScimUserResource | ScimJSONResponse:
"""Partially update a user (RFC 7644 §3.5.2).
This is the primary endpoint for user deprovisioning — Okta sends
@@ -324,17 +532,27 @@ def patch_user(
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, JSONResponse):
if isinstance(result, ScimJSONResponse):
return result
user = result
mapping = dal.get_user_mapping_by_user_id(user.id)
external_id = mapping.external_id if mapping else None
current_scim_username = mapping.scim_username if mapping else None
current_fields = _mapping_to_fields(mapping)
current = _user_to_scim(user, external_id)
current = provider.build_user_resource(
user,
external_id,
groups=dal.get_user_groups(user.id),
scim_username=current_scim_username,
fields=current_fields,
)
try:
patched = apply_user_patch(patch_request.Operations, current)
patched, ent_data = apply_user_patch(
patch_request.Operations, current, provider.ignored_patch_paths
)
except ScimPatchError as e:
return _scim_error_response(e.status, e.detail)
@@ -345,22 +563,60 @@ def patch_user(
if seat_error:
return _scim_error_response(403, seat_error)
# Track the scim_username — if userName was patched, update it
new_scim_username = patched.userName.strip() if patched.userName else None
# If displayName was explicitly patched (different from the original), use
# it as personal_name directly. Otherwise, derive from name components.
personal_name: str | None
if patched.displayName and patched.displayName != current.displayName:
personal_name = patched.displayName
else:
personal_name = _scim_name_to_str(patched.name)
dal.update_user(
user,
email=(
patched.userName.strip().lower()
if patched.userName.lower() != user.email
patched.userName.strip()
if patched.userName.strip().lower() != user.email.lower()
else None
),
is_active=patched.active if patched.active != user.is_active else None,
personal_name=_scim_name_to_str(patched.name),
personal_name=personal_name,
)
dal.sync_user_external_id(user.id, patched.externalId)
# Build updated fields by merging PATCH enterprise data with current values
cf = current_fields or ScimMappingFields()
fields = ScimMappingFields(
department=ent_data.get("department", cf.department),
manager=ent_data.get("manager", cf.manager),
given_name=patched.name.givenName if patched.name else cf.given_name,
family_name=patched.name.familyName if patched.name else cf.family_name,
scim_emails_json=(
serialize_emails(patched.emails)
if patched.emails is not None
else cf.scim_emails_json
),
)
dal.sync_user_external_id(
user.id,
patched.externalId,
scim_username=new_scim_username,
fields=fields,
)
dal.commit()
return _user_to_scim(user, patched.externalId)
return _scim_resource_response(
provider.build_user_resource(
user,
patched.externalId,
groups=dal.get_user_groups(user.id),
scim_username=new_scim_username,
fields=fields,
)
)
@scim_router.delete("/Users/{user_id}", status_code=204, response_model=None)
@@ -368,25 +624,29 @@ def delete_user(
user_id: str,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> Response | JSONResponse:
) -> Response | ScimJSONResponse:
"""Delete a user (RFC 7644 §3.6).
Deactivates the user and removes the SCIM mapping. Note that Okta
typically uses PATCH active=false instead of DELETE.
A second DELETE returns 404 per RFC 7644 §3.6.
"""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_user_or_404(user_id, dal)
if isinstance(result, JSONResponse):
if isinstance(result, ScimJSONResponse):
return result
user = result
dal.deactivate_user(user)
# If no SCIM mapping exists, the user was already deleted from
# SCIM's perspective — return 404 per RFC 7644 §3.6.
mapping = dal.get_user_mapping_by_user_id(user.id)
if mapping:
dal.delete_user_mapping(mapping.id)
if not mapping:
return _scim_error_response(404, f"User {user_id} not found")
dal.deactivate_user(user)
dal.delete_user_mapping(mapping.id)
dal.commit()
@@ -398,25 +658,7 @@ def delete_user(
# ---------------------------------------------------------------------------
def _group_to_scim(
group: UserGroup,
members: list[tuple[UUID, str | None]],
external_id: str | None = None,
) -> ScimGroupResource:
"""Convert an Onyx UserGroup to a SCIM Group resource."""
scim_members = [
ScimGroupMember(value=str(uid), display=email) for uid, email in members
]
return ScimGroupResource(
id=str(group.id),
externalId=external_id,
displayName=group.name,
members=scim_members,
meta=ScimMeta(resourceType="Group"),
)
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | JSONResponse:
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | ScimJSONResponse:
"""Parse *group_id* as int, look up the group, or return a 404 error."""
try:
gid = int(group_id)
@@ -471,14 +713,17 @@ def _validate_and_parse_members(
@scim_router.get("/Groups", response_model=None)
def list_groups(
filter: str | None = Query(None),
excludedAttributes: str | None = None,
startIndex: int = Query(1, ge=1),
count: int = Query(100, ge=0, le=500),
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimListResponse | JSONResponse:
) -> ScimListResponse | ScimJSONResponse:
"""List groups with optional SCIM filter and pagination."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
dal.commit()
try:
scim_filter = parse_scim_filter(filter)
@@ -491,45 +736,59 @@ def list_groups(
return _scim_error_response(400, str(e))
resources: list[ScimUserResource | ScimGroupResource] = [
_group_to_scim(group, dal.get_group_members(group.id), ext_id)
provider.build_group_resource(group, dal.get_group_members(group.id), ext_id)
for group, ext_id in groups_with_ext_ids
]
return ScimListResponse(
totalResults=total,
startIndex=startIndex,
itemsPerPage=count,
Resources=resources,
return _build_list_response(
resources,
total,
startIndex,
count,
excluded=_parse_excluded_attributes(excludedAttributes),
)
@scim_router.get("/Groups/{group_id}", response_model=None)
def get_group(
group_id: str,
excludedAttributes: str | None = None,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
) -> ScimGroupResource | ScimJSONResponse:
"""Get a single group by ID."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
dal.commit()
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, JSONResponse):
if isinstance(result, ScimJSONResponse):
return result
group = result
mapping = dal.get_group_mapping_by_group_id(group.id)
members = dal.get_group_members(group.id)
return _group_to_scim(group, members, mapping.external_id if mapping else None)
resource = provider.build_group_resource(
group, members, mapping.external_id if mapping else None
)
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
excluded = _parse_excluded_attributes(excludedAttributes)
if excluded:
return ScimJSONResponse(content=_apply_exclusions(resource, excluded))
return _scim_resource_response(resource)
@scim_router.post("/Groups", status_code=201, response_model=None)
def create_group(
group_resource: ScimGroupResource,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
) -> ScimGroupResource | ScimJSONResponse:
"""Create a new group from a SCIM provisioning request."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
@@ -565,7 +824,10 @@ def create_group(
dal.commit()
members = dal.get_group_members(db_group.id)
return _group_to_scim(db_group, members, external_id)
return _scim_resource_response(
provider.build_group_resource(db_group, members, external_id),
status_code=201,
)
@scim_router.put("/Groups/{group_id}", response_model=None)
@@ -573,14 +835,15 @@ def replace_group(
group_id: str,
group_resource: ScimGroupResource,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
) -> ScimGroupResource | ScimJSONResponse:
"""Replace a group entirely (RFC 7644 §3.5.1)."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, JSONResponse):
if isinstance(result, ScimJSONResponse):
return result
group = result
@@ -595,7 +858,9 @@ def replace_group(
dal.commit()
members = dal.get_group_members(group.id)
return _group_to_scim(group, members, group_resource.externalId)
return _scim_resource_response(
provider.build_group_resource(group, members, group_resource.externalId)
)
@scim_router.patch("/Groups/{group_id}", response_model=None)
@@ -603,8 +868,9 @@ def patch_group(
group_id: str,
patch_request: ScimPatchRequest,
_token: ScimToken = Depends(verify_scim_token),
provider: ScimProvider = Depends(_get_provider),
db_session: Session = Depends(get_session),
) -> ScimGroupResource | JSONResponse:
) -> ScimGroupResource | ScimJSONResponse:
"""Partially update a group (RFC 7644 §3.5.2).
Handles member add/remove operations from Okta and Azure AD.
@@ -613,7 +879,7 @@ def patch_group(
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, JSONResponse):
if isinstance(result, ScimJSONResponse):
return result
group = result
@@ -621,11 +887,11 @@ def patch_group(
external_id = mapping.external_id if mapping else None
current_members = dal.get_group_members(group.id)
current = _group_to_scim(group, current_members, external_id)
current = provider.build_group_resource(group, current_members, external_id)
try:
patched, added_ids, removed_ids = apply_group_patch(
patch_request.Operations, current
patch_request.Operations, current, provider.ignored_patch_paths
)
except ScimPatchError as e:
return _scim_error_response(e.status, e.detail)
@@ -652,7 +918,9 @@ def patch_group(
dal.commit()
members = dal.get_group_members(group.id)
return _group_to_scim(group, members, patched.externalId)
return _scim_resource_response(
provider.build_group_resource(group, members, patched.externalId)
)
@scim_router.delete("/Groups/{group_id}", status_code=204, response_model=None)
@@ -660,13 +928,13 @@ def delete_group(
group_id: str,
_token: ScimToken = Depends(verify_scim_token),
db_session: Session = Depends(get_session),
) -> Response | JSONResponse:
) -> Response | ScimJSONResponse:
"""Delete a group (RFC 7644 §3.6)."""
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
result = _fetch_group_or_404(group_id, dal)
if isinstance(result, JSONResponse):
if isinstance(result, ScimJSONResponse):
return result
group = result

View File

@@ -7,12 +7,14 @@ SCIM protocol schemas follow the wire format defined in:
Admin API schemas are internal to Onyx and used for SCIM token management.
"""
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import field_validator
# ---------------------------------------------------------------------------
@@ -31,6 +33,9 @@ SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA = (
)
SCIM_RESOURCE_TYPE_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:ResourceType"
SCIM_SCHEMA_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Schema"
SCIM_ENTERPRISE_USER_SCHEMA = (
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
)
# ---------------------------------------------------------------------------
@@ -63,6 +68,43 @@ class ScimMeta(BaseModel):
location: str | None = None
class ScimUserGroupRef(BaseModel):
"""Group reference within a User resource (RFC 7643 §4.1.2, read-only)."""
value: str
display: str | None = None
class ScimManagerRef(BaseModel):
"""Manager sub-attribute for the enterprise extension (RFC 7643 §4.3)."""
value: str | None = None
class ScimEnterpriseExtension(BaseModel):
"""Enterprise User extension attributes (RFC 7643 §4.3)."""
department: str | None = None
manager: ScimManagerRef | None = None
@dataclass
class ScimMappingFields:
"""Stored SCIM mapping fields that need to round-trip through the IdP.
Entra ID sends structured name components, email metadata, and enterprise
extension attributes that must be returned verbatim in subsequent GET
responses. These fields are persisted on ScimUserMapping and threaded
through the DAL, provider, and endpoint layers.
"""
department: str | None = None
manager: str | None = None
given_name: str | None = None
family_name: str | None = None
scim_emails_json: str | None = None
class ScimUserResource(BaseModel):
"""SCIM User resource representation (RFC 7643 §4.1).
@@ -71,14 +113,22 @@ class ScimUserResource(BaseModel):
to match the SCIM wire format (not Python convention).
"""
model_config = ConfigDict(populate_by_name=True)
schemas: list[str] = Field(default_factory=lambda: [SCIM_USER_SCHEMA])
id: str | None = None # Onyx's internal user ID, set on responses
externalId: str | None = None # IdP's identifier for this user
userName: str # Typically the user's email address
name: ScimName | None = None
displayName: str | None = None
emails: list[ScimEmail] = Field(default_factory=list)
active: bool = True
groups: list[ScimUserGroupRef] = Field(default_factory=list)
meta: ScimMeta | None = None
enterprise_extension: ScimEnterpriseExtension | None = Field(
default=None,
alias="urn:ietf:params:scim:schemas:extension:enterprise:2.0:User",
)
class ScimGroupMember(BaseModel):
@@ -121,12 +171,53 @@ class ScimPatchOperationType(str, Enum):
REMOVE = "remove"
class ScimPatchResourceValue(BaseModel):
"""Partial resource dict for path-less PATCH replace operations.
When an IdP sends a PATCH without a ``path``, the ``value`` is a dict
of resource attributes to set. IdPs may include read-only fields
(``id``, ``schemas``, ``meta``) alongside actual changes — these are
stripped by the provider's ``ignored_patch_paths`` before processing.
``extra="allow"`` lets unknown attributes pass through so the patch
handler can decide what to do with them (ignore or reject).
"""
model_config = ConfigDict(extra="allow")
active: bool | None = None
userName: str | None = None
displayName: str | None = None
externalId: str | None = None
name: ScimName | None = None
members: list[ScimGroupMember] | None = None
id: str | None = None
schemas: list[str] | None = None
meta: ScimMeta | None = None
ScimPatchValue = str | bool | list[ScimGroupMember] | ScimPatchResourceValue | None
class ScimPatchOperation(BaseModel):
"""Single PATCH operation (RFC 7644 §3.5.2)."""
op: ScimPatchOperationType
path: str | None = None
value: str | list[dict[str, str]] | dict[str, str | bool] | bool | None = None
value: ScimPatchValue = None
@field_validator("op", mode="before")
@classmethod
def normalize_operation(cls, v: object) -> object:
"""Normalize op to lowercase for case-insensitive matching.
Some IdPs (e.g. Entra ID) send capitalized ops like ``"Replace"``
instead of ``"replace"``. This is safe for all providers since the
enum values are lowercase. If a future provider requires other
pre-processing quirks, move patch deserialization into the provider
subclass instead of adding more special cases here.
"""
return v.lower() if isinstance(v, str) else v
class ScimPatchRequest(BaseModel):

View File

@@ -14,13 +14,70 @@ responsible for persisting changes.
from __future__ import annotations
import logging
import re
from dataclasses import dataclass
from dataclasses import field
from typing import Any
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimPatchOperation
from ee.onyx.server.scim.models import ScimPatchOperationType
from ee.onyx.server.scim.models import ScimPatchResourceValue
from ee.onyx.server.scim.models import ScimPatchValue
from ee.onyx.server.scim.models import ScimUserResource
logger = logging.getLogger(__name__)
# Lowercased enterprise extension URN for case-insensitive matching
_ENTERPRISE_URN_LOWER = SCIM_ENTERPRISE_USER_SCHEMA.lower()
# Pattern for email filter paths, e.g.:
# emails[primary eq true].value (Okta)
# emails[type eq "work"].value (Azure AD / Entra ID)
_EMAIL_FILTER_RE = re.compile(
r"^emails\[.+\]\.value$",
re.IGNORECASE,
)
# Pattern for member removal path: members[value eq "user-id"]
_MEMBER_FILTER_RE = re.compile(
r'^members\[value\s+eq\s+"([^"]+)"\]$',
re.IGNORECASE,
)
# ---------------------------------------------------------------------------
# Dispatch tables for user PATCH paths
#
# Maps lowercased SCIM path → (camelCase key, target dict name).
# "data" writes to the top-level resource dict, "name" writes to the
# name sub-object dict. This replaces the elif chains for simple fields.
# ---------------------------------------------------------------------------
_USER_REPLACE_PATHS: dict[str, tuple[str, str]] = {
"active": ("active", "data"),
"username": ("userName", "data"),
"externalid": ("externalId", "data"),
"name.givenname": ("givenName", "name"),
"name.familyname": ("familyName", "name"),
"name.formatted": ("formatted", "name"),
}
_USER_REMOVE_PATHS: dict[str, tuple[str, str]] = {
"externalid": ("externalId", "data"),
"name.givenname": ("givenName", "name"),
"name.familyname": ("familyName", "name"),
"name.formatted": ("formatted", "name"),
"displayname": ("displayName", "data"),
}
_GROUP_REPLACE_PATHS: dict[str, tuple[str, str]] = {
"displayname": ("displayName", "data"),
"externalid": ("externalId", "data"),
}
class ScimPatchError(Exception):
"""Raised when a PATCH operation cannot be applied."""
@@ -31,94 +88,223 @@ class ScimPatchError(Exception):
super().__init__(detail)
# Pattern for member removal path: members[value eq "user-id"]
_MEMBER_FILTER_RE = re.compile(
r'^members\[value\s+eq\s+"([^"]+)"\]$',
re.IGNORECASE,
)
@dataclass
class _UserPatchCtx:
"""Bundles the mutable state for user PATCH operations."""
data: dict[str, Any]
name_data: dict[str, Any]
ent_data: dict[str, str | None] = field(default_factory=dict)
# ---------------------------------------------------------------------------
# User PATCH
# ---------------------------------------------------------------------------
def apply_user_patch(
operations: list[ScimPatchOperation],
current: ScimUserResource,
) -> ScimUserResource:
ignored_paths: frozenset[str] = frozenset(),
) -> tuple[ScimUserResource, dict[str, str | None]]:
"""Apply SCIM PATCH operations to a user resource.
Returns a new ``ScimUserResource`` with the modifications applied.
The original object is not mutated.
Args:
operations: The PATCH operations to apply.
current: The current user resource state.
ignored_paths: SCIM attribute paths to silently skip (from provider).
Returns:
A tuple of (modified user resource, enterprise extension data dict).
The enterprise dict has keys ``"department"`` and ``"manager"``
with values set only when a PATCH operation touched them.
Raises:
ScimPatchError: If an operation targets an unsupported path.
"""
data = current.model_dump()
name_data = data.get("name") or {}
ctx = _UserPatchCtx(data=data, name_data=data.get("name") or {})
for op in operations:
if op.op == ScimPatchOperationType.REPLACE:
_apply_user_replace(op, data, name_data)
elif op.op == ScimPatchOperationType.ADD:
_apply_user_replace(op, data, name_data)
if op.op in (ScimPatchOperationType.REPLACE, ScimPatchOperationType.ADD):
_apply_user_replace(op, ctx, ignored_paths)
elif op.op == ScimPatchOperationType.REMOVE:
_apply_user_remove(op, ctx, ignored_paths)
else:
raise ScimPatchError(
f"Unsupported operation '{op.op.value}' on User resource"
)
data["name"] = name_data
return ScimUserResource.model_validate(data)
ctx.data["name"] = ctx.name_data
return ScimUserResource.model_validate(ctx.data), ctx.ent_data
def _apply_user_replace(
op: ScimPatchOperation,
data: dict,
name_data: dict,
ctx: _UserPatchCtx,
ignored_paths: frozenset[str],
) -> None:
"""Apply a replace/add operation to user data."""
path = (op.path or "").lower()
if not path:
# No path — value is a dict of top-level attributes to set
if isinstance(op.value, dict):
for key, val in op.value.items():
_set_user_field(key.lower(), val, data, name_data)
# No path — value is a resource dict of top-level attributes to set.
if isinstance(op.value, ScimPatchResourceValue):
for key, val in op.value.model_dump(exclude_unset=True).items():
_set_user_field(key.lower(), val, ctx, ignored_paths, strict=False)
else:
raise ScimPatchError("Replace without path requires a dict value")
return
_set_user_field(path, op.value, data, name_data)
_set_user_field(path, op.value, ctx, ignored_paths)
def _apply_user_remove(
op: ScimPatchOperation,
ctx: _UserPatchCtx,
ignored_paths: frozenset[str],
) -> None:
"""Apply a remove operation to user data — clears the target field."""
path = (op.path or "").lower()
if not path:
raise ScimPatchError("Remove operation requires a path")
if path in ignored_paths:
return
entry = _USER_REMOVE_PATHS.get(path)
if entry:
key, target = entry
target_dict = ctx.data if target == "data" else ctx.name_data
target_dict[key] = None
return
raise ScimPatchError(f"Unsupported remove path '{path}' for User PATCH")
def _set_user_field(
path: str,
value: str | bool | dict | list | None,
data: dict,
name_data: dict,
value: ScimPatchValue,
ctx: _UserPatchCtx,
ignored_paths: frozenset[str],
*,
strict: bool = True,
) -> None:
"""Set a single field on user data by SCIM path."""
if path == "active":
data["active"] = value
elif path == "username":
data["userName"] = value
elif path == "externalid":
data["externalId"] = value
elif path == "name.givenname":
name_data["givenName"] = value
elif path == "name.familyname":
name_data["familyName"] = value
elif path == "name.formatted":
name_data["formatted"] = value
elif path == "displayname":
# Some IdPs send displayName on users; map to formatted name
name_data["formatted"] = value
"""Set a single field on user data by SCIM path.
Args:
strict: When ``False`` (path-less replace), unknown attributes are
silently skipped. When ``True`` (explicit path), they raise.
"""
if path in ignored_paths:
return
# Simple field writes handled by the dispatch table
entry = _USER_REPLACE_PATHS.get(path)
if entry:
key, target = entry
target_dict = ctx.data if target == "data" else ctx.name_data
target_dict[key] = value
return
# displayName sets both the top-level field and the name.formatted sub-field
if path == "displayname":
ctx.data["displayName"] = value
ctx.name_data["formatted"] = value
elif path == "name":
if isinstance(value, dict):
for k, v in value.items():
ctx.name_data[k] = v
elif path == "emails":
if isinstance(value, list):
ctx.data["emails"] = value
elif _EMAIL_FILTER_RE.match(path):
_update_primary_email(ctx.data, value)
elif path.startswith(_ENTERPRISE_URN_LOWER):
_set_enterprise_field(path, value, ctx.ent_data)
elif not strict:
return
else:
raise ScimPatchError(f"Unsupported path '{path}' for User PATCH")
def _update_primary_email(data: dict[str, Any], value: ScimPatchValue) -> None:
"""Update the primary email entry via an email filter path."""
emails: list[dict] = data.get("emails") or []
for email_entry in emails:
if email_entry.get("primary"):
email_entry["value"] = value
break
else:
emails.append({"value": value, "type": "work", "primary": True})
data["emails"] = emails
def _to_dict(value: ScimPatchValue) -> dict | None:
"""Coerce a SCIM patch value to a plain dict if possible.
Pydantic may parse raw dicts as ``ScimPatchResourceValue`` (which uses
``extra="allow"``), so we also dump those back to a dict.
"""
if isinstance(value, dict):
return value
if isinstance(value, ScimPatchResourceValue):
return value.model_dump(exclude_unset=True)
return None
def _set_enterprise_field(
path: str,
value: ScimPatchValue,
ent_data: dict[str, str | None],
) -> None:
"""Handle enterprise extension URN paths or value dicts."""
# Full URN as key with dict value (path-less PATCH)
# e.g. key="urn:...:user", value={"department": "Eng", "manager": {...}}
if path == _ENTERPRISE_URN_LOWER:
d = _to_dict(value)
if d is not None:
if "department" in d:
ent_data["department"] = d["department"]
if "manager" in d:
mgr = d["manager"]
if isinstance(mgr, dict):
ent_data["manager"] = mgr.get("value")
return
# Dotted URN path, e.g. "urn:...:user:department"
suffix = path[len(_ENTERPRISE_URN_LOWER) :].lstrip(":").lower()
if suffix == "department":
ent_data["department"] = str(value) if value is not None else None
elif suffix == "manager":
d = _to_dict(value)
if d is not None:
ent_data["manager"] = d.get("value")
elif isinstance(value, str):
ent_data["manager"] = value
else:
# Unknown enterprise attributes are silently ignored rather than
# rejected — IdPs may send attributes we don't model yet.
logger.warning("Ignoring unknown enterprise extension attribute '%s'", suffix)
# ---------------------------------------------------------------------------
# Group PATCH
# ---------------------------------------------------------------------------
def apply_group_patch(
operations: list[ScimPatchOperation],
current: ScimGroupResource,
ignored_paths: frozenset[str] = frozenset(),
) -> tuple[ScimGroupResource, list[str], list[str]]:
"""Apply SCIM PATCH operations to a group resource.
Args:
operations: The PATCH operations to apply.
current: The current group resource state.
ignored_paths: SCIM attribute paths to silently skip (from provider).
Returns:
A tuple of (modified group, added member IDs, removed member IDs).
The caller uses the member ID lists to update the database.
@@ -133,7 +319,9 @@ def apply_group_patch(
for op in operations:
if op.op == ScimPatchOperationType.REPLACE:
_apply_group_replace(op, data, current_members, added_ids, removed_ids)
_apply_group_replace(
op, data, current_members, added_ids, removed_ids, ignored_paths
)
elif op.op == ScimPatchOperationType.ADD:
_apply_group_add(op, current_members, added_ids)
elif op.op == ScimPatchOperationType.REMOVE:
@@ -154,38 +342,48 @@ def _apply_group_replace(
current_members: list[dict],
added_ids: list[str],
removed_ids: list[str],
ignored_paths: frozenset[str],
) -> None:
"""Apply a replace operation to group data."""
path = (op.path or "").lower()
if not path:
if isinstance(op.value, dict):
for key, val in op.value.items():
if isinstance(op.value, ScimPatchResourceValue):
dumped = op.value.model_dump(exclude_unset=True)
for key, val in dumped.items():
if key.lower() == "members":
_replace_members(val, current_members, added_ids, removed_ids)
else:
_set_group_field(key.lower(), val, data)
_set_group_field(key.lower(), val, data, ignored_paths)
else:
raise ScimPatchError("Replace without path requires a dict value")
return
if path == "members":
_replace_members(op.value, current_members, added_ids, removed_ids)
_replace_members(
_members_to_dicts(op.value), current_members, added_ids, removed_ids
)
return
_set_group_field(path, op.value, data)
_set_group_field(path, op.value, data, ignored_paths)
def _members_to_dicts(
value: str | bool | list[ScimGroupMember] | ScimPatchResourceValue | None,
) -> list[dict]:
"""Convert a member list value to a list of dicts for internal processing."""
if not isinstance(value, list):
raise ScimPatchError("Replace members requires a list value")
return [m.model_dump(exclude_none=True) for m in value]
def _replace_members(
value: str | list | dict | bool | None,
value: list[dict],
current_members: list[dict],
added_ids: list[str],
removed_ids: list[str],
) -> None:
"""Replace the entire group member list."""
if not isinstance(value, list):
raise ScimPatchError("Replace members requires a list value")
old_ids = {m["value"] for m in current_members}
new_ids = {m.get("value", "") for m in value}
@@ -197,16 +395,21 @@ def _replace_members(
def _set_group_field(
path: str,
value: str | bool | dict | list | None,
value: ScimPatchValue,
data: dict,
ignored_paths: frozenset[str],
) -> None:
"""Set a single field on group data by SCIM path."""
if path == "displayname":
data["displayName"] = value
elif path == "externalid":
data["externalId"] = value
else:
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
if path in ignored_paths:
return
entry = _GROUP_REPLACE_PATHS.get(path)
if entry:
key, _ = entry
data[key] = value
return
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
def _apply_group_add(
@@ -223,8 +426,10 @@ def _apply_group_add(
if not isinstance(op.value, list):
raise ScimPatchError("Add members requires a list value")
member_dicts = [m.model_dump(exclude_none=True) for m in op.value]
existing_ids = {m["value"] for m in members}
for member_data in op.value:
for member_data in member_dicts:
member_id = member_data.get("value", "")
if member_id and member_id not in existing_ids:
members.append(member_data)

View File

@@ -0,0 +1,210 @@
"""Base SCIM provider abstraction."""
from __future__ import annotations
import json
import logging
from abc import ABC
from abc import abstractmethod
from uuid import UUID
from pydantic import ValidationError
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
from ee.onyx.server.scim.models import ScimEmail
from ee.onyx.server.scim.models import ScimEnterpriseExtension
from ee.onyx.server.scim.models import ScimGroupMember
from ee.onyx.server.scim.models import ScimGroupResource
from ee.onyx.server.scim.models import ScimManagerRef
from ee.onyx.server.scim.models import ScimMappingFields
from ee.onyx.server.scim.models import ScimMeta
from ee.onyx.server.scim.models import ScimName
from ee.onyx.server.scim.models import ScimUserGroupRef
from ee.onyx.server.scim.models import ScimUserResource
from onyx.db.models import User
from onyx.db.models import UserGroup
logger = logging.getLogger(__name__)
COMMON_IGNORED_PATCH_PATHS: frozenset[str] = frozenset(
{
"id",
"schemas",
"meta",
}
)
class ScimProvider(ABC):
"""Base class for provider-specific SCIM behavior.
Subclass this to handle IdP-specific quirks. The base class provides
RFC 7643-compliant response builders that populate all standard fields.
"""
@property
@abstractmethod
def name(self) -> str:
"""Short identifier for this provider (e.g. ``"okta"``)."""
...
@property
@abstractmethod
def ignored_patch_paths(self) -> frozenset[str]:
"""SCIM attribute paths to silently skip in PATCH value-object dicts.
IdPs may include read-only or meta fields alongside actual changes
(e.g. Okta sends ``{"id": "...", "active": false}``). Paths listed
here are silently dropped instead of raising an error.
"""
...
@property
def user_schemas(self) -> list[str]:
"""Schema URIs to include in User resource responses.
Override in subclasses to advertise additional schemas (e.g. the
enterprise extension for Entra ID).
"""
return [SCIM_USER_SCHEMA]
def build_user_resource(
self,
user: User,
external_id: str | None = None,
groups: list[tuple[int, str]] | None = None,
scim_username: str | None = None,
fields: ScimMappingFields | None = None,
) -> ScimUserResource:
"""Build a SCIM User response from an Onyx User.
Args:
user: The Onyx user model.
external_id: The IdP's external identifier for this user.
groups: List of ``(group_id, group_name)`` tuples for the
``groups`` read-only attribute. Pass ``None`` or ``[]``
for newly-created users.
scim_username: The original-case userName from the IdP. Falls
back to ``user.email`` (lowercase) when not available.
fields: Stored mapping fields that the IdP expects round-tripped.
"""
f = fields or ScimMappingFields()
group_refs = [
ScimUserGroupRef(value=str(gid), display=gname)
for gid, gname in (groups or [])
]
username = scim_username or user.email
# Build enterprise extension when at least one value is present.
# Dynamically add the enterprise URN to schemas per RFC 7643 §3.0.
enterprise_ext: ScimEnterpriseExtension | None = None
schemas = list(self.user_schemas)
if f.department is not None or f.manager is not None:
manager_ref = (
ScimManagerRef(value=f.manager) if f.manager is not None else None
)
enterprise_ext = ScimEnterpriseExtension(
department=f.department,
manager=manager_ref,
)
if SCIM_ENTERPRISE_USER_SCHEMA not in schemas:
schemas.append(SCIM_ENTERPRISE_USER_SCHEMA)
name = self.build_scim_name(user, f)
emails = _deserialize_emails(f.scim_emails_json, username)
resource = ScimUserResource(
schemas=schemas,
id=str(user.id),
externalId=external_id,
userName=username,
name=name,
displayName=user.personal_name,
emails=emails,
active=user.is_active,
groups=group_refs,
meta=ScimMeta(resourceType="User"),
)
resource.enterprise_extension = enterprise_ext
return resource
def build_group_resource(
self,
group: UserGroup,
members: list[tuple[UUID, str | None]],
external_id: str | None = None,
) -> ScimGroupResource:
"""Build a SCIM Group response from an Onyx UserGroup."""
scim_members = [
ScimGroupMember(value=str(uid), display=email) for uid, email in members
]
return ScimGroupResource(
id=str(group.id),
externalId=external_id,
displayName=group.name,
members=scim_members,
meta=ScimMeta(resourceType="Group"),
)
def build_scim_name(
self,
user: User,
fields: ScimMappingFields,
) -> ScimName | None:
"""Build SCIM name components for the response.
Round-trips stored ``given_name``/``family_name`` when available (so
the IdP gets back what it sent). Falls back to splitting
``personal_name`` for users provisioned before we stored components.
Providers may override for custom behavior.
"""
if fields.given_name is not None or fields.family_name is not None:
return ScimName(
givenName=fields.given_name,
familyName=fields.family_name,
formatted=user.personal_name,
)
if not user.personal_name:
return None
parts = user.personal_name.split(" ", 1)
return ScimName(
givenName=parts[0],
familyName=parts[1] if len(parts) > 1 else None,
formatted=user.personal_name,
)
def _deserialize_emails(stored_json: str | None, username: str) -> list[ScimEmail]:
"""Deserialize stored email entries or build a default work email."""
if stored_json:
try:
entries = json.loads(stored_json)
if isinstance(entries, list) and entries:
return [ScimEmail(**e) for e in entries]
except (json.JSONDecodeError, TypeError, ValidationError):
logger.warning(
"Corrupt scim_emails_json, falling back to default: %s", stored_json
)
return [ScimEmail(value=username, type="work", primary=True)]
def serialize_emails(emails: list[ScimEmail]) -> str | None:
"""Serialize SCIM email entries to JSON for storage."""
if not emails:
return None
return json.dumps([e.model_dump(exclude_none=True) for e in emails])
def get_default_provider() -> ScimProvider:
"""Return the default SCIM provider.
Currently returns ``OktaProvider`` since Okta is the primary supported
IdP. When provider detection is added (via token metadata or tenant
config), this can be replaced with dynamic resolution.
"""
from ee.onyx.server.scim.providers.okta import OktaProvider
return OktaProvider()

View File

@@ -0,0 +1,36 @@
"""Entra ID (Azure AD) SCIM provider."""
from __future__ import annotations
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
from ee.onyx.server.scim.providers.base import ScimProvider
_ENTRA_IGNORED_PATCH_PATHS = COMMON_IGNORED_PATCH_PATHS
class EntraProvider(ScimProvider):
"""Entra ID (Azure AD) SCIM provider.
Entra behavioral notes:
- Sends capitalized PATCH ops (``"Add"``, ``"Replace"``, ``"Remove"``)
— handled by ``ScimPatchOperation.normalize_op`` validator.
- Sends the enterprise extension URN as a key in path-less PATCH value
dicts — handled by ``_set_enterprise_field`` in ``patch.py`` to
store department/manager values.
- Expects the enterprise extension schema in ``schemas`` arrays and
``/Schemas`` + ``/ResourceTypes`` discovery endpoints.
"""
@property
def name(self) -> str:
return "entra"
@property
def ignored_patch_paths(self) -> frozenset[str]:
return _ENTRA_IGNORED_PATCH_PATHS
@property
def user_schemas(self) -> list[str]:
return [SCIM_USER_SCHEMA, SCIM_ENTERPRISE_USER_SCHEMA]

View File

@@ -0,0 +1,26 @@
"""Okta SCIM provider."""
from __future__ import annotations
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
from ee.onyx.server.scim.providers.base import ScimProvider
class OktaProvider(ScimProvider):
"""Okta SCIM provider.
Okta behavioral notes:
- Uses ``PATCH {"active": false}`` for deprovisioning (not DELETE)
- Sends path-less PATCH with value dicts containing extra fields
(``id``, ``schemas``)
- Expects ``displayName`` and ``groups`` in user responses
- Only uses ``eq`` operator for ``userName`` filter
"""
@property
def name(self) -> str:
return "okta"
@property
def ignored_patch_paths(self) -> frozenset[str]:
return COMMON_IGNORED_PATCH_PATHS

View File

@@ -4,6 +4,7 @@ Pre-built at import time — these never change at runtime. Separated from
api.py to keep the endpoint module focused on request handling.
"""
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
from ee.onyx.server.scim.models import SCIM_GROUP_SCHEMA
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
from ee.onyx.server.scim.models import ScimResourceType
@@ -20,6 +21,9 @@ USER_RESOURCE_TYPE = ScimResourceType.model_validate(
"endpoint": "/scim/v2/Users",
"description": "SCIM User resource",
"schema": SCIM_USER_SCHEMA,
"schemaExtensions": [
{"schema": SCIM_ENTERPRISE_USER_SCHEMA, "required": False}
],
}
)
@@ -104,6 +108,31 @@ USER_SCHEMA_DEF = ScimSchemaDefinition(
],
)
ENTERPRISE_USER_SCHEMA_DEF = ScimSchemaDefinition(
id=SCIM_ENTERPRISE_USER_SCHEMA,
name="EnterpriseUser",
description="Enterprise User extension (RFC 7643 §4.3)",
attributes=[
ScimSchemaAttribute(
name="department",
type="string",
description="Department.",
),
ScimSchemaAttribute(
name="manager",
type="complex",
description="The user's manager.",
subAttributes=[
ScimSchemaAttribute(
name="value",
type="string",
description="Manager user ID.",
),
],
),
],
)
GROUP_SCHEMA_DEF = ScimSchemaDefinition(
id=SCIM_GROUP_SCHEMA,
name="Group",

View File

@@ -37,12 +37,15 @@ def list_user_groups(
db_session: Session = Depends(get_session),
) -> list[UserGroup]:
if user.role == UserRole.ADMIN:
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
user_groups = fetch_user_groups(
db_session, only_up_to_date=False, eager_load_for_snapshot=True
)
else:
user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=user.id,
only_curator_groups=user.role == UserRole.CURATOR,
eager_load_for_snapshot=True,
)
return [UserGroup.from_model(user_group) for user_group in user_groups]

View File

@@ -53,7 +53,8 @@ class UserGroup(BaseModel):
id=cc_pair_relationship.cc_pair.id,
name=cc_pair_relationship.cc_pair.name,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair_relationship.cc_pair.connector
cc_pair_relationship.cc_pair.connector,
credential_ids=[cc_pair_relationship.cc_pair.credential_id],
),
credential=CredentialSnapshot.from_credential_db_model(
cc_pair_relationship.cc_pair.credential

View File

@@ -58,16 +58,27 @@ class OAuthTokenManager:
if not user_token.token_data:
raise ValueError("No token data available for refresh")
if (
self.oauth_config.client_id is None
or self.oauth_config.client_secret is None
):
raise ValueError(
"OAuth client_id and client_secret are required for token refresh"
)
token_data = self._unwrap_token_data(user_token.token_data)
data: dict[str, str] = {
"grant_type": "refresh_token",
"refresh_token": token_data["refresh_token"],
"client_id": self._unwrap_sensitive_str(self.oauth_config.client_id),
"client_secret": self._unwrap_sensitive_str(
self.oauth_config.client_secret
),
}
response = requests.post(
self.oauth_config.token_url,
data={
"grant_type": "refresh_token",
"refresh_token": token_data["refresh_token"],
"client_id": self.oauth_config.client_id,
"client_secret": self.oauth_config.client_secret,
},
data=data,
headers={"Accept": "application/json"},
)
response.raise_for_status()
@@ -115,15 +126,26 @@ class OAuthTokenManager:
def exchange_code_for_token(self, code: str, redirect_uri: str) -> dict[str, Any]:
"""Exchange authorization code for access token"""
if (
self.oauth_config.client_id is None
or self.oauth_config.client_secret is None
):
raise ValueError(
"OAuth client_id and client_secret are required for code exchange"
)
data: dict[str, str] = {
"grant_type": "authorization_code",
"code": code,
"client_id": self._unwrap_sensitive_str(self.oauth_config.client_id),
"client_secret": self._unwrap_sensitive_str(
self.oauth_config.client_secret
),
"redirect_uri": redirect_uri,
}
response = requests.post(
self.oauth_config.token_url,
data={
"grant_type": "authorization_code",
"code": code,
"client_id": self.oauth_config.client_id,
"client_secret": self.oauth_config.client_secret,
"redirect_uri": redirect_uri,
},
data=data,
headers={"Accept": "application/json"},
)
response.raise_for_status()
@@ -141,8 +163,13 @@ class OAuthTokenManager:
oauth_config: OAuthConfig, redirect_uri: str, state: str
) -> str:
"""Build OAuth authorization URL"""
if oauth_config.client_id is None:
raise ValueError("OAuth client_id is required to build authorization URL")
params: dict[str, Any] = {
"client_id": oauth_config.client_id,
"client_id": OAuthTokenManager._unwrap_sensitive_str(
oauth_config.client_id
),
"redirect_uri": redirect_uri,
"response_type": "code",
"state": state,
@@ -161,6 +188,12 @@ class OAuthTokenManager:
return f"{oauth_config.authorization_url}{separator}{urlencode(params)}"
@staticmethod
def _unwrap_sensitive_str(value: SensitiveValue[str] | str) -> str:
if isinstance(value, SensitiveValue):
return value.get_value(apply_mask=False)
return value
@staticmethod
def _unwrap_token_data(
token_data: SensitiveValue[dict[str, Any]] | dict[str, Any],

View File

@@ -277,13 +277,32 @@ def verify_email_domain(email: str) -> None:
detail="Email is not valid",
)
domain = email.split("@")[-1].lower()
local_part, domain = email.split("@")
domain = domain.lower()
if AUTH_TYPE == AuthType.CLOUD:
# Normalize googlemail.com to gmail.com (they deliver to the same inbox)
if domain == "googlemail.com":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "Please use @gmail.com instead of @googlemail.com."},
)
if "+" in local_part and domain != "onyx.app":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"reason": "Email addresses with '+' are not allowed. Please use your base email address."
},
)
# Check if email uses a disposable/temporary domain
if is_disposable_email(email):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Disposable email addresses are not allowed. Please use a permanent email address.",
detail={
"reason": "Disposable email addresses are not allowed. Please use a permanent email address."
},
)
# Check domain whitelist if configured
@@ -1671,7 +1690,10 @@ def get_oauth_router(
if redirect_url is not None:
authorize_redirect_url = redirect_url
else:
authorize_redirect_url = str(request.url_for(callback_route_name))
# Use WEB_DOMAIN instead of request.url_for() to prevent host
# header poisoning — request.url_for() trusts the Host header.
callback_path = request.app.url_path_for(callback_route_name)
authorize_redirect_url = f"{WEB_DOMAIN}{callback_path}"
next_url = request.query_params.get("next", "/")

View File

@@ -48,6 +48,7 @@ from onyx.document_index.opensearch.opensearch_document_index import (
OpenSearchDocumentIndex,
)
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
from onyx.indexing.models import IndexingSetting
from onyx.redis.redis_pool import get_redis_client
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
@@ -149,8 +150,12 @@ def migrate_chunks_from_vespa_to_opensearch_task(
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
search_settings = get_current_search_settings(db_session)
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
indexing_setting = IndexingSetting.from_db_model(search_settings)
opensearch_document_index = OpenSearchDocumentIndex(
index_name=search_settings.index_name, tenant_state=tenant_state
tenant_state=tenant_state,
index_name=search_settings.index_name,
embedding_dim=indexing_setting.final_embedding_dim,
embedding_precision=indexing_setting.embedding_precision,
)
vespa_document_index = VespaDocumentIndex(
index_name=search_settings.index_name,

View File

@@ -22,6 +22,7 @@ from onyx.document_index.vespa_constants import HIDDEN
from onyx.document_index.vespa_constants import IMAGE_FILE_NAME
from onyx.document_index.vespa_constants import METADATA_LIST
from onyx.document_index.vespa_constants import METADATA_SUFFIX
from onyx.document_index.vespa_constants import PERSONAS
from onyx.document_index.vespa_constants import PRIMARY_OWNERS
from onyx.document_index.vespa_constants import SECONDARY_OWNERS
from onyx.document_index.vespa_constants import SEMANTIC_IDENTIFIER
@@ -58,6 +59,7 @@ FIELDS_NEEDED_FOR_TRANSFORMATION: list[str] = [
METADATA_SUFFIX,
DOCUMENT_SETS,
USER_PROJECT,
PERSONAS,
PRIMARY_OWNERS,
SECONDARY_OWNERS,
ACCESS_CONTROL_LIST,
@@ -276,6 +278,7 @@ def transform_vespa_chunks_to_opensearch_chunks(
)
)
user_projects: list[int] | None = vespa_chunk.get(USER_PROJECT)
personas: list[int] | None = vespa_chunk.get(PERSONAS)
primary_owners: list[str] | None = vespa_chunk.get(PRIMARY_OWNERS)
secondary_owners: list[str] | None = vespa_chunk.get(SECONDARY_OWNERS)
@@ -325,6 +328,7 @@ def transform_vespa_chunks_to_opensearch_chunks(
metadata_suffix=metadata_suffix,
document_sets=document_sets,
user_projects=user_projects,
personas=personas,
primary_owners=primary_owners,
secondary_owners=secondary_owners,
tenant_id=tenant_state,

View File

@@ -5,14 +5,18 @@ from uuid import UUID
import httpx
import sqlalchemy as sa
from celery import Celery
from celery import shared_task
from celery import Task
from redis import Redis
from redis.lock import Lock as RedisLock
from retry import retry
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.configs.app_configs import DISABLE_VECTOR_DB
@@ -21,12 +25,16 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
from onyx.configs.constants import USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH
from onyx.connectors.file.connector import LocalFileConnector
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
@@ -57,14 +65,73 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
def _user_file_queued_key(user_file_id: str | UUID) -> str:
"""Key that exists while a process_single_user_file task is sitting in the queue.
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
before enqueuing and the worker deletes it as its first action. This prevents
the beat from adding duplicate tasks for files that already have a live task
in flight.
"""
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
def user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
def _user_file_project_sync_queued_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_QUEUED_PREFIX}:{user_file_id}"
def _user_file_delete_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_DELETE_LOCK_PREFIX}:{user_file_id}"
def get_user_file_project_sync_queue_depth(celery_app: Celery) -> int:
redis_celery: Redis = celery_app.broker_connection().channel().client # type: ignore
return celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, redis_celery
)
def enqueue_user_file_project_sync_task(
*,
celery_app: Celery,
redis_client: Redis,
user_file_id: str | UUID,
tenant_id: str,
priority: OnyxCeleryPriority = OnyxCeleryPriority.HIGH,
) -> bool:
"""Enqueue a project-sync task if no matching queued task already exists."""
queued_key = _user_file_project_sync_queued_key(user_file_id)
# NX+EX gives us atomic dedupe and a self-healing TTL.
queued_guard_set = redis_client.set(
queued_key,
1,
nx=True,
ex=CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES,
)
if not queued_guard_set:
return False
try:
celery_app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
priority=priority,
expires=CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES,
)
except Exception:
# Roll back the queued guard if task publish fails.
redis_client.delete(queued_key)
raise
return True
@retry(tries=3, delay=1, backoff=2, jitter=(0.0, 1.0))
def _visit_chunks(
*,
@@ -120,7 +187,24 @@ def _get_document_chunk_count(
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
Uses direct Redis locks to avoid overlapping runs.
Three mechanisms prevent queue runaway:
1. **Queue depth backpressure** if the broker queue already has more than
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
entirely. Workers are clearly behind; adding more tasks would only make
the backlog worse.
2. **Per-file queued guard** before enqueuing a task we set a short-lived
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
already exists the file already has a live task in the queue, so we skip
it. The worker deletes the key the moment it picks up the task so the
next beat cycle can re-enqueue if the file is still PROCESSING.
3. **Task expiry** every enqueued task carries an `expires` value equal to
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
the queue after that deadline, Celery discards it without touching the DB.
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
Redis restart), stale tasks evict themselves rather than piling up forever.
"""
task_logger.info("check_user_file_processing - Starting")
@@ -135,7 +219,21 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
return None
enqueued = 0
skipped_guard = 0
try:
# --- Protection 1: queue depth backpressure ---
r_celery = self.app.broker_connection().channel().client # type: ignore
queue_len = celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
)
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
task_logger.warning(
f"check_user_file_processing - Queue depth {queue_len} exceeds "
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
f"tenant={tenant_id}"
)
return None
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
@@ -148,12 +246,35 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
)
for user_file_id in user_file_ids:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
# --- Protection 2: per-file queued guard ---
queued_key = _user_file_queued_key(user_file_id)
guard_set = redis_client.set(
queued_key,
1,
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
nx=True,
)
if not guard_set:
skipped_guard += 1
continue
# --- Protection 3: task expiry ---
# If task submission fails, clear the guard immediately so the
# next beat cycle can retry enqueuing this file.
try:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={
"user_file_id": str(user_file_id),
"tenant_id": tenant_id,
},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
)
except Exception:
redis_client.delete(queued_key)
raise
enqueued += 1
finally:
@@ -161,7 +282,8 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
lock.release()
task_logger.info(
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
f"tasks for tenant={tenant_id}"
)
return None
@@ -304,6 +426,12 @@ def process_single_user_file(
start = time.monotonic()
redis_client = get_redis_client(tenant_id=tenant_id)
# Clear the "queued" guard set by the beat generator so that the next beat
# cycle can re-enqueue this file if it is still in PROCESSING state after
# this task completes or fails.
redis_client.delete(_user_file_queued_key(user_file_id))
file_lock: RedisLock = redis_client.lock(
_user_file_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
@@ -557,8 +685,8 @@ def process_single_user_file_delete(
ignore_result=True,
)
def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
"""Scan for user files with PROJECT_SYNC status and enqueue per-file tasks."""
task_logger.info("check_for_user_file_project_sync - Starting")
"""Scan for user files needing project sync and enqueue per-file tasks."""
task_logger.info("Starting")
redis_client = get_redis_client(tenant_id=tenant_id)
lock: RedisLock = redis_client.lock(
@@ -570,13 +698,25 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
return None
enqueued = 0
skipped_guard = 0
try:
queue_depth = get_user_file_project_sync_queue_depth(self.app)
if queue_depth > USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH:
task_logger.warning(
f"Queue depth {queue_depth} exceeds "
f"{USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH}, skipping enqueue for tenant={tenant_id}"
)
return None
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
select(UserFile.id).where(
sa.and_(
UserFile.needs_project_sync.is_(True),
sa.or_(
UserFile.needs_project_sync.is_(True),
UserFile.needs_persona_sync.is_(True),
),
UserFile.status == UserFileStatus.COMPLETED,
)
)
@@ -586,19 +726,23 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
)
for user_file_id in user_file_ids:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC,
if not enqueue_user_file_project_sync_task(
celery_app=self.app,
redis_client=redis_client,
user_file_id=user_file_id,
tenant_id=tenant_id,
priority=OnyxCeleryPriority.HIGH,
)
):
skipped_guard += 1
continue
enqueued += 1
finally:
if lock.owned():
lock.release()
task_logger.info(
f"check_for_user_file_project_sync - Enqueued {enqueued} tasks for tenant={tenant_id}"
f"Enqueued {enqueued} "
f"Skipped guard {skipped_guard} tasks for tenant={tenant_id}"
)
return None
@@ -617,8 +761,10 @@ def process_single_user_file_project_sync(
)
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
file_lock: RedisLock = redis_client.lock(
_user_file_project_sync_lock_key(user_file_id),
user_file_project_sync_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
)
@@ -630,7 +776,11 @@ def process_single_user_file_project_sync(
try:
with get_session_with_current_tenant() as db_session:
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
user_file = db_session.execute(
select(UserFile)
.where(UserFile.id == _as_uuid(user_file_id))
.options(selectinload(UserFile.assistants))
).scalar_one_or_none()
if not user_file:
task_logger.info(
f"process_single_user_file_project_sync - User file not found id={user_file_id}"
@@ -658,13 +808,17 @@ def process_single_user_file_project_sync(
]
project_ids = [project.id for project in user_file.projects]
persona_ids = [p.id for p in user_file.assistants if not p.deleted]
for retry_document_index in retry_document_indices:
retry_document_index.update_single(
doc_id=str(user_file.id),
tenant_id=tenant_id,
chunk_count=user_file.chunk_count,
fields=None,
user_fields=VespaDocumentUserFields(user_projects=project_ids),
user_fields=VespaDocumentUserFields(
user_projects=project_ids,
personas=persona_ids,
),
)
task_logger.info(
@@ -672,6 +826,7 @@ def process_single_user_file_project_sync(
)
user_file.needs_project_sync = False
user_file.needs_persona_sync = False
user_file.last_project_sync_at = datetime.datetime.now(
datetime.timezone.utc
)

View File

@@ -58,6 +58,8 @@ from onyx.file_store.document_batch_storage import DocumentBatchStorage
from onyx.file_store.document_batch_storage import get_document_batch_storage
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
from onyx.indexing.postgres_sanitization import sanitize_document_for_postgres
from onyx.indexing.postgres_sanitization import sanitize_hierarchy_nodes_for_postgres
from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch
from onyx.redis.redis_hierarchy import ensure_source_node_exists
from onyx.redis.redis_hierarchy import get_node_id_from_raw_id
@@ -156,36 +158,7 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
logger.warning(
f"doc {doc.id} too large, Document size: {sys.getsizeof(doc)}"
)
cleaned_doc = doc.model_copy()
# Postgres cannot handle NUL characters in text fields
if "\x00" in cleaned_doc.id:
logger.warning(f"NUL characters found in document ID: {cleaned_doc.id}")
cleaned_doc.id = cleaned_doc.id.replace("\x00", "")
if cleaned_doc.title and "\x00" in cleaned_doc.title:
logger.warning(
f"NUL characters found in document title: {cleaned_doc.title}"
)
cleaned_doc.title = cleaned_doc.title.replace("\x00", "")
if "\x00" in cleaned_doc.semantic_identifier:
logger.warning(
f"NUL characters found in document semantic identifier: {cleaned_doc.semantic_identifier}"
)
cleaned_doc.semantic_identifier = cleaned_doc.semantic_identifier.replace(
"\x00", ""
)
for section in cleaned_doc.sections:
if section.link is not None:
section.link = section.link.replace("\x00", "")
# since text can be longer, just replace to avoid double scan
if isinstance(section, TextSection) and section.text is not None:
section.text = section.text.replace("\x00", "")
cleaned_batch.append(cleaned_doc)
cleaned_batch.append(sanitize_document_for_postgres(doc))
return cleaned_batch
@@ -602,10 +575,13 @@ def connector_document_extraction(
# Process hierarchy nodes batch - upsert to Postgres and cache in Redis
if hierarchy_node_batch:
hierarchy_node_batch_cleaned = (
sanitize_hierarchy_nodes_for_postgres(hierarchy_node_batch)
)
with get_session_with_current_tenant() as db_session:
upserted_nodes = upsert_hierarchy_nodes_batch(
db_session=db_session,
nodes=hierarchy_node_batch,
nodes=hierarchy_node_batch_cleaned,
source=db_connector.source,
commit=True,
is_connector_public=is_connector_public,
@@ -624,7 +600,7 @@ def connector_document_extraction(
)
logger.debug(
f"Persisted and cached {len(hierarchy_node_batch)} hierarchy nodes "
f"Persisted and cached {len(hierarchy_node_batch_cleaned)} hierarchy nodes "
f"for attempt={index_attempt_id}"
)

View File

@@ -3,7 +3,6 @@ import time
from collections.abc import Callable
from collections.abc import Generator
from queue import Empty
from typing import Any
from onyx.chat.citation_processor import CitationMapping
from onyx.chat.emitter import Emitter
@@ -163,13 +162,11 @@ class ChatStateContainer:
def run_chat_loop_with_state_containers(
func: Callable[..., None],
chat_loop_func: Callable[[Emitter, ChatStateContainer], None],
completion_callback: Callable[[ChatStateContainer], None],
is_connected: Callable[[], bool],
emitter: Emitter,
state_container: ChatStateContainer,
*args: Any,
**kwargs: Any,
) -> Generator[Packet, None]:
"""
Explicit wrapper function that runs a function in a background thread
@@ -180,19 +177,18 @@ def run_chat_loop_with_state_containers(
Args:
func: The function to wrap (should accept emitter and state_container as first and second args)
completion_callback: Callback function to call when the function completes
emitter: Emitter instance for sending packets
state_container: ChatStateContainer instance for accumulating state
is_connected: Callable that returns False when stop signal is set
*args: Additional positional arguments for func
**kwargs: Additional keyword arguments for func
Usage:
packets = run_chat_loop_with_state_containers(
my_func,
completion_callback=completion_callback,
emitter=emitter,
state_container=state_container,
is_connected=check_func,
arg1, arg2, kwarg1=value1
)
for packet in packets:
# Process packets
@@ -201,9 +197,7 @@ def run_chat_loop_with_state_containers(
def run_with_exception_capture() -> None:
try:
# Ensure state_container is passed explicitly, removing it from kwargs if present
kwargs_with_state = {**kwargs, "state_container": state_container}
func(emitter, *args, **kwargs_with_state)
chat_loop_func(emitter, state_container)
except Exception as e:
# If execution fails, emit an exception packet
emitter.emit(

View File

@@ -1,3 +1,4 @@
import json
import re
from collections.abc import Callable
from typing import cast
@@ -45,6 +46,7 @@ from onyx.utils.timing import log_function_time
logger = setup_logger()
IMAGE_GENERATION_TOOL_NAME = "generate_image"
def create_chat_session_from_request(
@@ -422,10 +424,44 @@ def convert_chat_history_basic(
return list(reversed(trimmed_reversed))
def _build_tool_call_response_history_message(
tool_name: str,
generated_images: list[dict] | None,
tool_call_response: str | None,
) -> str:
if tool_name != IMAGE_GENERATION_TOOL_NAME:
return TOOL_CALL_RESPONSE_CROSS_MESSAGE
if generated_images:
llm_image_context: list[dict[str, str]] = []
for image in generated_images:
file_id = image.get("file_id")
revised_prompt = image.get("revised_prompt")
if not isinstance(file_id, str):
continue
llm_image_context.append(
{
"file_id": file_id,
"revised_prompt": (
revised_prompt if isinstance(revised_prompt, str) else ""
),
}
)
if llm_image_context:
return json.dumps(llm_image_context)
if tool_call_response:
return tool_call_response
return TOOL_CALL_RESPONSE_CROSS_MESSAGE
def convert_chat_history(
chat_history: list[ChatMessage],
files: list[ChatLoadedFile],
project_image_files: list[ChatLoadedFile],
context_image_files: list[ChatLoadedFile],
additional_context: str | None,
token_counter: Callable[[str], int],
tool_id_to_name_map: dict[int, str],
@@ -505,11 +541,11 @@ def convert_chat_history(
)
# Add the user message with image files attached
# If this is the last USER message, also include project_image_files
# Note: project image file tokens are NOT counted in the token count
# If this is the last USER message, also include context_image_files
# Note: context image file tokens are NOT counted in the token count
if idx == last_user_message_idx:
if project_image_files:
image_files.extend(project_image_files)
if context_image_files:
image_files.extend(context_image_files)
if additional_context:
simple_messages.append(
@@ -582,10 +618,24 @@ def convert_chat_history(
# Add TOOL_CALL_RESPONSE messages for each tool call in this turn
for tool_call in turn_tool_calls:
tool_name = tool_id_to_name_map.get(
tool_call.tool_id, "unknown"
)
tool_response_message = (
_build_tool_call_response_history_message(
tool_name=tool_name,
generated_images=tool_call.generated_images,
tool_call_response=tool_call.tool_call_response,
)
)
simple_messages.append(
ChatMessageSimple(
message=TOOL_CALL_RESPONSE_CROSS_MESSAGE,
token_count=20, # Tiny overestimate
message=tool_response_message,
token_count=(
token_counter(tool_response_message)
if tool_name == IMAGE_GENERATION_TOOL_NAME
else 20
),
message_type=MessageType.TOOL_CALL_RESPONSE,
tool_call_id=tool_call.tool_call_id,
image_files=None,

View File

@@ -15,10 +15,10 @@ from onyx.chat.emitter import Emitter
from onyx.chat.llm_step import extract_tool_calls_from_response_text
from onyx.chat.llm_step import run_llm_step
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import ExtractedProjectFiles
from onyx.chat.models import ContextFileMetadata
from onyx.chat.models import ExtractedContextFiles
from onyx.chat.models import FileToolMetadata
from onyx.chat.models import LlmStepResult
from onyx.chat.models import ProjectFileMetadata
from onyx.chat.models import ToolCallSimple
from onyx.chat.prompt_utils import build_reminder_message
from onyx.chat.prompt_utils import build_system_prompt
@@ -30,6 +30,7 @@ from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MessageType
from onyx.context.search.models import SearchDoc
from onyx.context.search.models import SearchDocsResponse
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.memory import add_memory
from onyx.db.memory import update_memory_at_index
from onyx.db.memory import UserMemoryContext
@@ -57,6 +58,7 @@ from onyx.tools.tool_implementations.images.models import (
FinalImageGenerationResponse,
)
from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
from onyx.tools.tool_implementations.python.python_tool import PythonTool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.web_search.utils import extract_url_snippet_map
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
@@ -201,17 +203,17 @@ def _try_fallback_tool_extraction(
MAX_LLM_CYCLES = 6
def _build_project_file_citation_mapping(
project_file_metadata: list[ProjectFileMetadata],
def _build_context_file_citation_mapping(
file_metadata: list[ContextFileMetadata],
starting_citation_num: int = 1,
) -> CitationMapping:
"""Build citation mapping for project files.
"""Build citation mapping for context files.
Converts project file metadata into SearchDoc objects that can be cited.
Converts context file metadata into SearchDoc objects that can be cited.
Citation numbers start from the provided starting number.
Args:
project_file_metadata: List of project file metadata
file_metadata: List of context file metadata
starting_citation_num: Starting citation number (default: 1)
Returns:
@@ -219,8 +221,7 @@ def _build_project_file_citation_mapping(
"""
citation_mapping: CitationMapping = {}
for idx, file_meta in enumerate(project_file_metadata, start=starting_citation_num):
# Create a SearchDoc for each project file
for idx, file_meta in enumerate(file_metadata, start=starting_citation_num):
search_doc = SearchDoc(
document_id=file_meta.file_id,
chunk_ind=0,
@@ -240,29 +241,28 @@ def _build_project_file_citation_mapping(
def _build_project_message(
project_files: ExtractedProjectFiles | None,
context_files: ExtractedContextFiles | None,
token_counter: Callable[[str], int] | None,
) -> list[ChatMessageSimple]:
"""Build messages for project / tool-backed files.
"""Build messages for context-injected / tool-backed files.
Returns up to two messages:
1. The full-text project files message (if project_file_texts is populated).
1. The full-text files message (if file_texts is populated).
2. A lightweight metadata message for files the LLM should access via the
FileReaderTool (e.g. oversized chat-attached files or project files that
don't fit in context).
FileReaderTool (e.g. oversized files that don't fit in context).
"""
if not project_files:
if not context_files:
return []
messages: list[ChatMessageSimple] = []
if project_files.project_file_texts:
if context_files.file_texts:
messages.append(
_create_project_files_message(project_files, token_counter=None)
_create_context_files_message(context_files, token_counter=None)
)
if project_files.file_metadata_for_tool and token_counter:
if context_files.file_metadata_for_tool and token_counter:
messages.append(
_create_file_tool_metadata_message(
project_files.file_metadata_for_tool, token_counter
context_files.file_metadata_for_tool, token_counter
)
)
return messages
@@ -273,7 +273,7 @@ def construct_message_history(
custom_agent_prompt: ChatMessageSimple | None,
simple_chat_history: list[ChatMessageSimple],
reminder_message: ChatMessageSimple | None,
project_files: ExtractedProjectFiles | None,
context_files: ExtractedContextFiles | None,
available_tokens: int,
last_n_user_messages: int | None = None,
token_counter: Callable[[str], int] | None = None,
@@ -287,7 +287,7 @@ def construct_message_history(
# Build the project / file-metadata messages up front so we can use their
# actual token counts for the budget.
project_messages = _build_project_message(project_files, token_counter)
project_messages = _build_project_message(context_files, token_counter)
project_messages_tokens = sum(m.token_count for m in project_messages)
history_token_budget = available_tokens
@@ -443,17 +443,17 @@ def construct_message_history(
)
# Attach project images to the last user message
if project_files and project_files.project_image_files:
if context_files and context_files.image_files:
existing_images = last_user_message.image_files or []
last_user_message = ChatMessageSimple(
message=last_user_message.message,
token_count=last_user_message.token_count,
message_type=last_user_message.message_type,
image_files=existing_images + project_files.project_image_files,
image_files=existing_images + context_files.image_files,
)
# Build the final message list according to README ordering:
# [system], [history_before_last_user], [custom_agent], [project_files],
# [system], [history_before_last_user], [custom_agent], [context_files],
# [forgotten_files], [last_user_message], [messages_after_last_user], [reminder]
result = [system_prompt] if system_prompt else []
@@ -464,14 +464,14 @@ def construct_message_history(
if custom_agent_prompt:
result.append(custom_agent_prompt)
# 3. Add project files / file-metadata messages (inserted before last user message)
# 3. Add context files / file-metadata messages (inserted before last user message)
result.extend(project_messages)
# 4. Add forgotten-files metadata (right before the user's question)
if forgotten_files_message:
result.append(forgotten_files_message)
# 5. Add last user message (with project images attached)
# 5. Add last user message (with context images attached)
result.append(last_user_message)
# 6. Add messages after last user message (tool calls, responses, etc.)
@@ -545,11 +545,11 @@ def _create_file_tool_metadata_message(
)
def _create_project_files_message(
project_files: ExtractedProjectFiles,
def _create_context_files_message(
context_files: ExtractedContextFiles,
token_counter: Callable[[str], int] | None, # noqa: ARG001
) -> ChatMessageSimple:
"""Convert project files to a ChatMessageSimple message.
"""Convert context files to a ChatMessageSimple message.
Format follows the README specification for document representation.
"""
@@ -557,7 +557,7 @@ def _create_project_files_message(
# Format as documents JSON as described in README
documents_list = []
for idx, file_text in enumerate(project_files.project_file_texts, start=1):
for idx, file_text in enumerate(context_files.file_texts, start=1):
documents_list.append(
{
"document": idx,
@@ -568,10 +568,10 @@ def _create_project_files_message(
documents_json = json.dumps({"documents": documents_list}, indent=2)
message_content = f"Here are some documents provided for context, they may not all be relevant:\n{documents_json}"
# Use pre-calculated token count from project_files
# Use pre-calculated token count from context_files
return ChatMessageSimple(
message=message_content,
token_count=project_files.total_token_count,
token_count=context_files.total_token_count,
message_type=MessageType.USER,
)
@@ -582,7 +582,7 @@ def run_llm_loop(
simple_chat_history: list[ChatMessageSimple],
tools: list[Tool],
custom_agent_prompt: str | None,
project_files: ExtractedProjectFiles,
context_files: ExtractedContextFiles,
persona: Persona | None,
user_memory_context: UserMemoryContext | None,
llm: LLM,
@@ -625,9 +625,9 @@ def run_llm_loop(
# Add project file citation mappings if project files are present
project_citation_mapping: CitationMapping = {}
if project_files.project_file_metadata:
project_citation_mapping = _build_project_file_citation_mapping(
project_files.project_file_metadata
if context_files.file_metadata:
project_citation_mapping = _build_context_file_citation_mapping(
context_files.file_metadata
)
citation_processor.update_citation_mapping(project_citation_mapping)
@@ -645,16 +645,22 @@ def run_llm_loop(
# TODO allow citing of images in Projects. Since attached to the last user message, it has no text associated with it.
# One future workaround is to include the images as separate user messages with citation information and process those.
always_cite_documents: bool = bool(
project_files.project_as_filter or project_files.project_file_texts
context_files.use_as_search_filter or context_files.file_texts
)
should_cite_documents: bool = False
ran_image_gen: bool = False
just_ran_web_search: bool = False
has_called_search_tool: bool = False
code_interpreter_file_generated: bool = False
fallback_extraction_attempted: bool = False
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
default_base_system_prompt: str = get_default_base_system_prompt(db_session)
# Fetch this in a short-lived session so the long-running stream loop does
# not pin a connection just to keep read state alive.
with get_session_with_current_tenant() as prompt_db_session:
default_base_system_prompt: str = get_default_base_system_prompt(
prompt_db_session
)
system_prompt = None
custom_agent_prompt_msg = None
@@ -761,6 +767,7 @@ def run_llm_loop(
),
include_citation_reminder=should_cite_documents
or always_cite_documents,
include_file_reminder=code_interpreter_file_generated,
is_last_cycle=out_of_cycles,
)
@@ -779,7 +786,7 @@ def run_llm_loop(
custom_agent_prompt=custom_agent_prompt_msg,
simple_chat_history=simple_chat_history,
reminder_message=reminder_msg,
project_files=project_files,
context_files=context_files,
available_tokens=available_tokens,
token_counter=token_counter,
all_injected_file_metadata=all_injected_file_metadata,
@@ -900,6 +907,18 @@ def run_llm_loop(
if tool_call.tool_name == SearchTool.NAME:
has_called_search_tool = True
# Track if code interpreter generated files with download links
if (
tool_call.tool_name == PythonTool.NAME
and not code_interpreter_file_generated
):
try:
parsed = json.loads(tool_response.llm_facing_response)
if parsed.get("generated_files"):
code_interpreter_file_generated = True
except (json.JSONDecodeError, AttributeError):
pass
# Build a mapping of tool names to tool objects for getting tool_id
tools_by_name = {tool.name: tool for tool in final_tools}

View File

@@ -31,13 +31,6 @@ class CustomToolResponse(BaseModel):
tool_name: str
class ProjectSearchConfig(BaseModel):
"""Configuration for search tool availability in project context."""
search_usage: SearchToolUsage
disable_forced_tool: bool
class CreateChatSessionID(BaseModel):
chat_session_id: UUID
@@ -132,8 +125,8 @@ class ChatMessageSimple(BaseModel):
file_id: str | None = None
class ProjectFileMetadata(BaseModel):
"""Metadata for a project file to enable citation support."""
class ContextFileMetadata(BaseModel):
"""Metadata for a context-injected file to enable citation support."""
file_id: str
filename: str
@@ -167,20 +160,28 @@ class ChatHistoryResult(BaseModel):
all_injected_file_metadata: dict[str, FileToolMetadata]
class ExtractedProjectFiles(BaseModel):
project_file_texts: list[str]
project_image_files: list[ChatLoadedFile]
project_as_filter: bool
class ExtractedContextFiles(BaseModel):
"""Result of attempting to load user files (from a project or persona) into context."""
file_texts: list[str]
image_files: list[ChatLoadedFile]
use_as_search_filter: bool
total_token_count: int
# Metadata for project files to enable citations
project_file_metadata: list[ProjectFileMetadata]
# None if not a project
project_uncapped_token_count: int | None
# Lightweight metadata for files exposed via FileReaderTool
# (populated when files don't fit in context and vector DB is disabled)
# (populated when files don't fit in context and vector DB is disabled).
file_metadata: list[ContextFileMetadata]
uncapped_token_count: int | None
file_metadata_for_tool: list[FileToolMetadata] = []
class SearchParams(BaseModel):
"""Resolved search filter IDs and search-tool usage for a chat turn."""
search_project_id: int | None
search_persona_id: int | None
search_usage: SearchToolUsage
class LlmStepResult(BaseModel):
reasoning: str | None
answer: str | None

View File

@@ -3,6 +3,7 @@ IMPORTANT: familiarize yourself with the design concepts prior to contributing t
An overview can be found in the README.md file in this directory.
"""
import io
import re
import traceback
from collections.abc import Callable
@@ -33,11 +34,11 @@ from onyx.chat.models import ChatBasicResponse
from onyx.chat.models import ChatFullResponse
from onyx.chat.models import ChatLoadedFile
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import ContextFileMetadata
from onyx.chat.models import CreateChatSessionID
from onyx.chat.models import ExtractedProjectFiles
from onyx.chat.models import ExtractedContextFiles
from onyx.chat.models import FileToolMetadata
from onyx.chat.models import ProjectFileMetadata
from onyx.chat.models import ProjectSearchConfig
from onyx.chat.models import SearchParams
from onyx.chat.models import StreamingError
from onyx.chat.models import ToolCallResponse
from onyx.chat.prompt_utils import calculate_reserved_tokens
@@ -62,11 +63,12 @@ from onyx.db.models import ChatSession
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.db.projects import get_project_token_count
from onyx.db.projects import get_user_files_from_project
from onyx.db.tools import get_tools
from onyx.deep_research.dr_loop import run_deep_research_llm_loop
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.file_store.utils import load_in_memory_chat_files
from onyx.file_store.utils import verify_user_files
from onyx.llm.factory import get_llm_for_persona
@@ -139,12 +141,12 @@ def _collect_available_file_ids(
pass
if project_id:
project_files = get_user_files_from_project(
user_files = get_user_files_from_project(
project_id=project_id,
user_id=user_id,
db_session=db_session,
)
for uf in project_files:
for uf in user_files:
user_file_ids.add(uf.id)
return _AvailableFiles(
@@ -192,9 +194,67 @@ def _convert_loaded_files_to_chat_files(
return chat_files
def _extract_project_file_texts_and_images(
def resolve_context_user_files(
persona: Persona,
project_id: int | None,
user_id: UUID | None,
db_session: Session,
) -> list[UserFile]:
"""Apply the precedence rule to decide which user files to load.
A custom persona fully supersedes the project. When a chat uses a
custom persona, the project is purely organisational — its files are
never loaded and never made searchable.
Custom persona → persona's own user_files (may be empty).
Default persona inside a project → project files.
Otherwise → empty list.
"""
if persona.id != DEFAULT_PERSONA_ID:
return list(persona.user_files) if persona.user_files else []
if project_id:
return get_user_files_from_project(
project_id=project_id,
user_id=user_id,
db_session=db_session,
)
return []
def _empty_extracted_context_files() -> ExtractedContextFiles:
return ExtractedContextFiles(
file_texts=[],
image_files=[],
use_as_search_filter=False,
total_token_count=0,
file_metadata=[],
uncapped_token_count=None,
)
def _extract_text_from_in_memory_file(f: InMemoryChatFile) -> str | None:
"""Extract text content from an InMemoryChatFile.
PLAIN_TEXT: the content is pre-extracted UTF-8 plaintext stored during
ingestion — decode directly.
DOC / CSV / other text types: the content is the original file bytes —
use extract_file_text which handles encoding detection and format parsing.
"""
try:
if f.file_type == ChatFileType.PLAIN_TEXT:
return f.content.decode("utf-8", errors="ignore").replace("\x00", "")
return extract_file_text(
file=io.BytesIO(f.content),
file_name=f.filename or "",
break_on_unprocessable=False,
)
except Exception:
logger.warning(f"Failed to extract text from file {f.file_id}", exc_info=True)
return None
def extract_context_files(
user_files: list[UserFile],
llm_max_context_window: int,
reserved_token_count: int,
db_session: Session,
@@ -203,8 +263,12 @@ def _extract_project_file_texts_and_images(
# 60% of the LLM's max context window. The other benefit is that for projects with
# more files, this makes it so that we don't throw away the history too quickly every time.
max_llm_context_percentage: float = 0.6,
) -> ExtractedProjectFiles:
"""Extract text content from project files if they fit within the context window.
) -> ExtractedContextFiles:
"""Load user files into context if they fit; otherwise flag for search.
The caller is responsible for deciding *which* user files to pass in
(project files, persona files, etc.). This function only cares about
the all-or-nothing fit check and the actual content loading.
Args:
project_id: The project ID to load files from
@@ -213,160 +277,95 @@ def _extract_project_file_texts_and_images(
reserved_token_count: Number of tokens to reserve for other content
db_session: Database session
max_llm_context_percentage: Maximum percentage of the LLM context window to use.
Returns:
ExtractedProjectFiles containing:
- List of text content strings from project files (text files only)
- List of image files from project (ChatLoadedFile objects)
- Project id if the the project should be provided as a filter in search or None if not.
ExtractedContextFiles containing:
- List of text content strings from context files (text files only)
- List of image files from context (ChatLoadedFile objects)
- Total token count of all extracted files
- File metadata for context files
- Uncapped token count of all extracted files
- File metadata for files that don't fit in context and vector DB is disabled
"""
# TODO I believe this is not handling all file types correctly.
project_as_filter = False
if not project_id:
return ExtractedProjectFiles(
project_file_texts=[],
project_image_files=[],
project_as_filter=False,
total_token_count=0,
project_file_metadata=[],
project_uncapped_token_count=None,
)
# TODO(yuhong): I believe this is not handling all file types correctly.
if not user_files:
return _empty_extracted_context_files()
aggregate_tokens = sum(uf.token_count or 0 for uf in user_files)
max_actual_tokens = (
llm_max_context_window - reserved_token_count
) * max_llm_context_percentage
# Calculate total token count for all user files in the project
project_tokens = get_project_token_count(
project_id=project_id,
user_id=user_id,
if aggregate_tokens >= max_actual_tokens:
tool_metadata = []
use_as_search_filter = not DISABLE_VECTOR_DB
if DISABLE_VECTOR_DB:
tool_metadata = _build_file_tool_metadata_for_user_files(user_files)
return ExtractedContextFiles(
file_texts=[],
image_files=[],
use_as_search_filter=use_as_search_filter,
total_token_count=0,
file_metadata=[],
uncapped_token_count=aggregate_tokens,
file_metadata_for_tool=tool_metadata,
)
# Files fit — load them into context
user_file_map = {str(uf.id): uf for uf in user_files}
in_memory_files = load_in_memory_chat_files(
user_file_ids=[uf.id for uf in user_files],
db_session=db_session,
)
project_file_texts: list[str] = []
project_image_files: list[ChatLoadedFile] = []
project_file_metadata: list[ProjectFileMetadata] = []
file_texts: list[str] = []
image_files: list[ChatLoadedFile] = []
file_metadata: list[ContextFileMetadata] = []
total_token_count = 0
if project_tokens < max_actual_tokens:
# Load project files into memory using cached plaintext when available
project_user_files = get_user_files_from_project(
project_id=project_id,
user_id=user_id,
db_session=db_session,
)
if project_user_files:
# Create a mapping from file_id to UserFile for token count lookup
user_file_map = {str(file.id): file for file in project_user_files}
project_file_ids = [file.id for file in project_user_files]
in_memory_project_files = load_in_memory_chat_files(
user_file_ids=project_file_ids,
db_session=db_session,
for f in in_memory_files:
uf = user_file_map.get(str(f.file_id))
if f.file_type.is_text_file():
text_content = _extract_text_from_in_memory_file(f)
if not text_content:
continue
file_texts.append(text_content)
file_metadata.append(
ContextFileMetadata(
file_id=str(f.file_id),
filename=f.filename or f"file_{f.file_id}",
file_content=text_content,
)
)
if uf and uf.token_count:
total_token_count += uf.token_count
elif f.file_type == ChatFileType.IMAGE:
token_count = uf.token_count if uf and uf.token_count else 0
total_token_count += token_count
image_files.append(
ChatLoadedFile(
file_id=f.file_id,
content=f.content,
file_type=f.file_type,
filename=f.filename,
content_text=None,
token_count=token_count,
)
)
# Extract text content from loaded files
for file in in_memory_project_files:
if file.file_type.is_text_file():
try:
text_content = file.content.decode("utf-8", errors="ignore")
# Strip null bytes
text_content = text_content.replace("\x00", "")
if text_content:
project_file_texts.append(text_content)
# Add metadata for citation support
project_file_metadata.append(
ProjectFileMetadata(
file_id=str(file.file_id),
filename=file.filename or f"file_{file.file_id}",
file_content=text_content,
)
)
# Add token count for text file
user_file = user_file_map.get(str(file.file_id))
if user_file and user_file.token_count:
total_token_count += user_file.token_count
except Exception:
# Skip files that can't be decoded
pass
elif file.file_type == ChatFileType.IMAGE:
# Convert InMemoryChatFile to ChatLoadedFile
user_file = user_file_map.get(str(file.file_id))
token_count = (
user_file.token_count
if user_file and user_file.token_count
else 0
)
total_token_count += token_count
chat_loaded_file = ChatLoadedFile(
file_id=file.file_id,
content=file.content,
file_type=file.file_type,
filename=file.filename,
content_text=None, # Images don't have text content
token_count=token_count,
)
project_image_files.append(chat_loaded_file)
else:
if DISABLE_VECTOR_DB:
# Without a vector DB we can't use project-as-filter search.
# Instead, build lightweight metadata so the LLM can call the
# FileReaderTool to inspect individual files on demand.
file_metadata_for_tool = _build_file_tool_metadata_for_project(
project_id=project_id,
user_id=user_id,
db_session=db_session,
)
return ExtractedProjectFiles(
project_file_texts=[],
project_image_files=[],
project_as_filter=False,
total_token_count=0,
project_file_metadata=[],
project_uncapped_token_count=project_tokens,
file_metadata_for_tool=file_metadata_for_tool,
)
project_as_filter = True
return ExtractedProjectFiles(
project_file_texts=project_file_texts,
project_image_files=project_image_files,
project_as_filter=project_as_filter,
return ExtractedContextFiles(
file_texts=file_texts,
image_files=image_files,
use_as_search_filter=False,
total_token_count=total_token_count,
project_file_metadata=project_file_metadata,
project_uncapped_token_count=project_tokens,
file_metadata=file_metadata,
uncapped_token_count=aggregate_tokens,
)
APPROX_CHARS_PER_TOKEN = 4
def _build_file_tool_metadata_for_project(
project_id: int,
user_id: UUID | None,
db_session: Session,
) -> list[FileToolMetadata]:
"""Build lightweight FileToolMetadata for every file in a project.
Used when files are too large to fit in context and the vector DB is
disabled, so the LLM needs to know which files it can read via the
FileReaderTool.
"""
project_user_files = get_user_files_from_project(
project_id=project_id,
user_id=user_id,
db_session=db_session,
)
return [
FileToolMetadata(
file_id=str(uf.id),
filename=uf.name,
approx_char_count=(uf.token_count or 0) * APPROX_CHARS_PER_TOKEN,
)
for uf in project_user_files
]
def _build_file_tool_metadata_for_user_files(
user_files: list[UserFile],
) -> list[FileToolMetadata]:
@@ -381,55 +380,46 @@ def _build_file_tool_metadata_for_user_files(
]
def _get_project_search_availability(
def determine_search_params(
persona_id: int,
project_id: int | None,
persona_id: int | None,
loaded_project_files: bool,
project_has_files: bool,
forced_tool_id: int | None,
search_tool_id: int | None,
) -> ProjectSearchConfig:
"""Determine search tool availability based on project context.
extracted_context_files: ExtractedContextFiles,
) -> SearchParams:
"""Decide which search filter IDs and search-tool usage apply for a chat turn.
Search is disabled when ALL of the following are true:
- User is in a project
- Using the default persona (not a custom agent)
- Project files are already loaded in context
A custom persona fully supersedes the project — project files are never
searchable and the search tool config is entirely controlled by the
persona. The project_id filter is only set for the default persona.
When search is disabled and the user tried to force the search tool,
that forcing is also disabled.
Returns AUTO (follow persona config) in all other cases.
For the default persona inside a project:
- Files overflow → ENABLED (vector DB scopes to these files)
- Files fit → DISABLED (content already in prompt)
- No files at all → DISABLED (nothing to search)
"""
# Not in a project, this should have no impact on search tool availability
if not project_id:
return ProjectSearchConfig(
search_usage=SearchToolUsage.AUTO, disable_forced_tool=False
)
is_custom_persona = persona_id != DEFAULT_PERSONA_ID
# Custom persona in project - let persona config decide
# Even if there are no files in the project, it's still guided by the persona config.
if persona_id != DEFAULT_PERSONA_ID:
return ProjectSearchConfig(
search_usage=SearchToolUsage.AUTO, disable_forced_tool=False
)
search_project_id: int | None = None
search_persona_id: int | None = None
if extracted_context_files.use_as_search_filter:
if is_custom_persona:
search_persona_id = persona_id
else:
search_project_id = project_id
# If in a project with the default persona and the files have been already loaded into the context or
# there are no files in the project, disable search as there is nothing to search for.
if loaded_project_files or not project_has_files:
user_forced_search = (
forced_tool_id is not None
and search_tool_id is not None
and forced_tool_id == search_tool_id
)
return ProjectSearchConfig(
search_usage=SearchToolUsage.DISABLED,
disable_forced_tool=user_forced_search,
)
search_usage = SearchToolUsage.AUTO
if not is_custom_persona and project_id:
has_context_files = bool(extracted_context_files.uncapped_token_count)
files_loaded_in_context = bool(extracted_context_files.file_texts)
# Default persona in a project with files, but also the files have not been loaded into the context already.
return ProjectSearchConfig(
search_usage=SearchToolUsage.ENABLED, disable_forced_tool=False
if extracted_context_files.use_as_search_filter:
search_usage = SearchToolUsage.ENABLED
elif files_loaded_in_context or not has_context_files:
search_usage = SearchToolUsage.DISABLED
return SearchParams(
search_project_id=search_project_id,
search_persona_id=search_persona_id,
search_usage=search_usage,
)
@@ -661,26 +651,37 @@ def handle_stream_message_objects(
user_memory_context=prompt_memory_context,
)
# Process projects, if all of the files fit in the context, it doesn't need to use RAG
extracted_project_files = _extract_project_file_texts_and_images(
# Determine which user files to use. A custom persona fully
# supersedes the project — project files are never loaded or
# searchable when a custom persona is in play. Only the default
# persona inside a project uses the project's files.
context_user_files = resolve_context_user_files(
persona=persona,
project_id=chat_session.project_id,
user_id=user_id,
db_session=db_session,
)
extracted_context_files = extract_context_files(
user_files=context_user_files,
llm_max_context_window=llm.config.max_input_tokens,
reserved_token_count=reserved_token_count,
db_session=db_session,
)
# When the vector DB is disabled, persona-attached user_files have no
# search pipeline path. Inject them as file_metadata_for_tool so the
# LLM can read them via the FileReaderTool.
if DISABLE_VECTOR_DB and persona.user_files:
persona_file_metadata = _build_file_tool_metadata_for_user_files(
persona.user_files
)
# Merge persona file metadata into the extracted project files
extracted_project_files.file_metadata_for_tool.extend(persona_file_metadata)
search_params = determine_search_params(
persona_id=persona.id,
project_id=chat_session.project_id,
extracted_context_files=extracted_context_files,
)
# Also grant access to persona-attached user files for FileReaderTool
if persona.user_files:
existing = set(available_files.user_file_ids)
for uf in persona.user_files:
if uf.id not in existing:
available_files.user_file_ids.append(uf.id)
# Build a mapping of tool_id to tool_name for history reconstruction
all_tools = get_tools(db_session)
tool_id_to_name_map = {tool.id: tool.name for tool in all_tools}
@@ -689,30 +690,17 @@ def handle_stream_message_objects(
None,
)
# Determine if search should be disabled for this project context
forced_tool_id = new_msg_req.forced_tool_id
project_search_config = _get_project_search_availability(
project_id=chat_session.project_id,
persona_id=persona.id,
loaded_project_files=bool(extracted_project_files.project_file_texts),
project_has_files=bool(
extracted_project_files.project_uncapped_token_count
),
forced_tool_id=new_msg_req.forced_tool_id,
search_tool_id=search_tool_id,
)
if project_search_config.disable_forced_tool:
if (
search_params.search_usage == SearchToolUsage.DISABLED
and forced_tool_id is not None
and search_tool_id is not None
and forced_tool_id == search_tool_id
):
forced_tool_id = None
emitter = get_default_emitter()
# Also grant access to persona-attached user files
if persona.user_files:
existing = set(available_files.user_file_ids)
for uf in persona.user_files:
if uf.id not in existing:
available_files.user_file_ids.append(uf.id)
# Construct tools based on the persona configurations
tool_dict = construct_tools(
persona=persona,
@@ -722,11 +710,8 @@ def handle_stream_message_objects(
llm=llm,
search_tool_config=SearchToolConfig(
user_selected_filters=new_msg_req.internal_search_filters,
project_id=(
chat_session.project_id
if extracted_project_files.project_as_filter
else None
),
project_id=search_params.search_project_id,
persona_id=search_params.search_persona_id,
bypass_acl=bypass_acl,
slack_context=slack_context,
enable_slack_search=_should_enable_slack_search(
@@ -744,7 +729,7 @@ def handle_stream_message_objects(
chat_file_ids=available_files.chat_file_ids,
),
allowed_tool_ids=new_msg_req.allowed_tool_ids,
search_usage_forcing_setting=project_search_config.search_usage,
search_usage_forcing_setting=search_params.search_usage,
)
tools: list[Tool] = []
for tool_list in tool_dict.values():
@@ -783,7 +768,7 @@ def handle_stream_message_objects(
chat_history_result = convert_chat_history(
chat_history=chat_history,
files=files,
project_image_files=extracted_project_files.project_image_files,
context_image_files=extracted_context_files.image_files,
additional_context=additional_context,
token_counter=token_counter,
tool_id_to_name_map=tool_id_to_name_map,
@@ -856,6 +841,11 @@ def handle_stream_message_objects(
reserved_tokens=reserved_token_count,
)
# Release any read transaction before entering the long-running LLM stream.
# Without this, the request-scoped session can keep a connection checked out
# for the full stream duration.
db_session.commit()
# The stream generator can resume on a different worker thread after early yields.
# Set this right before launching the LLM loop so run_in_background copies the right context.
if new_msg_req.mock_llm_response is not None:
@@ -874,46 +864,54 @@ def handle_stream_message_objects(
# (user has already responded to a clarification question)
skip_clarification = is_last_assistant_message_clarification(chat_history)
# NOTE: we _could_ pass in a zero argument function since emitter and state_container
# are just passed in immediately anyways, but the abstraction is cleaner this way.
yield from run_chat_loop_with_state_containers(
run_deep_research_llm_loop,
lambda emitter, state_container: run_deep_research_llm_loop(
emitter=emitter,
state_container=state_container,
simple_chat_history=simple_chat_history,
tools=tools,
custom_agent_prompt=custom_agent_prompt,
llm=llm,
token_counter=token_counter,
db_session=db_session,
skip_clarification=skip_clarification,
user_identity=user_identity,
chat_session_id=str(chat_session.id),
all_injected_file_metadata=all_injected_file_metadata,
),
llm_loop_completion_callback,
is_connected=check_is_connected,
emitter=emitter,
state_container=state_container,
simple_chat_history=simple_chat_history,
tools=tools,
custom_agent_prompt=custom_agent_prompt,
llm=llm,
token_counter=token_counter,
db_session=db_session,
skip_clarification=skip_clarification,
user_identity=user_identity,
chat_session_id=str(chat_session.id),
all_injected_file_metadata=all_injected_file_metadata,
)
else:
yield from run_chat_loop_with_state_containers(
run_llm_loop,
lambda emitter, state_container: run_llm_loop(
emitter=emitter,
state_container=state_container,
simple_chat_history=simple_chat_history,
tools=tools,
custom_agent_prompt=custom_agent_prompt,
context_files=extracted_context_files,
persona=persona,
user_memory_context=user_memory_context,
llm=llm,
token_counter=token_counter,
db_session=db_session,
forced_tool_id=forced_tool_id,
user_identity=user_identity,
chat_session_id=str(chat_session.id),
chat_files=chat_files_for_tools,
include_citations=new_msg_req.include_citations,
all_injected_file_metadata=all_injected_file_metadata,
inject_memories_in_prompt=user.use_memories,
),
llm_loop_completion_callback,
is_connected=check_is_connected, # Not passed through to run_llm_loop
emitter=emitter,
state_container=state_container,
simple_chat_history=simple_chat_history,
tools=tools,
custom_agent_prompt=custom_agent_prompt,
project_files=extracted_project_files,
persona=persona,
user_memory_context=user_memory_context,
llm=llm,
token_counter=token_counter,
db_session=db_session,
forced_tool_id=forced_tool_id,
user_identity=user_identity,
chat_session_id=str(chat_session.id),
chat_files=chat_files_for_tools,
include_citations=new_msg_req.include_citations,
all_injected_file_metadata=all_injected_file_metadata,
inject_memories_in_prompt=user.use_memories,
)
except ValueError as e:

View File

@@ -10,6 +10,7 @@ from onyx.db.user_file import calculate_user_files_token_count
from onyx.file_store.models import FileDescriptor
from onyx.prompts.chat_prompts import CITATION_REMINDER
from onyx.prompts.chat_prompts import DEFAULT_SYSTEM_PROMPT
from onyx.prompts.chat_prompts import FILE_REMINDER
from onyx.prompts.chat_prompts import LAST_CYCLE_CITATION_REMINDER
from onyx.prompts.chat_prompts import REQUIRE_CITATION_GUIDANCE
from onyx.prompts.prompt_utils import get_company_context
@@ -125,6 +126,7 @@ def calculate_reserved_tokens(
def build_reminder_message(
reminder_text: str | None,
include_citation_reminder: bool,
include_file_reminder: bool,
is_last_cycle: bool,
) -> str | None:
reminder = reminder_text.strip() if reminder_text else ""
@@ -132,6 +134,8 @@ def build_reminder_message(
reminder += "\n\n" + LAST_CYCLE_CITATION_REMINDER
if include_citation_reminder:
reminder += "\n\n" + CITATION_REMINDER
if include_file_reminder:
reminder += "\n\n" + FILE_REMINDER
reminder = reminder.strip()
return reminder if reminder else None
@@ -186,7 +190,7 @@ def _build_user_information_section(
if not sections:
return ""
return USER_INFORMATION_HEADER + "".join(sections)
return USER_INFORMATION_HEADER + "\n".join(sections)
def build_system_prompt(
@@ -224,23 +228,21 @@ def build_system_prompt(
system_prompt += REQUIRE_CITATION_GUIDANCE
if include_all_guidance:
system_prompt += (
TOOL_SECTION_HEADER
+ TOOL_DESCRIPTION_SEARCH_GUIDANCE
+ INTERNAL_SEARCH_GUIDANCE
+ WEB_SEARCH_GUIDANCE.format(
tool_sections = [
TOOL_DESCRIPTION_SEARCH_GUIDANCE,
INTERNAL_SEARCH_GUIDANCE,
WEB_SEARCH_GUIDANCE.format(
site_colon_disabled=WEB_SEARCH_SITE_DISABLED_GUIDANCE
)
+ OPEN_URLS_GUIDANCE
+ PYTHON_TOOL_GUIDANCE
+ GENERATE_IMAGE_GUIDANCE
+ MEMORY_GUIDANCE
)
),
OPEN_URLS_GUIDANCE,
PYTHON_TOOL_GUIDANCE,
GENERATE_IMAGE_GUIDANCE,
MEMORY_GUIDANCE,
]
system_prompt += TOOL_SECTION_HEADER + "\n".join(tool_sections)
return system_prompt
if tools:
system_prompt += TOOL_SECTION_HEADER
has_web_search = any(isinstance(tool, WebSearchTool) for tool in tools)
has_internal_search = any(isinstance(tool, SearchTool) for tool in tools)
has_open_urls = any(isinstance(tool, OpenURLTool) for tool in tools)
@@ -250,12 +252,14 @@ def build_system_prompt(
)
has_memory = any(isinstance(tool, MemoryTool) for tool in tools)
tool_guidance_sections: list[str] = []
if has_web_search or has_internal_search or include_all_guidance:
system_prompt += TOOL_DESCRIPTION_SEARCH_GUIDANCE
tool_guidance_sections.append(TOOL_DESCRIPTION_SEARCH_GUIDANCE)
# These are not included at the Tool level because the ordering may matter.
if has_internal_search or include_all_guidance:
system_prompt += INTERNAL_SEARCH_GUIDANCE
tool_guidance_sections.append(INTERNAL_SEARCH_GUIDANCE)
if has_web_search or include_all_guidance:
site_disabled_guidance = ""
@@ -265,20 +269,23 @@ def build_system_prompt(
)
if web_search_tool and not web_search_tool.supports_site_filter:
site_disabled_guidance = WEB_SEARCH_SITE_DISABLED_GUIDANCE
system_prompt += WEB_SEARCH_GUIDANCE.format(
site_colon_disabled=site_disabled_guidance
tool_guidance_sections.append(
WEB_SEARCH_GUIDANCE.format(site_colon_disabled=site_disabled_guidance)
)
if has_open_urls or include_all_guidance:
system_prompt += OPEN_URLS_GUIDANCE
tool_guidance_sections.append(OPEN_URLS_GUIDANCE)
if has_python or include_all_guidance:
system_prompt += PYTHON_TOOL_GUIDANCE
tool_guidance_sections.append(PYTHON_TOOL_GUIDANCE)
if has_generate_image or include_all_guidance:
system_prompt += GENERATE_IMAGE_GUIDANCE
tool_guidance_sections.append(GENERATE_IMAGE_GUIDANCE)
if has_memory or include_all_guidance:
system_prompt += MEMORY_GUIDANCE
tool_guidance_sections.append(MEMORY_GUIDANCE)
if tool_guidance_sections:
system_prompt += TOOL_SECTION_HEADER + "\n".join(tool_guidance_sections)
return system_prompt

View File

@@ -210,10 +210,10 @@ AUTH_COOKIE_EXPIRE_TIME_SECONDS = int(
REQUIRE_EMAIL_VERIFICATION = (
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
)
SMTP_SERVER = os.environ.get("SMTP_SERVER") or "smtp.gmail.com"
SMTP_SERVER = os.environ.get("SMTP_SERVER") or ""
SMTP_PORT = int(os.environ.get("SMTP_PORT") or "587")
SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
SMTP_USER = os.environ.get("SMTP_USER") or ""
SMTP_PASS = os.environ.get("SMTP_PASS") or ""
EMAIL_FROM = os.environ.get("EMAIL_FROM") or SMTP_USER
SENDGRID_API_KEY = os.environ.get("SENDGRID_API_KEY") or ""
@@ -251,7 +251,9 @@ DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S = int(
os.environ.get("DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S") or 50
)
OPENSEARCH_ADMIN_USERNAME = os.environ.get("OPENSEARCH_ADMIN_USERNAME", "admin")
OPENSEARCH_ADMIN_PASSWORD = os.environ.get("OPENSEARCH_ADMIN_PASSWORD", "")
OPENSEARCH_ADMIN_PASSWORD = os.environ.get(
"OPENSEARCH_ADMIN_PASSWORD", "StrongPassword123!"
)
USING_AWS_MANAGED_OPENSEARCH = (
os.environ.get("USING_AWS_MANAGED_OPENSEARCH", "").lower() == "true"
)
@@ -282,6 +284,9 @@ OPENSEARCH_TEXT_ANALYZER = os.environ.get("OPENSEARCH_TEXT_ANALYZER") or "englis
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX = (
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "").lower() == "true"
)
# NOTE: This effectively does nothing anymore, admins can now toggle whether
# retrieval is through OpenSearch. This value is only used as a final fallback
# in case that doesn't work for whatever reason.
# Given that the "base" config above is true, this enables whether we want to
# retrieve from OpenSearch or Vespa. We want to be able to quickly toggle this
# in the event we see issues with OpenSearch retrieval in our dev environments.
@@ -289,6 +294,12 @@ ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX = (
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
and os.environ.get("ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX", "").lower() == "true"
)
# Whether we should check for and create an index if necessary every time we
# instantiate an OpenSearchDocumentIndex on multitenant cloud. Defaults to True.
VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT = (
os.environ.get("VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT", "true").lower()
== "true"
)
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
# NOTE: this is used if and only if the vespa config server is accessible via a
@@ -637,6 +648,14 @@ SHAREPOINT_CONNECTOR_SIZE_THRESHOLD = int(
os.environ.get("SHAREPOINT_CONNECTOR_SIZE_THRESHOLD", 20 * 1024 * 1024)
)
# When True, group sync enumerates every Azure AD group in the tenant (expensive).
# When False (default), only groups found in site role assignments are synced.
# Can be overridden per-connector via the "exhaustive_ad_enumeration" key in
# connector_specific_config.
SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION = (
os.environ.get("SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION", "").lower() == "true"
)
BLOB_STORAGE_SIZE_THRESHOLD = int(
os.environ.get("BLOB_STORAGE_SIZE_THRESHOLD", 20 * 1024 * 1024)
)

View File

@@ -157,6 +157,25 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
# How long a queued user-file task is valid before workers discard it.
# Should be longer than the beat interval (20 s) but short enough to prevent
# indefinite queue growth. Workers drop tasks older than this without touching
# the DB, so a shorter value = faster drain of stale duplicates.
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
# Maximum number of tasks allowed in the user-file-processing queue before the
# beat generator stops adding more. Prevents unbounded queue growth when workers
# fall behind.
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
# How long a queued user-file-project-sync task remains valid.
# Should be short enough to discard stale queue entries under load while still
# allowing workers enough time to pick up new tasks.
CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES = 60 # 1 minute (in seconds)
# Max queue depth before user-file-project-sync producers stop enqueuing.
# This applies backpressure when workers are falling behind.
USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH = 500
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
CELERY_SANDBOX_FILE_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
@@ -443,8 +462,12 @@ class OnyxRedisLocks:
# User file processing
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
# Prevents the beat from re-enqueuing the same file while a task is already queued.
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
USER_FILE_PROJECT_SYNC_QUEUED_PREFIX = "da_lock:user_file_project_sync_queued"
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
USER_FILE_DELETE_LOCK_PREFIX = "da_lock:user_file_delete"

View File

@@ -7,7 +7,6 @@ from datetime import timezone
from enum import Enum
from typing import Any
from typing import cast
from urllib.parse import urlparse
from github import Github
from github import RateLimitExceededException
@@ -71,78 +70,6 @@ class DocMetadata(BaseModel):
repo: str
def parse_github_url_if_needed(
repo_owner: str, repositories: str | None
) -> tuple[str, str | None]:
"""
Parse GitHub URL if present in repo_owner, otherwise return as-is.
This is a defensive helper that handles cases where users might paste a full
GitHub URL instead of just the owner name. The frontend should parse URLs,
but this provides a safety net.
Args:
repo_owner: Either an owner name (e.g., "octocat") or a full GitHub URL
repositories: Comma-separated repository names, or None for all repos
Returns:
Tuple of (parsed_owner, parsed_repositories)
Examples:
>>> parse_github_url_if_needed("octocat", "repo1")
("octocat", "repo1")
>>> parse_github_url_if_needed("https://github.com/octocat", None)
("octocat", None)
>>> parse_github_url_if_needed("https://github.com/octocat/repo1", None)
("octocat", "repo1")
>>> parse_github_url_if_needed("https://github.com/octocat/repo1", "repo2,repo3")
("octocat", "repo2,repo3")
"""
# If repo_owner doesn't look like a URL, return as-is
if not repo_owner or "github.com" not in repo_owner:
return (repo_owner, repositories)
try:
# Add protocol if missing for proper URL parsing
url_to_parse = repo_owner
if not url_to_parse.startswith(("http://", "https://")):
url_to_parse = "https://" + url_to_parse
parsed = urlparse(url_to_parse)
# Verify this is actually a GitHub URL
if "github.com" not in parsed.netloc:
return (repo_owner, repositories)
# Extract path components, removing empty strings and leading/trailing slashes
path_parts = [part for part in parsed.path.strip("/").split("/") if part]
if not path_parts:
# URL with no path, can't extract owner
return (repo_owner, repositories)
# First part of path is the owner
extracted_owner = path_parts[0]
# If there's a second part (repo name) and repositories is not specified, use it
if len(path_parts) >= 2 and repositories is None:
extracted_repo = path_parts[1]
return (extracted_owner, extracted_repo)
# Otherwise, just return the extracted owner with the original repositories
return (extracted_owner, repositories)
except Exception as e:
# If parsing fails for any reason, return original values
logger.warning(
f"Failed to parse potential GitHub URL '{repo_owner}': {e}. Using as-is."
)
return (repo_owner, repositories)
def get_nextUrl_key(pag_list: PaginatedList[PullRequest | Issue]) -> str:
if "_PaginatedList__nextUrl" in pag_list.__dict__:
return "_PaginatedList__nextUrl"
@@ -509,12 +436,8 @@ class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoin
include_prs: bool = True,
include_issues: bool = False,
) -> None:
# Defensive URL parsing: handle cases where URLs might be submitted directly
parsed_owner, parsed_repositories = parse_github_url_if_needed(
repo_owner, repositories
)
self.repo_owner = parsed_owner
self.repositories = parsed_repositories
self.repo_owner = repo_owner
self.repositories = repositories
self.state_filter = state_filter
self.include_prs = include_prs
self.include_issues = include_issues

View File

@@ -16,6 +16,22 @@ from onyx.utils.retry_wrapper import retry_builder
logger = setup_logger()
_RATE_LIMIT_REASONS = {"userRateLimitExceeded", "rateLimitExceeded"}
def _is_rate_limit_error(error: HttpError) -> bool:
"""Google sometimes returns rate-limit errors as 403 with reason
'userRateLimitExceeded' instead of 429. This helper detects both."""
if error.resp.status == 429:
return True
if error.resp.status != 403:
return False
error_details = getattr(error, "error_details", None) or []
for detail in error_details:
if isinstance(detail, dict) and detail.get("reason") in _RATE_LIMIT_REASONS:
return True
return "userRateLimitExceeded" in str(error) or "rateLimitExceeded" in str(error)
# Google Drive APIs are quite flakey and may 500 for an
# extended period of time. This is now addressed by checkpointing.
@@ -57,7 +73,7 @@ def _execute_with_retry(request: Any) -> Any:
except HttpError as error:
attempt += 1
if error.resp.status == 429:
if _is_rate_limit_error(error):
# Attempt to get 'Retry-After' from headers
retry_after = error.resp.get("Retry-After")
if retry_after:
@@ -140,16 +156,16 @@ def _execute_single_retrieval(
)
logger.error(f"Error executing request: {e}")
raise e
elif _is_rate_limit_error(e):
results = _execute_with_retry(
lambda: retrieval_function(**request_kwargs).execute()
)
elif e.resp.status == 404 or e.resp.status == 403:
if continue_on_404_or_403:
logger.debug(f"Error executing request: {e}")
results = {}
else:
raise e
elif e.resp.status == 429:
results = _execute_with_retry(
lambda: retrieval_function(**request_kwargs).execute()
)
else:
logger.exception("Error executing request:")
raise e

View File

@@ -0,0 +1,96 @@
"""Inverse mapping from user-facing Microsoft host URLs to the SDK's AzureEnvironment.
The office365 library's GraphClient requires an ``AzureEnvironment`` string
(e.g. ``"Global"``, ``"GCC High"``) to route requests to the correct national
cloud. Our connectors instead expose free-text ``authority_host`` and
``graph_api_host`` fields so the frontend doesn't need to know about SDK
internals.
This module bridges the gap: given the two host URLs the user configured, it
resolves the matching ``AzureEnvironment`` value (and the implied SharePoint
domain suffix) so callers can pass ``environment=…`` to ``GraphClient``.
"""
from office365.graph_client import AzureEnvironment # type: ignore[import-untyped]
from pydantic import BaseModel
from onyx.connectors.exceptions import ConnectorValidationError
class MicrosoftGraphEnvironment(BaseModel):
"""One row of the inverse mapping."""
environment: str
graph_host: str
authority_host: str
sharepoint_domain_suffix: str
_ENVIRONMENTS: list[MicrosoftGraphEnvironment] = [
MicrosoftGraphEnvironment(
environment=AzureEnvironment.Global,
graph_host="https://graph.microsoft.com",
authority_host="https://login.microsoftonline.com",
sharepoint_domain_suffix="sharepoint.com",
),
MicrosoftGraphEnvironment(
environment=AzureEnvironment.USGovernmentHigh,
graph_host="https://graph.microsoft.us",
authority_host="https://login.microsoftonline.us",
sharepoint_domain_suffix="sharepoint.us",
),
MicrosoftGraphEnvironment(
environment=AzureEnvironment.USGovernmentDoD,
graph_host="https://dod-graph.microsoft.us",
authority_host="https://login.microsoftonline.us",
sharepoint_domain_suffix="sharepoint.us",
),
MicrosoftGraphEnvironment(
environment=AzureEnvironment.China,
graph_host="https://microsoftgraph.chinacloudapi.cn",
authority_host="https://login.chinacloudapi.cn",
sharepoint_domain_suffix="sharepoint.cn",
),
MicrosoftGraphEnvironment(
environment=AzureEnvironment.Germany,
graph_host="https://graph.microsoft.de",
authority_host="https://login.microsoftonline.de",
sharepoint_domain_suffix="sharepoint.de",
),
]
_GRAPH_HOST_INDEX: dict[str, MicrosoftGraphEnvironment] = {
env.graph_host: env for env in _ENVIRONMENTS
}
def resolve_microsoft_environment(
graph_api_host: str,
authority_host: str,
) -> MicrosoftGraphEnvironment:
"""Return the ``MicrosoftGraphEnvironment`` that matches the supplied hosts.
Raises ``ConnectorValidationError`` when the combination is unknown or
internally inconsistent (e.g. a GCC-High graph host paired with a
commercial authority host).
"""
graph_api_host = graph_api_host.rstrip("/")
authority_host = authority_host.rstrip("/")
env = _GRAPH_HOST_INDEX.get(graph_api_host)
if env is None:
known = ", ".join(sorted(_GRAPH_HOST_INDEX))
raise ConnectorValidationError(
f"Unsupported Microsoft Graph API host '{graph_api_host}'. "
f"Recognised hosts: {known}"
)
if env.authority_host != authority_host:
raise ConnectorValidationError(
f"Authority host '{authority_host}' is inconsistent with "
f"graph API host '{graph_api_host}'. "
f"Expected authority host '{env.authority_host}' "
f"for the {env.environment} environment."
)
return env

View File

@@ -6,6 +6,7 @@ from typing import cast
from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator
from pydantic import model_validator
from onyx.access.models import ExternalAccess
@@ -167,6 +168,14 @@ class DocumentBase(BaseModel):
# list of strings.
metadata: dict[str, str | list[str]]
@field_validator("metadata", mode="before")
@classmethod
def _coerce_metadata_values(cls, v: dict[str, Any]) -> dict[str, str | list[str]]:
return {
key: [str(item) for item in val] if isinstance(val, list) else str(val)
for key, val in v.items()
}
# UTC time
doc_updated_at: datetime | None = None
chunk_count: int | None = None

File diff suppressed because it is too large Load Diff

View File

@@ -11,6 +11,7 @@ from dateutil import parser
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
@@ -258,3 +259,21 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync):
slim_doc_batch = []
if slim_doc_batch:
yield slim_doc_batch
def validate_connector_settings(self) -> None:
"""
Very basic validation, we could do more here
"""
if not self.base_url.startswith("https://") and not self.base_url.startswith(
"http://"
):
raise ConnectorValidationError(
"Base URL must start with https:// or http://"
)
try:
get_all_post_ids(self.slab_bot_token)
except ConnectorMissingCredentialError:
raise
except Exception as e:
raise ConnectorValidationError(f"Failed to fetch posts from Slab: {e}")

View File

@@ -23,6 +23,7 @@ from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnectorWithPermSync
from onyx.connectors.microsoft_graph_env import resolve_microsoft_environment
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
@@ -50,12 +51,15 @@ class TeamsCheckpoint(ConnectorCheckpoint):
todo_team_ids: list[str] | None = None
DEFAULT_AUTHORITY_HOST = "https://login.microsoftonline.com"
DEFAULT_GRAPH_API_HOST = "https://graph.microsoft.com"
class TeamsConnector(
CheckpointedConnectorWithPermSync[TeamsCheckpoint],
SlimConnectorWithPermSync,
):
MAX_WORKERS = 10
AUTHORITY_URL_PREFIX = "https://login.microsoftonline.com/"
def __init__(
self,
@@ -63,12 +67,19 @@ class TeamsConnector(
# are not necessarily guaranteed to be unique
teams: list[str] = [],
max_workers: int = MAX_WORKERS,
authority_host: str = DEFAULT_AUTHORITY_HOST,
graph_api_host: str = DEFAULT_GRAPH_API_HOST,
) -> None:
self.graph_client: GraphClient | None = None
self.msal_app: msal.ConfidentialClientApplication | None = None
self.max_workers = max_workers
self.requested_team_list: list[str] = teams
resolved_env = resolve_microsoft_environment(graph_api_host, authority_host)
self._azure_environment = resolved_env.environment
self.authority_host = resolved_env.authority_host
self.graph_api_host = resolved_env.graph_host
# impls for BaseConnector
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
@@ -76,7 +87,7 @@ class TeamsConnector(
teams_client_secret = credentials["teams_client_secret"]
teams_directory_id = credentials["teams_directory_id"]
authority_url = f"{TeamsConnector.AUTHORITY_URL_PREFIX}{teams_directory_id}"
authority_url = f"{self.authority_host}/{teams_directory_id}"
self.msal_app = msal.ConfidentialClientApplication(
authority=authority_url,
client_id=teams_client_id,
@@ -91,7 +102,7 @@ class TeamsConnector(
raise RuntimeError("MSAL app is not initialized")
token = self.msal_app.acquire_token_for_client(
scopes=["https://graph.microsoft.com/.default"]
scopes=[f"{self.graph_api_host}/.default"]
)
if not isinstance(token, dict):
@@ -99,7 +110,9 @@ class TeamsConnector(
return token
self.graph_client = GraphClient(_acquire_token_func)
self.graph_client = GraphClient(
_acquire_token_func, environment=self._azure_environment
)
return None
def validate_connector_settings(self) -> None:

View File

@@ -32,6 +32,7 @@ from onyx.context.search.federated.slack_search_utils import should_include_mess
from onyx.context.search.models import ChunkIndexRequest
from onyx.context.search.models import InferenceChunk
from onyx.db.document import DocumentSource
from onyx.db.models import SearchSettings
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.document_index_utils import (
get_multipass_config,
@@ -905,13 +906,15 @@ def convert_slack_score(slack_score: float) -> float:
def slack_retrieval(
query: ChunkIndexRequest,
access_token: str,
db_session: Session,
db_session: Session | None = None,
connector: FederatedConnectorDetail | None = None, # noqa: ARG001
entities: dict[str, Any] | None = None,
limit: int | None = None,
slack_event_context: SlackContext | None = None,
bot_token: str | None = None, # Add bot token parameter
team_id: str | None = None,
# Pre-fetched data — when provided, avoids DB query (no session needed)
search_settings: SearchSettings | None = None,
) -> list[InferenceChunk]:
"""
Main entry point for Slack federated search with entity filtering.
@@ -925,7 +928,7 @@ def slack_retrieval(
Args:
query: Search query object
access_token: User OAuth access token
db_session: Database session
db_session: Database session (optional if search_settings provided)
connector: Federated connector detail (unused, kept for backwards compat)
entities: Connector-level config (entity filtering configuration)
limit: Maximum number of results
@@ -1153,7 +1156,10 @@ def slack_retrieval(
# chunk index docs into doc aware chunks
# a single index doc can get split into multiple chunks
search_settings = get_current_search_settings(db_session)
if search_settings is None:
if db_session is None:
raise ValueError("Either db_session or search_settings must be provided")
search_settings = get_current_search_settings(db_session)
embedder = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings
)

View File

@@ -72,6 +72,7 @@ class BaseFilters(BaseModel):
class UserFileFilters(BaseModel):
user_file_ids: list[UUID] | None = None
project_id: int | None = None
persona_id: int | None = None
class AssistantKnowledgeFilters(BaseModel):

View File

@@ -18,8 +18,10 @@ from onyx.context.search.utils import inference_section_from_chunks
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.document_index.interfaces import DocumentIndex
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
from onyx.llm.interfaces import LLM
from onyx.natural_language_processing.english_stopwords import strip_stopwords
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.secondary_llm_flows.source_filter import extract_source_filter
from onyx.secondary_llm_flows.time_filter import extract_time_filter
from onyx.utils.logger import setup_logger
@@ -38,10 +40,11 @@ def _build_index_filters(
user_provided_filters: BaseFilters | None,
user: User, # Used for ACLs, anonymous users only see public docs
project_id: int | None,
persona_id: int | None,
user_file_ids: list[UUID] | None,
persona_document_sets: list[str] | None,
persona_time_cutoff: datetime | None,
db_session: Session,
db_session: Session | None = None,
auto_detect_filters: bool = False,
query: str | None = None,
llm: LLM | None = None,
@@ -49,18 +52,19 @@ def _build_index_filters(
# Assistant knowledge filters
attached_document_ids: list[str] | None = None,
hierarchy_node_ids: list[int] | None = None,
# Pre-fetched ACL filters (skips DB query when provided)
acl_filters: list[str] | None = None,
) -> IndexFilters:
if auto_detect_filters and (llm is None or query is None):
raise RuntimeError("LLM and query are required for auto detect filters")
base_filters = user_provided_filters or BaseFilters()
if (
user_provided_filters
and user_provided_filters.document_set is None
and persona_document_sets is not None
):
base_filters.document_set = persona_document_sets
document_set_filter = (
base_filters.document_set
if base_filters.document_set is not None
else persona_document_sets
)
time_filter = base_filters.time_cutoff or persona_time_cutoff
source_filter = base_filters.source_type
@@ -103,15 +107,21 @@ def _build_index_filters(
source_filter = list(source_filter) + [DocumentSource.USER_FILE]
logger.debug("Added USER_FILE to source_filter for user knowledge search")
user_acl_filters = (
None if bypass_acl else build_access_filters_for_user(user, db_session)
)
if bypass_acl:
user_acl_filters = None
elif acl_filters is not None:
user_acl_filters = acl_filters
else:
if db_session is None:
raise ValueError("Either db_session or acl_filters must be provided")
user_acl_filters = build_access_filters_for_user(user, db_session)
final_filters = IndexFilters(
user_file_ids=user_file_ids,
project_id=project_id,
persona_id=persona_id,
source_type=source_filter,
document_set=persona_document_sets,
document_set=document_set_filter,
time_cutoff=time_filter,
tags=base_filters.tags,
access_control_list=user_acl_filters,
@@ -252,11 +262,17 @@ def search_pipeline(
user: User,
# Used for default filters and settings
persona: Persona | None,
db_session: Session,
db_session: Session | None = None,
auto_detect_filters: bool = False,
llm: LLM | None = None,
# If a project ID is provided, it will be exclusively scoped to that project
project_id: int | None = None,
# If a persona_id is provided, search scopes to files attached to this persona
persona_id: int | None = None,
# Pre-fetched data — when provided, avoids DB queries (no session needed)
acl_filters: list[str] | None = None,
embedding_model: EmbeddingModel | None = None,
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
) -> list[InferenceChunk]:
user_uploaded_persona_files: list[UUID] | None = (
[user_file.id for user_file in persona.user_files] if persona else None
@@ -287,6 +303,7 @@ def search_pipeline(
user_provided_filters=chunk_search_request.user_selected_filters,
user=user,
project_id=project_id,
persona_id=persona_id,
user_file_ids=user_uploaded_persona_files,
persona_document_sets=persona_document_sets,
persona_time_cutoff=persona_time_cutoff,
@@ -297,6 +314,7 @@ def search_pipeline(
bypass_acl=chunk_search_request.bypass_acl,
attached_document_ids=attached_document_ids,
hierarchy_node_ids=hierarchy_node_ids,
acl_filters=acl_filters,
)
query_keywords = strip_stopwords(chunk_search_request.query)
@@ -315,6 +333,8 @@ def search_pipeline(
user_id=user.id if user else None,
document_index=document_index,
db_session=db_session,
embedding_model=embedding_model,
prefetched_federated_retrieval_infos=prefetched_federated_retrieval_infos,
)
# For some specific connectors like Salesforce, a user that has access to an object doesn't mean

View File

@@ -14,9 +14,11 @@ from onyx.context.search.utils import get_query_embedding
from onyx.context.search.utils import inference_section_from_chunks
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import VespaChunkRequest
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
from onyx.federated_connectors.federated_retrieval import (
get_federated_retrieval_functions,
)
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -50,9 +52,14 @@ def combine_retrieval_results(
def _embed_and_search(
query_request: ChunkIndexRequest,
document_index: DocumentIndex,
db_session: Session,
db_session: Session | None = None,
embedding_model: EmbeddingModel | None = None,
) -> list[InferenceChunk]:
query_embedding = get_query_embedding(query_request.query, db_session)
query_embedding = get_query_embedding(
query_request.query,
db_session=db_session,
embedding_model=embedding_model,
)
hybrid_alpha = query_request.hybrid_alpha or HYBRID_ALPHA
@@ -78,7 +85,9 @@ def search_chunks(
query_request: ChunkIndexRequest,
user_id: UUID | None,
document_index: DocumentIndex,
db_session: Session,
db_session: Session | None = None,
embedding_model: EmbeddingModel | None = None,
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
) -> list[InferenceChunk]:
run_queries: list[tuple[Callable, tuple]] = []
@@ -88,14 +97,22 @@ def search_chunks(
else None
)
# Federated retrieval
federated_retrieval_infos = get_federated_retrieval_functions(
db_session=db_session,
user_id=user_id,
source_types=list(source_filters) if source_filters else None,
document_set_names=query_request.filters.document_set,
user_file_ids=query_request.filters.user_file_ids,
)
# Federated retrieval — use pre-fetched if available, otherwise query DB
if prefetched_federated_retrieval_infos is not None:
federated_retrieval_infos = prefetched_federated_retrieval_infos
else:
if db_session is None:
raise ValueError(
"Either db_session or prefetched_federated_retrieval_infos "
"must be provided"
)
federated_retrieval_infos = get_federated_retrieval_functions(
db_session=db_session,
user_id=user_id,
source_types=list(source_filters) if source_filters else None,
document_set_names=query_request.filters.document_set,
user_file_ids=query_request.filters.user_file_ids,
)
federated_sources = set(
federated_retrieval_info.source.to_non_federated_source()
@@ -114,7 +131,10 @@ def search_chunks(
if normal_search_enabled:
run_queries.append(
(_embed_and_search, (query_request, document_index, db_session))
(
_embed_and_search,
(query_request, document_index, db_session, embedding_model),
)
)
parallel_search_results = run_functions_tuples_in_parallel(run_queries)

View File

@@ -64,23 +64,34 @@ def inference_section_from_single_chunk(
)
def get_query_embeddings(queries: list[str], db_session: Session) -> list[Embedding]:
search_settings = get_current_search_settings(db_session)
def get_query_embeddings(
queries: list[str],
db_session: Session | None = None,
embedding_model: EmbeddingModel | None = None,
) -> list[Embedding]:
if embedding_model is None:
if db_session is None:
raise ValueError("Either db_session or embedding_model must be provided")
search_settings = get_current_search_settings(db_session)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
model = EmbeddingModel.from_db_model(
search_settings=search_settings,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
query_embedding = model.encode(queries, text_type=EmbedTextType.QUERY)
query_embedding = embedding_model.encode(queries, text_type=EmbedTextType.QUERY)
return query_embedding
@log_function_time(print_only=True, debug_only=True)
def get_query_embedding(query: str, db_session: Session) -> Embedding:
return get_query_embeddings([query], db_session)[0]
def get_query_embedding(
query: str,
db_session: Session | None = None,
embedding_model: EmbeddingModel | None = None,
) -> Embedding:
return get_query_embeddings(
[query], db_session=db_session, embedding_model=embedding_model
)[0]
def convert_inference_sections_to_search_docs(

View File

@@ -4,6 +4,7 @@ from fastapi_users.password import PasswordHelper
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.api_key import ApiKeyDescriptor
@@ -54,6 +55,7 @@ async def fetch_user_for_api_key(
select(User)
.join(ApiKey, ApiKey.user_id == User.id)
.where(ApiKey.hashed_api_key == hashed_api_key)
.options(selectinload(User.memories))
)

View File

@@ -13,6 +13,7 @@ from sqlalchemy import func
from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
@@ -97,6 +98,11 @@ async def get_user_count(only_admin_users: bool = False) -> int:
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]):
async def _get_user(self, statement: Select) -> UP | None:
statement = statement.options(selectinload(User.memories))
results = await self.session.execute(statement)
return results.unique().scalar_one_or_none()
async def create(
self,
create_dict: Dict[str, Any],

View File

@@ -0,0 +1,21 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.models import CodeInterpreterServer
def fetch_code_interpreter_server(
db_session: Session,
) -> CodeInterpreterServer:
server = db_session.scalars(select(CodeInterpreterServer)).one()
return server
def update_code_interpreter_server_enabled(
db_session: Session,
enabled: bool,
) -> CodeInterpreterServer:
server = db_session.scalars(select(CodeInterpreterServer)).one()
server.server_enabled = enabled
db_session.commit()
return server

View File

@@ -116,12 +116,15 @@ def get_connector_credential_pairs_for_user(
order_by_desc: bool = False,
source: DocumentSource | None = None,
processing_mode: ProcessingMode | None = ProcessingMode.REGULAR,
defer_connector_config: bool = False,
) -> list[ConnectorCredentialPair]:
"""Get connector credential pairs for a user.
Args:
processing_mode: Filter by processing mode. Defaults to REGULAR to hide
FILE_SYSTEM connectors from standard admin UI. Pass None to get all.
defer_connector_config: If True, skips loading Connector.connector_specific_config
to avoid fetching large JSONB blobs when they aren't needed.
"""
if eager_load_user:
assert (
@@ -130,7 +133,10 @@ def get_connector_credential_pairs_for_user(
stmt = select(ConnectorCredentialPair).distinct()
if eager_load_connector:
stmt = stmt.options(selectinload(ConnectorCredentialPair.connector))
connector_load = selectinload(ConnectorCredentialPair.connector)
if defer_connector_config:
connector_load = connector_load.defer(Connector.connector_specific_config)
stmt = stmt.options(connector_load)
if eager_load_credential:
load_opts = selectinload(ConnectorCredentialPair.credential)
@@ -170,6 +176,7 @@ def get_connector_credential_pairs_for_user_parallel(
order_by_desc: bool = False,
source: DocumentSource | None = None,
processing_mode: ProcessingMode | None = ProcessingMode.REGULAR,
defer_connector_config: bool = False,
) -> list[ConnectorCredentialPair]:
with get_session_with_current_tenant() as db_session:
return get_connector_credential_pairs_for_user(
@@ -183,6 +190,7 @@ def get_connector_credential_pairs_for_user_parallel(
order_by_desc=order_by_desc,
source=source,
processing_mode=processing_mode,
defer_connector_config=defer_connector_config,
)

View File

@@ -554,10 +554,19 @@ def fetch_all_document_sets_for_user(
stmt = (
select(DocumentSetDBModel)
.distinct()
.options(selectinload(DocumentSetDBModel.federated_connectors))
.options(
selectinload(DocumentSetDBModel.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSetDBModel.users),
selectinload(DocumentSetDBModel.groups),
selectinload(DocumentSetDBModel.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
)
)
stmt = _add_user_filters(stmt, user, get_editable=get_editable)
return db_session.scalars(stmt).all()
return db_session.scalars(stmt).unique().all()
def fetch_documents_for_document_set_paginated(

View File

@@ -21,8 +21,8 @@ from onyx.configs.app_configs import POSTGRES_POOL_RECYCLE
from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USE_NULL_POOL
from onyx.configs.app_configs import POSTGRES_USER
from onyx.db.engine.iam_auth import create_ssl_context_if_iam
from onyx.db.engine.iam_auth import get_iam_auth_token
from onyx.db.engine.iam_auth import ssl_context
from onyx.db.engine.sql_engine import ASYNC_DB_API
from onyx.db.engine.sql_engine import build_connection_string
from onyx.db.engine.sql_engine import is_valid_schema_name
@@ -66,7 +66,7 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
if app_name:
connect_args["server_settings"] = {"application_name": app_name}
connect_args["ssl"] = ssl_context
connect_args["ssl"] = create_ssl_context_if_iam()
engine_kwargs = {
"connect_args": connect_args,
@@ -97,7 +97,7 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
user = POSTGRES_USER
token = get_iam_auth_token(host, port, user, AWS_REGION_NAME)
cparams["password"] = token
cparams["ssl"] = ssl_context
cparams["ssl"] = create_ssl_context_if_iam()
return _ASYNC_ENGINE

View File

@@ -1,3 +1,4 @@
import functools
import os
import ssl
from typing import Any
@@ -48,11 +49,9 @@ def provide_iam_token(
configure_psycopg2_iam_auth(cparams, host, port, user, region)
@functools.cache
def create_ssl_context_if_iam() -> ssl.SSLContext | None:
"""Create an SSL context if IAM authentication is enabled, else return None."""
if USE_IAM_AUTH:
return ssl.create_default_context(cafile=SSL_CERT_FILE)
return None
ssl_context = create_ssl_context_if_iam()

View File

@@ -1,11 +1,102 @@
from sqlalchemy import text
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.db.engine.sql_engine import SqlEngine
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
def get_schemas_needing_migration(
tenant_schemas: list[str], head_rev: str
) -> list[str]:
"""Return only schemas whose current alembic version is not at head.
Uses a server-side PL/pgSQL loop to collect each schema's alembic version
into a temp table one at a time. This avoids building a massive UNION ALL
query (which locks the DB and times out at 17k+ schemas) and instead
acquires locks sequentially, one schema per iteration.
"""
if not tenant_schemas:
return []
engine = SqlEngine.get_engine()
with engine.connect() as conn:
# Populate a temp input table with exactly the schemas we care about.
# The DO block reads from this table so it only iterates the requested
# schemas instead of every tenant_% schema in the database.
conn.execute(text("DROP TABLE IF EXISTS _alembic_version_snapshot"))
conn.execute(text("DROP TABLE IF EXISTS _tenant_schemas_input"))
conn.execute(text("CREATE TEMP TABLE _tenant_schemas_input (schema_name text)"))
conn.execute(
text(
"INSERT INTO _tenant_schemas_input (schema_name) "
"SELECT unnest(CAST(:schemas AS text[]))"
),
{"schemas": tenant_schemas},
)
conn.execute(
text(
"CREATE TEMP TABLE _alembic_version_snapshot "
"(schema_name text, version_num text)"
)
)
conn.execute(
text(
"""
DO $$
DECLARE
s text;
schemas text[];
BEGIN
SELECT array_agg(schema_name) INTO schemas
FROM _tenant_schemas_input;
IF schemas IS NULL THEN
RAISE NOTICE 'No tenant schemas found.';
RETURN;
END IF;
FOREACH s IN ARRAY schemas LOOP
BEGIN
EXECUTE format(
'INSERT INTO _alembic_version_snapshot
SELECT %L, version_num FROM %I.alembic_version',
s, s
);
EXCEPTION
-- undefined_table: schema exists but has no alembic_version
-- table yet (new tenant, not yet migrated).
-- invalid_schema_name: tenant is registered but its
-- PostgreSQL schema does not exist yet (e.g. provisioning
-- incomplete). Both cases mean no version is available and
-- the schema will be included in the migration list.
WHEN undefined_table THEN NULL;
WHEN invalid_schema_name THEN NULL;
END;
END LOOP;
END;
$$
"""
)
)
rows = conn.execute(
text("SELECT schema_name, version_num FROM _alembic_version_snapshot")
)
version_by_schema = {row[0]: row[1] for row in rows}
conn.execute(text("DROP TABLE IF EXISTS _alembic_version_snapshot"))
conn.execute(text("DROP TABLE IF EXISTS _tenant_schemas_input"))
# Schemas missing from the snapshot have no alembic_version table yet and
# also need migration. version_by_schema.get(s) returns None for those,
# and None != head_rev, so they are included automatically.
return [s for s in tenant_schemas if version_by_schema.get(s) != head_rev]
def get_all_tenant_ids() -> list[str]:
"""Returning [None] means the only tenant is the 'public' or self hosted tenant."""

View File

@@ -619,7 +619,7 @@ def update_default_provider(provider_id: int, db_session: Session) -> None:
_update_default_model(
db_session,
provider_id,
provider.default_model_name,
provider.default_model_name, # type: ignore[arg-type]
LLMModelFlowType.CHAT,
)

View File

@@ -287,7 +287,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
# relationships
credentials: Mapped[list["Credential"]] = relationship(
"Credential", back_populates="user", lazy="joined"
"Credential", back_populates="user"
)
chat_sessions: Mapped[list["ChatSession"]] = relationship(
"ChatSession", back_populates="user"
@@ -321,7 +321,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
"Memory",
back_populates="user",
cascade="all, delete-orphan",
lazy="selectin",
order_by="desc(Memory.id)",
)
oauth_user_tokens: Mapped[list["OAuthUserToken"]] = relationship(
@@ -2823,13 +2822,17 @@ class LLMProvider(Base):
custom_config: Mapped[dict[str, str] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
default_model_name: Mapped[str] = mapped_column(String)
# Deprecated: use LLMModelFlow with CHAT flow type instead
default_model_name: Mapped[str | None] = mapped_column(String, nullable=True)
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
# should only be set for a single provider
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
# Deprecated: use LLMModelFlow.is_default with CHAT flow type instead
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
# Deprecated: use LLMModelFlow.is_default with VISION flow type instead
is_default_vision_provider: Mapped[bool | None] = mapped_column(Boolean)
# Deprecated: use LLMModelFlow with VISION flow type instead
default_vision_model: Mapped[str | None] = mapped_column(String, nullable=True)
# EE only
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
@@ -2880,6 +2883,7 @@ class ModelConfiguration(Base):
# - The end-user is configuring a model and chooses not to set a max-input-tokens limit.
max_input_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Deprecated: use LLMModelFlow with VISION flow type instead
supports_image_input: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
# Human-readable display name for the model.
@@ -4271,6 +4275,9 @@ class UserFile(Base):
needs_project_sync: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)
needs_persona_sync: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)
last_project_sync_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
@@ -4940,6 +4947,12 @@ class ScimUserMapping(Base):
user_id: Mapped[UUID] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False
)
scim_username: Mapped[str | None] = mapped_column(String, nullable=True)
department: Mapped[str | None] = mapped_column(String, nullable=True)
manager: Mapped[str | None] = mapped_column(String, nullable=True)
given_name: Mapped[str | None] = mapped_column(String, nullable=True)
family_name: Mapped[str | None] = mapped_column(String, nullable=True)
scim_emails_json: Mapped[str | None] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
@@ -4978,3 +4991,12 @@ class ScimGroupMapping(Base):
user_group: Mapped[UserGroup] = relationship(
"UserGroup", foreign_keys=[user_group_id]
)
class CodeInterpreterServer(Base):
"""Details about the code interpreter server"""
__tablename__ = "code_interpreter_server"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
server_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)

View File

@@ -8,6 +8,7 @@ from uuid import UUID
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.pat import build_displayable_pat
@@ -31,55 +32,61 @@ async def fetch_user_for_pat(
NOTE: This is async since it's used during auth (which is necessarily async due to FastAPI Users).
NOTE: Expired includes both naturally expired and user-revoked tokens (revocation sets expires_at=NOW()).
Uses select(User) as primary entity so that joined-eager relationships (e.g. oauth_accounts)
are loaded correctly — matching the pattern in fetch_user_for_api_key.
"""
# Single joined query with all filters pushed to database
now = datetime.now(timezone.utc)
result = await async_db_session.execute(
select(PersonalAccessToken, User)
.join(User, PersonalAccessToken.user_id == User.id)
user = await async_db_session.scalar(
select(User)
.join(PersonalAccessToken, PersonalAccessToken.user_id == User.id)
.where(PersonalAccessToken.hashed_token == hashed_token)
.where(User.is_active) # type: ignore
.where(
(PersonalAccessToken.expires_at.is_(None))
| (PersonalAccessToken.expires_at > now)
)
.limit(1)
.options(selectinload(User.memories))
)
row = result.first()
if not row:
if not user:
return None
pat, user = row
# Throttle last_used_at updates to reduce DB load (5-minute granularity sufficient for auditing)
# For request-level auditing, use application logs or a dedicated audit table
should_update = (
pat.last_used_at is None or (now - pat.last_used_at).total_seconds() > 300
)
if should_update:
# Update in separate session to avoid transaction coupling (fire-and-forget)
async def _update_last_used() -> None:
try:
tenant_id = get_current_tenant_id()
async with get_async_session_context_manager(
tenant_id
) as separate_session:
await separate_session.execute(
update(PersonalAccessToken)
.where(PersonalAccessToken.hashed_token == hashed_token)
.values(last_used_at=now)
)
await separate_session.commit()
except Exception as e:
logger.warning(f"Failed to update last_used_at for PAT: {e}")
asyncio.create_task(_update_last_used())
_schedule_pat_last_used_update(hashed_token, now)
return user
def _schedule_pat_last_used_update(hashed_token: str, now: datetime) -> None:
"""Fire-and-forget update of last_used_at, throttled to 5-minute granularity."""
async def _update() -> None:
try:
tenant_id = get_current_tenant_id()
async with get_async_session_context_manager(tenant_id) as session:
pat = await session.scalar(
select(PersonalAccessToken).where(
PersonalAccessToken.hashed_token == hashed_token
)
)
if not pat:
return
if (
pat.last_used_at is not None
and (now - pat.last_used_at).total_seconds() <= 300
):
return
await session.execute(
update(PersonalAccessToken)
.where(PersonalAccessToken.hashed_token == hashed_token)
.values(last_used_at=now)
)
await session.commit()
except Exception as e:
logger.warning(f"Failed to update last_used_at for PAT: {e}")
asyncio.create_task(_update())
def create_pat(
db_session: Session,
user_id: UUID,

View File

@@ -28,6 +28,7 @@ from onyx.db.document_access import get_accessible_documents_by_ids
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Document
from onyx.db.models import DocumentSet
from onyx.db.models import FederatedConnector__DocumentSet
from onyx.db.models import HierarchyNode
from onyx.db.models import Persona
from onyx.db.models import Persona__User
@@ -255,9 +256,6 @@ def create_update_persona(
try:
# Default persona validation
if create_persona_request.is_default_persona:
if not create_persona_request.is_public:
raise ValueError("Cannot make a default persona non public")
# Curators can edit default personas, but not make them
if user.role == UserRole.CURATOR or user.role == UserRole.GLOBAL_CURATOR:
pass
@@ -334,6 +332,7 @@ def update_persona_shared(
db_session: Session,
group_ids: list[int] | None = None,
is_public: bool | None = None,
label_ids: list[int] | None = None,
) -> None:
"""Simplified version of `create_update_persona` which only touches the
accessibility rather than any of the logic (e.g. prompt, connected data sources,
@@ -343,9 +342,7 @@ def update_persona_shared(
)
if user and user.role != UserRole.ADMIN and persona.user_id != user.id:
raise HTTPException(
status_code=403, detail="You don't have permission to modify this persona"
)
raise PermissionError("You don't have permission to modify this persona")
versioned_update_persona_access = fetch_versioned_implementation(
"onyx.db.persona", "update_persona_access"
@@ -359,6 +356,15 @@ def update_persona_shared(
group_ids=group_ids,
)
if label_ids is not None:
labels = (
db_session.query(PersonaLabel).filter(PersonaLabel.id.in_(label_ids)).all()
)
if len(labels) != len(label_ids):
raise ValueError("Some label IDs were not found in the database")
persona.labels.clear()
persona.labels = labels
db_session.commit()
@@ -420,9 +426,16 @@ def get_minimal_persona_snapshots_for_user(
stmt = stmt.options(
selectinload(Persona.tools),
selectinload(Persona.labels),
selectinload(Persona.document_sets)
.selectinload(DocumentSet.connector_credential_pairs)
.selectinload(ConnectorCredentialPair.connector),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.hierarchy_nodes),
selectinload(Persona.attached_documents).selectinload(
Document.parent_hierarchy_node
@@ -453,7 +466,16 @@ def get_persona_snapshots_for_user(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.user),
selectinload(Persona.user_files),
selectinload(Persona.users),
@@ -550,9 +572,16 @@ def get_minimal_persona_snapshots_paginated(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets)
.selectinload(DocumentSet.connector_credential_pairs)
.selectinload(ConnectorCredentialPair.connector),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.user),
)
@@ -611,7 +640,16 @@ def get_persona_snapshots_paginated(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.user),
selectinload(Persona.user_files),
selectinload(Persona.users),
@@ -732,6 +770,9 @@ def mark_persona_as_deleted(
) -> None:
persona = get_persona_by_id(persona_id=persona_id, user=user, db_session=db_session)
persona.deleted = True
affected_file_ids = [uf.id for uf in persona.user_files]
if affected_file_ids:
_mark_files_need_persona_sync(db_session, affected_file_ids)
db_session.commit()
@@ -743,11 +784,13 @@ def mark_persona_as_not_deleted(
persona = get_persona_by_id(
persona_id=persona_id, user=user, db_session=db_session, include_deleted=True
)
if persona.deleted:
persona.deleted = False
db_session.commit()
else:
if not persona.deleted:
raise ValueError(f"Persona with ID {persona_id} is not deleted.")
persona.deleted = False
affected_file_ids = [uf.id for uf in persona.user_files]
if affected_file_ids:
_mark_files_need_persona_sync(db_session, affected_file_ids)
db_session.commit()
def mark_delete_persona_by_name(
@@ -813,6 +856,20 @@ def update_personas_display_priority(
db_session.commit()
def _mark_files_need_persona_sync(
db_session: Session,
user_file_ids: list[UUID],
) -> None:
"""Flag the given UserFile rows so the background sync task picks them up
and updates their persona metadata in the vector DB."""
if not user_file_ids:
return
db_session.query(UserFile).filter(UserFile.id.in_(user_file_ids)).update(
{UserFile.needs_persona_sync: True},
synchronize_session=False,
)
def upsert_persona(
user: User | None,
name: str,
@@ -913,6 +970,8 @@ def upsert_persona(
labels = (
db_session.query(PersonaLabel).filter(PersonaLabel.id.in_(label_ids)).all()
)
if len(labels) != len(label_ids):
raise ValueError("Some label IDs were not found in the database")
# Fetch and attach hierarchy_nodes by IDs
hierarchy_nodes = None
@@ -1001,8 +1060,13 @@ def upsert_persona(
existing_persona.tools = tools or []
if user_file_ids is not None:
old_file_ids = {uf.id for uf in existing_persona.user_files}
new_file_ids = {uf.id for uf in (user_files or [])}
affected_file_ids = old_file_ids | new_file_ids
existing_persona.user_files.clear()
existing_persona.user_files = user_files or []
if affected_file_ids:
_mark_files_need_persona_sync(db_session, list(affected_file_ids))
if hierarchy_node_ids is not None:
existing_persona.hierarchy_nodes.clear()
@@ -1056,6 +1120,8 @@ def upsert_persona(
attached_documents=attached_documents or [],
)
db_session.add(new_persona)
if user_files:
_mark_files_need_persona_sync(db_session, [uf.id for uf in user_files])
persona = new_persona
if commit:
db_session.commit()
@@ -1102,9 +1168,6 @@ def update_persona_is_default(
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
)
if not persona.is_public:
persona.is_public = True
persona.is_default_persona = is_default
db_session.commit()

View File

@@ -2,6 +2,7 @@ import random
from datetime import datetime
from datetime import timedelta
from logging import getLogger
from uuid import UUID
from onyx.configs.constants import MessageType
from onyx.db.chat import create_chat_session
@@ -13,18 +14,26 @@ from onyx.db.models import ChatSession
logger = getLogger(__name__)
def seed_chat_history(num_sessions: int, num_messages: int, days: int) -> None:
def seed_chat_history(
num_sessions: int,
num_messages: int,
days: int,
user_id: UUID | None = None,
persona_id: int | None = None,
) -> None:
"""Utility function to seed chat history for testing.
num_sessions: the number of sessions to seed
num_messages: the number of messages to seed per sessions
days: the number of days looking backwards from the current time over which to randomize
the times.
user_id: optional user to associate with sessions
persona_id: optional persona/assistant to associate with sessions
"""
with get_session_with_current_tenant() as db_session:
logger.info(f"Seeding {num_sessions} sessions.")
for y in range(0, num_sessions):
create_chat_session(db_session, f"pytest_session_{y}", None, None)
create_chat_session(db_session, f"pytest_session_{y}", user_id, persona_id)
# randomize all session times
logger.info(f"Seeding {num_messages} messages per session.")

View File

@@ -3,8 +3,10 @@ from uuid import UUID
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.db.models import Project__UserFile
from onyx.db.models import UserFile
@@ -56,10 +58,34 @@ def fetch_user_project_ids_for_user_files(
db_session: Session,
) -> dict[str, list[int]]:
"""Fetch user project ids for specified user files"""
stmt = select(UserFile).where(UserFile.id.in_(user_file_ids))
user_file_uuid_ids = [UUID(user_file_id) for user_file_id in user_file_ids]
stmt = select(Project__UserFile.user_file_id, Project__UserFile.project_id).where(
Project__UserFile.user_file_id.in_(user_file_uuid_ids)
)
rows = db_session.execute(stmt).all()
user_file_id_to_project_ids: dict[str, list[int]] = {
user_file_id: [] for user_file_id in user_file_ids
}
for user_file_id, project_id in rows:
user_file_id_to_project_ids[str(user_file_id)].append(project_id)
return user_file_id_to_project_ids
def fetch_persona_ids_for_user_files(
user_file_ids: list[str],
db_session: Session,
) -> dict[str, list[int]]:
"""Fetch persona (assistant) ids for specified user files."""
stmt = (
select(UserFile)
.where(UserFile.id.in_(user_file_ids))
.options(selectinload(UserFile.assistants))
)
results = db_session.execute(stmt).scalars().all()
return {
str(user_file.id): [project.id for project in user_file.projects]
str(user_file.id): [persona.id for persona in user_file.assistants]
for user_file in results
}

View File

@@ -139,7 +139,7 @@ def generate_final_report(
custom_agent_prompt=None,
simple_chat_history=history,
reminder_message=reminder_message,
project_files=None,
context_files=None,
available_tokens=llm.config.max_input_tokens,
all_injected_file_metadata=all_injected_file_metadata,
)
@@ -257,7 +257,7 @@ def run_deep_research_llm_loop(
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=None,
project_files=None,
context_files=None,
available_tokens=available_tokens,
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
all_injected_file_metadata=all_injected_file_metadata,
@@ -321,7 +321,7 @@ def run_deep_research_llm_loop(
custom_agent_prompt=None,
simple_chat_history=simple_chat_history + [reminder_message],
reminder_message=None,
project_files=None,
context_files=None,
available_tokens=available_tokens,
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT + 1,
all_injected_file_metadata=all_injected_file_metadata,
@@ -485,7 +485,7 @@ def run_deep_research_llm_loop(
custom_agent_prompt=None,
simple_chat_history=simple_chat_history,
reminder_message=first_cycle_reminder_message,
project_files=None,
context_files=None,
available_tokens=available_tokens,
last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT,
all_injected_file_metadata=all_injected_file_metadata,

View File

@@ -11,6 +11,7 @@ from onyx.document_index.opensearch.opensearch_document_index import (
OpenSearchOldDocumentIndex,
)
from onyx.document_index.vespa.index import VespaIndex
from onyx.indexing.models import IndexingSetting
from shared_configs.configs import MULTI_TENANT
@@ -49,8 +50,11 @@ def get_default_document_index(
opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session)
if opensearch_retrieval_enabled:
indexing_setting = IndexingSetting.from_db_model(search_settings)
return OpenSearchOldDocumentIndex(
index_name=search_settings.index_name,
embedding_dim=indexing_setting.final_embedding_dim,
embedding_precision=indexing_setting.embedding_precision,
secondary_index_name=secondary_index_name,
large_chunks_enabled=search_settings.large_chunks_enabled,
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
@@ -118,8 +122,11 @@ def get_all_document_indices(
)
opensearch_document_index: OpenSearchOldDocumentIndex | None = None
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
indexing_setting = IndexingSetting.from_db_model(search_settings)
opensearch_document_index = OpenSearchOldDocumentIndex(
index_name=search_settings.index_name,
embedding_dim=indexing_setting.final_embedding_dim,
embedding_precision=indexing_setting.embedding_precision,
secondary_index_name=None,
large_chunks_enabled=False,
secondary_large_chunks_enabled=None,

View File

@@ -121,6 +121,7 @@ class VespaDocumentUserFields:
"""
user_projects: list[int] | None = None
personas: list[int] | None = None
@dataclass

View File

@@ -148,6 +148,7 @@ class MetadataUpdateRequest(BaseModel):
hidden: bool | None = None
secondary_index_updated: bool | None = None
project_ids: set[int] | None = None
persona_ids: set[int] | None = None
class IndexRetrievalFilters(BaseModel):

View File

@@ -1,5 +1,7 @@
import logging
import time
from contextlib import AbstractContextManager
from contextlib import nullcontext
from typing import Any
from typing import Generic
from typing import TypeVar
@@ -83,22 +85,26 @@ def get_new_body_without_vectors(body: dict[str, Any]) -> dict[str, Any]:
return new_body
class OpenSearchClient:
"""Client for interacting with OpenSearch.
class OpenSearchClient(AbstractContextManager):
"""Client for interacting with OpenSearch for cluster-level operations.
OpenSearch's Python module has pretty bad typing support so this client
attempts to protect the rest of the codebase from this. As a consequence,
most methods here return the minimum data needed for the rest of Onyx, and
tend to rely on Exceptions to handle errors.
TODO(andrei): This class currently assumes the structure of the database
schema when it returns a DocumentChunk. Make the class, or at least the
search method, templated on the structure the caller can expect.
Args:
host: The host of the OpenSearch cluster.
port: The port of the OpenSearch cluster.
auth: The authentication credentials for the OpenSearch cluster. A tuple
of (username, password).
use_ssl: Whether to use SSL for the OpenSearch cluster. Defaults to
True.
verify_certs: Whether to verify the SSL certificates for the OpenSearch
cluster. Defaults to False.
ssl_show_warn: Whether to show warnings for SSL certificates. Defaults
to False.
timeout: The timeout for the OpenSearch cluster. Defaults to
DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S.
"""
def __init__(
self,
index_name: str,
host: str = OPENSEARCH_HOST,
port: int = OPENSEARCH_REST_API_PORT,
auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD),
@@ -107,9 +113,8 @@ class OpenSearchClient:
ssl_show_warn: bool = False,
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
):
self._index_name = index_name
logger.debug(
f"Creating OpenSearch client for index {index_name} with host {host} and port {port} and timeout {timeout} seconds."
f"Creating OpenSearch client with host {host}, port {port} and timeout {timeout} seconds."
)
self._client = OpenSearch(
hosts=[{"host": host, "port": port}],
@@ -125,6 +130,142 @@ class OpenSearchClient:
# your request body that is less than this value.
timeout=timeout,
)
def __exit__(self, *_: Any) -> None:
self.close()
def __del__(self) -> None:
try:
self.close()
except Exception:
pass
@log_function_time(print_only=True, debug_only=True, include_args=True)
def create_search_pipeline(
self,
pipeline_id: str,
pipeline_body: dict[str, Any],
) -> None:
"""Creates a search pipeline.
See the OpenSearch documentation for more information on the search
pipeline body.
https://docs.opensearch.org/latest/search-plugins/search-pipelines/index/
Args:
pipeline_id: The ID of the search pipeline to create.
pipeline_body: The body of the search pipeline to create.
Raises:
Exception: There was an error creating the search pipeline.
"""
result = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
if not result.get("acknowledged", False):
raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.")
@log_function_time(print_only=True, debug_only=True, include_args=True)
def delete_search_pipeline(self, pipeline_id: str) -> None:
"""Deletes a search pipeline.
Args:
pipeline_id: The ID of the search pipeline to delete.
Raises:
Exception: There was an error deleting the search pipeline.
"""
result = self._client.search_pipeline.delete(id=pipeline_id)
if not result.get("acknowledged", False):
raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.")
@log_function_time(print_only=True, debug_only=True, include_args=True)
def put_cluster_settings(self, settings: dict[str, Any]) -> bool:
"""Puts cluster settings.
Args:
settings: The settings to put.
Raises:
Exception: There was an error putting the cluster settings.
Returns:
True if the settings were put successfully, False otherwise.
"""
response = self._client.cluster.put_settings(body=settings)
if response.get("acknowledged", False):
logger.info("Successfully put cluster settings.")
return True
else:
logger.error(f"Failed to put cluster settings: {response}.")
return False
@log_function_time(print_only=True, debug_only=True)
def ping(self) -> bool:
"""Pings the OpenSearch cluster.
Returns:
True if OpenSearch could be reached, False if it could not.
"""
return self._client.ping()
@log_function_time(print_only=True, debug_only=True)
def close(self) -> None:
"""Closes the client.
Raises:
Exception: There was an error closing the client.
"""
self._client.close()
class OpenSearchIndexClient(OpenSearchClient):
"""Client for interacting with OpenSearch for index-level operations.
OpenSearch's Python module has pretty bad typing support so this client
attempts to protect the rest of the codebase from this. As a consequence,
most methods here return the minimum data needed for the rest of Onyx, and
tend to rely on Exceptions to handle errors.
TODO(andrei): This class currently assumes the structure of the database
schema when it returns a DocumentChunk. Make the class, or at least the
search method, templated on the structure the caller can expect.
Args:
index_name: The name of the index to interact with.
host: The host of the OpenSearch cluster.
port: The port of the OpenSearch cluster.
auth: The authentication credentials for the OpenSearch cluster. A tuple
of (username, password).
use_ssl: Whether to use SSL for the OpenSearch cluster. Defaults to
True.
verify_certs: Whether to verify the SSL certificates for the OpenSearch
cluster. Defaults to False.
ssl_show_warn: Whether to show warnings for SSL certificates. Defaults
to False.
timeout: The timeout for the OpenSearch cluster. Defaults to
DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S.
"""
def __init__(
self,
index_name: str,
host: str = OPENSEARCH_HOST,
port: int = OPENSEARCH_REST_API_PORT,
auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD),
use_ssl: bool = True,
verify_certs: bool = False,
ssl_show_warn: bool = False,
timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S,
):
super().__init__(
host=host,
port=port,
auth=auth,
use_ssl=use_ssl,
verify_certs=verify_certs,
ssl_show_warn=ssl_show_warn,
timeout=timeout,
)
self._index_name = index_name
logger.debug(
f"OpenSearch client created successfully for index {self._index_name}."
)
@@ -192,6 +333,38 @@ class OpenSearchClient:
"""
return self._client.indices.exists(index=self._index_name)
@log_function_time(print_only=True, debug_only=True, include_args=True)
def put_mapping(self, mappings: dict[str, Any]) -> None:
"""Updates the index mapping in an idempotent manner.
- Existing fields with the same definition: No-op (succeeds silently).
- New fields: Added to the index.
- Existing fields with different types: Raises exception (requires
reindex).
See the OpenSearch documentation for more information:
https://docs.opensearch.org/latest/api-reference/index-apis/put-mapping/
Args:
mappings: The complete mapping definition to apply. This will be
merged with existing mappings in the index.
Raises:
Exception: There was an error updating the mappings, such as
attempting to change the type of an existing field.
"""
logger.debug(
f"Putting mappings for index {self._index_name} with mappings {mappings}."
)
response = self._client.indices.put_mapping(
index=self._index_name, body=mappings
)
if not response.get("acknowledged", False):
raise RuntimeError(
f"Failed to put the mapping update for index {self._index_name}."
)
logger.debug(f"Successfully put mappings for index {self._index_name}.")
@log_function_time(print_only=True, debug_only=True, include_args=True)
def validate_index(self, expected_mappings: dict[str, Any]) -> bool:
"""Validates the index.
@@ -610,43 +783,6 @@ class OpenSearchClient:
)
return DocumentChunk.model_validate(document_chunk_source)
@log_function_time(print_only=True, debug_only=True, include_args=True)
def create_search_pipeline(
self,
pipeline_id: str,
pipeline_body: dict[str, Any],
) -> None:
"""Creates a search pipeline.
See the OpenSearch documentation for more information on the search
pipeline body.
https://docs.opensearch.org/latest/search-plugins/search-pipelines/index/
Args:
pipeline_id: The ID of the search pipeline to create.
pipeline_body: The body of the search pipeline to create.
Raises:
Exception: There was an error creating the search pipeline.
"""
result = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body)
if not result.get("acknowledged", False):
raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.")
@log_function_time(print_only=True, debug_only=True, include_args=True)
def delete_search_pipeline(self, pipeline_id: str) -> None:
"""Deletes a search pipeline.
Args:
pipeline_id: The ID of the search pipeline to delete.
Raises:
Exception: There was an error deleting the search pipeline.
"""
result = self._client.search_pipeline.delete(id=pipeline_id)
if not result.get("acknowledged", False):
raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.")
@log_function_time(print_only=True, debug_only=True)
def search(
self, body: dict[str, Any], search_pipeline_id: str | None
@@ -807,48 +943,6 @@ class OpenSearchClient:
"""
self._client.indices.refresh(index=self._index_name)
@log_function_time(print_only=True, debug_only=True, include_args=True)
def put_cluster_settings(self, settings: dict[str, Any]) -> bool:
"""Puts cluster settings.
Args:
settings: The settings to put.
Raises:
Exception: There was an error putting the cluster settings.
Returns:
True if the settings were put successfully, False otherwise.
"""
response = self._client.cluster.put_settings(body=settings)
if response.get("acknowledged", False):
logger.info("Successfully put cluster settings.")
return True
else:
logger.error(f"Failed to put cluster settings: {response}.")
return False
@log_function_time(print_only=True, debug_only=True)
def ping(self) -> bool:
"""Pings the OpenSearch cluster.
Returns:
True if OpenSearch could be reached, False if it could not.
"""
return self._client.ping()
@log_function_time(print_only=True, debug_only=True)
def close(self) -> None:
"""Closes the client.
TODO(andrei): Can we have some way to auto close when the client no
longer has any references?
Raises:
Exception: There was an error closing the client.
"""
self._client.close()
def _get_hits_and_profile_from_search_result(
self, result: dict[str, Any]
) -> tuple[list[Any], int | None, bool | None, dict[str, Any], dict[str, Any]]:
@@ -945,14 +1039,7 @@ def wait_for_opensearch_with_timeout(
Returns:
True if OpenSearch is ready, False otherwise.
"""
made_client = False
try:
if client is None:
# NOTE: index_name does not matter because we are only using this object
# to ping.
# TODO(andrei): Make this better.
client = OpenSearchClient(index_name="")
made_client = True
with nullcontext(client) if client else OpenSearchClient() as client:
time_start = time.monotonic()
while True:
if client.ping():
@@ -969,7 +1056,3 @@ def wait_for_opensearch_with_timeout(
f"[OpenSearch] Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={wait_limit_s:.1f}"
)
time.sleep(wait_interval_s)
finally:
if made_client:
assert client is not None
client.close()

View File

@@ -7,6 +7,7 @@ from opensearchpy import NotFoundError
from onyx.access.models import DocumentAccess
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
from onyx.configs.chat_configs import NUM_RETURNED_HITS
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
from onyx.configs.constants import PUBLIC_DOC_PAT
@@ -40,6 +41,7 @@ from onyx.document_index.interfaces_new import IndexingMetadata
from onyx.document_index.interfaces_new import MetadataUpdateRequest
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.client import OpenSearchClient
from onyx.document_index.opensearch.client import OpenSearchIndexClient
from onyx.document_index.opensearch.client import SearchHit
from onyx.document_index.opensearch.cluster_settings import OPENSEARCH_CLUSTER_SETTINGS
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
@@ -50,6 +52,7 @@ from onyx.document_index.opensearch.schema import DocumentSchema
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
from onyx.document_index.opensearch.schema import GLOBAL_BOOST_FIELD_NAME
from onyx.document_index.opensearch.schema import HIDDEN_FIELD_NAME
from onyx.document_index.opensearch.schema import PERSONAS_FIELD_NAME
from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME
from onyx.document_index.opensearch.search import DocumentQuery
from onyx.document_index.opensearch.search import (
@@ -92,6 +95,25 @@ def generate_opensearch_filtered_access_control_list(
return list(access_control_list)
def set_cluster_state(client: OpenSearchClient) -> None:
if not client.put_cluster_settings(settings=OPENSEARCH_CLUSTER_SETTINGS):
logger.error(
"Failed to put cluster settings. If the settings have never been set before, "
"this may cause unexpected index creation when indexing documents into an "
"index that does not exist, or may cause expected logs to not appear. If this "
"is not the first time running Onyx against this instance of OpenSearch, these "
"settings have likely already been set. Not taking any further action..."
)
client.create_search_pipeline(
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
pipeline_body=MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
)
client.create_search_pipeline(
pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
pipeline_body=ZSCORE_NORMALIZATION_PIPELINE_CONFIG,
)
def _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
chunk: DocumentChunk,
score: float | None,
@@ -215,6 +237,7 @@ def _convert_onyx_chunk_to_opensearch_document(
# OpenSearch and it will not store any data at all for this field, which
# is different from supplying an empty list.
user_projects=chunk.user_project or None,
personas=chunk.personas or None,
primary_owners=get_experts_stores_representations(
chunk.source_document.primary_owners
),
@@ -246,6 +269,8 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
def __init__(
self,
index_name: str,
embedding_dim: int,
embedding_precision: EmbeddingPrecision,
secondary_index_name: str | None,
large_chunks_enabled: bool, # noqa: ARG002
secondary_large_chunks_enabled: bool | None, # noqa: ARG002
@@ -256,10 +281,6 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
index_name=index_name,
secondary_index_name=secondary_index_name,
)
if multitenant:
raise ValueError(
"Bug: OpenSearch is not yet ready for multitenant environments but something tried to use it."
)
if multitenant != MULTI_TENANT:
raise ValueError(
"Bug: Multitenant mismatch when initializing an OpenSearchDocumentIndex. "
@@ -267,8 +288,10 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
)
tenant_id = get_current_tenant_id()
self._real_index = OpenSearchDocumentIndex(
index_name=index_name,
tenant_state=TenantState(tenant_id=tenant_id, multitenant=multitenant),
index_name=index_name,
embedding_dim=embedding_dim,
embedding_precision=embedding_precision,
)
@staticmethod
@@ -277,9 +300,8 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
embedding_dims: list[int],
embedding_precisions: list[EmbeddingPrecision],
) -> None:
# TODO(andrei): Implement.
raise NotImplementedError(
"Multitenant index registration is not yet implemented for OpenSearch."
"Bug: Multitenant index registration is not supported for OpenSearch."
)
def ensure_indices_exist(
@@ -362,6 +384,11 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
if user_fields and user_fields.user_projects
else None
),
persona_ids=(
set(user_fields.personas)
if user_fields and user_fields.personas
else None
),
)
try:
@@ -464,19 +491,37 @@ class OpenSearchDocumentIndex(DocumentIndex):
for an OpenSearch search engine instance. It handles the complete lifecycle
of document chunks within a specific OpenSearch index/schema.
Although not yet used in this way in the codebase, each kind of embedding
used should correspond to a different instance of this class, and therefore
a different index in OpenSearch.
Each kind of embedding used should correspond to a different instance of
this class, and therefore a different index in OpenSearch.
If in a multitenant environment and
VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT, will verify and create the index
if necessary on initialization. This is because there is no logic which runs
on cluster restart which scans through all search settings over all tenants
and creates the relevant indices.
Args:
tenant_state: The tenant state of the caller.
index_name: The name of the index to interact with.
embedding_dim: The dimensionality of the embeddings used for the index.
embedding_precision: The precision of the embeddings used for the index.
"""
def __init__(
self,
index_name: str,
tenant_state: TenantState,
index_name: str,
embedding_dim: int,
embedding_precision: EmbeddingPrecision,
) -> None:
self._index_name: str = index_name
self._tenant_state: TenantState = tenant_state
self._os_client = OpenSearchClient(index_name=self._index_name)
self._client = OpenSearchIndexClient(index_name=self._index_name)
if self._tenant_state.multitenant and VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT:
self.verify_and_create_index_if_necessary(
embedding_dim=embedding_dim, embedding_precision=embedding_precision
)
def verify_and_create_index_if_necessary(
self,
@@ -485,10 +530,15 @@ class OpenSearchDocumentIndex(DocumentIndex):
) -> None:
"""Verifies and creates the index if necessary.
Also puts the desired cluster settings.
Also puts the desired cluster settings if not in a multitenant
environment.
Also puts the desired search pipeline state, creating the pipelines if
they do not exist and updating them otherwise.
Also puts the desired search pipeline state if not in a multitenant
environment, creating the pipelines if they do not exist and updating
them otherwise.
In a multitenant environment, the above steps happen explicitly on
setup.
Args:
embedding_dim: Vector dimensionality for the vector similarity part
@@ -501,47 +551,38 @@ class OpenSearchDocumentIndex(DocumentIndex):
search pipelines.
"""
logger.debug(
f"[OpenSearchDocumentIndex] Verifying and creating index {self._index_name} if necessary, "
f"with embedding dimension {embedding_dim}."
f"[OpenSearchDocumentIndex] Verifying and creating index {self._index_name} if "
f"necessary, with embedding dimension {embedding_dim}."
)
if not self._tenant_state.multitenant:
set_cluster_state(self._client)
expected_mappings = DocumentSchema.get_document_schema(
embedding_dim, self._tenant_state.multitenant
)
if not self._os_client.put_cluster_settings(
settings=OPENSEARCH_CLUSTER_SETTINGS
):
logger.error(
f"Failed to put cluster settings for index {self._index_name}. If the settings have never been set before this "
"may cause unexpected index creation when indexing documents into an index that does not exist, or may cause "
"expected logs to not appear. If this is not the first time running Onyx against this instance of OpenSearch, "
"these settings have likely already been set. Not taking any further action..."
)
if not self._os_client.index_exists():
if not self._client.index_exists():
if USING_AWS_MANAGED_OPENSEARCH:
index_settings = (
DocumentSchema.get_index_settings_for_aws_managed_opensearch()
)
else:
index_settings = DocumentSchema.get_index_settings()
self._os_client.create_index(
self._client.create_index(
mappings=expected_mappings,
settings=index_settings,
)
if not self._os_client.validate_index(
expected_mappings=expected_mappings,
):
raise RuntimeError(
f"The index {self._index_name} is not valid. The expected mappings do not match the actual mappings."
)
self._os_client.create_search_pipeline(
pipeline_id=MIN_MAX_NORMALIZATION_PIPELINE_NAME,
pipeline_body=MIN_MAX_NORMALIZATION_PIPELINE_CONFIG,
)
self._os_client.create_search_pipeline(
pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
pipeline_body=ZSCORE_NORMALIZATION_PIPELINE_CONFIG,
)
else:
# Ensure schema is up to date by applying the current mappings.
try:
self._client.put_mapping(expected_mappings)
except Exception as e:
logger.error(
f"Failed to update mappings for index {self._index_name}. This likely means a "
f"field type was changed which requires reindexing. Error: {e}"
)
raise
def index(
self,
@@ -613,7 +654,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
)
# Now index. This will raise if a chunk of the same ID exists, which
# we do not expect because we should have deleted all chunks.
self._os_client.bulk_index_documents(
self._client.bulk_index_documents(
documents=chunk_batch,
tenant_state=self._tenant_state,
)
@@ -653,7 +694,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
tenant_state=self._tenant_state,
)
return self._os_client.delete_by_query(query_body)
return self._client.delete_by_query(query_body)
def update(
self,
@@ -709,6 +750,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
properties_to_update[USER_PROJECTS_FIELD_NAME] = list(
update_request.project_ids
)
if update_request.persona_ids is not None:
properties_to_update[PERSONAS_FIELD_NAME] = list(
update_request.persona_ids
)
if not properties_to_update:
if len(update_request.document_ids) > 1:
@@ -749,7 +794,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
document_id=doc_id,
chunk_index=chunk_index,
)
self._os_client.update_document(
self._client.update_document(
document_chunk_id=document_chunk_id,
properties_to_update=properties_to_update,
)
@@ -788,7 +833,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
min_chunk_index=chunk_request.min_chunk_ind,
max_chunk_index=chunk_request.max_chunk_ind,
)
search_hits = self._os_client.search(
search_hits = self._client.search(
body=query_body,
search_pipeline_id=None,
)
@@ -838,7 +883,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
# NOTE: Using z-score normalization here because it's better for hybrid search from a theoretical standpoint.
# Empirically on a small dataset of up to 10K docs, it's not very different. Likely more impactful at scale.
# https://opensearch.org/blog/introducing-the-z-score-normalization-technique-for-hybrid-search/
search_hits: list[SearchHit[DocumentChunk]] = self._os_client.search(
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
body=query_body,
search_pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
)
@@ -870,7 +915,7 @@ class OpenSearchDocumentIndex(DocumentIndex):
index_filters=filters,
num_to_retrieve=num_to_retrieve,
)
search_hits: list[SearchHit[DocumentChunk]] = self._os_client.search(
search_hits: list[SearchHit[DocumentChunk]] = self._client.search(
body=query_body,
search_pipeline_id=None,
)
@@ -898,6 +943,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
# Do not raise if the document already exists, just update. This is
# because the document may already have been indexed during the
# OpenSearch transition period.
self._os_client.bulk_index_documents(
self._client.bulk_index_documents(
documents=chunks, tenant_state=self._tenant_state, update_if_exists=True
)

View File

@@ -41,6 +41,7 @@ IMAGE_FILE_ID_FIELD_NAME = "image_file_id"
SOURCE_LINKS_FIELD_NAME = "source_links"
DOCUMENT_SETS_FIELD_NAME = "document_sets"
USER_PROJECTS_FIELD_NAME = "user_projects"
PERSONAS_FIELD_NAME = "personas"
DOCUMENT_ID_FIELD_NAME = "document_id"
CHUNK_INDEX_FIELD_NAME = "chunk_index"
MAX_CHUNK_SIZE_FIELD_NAME = "max_chunk_size"
@@ -156,6 +157,7 @@ class DocumentChunk(BaseModel):
document_sets: list[str] | None = None
user_projects: list[int] | None = None
personas: list[int] | None = None
primary_owners: list[str] | None = None
secondary_owners: list[str] | None = None
@@ -485,6 +487,7 @@ class DocumentSchema:
# Product-specific fields.
DOCUMENT_SETS_FIELD_NAME: {"type": "keyword"},
USER_PROJECTS_FIELD_NAME: {"type": "integer"},
PERSONAS_FIELD_NAME: {"type": "integer"},
PRIMARY_OWNERS_FIELD_NAME: {"type": "keyword"},
SECONDARY_OWNERS_FIELD_NAME: {"type": "keyword"},
# OpenSearch metadata fields.

View File

@@ -28,6 +28,7 @@ from onyx.document_index.opensearch.schema import HIDDEN_FIELD_NAME
from onyx.document_index.opensearch.schema import LAST_UPDATED_FIELD_NAME
from onyx.document_index.opensearch.schema import MAX_CHUNK_SIZE_FIELD_NAME
from onyx.document_index.opensearch.schema import METADATA_LIST_FIELD_NAME
from onyx.document_index.opensearch.schema import PERSONAS_FIELD_NAME
from onyx.document_index.opensearch.schema import PUBLIC_FIELD_NAME
from onyx.document_index.opensearch.schema import set_or_convert_timezone_to_utc
from onyx.document_index.opensearch.schema import SOURCE_TYPE_FIELD_NAME
@@ -144,6 +145,7 @@ class DocumentQuery:
document_sets=index_filters.document_set or [],
user_file_ids=index_filters.user_file_ids or [],
project_id=index_filters.project_id,
persona_id=index_filters.persona_id,
time_cutoff=index_filters.time_cutoff,
min_chunk_index=min_chunk_index,
max_chunk_index=max_chunk_index,
@@ -202,6 +204,7 @@ class DocumentQuery:
document_sets=[],
user_file_ids=[],
project_id=None,
persona_id=None,
time_cutoff=None,
min_chunk_index=None,
max_chunk_index=None,
@@ -267,6 +270,7 @@ class DocumentQuery:
document_sets=index_filters.document_set or [],
user_file_ids=index_filters.user_file_ids or [],
project_id=index_filters.project_id,
persona_id=index_filters.persona_id,
time_cutoff=index_filters.time_cutoff,
min_chunk_index=None,
max_chunk_index=None,
@@ -334,6 +338,7 @@ class DocumentQuery:
document_sets=index_filters.document_set or [],
user_file_ids=index_filters.user_file_ids or [],
project_id=index_filters.project_id,
persona_id=index_filters.persona_id,
time_cutoff=index_filters.time_cutoff,
min_chunk_index=None,
max_chunk_index=None,
@@ -496,6 +501,7 @@ class DocumentQuery:
document_sets: list[str],
user_file_ids: list[UUID],
project_id: int | None,
persona_id: int | None,
time_cutoff: datetime | None,
min_chunk_index: int | None,
max_chunk_index: int | None,
@@ -530,6 +536,8 @@ class DocumentQuery:
retrieved.
project_id: If not None, only documents with this project ID in user
projects will be retrieved.
persona_id: If not None, only documents whose personas array
contains this persona ID will be retrieved.
time_cutoff: Time cutoff for the documents to retrieve. If not None,
Documents which were last updated before this date will not be
returned. For documents which do not have a value for their last
@@ -627,6 +635,9 @@ class DocumentQuery:
)
return user_project_filter
def _get_persona_filter(persona_id: int) -> dict[str, Any]:
return {"term": {PERSONAS_FIELD_NAME: {"value": persona_id}}}
def _get_time_cutoff_filter(time_cutoff: datetime) -> dict[str, Any]:
# Convert to UTC if not already so the cutoff is comparable to the
# document data.
@@ -780,6 +791,9 @@ class DocumentQuery:
# document's user projects list.
filter_clauses.append(_get_user_project_filter(project_id))
if persona_id is not None:
filter_clauses.append(_get_persona_filter(persona_id))
if time_cutoff is not None:
# If a time cutoff is provided, the caller will only retrieve
# documents where the document was last updated at or after the time

View File

@@ -181,6 +181,11 @@ schema {{ schema_name }} {
rank: filter
attribute: fast-search
}
field personas type array<int> {
indexing: summary | attribute
rank: filter
attribute: fast-search
}
}
# If using different tokenization settings, the fieldset has to be removed, and the field must

View File

@@ -689,6 +689,9 @@ class VespaIndex(DocumentIndex):
project_ids: set[int] | None = None
if user_fields is not None and user_fields.user_projects is not None:
project_ids = set(user_fields.user_projects)
persona_ids: set[int] | None = None
if user_fields is not None and user_fields.personas is not None:
persona_ids = set(user_fields.personas)
update_request = MetadataUpdateRequest(
document_ids=[doc_id],
doc_id_to_chunk_cnt={
@@ -699,6 +702,7 @@ class VespaIndex(DocumentIndex):
boost=fields.boost if fields is not None else None,
hidden=fields.hidden if fields is not None else None,
project_ids=project_ids,
persona_ids=persona_ids,
)
vespa_document_index.update([update_request])

View File

@@ -46,6 +46,7 @@ from onyx.document_index.vespa_constants import METADATA
from onyx.document_index.vespa_constants import METADATA_LIST
from onyx.document_index.vespa_constants import METADATA_SUFFIX
from onyx.document_index.vespa_constants import NUM_THREADS
from onyx.document_index.vespa_constants import PERSONAS
from onyx.document_index.vespa_constants import PRIMARY_OWNERS
from onyx.document_index.vespa_constants import SECONDARY_OWNERS
from onyx.document_index.vespa_constants import SECTION_CONTINUATION
@@ -218,6 +219,7 @@ def _index_vespa_chunk(
# still called `image_file_name` in Vespa for backwards compatibility
IMAGE_FILE_NAME: chunk.image_file_id,
USER_PROJECT: chunk.user_project if chunk.user_project is not None else [],
PERSONAS: chunk.personas if chunk.personas is not None else [],
BOOST: chunk.boost,
AGGREGATED_CHUNK_BOOST_FACTOR: chunk.aggregated_chunk_boost_factor,
}

View File

@@ -12,6 +12,7 @@ from onyx.document_index.vespa_constants import DOCUMENT_ID
from onyx.document_index.vespa_constants import DOCUMENT_SETS
from onyx.document_index.vespa_constants import HIDDEN
from onyx.document_index.vespa_constants import METADATA_LIST
from onyx.document_index.vespa_constants import PERSONAS
from onyx.document_index.vespa_constants import SOURCE_TYPE
from onyx.document_index.vespa_constants import TENANT_ID
from onyx.document_index.vespa_constants import USER_PROJECT
@@ -149,6 +150,18 @@ def build_vespa_filters(
# Vespa YQL 'contains' expects a string literal; quote the integer
return f'({USER_PROJECT} contains "{pid}") and '
def _build_persona_filter(
persona_id: int | None,
) -> str:
if persona_id is None:
return ""
try:
pid = int(persona_id)
except Exception:
logger.warning(f"Invalid persona ID: {persona_id}")
return ""
return f'({PERSONAS} contains "{pid}") and '
# Start building the filter string
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
@@ -192,6 +205,9 @@ def build_vespa_filters(
# User project filter (array<int> attribute membership)
filter_str += _build_user_project_filter(filters.project_id)
# Persona filter (array<int> attribute membership)
filter_str += _build_persona_filter(filters.persona_id)
# Time filter
filter_str += _build_time_filter(filters.time_cutoff)

View File

@@ -183,6 +183,10 @@ def _update_single_chunk(
model_config = {"frozen": True}
assign: list[int]
class _Personas(BaseModel):
model_config = {"frozen": True}
assign: list[int]
class _VespaPutFields(BaseModel):
model_config = {"frozen": True}
# The names of these fields are based the Vespa schema. Changes to the
@@ -193,6 +197,7 @@ def _update_single_chunk(
access_control_list: _AccessControl | None = None
hidden: _Hidden | None = None
user_project: _UserProjects | None = None
personas: _Personas | None = None
class _VespaPutRequest(BaseModel):
model_config = {"frozen": True}
@@ -227,6 +232,11 @@ def _update_single_chunk(
if update_request.project_ids is not None
else None
)
personas_update: _Personas | None = (
_Personas(assign=list(update_request.persona_ids))
if update_request.persona_ids is not None
else None
)
vespa_put_fields = _VespaPutFields(
boost=boost_update,
@@ -234,6 +244,7 @@ def _update_single_chunk(
access_control_list=access_update,
hidden=hidden_update,
user_project=user_projects_update,
personas=personas_update,
)
vespa_put_request = _VespaPutRequest(
@@ -554,10 +565,9 @@ class VespaDocumentIndex(DocumentIndex):
num_to_retrieve: int,
) -> list[InferenceChunk]:
vespa_where_clauses = build_vespa_filters(filters)
# Needs to be at least as much as the rerank-count value set in the
# Vespa schema config. Otherwise we would be getting fewer results than
# expected for reranking.
target_hits = max(10 * num_to_retrieve, RERANK_COUNT)
# Avoid over-fetching a very large candidate set for global-phase reranking.
# Keep enough headroom for quality while capping cost on larger indices.
target_hits = min(max(4 * num_to_retrieve, 100), RERANK_COUNT)
yql = (
YQL_BASE.format(index_name=self._index_name)

View File

@@ -58,6 +58,7 @@ DOCUMENT_SETS = "document_sets"
USER_FILE = "user_file"
USER_FOLDER = "user_folder"
USER_PROJECT = "user_project"
PERSONAS = "personas"
LARGE_CHUNK_REFERENCE_IDS = "large_chunk_reference_ids"
METADATA = "metadata"
METADATA_LIST = "metadata_list"

View File

@@ -20,7 +20,20 @@ class ImageGenerationProviderCredentials(BaseModel):
custom_config: dict[str, str] | None = None
class ReferenceImage(BaseModel):
data: bytes
mime_type: str
class ImageGenerationProvider(abc.ABC):
@property
def supports_reference_images(self) -> bool:
return False
@property
def max_reference_images(self) -> int:
return 0
@classmethod
@abc.abstractmethod
def validate_credentials(
@@ -63,6 +76,7 @@ class ImageGenerationProvider(abc.ABC):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None,
**kwargs: Any,
) -> ImageGenerationResponse:
"""Generates an image based on a prompt."""

View File

@@ -5,12 +5,16 @@ from typing import TYPE_CHECKING
from onyx.image_gen.interfaces import ImageGenerationProvider
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
if TYPE_CHECKING:
from onyx.image_gen.interfaces import ImageGenerationResponse
class AzureImageGenerationProvider(ImageGenerationProvider):
_GPT_IMAGE_MODEL_PREFIX = "gpt-image-"
_DALL_E_2_MODEL_NAME = "dall-e-2"
def __init__(
self,
api_key: str,
@@ -52,6 +56,25 @@ class AzureImageGenerationProvider(ImageGenerationProvider):
deployment_name=credentials.deployment_name,
)
@property
def supports_reference_images(self) -> bool:
return True
@property
def max_reference_images(self) -> int:
# Azure GPT image models support up to 16 input images for edits.
return 16
def _normalize_model_name(self, model: str) -> str:
return model.rsplit("/", 1)[-1]
def _model_supports_image_edits(self, model: str) -> bool:
normalized_model = self._normalize_model_name(model)
return (
normalized_model.startswith(self._GPT_IMAGE_MODEL_PREFIX)
or normalized_model == self._DALL_E_2_MODEL_NAME
)
def generate_image(
self,
prompt: str,
@@ -59,13 +82,44 @@ class AzureImageGenerationProvider(ImageGenerationProvider):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None,
**kwargs: Any,
) -> ImageGenerationResponse:
from litellm import image_generation
deployment = self._deployment_name or model
model_name = f"azure/{deployment}"
if reference_images:
if not self._model_supports_image_edits(model):
raise ValueError(
f"Model '{model}' does not support image edits with reference images."
)
normalized_model = self._normalize_model_name(model)
if (
normalized_model == self._DALL_E_2_MODEL_NAME
and len(reference_images) > 1
):
raise ValueError(
"Model 'dall-e-2' only supports a single reference image for edits."
)
from litellm import image_edit
return image_edit(
image=[image.data for image in reference_images],
prompt=prompt,
model=model_name,
api_key=self._api_key,
api_base=self._api_base,
api_version=self._api_version,
size=size,
n=n,
quality=quality,
**kwargs,
)
from litellm import image_generation
return image_generation(
prompt=prompt,
model=model_name,

View File

@@ -5,12 +5,16 @@ from typing import TYPE_CHECKING
from onyx.image_gen.interfaces import ImageGenerationProvider
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
if TYPE_CHECKING:
from onyx.image_gen.interfaces import ImageGenerationResponse
class OpenAIImageGenerationProvider(ImageGenerationProvider):
_GPT_IMAGE_MODEL_PREFIX = "gpt-image-"
_DALL_E_2_MODEL_NAME = "dall-e-2"
def __init__(
self,
api_key: str,
@@ -38,6 +42,25 @@ class OpenAIImageGenerationProvider(ImageGenerationProvider):
api_base=credentials.api_base,
)
@property
def supports_reference_images(self) -> bool:
return True
@property
def max_reference_images(self) -> int:
# GPT image models support up to 16 input images for edits.
return 16
def _normalize_model_name(self, model: str) -> str:
return model.rsplit("/", 1)[-1]
def _model_supports_image_edits(self, model: str) -> bool:
normalized_model = self._normalize_model_name(model)
return (
normalized_model.startswith(self._GPT_IMAGE_MODEL_PREFIX)
or normalized_model == self._DALL_E_2_MODEL_NAME
)
def generate_image(
self,
prompt: str,
@@ -45,8 +68,38 @@ class OpenAIImageGenerationProvider(ImageGenerationProvider):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None,
**kwargs: Any,
) -> ImageGenerationResponse:
if reference_images:
if not self._model_supports_image_edits(model):
raise ValueError(
f"Model '{model}' does not support image edits with reference images."
)
normalized_model = self._normalize_model_name(model)
if (
normalized_model == self._DALL_E_2_MODEL_NAME
and len(reference_images) > 1
):
raise ValueError(
"Model 'dall-e-2' only supports a single reference image for edits."
)
from litellm import image_edit
return image_edit(
image=[image.data for image in reference_images],
prompt=prompt,
model=model,
api_key=self._api_key,
api_base=self._api_base,
size=size,
n=n,
quality=quality,
**kwargs,
)
from litellm import image_generation
return image_generation(

View File

@@ -1,6 +1,8 @@
from __future__ import annotations
import base64
import json
from datetime import datetime
from typing import Any
from typing import TYPE_CHECKING
@@ -9,6 +11,7 @@ from pydantic import BaseModel
from onyx.image_gen.exceptions import ImageProviderCredentialsError
from onyx.image_gen.interfaces import ImageGenerationProvider
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
if TYPE_CHECKING:
from onyx.image_gen.interfaces import ImageGenerationResponse
@@ -51,6 +54,15 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
vertex_credentials=vertex_credentials,
)
@property
def supports_reference_images(self) -> bool:
return True
@property
def max_reference_images(self) -> int:
# Gemini image editing supports up to 14 input images.
return 14
def generate_image(
self,
prompt: str,
@@ -58,8 +70,18 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None,
**kwargs: Any,
) -> ImageGenerationResponse:
if reference_images:
return self._generate_image_with_reference_images(
prompt=prompt,
model=model,
size=size,
n=n,
reference_images=reference_images,
)
from litellm import image_generation
return image_generation(
@@ -74,6 +96,99 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
**kwargs,
)
def _generate_image_with_reference_images(
self,
prompt: str,
model: str,
size: str,
n: int,
reference_images: list[ReferenceImage],
) -> ImageGenerationResponse:
from google import genai
from google.genai import types as genai_types
from google.oauth2 import service_account
from litellm.types.utils import ImageObject
from litellm.types.utils import ImageResponse
service_account_info = json.loads(self._vertex_credentials)
credentials = service_account.Credentials.from_service_account_info(
service_account_info,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
client = genai.Client(
vertexai=True,
project=self._vertex_project,
location=self._vertex_location,
credentials=credentials,
)
parts: list[genai_types.Part] = [
genai_types.Part.from_bytes(data=image.data, mime_type=image.mime_type)
for image in reference_images
]
parts.append(genai_types.Part.from_text(text=prompt))
config = genai_types.GenerateContentConfig(
response_modalities=["TEXT", "IMAGE"],
candidate_count=max(1, n),
image_config=genai_types.ImageConfig(
aspect_ratio=_map_size_to_aspect_ratio(size)
),
)
model_name = model.replace("vertex_ai/", "")
response = client.models.generate_content(
model=model_name,
contents=genai_types.Content(
role="user",
parts=parts,
),
config=config,
)
generated_data: list[ImageObject] = []
for candidate in response.candidates or []:
candidate_content = candidate.content
if not candidate_content:
continue
for part in candidate_content.parts or []:
inline_data = part.inline_data
if not inline_data or inline_data.data is None:
continue
if isinstance(inline_data.data, bytes):
b64_json = base64.b64encode(inline_data.data).decode("utf-8")
elif isinstance(inline_data.data, str):
b64_json = inline_data.data
else:
continue
generated_data.append(
ImageObject(
b64_json=b64_json,
revised_prompt=prompt,
)
)
if not generated_data:
raise RuntimeError("No image data returned from Vertex AI.")
return ImageResponse(
created=int(datetime.now().timestamp()),
data=generated_data,
)
def _map_size_to_aspect_ratio(size: str) -> str:
return {
"1024x1024": "1:1",
"1792x1024": "16:9",
"1024x1792": "9:16",
"1536x1024": "3:2",
"1024x1536": "2:3",
}.get(size, "1:1")
def _parse_to_vertex_credentials(
credentials: ImageGenerationProviderCredentials,

View File

@@ -146,6 +146,7 @@ class DocumentIndexingBatchAdapter:
doc_id_to_document_set.get(chunk.source_document.id, [])
),
user_project=[],
personas=[],
boost=(
context.id_to_boost_map[chunk.source_document.id]
if chunk.source_document.id in context.id_to_boost_map

View File

@@ -20,6 +20,7 @@ from onyx.db.models import Persona
from onyx.db.models import UserFile
from onyx.db.notification import create_notification
from onyx.db.user_file import fetch_chunk_counts_for_user_files
from onyx.db.user_file import fetch_persona_ids_for_user_files
from onyx.db.user_file import fetch_user_project_ids_for_user_files
from onyx.file_store.utils import store_user_file_plaintext
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
@@ -119,6 +120,10 @@ class UserFileIndexingAdapter:
user_file_ids=updatable_ids,
db_session=self.db_session,
)
user_file_id_to_persona_ids = fetch_persona_ids_for_user_files(
user_file_ids=updatable_ids,
db_session=self.db_session,
)
user_file_id_to_access: dict[str, DocumentAccess] = get_access_for_user_files(
user_file_ids=updatable_ids,
db_session=self.db_session,
@@ -182,7 +187,7 @@ class UserFileIndexingAdapter:
user_project=user_file_id_to_project_ids.get(
chunk.source_document.id, []
),
# we are going to index userfiles only once, so we just set the boost to the default
personas=user_file_id_to_persona_ids.get(chunk.source_document.id, []),
boost=DEFAULT_BOOST,
tenant_id=tenant_id,
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],

View File

@@ -49,6 +49,7 @@ from onyx.indexing.embedder import IndexingEmbedder
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import IndexingBatchAdapter
from onyx.indexing.models import UpdatableChunkData
from onyx.indexing.postgres_sanitization import sanitize_documents_for_postgres
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
from onyx.llm.factory import get_default_llm_with_vision
from onyx.llm.factory import get_llm_for_contextual_rag
@@ -228,6 +229,8 @@ def index_doc_batch_prepare(
) -> DocumentBatchPrepareContext | None:
"""Sets up the documents in the relational DB (source of truth) for permissions, metadata, etc.
This preceeds indexing it into the actual document index."""
documents = sanitize_documents_for_postgres(documents)
# Create a trimmed list of docs that don't have a newer updated at
# Shortcuts the time-consuming flow on connector index retries
document_ids: list[str] = [document.id for document in documents]

Some files were not shown because too many files have changed in this diff Show More