Compare commits

..

185 Commits

Author SHA1 Message Date
Raunak Bhagat
bcc59a476b Remove skipOverlay flag 2025-12-26 09:31:52 -07:00
Raunak Bhagat
75cfa49504 Fix 2025-12-26 09:23:06 -07:00
Raunak Bhagat
6ec0b09139 feat: Add small icons + scripts + readme to Opal (#7046) 2025-12-26 08:06:57 -08:00
roshan
53691fc95a chore: refactor search tool renderer (#7044) 2025-12-25 22:04:11 -05:00
Jamison Lahman
3400e2a14d chore(desktop): skip desktop on beta tags (#7043) 2025-12-25 13:41:05 -08:00
roshan
d8cc1f7a2c chore: clean up unused feature flag (#7042) 2025-12-25 16:35:53 -05:00
roshan
2098e910dd chore: clean up search renderer v2 (#7041) 2025-12-25 16:31:26 -05:00
Jamison Lahman
e5491d6f79 revert: "chore(fe): enable reactRemoveProperties" (#7040) 2025-12-25 12:00:52 -08:00
Raunak Bhagat
a8934a083a feat: Add useOnChangeValue hook and update form components (#7036) 2025-12-25 11:40:39 -08:00
Chris Weaver
80e9507e01 fix: google index names (#7038) 2025-12-25 17:56:22 +00:00
Raunak Bhagat
60d3be5fe2 refactor: Improve form hook to handle events directly (#7035) 2025-12-25 02:16:47 -08:00
Raunak Bhagat
b481cc36d0 refactor: Update form field components to use new hook (#7034) 2025-12-25 01:54:07 -08:00
Raunak Bhagat
65c5da8912 feat: Create new InputDatePicker component (#7023) 2025-12-24 23:23:47 -08:00
Jamison Lahman
0a0366e6ca chore(fe): enable reactRemoveProperties (#7030) 2025-12-25 05:12:36 +00:00
Jamison Lahman
84a623e884 chore(fe): remove reliance on data-testid prop (#7031) 2025-12-24 20:44:28 -08:00
roshan
6b91607b17 chore: feature flag for deep research (#7022) 2025-12-24 21:38:34 -05:00
Wenxi
82fb737ad9 fix: conditional tool choice param for anthropic (#7029) 2025-12-25 00:25:19 +00:00
Justin Tahara
eed49e699e fix(docprocessing): Cleaning up Events (#7025) 2025-12-24 12:25:43 -08:00
Justin Tahara
3cc7afd334 fix(chat): Copy functionality (#7027) 2025-12-24 12:22:02 -08:00
Jamison Lahman
bcbfd28234 chore(fe): "Copy code"->"Copy" (#7018) 2025-12-24 11:38:02 -08:00
Rohit V
faa47d9691 chore(docs): update docker compose command in CONTRIBUTING.md (#7020)
Co-authored-by: Rohit V <rohit.v@thoughtspot.com>
2025-12-24 11:18:12 -08:00
Wenxi
6649561bf3 fix: multiple tool calls unit test (#7026) 2025-12-24 18:08:12 +00:00
Wenxi
026cda0468 fix: force tool with openai (#7024) 2025-12-24 09:37:14 -08:00
Raunak Bhagat
64297e5996 feat: add formik field components and helpers (#7017)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-12-24 08:09:24 -08:00
Raunak Bhagat
c517137c0a refactor: Update CSS stylings for SidebarTab component (#7016) 2025-12-23 22:56:06 -08:00
SubashMohan
cbfbe0bbbe fix(onboarding): Azure llm url parsing (#6950) 2025-12-24 12:17:31 +05:30
Raunak Bhagat
13ca4c6650 refactor: remove icon prop from UserFilesModal (#7014) 2025-12-23 22:35:42 -08:00
Raunak Bhagat
e8d9e36d62 refactor: SidebarTab fixes (#7012)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2025-12-24 06:06:06 +00:00
Jamison Lahman
77e4f3c574 fix(fe): right sidebar buttons dont inherit href (#7007)
Co-authored-by: Raunak Bhagat <r@rabh.io>
2025-12-24 04:41:22 +00:00
Chris Weaver
2bdc06201a fix: improve scrollbar for code blocks (#7013) 2025-12-24 03:38:09 +00:00
Yuhong Sun
077ba9624c fix: parallel tool call with openai (#7010) 2025-12-23 19:07:23 -08:00
Raunak Bhagat
81eb1a1c7c fix: Fix import error (#7011) 2025-12-23 19:00:10 -08:00
Yuhong Sun
1a16fef783 feat: DEEP RESEARCH ALPHA HUZZAH (#7001) 2025-12-23 18:45:43 -08:00
Yuhong Sun
027692d5eb chore: bump litellm version (#7009) 2025-12-23 18:09:21 -08:00
Raunak Bhagat
3a889f7069 refactor: Add more comprehensive layout components (#6989) 2025-12-23 17:54:32 -08:00
Raunak Bhagat
20d67bd956 feat: Add new components to refresh-components (#6991)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2025-12-23 17:53:59 -08:00
acaprau
8d6b6accaf feat(new vector db interface): Plug in retrievals for Vespa (#6966) 2025-12-23 23:30:59 +00:00
Chris Weaver
ed76b4eb55 fix: masking (#7003) 2025-12-23 23:23:03 +00:00
Raunak Bhagat
7613c100d1 feat: update icons (#6988) 2025-12-23 15:11:33 -08:00
Raunak Bhagat
c52d3412de refactor: add more helpful utility hooks (#6987) 2025-12-23 14:38:13 -08:00
Jamison Lahman
96b6162b52 chore(desktop): fix windows version (#6999) 2025-12-23 22:21:30 +00:00
Yuhong Sun
502ed8909b chore: Tuning Deep Research (#7000) 2025-12-23 14:19:20 -08:00
roshan
8de75dd033 feat: deep research (#6936)
Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
Co-authored-by: Cursor Agent <cursoragent@cursor.com>
2025-12-23 21:24:27 +00:00
Wenxi
74e3668e38 chore: cleanup drupal connector nits (#6998) 2025-12-23 21:24:21 +00:00
Justin Tahara
2475a9ef92 fix(gdrive): Investigation Logging (#6996) 2025-12-23 13:26:44 -08:00
rexjohannes
690f54c441 feat: Drupal Wiki connector (#4773)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-12-23 19:28:23 +00:00
Jamison Lahman
71bb0c029e chore(desktop): deployment automation for the desktop app (#6990) 2025-12-23 09:20:59 -08:00
Yuhong Sun
ccf890a129 Small Tuning (#6986) 2025-12-22 20:13:17 -08:00
acaprau
a7bfdebddf feat(new vector db interface): Implement retrievals for Vespa (#6963) 2025-12-23 03:00:38 +00:00
Yuhong Sun
6fc5ca12a3 Fine grained Braintrust tracing (#6985) 2025-12-22 19:08:49 -08:00
Wenxi
8298452522 feat: add open book icon (#6984) 2025-12-22 19:00:31 -08:00
Wenxi
2559327636 fix: allow chat file previewing and fix csv rendering (#6915) 2025-12-23 02:08:42 +00:00
Yuhong Sun
ef185ce2c8 feat: DR Tab for intermediate reports and Index increment for final report section end (#6983) 2025-12-22 18:10:45 -08:00
Wenxi
a04fee5cbd feat: add optional image parsing for docx (#6981) 2025-12-22 17:45:44 -08:00
Justin Tahara
e507378244 fix(vertex-ai): Bump Default Batch Size (#6982) 2025-12-22 17:21:55 -08:00
Justin Tahara
e6be3f85b2 fix(gemini): No Asyncio (#6980) 2025-12-23 01:07:40 +00:00
acaprau
cc96e303ce feat(new vector db interface): Plug in delete for Vespa (#6867)
Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2025-12-23 00:54:52 +00:00
Nikolas Garza
e0fcb1f860 feat(fe): speed up pre-commit TypeScript type checking with tsgo (#6978) 2025-12-23 00:22:42 +00:00
roshan
f5442c431d feat: add PacketException handling (#6968) 2025-12-23 00:09:51 +00:00
acaprau
652e5848e5 feat(new vector db interface): Implement delete for Vespa (#6866)
Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2025-12-22 23:58:32 +00:00
Wenxi
3fa1896316 fix: download cloud svg (#6977) 2025-12-22 14:54:33 -08:00
roshan
f855ecab11 feat: add dr loop tracing (#6971) 2025-12-22 21:35:29 +00:00
Jamison Lahman
fd26176e7d revert: "fix(fe): make recent chat sidebar buttons links" (#6967) 2025-12-22 12:12:48 -08:00
Justin Tahara
8986f67779 fix(docprocessing): Reusing Threads (#6916) 2025-12-22 19:03:46 +00:00
Nikolas Garza
42f2d4aca5 feat(teams): Enable Auto Sync Permissions for Teams connector (#6648) 2025-12-22 18:57:01 +00:00
Evan Lohn
7116d24a8c fix: small MCP UI changes (#6862) 2025-12-22 18:09:36 +00:00
Justin Tahara
7f4593be32 fix(vertex): Infinite Embedding (#6917) 2025-12-22 10:43:11 -08:00
Wenxi
f47e25e693 feat(ingestion): restore delete api (#6962) 2025-12-22 10:06:43 -08:00
acaprau
877184ae97 feat(new vector db interface): Plug in update for Vespa (#6792) 2025-12-22 16:25:13 +00:00
acaprau
54961ec8ef fix: test_multi_llm.py::test_multiple_tool_calls callsite fix (#6959) 2025-12-22 08:06:13 -08:00
Raunak Bhagat
e797971ce5 fix: Layout fix + CSR updates (#6958)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-12-22 08:00:39 -08:00
Jamison Lahman
566cca70d8 chore(fe): conditionally render header on chatSession (#6955) 2025-12-22 02:37:01 -08:00
Jamison Lahman
be2d0e2b5d chore(fe): prevent header continuous render (#6954) 2025-12-22 00:46:21 -08:00
Jamison Lahman
692f937ca4 chore(fmt): fix prettier (#6953) 2025-12-22 00:30:21 -08:00
Jamison Lahman
11de1ceb65 chore(ts): typedRoutes = true (#6930) 2025-12-22 00:21:44 -08:00
Jamison Lahman
19993b4679 chore(chat): refactor chat header (#6952)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-12-22 00:20:46 -08:00
Yuhong Sun
9063827782 Enable DR on the backend (#6948) 2025-12-21 18:25:24 -08:00
Yuhong Sun
0cc6fa49d7 DR Minor tweaking (#6947) 2025-12-21 17:23:52 -08:00
roshan
3f3508b668 fix: sanitize postgres to remove nul characters (#6934) 2025-12-22 00:19:25 +00:00
Jamison Lahman
1c3a88daf8 perf(chat): avoid re-rendering chat on ChatInput change (#6945) 2025-12-21 16:15:34 -08:00
Yuhong Sun
92f30bbad9 Fix misalignment in DR failed agents (#6946) 2025-12-21 15:07:45 -08:00
Yuhong Sun
4abf43d85b DR bug fixes (#6944) 2025-12-21 14:56:52 -08:00
Jamison Lahman
b08f9adb23 chore(perf): frontend stats overlay in dev (#6840)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-12-21 22:12:54 +00:00
Yuhong Sun
7a915833bb More correct packet handling (#6943) 2025-12-21 13:48:27 -08:00
Jamison Lahman
9698b700e6 fix(desktop): Linux-specific fixes (#6928) 2025-12-21 20:39:52 +00:00
Jamison Lahman
fd944acc5b fix(fe): chat content links use proper hrefs (#6939) 2025-12-21 12:09:20 -08:00
Yuhong Sun
a1309257f5 Log (#6937) 2025-12-20 23:28:28 -08:00
Yuhong Sun
6266dc816d feat: Deep Research Citation Handling (#6935) 2025-12-20 22:46:20 -08:00
Jamison Lahman
83c011a9e4 chore(deps): upgrade urllib3 2.6.1->2.6.2 (#6932) 2025-12-20 20:21:10 -08:00
Yuhong Sun
8d1ac81d09 Citation Processing (#6933) 2025-12-20 20:08:24 -08:00
Yuhong Sun
d8cd4c9928 feat: DR fix a couple issues with saving (#6931) 2025-12-20 18:28:04 -08:00
Jamison Lahman
5caa4fdaa0 fix(chat): attached images are flush right (#6927) 2025-12-20 07:20:14 -08:00
Jamison Lahman
f22f33564b fix(fe): ensure error messages have padding (#6926) 2025-12-20 07:03:27 -08:00
Jamison Lahman
f86d282a47 chore(fe): ensure chat padding on medium size viewport (#6925) 2025-12-20 06:38:16 -08:00
Jamison Lahman
ece1edb80f fix(fe): make recent chat sidebar buttons links (#6924) 2025-12-20 06:04:59 -08:00
Jamison Lahman
c9c17e19f3 fix(chat): only scroll to bottom on page load (#6923) 2025-12-20 05:01:56 -08:00
Jamison Lahman
40e834e0b8 fix(fe): make "New Session" button a link (#6922) 2025-12-20 04:29:22 -08:00
Jamison Lahman
45bd82d031 fix(style): floating scroll down is z-sticky (#6921) 2025-12-20 04:12:48 -08:00
Yuhong Sun
27c1619c3d feat: hyperparams (#6920) 2025-12-19 20:32:00 -08:00
Yuhong Sun
8cfeb85c43 feat: Deep Research packets streaming done (#6919) 2025-12-19 20:23:02 -08:00
Yuhong Sun
491b550ebc feat: Deep Research more stuff (#6918) 2025-12-19 19:14:22 -08:00
Chris Weaver
1a94dfd113 fix: reasoning width (#6914) 2025-12-20 02:24:46 +00:00
Jamison Lahman
bcd9d7ae41 fix(install): handle non-semver docker-compose versions (#6913) 2025-12-19 18:17:44 -08:00
Vinit
98b4353632 fix: use consistent INSTALL_ROOT instead of pwd for deployment paths (#6680)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2025-12-20 01:25:51 +00:00
Yuhong Sun
f071b280d4 feat: Deep Research packets (#6912) 2025-12-19 17:18:56 -08:00
acaprau
f7ebaa42fc feat(new vector db interface): Implement update for Vespa (#6790) 2025-12-20 00:56:23 +00:00
Justin Tahara
11737c2069 fix(vespa): Handling Rate Limits (#6878) 2025-12-20 00:52:11 +00:00
Jamison Lahman
1712253e5f fix(fe): Set up provider logos are equal size (#6900) 2025-12-20 00:50:31 +00:00
Yuhong Sun
de8f292fce feat: DR packets cont (#6910) 2025-12-19 16:47:03 -08:00
Jamison Lahman
bbe5058131 chore(mypy): "ragas.metrics" [import-not-found] (#6909) 2025-12-19 16:35:45 -08:00
Yuhong Sun
45fc5e3c97 chore: Tool interface (#6908) 2025-12-19 16:12:21 -08:00
Yuhong Sun
5c976815cc Mypy (#6906) 2025-12-19 15:50:30 -08:00
Justin Tahara
3ea4b6e6cc feat(desktop): Make Desktop App (#6690)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2025-12-19 15:49:21 -08:00
Yuhong Sun
7b75c0049b chore: minor refactor (#6905) 2025-12-19 15:37:27 -08:00
Yuhong Sun
04bdce55f4 chore: Placement used in more places (#6904) 2025-12-19 15:07:48 -08:00
Yuhong Sun
2446b1898e chore: Test Manager class (#6903) 2025-12-19 14:58:55 -08:00
Yuhong Sun
6f22a2f656 chore: Update Packet structure to make the positioning info an object (#6899) 2025-12-19 14:12:39 -08:00
Justin Tahara
e307a84863 fix(agents): Fix User File Search (#6895) 2025-12-19 21:42:28 +00:00
Chris Weaver
2dd27f25cb feat: allow cmd+click on connector rows in admin panel (#6894) 2025-12-19 21:39:23 +00:00
Nikolas Garza
e402c0e3b4 fix: fix Icon React Compiler error in LLMPopover when searching models (#6891) 2025-12-19 21:16:41 +00:00
Jamison Lahman
2721c8582a chore(pre-commit): run uv-sync in active venv (#6898) 2025-12-19 13:44:00 -08:00
Yuhong Sun
43c8b7a712 feat: Deep Research substep initial (#6896) 2025-12-19 13:30:25 -08:00
acaprau
f473b85acd feat(new vector db interface): Plug in hybrid_retrieval for Vespa (#6752) 2025-12-19 21:03:19 +00:00
Nikolas Garza
02cd84c39a fix(slack): limit thread context fetch to top N messages by relevance (#6861) 2025-12-19 20:26:30 +00:00
Raunak Bhagat
46d17d6c64 fix: Fix header on AgentsNavigationPage (#6873) 2025-12-19 20:15:44 +00:00
Jamison Lahman
10ad536491 chore(mypy): enable warn-unused-ignores (#6893) 2025-12-19 12:00:30 -08:00
acaprau
ccabc1a7a7 feat(new vector db interface): Implement hybrid_retrieval for Vespa (#6750) 2025-12-19 19:32:48 +00:00
Chris Weaver
8e262e4da8 feat: make first runs be high priority (#6871) 2025-12-19 19:05:15 +00:00
Raunak Bhagat
79dea9d901 Revert "refactor: Consolidate chat and agents contexts" (#6872)
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
2025-12-19 11:11:33 -08:00
Yuhong Sun
2f650bbef8 chore: Matplotlib for mypy (#6892) 2025-12-19 10:47:59 -08:00
Jamison Lahman
021e67ca71 chore(pre-commit): "Check lazy imports" prefers active venv (#6890) 2025-12-19 10:04:02 -08:00
roshan
87ae024280 fix icon button z-index (#6889) 2025-12-19 09:52:47 -08:00
SubashMohan
5092429557 Feat/tests GitHub perm sync (#6882) 2025-12-19 17:26:55 +00:00
Nikolas Garza
dc691199f5 fix: persist user-selected connector sources on follow-up messages (#6865) 2025-12-19 17:26:48 +00:00
Jamison Lahman
1662c391f0 fix(fe): chat attachment alignment regression (#6884) 2025-12-19 07:44:34 -08:00
Jamison Lahman
08aefbc115 fix(style): bottom message padding on small screen (#6883) 2025-12-19 06:50:43 -08:00
Jamison Lahman
fb6342daa9 fix(style): chat page is flush left on small screens (#6881) 2025-12-19 06:37:35 -08:00
Jamison Lahman
4e7adcc9ee chore(devtools): pass debug auth token with server-side requests (#6836) 2025-12-19 04:07:53 -08:00
Wenxi
aa4b3d8a24 fix(tests): add research agent tool to tool seeding test (#6877) 2025-12-18 23:09:18 -08:00
Wenxi
f3bc459b6e fix(anthropic): parse chat history tool calls correctly for anthropic models (#6876) 2025-12-18 22:28:34 -08:00
Yuhong Sun
87cab60b01 feat: Deep Research Tool (#6875) 2025-12-18 20:30:36 -08:00
Yuhong Sun
08ab73caf8 fix: Reasoning (#6874) 2025-12-18 19:00:13 -08:00
Justin Tahara
675761c81e fix(users): Clean up Invited Users who are Active (#6857) 2025-12-19 01:43:32 +00:00
Raunak Bhagat
18e15c6da6 refactor: Consolidate chat and agents contexts (#6834) 2025-12-19 01:31:02 +00:00
Yuhong Sun
e1f77e2e17 feat: Deep Research works till the end (#6870) 2025-12-18 17:18:08 -08:00
Justin Tahara
4ef388b2dc fix(tf): Instance Configurability (#6869) 2025-12-18 17:15:05 -08:00
Justin Tahara
031485232b fix(admin): Sidebar Scroll (#6853) 2025-12-19 00:39:27 +00:00
Wenxi
c0debefaf6 fix(bandaid): admin pages bottom padding (#6856) 2025-12-18 16:49:32 -08:00
Nikolas Garza
bbebe5f201 fix: reset actions popover to main menu on open (#6863) 2025-12-19 00:24:01 +00:00
Yuhong Sun
ac9cb22fee feat: deep research continued (#6864) 2025-12-18 15:51:13 -08:00
Wenxi
5e281ce2e6 refactor: unify mimetype and file extensions (#6849) 2025-12-18 23:08:26 +00:00
Chris Weaver
9ea5b7a424 chore: better cloud metrics (#6851) 2025-12-18 22:47:41 +00:00
Justin Tahara
e0b83fad4c fix(web): Avoiding Bot Detection issues (#6845) 2025-12-18 22:43:38 +00:00
Chris Weaver
7191b9010d fix: handle 401s in attachment fetching (#6858)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-12-18 14:52:05 -08:00
Yuhong Sun
fb3428ed37 feat: deep research more dev stuff (#6854) 2025-12-18 14:09:46 -08:00
Chris Weaver
444ad297da chore: remove fast model (#6841) 2025-12-18 20:38:13 +00:00
roshan
f46df421a7 fix: correct tool response pairing for parallel tool calls in llm_loop (#6846) 2025-12-18 11:46:34 -08:00
Yuhong Sun
98a2e12090 feat: DR continued work (#6848) 2025-12-18 11:36:34 -08:00
Jamison Lahman
36bfa8645e chore(gha): run playwright and jest similar to other tests (#6844) 2025-12-18 18:41:16 +00:00
roshan
56e71d7f6c fix: text view auto focus on button (#6843) 2025-12-18 10:18:43 -08:00
roshan
e0d172615b fix: TextView tooltip z-index (#6842) 2025-12-18 10:11:40 -08:00
Shahar Mazor
bde52b13d4 feat: add file management capabilities (#5623)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Wenxi <wenxi@onyx.app>
2025-12-18 17:40:24 +00:00
SubashMohan
b273d91512 feat(actions): add passthrough auth (#6665) 2025-12-18 10:58:52 +00:00
Jamison Lahman
1fbe76a607 fix(fe): center-align credential update icons (#6837) 2025-12-18 02:43:24 -08:00
Jamison Lahman
6ee7316130 fix(fe): avoid chat message shift on hover (#6835) 2025-12-17 23:44:09 -08:00
Raunak Bhagat
51802f46bb fix: Open sub menu on tool force (#6813) 2025-12-18 05:16:43 +00:00
Jamison Lahman
d430444424 fix(fe): apply z-sticky to ChatInput (#6827) 2025-12-17 21:04:34 -08:00
Yuhong Sun
17fff6c805 fix: reasoning with 5 series (#6833) 2025-12-17 20:16:48 -08:00
Yuhong Sun
a33f6e8416 fix: LLM can hallucinate tool calls (#6832) 2025-12-17 19:45:31 -08:00
Nikolas Garza
d157649069 fix(llm-popover): hide provider group when single provider (#6820) 2025-12-17 19:30:48 -08:00
Wenxi
77bbb9f7a7 fix: decrement litellm and openai broken versions (#6831) 2025-12-17 19:09:06 -08:00
Yuhong Sun
996b5177d9 feat: parallel tool calling (#6779)
Co-authored-by: rohoswagger <rohod04@gmail.com>
2025-12-17 18:59:34 -08:00
acaprau
ab9a3ba970 feat(new vector db interface): Plug in index for Vespa (#6659)
Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2025-12-18 01:42:08 +00:00
Yuhong Sun
87c1f0ab10 feat: more orchestrator stuff (#6826) 2025-12-17 17:12:22 -08:00
acaprau
dcea1d88e5 feat(new vector db interface): Implement index for Vespa (#6658)
Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2025-12-18 00:26:07 +00:00
Nikolas Garza
cc481e20d3 feat: ee license tracking - API Endpoints (#6812) 2025-12-18 00:24:01 +00:00
Nikolas Garza
4d141a8f68 feat: ee license tracking - DB and Cache Operations (#6811) 2025-12-17 23:53:28 +00:00
Wenxi
cb32c81d1b refactor(web search): use refreshed modal, improve ux, add playwright tests (#6791) 2025-12-17 15:24:47 -08:00
Nikolas Garza
64f327fdef feat: ee license tracking - Crypto Verification Utils (#6810) 2025-12-17 22:41:12 +00:00
Yuhong Sun
902d6112c3 feat: Deep Research orchestration start (#6825) 2025-12-17 14:53:25 -08:00
Jamison Lahman
f71e3b9151 chore(devtools): address hatch.version.raw-options review comment (#6823) 2025-12-17 14:52:06 -08:00
Nikolas Garza
dd7e1520c5 feat: ee license tracking - Data Plane Models + Database Schema (#6809) 2025-12-17 21:26:33 +00:00
Jamison Lahman
97553de299 chore(devtools): go onboarding docs + replace hatch-vcs w/ code script (#6819) 2025-12-17 13:27:43 -08:00
Justin Tahara
c80ab8b200 fix(jira): Handle Errors better (#6816) 2025-12-17 21:12:14 +00:00
Jamison Lahman
85c4ddce39 chore(frontend): optionally inject auth cookie into requests (#6794)
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2025-12-17 20:43:36 +00:00
623 changed files with 38855 additions and 14743 deletions

View File

@@ -6,8 +6,9 @@ on:
- "*"
workflow_dispatch:
permissions:
contents: read
# Set restrictive default permissions for all jobs. Jobs that need more permissions
# should explicitly declare them.
permissions: {}
env:
IS_DRY_RUN: ${{ github.event_name == 'workflow_dispatch' }}
@@ -20,6 +21,7 @@ jobs:
runs-on: ubuntu-slim
timeout-minutes: 90
outputs:
build-desktop: ${{ steps.check.outputs.build-desktop }}
build-web: ${{ steps.check.outputs.build-web }}
build-web-cloud: ${{ steps.check.outputs.build-web-cloud }}
build-backend: ${{ steps.check.outputs.build-backend }}
@@ -30,6 +32,7 @@ jobs:
is-stable-standalone: ${{ steps.check.outputs.is-stable-standalone }}
is-beta-standalone: ${{ steps.check.outputs.is-beta-standalone }}
sanitized-tag: ${{ steps.check.outputs.sanitized-tag }}
short-sha: ${{ steps.check.outputs.short-sha }}
steps:
- name: Check which components to build and version info
id: check
@@ -38,6 +41,7 @@ jobs:
# Sanitize tag name by replacing slashes with hyphens (for Docker tag compatibility)
SANITIZED_TAG=$(echo "$TAG" | tr '/' '-')
IS_CLOUD=false
BUILD_DESKTOP=false
BUILD_WEB=false
BUILD_WEB_CLOUD=false
BUILD_BACKEND=true
@@ -47,13 +51,6 @@ jobs:
IS_STABLE_STANDALONE=false
IS_BETA_STANDALONE=false
if [[ "$TAG" == *cloud* ]]; then
IS_CLOUD=true
BUILD_WEB_CLOUD=true
else
BUILD_WEB=true
fi
# Version checks (for web - any stable version)
if [[ "$TAG" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
IS_STABLE=true
@@ -62,6 +59,17 @@ jobs:
IS_BETA=true
fi
if [[ "$TAG" == *cloud* ]]; then
IS_CLOUD=true
BUILD_WEB_CLOUD=true
else
BUILD_WEB=true
# Skip desktop builds on beta tags
if [[ "$IS_BETA" != "true" ]]; then
BUILD_DESKTOP=true
fi
fi
# Version checks (for backend/model-server - stable version excluding cloud tags)
if [[ "$TAG" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "$TAG" != *cloud* ]]; then
IS_STABLE_STANDALONE=true
@@ -70,7 +78,9 @@ jobs:
IS_BETA_STANDALONE=true
fi
SHORT_SHA="${GITHUB_SHA::7}"
{
echo "build-desktop=$BUILD_DESKTOP"
echo "build-web=$BUILD_WEB"
echo "build-web-cloud=$BUILD_WEB_CLOUD"
echo "build-backend=$BUILD_BACKEND"
@@ -81,6 +91,7 @@ jobs:
echo "is-stable-standalone=$IS_STABLE_STANDALONE"
echo "is-beta-standalone=$IS_BETA_STANDALONE"
echo "sanitized-tag=$SANITIZED_TAG"
echo "short-sha=$SHORT_SHA"
} >> "$GITHUB_OUTPUT"
check-version-tag:
@@ -95,7 +106,7 @@ jobs:
fetch-depth: 0
- name: Setup uv
uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # ratchet:astral-sh/setup-uv@v7.1.4
uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # ratchet:astral-sh/setup-uv@v7.1.5
with:
# NOTE: This isn't caching much and zizmor suggests this could be poisoned, so disable.
enable-cache: false
@@ -124,6 +135,136 @@ jobs:
title: "🚨 Version Tag Check Failed"
ref-name: ${{ github.ref_name }}
build-desktop:
needs: determine-builds
if: needs.determine-builds.outputs.build-desktop == 'true'
permissions:
contents: write
actions: read
strategy:
fail-fast: false
matrix:
include:
- platform: 'macos-latest' # Build a universal image for macOS.
args: '--target universal-apple-darwin'
- platform: 'ubuntu-24.04'
args: '--bundles deb,rpm'
- platform: 'ubuntu-24.04-arm' # Only available in public repos.
args: '--bundles deb,rpm'
- platform: 'windows-latest'
args: ''
runs-on: ${{ matrix.platform }}
timeout-minutes: 90
steps:
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6.0.1
with:
# NOTE: persist-credentials is needed for tauri-action to create GitHub releases.
persist-credentials: true # zizmor: ignore[artipacked]
- name: install dependencies (ubuntu only)
if: startsWith(matrix.platform, 'ubuntu-')
run: |
sudo apt-get update
sudo apt-get install -y \
build-essential \
libglib2.0-dev \
libgirepository1.0-dev \
libgtk-3-dev \
libjavascriptcoregtk-4.1-dev \
libwebkit2gtk-4.1-dev \
libayatana-appindicator3-dev \
gobject-introspection \
pkg-config \
curl \
xdg-utils
- name: setup node
uses: actions/setup-node@395ad3262231945c25e8478fd5baf05154b1d79f # ratchet:actions/setup-node@v6.1.0
with:
node-version: 24
package-manager-cache: false
- name: install Rust stable
uses: dtolnay/rust-toolchain@6d9817901c499d6b02debbb57edb38d33daa680b # zizmor: ignore[impostor-commit]
with:
# Those targets are only used on macos runners so it's in an `if` to slightly speed up windows and linux builds.
targets: ${{ matrix.platform == 'macos-latest' && 'aarch64-apple-darwin,x86_64-apple-darwin' || '' }}
- name: install frontend dependencies
working-directory: ./desktop
run: npm install
- name: Inject version (Unix)
if: runner.os != 'Windows'
working-directory: ./desktop
env:
SHORT_SHA: ${{ needs.determine-builds.outputs.short-sha }}
EVENT_NAME: ${{ github.event_name }}
run: |
if [ "${EVENT_NAME}" == "workflow_dispatch" ]; then
VERSION="0.0.0-dev+${SHORT_SHA}"
else
VERSION="${GITHUB_REF_NAME#v}"
fi
echo "Injecting version: $VERSION"
# Update Cargo.toml
sed "s/^version = .*/version = \"$VERSION\"/" src-tauri/Cargo.toml > src-tauri/Cargo.toml.tmp
mv src-tauri/Cargo.toml.tmp src-tauri/Cargo.toml
# Update tauri.conf.json
jq --arg v "$VERSION" '.version = $v' src-tauri/tauri.conf.json > src-tauri/tauri.conf.json.tmp
mv src-tauri/tauri.conf.json.tmp src-tauri/tauri.conf.json
# Update package.json
jq --arg v "$VERSION" '.version = $v' package.json > package.json.tmp
mv package.json.tmp package.json
echo "Versions set to: $VERSION"
- name: Inject version (Windows)
if: runner.os == 'Windows'
working-directory: ./desktop
shell: pwsh
run: |
# Windows MSI requires numeric-only build metadata, so we skip the SHA suffix
if ("${{ github.event_name }}" -eq "workflow_dispatch") {
$VERSION = "0.0.0"
} else {
# Strip 'v' prefix and any pre-release suffix (e.g., -beta.13) for MSI compatibility
$VERSION = "$env:GITHUB_REF_NAME" -replace '^v', '' -replace '-.*$', ''
}
Write-Host "Injecting version: $VERSION"
# Update Cargo.toml
$cargo = Get-Content src-tauri/Cargo.toml -Raw
$cargo = $cargo -replace '(?m)^version = .*', "version = `"$VERSION`""
Set-Content src-tauri/Cargo.toml $cargo -NoNewline
# Update tauri.conf.json
$json = Get-Content src-tauri/tauri.conf.json | ConvertFrom-Json
$json.version = $VERSION
$json | ConvertTo-Json -Depth 100 | Set-Content src-tauri/tauri.conf.json
# Update package.json
$pkg = Get-Content package.json | ConvertFrom-Json
$pkg.version = $VERSION
$pkg | ConvertTo-Json -Depth 100 | Set-Content package.json
Write-Host "Versions set to: $VERSION"
- uses: tauri-apps/tauri-action@19b93bb55601e3e373a93cfb6eb4242e45f5af20 # ratchet:tauri-apps/tauri-action@action-v0.6.0
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
tagName: ${{ github.event_name != 'workflow_dispatch' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
releaseName: ${{ github.event_name != 'workflow_dispatch' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
releaseBody: 'See the assets to download this version and install.'
releaseDraft: true
prerelease: false
args: ${{ matrix.args }}
build-web-amd64:
needs: determine-builds
if: needs.determine-builds.outputs.build-web == 'true'
@@ -147,7 +288,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
with:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
@@ -205,7 +346,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
with:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
@@ -267,7 +408,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
with:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
@@ -313,7 +454,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
with:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
@@ -379,7 +520,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
with:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
@@ -449,7 +590,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
with:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
@@ -492,7 +633,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
with:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
@@ -549,7 +690,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
with:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
@@ -610,7 +751,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
with:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
@@ -657,7 +798,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
with:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
@@ -721,7 +862,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
with:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
@@ -788,7 +929,7 @@ jobs:
- name: Docker meta
id: meta
uses: docker/metadata-action@318604b99e75e41977312d83839a89be02ca4893 # ratchet:docker/metadata-action@v5
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # ratchet:docker/metadata-action@v5
with:
images: ${{ github.event_name == 'workflow_dispatch' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }}
flavor: |
@@ -980,6 +1121,7 @@ jobs:
notify-slack-on-failure:
needs:
- build-desktop
- build-web-amd64
- build-web-arm64
- merge-web
@@ -992,7 +1134,7 @@ jobs:
- build-model-server-amd64
- build-model-server-arm64
- merge-model-server
if: always() && (needs.build-web-amd64.result == 'failure' || needs.build-web-arm64.result == 'failure' || needs.merge-web.result == 'failure' || needs.build-web-cloud-amd64.result == 'failure' || needs.build-web-cloud-arm64.result == 'failure' || needs.merge-web-cloud.result == 'failure' || needs.build-backend-amd64.result == 'failure' || needs.build-backend-arm64.result == 'failure' || needs.merge-backend.result == 'failure' || needs.build-model-server-amd64.result == 'failure' || needs.build-model-server-arm64.result == 'failure' || needs.merge-model-server.result == 'failure') && github.event_name != 'workflow_dispatch'
if: always() && (needs.build-desktop.result == 'failure' || needs.build-web-amd64.result == 'failure' || needs.build-web-arm64.result == 'failure' || needs.merge-web.result == 'failure' || needs.build-web-cloud-amd64.result == 'failure' || needs.build-web-cloud-arm64.result == 'failure' || needs.merge-web-cloud.result == 'failure' || needs.build-backend-amd64.result == 'failure' || needs.build-backend-arm64.result == 'failure' || needs.merge-backend.result == 'failure' || needs.build-model-server-amd64.result == 'failure' || needs.build-model-server-arm64.result == 'failure' || needs.merge-model-server.result == 'failure') && github.event_name != 'workflow_dispatch'
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
runs-on: ubuntu-slim
timeout-minutes: 90
@@ -1007,6 +1149,9 @@ jobs:
shell: bash
run: |
FAILED_JOBS=""
if [ "${NEEDS_BUILD_DESKTOP_RESULT}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• build-desktop\\n"
fi
if [ "${NEEDS_BUILD_WEB_AMD64_RESULT}" == "failure" ]; then
FAILED_JOBS="${FAILED_JOBS}• build-web-amd64\\n"
fi
@@ -1047,6 +1192,7 @@ jobs:
FAILED_JOBS=$(printf '%s' "$FAILED_JOBS" | sed 's/\\n$//')
echo "jobs=$FAILED_JOBS" >> "$GITHUB_OUTPUT"
env:
NEEDS_BUILD_DESKTOP_RESULT: ${{ needs.build-desktop.result }}
NEEDS_BUILD_WEB_AMD64_RESULT: ${{ needs.build-web-amd64.result }}
NEEDS_BUILD_WEB_ARM64_RESULT: ${{ needs.build-web-arm64.result }}
NEEDS_MERGE_WEB_RESULT: ${{ needs.merge-web.result }}

View File

@@ -33,6 +33,11 @@ env:
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
EXA_API_KEY: ${{ secrets.EXA_API_KEY }}
GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN: ${{ secrets.ONYX_GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN }}
GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC: ${{ secrets.ONYX_GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC }}
GITHUB_ADMIN_EMAIL: ${{ secrets.ONYX_GITHUB_ADMIN_EMAIL }}
GITHUB_TEST_USER_1_EMAIL: ${{ secrets.ONYX_GITHUB_TEST_USER_1_EMAIL }}
GITHUB_TEST_USER_2_EMAIL: ${{ secrets.ONYX_GITHUB_TEST_USER_2_EMAIL }}
jobs:
discover-test-dirs:
@@ -399,6 +404,11 @@ jobs:
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
-e GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN=${GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN} \
-e GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC=${GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC} \
-e GITHUB_ADMIN_EMAIL=${GITHUB_ADMIN_EMAIL} \
-e GITHUB_TEST_USER_1_EMAIL=${GITHUB_TEST_USER_1_EMAIL} \
-e GITHUB_TEST_USER_2_EMAIL=${GITHUB_TEST_USER_2_EMAIL} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \

View File

@@ -4,7 +4,14 @@ concurrency:
cancel-in-progress: true
on:
merge_group:
pull_request:
branches:
- main
- "release/**"
push:
tags:
- "v*.*.*"
permissions:
contents: read

View File

@@ -4,7 +4,14 @@ concurrency:
cancel-in-progress: true
on:
merge_group:
pull_request:
branches:
- main
- "release/**"
push:
tags:
- "v*.*.*"
permissions:
contents: read

View File

@@ -8,30 +8,66 @@ repos:
# From: https://github.com/astral-sh/uv-pre-commit/pull/53/commits/d30b4298e4fb63ce8609e29acdbcf4c9018a483c
rev: d30b4298e4fb63ce8609e29acdbcf4c9018a483c
hooks:
- id: uv-run
name: Check lazy imports
args: ["--with=onyx-devtools", "ods", "check-lazy-imports"]
files: ^backend/(?!\.venv/).*\.py$
- id: uv-sync
args: ["--locked", "--all-extras"]
args: ["--active", "--locked", "--all-extras"]
- id: uv-lock
files: ^pyproject\.toml$
- id: uv-export
name: uv-export default.txt
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "backend", "-o", "backend/requirements/default.txt"]
args:
[
"--no-emit-project",
"--no-default-groups",
"--no-hashes",
"--extra",
"backend",
"-o",
"backend/requirements/default.txt",
]
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
- id: uv-export
name: uv-export dev.txt
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "dev", "-o", "backend/requirements/dev.txt"]
args:
[
"--no-emit-project",
"--no-default-groups",
"--no-hashes",
"--extra",
"dev",
"-o",
"backend/requirements/dev.txt",
]
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
- id: uv-export
name: uv-export ee.txt
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "ee", "-o", "backend/requirements/ee.txt"]
args:
[
"--no-emit-project",
"--no-default-groups",
"--no-hashes",
"--extra",
"ee",
"-o",
"backend/requirements/ee.txt",
]
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
- id: uv-export
name: uv-export model_server.txt
args: ["--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "model_server", "-o", "backend/requirements/model_server.txt"]
args:
[
"--no-emit-project",
"--no-default-groups",
"--no-hashes",
"--extra",
"model_server",
"-o",
"backend/requirements/model_server.txt",
]
files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$
- id: uv-run
name: Check lazy imports
args: ["--active", "--with=onyx-devtools", "ods", "check-lazy-imports"]
files: ^backend/(?!\.venv/).*\.py$
# NOTE: This takes ~6s on a single, large module which is prohibitively slow.
# - id: uv-run
# name: mypy
@@ -40,68 +76,73 @@ repos:
# files: ^backend/.*\.py$
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: 3e8a8703264a2f4a69428a0aa4dcb512790b2c8c # frozen: v6.0.0
rev: 3e8a8703264a2f4a69428a0aa4dcb512790b2c8c # frozen: v6.0.0
hooks:
- id: check-yaml
files: ^.github/
- repo: https://github.com/rhysd/actionlint
rev: a443f344ff32813837fa49f7aa6cbc478d770e62 # frozen: v1.7.9
rev: a443f344ff32813837fa49f7aa6cbc478d770e62 # frozen: v1.7.9
hooks:
- id: actionlint
- repo: https://github.com/psf/black
rev: 8a737e727ac5ab2f1d4cf5876720ed276dc8dc4b # frozen: 25.1.0
hooks:
- id: black
language_version: python3.11
- id: black
language_version: python3.11
# this is a fork which keeps compatibility with black
- repo: https://github.com/wimglenn/reorder-python-imports-black
rev: f55cd27f90f0cf0ee775002c2383ce1c7820013d # frozen: v3.14.0
rev: f55cd27f90f0cf0ee775002c2383ce1c7820013d # frozen: v3.14.0
hooks:
- id: reorder-python-imports
args: ['--py311-plus', '--application-directories=backend/']
# need to ignore alembic files, since reorder-python-imports gets confused
# and thinks that alembic is a local package since there is a folder
# in the backend directory called `alembic`
exclude: ^backend/alembic/
- id: reorder-python-imports
args: ["--py311-plus", "--application-directories=backend/"]
# need to ignore alembic files, since reorder-python-imports gets confused
# and thinks that alembic is a local package since there is a folder
# in the backend directory called `alembic`
exclude: ^backend/alembic/
# These settings will remove unused imports with side effects
# Note: The repo currently does not and should not have imports with side effects
- repo: https://github.com/PyCQA/autoflake
rev: 0544741e2b4a22b472d9d93e37d4ea9153820bb1 # frozen: v2.3.1
rev: 0544741e2b4a22b472d9d93e37d4ea9153820bb1 # frozen: v2.3.1
hooks:
- id: autoflake
args: [ '--remove-all-unused-imports', '--remove-unused-variables', '--in-place' , '--recursive']
args:
[
"--remove-all-unused-imports",
"--remove-unused-variables",
"--in-place",
"--recursive",
]
- repo: https://github.com/golangci/golangci-lint
rev: 9f61b0f53f80672872fced07b6874397c3ed197b # frozen: v2.7.2
rev: 9f61b0f53f80672872fced07b6874397c3ed197b # frozen: v2.7.2
hooks:
- id: golangci-lint
entry: bash -c "find tools/ -name go.mod -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: 971923581912ef60a6b70dbf0c3e9a39563c9d47 # frozen: v0.11.4
rev: 971923581912ef60a6b70dbf0c3e9a39563c9d47 # frozen: v0.11.4
hooks:
- id: ruff
- repo: https://github.com/pre-commit/mirrors-prettier
rev: ffb6a759a979008c0e6dff86e39f4745a2d9eac4 # frozen: v3.1.0
rev: ffb6a759a979008c0e6dff86e39f4745a2d9eac4 # frozen: v3.1.0
hooks:
- id: prettier
types_or: [html, css, javascript, ts, tsx]
language_version: system
- id: prettier
types_or: [html, css, javascript, ts, tsx]
language_version: system
- repo: https://github.com/sirwart/ripsecrets
rev: 7d94620933e79b8acaa0cd9e60e9864b07673d86 # frozen: v0.1.11
rev: 7d94620933e79b8acaa0cd9e60e9864b07673d86 # frozen: v0.1.11
hooks:
- id: ripsecrets
args:
- --additional-pattern
- ^sk-[A-Za-z0-9_\-]{20,}$
- --additional-pattern
- ^sk-[A-Za-z0-9_\-]{20,}$
- repo: local
hooks:
@@ -112,9 +153,13 @@ repos:
pass_filenames: false
files: \.tf$
# Uses tsgo (TypeScript's native Go compiler) for ~10x faster type checking.
# This is a preview package - if it breaks:
# 1. Try updating: cd web && npm update @typescript/native-preview
# 2. Or fallback to tsc: replace 'tsgo' with 'tsc' below
- id: typescript-check
name: TypeScript type check
entry: bash -c 'cd web && npm run types:check'
entry: bash -c 'cd web && npx tsgo --noEmit --project tsconfig.types.json'
language: system
pass_filenames: false
files: ^web/.*\.(ts|tsx)$

View File

@@ -161,7 +161,7 @@ You will need Docker installed to run these containers.
First navigate to `onyx/deployment/docker_compose`, then start up Postgres/Vespa/Redis/MinIO with:
```bash
docker compose up -d index relational_db cache minio
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d index relational_db cache minio
```
(index refers to Vespa, relational_db refers to Postgres, and cache refers to Redis)

View File

@@ -12,8 +12,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "23957775e5f5"
down_revision = "bc9771dccadf"
branch_labels = None # type: ignore
depends_on = None # type: ignore
branch_labels = None
depends_on = None
def upgrade() -> None:

View File

@@ -0,0 +1,27 @@
"""add last refreshed at mcp server
Revision ID: 2a391f840e85
Revises: 4cebcbc9b2ae
Create Date: 2025-12-06 15:19:59.766066
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembi.
revision = "2a391f840e85"
down_revision = "4cebcbc9b2ae"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"mcp_server",
sa.Column("last_refreshed_at", sa.DateTime(timezone=True), nullable=True),
)
def downgrade() -> None:
op.drop_column("mcp_server", "last_refreshed_at")

View File

@@ -0,0 +1,27 @@
"""add tab_index to tool_call
Revision ID: 4cebcbc9b2ae
Revises: a1b2c3d4e5f6
Create Date: 2025-12-16
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4cebcbc9b2ae"
down_revision = "a1b2c3d4e5f6"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"tool_call",
sa.Column("tab_index", sa.Integer(), nullable=False, server_default="0"),
)
def downgrade() -> None:
op.drop_column("tool_call", "tab_index")

View File

@@ -42,13 +42,13 @@ def upgrade() -> None:
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"), # type: ignore
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"), # type: ignore
server_default=sa.text("now()"),
nullable=False,
),
)
@@ -63,13 +63,13 @@ def upgrade() -> None:
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"), # type: ignore
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"), # type: ignore
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(

View File

@@ -0,0 +1,49 @@
"""add license table
Revision ID: a1b2c3d4e5f6
Revises: a01bf2971c5d
Create Date: 2025-12-04 10:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "a1b2c3d4e5f6"
down_revision = "a01bf2971c5d"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"license",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("license_data", sa.Text(), nullable=False),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
)
# Singleton pattern - only ever one row in this table
op.create_index(
"idx_license_singleton",
"license",
[sa.text("(true)")],
unique=True,
)
def downgrade() -> None:
op.drop_index("idx_license_singleton", table_name="license")
op.drop_table("license")

View File

@@ -0,0 +1,27 @@
"""Remove fast_default_model_name from llm_provider
Revision ID: a2b3c4d5e6f7
Revises: 2a391f840e85
Create Date: 2024-12-17
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "a2b3c4d5e6f7"
down_revision = "2a391f840e85"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.drop_column("llm_provider", "fast_default_model_name")
def downgrade() -> None:
op.add_column(
"llm_provider",
sa.Column("fast_default_model_name", sa.String(), nullable=True),
)

View File

@@ -0,0 +1,46 @@
"""Drop milestone table
Revision ID: b8c9d0e1f2a3
Revises: a2b3c4d5e6f7
Create Date: 2025-12-18
"""
from alembic import op
import sqlalchemy as sa
import fastapi_users_db_sqlalchemy
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "b8c9d0e1f2a3"
down_revision = "a2b3c4d5e6f7"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_table("milestone")
def downgrade() -> None:
op.create_table(
"milestone",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("tenant_id", sa.String(), nullable=True),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.Column("event_type", sa.String(), nullable=False),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column("event_tracker", postgresql.JSONB(), nullable=True),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("event_type", name="uq_milestone_event_type"),
)

View File

@@ -0,0 +1,52 @@
"""add_deep_research_tool
Revision ID: c1d2e3f4a5b6
Revises: b8c9d0e1f2a3
Create Date: 2025-12-18 16:00:00.000000
"""
from alembic import op
from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_DB_NAME
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c1d2e3f4a5b6"
down_revision = "b8c9d0e1f2a3"
branch_labels = None
depends_on = None
DEEP_RESEARCH_TOOL = {
"name": RESEARCH_AGENT_DB_NAME,
"display_name": "Research Agent",
"description": "The Research Agent is a sub-agent that conducts research on a specific topic.",
"in_code_tool_id": "ResearchAgent",
}
def upgrade() -> None:
conn = op.get_bind()
conn.execute(
sa.text(
"""
INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled)
VALUES (:name, :display_name, :description, :in_code_tool_id, false)
"""
),
DEEP_RESEARCH_TOOL,
)
def downgrade() -> None:
conn = op.get_bind()
conn.execute(
sa.text(
"""
DELETE FROM tool
WHERE in_code_tool_id = :in_code_tool_id
"""
),
{"in_code_tool_id": DEEP_RESEARCH_TOOL["in_code_tool_id"]},
)

View File

@@ -257,8 +257,8 @@ def _migrate_files_to_external_storage() -> None:
print(f"File {file_id} not found in PostgreSQL storage.")
continue
lobj_id = cast(int, file_record.lobj_oid) # type: ignore
file_metadata = cast(Any, file_record.file_metadata) # type: ignore
lobj_id = cast(int, file_record.lobj_oid)
file_metadata = cast(Any, file_record.file_metadata)
# Read file content from PostgreSQL
try:
@@ -280,7 +280,7 @@ def _migrate_files_to_external_storage() -> None:
else:
# Convert other types to dict if possible, otherwise None
try:
file_metadata = dict(file_record.file_metadata) # type: ignore
file_metadata = dict(file_record.file_metadata)
except (TypeError, ValueError):
file_metadata = None

View File

@@ -11,8 +11,8 @@ import sqlalchemy as sa
revision = "e209dc5a8156"
down_revision = "48d14957fe80"
branch_labels = None # type: ignore
depends_on = None # type: ignore
branch_labels = None
depends_on = None
def upgrade() -> None:

View File

@@ -8,7 +8,7 @@ Create Date: 2025-11-28 11:15:37.667340
from alembic import op
import sqlalchemy as sa
from onyx.db.enums import ( # type: ignore[import-untyped]
from onyx.db.enums import (
MCPTransport,
MCPAuthenticationType,
MCPAuthenticationPerformer,

View File

@@ -82,9 +82,9 @@ def run_migrations_offline() -> None:
def do_run_migrations(connection: Connection) -> None:
context.configure(
connection=connection,
target_metadata=target_metadata, # type: ignore
target_metadata=target_metadata, # type: ignore[arg-type]
include_object=include_object,
) # type: ignore
)
with context.begin_transaction():
context.run_migrations()

View File

@@ -118,6 +118,6 @@ def fetch_document_sets(
.all()
)
document_set_with_cc_pairs.append((document_set, cc_pairs)) # type: ignore
document_set_with_cc_pairs.append((document_set, cc_pairs))
return document_set_with_cc_pairs

View File

@@ -0,0 +1,278 @@
"""Database and cache operations for the license table."""
from datetime import datetime
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import Session
from ee.onyx.server.license.models import LicenseMetadata
from ee.onyx.server.license.models import LicensePayload
from ee.onyx.server.license.models import LicenseSource
from onyx.db.models import License
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
LICENSE_METADATA_KEY = "license:metadata"
LICENSE_CACHE_TTL_SECONDS = 86400 # 24 hours
# -----------------------------------------------------------------------------
# Database CRUD Operations
# -----------------------------------------------------------------------------
def get_license(db_session: Session) -> License | None:
"""
Get the current license (singleton pattern - only one row).
Args:
db_session: Database session
Returns:
License object if exists, None otherwise
"""
return db_session.execute(select(License)).scalars().first()
def upsert_license(db_session: Session, license_data: str) -> License:
"""
Insert or update the license (singleton pattern).
Args:
db_session: Database session
license_data: Base64-encoded signed license blob
Returns:
The created or updated License object
"""
existing = get_license(db_session)
if existing:
existing.license_data = license_data
db_session.commit()
db_session.refresh(existing)
logger.info("License updated")
return existing
new_license = License(license_data=license_data)
db_session.add(new_license)
db_session.commit()
db_session.refresh(new_license)
logger.info("License created")
return new_license
def delete_license(db_session: Session) -> bool:
"""
Delete the current license.
Args:
db_session: Database session
Returns:
True if deleted, False if no license existed
"""
existing = get_license(db_session)
if existing:
db_session.delete(existing)
db_session.commit()
logger.info("License deleted")
return True
return False
# -----------------------------------------------------------------------------
# Seat Counting
# -----------------------------------------------------------------------------
def get_used_seats(tenant_id: str | None = None) -> int:
"""
Get current seat usage.
For multi-tenant: counts users in UserTenantMapping for this tenant.
For self-hosted: counts all active users (includes both Onyx UI users
and Slack users who have been converted to Onyx users).
"""
if MULTI_TENANT:
from ee.onyx.server.tenants.user_mapping import get_tenant_count
return get_tenant_count(tenant_id or get_current_tenant_id())
else:
# Self-hosted: count all active users (Onyx + converted Slack users)
from onyx.db.engine.sql_engine import get_session_with_current_tenant
with get_session_with_current_tenant() as db_session:
result = db_session.execute(
select(func.count()).select_from(User).where(User.is_active) # type: ignore
)
return result.scalar() or 0
# -----------------------------------------------------------------------------
# Redis Cache Operations
# -----------------------------------------------------------------------------
def get_cached_license_metadata(tenant_id: str | None = None) -> LicenseMetadata | None:
"""
Get license metadata from Redis cache.
Args:
tenant_id: Tenant ID (for multi-tenant deployments)
Returns:
LicenseMetadata if cached, None otherwise
"""
tenant = tenant_id or get_current_tenant_id()
redis_client = get_redis_replica_client(tenant_id=tenant)
cached = redis_client.get(LICENSE_METADATA_KEY)
if cached:
try:
cached_str: str
if isinstance(cached, bytes):
cached_str = cached.decode("utf-8")
else:
cached_str = str(cached)
return LicenseMetadata.model_validate_json(cached_str)
except Exception as e:
logger.warning(f"Failed to parse cached license metadata: {e}")
return None
return None
def invalidate_license_cache(tenant_id: str | None = None) -> None:
"""
Invalidate the license metadata cache (not the license itself).
This deletes the cached LicenseMetadata from Redis. The actual license
in the database is not affected. Redis delete is idempotent - if the
key doesn't exist, this is a no-op.
Args:
tenant_id: Tenant ID (for multi-tenant deployments)
"""
tenant = tenant_id or get_current_tenant_id()
redis_client = get_redis_client(tenant_id=tenant)
redis_client.delete(LICENSE_METADATA_KEY)
logger.info("License cache invalidated")
def update_license_cache(
payload: LicensePayload,
source: LicenseSource | None = None,
grace_period_end: datetime | None = None,
tenant_id: str | None = None,
) -> LicenseMetadata:
"""
Update the Redis cache with license metadata.
We cache all license statuses (ACTIVE, GRACE_PERIOD, GATED_ACCESS) because:
1. Frontend needs status to show appropriate UI/banners
2. Caching avoids repeated DB + crypto verification on every request
3. Status enforcement happens at the feature level, not here
Args:
payload: Verified license payload
source: How the license was obtained
grace_period_end: Optional grace period end time
tenant_id: Tenant ID (for multi-tenant deployments)
Returns:
The cached LicenseMetadata
"""
from ee.onyx.utils.license import get_license_status
tenant = tenant_id or get_current_tenant_id()
redis_client = get_redis_client(tenant_id=tenant)
used_seats = get_used_seats(tenant)
status = get_license_status(payload, grace_period_end)
metadata = LicenseMetadata(
tenant_id=payload.tenant_id,
organization_name=payload.organization_name,
seats=payload.seats,
used_seats=used_seats,
plan_type=payload.plan_type,
issued_at=payload.issued_at,
expires_at=payload.expires_at,
grace_period_end=grace_period_end,
status=status,
source=source,
stripe_subscription_id=payload.stripe_subscription_id,
)
redis_client.setex(
LICENSE_METADATA_KEY,
LICENSE_CACHE_TTL_SECONDS,
metadata.model_dump_json(),
)
logger.info(f"License cache updated: {metadata.seats} seats, status={status.value}")
return metadata
def refresh_license_cache(
db_session: Session,
tenant_id: str | None = None,
) -> LicenseMetadata | None:
"""
Refresh the license cache from the database.
Args:
db_session: Database session
tenant_id: Tenant ID (for multi-tenant deployments)
Returns:
LicenseMetadata if license exists, None otherwise
"""
from ee.onyx.utils.license import verify_license_signature
license_record = get_license(db_session)
if not license_record:
invalidate_license_cache(tenant_id)
return None
try:
payload = verify_license_signature(license_record.license_data)
return update_license_cache(
payload,
source=LicenseSource.AUTO_FETCH,
tenant_id=tenant_id,
)
except ValueError as e:
logger.error(f"Failed to verify license during cache refresh: {e}")
invalidate_license_cache(tenant_id)
return None
def get_license_metadata(
db_session: Session,
tenant_id: str | None = None,
) -> LicenseMetadata | None:
"""
Get license metadata, using cache if available.
Args:
db_session: Database session
tenant_id: Tenant ID (for multi-tenant deployments)
Returns:
LicenseMetadata if license exists, None otherwise
"""
# Try cache first
cached = get_cached_license_metadata(tenant_id)
if cached:
return cached
# Refresh from database
return refresh_license_cache(db_session, tenant_id)

View File

@@ -14,6 +14,7 @@ from ee.onyx.server.enterprise_settings.api import (
basic_router as enterprise_settings_router,
)
from ee.onyx.server.evals.api import router as evals_router
from ee.onyx.server.license.api import router as license_router
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
from ee.onyx.server.middleware.tenant_tracking import (
add_api_server_tenant_id_middleware,
@@ -139,6 +140,8 @@ def get_application() -> FastAPI:
)
include_router_with_global_prefix_prepended(application, enterprise_settings_router)
include_router_with_global_prefix_prepended(application, usage_export_router)
# License management
include_router_with_global_prefix_prepended(application, license_router)
if MULTI_TENANT:
# Tenant management

View File

@@ -0,0 +1,246 @@
"""License API endpoints."""
import requests
from fastapi import APIRouter
from fastapi import Depends
from fastapi import File
from fastapi import HTTPException
from fastapi import UploadFile
from sqlalchemy.orm import Session
from ee.onyx.auth.users import current_admin_user
from ee.onyx.db.license import delete_license as db_delete_license
from ee.onyx.db.license import get_license_metadata
from ee.onyx.db.license import invalidate_license_cache
from ee.onyx.db.license import refresh_license_cache
from ee.onyx.db.license import update_license_cache
from ee.onyx.db.license import upsert_license
from ee.onyx.server.license.models import LicenseResponse
from ee.onyx.server.license.models import LicenseSource
from ee.onyx.server.license.models import LicenseStatusResponse
from ee.onyx.server.license.models import LicenseUploadResponse
from ee.onyx.server.license.models import SeatUsageResponse
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.utils.license import verify_license_signature
from onyx.auth.users import User
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.db.engine.sql_engine import get_session
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/license")
@router.get("")
async def get_license_status(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> LicenseStatusResponse:
"""Get current license status and seat usage."""
metadata = get_license_metadata(db_session)
if not metadata:
return LicenseStatusResponse(has_license=False)
return LicenseStatusResponse(
has_license=True,
seats=metadata.seats,
used_seats=metadata.used_seats,
plan_type=metadata.plan_type,
issued_at=metadata.issued_at,
expires_at=metadata.expires_at,
grace_period_end=metadata.grace_period_end,
status=metadata.status,
source=metadata.source,
)
@router.get("/seats")
async def get_seat_usage(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> SeatUsageResponse:
"""Get detailed seat usage information."""
metadata = get_license_metadata(db_session)
if not metadata:
return SeatUsageResponse(
total_seats=0,
used_seats=0,
available_seats=0,
)
return SeatUsageResponse(
total_seats=metadata.seats,
used_seats=metadata.used_seats,
available_seats=max(0, metadata.seats - metadata.used_seats),
)
@router.post("/fetch")
async def fetch_license(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> LicenseResponse:
"""
Fetch license from control plane.
Used after Stripe checkout completion to retrieve the new license.
"""
tenant_id = get_current_tenant_id()
try:
token = generate_data_plane_token()
except ValueError as e:
logger.error(f"Failed to generate data plane token: {e}")
raise HTTPException(
status_code=500, detail="Authentication configuration error"
)
try:
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
url = f"{CONTROL_PLANE_API_BASE_URL}/license/{tenant_id}"
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
data = response.json()
if not isinstance(data, dict) or "license" not in data:
raise HTTPException(
status_code=502, detail="Invalid response from control plane"
)
license_data = data["license"]
if not license_data:
raise HTTPException(status_code=404, detail="No license found")
# Verify signature before persisting
payload = verify_license_signature(license_data)
# Verify the fetched license is for this tenant
if payload.tenant_id != tenant_id:
logger.error(
f"License tenant mismatch: expected {tenant_id}, got {payload.tenant_id}"
)
raise HTTPException(
status_code=400,
detail="License tenant ID mismatch - control plane returned wrong license",
)
# Persist to DB and update cache atomically
upsert_license(db_session, license_data)
try:
update_license_cache(payload, source=LicenseSource.AUTO_FETCH)
except Exception as cache_error:
# Log but don't fail - DB is source of truth, cache will refresh on next read
logger.warning(f"Failed to update license cache: {cache_error}")
return LicenseResponse(success=True, license=payload)
except requests.HTTPError as e:
status_code = e.response.status_code if e.response is not None else 502
logger.error(f"Control plane returned error: {status_code}")
raise HTTPException(
status_code=status_code,
detail="Failed to fetch license from control plane",
)
except ValueError as e:
logger.error(f"License verification failed: {type(e).__name__}")
raise HTTPException(status_code=400, detail=str(e))
except requests.RequestException:
logger.exception("Failed to fetch license from control plane")
raise HTTPException(
status_code=502, detail="Failed to connect to control plane"
)
@router.post("/upload")
async def upload_license(
license_file: UploadFile = File(...),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> LicenseUploadResponse:
"""
Upload a license file manually.
Used for air-gapped deployments where control plane is not accessible.
"""
try:
content = await license_file.read()
license_data = content.decode("utf-8").strip()
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="Invalid license file format")
try:
payload = verify_license_signature(license_data)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
tenant_id = get_current_tenant_id()
if payload.tenant_id != tenant_id:
raise HTTPException(
status_code=400,
detail=f"License tenant ID mismatch. Expected {tenant_id}, got {payload.tenant_id}",
)
# Persist to DB and update cache
upsert_license(db_session, license_data)
try:
update_license_cache(payload, source=LicenseSource.MANUAL_UPLOAD)
except Exception as cache_error:
# Log but don't fail - DB is source of truth, cache will refresh on next read
logger.warning(f"Failed to update license cache: {cache_error}")
return LicenseUploadResponse(
success=True,
message=f"License uploaded successfully. {payload.seats} seats, expires {payload.expires_at.date()}",
)
@router.post("/refresh")
async def refresh_license_cache_endpoint(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> LicenseStatusResponse:
"""
Force refresh the license cache from the database.
Useful after manual database changes or to verify license validity.
"""
metadata = refresh_license_cache(db_session)
if not metadata:
return LicenseStatusResponse(has_license=False)
return LicenseStatusResponse(
has_license=True,
seats=metadata.seats,
used_seats=metadata.used_seats,
plan_type=metadata.plan_type,
issued_at=metadata.issued_at,
expires_at=metadata.expires_at,
grace_period_end=metadata.grace_period_end,
status=metadata.status,
source=metadata.source,
)
@router.delete("")
async def delete_license(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> dict[str, bool]:
"""
Delete the current license.
Admin only - removes license and invalidates cache.
"""
# Invalidate cache first - if DB delete fails, stale cache is worse than no cache
try:
invalidate_license_cache()
except Exception as cache_error:
logger.warning(f"Failed to invalidate license cache: {cache_error}")
deleted = db_delete_license(db_session)
return {"deleted": deleted}

View File

@@ -0,0 +1,92 @@
from datetime import datetime
from enum import Enum
from pydantic import BaseModel
from onyx.server.settings.models import ApplicationStatus
class PlanType(str, Enum):
MONTHLY = "monthly"
ANNUAL = "annual"
class LicenseSource(str, Enum):
AUTO_FETCH = "auto_fetch"
MANUAL_UPLOAD = "manual_upload"
class LicensePayload(BaseModel):
"""The payload portion of a signed license."""
version: str
tenant_id: str
organization_name: str | None = None
issued_at: datetime
expires_at: datetime
seats: int
plan_type: PlanType
billing_cycle: str | None = None
grace_period_days: int = 30
stripe_subscription_id: str | None = None
stripe_customer_id: str | None = None
class LicenseData(BaseModel):
"""Full signed license structure."""
payload: LicensePayload
signature: str
class LicenseMetadata(BaseModel):
"""Cached license metadata stored in Redis."""
tenant_id: str
organization_name: str | None = None
seats: int
used_seats: int
plan_type: PlanType
issued_at: datetime
expires_at: datetime
grace_period_end: datetime | None = None
status: ApplicationStatus
source: LicenseSource | None = None
stripe_subscription_id: str | None = None
class LicenseStatusResponse(BaseModel):
"""Response for license status API."""
has_license: bool
seats: int = 0
used_seats: int = 0
plan_type: PlanType | None = None
issued_at: datetime | None = None
expires_at: datetime | None = None
grace_period_end: datetime | None = None
status: ApplicationStatus | None = None
source: LicenseSource | None = None
class LicenseResponse(BaseModel):
"""Response after license fetch/upload."""
success: bool
message: str | None = None
license: LicensePayload | None = None
class LicenseUploadResponse(BaseModel):
"""Response after license upload."""
success: bool
message: str | None = None
class SeatUsageResponse(BaseModel):
"""Response for seat usage API."""
total_seats: int
used_seats: int
available_seats: int

View File

@@ -20,7 +20,7 @@ from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_or_create_root_message
from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.factory import get_llm_for_persona
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.utils.logger import setup_logger
@@ -100,7 +100,6 @@ def handle_simplified_chat_message(
chunks_below=0,
full_doc=chat_message_req.full_doc,
structured_response_format=chat_message_req.structured_response_format,
use_agentic_search=chat_message_req.use_agentic_search,
)
packets = stream_chat_message_objects(
@@ -158,7 +157,7 @@ def handle_send_message_simple_with_history(
persona_id=req.persona_id,
)
llm, _ = get_llms_for_persona(persona=chat_session.persona, user=user)
llm = get_llm_for_persona(persona=chat_session.persona, user=user)
llm_tokenizer = get_tokenizer(
model_name=llm.config.model_name,
@@ -205,7 +204,6 @@ def handle_send_message_simple_with_history(
chunks_below=0,
full_doc=req.full_doc,
structured_response_format=req.structured_response_format,
use_agentic_search=req.use_agentic_search,
)
packets = stream_chat_message_objects(

View File

@@ -54,9 +54,6 @@ class BasicCreateChatMessageRequest(ChunkContext):
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
# If True, uses agentic search instead of basic search
use_agentic_search: bool = False
@model_validator(mode="after")
def validate_chat_session_or_persona(self) -> "BasicCreateChatMessageRequest":
if self.chat_session_id is None and self.persona_id is None:
@@ -76,8 +73,6 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
# only works if using an OpenAI model. See the following for more details:
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
# If True, uses agentic search instead of basic search
use_agentic_search: bool = False
class SimpleDoc(BaseModel):

View File

@@ -45,7 +45,7 @@ from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRe
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.setup import setup_onyx
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import mt_cloud_telemetry
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
@@ -269,7 +269,6 @@ def configure_default_api_keys(db_session: Session) -> None:
provider=ANTHROPIC_PROVIDER_NAME,
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name="claude-3-7-sonnet-20250219",
fast_default_model_name="claude-3-5-sonnet-20241022",
model_configurations=[
ModelConfigurationUpsertRequest(
name=name,
@@ -296,7 +295,6 @@ def configure_default_api_keys(db_session: Session) -> None:
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name="gpt-4o",
fast_default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name=model_name,
@@ -562,17 +560,11 @@ async def assign_tenant_to_user(
try:
add_users_to_tenant([email], tenant_id)
# Create milestone record in the same transaction context as the tenant assignment
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
event_type=MilestoneRecordType.TENANT_CREATED,
properties={
"email": email,
},
db_session=db_session,
)
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=email,
event=MilestoneRecordType.TENANT_CREATED,
)
except Exception:
logger.exception(f"Failed to assign tenant {tenant_id} to user {email}")
raise Exception("Failed to assign tenant to user")

View File

@@ -249,6 +249,17 @@ def accept_user_invite(email: str, tenant_id: str) -> None:
)
raise
# Remove from invited users list since they've accepted
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
invited_users = get_invited_users()
if email in invited_users:
invited_users.remove(email)
write_invited_users(invited_users)
logger.info(f"Removed {email} from invited users list after acceptance")
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def deny_user_invite(email: str, tenant_id: str) -> None:
"""

View File

@@ -0,0 +1,126 @@
"""RSA-4096 license signature verification utilities."""
import base64
import json
import os
from datetime import datetime
from datetime import timezone
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
from ee.onyx.server.license.models import LicenseData
from ee.onyx.server.license.models import LicensePayload
from onyx.server.settings.models import ApplicationStatus
from onyx.utils.logger import setup_logger
logger = setup_logger()
# RSA-4096 Public Key for license verification
# Load from environment variable - key is generated on the control plane
# In production, inject via Kubernetes secrets or secrets manager
LICENSE_PUBLIC_KEY_PEM = os.environ.get("LICENSE_PUBLIC_KEY_PEM", "")
def _get_public_key() -> RSAPublicKey:
"""Load the public key from environment variable."""
if not LICENSE_PUBLIC_KEY_PEM:
raise ValueError(
"LICENSE_PUBLIC_KEY_PEM environment variable not set. "
"License verification requires the control plane public key."
)
key = serialization.load_pem_public_key(LICENSE_PUBLIC_KEY_PEM.encode())
if not isinstance(key, RSAPublicKey):
raise ValueError("Expected RSA public key")
return key
def verify_license_signature(license_data: str) -> LicensePayload:
"""
Verify RSA-4096 signature and return payload if valid.
Args:
license_data: Base64-encoded JSON containing payload and signature
Returns:
LicensePayload if signature is valid
Raises:
ValueError: If license data is invalid or signature verification fails
"""
try:
# Decode the license data
decoded = json.loads(base64.b64decode(license_data))
license_obj = LicenseData(**decoded)
payload_json = json.dumps(
license_obj.payload.model_dump(mode="json"), sort_keys=True
)
signature_bytes = base64.b64decode(license_obj.signature)
# Verify signature using PSS padding (modern standard)
public_key = _get_public_key()
public_key.verify(
signature_bytes,
payload_json.encode(),
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH,
),
hashes.SHA256(),
)
return license_obj.payload
except InvalidSignature:
logger.error("License signature verification failed")
raise ValueError("Invalid license signature")
except json.JSONDecodeError:
logger.error("Failed to decode license JSON")
raise ValueError("Invalid license format: not valid JSON")
except (ValueError, KeyError, TypeError) as e:
logger.error(f"License data validation error: {type(e).__name__}")
raise ValueError(f"Invalid license format: {type(e).__name__}")
except Exception:
logger.exception("Unexpected error during license verification")
raise ValueError("License verification failed: unexpected error")
def get_license_status(
payload: LicensePayload,
grace_period_end: datetime | None = None,
) -> ApplicationStatus:
"""
Determine current license status based on expiry.
Args:
payload: The verified license payload
grace_period_end: Optional grace period end datetime
Returns:
ApplicationStatus indicating current license state
"""
now = datetime.now(timezone.utc)
# Check if grace period has expired
if grace_period_end and now > grace_period_end:
return ApplicationStatus.GATED_ACCESS
# Check if license has expired
if now > payload.expires_at:
if grace_period_end and now <= grace_period_end:
return ApplicationStatus.GRACE_PERIOD
return ApplicationStatus.GATED_ACCESS
# License is valid
return ApplicationStatus.ACTIVE
def is_license_valid(payload: LicensePayload) -> bool:
"""Check if a license is currently valid (not expired)."""
now = datetime.now(timezone.utc)
return now <= payload.expires_at

View File

@@ -6,7 +6,7 @@ import numpy as np
import torch
import torch.nn.functional as F
from fastapi import APIRouter
from huggingface_hub import snapshot_download # type: ignore
from huggingface_hub import snapshot_download
from model_server.constants import INFORMATION_CONTENT_MODEL_WARM_UP_STRING
from model_server.constants import MODEL_WARM_UP_STRING
@@ -36,8 +36,8 @@ from shared_configs.model_server_models import IntentRequest
from shared_configs.model_server_models import IntentResponse
if TYPE_CHECKING:
from setfit import SetFitModel # type: ignore
from transformers import PreTrainedTokenizer, BatchEncoding # type: ignore
from setfit import SetFitModel # type: ignore[import-untyped]
from transformers import PreTrainedTokenizer, BatchEncoding
logger = setup_logger()

View File

@@ -42,7 +42,7 @@ def get_embedding_model(
Loads or returns a cached SentenceTransformer, sets max_seq_length, pins device,
pre-warms rotary caches once, and wraps encode() with a lock to avoid cache races.
"""
from sentence_transformers import SentenceTransformer # type: ignore
from sentence_transformers import SentenceTransformer
def _prewarm_rope(st_model: "SentenceTransformer", target_len: int) -> None:
"""
@@ -91,7 +91,7 @@ def get_local_reranking_model(
model_name: str,
) -> "CrossEncoder":
global _RERANK_MODEL
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import CrossEncoder
if _RERANK_MODEL is None:
logger.notice(f"Loading {model_name}")
@@ -195,7 +195,7 @@ async def local_rerank(query: str, docs: list[str], model_name: str) -> list[flo
# Run CPU-bound reranking in a thread pool
return await asyncio.get_event_loop().run_in_executor(
None,
lambda: cross_encoder.predict([(query, doc) for doc in docs]).tolist(), # type: ignore
lambda: cross_encoder.predict([(query, doc) for doc in docs]).tolist(),
)

View File

@@ -12,7 +12,7 @@ from fastapi import FastAPI
from prometheus_fastapi_instrumentator import Instrumentator
from sentry_sdk.integrations.fastapi import FastApiIntegration
from sentry_sdk.integrations.starlette import StarletteIntegration
from transformers import logging as transformer_logging # type:ignore
from transformers import logging as transformer_logging
from model_server.custom_models import router as custom_models_router
from model_server.custom_models import warm_up_information_content_model

View File

@@ -8,7 +8,7 @@ import torch.nn as nn
if TYPE_CHECKING:
from transformers import DistilBertConfig # type: ignore
from transformers import DistilBertConfig
class HybridClassifier(nn.Module):
@@ -34,7 +34,7 @@ class HybridClassifier(nn.Module):
query_ids: torch.Tensor,
query_mask: torch.Tensor,
) -> dict[str, torch.Tensor]:
outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask) # type: ignore
outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask)
sequence_output = outputs.last_hidden_state
# Intent classification on the CLS token
@@ -102,7 +102,7 @@ class ConnectorClassifier(nn.Module):
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states = self.distilbert( # type: ignore
hidden_states = self.distilbert(
input_ids=input_ids, attention_mask=attention_mask
).last_hidden_state

View File

@@ -43,7 +43,7 @@ def get_access_for_document(
versioned_get_access_for_document_fn = fetch_versioned_implementation(
"onyx.access.access", "_get_access_for_document"
)
return versioned_get_access_for_document_fn(document_id, db_session) # type: ignore
return versioned_get_access_for_document_fn(document_id, db_session)
def get_null_document_access() -> DocumentAccess:
@@ -93,9 +93,7 @@ def get_access_for_documents(
versioned_get_access_for_documents_fn = fetch_versioned_implementation(
"onyx.access.access", "_get_access_for_documents"
)
return versioned_get_access_for_documents_fn(
document_ids, db_session
) # type: ignore
return versioned_get_access_for_documents_fn(document_ids, db_session)
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
@@ -113,7 +111,7 @@ def get_acl_for_user(user: User | None, db_session: Session | None = None) -> se
versioned_acl_for_user_fn = fetch_versioned_implementation(
"onyx.access.access", "_get_acl_for_user"
)
return versioned_acl_for_user_fn(user, db_session) # type: ignore
return versioned_acl_for_user_fn(user, db_session)
def source_should_fetch_permissions_during_indexing(source: DocumentSource) -> bool:

View File

@@ -117,7 +117,7 @@ from onyx.redis.redis_pool import get_async_redis_connection
from onyx.redis.redis_pool import get_redis_client
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from onyx.utils.timing import log_function_time
@@ -338,9 +338,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_created = False
try:
user = await super().create(
user_create, safe=safe, request=request
) # type: ignore
user = await super().create(user_create, safe=safe, request=request)
user_created = True
except IntegrityError as error:
# Race condition: another request created the same user after the
@@ -604,10 +602,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
# otherwise, the oidc expiry will always be old, and the user will never be able to login
if (
user.oidc_expiry is not None # type: ignore
and not TRACK_EXTERNAL_IDP_EXPIRY
):
if user.oidc_expiry is not None and not TRACK_EXTERNAL_IDP_EXPIRY:
await self.user_db.update(user, {"oidc_expiry": None})
user.oidc_expiry = None # type: ignore
remove_user_from_invited_users(user.email)
@@ -653,19 +648,11 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_count = await get_user_count()
logger.debug(f"Current tenant user count: {user_count}")
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
event_type = (
MilestoneRecordType.USER_SIGNED_UP
if user_count == 1
else MilestoneRecordType.MULTIPLE_USERS
)
create_milestone_and_report(
user=user,
distinct_id=user.email,
event_type=event_type,
properties=None,
db_session=db_session,
)
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=user.email,
event=MilestoneRecordType.USER_SIGNED_UP,
)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@@ -1186,7 +1173,7 @@ async def _sync_jwt_oidc_expiry(
return
await user_manager.user_db.update(user, {"oidc_expiry": oidc_expiry})
user.oidc_expiry = oidc_expiry # type: ignore
user.oidc_expiry = oidc_expiry
return
if user.oidc_expiry is not None:

View File

@@ -0,0 +1,135 @@
from uuid import uuid4
from celery import Celery
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import SearchSettings
def try_creating_docfetching_task(
celery_app: Celery,
cc_pair: ConnectorCredentialPair,
search_settings: SearchSettings,
reindex: bool,
db_session: Session,
r: Redis,
tenant_id: str,
) -> int | None:
"""Checks for any conditions that should block the indexing task from being
created, then creates the task.
Does not check for scheduling related conditions as this function
is used to trigger indexing immediately.
Now uses database-based coordination instead of Redis fencing.
"""
LOCK_TIMEOUT = 30
# we need to serialize any attempt to trigger indexing since it can be triggered
# either via celery beat or manually (API call)
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
timeout=LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
return None
index_attempt_id = None
try:
# Basic status checks
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
# Generate custom task ID for tracking
custom_task_id = f"docfetching_{cc_pair.id}_{search_settings.id}_{uuid4()}"
# Try to create a new index attempt using database coordination
# This replaces the Redis fencing mechanism
index_attempt_id = IndexingCoordination.try_create_index_attempt(
db_session=db_session,
cc_pair_id=cc_pair.id,
search_settings_id=search_settings.id,
celery_task_id=custom_task_id,
from_beginning=reindex,
)
if index_attempt_id is None:
# Another indexing attempt is already running
return None
# Determine which queue to use based on whether this is a user file
# TODO: at the moment the indexing pipeline is
# shared between user files and connectors
queue = (
OnyxCeleryQueues.USER_FILES_INDEXING
if cc_pair.is_user_file
else OnyxCeleryQueues.CONNECTOR_DOC_FETCHING
)
# Use higher priority for first-time indexing to ensure new connectors
# get processed before re-indexing of existing connectors
has_successful_attempt = cc_pair.last_successful_index_time is not None
priority = (
OnyxCeleryPriority.MEDIUM
if has_successful_attempt
else OnyxCeleryPriority.HIGH
)
# Send the task to Celery
result = celery_app.send_task(
OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
kwargs=dict(
index_attempt_id=index_attempt_id,
cc_pair_id=cc_pair.id,
search_settings_id=search_settings.id,
tenant_id=tenant_id,
),
queue=queue,
task_id=custom_task_id,
priority=priority,
)
if not result:
raise RuntimeError("send_task for connector_doc_fetching_task failed.")
task_logger.info(
f"Created docfetching task: "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings.id} "
f"attempt_id={index_attempt_id} "
f"celery_task_id={custom_task_id}"
)
return index_attempt_id
except Exception:
task_logger.exception(
f"try_creating_indexing_task - Unexpected exception: "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings.id}"
)
# Clean up on failure
if index_attempt_id is not None:
mark_attempt_failed(index_attempt_id, db_session)
return None
finally:
if lock.owned():
lock.release()
return index_attempt_id

View File

@@ -25,14 +25,14 @@ from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.memory_monitoring import emit_process_memory
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.background.celery.tasks.docfetching.task_creation_utils import (
try_creating_docfetching_task,
)
from onyx.background.celery.tasks.docprocessing.heartbeat import start_heartbeat
from onyx.background.celery.tasks.docprocessing.heartbeat import stop_heartbeat
from onyx.background.celery.tasks.docprocessing.utils import IndexingCallback
from onyx.background.celery.tasks.docprocessing.utils import is_in_repeated_error_state
from onyx.background.celery.tasks.docprocessing.utils import should_index
from onyx.background.celery.tasks.docprocessing.utils import (
try_creating_docfetching_task,
)
from onyx.background.celery.tasks.models import DocProcessingContext
from onyx.background.indexing.checkpointing_utils import cleanup_checkpoint
from onyx.background.indexing.checkpointing_utils import (
@@ -45,6 +45,7 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
@@ -108,6 +109,7 @@ from onyx.redis.redis_utils import is_fence
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.logger import setup_logger
from onyx.utils.middleware import make_randomized_onyx_request_id
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
@@ -547,6 +549,12 @@ def check_indexing_completion(
)
db_session.commit()
mt_cloud_telemetry(
tenant_id=tenant_id,
distinct_id=tenant_id,
event=MilestoneRecordType.CONNECTOR_SUCCEEDED,
)
# Clear repeated error state on success
if cc_pair.in_repeated_error_state:
cc_pair.in_repeated_error_state = False
@@ -1404,8 +1412,13 @@ def _docprocessing_task(
)
# Process documents through indexing pipeline
connector_source = (
index_attempt.connector_credential_pair.connector.source.value
)
task_logger.info(
f"Processing {len(documents)} documents through indexing pipeline"
f"Processing {len(documents)} documents through indexing pipeline: "
f"cc_pair_id={cc_pair_id}, source={connector_source}, "
f"batch_num={batch_num}"
)
adapter = DocumentIndexingBatchAdapter(
@@ -1495,6 +1508,8 @@ def _docprocessing_task(
# FIX: Explicitly clear document batch from memory and force garbage collection
# This helps prevent memory accumulation across multiple batches
# NOTE: Thread-local event loops in embedding threads are cleaned up automatically
# via the _cleanup_thread_local decorator in search_nlp_models.py
del documents
gc.collect()

View File

@@ -1,22 +1,15 @@
import time
from datetime import datetime
from datetime import timezone
from uuid import uuid4
from celery import Celery
from redis import Redis
from redis.exceptions import LockError
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine.time_utils import get_db_current_time
from onyx.db.enums import ConnectorCredentialPairStatus
@@ -24,8 +17,6 @@ from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
from onyx.db.index_attempt import get_last_attempt_for_cc_pair
from onyx.db.index_attempt import get_recent_attempts_for_cc_pair
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import SearchSettings
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
@@ -298,112 +289,3 @@ def should_index(
return False
return True
def try_creating_docfetching_task(
celery_app: Celery,
cc_pair: ConnectorCredentialPair,
search_settings: SearchSettings,
reindex: bool,
db_session: Session,
r: Redis,
tenant_id: str,
) -> int | None:
"""Checks for any conditions that should block the indexing task from being
created, then creates the task.
Does not check for scheduling related conditions as this function
is used to trigger indexing immediately.
Now uses database-based coordination instead of Redis fencing.
"""
LOCK_TIMEOUT = 30
# we need to serialize any attempt to trigger indexing since it can be triggered
# either via celery beat or manually (API call)
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
timeout=LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
return None
index_attempt_id = None
try:
# Basic status checks
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
# Generate custom task ID for tracking
custom_task_id = f"docfetching_{cc_pair.id}_{search_settings.id}_{uuid4()}"
# Try to create a new index attempt using database coordination
# This replaces the Redis fencing mechanism
index_attempt_id = IndexingCoordination.try_create_index_attempt(
db_session=db_session,
cc_pair_id=cc_pair.id,
search_settings_id=search_settings.id,
celery_task_id=custom_task_id,
from_beginning=reindex,
)
if index_attempt_id is None:
# Another indexing attempt is already running
return None
# Determine which queue to use based on whether this is a user file
# TODO: at the moment the indexing pipeline is
# shared between user files and connectors
queue = (
OnyxCeleryQueues.USER_FILES_INDEXING
if cc_pair.is_user_file
else OnyxCeleryQueues.CONNECTOR_DOC_FETCHING
)
# Send the task to Celery
result = celery_app.send_task(
OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
kwargs=dict(
index_attempt_id=index_attempt_id,
cc_pair_id=cc_pair.id,
search_settings_id=search_settings.id,
tenant_id=tenant_id,
),
queue=queue,
task_id=custom_task_id,
priority=OnyxCeleryPriority.MEDIUM,
)
if not result:
raise RuntimeError("send_task for connector_doc_fetching_task failed.")
task_logger.info(
f"Created docfetching task: "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings.id} "
f"attempt_id={index_attempt_id} "
f"celery_task_id={custom_task_id}"
)
return index_attempt_id
except Exception:
task_logger.exception(
f"try_creating_indexing_task - Unexpected exception: "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings.id}"
)
# Clean up on failure
if index_attempt_id is not None:
mark_attempt_failed(index_attempt_id, db_session)
return None
finally:
if lock.owned():
lock.release()
return index_attempt_id

View File

@@ -127,12 +127,6 @@ def check_for_llm_model_update(self: Task, *, tenant_id: str) -> bool | None:
f"available, setting to first model in list: {available_models[0]}"
)
default_provider.default_model_name = available_models[0]
if default_provider.fast_default_model_name not in available_models:
task_logger.info(
f"Fast default model {default_provider.fast_default_model_name} "
f"not available, setting to first model in list: {available_models[0]}"
)
default_provider.fast_default_model_name = available_models[0]
db_session.commit()
if added_models or removed_models:

View File

@@ -55,8 +55,8 @@ class RetryDocumentIndex:
chunk_count: int | None,
fields: VespaDocumentFields | None,
user_fields: VespaDocumentUserFields | None,
) -> int:
return self.index.update_single(
) -> None:
self.index.update_single(
doc_id,
tenant_id=tenant_id,
chunk_count=chunk_count,

View File

@@ -95,7 +95,6 @@ def document_by_cc_pair_cleanup_task(
try:
with get_session_with_current_tenant() as db_session:
action = "skip"
chunks_affected = 0
active_search_settings = get_active_search_settings(db_session)
doc_index = get_default_document_index(
@@ -114,7 +113,7 @@ def document_by_cc_pair_cleanup_task(
chunk_count = fetch_chunk_count_for_document(document_id, db_session)
chunks_affected = retry_index.delete_single(
_ = retry_index.delete_single(
document_id,
tenant_id=tenant_id,
chunk_count=chunk_count,
@@ -157,7 +156,7 @@ def document_by_cc_pair_cleanup_task(
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
chunks_affected = retry_index.update_single(
retry_index.update_single(
document_id,
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
@@ -187,7 +186,6 @@ def document_by_cc_pair_cleanup_task(
f"doc={document_id} "
f"action={action} "
f"refcount={count} "
f"chunks={chunks_affected} "
f"elapsed={elapsed:.2f}"
)
except SoftTimeLimitExceeded:

View File

@@ -597,7 +597,7 @@ def process_single_user_file_project_sync(
return None
project_ids = [project.id for project in user_file.projects]
chunks_affected = retry_index.update_single(
retry_index.update_single(
doc_id=str(user_file.id),
tenant_id=tenant_id,
chunk_count=user_file.chunk_count,
@@ -606,7 +606,7 @@ def process_single_user_file_project_sync(
)
task_logger.info(
f"process_single_user_file_project_sync - Chunks affected id={user_file_id} chunks={chunks_affected}"
f"process_single_user_file_project_sync - User file id={user_file_id}"
)
user_file.needs_project_sync = False
@@ -874,7 +874,10 @@ def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
)
# Now update Vespa chunks with the found chunk count using retry_index
updated_chunks = retry_index.update_single(
# WARNING: In the future this will error; we no longer want
# to support changing document ID.
# TODO(andrei): Delete soon.
retry_index.update_single(
doc_id=str(normalized_doc_id),
tenant_id=tenant_id,
chunk_count=chunk_count,
@@ -883,7 +886,7 @@ def user_file_docid_migration_task(self: Task, *, tenant_id: str) -> bool:
user_projects=user_project_ids
),
)
user_file.chunk_count = updated_chunks
user_file.chunk_count = chunk_count
# Update the SearchDocs
actual_doc_id = str(user_file.document_id)

View File

@@ -501,7 +501,7 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
chunks_affected = retry_index.update_single(
retry_index.update_single(
document_id,
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
@@ -515,10 +515,7 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action=sync "
f"chunks={chunks_affected} "
f"elapsed={elapsed:.2f}"
f"doc={document_id} " f"action=sync " f"elapsed={elapsed:.2f}"
)
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
except SoftTimeLimitExceeded:

View File

@@ -1,7 +1,6 @@
import sys
import time
import traceback
from collections import defaultdict
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -21,7 +20,6 @@ from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE
from onyx.configs.app_configs import MAX_FILE_SIZE_BYTES
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
@@ -32,11 +30,8 @@ from onyx.connectors.factory import instantiate_connector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorStopSignal
from onyx.connectors.models import DocExtractionContext
from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
from onyx.connectors.models import TextSection
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_last_successful_attempt_poll_range_end
@@ -49,34 +44,16 @@ from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
from onyx.db.index_attempt import create_index_attempt_error
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair
from onyx.db.index_attempt import get_recent_completed_attempts_for_cc_pair
from onyx.db.index_attempt import mark_attempt_canceled
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.index_attempt import mark_attempt_partially_succeeded
from onyx.db.index_attempt import mark_attempt_succeeded
from onyx.db.index_attempt import transition_attempt_to_in_progress
from onyx.db.index_attempt import update_docs_indexed
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
from onyx.document_index.factory import get_default_document_index
from onyx.file_store.document_batch_storage import DocumentBatchStorage
from onyx.file_store.document_batch_storage import get_document_batch_storage
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.adapters.document_indexing_adapter import (
DocumentIndexingBatchAdapter,
)
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
from onyx.utils.logger import setup_logger
from onyx.utils.middleware import make_randomized_onyx_request_id
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import INDEX_ATTEMPT_INFO_CONTEXTVAR
@@ -272,583 +249,6 @@ def _check_failure_threshold(
)
# NOTE: this is the old run_indexing function that the new decoupled approach
# is based on. Leaving this for comparison purposes, but if you see this comment
# has been here for >2 month, please delete this function.
def _run_indexing(
db_session: Session,
index_attempt_id: int,
tenant_id: str,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
"""
1. Get documents which are either new or updated from specified application
2. Embed and index these documents into the chosen datastore (vespa)
3. Updates Postgres to record the indexed documents + the outcome of this run
"""
from onyx.indexing.indexing_pipeline import run_indexing_pipeline
start_time = time.monotonic() # jsut used for logging
with get_session_with_current_tenant() as db_session_temp:
index_attempt_start = get_index_attempt(
db_session_temp,
index_attempt_id,
eager_load_cc_pair=True,
eager_load_search_settings=True,
)
if not index_attempt_start:
raise ValueError(
f"Index attempt {index_attempt_id} does not exist in DB. This should not be possible."
)
if index_attempt_start.search_settings is None:
raise ValueError(
"Search settings must be set for indexing. This should not be possible."
)
db_connector = index_attempt_start.connector_credential_pair.connector
db_credential = index_attempt_start.connector_credential_pair.credential
is_primary = (
index_attempt_start.search_settings.status == IndexModelStatus.PRESENT
)
from_beginning = index_attempt_start.from_beginning
has_successful_attempt = (
index_attempt_start.connector_credential_pair.last_successful_index_time
is not None
)
ctx = DocExtractionContext(
index_name=index_attempt_start.search_settings.index_name,
cc_pair_id=index_attempt_start.connector_credential_pair.id,
connector_id=db_connector.id,
credential_id=db_credential.id,
source=db_connector.source,
earliest_index_time=(
db_connector.indexing_start.timestamp()
if db_connector.indexing_start
else 0
),
from_beginning=from_beginning,
# Only update cc-pair status for primary index jobs
# Secondary index syncs at the end when swapping
is_primary=is_primary,
should_fetch_permissions_during_indexing=(
index_attempt_start.connector_credential_pair.access_type
== AccessType.SYNC
and source_should_fetch_permissions_during_indexing(db_connector.source)
and is_primary
# if we've already successfully indexed, let the doc_sync job
# take care of doc-level permissions
and (from_beginning or not has_successful_attempt)
),
search_settings_status=index_attempt_start.search_settings.status,
doc_extraction_complete_batch_num=None,
)
last_successful_index_poll_range_end = (
ctx.earliest_index_time
if ctx.from_beginning
else get_last_successful_attempt_poll_range_end(
cc_pair_id=ctx.cc_pair_id,
earliest_index=ctx.earliest_index_time,
search_settings=index_attempt_start.search_settings,
db_session=db_session_temp,
)
)
if last_successful_index_poll_range_end > POLL_CONNECTOR_OFFSET:
window_start = datetime.fromtimestamp(
last_successful_index_poll_range_end, tz=timezone.utc
) - timedelta(minutes=POLL_CONNECTOR_OFFSET)
else:
# don't go into "negative" time if we've never indexed before
window_start = datetime.fromtimestamp(0, tz=timezone.utc)
most_recent_attempt = next(
iter(
get_recent_completed_attempts_for_cc_pair(
cc_pair_id=ctx.cc_pair_id,
search_settings_id=index_attempt_start.search_settings_id,
db_session=db_session_temp,
limit=1,
)
),
None,
)
# if the last attempt failed, try and use the same window. This is necessary
# to ensure correctness with checkpointing. If we don't do this, things like
# new slack channels could be missed (since existing slack channels are
# cached as part of the checkpoint).
if (
most_recent_attempt
and most_recent_attempt.poll_range_end
and (
most_recent_attempt.status == IndexingStatus.FAILED
or most_recent_attempt.status == IndexingStatus.CANCELED
)
):
window_end = most_recent_attempt.poll_range_end
else:
window_end = datetime.now(tz=timezone.utc)
# add start/end now that they have been set
index_attempt_start.poll_range_start = window_start
index_attempt_start.poll_range_end = window_end
db_session_temp.add(index_attempt_start)
db_session_temp.commit()
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=index_attempt_start.search_settings,
callback=callback,
)
information_content_classification_model = InformationContentClassificationModel()
document_index = get_default_document_index(
index_attempt_start.search_settings,
None,
httpx_client=HttpxPool.get("vespa"),
)
# Initialize memory tracer. NOTE: won't actually do anything if
# `INDEXING_TRACER_INTERVAL` is 0.
memory_tracer = MemoryTracer(interval=INDEXING_TRACER_INTERVAL)
memory_tracer.start()
index_attempt_md = IndexAttemptMetadata(
attempt_id=index_attempt_id,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
)
total_failures = 0
batch_num = 0
net_doc_change = 0
document_count = 0
chunk_count = 0
index_attempt: IndexAttempt | None = None
try:
with get_session_with_current_tenant() as db_session_temp:
index_attempt = get_index_attempt(
db_session_temp, index_attempt_id, eager_load_cc_pair=True
)
if not index_attempt:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
connector_runner = _get_connector_runner(
db_session=db_session_temp,
attempt=index_attempt,
batch_size=INDEX_BATCH_SIZE,
start_time=window_start,
end_time=window_end,
include_permissions=ctx.should_fetch_permissions_during_indexing,
)
# don't use a checkpoint if we're explicitly indexing from
# the beginning in order to avoid weird interactions between
# checkpointing / failure handling
# OR
# if the last attempt was successful
if index_attempt.from_beginning or (
most_recent_attempt and most_recent_attempt.status.is_successful()
):
checkpoint = connector_runner.connector.build_dummy_checkpoint()
else:
checkpoint, _ = get_latest_valid_checkpoint(
db_session=db_session_temp,
cc_pair_id=ctx.cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
window_start=window_start,
window_end=window_end,
connector=connector_runner.connector,
)
# save the initial checkpoint to have a proper record of the
# "last used checkpoint"
save_checkpoint(
db_session=db_session_temp,
index_attempt_id=index_attempt_id,
checkpoint=checkpoint,
)
unresolved_errors = get_index_attempt_errors_for_cc_pair(
cc_pair_id=ctx.cc_pair_id,
unresolved_only=True,
db_session=db_session_temp,
)
doc_id_to_unresolved_errors: dict[str, list[IndexAttemptError]] = (
defaultdict(list)
)
for error in unresolved_errors:
if error.document_id:
doc_id_to_unresolved_errors[error.document_id].append(error)
entity_based_unresolved_errors = [
error for error in unresolved_errors if error.entity_id
]
while checkpoint.has_more:
logger.info(
f"Running '{ctx.source.value}' connector with checkpoint: {checkpoint}"
)
for document_batch, failure, next_checkpoint in connector_runner.run(
checkpoint
):
# Check if connector is disabled mid run and stop if so unless it's the secondary
# index being built. We want to populate it even for paused connectors
# Often paused connectors are sources that aren't updated frequently but the
# contents still need to be initially pulled.
if callback:
if callback.should_stop():
raise ConnectorStopSignal("Connector stop signal detected")
# NOTE: this progress callback runs on every loop. We've seen cases
# where we loop many times with no new documents and eventually time
# out, so only doing the callback after indexing isn't sufficient.
callback.progress("_run_indexing", 0)
# TODO: should we move this into the above callback instead?
with get_session_with_current_tenant() as db_session_temp:
# will exception if the connector/index attempt is marked as paused/failed
_check_connector_and_attempt_status(
db_session_temp,
ctx.cc_pair_id,
ctx.search_settings_status,
index_attempt_id,
)
# save record of any failures at the connector level
if failure is not None:
total_failures += 1
with get_session_with_current_tenant() as db_session_temp:
create_index_attempt_error(
index_attempt_id,
ctx.cc_pair_id,
failure,
db_session_temp,
)
_check_failure_threshold(
total_failures, document_count, batch_num, failure
)
# save the new checkpoint (if one is provided)
if next_checkpoint:
checkpoint = next_checkpoint
# below is all document processing logic, so if no batch we can just continue
if document_batch is None:
continue
batch_description = []
# Generate an ID that can be used to correlate activity between here
# and the embedding model server
doc_batch_cleaned = strip_null_characters(document_batch)
for doc in doc_batch_cleaned:
batch_description.append(doc.to_short_descriptor())
doc_size = 0
for section in doc.sections:
if (
isinstance(section, TextSection)
and section.text is not None
):
doc_size += len(section.text)
if doc_size > INDEXING_SIZE_WARNING_THRESHOLD:
logger.warning(
f"Document size: doc='{doc.to_short_descriptor()}' "
f"size={doc_size} "
f"threshold={INDEXING_SIZE_WARNING_THRESHOLD}"
)
logger.debug(f"Indexing batch of documents: {batch_description}")
index_attempt_md.request_id = make_randomized_onyx_request_id("CIX")
index_attempt_md.structured_id = (
f"{tenant_id}:{ctx.cc_pair_id}:{index_attempt_id}:{batch_num}"
)
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
# real work happens here!
adapter = DocumentIndexingBatchAdapter(
db_session=db_session,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
tenant_id=tenant_id,
index_attempt_metadata=index_attempt_md,
)
index_pipeline_result = run_indexing_pipeline(
embedder=embedding_model,
information_content_classification_model=information_content_classification_model,
document_index=document_index,
ignore_time_skip=(
ctx.from_beginning
or (ctx.search_settings_status == IndexModelStatus.FUTURE)
),
db_session=db_session,
tenant_id=tenant_id,
document_batch=doc_batch_cleaned,
request_id=index_attempt_md.request_id,
adapter=adapter,
)
batch_num += 1
net_doc_change += index_pipeline_result.new_docs
chunk_count += index_pipeline_result.total_chunks
document_count += index_pipeline_result.total_docs
# resolve errors for documents that were successfully indexed
failed_document_ids = [
failure.failed_document.document_id
for failure in index_pipeline_result.failures
if failure.failed_document
]
successful_document_ids = [
document.id
for document in document_batch
if document.id not in failed_document_ids
]
for document_id in successful_document_ids:
with get_session_with_current_tenant() as db_session_temp:
if document_id in doc_id_to_unresolved_errors:
logger.info(
f"Resolving IndexAttemptError for document '{document_id}'"
)
for error in doc_id_to_unresolved_errors[document_id]:
error.is_resolved = True
db_session_temp.add(error)
db_session_temp.commit()
# add brand new failures
if index_pipeline_result.failures:
total_failures += len(index_pipeline_result.failures)
with get_session_with_current_tenant() as db_session_temp:
for failure in index_pipeline_result.failures:
create_index_attempt_error(
index_attempt_id,
ctx.cc_pair_id,
failure,
db_session_temp,
)
_check_failure_threshold(
total_failures,
document_count,
batch_num,
index_pipeline_result.failures[-1],
)
# This new value is updated every batch, so UI can refresh per batch update
with get_session_with_current_tenant() as db_session_temp:
# NOTE: Postgres uses the start of the transactions when computing `NOW()`
# so we need either to commit() or to use a new session
update_docs_indexed(
db_session=db_session_temp,
index_attempt_id=index_attempt_id,
total_docs_indexed=document_count,
new_docs_indexed=net_doc_change,
docs_removed_from_index=0,
)
if callback:
callback.progress("_run_indexing", len(doc_batch_cleaned))
# Add telemetry for indexing progress
optional_telemetry(
record_type=RecordType.INDEXING_PROGRESS,
data={
"index_attempt_id": index_attempt_id,
"cc_pair_id": ctx.cc_pair_id,
"current_docs_indexed": document_count,
"current_chunks_indexed": chunk_count,
"source": ctx.source.value,
},
tenant_id=tenant_id,
)
memory_tracer.increment_and_maybe_trace()
# `make sure the checkpoints aren't getting too large`at some regular interval
CHECKPOINT_SIZE_CHECK_INTERVAL = 100
if batch_num % CHECKPOINT_SIZE_CHECK_INTERVAL == 0:
check_checkpoint_size(checkpoint)
# save latest checkpoint
with get_session_with_current_tenant() as db_session_temp:
save_checkpoint(
db_session=db_session_temp,
index_attempt_id=index_attempt_id,
checkpoint=checkpoint,
)
optional_telemetry(
record_type=RecordType.INDEXING_COMPLETE,
data={
"index_attempt_id": index_attempt_id,
"cc_pair_id": ctx.cc_pair_id,
"total_docs_indexed": document_count,
"total_chunks": chunk_count,
"time_elapsed_seconds": time.monotonic() - start_time,
"source": ctx.source.value,
},
tenant_id=tenant_id,
)
except Exception as e:
logger.exception(
"Connector run exceptioned after elapsed time: "
f"{time.monotonic() - start_time} seconds"
)
if isinstance(e, ConnectorValidationError):
# On validation errors during indexing, we want to cancel the indexing attempt
# and mark the CCPair as invalid. This prevents the connector from being
# used in the future until the credentials are updated.
with get_session_with_current_tenant() as db_session_temp:
logger.exception(
f"Marking attempt {index_attempt_id} as canceled due to validation error."
)
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=f"{CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX}{str(e)}",
)
if ctx.is_primary:
if not index_attempt:
# should always be set by now
raise RuntimeError("Should never happen.")
VALIDATION_ERROR_THRESHOLD = 5
recent_index_attempts = get_recent_completed_attempts_for_cc_pair(
cc_pair_id=ctx.cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
limit=VALIDATION_ERROR_THRESHOLD,
db_session=db_session_temp,
)
num_validation_errors = len(
[
index_attempt
for index_attempt in recent_index_attempts
if index_attempt.error_msg
and index_attempt.error_msg.startswith(
CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX
)
]
)
if num_validation_errors >= VALIDATION_ERROR_THRESHOLD:
logger.warning(
f"Connector {ctx.connector_id} has {num_validation_errors} consecutive validation"
f" errors. Marking the CC Pair as invalid."
)
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
status=ConnectorCredentialPairStatus.INVALID,
)
memory_tracer.stop()
raise e
elif isinstance(e, ConnectorStopSignal):
with get_session_with_current_tenant() as db_session_temp:
logger.exception(
f"Marking attempt {index_attempt_id} as canceled due to stop signal."
)
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=str(e),
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
net_docs=net_doc_change,
)
memory_tracer.stop()
raise e
else:
with get_session_with_current_tenant() as db_session_temp:
mark_attempt_failed(
index_attempt_id,
db_session_temp,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
net_docs=net_doc_change,
)
memory_tracer.stop()
raise e
memory_tracer.stop()
# we know index attempt is successful (at least partially) at this point,
# all other cases have been short-circuited
elapsed_time = time.monotonic() - start_time
with get_session_with_current_tenant() as db_session_temp:
# resolve entity-based errors
for error in entity_based_unresolved_errors:
logger.info(f"Resolving IndexAttemptError for entity '{error.entity_id}'")
error.is_resolved = True
db_session_temp.add(error)
db_session_temp.commit()
if total_failures == 0:
mark_attempt_succeeded(index_attempt_id, db_session_temp)
create_milestone_and_report(
user=None,
distinct_id=tenant_id or "N/A",
event_type=MilestoneRecordType.CONNECTOR_SUCCEEDED,
properties=None,
db_session=db_session_temp,
)
logger.info(
f"Connector succeeded: "
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
)
else:
mark_attempt_partially_succeeded(index_attempt_id, db_session_temp)
logger.info(
f"Connector completed with some errors: "
f"failures={total_failures} "
f"batches={batch_num} "
f"docs={document_count} "
f"chunks={chunk_count} "
f"elapsed={elapsed_time:.2f}s"
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
run_dt=window_end,
)
if ctx.should_fetch_permissions_during_indexing:
mark_cc_pair_as_permissions_synced(
db_session=db_session_temp,
cc_pair_id=ctx.cc_pair_id,
start_time=window_end,
)
def run_docfetching_entrypoint(
app: Celery,
index_attempt_id: int,
@@ -968,11 +368,19 @@ def connector_document_extraction(
db_connector = index_attempt.connector_credential_pair.connector
db_credential = index_attempt.connector_credential_pair.credential
is_primary = index_attempt.search_settings.status == IndexModelStatus.PRESENT
from_beginning = index_attempt.from_beginning
has_successful_attempt = (
index_attempt.connector_credential_pair.last_successful_index_time
is not None
)
# Use higher priority for first-time indexing to ensure new connectors
# get processed before re-indexing of existing connectors
docprocessing_priority = (
OnyxCeleryPriority.MEDIUM
if has_successful_attempt
else OnyxCeleryPriority.HIGH
)
earliest_index_time = (
db_connector.indexing_start.timestamp()
@@ -1095,6 +503,7 @@ def connector_document_extraction(
tenant_id,
app,
most_recent_attempt,
docprocessing_priority,
)
last_batch_num = reissued_batch_count + completed_batches
index_attempt.completed_batches = completed_batches
@@ -1207,7 +616,7 @@ def connector_document_extraction(
OnyxCeleryTask.DOCPROCESSING_TASK,
kwargs=processing_batch_data,
queue=OnyxCeleryQueues.DOCPROCESSING,
priority=OnyxCeleryPriority.MEDIUM,
priority=docprocessing_priority,
)
batch_num += 1
@@ -1358,6 +767,7 @@ def reissue_old_batches(
tenant_id: str,
app: Celery,
most_recent_attempt: IndexAttempt | None,
priority: OnyxCeleryPriority,
) -> tuple[int, int]:
# When loading from a checkpoint, we need to start new docprocessing tasks
# tied to the new index attempt for any batches left over in the file store
@@ -1385,7 +795,7 @@ def reissue_old_batches(
"batch_num": path_info.batch_num, # use same batch num as previously
},
queue=OnyxCeleryQueues.DOCPROCESSING,
priority=OnyxCeleryPriority.MEDIUM,
priority=priority,
)
recent_batches = most_recent_attempt.completed_batches if most_recent_attempt else 0
# resume from the batch num of the last attempt. This should be one more

View File

@@ -1,64 +0,0 @@
"""
Module for handling chat-related milestone tracking and telemetry.
"""
from sqlalchemy.orm import Session
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import NO_AUTH_USER_ID
from onyx.db.milestone import check_multi_assistant_milestone
from onyx.db.milestone import create_milestone_if_not_exists
from onyx.db.milestone import update_user_assistant_milestone
from onyx.db.models import User
from onyx.utils.telemetry import mt_cloud_telemetry
def process_multi_assistant_milestone(
user: User | None,
assistant_id: int,
tenant_id: str,
db_session: Session,
) -> None:
"""
Process the multi-assistant milestone for a user.
This function:
1. Creates or retrieves the multi-assistant milestone
2. Updates the milestone with the current assistant usage
3. Checks if the milestone was just achieved
4. Sends telemetry if the milestone was just hit
Args:
user: The user for whom to process the milestone (can be None for anonymous users)
assistant_id: The ID of the assistant being used
tenant_id: The current tenant ID
db_session: Database session for queries
"""
# Create or retrieve the multi-assistant milestone
multi_assistant_milestone, _is_new = create_milestone_if_not_exists(
user=user,
event_type=MilestoneRecordType.MULTIPLE_ASSISTANTS,
db_session=db_session,
)
# Update the milestone with the current assistant usage
update_user_assistant_milestone(
milestone=multi_assistant_milestone,
user_id=str(user.id) if user else NO_AUTH_USER_ID,
assistant_id=assistant_id,
db_session=db_session,
)
# Check if the milestone was just achieved
_, just_hit_multi_assistant_milestone = check_multi_assistant_milestone(
milestone=multi_assistant_milestone,
db_session=db_session,
)
# Send telemetry if the milestone was just hit
if just_hit_multi_assistant_milestone:
mt_cloud_telemetry(
distinct_id=tenant_id,
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
properties=None,
)

View File

@@ -1,10 +1,12 @@
import threading
from collections.abc import Callable
from collections.abc import Generator
from queue import Empty
from typing import Any
from onyx.chat.citation_processor import CitationMapping
from onyx.chat.emitter import Emitter
from onyx.context.search.models import SearchDoc
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PacketException
@@ -18,39 +20,77 @@ class ChatStateContainer:
This container holds the partial state that can be saved to the database
if the generation is stopped by the user or completes normally.
Thread-safe: All write operations are protected by a lock to ensure safe
concurrent access from multiple threads. For thread-safe reads, use the
getter methods. Direct attribute access is not thread-safe.
"""
def __init__(self) -> None:
self._lock = threading.Lock()
# These are collected at the end after the entire tool call is completed
self.tool_calls: list[ToolCallInfo] = []
# This is accumulated during the streaming
self.reasoning_tokens: str | None = None
# This is accumulated during the streaming of the answer
self.answer_tokens: str | None = None
# Store citation mapping for building citation_docs_info during partial saves
self.citation_to_doc: dict[int, SearchDoc] = {}
self.citation_to_doc: CitationMapping = {}
# True if this turn is a clarification question (deep research flow)
self.is_clarification: bool = False
def add_tool_call(self, tool_call: ToolCallInfo) -> None:
"""Add a tool call to the accumulated state."""
self.tool_calls.append(tool_call)
with self._lock:
self.tool_calls.append(tool_call)
def set_reasoning_tokens(self, reasoning: str | None) -> None:
"""Set the reasoning tokens from the final answer generation."""
self.reasoning_tokens = reasoning
with self._lock:
self.reasoning_tokens = reasoning
def set_answer_tokens(self, answer: str | None) -> None:
"""Set the answer tokens from the final answer generation."""
self.answer_tokens = answer
with self._lock:
self.answer_tokens = answer
def set_citation_mapping(self, citation_to_doc: dict[int, Any]) -> None:
def set_citation_mapping(self, citation_to_doc: CitationMapping) -> None:
"""Set the citation mapping from citation processor."""
self.citation_to_doc = citation_to_doc
with self._lock:
self.citation_to_doc = citation_to_doc
def set_is_clarification(self, is_clarification: bool) -> None:
"""Set whether this turn is a clarification question."""
self.is_clarification = is_clarification
with self._lock:
self.is_clarification = is_clarification
def get_answer_tokens(self) -> str | None:
"""Thread-safe getter for answer_tokens."""
with self._lock:
return self.answer_tokens
def get_reasoning_tokens(self) -> str | None:
"""Thread-safe getter for reasoning_tokens."""
with self._lock:
return self.reasoning_tokens
def get_tool_calls(self) -> list[ToolCallInfo]:
"""Thread-safe getter for tool_calls (returns a copy)."""
with self._lock:
return self.tool_calls.copy()
def get_citation_to_doc(self) -> CitationMapping:
"""Thread-safe getter for citation_to_doc (returns a copy)."""
with self._lock:
return self.citation_to_doc.copy()
def get_is_clarification(self) -> bool:
"""Thread-safe getter for is_clarification."""
with self._lock:
return self.is_clarification
def run_chat_llm_with_state_containers(
def run_chat_loop_with_state_containers(
func: Callable[..., None],
is_connected: Callable[[], bool],
emitter: Emitter,
@@ -74,7 +114,7 @@ def run_chat_llm_with_state_containers(
**kwargs: Additional keyword arguments for func
Usage:
packets = run_chat_llm_with_state_containers(
packets = run_chat_loop_with_state_containers(
my_func,
emitter=emitter,
state_container=state_container,
@@ -95,7 +135,7 @@ def run_chat_llm_with_state_containers(
# If execution fails, emit an exception packet
emitter.emit(
Packet(
turn_index=0,
placement=Placement(turn_index=0),
obj=PacketException(type="error", exception=e),
)
)

View File

@@ -38,6 +38,7 @@ from onyx.db.models import Tool
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.db.search_settings import get_current_search_settings
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import FileDescriptor
@@ -49,8 +50,10 @@ from onyx.llm.override_models import LLMOverride
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.prompts.chat_prompts import ADDITIONAL_CONTEXT_PROMPT
from onyx.prompts.chat_prompts import TOOL_CALL_RESPONSE_CROSS_MESSAGE
from onyx.prompts.tool_prompts import TOOL_CALL_FAILURE_PROMPT
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.tools.models import ToolCallKickoff
from onyx.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
@@ -71,7 +74,6 @@ def prepare_chat_message_request(
retrieval_details: RetrievalDetails | None,
rerank_settings: RerankingDetails | None,
db_session: Session,
use_agentic_search: bool = False,
skip_gen_ai_answer_generation: bool = False,
llm_override: LLMOverride | None = None,
allowed_tool_ids: list[int] | None = None,
@@ -98,7 +100,6 @@ def prepare_chat_message_request(
search_doc_ids=None,
retrieval_options=retrieval_details,
rerank_settings=rerank_settings,
use_agentic_search=use_agentic_search,
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
llm_override=llm_override,
allowed_tool_ids=allowed_tool_ids,
@@ -483,10 +484,14 @@ def load_chat_file(
if file_type.is_text_file():
try:
content_text = content.decode("utf-8")
except UnicodeDecodeError:
content_text = extract_file_text(
file=file_io,
file_name=file_descriptor.get("name") or "",
break_on_unprocessable=False,
)
except Exception as e:
logger.warning(
f"Failed to decode text content for file {file_descriptor['id']}"
f"Failed to retrieve content for file {file_descriptor['id']}: {str(e)}"
)
# Get token count from UserFile if available
@@ -581,9 +586,16 @@ def convert_chat_history(
# Add text files as separate messages before the user message
for text_file in text_files:
file_text = text_file.content_text or ""
filename = text_file.filename
message = (
f"File: {filename}\n{file_text}\nEnd of File"
if filename
else file_text
)
simple_messages.append(
ChatMessageSimple(
message=text_file.content_text or "",
message=message,
token_count=text_file.token_count,
message_type=MessageType.USER,
image_files=None,
@@ -729,3 +741,38 @@ def is_last_assistant_message_clarification(chat_history: list[ChatMessage]) ->
if message.message_type == MessageType.ASSISTANT:
return message.is_clarification
return False
def create_tool_call_failure_messages(
tool_call: ToolCallKickoff, token_counter: Callable[[str], int]
) -> list[ChatMessageSimple]:
"""Create ChatMessageSimple objects for a failed tool call.
Creates two messages:
1. The tool call message itself
2. A failure response message indicating the tool call failed
Args:
tool_call: The ToolCallKickoff object representing the failed tool call
token_counter: Function to count tokens in a message string
Returns:
List containing two ChatMessageSimple objects: tool call message and failure response
"""
tool_call_msg = ChatMessageSimple(
message=tool_call.to_msg_str(),
token_count=token_counter(tool_call.to_msg_str()),
message_type=MessageType.TOOL_CALL,
tool_call_id=tool_call.tool_call_id,
image_files=None,
)
failure_response_msg = ChatMessageSimple(
message=TOOL_CALL_FAILURE_PROMPT,
token_count=token_counter(TOOL_CALL_FAILURE_PROMPT),
message_type=MessageType.TOOL_CALL_RESPONSE,
tool_call_id=tool_call.tool_call_id,
image_files=None,
)
return [tool_call_msg, failure_response_msg]

View File

@@ -4,13 +4,15 @@ Dynamic Citation Processor for LLM Responses
This module provides a citation processor that can:
- Accept citation number to SearchDoc mappings dynamically
- Process token streams from LLMs to extract citations
- Remove citation markers from output text
- Emit CitationInfo objects for detected citations
- Optionally replace citation markers with formatted markdown links
- Emit CitationInfo objects for detected citations (when replacing)
- Track all seen citations regardless of replacement mode
- Maintain a list of cited documents in order of first citation
"""
import re
from collections.abc import Generator
from typing import TypeAlias
from onyx.configs.chat_configs import STOP_STREAM_PAT
from onyx.context.search.models import SearchDoc
@@ -21,8 +23,11 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
CitationMapping: TypeAlias = dict[int, SearchDoc]
# ============================================================================
# Utility functions (copied for self-containment)
# Utility functions
# ============================================================================
@@ -43,19 +48,29 @@ class DynamicCitationProcessor:
This processor is designed for multi-turn conversations where the citation
number to document mapping is provided externally. It processes streaming
tokens from an LLM, detects citations (e.g., [1], [2,3], [[4]]), and:
tokens from an LLM, detects citations (e.g., [1], [2,3], [[4]]), and based
on the `replace_citation_tokens` setting:
1. Removes citation markers from the output text
2. Emits CitationInfo objects for tracking
3. Maintains the order in which documents were first cited
When replace_citation_tokens=True (default):
1. Replaces citation markers with formatted markdown links (e.g., [[1]](url))
2. Emits CitationInfo objects for tracking
3. Maintains the order in which documents were first cited
When replace_citation_tokens=False:
1. Preserves original citation markers in the output text
2. Does NOT emit CitationInfo objects
3. Still tracks all seen citations via get_seen_citations()
Features:
- Accepts citation number → SearchDoc mapping via update_citation_mapping()
- Processes tokens from LLM and removes citation markers
- Holds back tokens that might be partial citations
- Maintains list of cited SearchDocs in order of first citation
- Accepts citation number → SearchDoc mapping via update_citation_mapping()
- Configurable citation replacement behavior at initialization
- Always tracks seen citations regardless of replacement mode
- Holds back tokens that might be partial citations
- Maintains list of cited SearchDocs in order of first citation
- Handles unicode bracket variants (【】, )
- Skips citation processing inside code blocks
Example:
Example (with citation replacement - default):
processor = DynamicCitationProcessor()
# Set up citation mapping
@@ -65,37 +80,55 @@ class DynamicCitationProcessor:
for token in llm_stream:
for result in processor.process_token(token):
if isinstance(result, str):
print(result) # Display text (citations removed)
print(result) # Display text with [[1]](url) format
elif isinstance(result, CitationInfo):
handle_citation(result) # Track citation
# Update mapping with more documents
processor.update_citation_mapping({3: search_doc3, 4: search_doc4})
# Continue processing...
# Get cited documents at the end
cited_docs = processor.get_cited_documents()
Example (without citation replacement):
processor = DynamicCitationProcessor(replace_citation_tokens=False)
processor.update_citation_mapping({1: search_doc1, 2: search_doc2})
# Process tokens from LLM
for token in llm_stream:
for result in processor.process_token(token):
# Only strings are yielded, no CitationInfo objects
print(result) # Display text with original [1] format preserved
# Get all seen citations after processing
seen_citations = processor.get_seen_citations() # {1: search_doc1, ...}
"""
def __init__(
self,
replace_citation_tokens: bool = True,
stop_stream: str | None = STOP_STREAM_PAT,
):
"""
Initialize the citation processor.
Args:
stop_stream: Optional stop token to halt processing early
replace_citation_tokens: If True (default), citations like [1] are replaced
with formatted markdown links like [[1]](url) and CitationInfo objects
are emitted. If False, original citation text is preserved in output
and no CitationInfo objects are emitted. Regardless of this setting,
all seen citations are tracked and available via get_seen_citations().
stop_stream: Optional stop token pattern to halt processing early.
When this pattern is detected in the token stream, processing stops.
Defaults to STOP_STREAM_PAT from chat configs.
"""
# Citation mapping from citation number to SearchDoc
self.citation_to_doc: dict[int, SearchDoc] = {}
self.citation_to_doc: CitationMapping = {}
self.seen_citations: CitationMapping = {} # citation num -> SearchDoc
# Token processing state
self.llm_out = "" # entire output so far
self.curr_segment = "" # tokens held for citation processing
self.hold = "" # tokens held for stop token processing
self.stop_stream = stop_stream
self.replace_citation_tokens = replace_citation_tokens
# Citation tracking
self.cited_documents_in_order: list[SearchDoc] = (
@@ -119,7 +152,11 @@ class DynamicCitationProcessor:
r"([\[【[]{2}\d+[\]】]]{2})|([\[【[]\d+(?:, ?\d+)*[\]】]])"
)
def update_citation_mapping(self, citation_mapping: dict[int, SearchDoc]) -> None:
def update_citation_mapping(
self,
citation_mapping: CitationMapping,
update_duplicate_keys: bool = False,
) -> None:
"""
Update the citation number to SearchDoc mapping.
@@ -128,15 +165,25 @@ class DynamicCitationProcessor:
Args:
citation_mapping: Dictionary mapping citation numbers (1, 2, 3, ...) to SearchDoc objects
update_duplicate_keys: If True, update existing mappings with new values when keys overlap.
If False (default), filter out duplicate keys and only add non-duplicates.
The default behavior is useful when OpenURL may have the same citation number as a
Web Search result - in those cases, we keep the web search citation and snippet etc.
"""
# Filter out duplicate keys and only add non-duplicates
# Reason for this is that OpenURL may have the same citation number as a Web Search result
# For those, we should just keep the web search citation and snippet etc.
duplicate_keys = set(citation_mapping.keys()) & set(self.citation_to_doc.keys())
non_duplicate_mapping = {
k: v for k, v in citation_mapping.items() if k not in duplicate_keys
}
self.citation_to_doc.update(non_duplicate_mapping)
if update_duplicate_keys:
# Update all mappings, including duplicates
self.citation_to_doc.update(citation_mapping)
else:
# Filter out duplicate keys and only add non-duplicates
# Reason for this is that OpenURL may have the same citation number as a Web Search result
# For those, we should just keep the web search citation and snippet etc.
duplicate_keys = set(citation_mapping.keys()) & set(
self.citation_to_doc.keys()
)
non_duplicate_mapping = {
k: v for k, v in citation_mapping.items() if k not in duplicate_keys
}
self.citation_to_doc.update(non_duplicate_mapping)
def process_token(
self, token: str | None
@@ -147,17 +194,24 @@ class DynamicCitationProcessor:
This method:
1. Accumulates tokens until a complete citation or non-citation is found
2. Holds back potential partial citations (e.g., "[", "[1")
3. Yields text chunks when they're safe to display (with citations REMOVED)
4. Yields CitationInfo when citations are detected
5. Handles code blocks (avoids processing citations inside code)
6. Handles stop tokens
3. Yields text chunks when they're safe to display
4. Handles code blocks (avoids processing citations inside code)
5. Handles stop tokens
6. Always tracks seen citations in self.seen_citations
Behavior depends on the `replace_citation_tokens` setting from __init__:
- If True: Citations are replaced with [[n]](url) format and CitationInfo
objects are yielded before each formatted citation
- If False: Original citation text (e.g., [1]) is preserved in output
and no CitationInfo objects are yielded
Args:
token: The next token from the LLM stream, or None to signal end of stream
token: The next token from the LLM stream, or None to signal end of stream.
Pass None to flush any remaining buffered text at end of stream.
Yields:
- str: Text chunks to display (citations removed)
- CitationInfo: Citation metadata when a citation is detected
str: Text chunks to display. Citation format depends on replace_citation_tokens.
CitationInfo: Citation metadata (only when replace_citation_tokens=True)
"""
# None -> end of stream, flush remaining segment
if token is None:
@@ -250,17 +304,24 @@ class DynamicCitationProcessor:
yield intermatch_str
# Process the citation (returns formatted citation text and CitationInfo objects)
# Always tracks seen citations regardless of strip_citations flag
citation_text, citation_info_list = self._process_citation(
match, has_leading_space
match, has_leading_space, self.replace_citation_tokens
)
# Yield CitationInfo objects BEFORE the citation text
# This allows the frontend to receive citation metadata before the token
# that contains [[n]](link), enabling immediate rendering
for citation in citation_info_list:
yield citation
# Then yield the formatted citation text
if citation_text:
yield citation_text
if self.replace_citation_tokens:
# Yield CitationInfo objects BEFORE the citation text
# This allows the frontend to receive citation metadata before the token
# that contains [[n]](link), enabling immediate rendering
for citation in citation_info_list:
yield citation
# Then yield the formatted citation text
if citation_text:
yield citation_text
else:
# When not stripping, yield the original citation text unchanged
yield match.group()
self.non_citation_count = 0
# Leftover text could be part of next citation
@@ -277,27 +338,42 @@ class DynamicCitationProcessor:
yield result
def _process_citation(
self, match: re.Match, has_leading_space: bool
self, match: re.Match, has_leading_space: bool, replace_tokens: bool = True
) -> tuple[str, list[CitationInfo]]:
"""
Process a single citation match and return formatted citation text and citation info objects.
The match string can look like '[1]', '[1, 13, 6]', '[[4]]', '【1】', etc.
This is an internal method called by process_token(). The match string can be
in various formats: '[1]', '[1, 13, 6]', '[[4]]', '【1】', '1', etc.
This method:
This method always:
1. Extracts citation numbers from the match
2. Looks up the corresponding SearchDoc from the mapping
3. Skips duplicate citations if they were recently cited
4. Creates formatted citation text like [n](link) for each citation
3. Tracks seen citations in self.seen_citations (regardless of replace_tokens)
When replace_tokens=True (controlled by self.replace_citation_tokens):
4. Creates formatted citation text as [[n]](url)
5. Creates CitationInfo objects for new citations
6. Handles deduplication of recently cited documents
When replace_tokens=False:
4. Returns empty string and empty list (caller yields original match text)
Args:
match: Regex match object containing the citation
has_leading_space: Whether the text before the citation has a leading space
match: Regex match object containing the citation pattern
has_leading_space: Whether the text immediately before this citation
ends with whitespace. Used to determine if a leading space should
be added to the formatted output.
replace_tokens: If True, return formatted text and CitationInfo objects.
If False, only track seen citations and return empty results.
This is passed from self.replace_citation_tokens by the caller.
Returns:
Tuple of (formatted_citation_text, list[CitationInfo])
- formatted_citation_text: Markdown-formatted citation text like [1](link) [2](link)
- citation_info_list: List of CitationInfo objects
Tuple of (formatted_citation_text, citation_info_list):
- formatted_citation_text: Markdown-formatted citation text like
"[[1]](https://example.com)" or empty string if replace_tokens=False
- citation_info_list: List of CitationInfo objects for newly cited
documents, or empty list if replace_tokens=False
"""
citation_str: str = match.group() # e.g., '[1]', '[1, 2, 3]', '[[1]]', '【1】'
formatted = (
@@ -335,7 +411,14 @@ class DynamicCitationProcessor:
doc_id = search_doc.document_id
link = search_doc.link or ""
# Always format the citation text as [[n]](link)
# Always track seen citations regardless of replace_tokens setting
self.seen_citations[num] = search_doc
# When not replacing citation tokens, skip the rest of the processing
if not replace_tokens:
continue
# Format the citation text as [[n]](link)
formatted_citation_parts.append(f"[[{num}]]({link})")
# Skip creating CitationInfo for citations of the same work if cited recently (deduplication)
@@ -367,8 +450,14 @@ class DynamicCitationProcessor:
"""
Get the list of cited SearchDoc objects in the order they were first cited.
Note: This list is only populated when `replace_citation_tokens=True`.
When `replace_citation_tokens=False`, this will return an empty list.
Use get_seen_citations() instead if you need to track citations without
replacing them.
Returns:
List of SearchDoc objects
List of SearchDoc objects in the order they were first cited.
Empty list if replace_citation_tokens=False.
"""
return self.cited_documents_in_order
@@ -376,34 +465,89 @@ class DynamicCitationProcessor:
"""
Get the list of cited document IDs in the order they were first cited.
Note: This list is only populated when `replace_citation_tokens=True`.
When `replace_citation_tokens=False`, this will return an empty list.
Use get_seen_citations() instead if you need to track citations without
replacing them.
Returns:
List of document IDs (strings)
List of document IDs (strings) in the order they were first cited.
Empty list if replace_citation_tokens=False.
"""
return [doc.document_id for doc in self.cited_documents_in_order]
def get_seen_citations(self) -> CitationMapping:
"""
Get all seen citations as a mapping from citation number to SearchDoc.
This returns all citations that have been encountered during processing,
regardless of the `replace_citation_tokens` setting. Citations are tracked
whenever they are parsed, making this useful for cases where you need to
know which citations appeared in the text without replacing them.
This is particularly useful when `replace_citation_tokens=False`, as
get_cited_documents() will be empty in that case, but get_seen_citations()
will still contain all the citations that were found.
Returns:
Dictionary mapping citation numbers (int) to SearchDoc objects.
The dictionary is keyed by the citation number as it appeared in
the text (e.g., {1: SearchDoc(...), 3: SearchDoc(...)}).
"""
return self.seen_citations
@property
def num_cited_documents(self) -> int:
"""Get the number of documents that have been cited."""
"""
Get the number of unique documents that have been cited.
Note: This count is only updated when `replace_citation_tokens=True`.
When `replace_citation_tokens=False`, this will always return 0.
Use len(get_seen_citations()) instead if you need to count citations
without replacing them.
Returns:
Number of unique documents cited. 0 if replace_citation_tokens=False.
"""
return len(self.cited_document_ids)
def reset_recent_citations(self) -> None:
"""
Reset the recent citations tracker.
This can be called to allow previously cited documents to be cited again
without being filtered out by the deduplication logic.
The processor tracks "recently cited" documents to avoid emitting duplicate
CitationInfo objects for the same document when it's cited multiple times
in close succession. This method clears that tracker.
This is primarily useful when `replace_citation_tokens=True` to allow
previously cited documents to emit CitationInfo objects again. Has no
effect when `replace_citation_tokens=False`.
The recent citation tracker is also automatically cleared when more than
5 non-citation characters are processed between citations.
"""
self.recent_cited_documents.clear()
def get_next_citation_number(self) -> int:
"""
Get the next available citation number.
Get the next available citation number for adding new documents to the mapping.
This method returns the next citation number that should be used for new documents.
If no citations exist yet, it returns 1. Otherwise, it returns max + 1.
This method returns the next citation number that should be used when adding
new documents via update_citation_mapping(). Useful when dynamically adding
citations during processing (e.g., from tool results like web search).
If no citations exist yet in the mapping, returns 1.
Otherwise, returns max(existing_citation_numbers) + 1.
Returns:
The next available citation number (1-indexed)
The next available citation number (1-indexed integer).
Example:
# After adding citations 1, 2, 3
processor.get_next_citation_number() # Returns 4
# With non-sequential citations 1, 5, 10
processor.get_next_citation_number() # Returns 11
"""
if not self.citation_to_doc:
return 1

View File

@@ -0,0 +1,177 @@
import re
from onyx.chat.citation_processor import CitationMapping
from onyx.chat.citation_processor import DynamicCitationProcessor
from onyx.context.search.models import SearchDocsResponse
from onyx.tools.built_in_tools import CITEABLE_TOOLS_NAMES
from onyx.tools.models import ToolResponse
def update_citation_processor_from_tool_response(
tool_response: ToolResponse,
citation_processor: DynamicCitationProcessor,
) -> None:
"""Update citation processor if this was a citeable tool with a SearchDocsResponse.
Checks if the tool call is citeable and if the response contains a SearchDocsResponse,
then creates a mapping from citation numbers to SearchDoc objects and updates the
citation processor.
Args:
tool_response: The response from the tool execution (must have tool_call set)
citation_processor: The DynamicCitationProcessor to update
"""
# Early return if tool_call is not set
if tool_response.tool_call is None:
return
# Update citation processor if this was a search tool
if tool_response.tool_call.tool_name in CITEABLE_TOOLS_NAMES:
# Check if the rich_response is a SearchDocsResponse
if isinstance(tool_response.rich_response, SearchDocsResponse):
search_response = tool_response.rich_response
# Create mapping from citation number to SearchDoc
citation_to_doc: CitationMapping = {}
for (
citation_num,
doc_id,
) in search_response.citation_mapping.items():
# Find the SearchDoc with this doc_id
matching_doc = next(
(
doc
for doc in search_response.search_docs
if doc.document_id == doc_id
),
None,
)
if matching_doc:
citation_to_doc[citation_num] = matching_doc
# Update the citation processor
citation_processor.update_citation_mapping(citation_to_doc)
def collapse_citations(
answer_text: str,
existing_citation_mapping: CitationMapping,
new_citation_mapping: CitationMapping,
) -> tuple[str, CitationMapping]:
"""Collapse the citations in the text to use the smallest possible numbers.
This function takes citations in the text (like [25], [30], etc.) and replaces them
with the smallest possible numbers. It starts numbering from the next available
integer after the existing citation mapping. If a citation refers to a document
that already exists in the existing citation mapping (matched by document_id),
it uses the existing citation number instead of assigning a new one.
Args:
answer_text: The text containing citations to collapse (e.g., "See [25] and [30]")
existing_citation_mapping: Citations already processed/displayed. These mappings
are preserved unchanged in the output.
new_citation_mapping: Citations from the current text that need to be collapsed.
The keys are the citation numbers as they appear in answer_text.
Returns:
A tuple of (updated_text, combined_mapping) where:
- updated_text: The text with citations replaced with collapsed numbers
- combined_mapping: All values from existing_citation_mapping plus the new
mappings with their (possibly renumbered) keys
"""
# Build a reverse lookup: document_id -> existing citation number
doc_id_to_existing_citation: dict[str, int] = {
doc.document_id: citation_num
for citation_num, doc in existing_citation_mapping.items()
}
# Determine the next available citation number
if existing_citation_mapping:
next_citation_num = max(existing_citation_mapping.keys()) + 1
else:
next_citation_num = 1
# Build the mapping from old citation numbers (in new_citation_mapping) to new numbers
old_to_new: dict[int, int] = {}
additional_mappings: CitationMapping = {}
for old_num, search_doc in new_citation_mapping.items():
doc_id = search_doc.document_id
# Check if this document already exists in existing citations
if doc_id in doc_id_to_existing_citation:
# Use the existing citation number
old_to_new[old_num] = doc_id_to_existing_citation[doc_id]
else:
# Check if we've already assigned a new number to this document
# (handles case where same doc appears with different old numbers)
existing_new_num = None
for mapped_old, mapped_new in old_to_new.items():
if (
mapped_old in new_citation_mapping
and new_citation_mapping[mapped_old].document_id == doc_id
):
existing_new_num = mapped_new
break
if existing_new_num is not None:
old_to_new[old_num] = existing_new_num
else:
# Assign the next available number
old_to_new[old_num] = next_citation_num
additional_mappings[next_citation_num] = search_doc
next_citation_num += 1
# Pattern to match citations like [25], [1, 2, 3], [[25]], etc.
# Also matches unicode bracket variants: 【】,
citation_pattern = re.compile(
r"([\[【[]{2}\d+[\]】]]{2})|([\[【[]\d+(?:, ?\d+)*[\]】]])"
)
def replace_citation(match: re.Match) -> str:
"""Replace citation numbers in a match with their new collapsed values."""
citation_str = match.group()
# Determine bracket style
if (
citation_str.startswith("[[")
or citation_str.startswith("【【")
or citation_str.startswith("")
):
open_bracket = citation_str[:2]
close_bracket = citation_str[-2:]
content = citation_str[2:-2]
else:
open_bracket = citation_str[0]
close_bracket = citation_str[-1]
content = citation_str[1:-1]
# Parse and replace citation numbers
new_nums = []
for num_str in content.split(","):
num_str = num_str.strip()
if not num_str:
continue
try:
num = int(num_str)
# Only replace if we have a mapping for this number
if num in old_to_new:
new_nums.append(str(old_to_new[num]))
else:
# Keep original if not in our mapping
new_nums.append(num_str)
except ValueError:
new_nums.append(num_str)
# Reconstruct the citation with original bracket style
new_content = ", ".join(new_nums)
return f"{open_bracket}{new_content}{close_bracket}"
# Replace all citations in the text
updated_text = citation_pattern.sub(replace_citation, answer_text)
# Build the combined mapping
combined_mapping: CitationMapping = dict(existing_citation_mapping)
combined_mapping.update(additional_mappings)
return updated_text, combined_mapping

View File

@@ -1,15 +1,14 @@
import json
from collections.abc import Callable
from typing import cast
from sqlalchemy.orm import Session
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.chat_utils import create_tool_call_failure_messages
from onyx.chat.citation_processor import CitationMapping
from onyx.chat.citation_processor import DynamicCitationProcessor
from onyx.chat.citation_utils import update_citation_processor_from_tool_response
from onyx.chat.emitter import Emitter
from onyx.chat.llm_step import run_llm_step
from onyx.chat.llm_step import TOOL_CALL_MSG_ARGUMENTS
from onyx.chat.llm_step import TOOL_CALL_MSG_FUNC_NAME
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import ExtractedProjectFiles
from onyx.chat.models import LlmStepResult
@@ -30,18 +29,18 @@ from onyx.llm.interfaces import ToolChoiceOptions
from onyx.llm.utils import model_needs_formatting_reenabled
from onyx.prompts.chat_prompts import IMAGE_GEN_REMINDER
from onyx.prompts.chat_prompts import OPEN_URL_REMINDER
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import TopLevelBranching
from onyx.tools.built_in_tools import CITEABLE_TOOLS_NAMES
from onyx.tools.built_in_tools import STOPPING_TOOLS_NAMES
from onyx.tools.interface import Tool
from onyx.tools.models import ToolCallInfo
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.images.models import (
FinalImageGenerationResponse,
)
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
from onyx.tools.tool_runner import run_tool_calls
@@ -64,7 +63,7 @@ MAX_LLM_CYCLES = 6
def _build_project_file_citation_mapping(
project_file_metadata: list[ProjectFileMetadata],
starting_citation_num: int = 1,
) -> dict[int, SearchDoc]:
) -> CitationMapping:
"""Build citation mapping for project files.
Converts project file metadata into SearchDoc objects that can be cited.
@@ -77,7 +76,7 @@ def _build_project_file_citation_mapping(
Returns:
Dictionary mapping citation numbers to SearchDoc objects
"""
citation_mapping: dict[int, SearchDoc] = {}
citation_mapping: CitationMapping = {}
for idx, file_meta in enumerate(project_file_metadata, start=starting_citation_num):
# Create a SearchDoc for each project file
@@ -293,8 +292,16 @@ def run_llm_loop(
db_session: Session,
forced_tool_id: int | None = None,
user_identity: LLMUserIdentity | None = None,
chat_session_id: str | None = None,
) -> None:
with trace("run_llm_loop", metadata={"tenant_id": get_current_tenant_id()}):
with trace(
"run_llm_loop",
group_id=chat_session_id,
metadata={
"tenant_id": get_current_tenant_id(),
"chat_session_id": chat_session_id,
},
):
# Fix some LiteLLM issues,
from onyx.llm.litellm_singleton.config import (
initialize_litellm,
@@ -302,18 +309,11 @@ def run_llm_loop(
initialize_litellm()
stopping_tools_names: list[str] = [ImageGenerationTool.NAME]
citeable_tools_names: list[str] = [
SearchTool.NAME,
WebSearchTool.NAME,
OpenURLTool.NAME,
]
# Initialize citation processor for handling citations dynamically
citation_processor = DynamicCitationProcessor()
# Add project file citation mappings if project files are present
project_citation_mapping: dict[int, SearchDoc] = {}
project_citation_mapping: CitationMapping = {}
if project_files.project_file_metadata:
project_citation_mapping = _build_project_file_citation_mapping(
project_files.project_file_metadata
@@ -325,7 +325,6 @@ def run_llm_loop(
# Pass the total budget to construct_message_history, which will handle token allocation
available_tokens = llm.config.max_input_tokens
tool_choice: ToolChoiceOptions = ToolChoiceOptions.AUTO
collected_tool_calls: list[ToolCallInfo] = []
# Initialize gathered_documents with project files if present
gathered_documents: list[SearchDoc] | None = (
list(project_citation_mapping.values())
@@ -343,12 +342,8 @@ def run_llm_loop(
has_called_search_tool: bool = False
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
current_tool_call_index = (
0 # TODO: just use the cycle count after parallel tool calls are supported
)
reasoning_cycles = 0
for llm_cycle_count in range(MAX_LLM_CYCLES):
if forced_tool_id:
# Needs to be just the single one because the "required" currently doesn't have a specified tool, just a binary
final_tools = [tool for tool in tools if tool.id == forced_tool_id]
@@ -445,12 +440,13 @@ def run_llm_loop(
# This calls the LLM, yields packets (reasoning, answers, etc.) and returns the result
# It also pre-processes the tool calls in preparation for running them
step_generator = run_llm_step(
llm_step_result, has_reasoned = run_llm_step(
emitter=emitter,
history=truncated_message_history,
tool_definitions=[tool.tool_definition() for tool in final_tools],
tool_choice=tool_choice,
llm=llm,
turn_index=current_tool_call_index,
placement=Placement(turn_index=llm_cycle_count + reasoning_cycles),
citation_processor=citation_processor,
state_container=state_container,
# The rich docs representation is passed in so that when yielding the answer, it can also
@@ -459,18 +455,8 @@ def run_llm_loop(
final_documents=gathered_documents,
user_identity=user_identity,
)
# Consume the generator, emitting packets and capturing the final result
while True:
try:
packet = next(step_generator)
emitter.emit(packet)
except StopIteration as e:
llm_step_result, current_tool_call_index = e.value
break
# Type narrowing: generator always returns a result, so this can't be None
llm_step_result = cast(LlmStepResult, llm_step_result)
if has_reasoned:
reasoning_cycles += 1
# Save citation mapping after each LLM step for incremental state updates
state_container.set_citation_mapping(citation_processor.citation_to_doc)
@@ -480,21 +466,50 @@ def run_llm_loop(
tool_responses: list[ToolResponse] = []
tool_calls = llm_step_result.tool_calls or []
just_ran_web_search = False
for tool_call in tool_calls:
# TODO replace the [tool_call] with the list of tool calls once parallel tool calls are supported
tool_responses, citation_mapping = run_tool_calls(
tool_calls=[tool_call],
tools=final_tools,
turn_index=current_tool_call_index,
message_history=truncated_message_history,
memories=memories,
user_info=None, # TODO, this is part of memories right now, might want to separate it out
citation_mapping=citation_mapping,
citation_processor=citation_processor,
skip_search_query_expansion=has_called_search_tool,
if len(tool_calls) > 1:
emitter.emit(
Packet(
placement=Placement(
turn_index=tool_calls[0].placement.turn_index
),
obj=TopLevelBranching(num_parallel_branches=len(tool_calls)),
)
)
# Quick note for why citation_mapping and citation_processors are both needed:
# 1. Tools return lightweight string mappings, not SearchDoc objects
# 2. The SearchDoc resolution is deliberately deferred to llm_loop.py
# 3. The citation_processor operates on SearchDoc objects and can't provide a complete reverse URL lookup for
# in-flight citations
# It can be cleaned up but not super trivial or worthwhile right now
just_ran_web_search = False
tool_responses, citation_mapping = run_tool_calls(
tool_calls=tool_calls,
tools=final_tools,
message_history=truncated_message_history,
memories=memories,
user_info=None, # TODO, this is part of memories right now, might want to separate it out
citation_mapping=citation_mapping,
next_citation_num=citation_processor.get_next_citation_number(),
skip_search_query_expansion=has_called_search_tool,
)
# Failure case, give something reasonable to the LLM to try again
if tool_calls and not tool_responses:
failure_messages = create_tool_call_failure_messages(
tool_calls[0], token_counter
)
simple_chat_history.extend(failure_messages)
continue
for tool_response in tool_responses:
# Extract tool_call from the response (set by run_tool_calls)
if tool_response.tool_call is None:
raise ValueError("Tool response missing tool_call reference")
tool_call = tool_response.tool_call
tab_index = tool_call.placement.tab_index
# Track if search tool was called (for skipping query expansion on subsequent calls)
if tool_call.tool_name == SearchTool.NAME:
has_called_search_tool = True
@@ -502,110 +517,81 @@ def run_llm_loop(
# Build a mapping of tool names to tool objects for getting tool_id
tools_by_name = {tool.name: tool for tool in final_tools}
# Add the results to the chat history, note that even if the tools were run in parallel, this isn't supported
# as all the LLM APIs require linear history, so these will just be included sequentially
for tool_call, tool_response in zip([tool_call], tool_responses):
# Get the tool object to retrieve tool_id
tool = tools_by_name.get(tool_call.tool_name)
if not tool:
raise ValueError(
f"Tool '{tool_call.tool_name}' not found in tools list"
)
# Extract search_docs if this is a search tool response
search_docs = None
if isinstance(tool_response.rich_response, SearchDocsResponse):
search_docs = tool_response.rich_response.search_docs
if gathered_documents:
gathered_documents.extend(search_docs)
else:
gathered_documents = search_docs
# This is used for the Open URL reminder in the next cycle
# only do this if the web search tool yielded results
if search_docs and tool_call.tool_name == WebSearchTool.NAME:
just_ran_web_search = True
# Extract generated_images if this is an image generation tool response
generated_images = None
if isinstance(
tool_response.rich_response, FinalImageGenerationResponse
):
generated_images = tool_response.rich_response.generated_images
tool_call_info = ToolCallInfo(
parent_tool_call_id=None, # Top-level tool calls are attached to the chat message
turn_index=current_tool_call_index,
tool_name=tool_call.tool_name,
tool_call_id=tool_call.tool_call_id,
tool_id=tool.id,
reasoning_tokens=llm_step_result.reasoning, # All tool calls from this loop share the same reasoning
tool_call_arguments=tool_call.tool_args,
tool_call_response=tool_response.llm_facing_response,
search_docs=search_docs,
generated_images=generated_images,
# Add the results to the chat history. Even though tools may run in parallel,
# LLM APIs require linear history, so results are added sequentially.
# Get the tool object to retrieve tool_id
tool = tools_by_name.get(tool_call.tool_name)
if not tool:
raise ValueError(
f"Tool '{tool_call.tool_name}' not found in tools list"
)
collected_tool_calls.append(tool_call_info)
# Add to state container for partial save support
state_container.add_tool_call(tool_call_info)
# Store tool call with function name and arguments in separate layers
tool_call_data = {
TOOL_CALL_MSG_FUNC_NAME: tool_call.tool_name,
TOOL_CALL_MSG_ARGUMENTS: tool_call.tool_args,
}
tool_call_message = json.dumps(tool_call_data)
tool_call_token_count = token_counter(tool_call_message)
# Extract search_docs if this is a search tool response
search_docs = None
if isinstance(tool_response.rich_response, SearchDocsResponse):
search_docs = tool_response.rich_response.search_docs
if gathered_documents:
gathered_documents.extend(search_docs)
else:
gathered_documents = search_docs
tool_call_msg = ChatMessageSimple(
message=tool_call_message,
token_count=tool_call_token_count,
message_type=MessageType.TOOL_CALL,
tool_call_id=tool_call.tool_call_id,
image_files=None,
)
simple_chat_history.append(tool_call_msg)
# This is used for the Open URL reminder in the next cycle
# only do this if the web search tool yielded results
if search_docs and tool_call.tool_name == WebSearchTool.NAME:
just_ran_web_search = True
tool_response_message = tool_response.llm_facing_response
tool_response_token_count = token_counter(tool_response_message)
# Extract generated_images if this is an image generation tool response
generated_images = None
if isinstance(
tool_response.rich_response, FinalImageGenerationResponse
):
generated_images = tool_response.rich_response.generated_images
tool_response_msg = ChatMessageSimple(
message=tool_response_message,
token_count=tool_response_token_count,
message_type=MessageType.TOOL_CALL_RESPONSE,
tool_call_id=tool_call.tool_call_id,
image_files=None,
)
simple_chat_history.append(tool_response_msg)
tool_call_info = ToolCallInfo(
parent_tool_call_id=None, # Top-level tool calls are attached to the chat message
turn_index=llm_cycle_count + reasoning_cycles,
tab_index=tab_index,
tool_name=tool_call.tool_name,
tool_call_id=tool_call.tool_call_id,
tool_id=tool.id,
reasoning_tokens=llm_step_result.reasoning, # All tool calls from this loop share the same reasoning
tool_call_arguments=tool_call.tool_args,
tool_call_response=tool_response.llm_facing_response,
search_docs=search_docs,
generated_images=generated_images,
)
# Add to state container for partial save support
state_container.add_tool_call(tool_call_info)
# Update citation processor if this was a search tool
if tool_call.tool_name in citeable_tools_names:
# Check if the rich_response is a SearchDocsResponse
if isinstance(tool_response.rich_response, SearchDocsResponse):
search_response = tool_response.rich_response
# Store tool call with function name and arguments in separate layers
tool_call_message = tool_call.to_msg_str()
tool_call_token_count = token_counter(tool_call_message)
# Create mapping from citation number to SearchDoc
citation_to_doc: dict[int, SearchDoc] = {}
for (
citation_num,
doc_id,
) in search_response.citation_mapping.items():
# Find the SearchDoc with this doc_id
matching_doc = next(
(
doc
for doc in search_response.search_docs
if doc.document_id == doc_id
),
None,
)
if matching_doc:
citation_to_doc[citation_num] = matching_doc
tool_call_msg = ChatMessageSimple(
message=tool_call_message,
token_count=tool_call_token_count,
message_type=MessageType.TOOL_CALL,
tool_call_id=tool_call.tool_call_id,
image_files=None,
)
simple_chat_history.append(tool_call_msg)
# Update the citation processor
citation_processor.update_citation_mapping(citation_to_doc)
tool_response_message = tool_response.llm_facing_response
tool_response_token_count = token_counter(tool_response_message)
current_tool_call_index += 1
tool_response_msg = ChatMessageSimple(
message=tool_response_message,
token_count=tool_response_token_count,
message_type=MessageType.TOOL_CALL_RESPONSE,
tool_call_id=tool_call.tool_call_id,
image_files=None,
)
simple_chat_history.append(tool_response_msg)
# Update citation processor if this was a search tool
update_citation_processor_from_tool_response(
tool_response, citation_processor
)
# If no tool calls, then it must have answered, wrap up
if not llm_step_result.tool_calls or len(llm_step_result.tool_calls) == 0:
@@ -613,13 +599,13 @@ def run_llm_loop(
# Certain tools do not allow further actions, force the LLM wrap up on the next cycle
if any(
tool.tool_name in stopping_tools_names
tool.tool_name in STOPPING_TOOLS_NAMES
for tool in llm_step_result.tool_calls
):
ran_image_gen = True
if llm_step_result.tool_calls and any(
tool.tool_name in citeable_tools_names
tool.tool_name in CITEABLE_TOOLS_NAMES
for tool in llm_step_result.tool_calls
):
# As long as 1 tool with citeable documents is called at any point, we ask the LLM to try to cite
@@ -629,5 +615,8 @@ def run_llm_loop(
raise RuntimeError("LLM did not return an answer.")
emitter.emit(
Packet(turn_index=current_tool_call_index, obj=OverallStop(type="stop"))
Packet(
placement=Placement(turn_index=llm_cycle_count + reasoning_cycles),
obj=OverallStop(type="stop"),
)
)

View File

@@ -1,4 +1,5 @@
import json
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Mapping
from collections.abc import Sequence
@@ -7,6 +8,7 @@ from typing import cast
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.citation_processor import DynamicCitationProcessor
from onyx.chat.emitter import Emitter
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import LlmStepResult
from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS
@@ -17,16 +19,19 @@ from onyx.llm.interfaces import LanguageModelInput
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMUserIdentity
from onyx.llm.interfaces import ToolChoiceOptions
from onyx.llm.model_response import Delta
from onyx.llm.models import AssistantMessage
from onyx.llm.models import ChatCompletionMessage
from onyx.llm.models import FunctionCall
from onyx.llm.models import ImageContentPart
from onyx.llm.models import ImageUrlDetail
from onyx.llm.models import ReasoningEffort
from onyx.llm.models import SystemMessage
from onyx.llm.models import TextContentPart
from onyx.llm.models import ToolCall
from onyx.llm.models import ToolMessage
from onyx.llm.models import UserMessage
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
from onyx.server.query_and_chat.streaming_models import AgentResponseStart
from onyx.server.query_and_chat.streaming_models import CitationInfo
@@ -34,6 +39,8 @@ from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
from onyx.server.query_and_chat.streaming_models import ReasoningDone
from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.tools.models import TOOL_CALL_MSG_ARGUMENTS
from onyx.tools.models import TOOL_CALL_MSG_FUNC_NAME
from onyx.tools.models import ToolCallKickoff
from onyx.tracing.framework.create import generation_span
from onyx.utils.b64 import get_image_type_from_bytes
@@ -43,8 +50,77 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
TOOL_CALL_MSG_FUNC_NAME = "function_name"
TOOL_CALL_MSG_ARGUMENTS = "arguments"
def _try_parse_json_string(value: Any) -> Any:
"""Attempt to parse a JSON string value into its Python equivalent.
If value is a string that looks like a JSON array or object, parse it.
Otherwise return the value unchanged.
This handles the case where the LLM returns arguments like:
- queries: '["query1", "query2"]' instead of ["query1", "query2"]
"""
if not isinstance(value, str):
return value
stripped = value.strip()
# Only attempt to parse if it looks like a JSON array or object
if not (
(stripped.startswith("[") and stripped.endswith("]"))
or (stripped.startswith("{") and stripped.endswith("}"))
):
return value
try:
return json.loads(stripped)
except json.JSONDecodeError:
return value
def _parse_tool_args_to_dict(raw_args: Any) -> dict[str, Any]:
"""Parse tool arguments into a dict.
Normal case:
- raw_args == '{"queries":[...]}' -> dict via json.loads
Defensive case (JSON string literal of an object):
- raw_args == '"{\\"queries\\":[...]}"' -> json.loads -> str -> json.loads -> dict
Also handles the case where argument values are JSON strings that need parsing:
- {"queries": '["q1", "q2"]'} -> {"queries": ["q1", "q2"]}
Anything else returns {}.
"""
if raw_args is None:
return {}
if isinstance(raw_args, dict):
# Parse any string values that look like JSON arrays/objects
return {k: _try_parse_json_string(v) for k, v in raw_args.items()}
if not isinstance(raw_args, str):
return {}
try:
parsed1: Any = json.loads(raw_args)
except json.JSONDecodeError:
return {}
if isinstance(parsed1, dict):
# Parse any string values that look like JSON arrays/objects
return {k: _try_parse_json_string(v) for k, v in parsed1.items()}
if isinstance(parsed1, str):
try:
parsed2: Any = json.loads(parsed1)
except json.JSONDecodeError:
return {}
if isinstance(parsed2, dict):
# Parse any string values that look like JSON arrays/objects
return {k: _try_parse_json_string(v) for k, v in parsed2.items()}
return {}
return {}
def _format_message_history_for_logging(
@@ -153,21 +229,27 @@ def _update_tool_call_with_delta(
def _extract_tool_call_kickoffs(
id_to_tool_call_map: dict[int, dict[str, Any]],
turn_index: int,
tab_index: int | None = None,
sub_turn_index: int | None = None,
) -> list[ToolCallKickoff]:
"""Extract ToolCallKickoff objects from the tool call map.
Returns a list of ToolCallKickoff objects for valid tool calls (those with both id and name).
Each tool call is assigned the given turn_index and a tab_index based on its order.
Args:
id_to_tool_call_map: Map of tool call index to tool call data
turn_index: The turn index for this set of tool calls
tab_index: If provided, use this tab_index for all tool calls (otherwise auto-increment)
sub_turn_index: The sub-turn index for nested tool calls
"""
tool_calls: list[ToolCallKickoff] = []
tab_index_calculated = 0
for tool_call_data in id_to_tool_call_map.values():
if tool_call_data.get("id") and tool_call_data.get("name"):
try:
# Parse arguments JSON string to dict
tool_args = (
json.loads(tool_call_data["arguments"])
if tool_call_data["arguments"]
else {}
)
tool_args = _parse_tool_args_to_dict(tool_call_data.get("arguments"))
except json.JSONDecodeError:
# If parsing fails, try empty dict, most tools would fail though
logger.error(
@@ -180,8 +262,16 @@ def _extract_tool_call_kickoffs(
tool_call_id=tool_call_data["id"],
tool_name=tool_call_data["name"],
tool_args=tool_args,
placement=Placement(
turn_index=turn_index,
tab_index=(
tab_index_calculated if tab_index is None else tab_index
),
sub_turn_index=sub_turn_index,
),
)
)
tab_index_calculated += 1
return tool_calls
@@ -272,13 +362,19 @@ def translate_history_to_llm_format(
function_name = tool_call_data.get(
TOOL_CALL_MSG_FUNC_NAME, "unknown"
)
tool_args = tool_call_data.get(TOOL_CALL_MSG_ARGUMENTS, {})
raw_args = tool_call_data.get(TOOL_CALL_MSG_ARGUMENTS, {})
else:
function_name = "unknown"
tool_args = (
raw_args = (
tool_call_data if isinstance(tool_call_data, dict) else {}
)
# IMPORTANT: `FunctionCall.arguments` must be a JSON object string.
# If `raw_args` is accidentally a JSON string literal of an object
# (e.g. '"{\\"queries\\":[...]}"'), calling `json.dumps(raw_args)`
# would produce a quoted JSON literal and break Anthropic tool parsing.
tool_args = _parse_tool_args_to_dict(raw_args)
# NOTE: if the model is trained on a different tool call format, this may slightly interfere
# with the future tool calls, if it doesn't look like this. Almost certainly not a big deal.
tool_call = ToolCall(
@@ -324,20 +420,87 @@ def translate_history_to_llm_format(
return messages
def run_llm_step(
def _increment_turns(
turn_index: int, sub_turn_index: int | None
) -> tuple[int, int | None]:
if sub_turn_index is None:
return turn_index + 1, None
else:
return turn_index, sub_turn_index + 1
def run_llm_step_pkt_generator(
history: list[ChatMessageSimple],
tool_definitions: list[dict],
tool_choice: ToolChoiceOptions,
llm: LLM,
turn_index: int,
citation_processor: DynamicCitationProcessor,
state_container: ChatStateContainer,
placement: Placement,
state_container: ChatStateContainer | None,
citation_processor: DynamicCitationProcessor | None,
reasoning_effort: ReasoningEffort | None = None,
final_documents: list[SearchDoc] | None = None,
user_identity: LLMUserIdentity | None = None,
) -> Generator[Packet, None, tuple[LlmStepResult, int]]:
# The second return value is for the turn index because reasoning counts on the frontend as a turn
# TODO this is maybe ok but does not align well with the backend logic too well
custom_token_processor: (
Callable[[Delta | None, Any], tuple[Delta | None, Any]] | None
) = None,
max_tokens: int | None = None,
# TODO: Temporary handling of nested tool calls with agents, figure out a better way to handle this
use_existing_tab_index: bool = False,
is_deep_research: bool = False,
) -> Generator[Packet, None, tuple[LlmStepResult, bool]]:
"""Run an LLM step and stream the response as packets.
NOTE: DO NOT TOUCH THIS FUNCTION BEFORE ASKING YUHONG, this is very finicky and
delicate logic that is core to the app's main functionality.
This generator function streams LLM responses, processing reasoning content,
answer content, tool calls, and citations. It yields Packet objects for
real-time streaming to clients and accumulates the final result.
Args:
history: List of chat messages in the conversation history.
tool_definitions: List of tool definitions available to the LLM.
tool_choice: Tool choice configuration (e.g., "auto", "required", "none").
llm: Language model interface to use for generation.
turn_index: Current turn index in the conversation.
state_container: Container for storing chat state (reasoning, answers).
citation_processor: Optional processor for extracting and formatting citations
from the response. If provided, processes tokens to identify citations.
reasoning_effort: Optional reasoning effort configuration for models that
support reasoning (e.g., o1 models).
final_documents: Optional list of search documents to include in the response
start packet.
user_identity: Optional user identity information for the LLM.
custom_token_processor: Optional callable that processes each token delta
before yielding. Receives (delta, processor_state) and returns
(modified_delta, new_processor_state). Can return None for delta to skip.
sub_turn_index: Optional sub-turn index for nested tool/agent calls.
Yields:
Packet: Streaming packets containing:
- ReasoningStart/ReasoningDelta/ReasoningDone for reasoning content
- AgentResponseStart/AgentResponseDelta for answer content
- CitationInfo for extracted citations
- ToolCallKickoff for tool calls (extracted at the end)
Returns:
tuple[LlmStepResult, bool]: A tuple containing:
- LlmStepResult: The final result with accumulated reasoning, answer,
and tool calls (if any).
- bool: Whether reasoning occurred during this step. This should be used to
increment the turn index or sub_turn index for the rest of the LLM loop.
Note:
The function handles incremental state updates, saving reasoning and answer
tokens to the state container as they are generated. Tool calls are extracted
and yielded only after the stream completes.
"""
turn_index = placement.turn_index
tab_index = placement.tab_index
sub_turn_index = placement.sub_turn_index
llm_msg_history = translate_history_to_llm_format(history)
has_reasoned = 0
# Uncomment the line below to log the entire message history to the console
if LOG_ONYX_MODEL_INTERACTIONS:
@@ -351,6 +514,8 @@ def run_llm_step(
accumulated_reasoning = ""
accumulated_answer = ""
processor_state: Any = None
with generation_span(
model=llm.config.model_name,
model_config={
@@ -366,7 +531,8 @@ def run_llm_step(
tools=tool_definitions,
tool_choice=tool_choice,
structured_response_format=None, # TODO
# reasoning_effort=ReasoningEffort.OFF, # Can set this for dev/testing.
max_tokens=max_tokens,
reasoning_effort=reasoning_effort,
user_identity=user_identity,
):
if packet.usage:
@@ -379,69 +545,173 @@ def run_llm_step(
}
delta = packet.choice.delta
if custom_token_processor:
# The custom token processor can modify the deltas for specific custom logic
# It can also return a state so that it can handle aggregated delta logic etc.
# Loosely typed so the function can be flexible
modified_delta, processor_state = custom_token_processor(
delta, processor_state
)
if modified_delta is None:
continue
delta = modified_delta
# Should only happen once, frontend does not expect multiple
# ReasoningStart or ReasoningDone packets.
if delta.reasoning_content:
accumulated_reasoning += delta.reasoning_content
# Save reasoning incrementally to state container
state_container.set_reasoning_tokens(accumulated_reasoning)
if state_container:
state_container.set_reasoning_tokens(accumulated_reasoning)
if not reasoning_start:
yield Packet(
turn_index=turn_index,
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningStart(),
)
yield Packet(
turn_index=turn_index,
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningDelta(reasoning=delta.reasoning_content),
)
reasoning_start = True
if delta.content:
if reasoning_start:
yield Packet(
turn_index=turn_index,
obj=ReasoningDone(),
)
turn_index += 1
reasoning_start = False
if not answer_start:
yield Packet(
turn_index=turn_index,
obj=AgentResponseStart(
final_documents=final_documents,
),
)
answer_start = True
for result in citation_processor.process_token(delta.content):
if isinstance(result, str):
accumulated_answer += result
# Save answer incrementally to state container
state_container.set_answer_tokens(accumulated_answer)
# When tool_choice is REQUIRED, content before tool calls is reasoning/thinking
# about which tool to call, not an actual answer to the user.
# Treat this content as reasoning instead of answer.
if is_deep_research and tool_choice == ToolChoiceOptions.REQUIRED:
# Treat content as reasoning when we know tool calls are coming
accumulated_reasoning += delta.content
if state_container:
state_container.set_reasoning_tokens(accumulated_reasoning)
if not reasoning_start:
yield Packet(
turn_index=turn_index,
obj=AgentResponseDelta(content=result),
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningStart(),
)
elif isinstance(result, CitationInfo):
yield Packet(
yield Packet(
placement=Placement(
turn_index=turn_index,
obj=result,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningDelta(reasoning=delta.content),
)
reasoning_start = True
else:
# Normal flow for AUTO or NONE tool choice
if reasoning_start:
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningDone(),
)
has_reasoned = 1
turn_index, sub_turn_index = _increment_turns(
turn_index, sub_turn_index
)
reasoning_start = False
if not answer_start:
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=AgentResponseStart(
final_documents=final_documents,
),
)
answer_start = True
if citation_processor:
for result in citation_processor.process_token(delta.content):
if isinstance(result, str):
accumulated_answer += result
# Save answer incrementally to state container
if state_container:
state_container.set_answer_tokens(
accumulated_answer
)
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=AgentResponseDelta(content=result),
)
elif isinstance(result, CitationInfo):
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=result,
)
else:
# When citation_processor is None, use delta.content directly without modification
accumulated_answer += delta.content
# Save answer incrementally to state container
if state_container:
state_container.set_answer_tokens(accumulated_answer)
yield Packet(
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=AgentResponseDelta(content=delta.content),
)
if delta.tool_calls:
if reasoning_start:
yield Packet(
turn_index=turn_index,
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningDone(),
)
turn_index += 1
has_reasoned = 1
turn_index, sub_turn_index = _increment_turns(
turn_index, sub_turn_index
)
reasoning_start = False
for tool_call_delta in delta.tool_calls:
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
tool_calls = _extract_tool_call_kickoffs(id_to_tool_call_map)
# Flush custom token processor to get any final tool calls
if custom_token_processor:
flush_delta, processor_state = custom_token_processor(None, processor_state)
if flush_delta and flush_delta.tool_calls:
for tool_call_delta in flush_delta.tool_calls:
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
tool_calls = _extract_tool_call_kickoffs(
id_to_tool_call_map=id_to_tool_call_map,
turn_index=turn_index,
tab_index=tab_index if use_existing_tab_index else None,
sub_turn_index=sub_turn_index,
)
if tool_calls:
tool_calls_list: list[ToolCall] = [
ToolCall(
@@ -468,28 +738,48 @@ def run_llm_step(
tool_calls=None,
)
span_generation.span_data.output = [assistant_msg_no_tools.model_dump()]
# Close reasoning block if still open (stream ended with reasoning content)
# This may happen if the custom token processor is used to modify other packets into reasoning
# Then there won't necessarily be anything else to come after the reasoning tokens
if reasoning_start:
yield Packet(
turn_index=turn_index,
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=ReasoningDone(),
)
turn_index += 1
has_reasoned = 1
turn_index, sub_turn_index = _increment_turns(turn_index, sub_turn_index)
reasoning_start = False
# Flush any remaining content from citation processor
# Reasoning is always first so this should use the post-incremented value of turn_index
# Note that this doesn't need to handle any sub-turns as those docs will not have citations
# as clickable items and will be stripped out instead.
if citation_processor:
for result in citation_processor.process_token(None):
if isinstance(result, str):
accumulated_answer += result
# Save answer incrementally to state container
state_container.set_answer_tokens(accumulated_answer)
if state_container:
state_container.set_answer_tokens(accumulated_answer)
yield Packet(
turn_index=turn_index,
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=AgentResponseDelta(content=result),
)
elif isinstance(result, CitationInfo):
yield Packet(
turn_index=turn_index,
placement=Placement(
turn_index=turn_index,
tab_index=tab_index,
sub_turn_index=sub_turn_index,
),
obj=result,
)
@@ -514,5 +804,55 @@ def run_llm_step(
answer=accumulated_answer if accumulated_answer else None,
tool_calls=tool_calls if tool_calls else None,
),
turn_index,
bool(has_reasoned),
)
def run_llm_step(
emitter: Emitter,
history: list[ChatMessageSimple],
tool_definitions: list[dict],
tool_choice: ToolChoiceOptions,
llm: LLM,
placement: Placement,
state_container: ChatStateContainer | None,
citation_processor: DynamicCitationProcessor | None,
reasoning_effort: ReasoningEffort | None = None,
final_documents: list[SearchDoc] | None = None,
user_identity: LLMUserIdentity | None = None,
custom_token_processor: (
Callable[[Delta | None, Any], tuple[Delta | None, Any]] | None
) = None,
max_tokens: int | None = None,
use_existing_tab_index: bool = False,
is_deep_research: bool = False,
) -> tuple[LlmStepResult, bool]:
"""Wrapper around run_llm_step_pkt_generator that consumes packets and emits them.
Returns:
tuple[LlmStepResult, bool]: The LLM step result and whether reasoning occurred.
"""
step_generator = run_llm_step_pkt_generator(
history=history,
tool_definitions=tool_definitions,
tool_choice=tool_choice,
llm=llm,
placement=placement,
state_container=state_container,
citation_processor=citation_processor,
reasoning_effort=reasoning_effort,
final_documents=final_documents,
user_identity=user_identity,
custom_token_processor=custom_token_processor,
max_tokens=max_tokens,
use_existing_tab_index=use_existing_tab_index,
is_deep_research=is_deep_research,
)
while True:
try:
packet = next(step_generator)
emitter.emit(packet)
except StopIteration as e:
llm_step_result, has_reasoned = e.value
return llm_step_result, bool(has_reasoned)

View File

@@ -1,4 +1,3 @@
import os
import re
import traceback
from collections.abc import Callable
@@ -7,9 +6,8 @@ from uuid import UUID
from sqlalchemy.orm import Session
from onyx.chat.chat_milestones import process_multi_assistant_milestone
from onyx.chat.chat_state import ChatStateContainer
from onyx.chat.chat_state import run_chat_llm_with_state_containers
from onyx.chat.chat_state import run_chat_loop_with_state_containers
from onyx.chat.chat_utils import convert_chat_history
from onyx.chat.chat_utils import create_chat_history_chain
from onyx.chat.chat_utils import get_custom_agent_prompt
@@ -32,6 +30,7 @@ from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.configs.constants import MessageType
from onyx.configs.constants import MilestoneRecordType
from onyx.context.search.models import CitationDocInfo
from onyx.context.search.models import SearchDoc
from onyx.db.chat import create_new_chat_message
@@ -51,8 +50,8 @@ from onyx.file_store.models import ChatFileType
from onyx.file_store.models import FileDescriptor
from onyx.file_store.utils import load_in_memory_chat_files
from onyx.file_store.utils import verify_user_files
from onyx.llm.factory import get_llm_for_persona
from onyx.llm.factory import get_llm_token_counter
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMUserIdentity
from onyx.llm.utils import litellm_exception_to_error_msg
@@ -65,13 +64,14 @@ from onyx.server.query_and_chat.streaming_models import CitationInfo
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.utils import get_json_line
from onyx.tools.constants import SEARCH_TOOL_ID
from onyx.tools.tool import Tool
from onyx.tools.interface import Tool
from onyx.tools.tool_constructor import construct_tools
from onyx.tools.tool_constructor import CustomToolConfig
from onyx.tools.tool_constructor import SearchToolConfig
from onyx.tools.tool_constructor import SearchToolUsage
from onyx.utils.logger import setup_logger
from onyx.utils.long_term_log import LongTermLogger
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.timing import log_function_time
from onyx.utils.timing import log_generator_function_time
from shared_configs.contextvars import get_current_tenant_id
@@ -367,11 +367,10 @@ def stream_chat_message_objects(
)
# Milestone tracking, most devs using the API don't need to understand this
process_multi_assistant_milestone(
user=user,
assistant_id=persona.id,
mt_cloud_telemetry(
tenant_id=tenant_id,
db_session=db_session,
distinct_id=user.email if user else tenant_id,
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
)
if reference_doc_ids is None and retrieval_options is None:
@@ -379,7 +378,7 @@ def stream_chat_message_objects(
"Must specify a set of documents for chat or specify search options"
)
llm, fast_llm = get_llms_for_persona(
llm = get_llm_for_persona(
persona=persona,
user=user,
llm_override=new_msg_req.llm_override or chat_session.llm_override,
@@ -475,7 +474,6 @@ def stream_chat_message_objects(
emitter=emitter,
user=user,
llm=llm,
fast_llm=fast_llm,
search_tool_config=SearchToolConfig(
user_selected_filters=user_selected_filters,
project_id=(
@@ -546,7 +544,7 @@ def stream_chat_message_objects(
# for stop signals. run_llm_loop itself doesn't know about stopping.
# Note: DB session is not thread safe but nothing else uses it and the
# reference is passed directly so it's ok.
if os.environ.get("ENABLE_DEEP_RESEARCH_LOOP"): # Dev only feature flag for now
if new_msg_req.deep_research:
if chat_session.project_id:
raise RuntimeError("Deep research is not supported for projects")
@@ -554,7 +552,7 @@ def stream_chat_message_objects(
# (user has already responded to a clarification question)
skip_clarification = is_last_assistant_message_clarification(chat_history)
yield from run_chat_llm_with_state_containers(
yield from run_chat_loop_with_state_containers(
run_deep_research_llm_loop,
is_connected=check_is_connected,
emitter=emitter,
@@ -567,9 +565,10 @@ def stream_chat_message_objects(
db_session=db_session,
skip_clarification=skip_clarification,
user_identity=user_identity,
chat_session_id=str(chat_session_id),
)
else:
yield from run_chat_llm_with_state_containers(
yield from run_chat_loop_with_state_containers(
run_llm_loop,
is_connected=check_is_connected, # Not passed through to run_llm_loop
emitter=emitter,
@@ -589,6 +588,7 @@ def stream_chat_message_objects(
else None
),
user_identity=user_identity,
chat_session_id=str(chat_session_id),
)
# Determine if stopped by user

View File

@@ -22,7 +22,7 @@ from onyx.prompts.tool_prompts import PYTHON_TOOL_GUIDANCE
from onyx.prompts.tool_prompts import TOOL_DESCRIPTION_SEARCH_GUIDANCE
from onyx.prompts.tool_prompts import TOOL_SECTION_HEADER
from onyx.prompts.tool_prompts import WEB_SEARCH_GUIDANCE
from onyx.tools.tool import Tool
from onyx.tools.interface import Tool
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
@@ -156,7 +156,7 @@ def build_system_prompt(
system_prompt += company_context
if memories:
system_prompt += "\n".join(
memory.strip() for memory in memories if memory.strip()
"- " + memory.strip() for memory in memories if memory.strip()
)
# Append citation guidance after company context if placeholder was not present

View File

@@ -102,6 +102,7 @@ def _create_and_link_tool_calls(
if tool_call_info.generated_images
else None
),
tab_index=tool_call_info.tab_index,
add_only=True,
)
@@ -219,8 +220,8 @@ def save_chat_turn(
search_doc_key_to_id[search_doc_key] = db_search_doc.id
search_doc_ids_for_tool.append(db_search_doc.id)
tool_call_to_search_doc_ids[tool_call_info.tool_call_id] = (
search_doc_ids_for_tool
tool_call_to_search_doc_ids[tool_call_info.tool_call_id] = list(
set(search_doc_ids_for_tool)
)
# 3. Collect all unique SearchDoc IDs from all tool calls to link to ChatMessage

View File

@@ -541,6 +541,11 @@ GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD = int(
os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
)
# Default size threshold for Drupal Wiki attachments (10MB)
DRUPAL_WIKI_ATTACHMENT_SIZE_THRESHOLD = int(
os.environ.get("DRUPAL_WIKI_ATTACHMENT_SIZE_THRESHOLD", 10 * 1024 * 1024)
)
# Default size threshold for SharePoint files (20MB)
SHAREPOINT_CONNECTOR_SIZE_THRESHOLD = int(
os.environ.get("SHAREPOINT_CONNECTOR_SIZE_THRESHOLD", 20 * 1024 * 1024)
@@ -583,6 +588,16 @@ LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
SLACK_NUM_THREADS = int(os.getenv("SLACK_NUM_THREADS") or 8)
MAX_SLACK_QUERY_EXPANSIONS = int(os.environ.get("MAX_SLACK_QUERY_EXPANSIONS", "5"))
# Slack federated search thread context settings
# Batch size for fetching thread context (controls concurrent API calls per batch)
SLACK_THREAD_CONTEXT_BATCH_SIZE = int(
os.environ.get("SLACK_THREAD_CONTEXT_BATCH_SIZE", "5")
)
# Maximum messages to fetch thread context for (top N by relevance get full context)
MAX_SLACK_THREAD_CONTEXT_MESSAGES = int(
os.environ.get("MAX_SLACK_THREAD_CONTEXT_MESSAGES", "5")
)
DASK_JOB_CLIENT_ENABLED = (
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
)
@@ -698,6 +713,15 @@ AVERAGE_SUMMARY_EMBEDDINGS = (
MAX_TOKENS_FOR_FULL_INCLUSION = 4096
# The intent was to have this be configurable per query, but I don't think any
# codepath was actually configuring this, so for the migrated Vespa interface
# we'll just use the default value, but also have it be configurable by env var.
RECENCY_BIAS_MULTIPLIER = float(os.environ.get("RECENCY_BIAS_MULTIPLIER") or 1.0)
# Should match the rerank-count value set in
# backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd.jinja.
RERANK_COUNT = int(os.environ.get("RERANK_COUNT") or 1000)
#####
# Tool Configs

View File

@@ -209,6 +209,7 @@ class DocumentSource(str, Enum):
EGNYTE = "egnyte"
AIRTABLE = "airtable"
HIGHSPOT = "highspot"
DRUPAL_WIKI = "drupal_wiki"
IMAP = "imap"
BITBUCKET = "bitbucket"
@@ -332,7 +333,6 @@ class FileType(str, Enum):
class MilestoneRecordType(str, Enum):
TENANT_CREATED = "tenant_created"
USER_SIGNED_UP = "user_signed_up"
MULTIPLE_USERS = "multiple_users"
VISITED_ADMIN_PAGE = "visited_admin_page"
CREATED_CONNECTOR = "created_connector"
CONNECTOR_SUCCEEDED = "connector_succeeded"
@@ -564,7 +564,7 @@ REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3
if platform.system() == "Darwin":
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPALIVE] = 60 # type: ignore
else:
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60 # type: ignore
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60
class OnyxCallTypes(str, Enum):
@@ -629,6 +629,7 @@ project management, and collaboration tools into a single, customizable platform
DocumentSource.EGNYTE: "egnyte - files",
DocumentSource.AIRTABLE: "airtable - database",
DocumentSource.HIGHSPOT: "highspot - CRM data",
DocumentSource.DRUPAL_WIKI: "drupal wiki - knowledge base content (pages, spaces, attachments)",
DocumentSource.IMAP: "imap - email data",
DocumentSource.TESTRAIL: "testrail - test case management tool for QA processes",
}

View File

@@ -64,12 +64,12 @@ _BASE_EMBEDDING_MODELS = [
_BaseEmbeddingModel(
name="google/gemini-embedding-001",
dim=3072,
index_name="danswer_chunk_google_gemini_embedding_001",
index_name="danswer_chunk_gemini_embedding_001",
),
_BaseEmbeddingModel(
name="google/text-embedding-005",
dim=768,
index_name="danswer_chunk_google_text_embedding_005",
index_name="danswer_chunk_text_embedding_005",
),
_BaseEmbeddingModel(
name="voyage/voyage-large-2-instruct",

View File

@@ -51,10 +51,9 @@ CROSS_ENCODER_RANGE_MIN = 0
# Generative AI Model Configs
#####
# NOTE: the 3 below should only be used for dev.
# NOTE: the 2 below should only be used for dev.
GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY")
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION")
FAST_GEN_AI_MODEL_VERSION = os.environ.get("FAST_GEN_AI_MODEL_VERSION")
# Override the auto-detection of LLM max context length
GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 0) or None

View File

@@ -38,7 +38,7 @@ class AsanaAPI:
def __init__(
self, api_token: str, workspace_gid: str, team_gid: str | None
) -> None:
self._user = None # type: ignore
self._user = None
self.workspace_gid = workspace_gid
self.team_gid = team_gid

View File

@@ -9,14 +9,14 @@ from typing import Any
from typing import Optional
from urllib.parse import quote
import boto3 # type: ignore
from botocore.client import Config # type: ignore
import boto3
from botocore.client import Config
from botocore.credentials import RefreshableCredentials
from botocore.exceptions import ClientError
from botocore.exceptions import NoCredentialsError
from botocore.exceptions import PartialCredentialsError
from botocore.session import get_session
from mypy_boto3_s3 import S3Client # type: ignore
from mypy_boto3_s3 import S3Client
from onyx.configs.app_configs import BLOB_STORAGE_SIZE_THRESHOLD
from onyx.configs.app_configs import INDEX_BATCH_SIZE
@@ -40,8 +40,7 @@ from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import OnyxExtensionType
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.utils.logger import setup_logger
@@ -410,7 +409,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
continue
# Handle image files
if is_accepted_file_ext(file_ext, OnyxExtensionType.Multimedia):
if file_ext in OnyxFileExtensions.IMAGE_EXTENSIONS:
if not self._allow_images:
logger.debug(
f"Skipping image file: {key} (image processing not enabled)"

View File

@@ -84,6 +84,12 @@ ONE_DAY = ONE_HOUR * 24
MAX_CACHED_IDS = 100
def _get_page_id(page: dict[str, Any], allow_missing: bool = False) -> str:
if allow_missing and "id" not in page:
return "unknown"
return str(page["id"])
class ConfluenceCheckpoint(ConnectorCheckpoint):
next_page_url: str | None
@@ -299,7 +305,7 @@ class ConfluenceConnector(
page_id = page_url = ""
try:
# Extract basic page information
page_id = page["id"]
page_id = _get_page_id(page)
page_title = page["title"]
logger.info(f"Converting page {page_title} to document")
page_url = build_confluence_document_id(
@@ -382,7 +388,9 @@ class ConfluenceConnector(
this function. The returned documents/connectorfailures are for non-inline attachments
and those at the end of the page.
"""
attachment_query = self._construct_attachment_query(page["id"], start, end)
attachment_query = self._construct_attachment_query(
_get_page_id(page), start, end
)
attachment_failures: list[ConnectorFailure] = []
attachment_docs: list[Document] = []
page_url = ""
@@ -430,7 +438,7 @@ class ConfluenceConnector(
response = convert_attachment_to_content(
confluence_client=self.confluence_client,
attachment=attachment,
page_id=page["id"],
page_id=_get_page_id(page),
allow_images=self.allow_images,
)
if response is None:
@@ -515,14 +523,21 @@ class ConfluenceConnector(
except HTTPError as e:
# If we get a 403 after all retries, the user likely doesn't have permission
# to access attachments on this page. Log and skip rather than failing the whole job.
if e.response and e.response.status_code == 403:
page_title = page.get("title", "unknown")
page_id = page.get("id", "unknown")
logger.warning(
f"Permission denied (403) when fetching attachments for page '{page_title}' "
page_id = _get_page_id(page, allow_missing=True)
page_title = page.get("title", "unknown")
if e.response and e.response.status_code in [401, 403]:
failure_message_prefix = (
"Invalid credentials (401)"
if e.response.status_code == 401
else "Permission denied (403)"
)
failure_message = (
f"{failure_message_prefix} when fetching attachments for page '{page_title}' "
f"(ID: {page_id}). The user may not have permission to query attachments on this page. "
"Skipping attachments for this page."
)
logger.warning(failure_message)
# Build the page URL for the failure record
try:
page_url = build_confluence_document_id(
@@ -537,7 +552,7 @@ class ConfluenceConnector(
document_id=page_id,
document_link=page_url,
),
failure_message=f"Permission denied (403) when fetching attachments for page '{page_title}'",
failure_message=failure_message,
exception=e,
)
]
@@ -708,7 +723,7 @@ class ConfluenceConnector(
expand=restrictions_expand,
limit=_SLIM_DOC_BATCH_SIZE,
):
page_id = page["id"]
page_id = _get_page_id(page)
page_restrictions = page.get("restrictions") or {}
page_space_key = page.get("space", {}).get("key")
page_ancestors = page.get("ancestors", [])
@@ -728,7 +743,7 @@ class ConfluenceConnector(
)
# Query attachments for each page
attachment_query = self._construct_attachment_query(page["id"])
attachment_query = self._construct_attachment_query(_get_page_id(page))
for attachment in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_query,
expand=restrictions_expand,

View File

@@ -24,9 +24,9 @@ from onyx.configs.app_configs import (
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.configs.constants import FileOrigin
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import OnyxExtensionType
from onyx.file_processing.file_validation import is_valid_image_type
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.file_processing.file_types import OnyxMimeTypes
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.utils.logger import setup_logger
@@ -56,15 +56,13 @@ def validate_attachment_filetype(
"""
media_type = attachment.get("metadata", {}).get("mediaType", "")
if media_type.startswith("image/"):
return is_valid_image_type(media_type)
return media_type in OnyxMimeTypes.IMAGE_MIME_TYPES
# For non-image files, check if we support the extension
title = attachment.get("title", "")
extension = Path(title).suffix.lstrip(".").lower() if "." in title else ""
extension = get_file_ext(title)
return is_accepted_file_ext(
"." + extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
)
return extension in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS
class AttachmentProcessingResult(BaseModel):

View File

@@ -71,6 +71,13 @@ def time_str_to_utc(datetime_str: str) -> datetime:
raise ValueError(f"Unable to parse datetime string: {datetime_str}")
# TODO: use this function in other connectors
def datetime_from_utc_timestamp(timestamp: int) -> datetime:
"""Convert a Unix timestamp to a datetime object in UTC"""
return datetime.fromtimestamp(timestamp, tz=timezone.utc)
def basic_expert_info_representation(info: BasicExpertInfo) -> str | None:
if info.first_name and info.last_name:
return f"{info.first_name} {info.middle_initial} {info.last_name}"

View File

@@ -2,11 +2,11 @@ from datetime import timezone
from io import BytesIO
from typing import Any
from dropbox import Dropbox # type: ignore
from dropbox.exceptions import ApiError # type:ignore
from dropbox.exceptions import AuthError # type:ignore
from dropbox.files import FileMetadata # type:ignore
from dropbox.files import FolderMetadata # type:ignore
from dropbox import Dropbox # type: ignore[import-untyped]
from dropbox.exceptions import ApiError # type: ignore[import-untyped]
from dropbox.exceptions import AuthError
from dropbox.files import FileMetadata # type: ignore[import-untyped]
from dropbox.files import FolderMetadata
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource

View File

@@ -0,0 +1,907 @@
import mimetypes
from io import BytesIO
from typing import Any
import requests
from typing_extensions import override
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from onyx.configs.app_configs import DRUPAL_WIKI_ATTACHMENT_SIZE_THRESHOLD
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
datetime_from_utc_timestamp,
)
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import rate_limit_builder
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import rl_requests
from onyx.connectors.drupal_wiki.models import DrupalWikiCheckpoint
from onyx.connectors.drupal_wiki.models import DrupalWikiPage
from onyx.connectors.drupal_wiki.models import DrupalWikiPageResponse
from onyx.connectors.drupal_wiki.models import DrupalWikiSpaceResponse
from onyx.connectors.drupal_wiki.utils import build_drupal_wiki_document_id
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import ConnectorFailure
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import ImageSection
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.b64 import get_image_type_from_bytes
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
logger = setup_logger()
MAX_API_PAGE_SIZE = 2000 # max allowed by API
DRUPAL_WIKI_SPACE_KEY = "space"
rate_limited_get = retry_builder()(
rate_limit_builder(max_calls=10, period=1)(rl_requests.get)
)
class DrupalWikiConnector(
CheckpointedConnector[DrupalWikiCheckpoint],
SlimConnector,
):
def __init__(
self,
base_url: str,
spaces: list[str] | None = None,
pages: list[str] | None = None,
include_all_spaces: bool = False,
batch_size: int = INDEX_BATCH_SIZE,
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
drupal_wiki_scope: str | None = None,
include_attachments: bool = False,
allow_images: bool = False,
) -> None:
"""
Initialize the Drupal Wiki connector.
Args:
base_url: The base URL of the Drupal Wiki instance (e.g., https://help.drupal-wiki.com)
spaces: List of space IDs to index. If None and include_all_spaces is False, no spaces will be indexed.
pages: List of page IDs to index. If provided, only these specific pages will be indexed.
include_all_spaces: If True, all spaces will be indexed regardless of the spaces parameter.
batch_size: Number of documents to process in a batch.
continue_on_failure: If True, continue indexing even if some documents fail.
drupal_wiki_scope: The selected tab value from the frontend. If "all_spaces", all spaces will be indexed.
include_attachments: If True, enable processing of page attachments including images and documents.
allow_images: If True, enable processing of image attachments.
"""
self.base_url = base_url.rstrip("/")
self.spaces = spaces or []
self.pages = pages or []
# Determine whether to include all spaces based on the selected tab
# If drupal_wiki_scope is "all_spaces", we should index all spaces
# If it's "specific_spaces", we should only index the specified spaces
# If it's None, we use the include_all_spaces parameter
if drupal_wiki_scope is not None:
logger.debug(f"drupal_wiki_scope is set to {drupal_wiki_scope}")
self.include_all_spaces = drupal_wiki_scope == "all_spaces"
# If scope is specific_spaces, include_all_spaces correctly defaults to False
else:
logger.debug(
f"drupal_wiki_scope is not set, using include_all_spaces={include_all_spaces}"
)
self.include_all_spaces = include_all_spaces
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
# Attachment processing configuration
self.include_attachments = include_attachments
self.allow_images = allow_images
self.headers: dict[str, str] = {"Accept": "application/json"}
self._api_token: str | None = None # set by load_credentials
def set_allow_images(self, value: bool) -> None:
logger.info(f"Setting allow_images to {value}.")
self.allow_images = value
def _get_page_attachments(self, page_id: int) -> list[dict[str, Any]]:
"""
Get all attachments for a specific page.
Args:
page_id: ID of the page.
Returns:
List of attachment dictionaries.
"""
url = f"{self.base_url}/api/rest/scope/api/attachment"
params = {"pageId": str(page_id)}
logger.debug(f"Fetching attachments for page {page_id} from {url}")
try:
response = rate_limited_get(url, headers=self.headers, params=params)
response.raise_for_status()
attachments = response.json()
logger.info(f"Found {len(attachments)} attachments for page {page_id}")
return attachments
except Exception as e:
logger.warning(f"Failed to fetch attachments for page {page_id}: {e}")
return []
def _download_attachment(self, attachment_id: int) -> bytes:
"""
Download attachment content.
Args:
attachment_id: ID of the attachment to download.
Returns:
Raw bytes of the attachment.
"""
url = f"{self.base_url}/api/rest/scope/api/attachment/{attachment_id}/download"
logger.info(f"Downloading attachment {attachment_id} from {url}")
# Use headers without Accept for binary downloads
download_headers = {"Authorization": f"Bearer {self._api_token}"}
response = rate_limited_get(url, headers=download_headers)
response.raise_for_status()
return response.content
def _validate_attachment_filetype(self, attachment: dict[str, Any]) -> bool:
"""
Validate if the attachment file type is supported.
Args:
attachment: Attachment dictionary from Drupal Wiki API.
Returns:
True if the file type is supported, False otherwise.
"""
file_name = attachment.get("fileName", "")
if not file_name:
return False
# Get file extension
file_extension = get_file_ext(file_name)
if file_extension in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS:
return True
logger.warning(f"Unsupported file type: {file_extension} for {file_name}")
return False
def _get_media_type_from_filename(self, filename: str) -> str:
"""
Get media type from filename using the standard mimetypes library.
Args:
filename: The filename.
Returns:
Media type string.
"""
mime_type, _encoding = mimetypes.guess_type(filename)
return mime_type or "application/octet-stream"
def _process_attachment(
self,
attachment: dict[str, Any],
page_id: int,
download_url: str,
) -> tuple[list[TextSection | ImageSection], str | None]:
"""
Process a single attachment and return generated sections.
Args:
attachment: Attachment dictionary from Drupal Wiki API.
page_id: ID of the parent page.
download_url: Direct download URL for the attachment.
Returns:
Tuple of (sections, error_message). If error_message is not None, the
sections list should be treated as invalid.
"""
sections: list[TextSection | ImageSection] = []
try:
if not self._validate_attachment_filetype(attachment):
return (
[],
f"Unsupported file type: {attachment.get('fileName', 'unknown')}",
)
attachment_id = attachment["id"]
file_name = attachment.get("fileName", f"attachment_{attachment_id}")
file_size = attachment.get("fileSize", 0)
media_type = self._get_media_type_from_filename(file_name)
if file_size > DRUPAL_WIKI_ATTACHMENT_SIZE_THRESHOLD:
return [], f"Attachment too large: {file_size} bytes"
try:
raw_bytes = self._download_attachment(attachment_id)
except Exception as e:
return [], f"Failed to download attachment: {e}"
if media_type.startswith("image/"):
if not self.allow_images:
logger.info(
f"Skipping image attachment {file_name} because allow_images is False",
)
return [], None
try:
image_section, _ = store_image_and_create_section(
image_data=raw_bytes,
file_id=str(attachment_id),
display_name=attachment.get(
"name", attachment.get("fileName", "Unknown")
),
link=download_url,
media_type=media_type,
file_origin=FileOrigin.CONNECTOR,
)
sections.append(image_section)
logger.debug(f"Stored image attachment with file name: {file_name}")
except Exception as e:
return [], f"Image storage failed: {e}"
return sections, None
image_counter = 0
def _store_embedded_image(image_data: bytes, image_name: str) -> None:
nonlocal image_counter
if not self.allow_images:
return
media_for_image = self._get_media_type_from_filename(image_name)
if media_for_image == "application/octet-stream":
try:
media_for_image = get_image_type_from_bytes(image_data)
except ValueError:
logger.warning(
f"Unable to determine media type for embedded image {image_name} on attachment {file_name}"
)
image_counter += 1
display_name = (
image_name
or f"{attachment.get('name', file_name)} - embedded image {image_counter}"
)
try:
image_section, _ = store_image_and_create_section(
image_data=image_data,
file_id=f"{attachment_id}_embedded_{image_counter}",
display_name=display_name,
link=download_url,
media_type=media_for_image,
file_origin=FileOrigin.CONNECTOR,
)
sections.append(image_section)
except Exception as err:
logger.warning(
f"Failed to store embedded image {image_name or image_counter} for attachment {file_name}: {err}"
)
extraction_result = extract_text_and_images(
file=BytesIO(raw_bytes),
file_name=file_name,
content_type=media_type,
image_callback=_store_embedded_image if self.allow_images else None,
)
text_content = extraction_result.text_content.strip()
if text_content:
sections.insert(0, TextSection(text=text_content, link=download_url))
logger.info(
f"Extracted {len(text_content)} characters from {file_name}"
)
elif not sections:
return [], f"No text extracted for {file_name}"
return sections, None
except Exception as e:
logger.error(
f"Failed to process attachment {attachment.get('name', 'unknown')} on page {page_id}: {e}"
)
return [], f"Failed to process attachment: {e}"
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""
Load credentials for the Drupal Wiki connector.
Args:
credentials: Dictionary containing the API token.
Returns:
None
"""
api_token = credentials.get("drupal_wiki_api_token", "").strip()
if not api_token:
raise ConnectorValidationError(
"API token is required for Drupal Wiki connector"
)
self._api_token = api_token
self.headers.update(
{
"Authorization": f"Bearer {api_token}",
}
)
return None
def _get_space_ids(self) -> list[int]:
"""
Get all space IDs from the Drupal Wiki instance.
Returns:
List of space IDs (deduplicated). The list is sorted to be deterministic.
"""
url = f"{self.base_url}/api/rest/scope/api/space"
size = MAX_API_PAGE_SIZE
page = 0
all_space_ids: set[int] = set()
has_more = True
last_num_ids = -1
while has_more and len(all_space_ids) > last_num_ids:
last_num_ids = len(all_space_ids)
params = {"size": size, "page": page}
logger.debug(f"Fetching spaces from {url} (page={page}, size={size})")
response = rate_limited_get(url, headers=self.headers, params=params)
response.raise_for_status()
resp_json = response.json()
space_response = DrupalWikiSpaceResponse.model_validate(resp_json)
logger.info(f"Fetched {len(space_response.content)} spaces from {page}")
# Collect ids into the set to deduplicate
for space in space_response.content:
all_space_ids.add(space.id)
# Continue if we got a full page, indicating there might be more
has_more = len(space_response.content) >= size
page += 1
# Return a deterministic, sorted list of ids
space_id_list = list(sorted(all_space_ids))
logger.debug(f"Total spaces fetched: {len(space_id_list)}")
return space_id_list
def _get_pages_for_space(
self, space_id: int, modified_after: SecondsSinceUnixEpoch | None = None
) -> list[DrupalWikiPage]:
"""
Get all pages for a specific space, optionally filtered by modification time.
Args:
space_id: ID of the space.
modified_after: Only return pages modified after this timestamp (seconds since Unix epoch).
Returns:
List of DrupalWikiPage objects.
"""
url = f"{self.base_url}/api/rest/scope/api/page"
size = MAX_API_PAGE_SIZE
page = 0
all_pages = []
has_more = True
while has_more:
params: dict[str, str | int] = {
DRUPAL_WIKI_SPACE_KEY: str(space_id),
"size": size,
"page": page,
}
# Add modifiedAfter parameter if provided
if modified_after is not None:
params["modifiedAfter"] = int(modified_after)
logger.debug(
f"Fetching pages for space {space_id} from {url} ({page=}, {size=}, {modified_after=})"
)
response = rate_limited_get(url, headers=self.headers, params=params)
response.raise_for_status()
resp_json = response.json()
try:
page_response = DrupalWikiPageResponse.model_validate(resp_json)
except Exception as e:
logger.error(f"Failed to validate Drupal Wiki page response: {e}")
raise ConnectorValidationError(f"Invalid API response format: {e}")
logger.info(
f"Fetched {len(page_response.content)} pages in space {space_id} (page={page})"
)
# Pydantic should automatically parse content items as DrupalWikiPage objects
# If validation fails, it will raise an exception which we should catch
all_pages.extend(page_response.content)
# Continue if we got a full page, indicating there might be more
has_more = len(page_response.content) >= size
page += 1
logger.debug(f"Total pages fetched for space {space_id}: {len(all_pages)}")
return all_pages
def _get_page_content(self, page_id: int) -> DrupalWikiPage:
"""
Get the content of a specific page.
Args:
page_id: ID of the page.
Returns:
DrupalWikiPage object.
"""
url = f"{self.base_url}/api/rest/scope/api/page/{page_id}"
response = rate_limited_get(url, headers=self.headers)
response.raise_for_status()
return DrupalWikiPage.model_validate(response.json())
def _process_page(self, page: DrupalWikiPage) -> Document | ConnectorFailure:
"""
Process a page and convert it to a Document.
Args:
page: DrupalWikiPage object.
Returns:
Document object or ConnectorFailure.
"""
try:
# Extract text from HTML, handle None body
text_content = parse_html_page_basic(page.body or "")
# Ensure text_content is a string, not None
if text_content is None:
text_content = ""
# Create document URL
page_url = build_drupal_wiki_document_id(self.base_url, page.id)
# Create sections with just the page content
sections: list[TextSection | ImageSection] = [
TextSection(text=text_content, link=page_url)
]
# Only process attachments if self.include_attachments is True
if self.include_attachments:
attachments = self._get_page_attachments(page.id)
for attachment in attachments:
logger.info(
f"Processing attachment: {attachment.get('name', 'Unknown')} (ID: {attachment['id']})"
)
# Use downloadUrl from API; fallback to page URL
raw_download = attachment.get("downloadUrl")
if raw_download:
download_url = (
raw_download
if raw_download.startswith("http")
else f"{self.base_url.rstrip('/')}" + raw_download
)
else:
download_url = page_url
# Process the attachment
attachment_sections, error = self._process_attachment(
attachment, page.id, download_url
)
if error:
logger.warning(
f"Error processing attachment {attachment.get('name', 'Unknown')}: {error}"
)
continue
if attachment_sections:
sections.extend(attachment_sections)
logger.debug(
f"Added {len(attachment_sections)} section(s) for attachment {attachment.get('name', 'Unknown')}"
)
# Create metadata
metadata: dict[str, str | list[str]] = {
"space_id": str(page.homeSpace),
"page_id": str(page.id),
"type": page.type,
}
# Create document
return Document(
id=page_url,
sections=sections,
source=DocumentSource.DRUPAL_WIKI,
semantic_identifier=page.title,
metadata=metadata,
doc_updated_at=datetime_from_utc_timestamp(page.lastModified),
)
except Exception as e:
logger.error(f"Error processing page {page.id}: {e}")
return ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(page.id),
document_link=build_drupal_wiki_document_id(self.base_url, page.id),
),
failure_message=f"Error processing page {page.id}: {e}",
exception=e,
)
@override
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: DrupalWikiCheckpoint,
) -> CheckpointOutput[DrupalWikiCheckpoint]:
"""
Load documents from a checkpoint.
Args:
start: Start time as seconds since Unix epoch.
end: End time as seconds since Unix epoch.
checkpoint: Checkpoint to resume from.
Returns:
Generator yielding documents and the updated checkpoint.
"""
# Ensure page_ids is not None
if checkpoint.page_ids is None:
checkpoint.page_ids = []
# Initialize page_ids from self.pages if not already set
if not checkpoint.page_ids and self.pages:
logger.info(f"Initializing page_ids from self.pages: {self.pages}")
checkpoint.page_ids = [int(page_id.strip()) for page_id in self.pages]
# Ensure spaces is not None
if checkpoint.spaces is None:
checkpoint.spaces = []
while checkpoint.current_page_id_index < len(checkpoint.page_ids):
page_id = checkpoint.page_ids[checkpoint.current_page_id_index]
logger.debug(f"Processing page ID: {page_id}")
try:
# Get the page content directly
page = self._get_page_content(page_id)
# Skip pages outside the time range
if not self._is_page_in_time_range(page.lastModified, start, end):
logger.info(f"Skipping page {page_id} - outside time range")
checkpoint.current_page_id_index += 1
continue
# Process the page
doc_or_failure = self._process_page(page)
yield doc_or_failure
except Exception as e:
logger.error(f"Error processing page ID {page_id}: {e}")
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(page_id),
document_link=build_drupal_wiki_document_id(
self.base_url, page_id
),
),
failure_message=f"Error processing page ID {page_id}: {e}",
exception=e,
)
# Move to the next page ID
checkpoint.current_page_id_index += 1
# TODO: The main benefit of CheckpointedConnectors is that they can "save their work"
# by storing a checkpoint so transient errors are easy to recover from: simply resume
# from the last checkpoint. The way to get checkpoints saved is to return them somewhere
# in the middle of this function. The guarantee our checkpointing system gives to you,
# the connector implementer, is that when you return a checkpoint, this connector will
# at a later time (generally within a few seconds) call the load_from_checkpoint function
# again with the checkpoint you last returned as long as has_more=True.
# Process spaces if include_all_spaces is True or spaces are provided
if self.include_all_spaces or self.spaces:
# If include_all_spaces is True, always fetch all spaces
if self.include_all_spaces:
logger.info("Fetching all spaces")
# Fetch all spaces
all_space_ids = self._get_space_ids()
# checkpoint.spaces expects a list of ints; assign returned list
checkpoint.spaces = all_space_ids
logger.info(f"Found {len(checkpoint.spaces)} spaces to process")
# Otherwise, use provided spaces if checkpoint is empty
elif not checkpoint.spaces:
logger.info(f"Using provided spaces: {self.spaces}")
# Use provided spaces
checkpoint.spaces = [int(space_id.strip()) for space_id in self.spaces]
# Process spaces from the checkpoint
while checkpoint.current_space_index < len(checkpoint.spaces):
space_id = checkpoint.spaces[checkpoint.current_space_index]
logger.debug(f"Processing space ID: {space_id}")
# Get pages for the current space, filtered by start time if provided
pages = self._get_pages_for_space(space_id, modified_after=start)
# Process pages from the checkpoint
while checkpoint.current_page_index < len(pages):
page = pages[checkpoint.current_page_index]
logger.debug(f"Processing page: {page.title} (ID: {page.id})")
# For space-based pages, we already filtered by modifiedAfter in the API call
# Only need to check the end time boundary
if end and page.lastModified >= end:
logger.info(
f"Skipping page {page.id} - outside time range (after end)"
)
checkpoint.current_page_index += 1
continue
# Process the page
doc_or_failure = self._process_page(page)
yield doc_or_failure
# Move to the next page
checkpoint.current_page_index += 1
# Move to the next space
checkpoint.current_space_index += 1
checkpoint.current_page_index = 0
# All spaces and pages processed
logger.info("Finished processing all spaces and pages")
checkpoint.has_more = False
return checkpoint
@override
def build_dummy_checkpoint(self) -> DrupalWikiCheckpoint:
"""
Build a dummy checkpoint.
Returns:
DrupalWikiCheckpoint with default values.
"""
return DrupalWikiCheckpoint(
has_more=True,
current_space_index=0,
current_page_index=0,
current_page_id_index=0,
spaces=[],
page_ids=[],
is_processing_specific_pages=False,
)
@override
def validate_checkpoint_json(self, checkpoint_json: str) -> DrupalWikiCheckpoint:
"""
Validate a checkpoint JSON string.
Args:
checkpoint_json: JSON string representing a checkpoint.
Returns:
Validated DrupalWikiCheckpoint.
"""
return DrupalWikiCheckpoint.model_validate_json(checkpoint_json)
# TODO: unify approach with load_from_checkpoint.
# Ideally slim retrieval shares a lot of the same code with non-slim
# and we pass in a param is_slim to the main helper function
# that does the retrieval.
@override
def retrieve_all_slim_docs(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
"""
Retrieve all slim documents.
Args:
start: Start time as seconds since Unix epoch.
end: End time as seconds since Unix epoch.
callback: Callback for indexing heartbeat.
Returns:
Generator yielding batches of SlimDocument objects.
"""
slim_docs: list[SlimDocument] = []
logger.info(
f"Starting retrieve_all_slim_docs with include_all_spaces={self.include_all_spaces}, spaces={self.spaces}"
)
# Process specific page IDs if provided
if self.pages:
logger.info(f"Processing specific pages: {self.pages}")
for page_id in self.pages:
try:
# Get the page content directly
page_content = self._get_page_content(int(page_id.strip()))
# Skip pages outside the time range
if not self._is_page_in_time_range(
page_content.lastModified, start, end
):
logger.info(f"Skipping page {page_id} - outside time range")
continue
# Create slim document for the page
page_url = build_drupal_wiki_document_id(
self.base_url, page_content.id
)
slim_docs.append(
SlimDocument(
id=page_url,
)
)
logger.debug(f"Added slim document for page {page_content.id}")
# Process attachments for this page
attachments = self._get_page_attachments(page_content.id)
for attachment in attachments:
if self._validate_attachment_filetype(attachment):
attachment_url = f"{page_url}#attachment-{attachment['id']}"
slim_docs.append(
SlimDocument(
id=attachment_url,
)
)
logger.debug(
f"Added slim document for attachment {attachment['id']}"
)
# Yield batch if it reaches the batch size
if len(slim_docs) >= self.batch_size:
logger.debug(
f"Yielding batch of {len(slim_docs)} slim documents"
)
yield slim_docs
slim_docs = []
if callback and callback.should_stop():
return
if callback:
callback.progress("retrieve_all_slim_docs", 1)
except Exception as e:
logger.error(
f"Error processing page ID {page_id} for slim documents: {e}"
)
# Process spaces if include_all_spaces is True or spaces are provided
if self.include_all_spaces or self.spaces:
logger.info("Processing spaces for slim documents")
# Get spaces to process
spaces_to_process = []
if self.include_all_spaces:
logger.info("Fetching all spaces for slim documents")
# Fetch all spaces
all_space_ids = self._get_space_ids()
spaces_to_process = all_space_ids
logger.info(f"Found {len(spaces_to_process)} spaces to process")
else:
logger.info(f"Using provided spaces: {self.spaces}")
# Use provided spaces
spaces_to_process = [int(space_id.strip()) for space_id in self.spaces]
# Process each space
for space_id in spaces_to_process:
logger.info(f"Processing space ID: {space_id}")
# Get pages for the current space, filtered by start time if provided
pages = self._get_pages_for_space(space_id, modified_after=start)
# Process each page
for page in pages:
logger.debug(f"Processing page: {page.title} (ID: {page.id})")
# Skip pages outside the time range
if end and page.lastModified >= end:
logger.info(
f"Skipping page {page.id} - outside time range (after end)"
)
continue
# Create slim document for the page
page_url = build_drupal_wiki_document_id(self.base_url, page.id)
slim_docs.append(
SlimDocument(
id=page_url,
)
)
logger.info(f"Added slim document for page {page.id}")
# Process attachments for this page
attachments = self._get_page_attachments(page.id)
for attachment in attachments:
if self._validate_attachment_filetype(attachment):
attachment_url = f"{page_url}#attachment-{attachment['id']}"
slim_docs.append(
SlimDocument(
id=attachment_url,
)
)
logger.info(
f"Added slim document for attachment {attachment['id']}"
)
# Yield batch if it reaches the batch size
if len(slim_docs) >= self.batch_size:
logger.info(
f"Yielding batch of {len(slim_docs)} slim documents"
)
yield slim_docs
slim_docs = []
if callback and callback.should_stop():
return
if callback:
callback.progress("retrieve_all_slim_docs", 1)
# Yield remaining documents
if slim_docs:
logger.debug(f"Yielding final batch of {len(slim_docs)} slim documents")
yield slim_docs
def validate_connector_settings(self) -> None:
"""
Validate the connector settings.
Raises:
ConnectorValidationError: If the settings are invalid.
"""
if not self.headers:
raise ConnectorMissingCredentialError("Drupal Wiki")
try:
# Try to fetch spaces to validate the connection
# Call the new helper which returns the list of space ids
self._get_space_ids()
except requests.exceptions.RequestException as e:
raise ConnectorValidationError(f"Failed to connect to Drupal Wiki: {e}")
def _is_page_in_time_range(
self,
last_modified: int,
start: SecondsSinceUnixEpoch | None,
end: SecondsSinceUnixEpoch | None,
) -> bool:
"""
Check if a page's last modified timestamp falls within the specified time range.
Args:
last_modified: The page's last modified timestamp.
start: Start time as seconds since Unix epoch (inclusive).
end: End time as seconds since Unix epoch (exclusive).
Returns:
True if the page is within the time range, False otherwise.
"""
return (not start or last_modified >= start) and (
not end or last_modified < end
)

View File

@@ -0,0 +1,75 @@
from enum import Enum
from typing import Generic
from typing import List
from typing import Optional
from typing import TypeVar
from pydantic import BaseModel
from onyx.connectors.interfaces import ConnectorCheckpoint
class SpaceAccessStatus(str, Enum):
"""Enum for Drupal Wiki space access status"""
PRIVATE = "PRIVATE"
ANONYMOUS = "ANONYMOUS"
AUTHENTICATED = "AUTHENTICATED"
class DrupalWikiSpace(BaseModel):
"""Model for a Drupal Wiki space"""
id: int
name: str
type: str
description: Optional[str] = None
accessStatus: Optional[SpaceAccessStatus] = None
color: Optional[str] = None
class DrupalWikiPage(BaseModel):
"""Model for a Drupal Wiki page"""
id: int
title: str
homeSpace: int
lastModified: int
type: str
body: Optional[str] = None
T = TypeVar("T")
class DrupalWikiBaseResponse(BaseModel, Generic[T]):
"""Base model for Drupal Wiki API responses"""
totalPages: int
totalElements: int
size: int
content: List[T]
number: int
first: bool
last: bool
numberOfElements: int
empty: bool
class DrupalWikiSpaceResponse(DrupalWikiBaseResponse[DrupalWikiSpace]):
"""Model for the response from the Drupal Wiki spaces API"""
class DrupalWikiPageResponse(DrupalWikiBaseResponse[DrupalWikiPage]):
"""Model for the response from the Drupal Wiki pages API"""
class DrupalWikiCheckpoint(ConnectorCheckpoint):
"""Checkpoint for the Drupal Wiki connector"""
current_space_index: int = 0
current_page_index: int = 0
current_page_id_index: int = 0
spaces: List[int] = []
page_ids: List[int] = []
is_processing_specific_pages: bool = False

View File

@@ -0,0 +1,10 @@
from onyx.utils.logger import setup_logger
logger = setup_logger()
def build_drupal_wiki_document_id(base_url: str, page_id: int) -> str:
"""Build a document ID for a Drupal Wiki page using the real URL format"""
# Ensure base_url ends with a slash
base_url = base_url.rstrip("/") + "/"
return f"{base_url}node/{page_id}"

View File

@@ -28,10 +28,8 @@ from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import detect_encoding
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import is_text_file_extension
from onyx.file_processing.extract_file_text import OnyxExtensionType
from onyx.file_processing.extract_file_text import read_text_file
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import request_with_retries
@@ -70,14 +68,15 @@ def _process_egnyte_file(
file_name = file_metadata["name"]
extension = get_file_ext(file_name)
if not is_accepted_file_ext(
extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
):
# Explicitly excluding image extensions here. TODO: consider allowing images
if extension not in OnyxFileExtensions.TEXT_AND_DOCUMENT_EXTENSIONS:
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
return None
# Extract text content based on file type
if is_text_file_extension(file_name):
# TODO @wenxi-onyx: convert to extract_text_and_images
if extension in OnyxFileExtensions.PLAIN_TEXT_EXTENSIONS:
encoding = detect_encoding(file_content)
file_content_raw, file_metadata = read_text_file(
file_content, encoding=encoding, ignore_onyx_metadata=False

View File

@@ -18,8 +18,7 @@ from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import OnyxExtensionType
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
@@ -90,7 +89,7 @@ def _process_file(
# Get file extension and determine file type
extension = get_file_ext(file_name)
if not is_accepted_file_ext(extension, OnyxExtensionType.All):
if extension not in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS:
logger.warning(
f"Skipping file '{file_name}' with unrecognized extension '{extension}'"
)
@@ -111,7 +110,7 @@ def _process_file(
title = metadata.get("title") or file_display_name
# 1) If the file itself is an image, handle that scenario quickly
if extension in LoadConnector.IMAGE_EXTENSIONS:
if extension in OnyxFileExtensions.IMAGE_EXTENSIONS:
# Read the image data
image_data = file.read()
if not image_data:

View File

@@ -5,8 +5,8 @@ from typing import Any
from typing import cast
from typing import Dict
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
from googleapiclient.errors import HttpError # type: ignore
from onyx.access.models import ExternalAccess

View File

@@ -14,9 +14,9 @@ from typing import cast
from typing import Protocol
from urllib.parse import urlparse
from google.auth.exceptions import RefreshError # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from google.auth.exceptions import RefreshError
from google.oauth2.credentials import Credentials as OAuthCredentials
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
from googleapiclient.errors import HttpError # type: ignore
from typing_extensions import override
@@ -1006,7 +1006,7 @@ class GoogleDriveConnector(
file.user_email,
)
if file.error is None:
file.error = exc # type: ignore[assignment]
file.error = exc
yield file
continue

View File

@@ -29,14 +29,14 @@ from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import ImageSection
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import ALL_ACCEPTED_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import docx_to_text_and_images
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import pptx_to_text
from onyx.file_processing.extract_file_text import read_docx_file
from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_processing.extract_file_text import xlsx_to_text
from onyx.file_processing.file_validation import is_valid_image_type
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.file_processing.file_types import OnyxMimeTypes
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import (
@@ -114,14 +114,6 @@ def onyx_document_id_from_drive_file(file: GoogleDriveFileType) -> str:
return urlunparse(parsed_url)
def is_gdrive_image_mime_type(mime_type: str) -> bool:
"""
Return True if the mime_type is a common image type in GDrive.
(e.g. 'image/png', 'image/jpeg')
"""
return is_valid_image_type(mime_type)
def download_request(
service: GoogleDriveService, file_id: str, size_threshold: int
) -> bytes:
@@ -173,7 +165,7 @@ def _download_and_extract_sections_basic(
def response_call() -> bytes:
return download_request(service, file_id, size_threshold)
if is_gdrive_image_mime_type(mime_type):
if mime_type in OnyxMimeTypes.IMAGE_MIME_TYPES:
# Skip images if not explicitly enabled
if not allow_images:
return []
@@ -222,7 +214,7 @@ def _download_and_extract_sections_basic(
mime_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
text, _ = docx_to_text_and_images(io.BytesIO(response_call()))
text, _ = read_docx_file(io.BytesIO(response_call()))
return [TextSection(link=link, text=text)]
elif (
@@ -260,7 +252,7 @@ def _download_and_extract_sections_basic(
# Final attempt at extracting text
file_ext = get_file_ext(file.get("name", ""))
if file_ext not in ALL_ACCEPTED_FILE_EXTENSIONS:
if file_ext not in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS:
logger.warning(f"Skipping file {file.get('name')} due to extension.")
return []

View File

@@ -1,9 +1,9 @@
import json
from typing import Any
from google.auth.transport.requests import Request # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from google.auth.transport.requests import Request
from google.oauth2.credentials import Credentials as OAuthCredentials
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
from onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
from onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET

View File

@@ -4,7 +4,7 @@ from urllib.parse import parse_qs
from urllib.parse import ParseResult
from urllib.parse import urlparse
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
from sqlalchemy.orm import Session
@@ -179,7 +179,7 @@ def get_auth_url(credential_id: int, source: DocumentSource) -> str:
get_kv_store().store(
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
) # type: ignore
)
return str(auth_url)

View File

@@ -1,11 +1,11 @@
from collections.abc import Callable
from typing import Any
from google.auth.exceptions import RefreshError # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient.discovery import build # type: ignore
from googleapiclient.discovery import Resource # type: ignore
from google.auth.exceptions import RefreshError
from google.oauth2.credentials import Credentials as OAuthCredentials
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
from googleapiclient.discovery import build # type: ignore[import-untyped]
from googleapiclient.discovery import Resource
from onyx.utils.logger import setup_logger

View File

@@ -23,9 +23,8 @@ from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import ACCEPTED_DOCUMENT_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -309,10 +308,7 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync)
elif (
is_valid_format
and (
file_extension in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS
or file_extension in ACCEPTED_DOCUMENT_FILE_EXTENSIONS
)
and file_extension in OnyxFileExtensions.TEXT_AND_DOCUMENT_EXTENSIONS
and can_download
):
content_response = self.client.get_item_content(item_id)

View File

@@ -27,8 +27,6 @@ CT = TypeVar("CT", bound=ConnectorCheckpoint)
class BaseConnector(abc.ABC, Generic[CT]):
REDIS_KEY_PREFIX = "da_connector_data:"
# Common image file extensions supported across connectors
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
@abc.abstractmethod
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:

View File

@@ -1,4 +1,5 @@
import copy
import json
import os
from collections.abc import Callable
from collections.abc import Iterable
@@ -9,6 +10,7 @@ from datetime import timezone
from typing import Any
from jira import JIRA
from jira.exceptions import JIRAError
from jira.resources import Issue
from more_itertools import chunked
from typing_extensions import override
@@ -134,6 +136,80 @@ def _perform_jql_search(
return _perform_jql_search_v2(jira_client, jql, start, max_results, fields)
def _handle_jira_search_error(e: Exception, jql: str) -> None:
"""Handle common Jira search errors and raise appropriate exceptions.
Args:
e: The exception raised by the Jira API
jql: The JQL query that caused the error
Raises:
ConnectorValidationError: For HTTP 400 errors (invalid JQL or project)
CredentialExpiredError: For HTTP 401 errors
InsufficientPermissionsError: For HTTP 403 errors
Exception: Re-raises the original exception for other error types
"""
# Extract error information from the exception
error_text = ""
status_code = None
def _format_error_text(error_payload: Any) -> str:
error_messages = (
error_payload.get("errorMessages", [])
if isinstance(error_payload, dict)
else []
)
if error_messages:
return (
"; ".join(error_messages)
if isinstance(error_messages, list)
else str(error_messages)
)
return str(error_payload)
# Try to get status code and error text from JIRAError or requests response
if hasattr(e, "status_code"):
status_code = e.status_code
raw_text = getattr(e, "text", "")
if isinstance(raw_text, str):
try:
error_text = _format_error_text(json.loads(raw_text))
except Exception:
error_text = raw_text
else:
error_text = str(raw_text)
elif hasattr(e, "response") and e.response is not None:
status_code = e.response.status_code
# Try JSON first, fall back to text
try:
error_json = e.response.json()
error_text = _format_error_text(error_json)
except Exception:
error_text = e.response.text
# Handle specific status codes
if status_code == 400:
if "does not exist for the field 'project'" in error_text:
raise ConnectorValidationError(
f"The specified Jira project does not exist or you don't have access to it. "
f"JQL query: {jql}. Error: {error_text}"
)
raise ConnectorValidationError(
f"Invalid JQL query. JQL: {jql}. Error: {error_text}"
)
elif status_code == 401:
raise CredentialExpiredError(
"Jira credentials are expired or invalid (HTTP 401)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
f"Insufficient permissions to execute JQL query. JQL: {jql}"
)
# Re-raise for other error types
raise e
def enhanced_search_ids(
jira_client: JIRA, jql: str, nextPageToken: str | None = None
) -> tuple[list[str], str | None]:
@@ -149,8 +225,15 @@ def enhanced_search_ids(
"nextPageToken": nextPageToken,
"fields": "id",
}
response = jira_client._session.get(enhanced_search_path, params=params).json()
return [str(issue["id"]) for issue in response["issues"]], response.get(
try:
response = jira_client._session.get(enhanced_search_path, params=params)
response.raise_for_status()
response_json = response.json()
except Exception as e:
_handle_jira_search_error(e, jql)
raise # Explicitly re-raise for type checker, should never reach here
return [str(issue["id"]) for issue in response_json["issues"]], response_json.get(
"nextPageToken"
)
@@ -232,12 +315,16 @@ def _perform_jql_search_v2(
f"Fetching Jira issues with JQL: {jql}, "
f"starting at {start}, max results: {max_results}"
)
issues = jira_client.search_issues(
jql_str=jql,
startAt=start,
maxResults=max_results,
fields=fields,
)
try:
issues = jira_client.search_issues(
jql_str=jql,
startAt=start,
maxResults=max_results,
fields=fields,
)
except JIRAError as e:
_handle_jira_search_error(e, jql)
raise # Explicitly re-raise for type checker, should never reach here
for issue in issues:
if isinstance(issue, Issue):

View File

@@ -10,7 +10,7 @@ from urllib.parse import urlparse
from urllib.parse import urlunparse
from pywikibot import family # type: ignore[import-untyped]
from pywikibot import pagegenerators # type: ignore[import-untyped]
from pywikibot import pagegenerators
from pywikibot.scripts import generate_family_file # type: ignore[import-untyped]
from pywikibot.scripts.generate_user_files import pywikibot # type: ignore[import-untyped]

View File

@@ -10,8 +10,8 @@ from typing import cast
from typing import ClassVar
import pywikibot.time # type: ignore[import-untyped]
from pywikibot import pagegenerators # type: ignore[import-untyped]
from pywikibot import textlib # type: ignore[import-untyped]
from pywikibot import pagegenerators
from pywikibot import textlib
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource

View File

@@ -196,6 +196,10 @@ CONNECTOR_CLASS_MAP = {
module_path="onyx.connectors.highspot.connector",
class_name="HighspotConnector",
),
DocumentSource.DRUPAL_WIKI: ConnectorMapping(
module_path="onyx.connectors.drupal_wiki.connector",
class_name="DrupalWikiConnector",
),
DocumentSource.IMAP: ConnectorMapping(
module_path="onyx.connectors.imap.connector",
class_name="ImapConnector",

View File

@@ -55,12 +55,10 @@ from onyx.connectors.models import ImageSection
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.connectors.sharepoint.connector_utils import get_sharepoint_external_access
from onyx.file_processing.extract_file_text import ACCEPTED_IMAGE_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import OnyxExtensionType
from onyx.file_processing.file_validation import EXCLUDED_IMAGE_TYPES
from onyx.file_processing.file_types import OnyxFileExtensions
from onyx.file_processing.file_types import OnyxMimeTypes
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.utils.b64 import get_image_type_from_bytes
from onyx.utils.logger import setup_logger
@@ -328,7 +326,7 @@ def _convert_driveitem_to_document_with_permissions(
try:
item_json = driveitem.to_json()
mime_type = item_json.get("file", {}).get("mimeType")
if not mime_type or mime_type in EXCLUDED_IMAGE_TYPES:
if not mime_type or mime_type in OnyxMimeTypes.EXCLUDED_IMAGE_TYPES:
# NOTE: this function should be refactored to look like Drive doc_conversion.py pattern
# for now, this skip must happen before we download the file
# Similar to Google Drive, we'll just semi-silently skip excluded image types
@@ -388,14 +386,14 @@ def _convert_driveitem_to_document_with_permissions(
return None
sections: list[TextSection | ImageSection] = []
file_ext = driveitem.name.split(".")[-1]
file_ext = get_file_ext(driveitem.name)
if not content_bytes:
logger.warning(
f"Zero-length content for '{driveitem.name}'. Skipping text/image extraction."
)
elif "." + file_ext in ACCEPTED_IMAGE_FILE_EXTENSIONS:
# NOTE: this if should use is_valid_image_type instead with mime_type
elif file_ext in OnyxFileExtensions.IMAGE_EXTENSIONS:
# NOTE: this if should probably check mime_type instead
image_section, _ = store_image_and_create_section(
image_data=content_bytes,
file_id=driveitem.id,
@@ -418,7 +416,7 @@ def _convert_driveitem_to_document_with_permissions(
# The only mime type that would be returned by get_image_type_from_bytes that is in
# EXCLUDED_IMAGE_TYPES is image/gif.
if mime_type in EXCLUDED_IMAGE_TYPES:
if mime_type in OnyxMimeTypes.EXCLUDED_IMAGE_TYPES:
logger.debug(
"Skipping embedded image of excluded type %s for %s",
mime_type,
@@ -1506,7 +1504,7 @@ class SharepointConnector(
)
for driveitem in driveitems:
driveitem_extension = get_file_ext(driveitem.name)
if not is_accepted_file_ext(driveitem_extension, OnyxExtensionType.All):
if driveitem_extension not in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS:
logger.warning(
f"Skipping {driveitem.web_url} as it is not a supported file type"
)
@@ -1514,7 +1512,7 @@ class SharepointConnector(
# Only yield empty documents if they are PDFs or images
should_yield_if_empty = (
driveitem_extension in ACCEPTED_IMAGE_FILE_EXTENSIONS
driveitem_extension in OnyxFileExtensions.IMAGE_EXTENSIONS
or driveitem_extension == ".pdf"
)

View File

@@ -18,7 +18,7 @@ from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
@@ -50,7 +50,7 @@ class TeamsCheckpoint(ConnectorCheckpoint):
class TeamsConnector(
CheckpointedConnector[TeamsCheckpoint],
CheckpointedConnectorWithPermSync[TeamsCheckpoint],
SlimConnectorWithPermSync,
):
MAX_WORKERS = 10
@@ -247,13 +247,23 @@ class TeamsConnector(
has_more=bool(todos),
)
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: TeamsCheckpoint,
) -> CheckpointOutput[TeamsCheckpoint]:
# Teams already fetches external_access (permissions) for each document
# in _convert_thread_to_document, so we can just delegate to load_from_checkpoint
return self.load_from_checkpoint(start, end, checkpoint)
# impls for SlimConnectorWithPermSync
def retrieve_all_slim_docs_perm_sync(
self,
start: SecondsSinceUnixEpoch | None = None,
_end: SecondsSinceUnixEpoch | None = None,
_callback: IndexingHeartbeatInterface | None = None,
end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
start = start or 0
@@ -302,6 +312,12 @@ class TeamsConnector(
)
if len(slim_doc_buffer) >= _SLIM_DOC_BATCH_SIZE:
if callback:
if callback.should_stop():
raise RuntimeError(
"retrieve_all_slim_docs_perm_sync: Stop signal detected"
)
callback.progress("retrieve_all_slim_docs_perm_sync", 1)
yield slim_doc_buffer
slim_doc_buffer = []

View File

@@ -4,9 +4,9 @@ from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from office365.graph_client import GraphClient # type: ignore
from office365.teams.channels.channel import Channel # type: ignore
from office365.teams.channels.channel import ConversationMember # type: ignore
from office365.graph_client import GraphClient # type: ignore[import-untyped]
from office365.teams.channels.channel import Channel # type: ignore[import-untyped]
from office365.teams.channels.channel import ConversationMember
from onyx.access.models import ExternalAccess
from onyx.connectors.interfaces import SecondsSinceUnixEpoch

View File

@@ -18,6 +18,7 @@ from oauthlib.oauth2 import BackendApplicationClient
from playwright.sync_api import BrowserContext
from playwright.sync_api import Playwright
from playwright.sync_api import sync_playwright
from playwright.sync_api import TimeoutError
from requests_oauthlib import OAuth2Session # type:ignore
from urllib3.exceptions import MaxRetryError
@@ -86,6 +87,8 @@ WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS = 20
IFRAME_TEXT_LENGTH_THRESHOLD = 700
# Message indicating JavaScript is disabled, which often appears when scraping fails
JAVASCRIPT_DISABLED_MESSAGE = "You have JavaScript disabled in your browser"
# Grace period after page navigation to allow bot-detection challenges to complete
BOT_DETECTION_GRACE_PERIOD_MS = 5000
# Define common headers that mimic a real browser
DEFAULT_USER_AGENT = (
@@ -554,12 +557,17 @@ class WebConnector(LoadConnector):
page = session_ctx.playwright_context.new_page()
try:
# Can't use wait_until="networkidle" because it interferes with the scrolling behavior
# Use "commit" instead of "domcontentloaded" to avoid hanging on bot-detection pages
# that may never fire domcontentloaded. "commit" waits only for navigation to be
# committed (response received), then we add a short wait for initial rendering.
page_response = page.goto(
initial_url,
timeout=30000, # 30 seconds
wait_until="domcontentloaded", # Wait for DOM to be ready
wait_until="commit", # Wait for navigation to commit
)
# Give the page a moment to start rendering after navigation commits.
# Allows CloudFlare and other bot-detection challenges to complete.
page.wait_for_timeout(BOT_DETECTION_GRACE_PERIOD_MS)
last_modified = (
page_response.header_value("Last-Modified") if page_response else None
@@ -584,8 +592,15 @@ class WebConnector(LoadConnector):
previous_height = page.evaluate("document.body.scrollHeight")
while scroll_attempts < WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS:
page.evaluate("window.scrollTo(0, document.body.scrollHeight)")
# wait for the content to load if we scrolled
page.wait_for_load_state("networkidle", timeout=30000)
# Wait for content to load, but catch timeout if page never reaches networkidle
# (e.g., CloudFlare protection keeps making requests)
try:
page.wait_for_load_state(
"networkidle", timeout=BOT_DETECTION_GRACE_PERIOD_MS
)
except TimeoutError:
# If networkidle times out, just give it a moment for content to render
time.sleep(1)
time.sleep(0.5) # let javascript run
new_height = page.evaluate("document.body.scrollHeight")

View File

@@ -21,6 +21,13 @@ class OptionalSearchSetting(str, Enum):
class QueryType(str, Enum):
"""
The type of first-pass query to use for hybrid search.
The values of this enum are injected into the ranking profile name which
should match the name in the schema.
"""
KEYWORD = "keyword"
SEMANTIC = "semantic"

View File

@@ -13,6 +13,8 @@ from slack_sdk.errors import SlackApiError
from sqlalchemy.orm import Session
from onyx.configs.app_configs import ENABLE_CONTEXTUAL_RAG
from onyx.configs.app_configs import MAX_SLACK_THREAD_CONTEXT_MESSAGES
from onyx.configs.app_configs import SLACK_THREAD_CONTEXT_BATCH_SIZE
from onyx.configs.chat_configs import DOC_TIME_DECAY
from onyx.connectors.models import IndexingDocument
from onyx.connectors.models import TextSection
@@ -39,7 +41,7 @@ from onyx.federated_connectors.slack.models import SlackEntities
from onyx.indexing.chunker import Chunker
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.models import DocAwareChunk
from onyx.llm.factory import get_default_llms
from onyx.llm.factory import get_default_llm
from onyx.onyxbot.slack.models import ChannelType
from onyx.onyxbot.slack.models import SlackContext
from onyx.redis.redis_pool import get_redis_client
@@ -623,33 +625,55 @@ def merge_slack_messages(
return merged_messages, docid_to_message, all_filtered_channels
def get_contextualized_thread_text(
class SlackRateLimitError(Exception):
"""Raised when Slack API returns a rate limit error (429)."""
class ThreadContextResult:
"""Result wrapper for thread context fetch that captures error type."""
__slots__ = ("text", "is_rate_limited", "is_error")
def __init__(
self, text: str, is_rate_limited: bool = False, is_error: bool = False
):
self.text = text
self.is_rate_limited = is_rate_limited
self.is_error = is_error
@classmethod
def success(cls, text: str) -> "ThreadContextResult":
return cls(text)
@classmethod
def rate_limited(cls, original_text: str) -> "ThreadContextResult":
return cls(original_text, is_rate_limited=True)
@classmethod
def error(cls, original_text: str) -> "ThreadContextResult":
return cls(original_text, is_error=True)
def _fetch_thread_context(
message: SlackMessage, access_token: str, team_id: str | None = None
) -> str:
) -> ThreadContextResult:
"""
Retrieves the initial thread message as well as the text following the message
and combines them into a single string. If the slack query fails, returns the
original message text.
Fetch thread context for a message, returning a result object.
The idea is that the message (the one that actually matched the search), the
initial thread message, and the replies to the message are important in answering
the user's query.
Args:
message: The SlackMessage to get context for
access_token: Slack OAuth access token
team_id: Slack team ID for caching user profiles (optional but recommended)
Returns ThreadContextResult with:
- success: enriched thread text
- rate_limited: original text + flag indicating we should stop
- error: original text for other failures (graceful degradation)
"""
channel_id = message.channel_id
thread_id = message.thread_id
message_id = message.message_id
# if it's not a thread, return the message text
# If not a thread, return original text as success
if thread_id is None:
return message.text
return ThreadContextResult.success(message.text)
# get the thread messages
slack_client = WebClient(token=access_token)
slack_client = WebClient(token=access_token, timeout=30)
try:
response = slack_client.conversations_replies(
channel=channel_id,
@@ -658,19 +682,44 @@ def get_contextualized_thread_text(
response.validate()
messages: list[dict[str, Any]] = response.get("messages", [])
except SlackApiError as e:
logger.error(f"Slack API error in get_contextualized_thread_text: {e}")
return message.text
# Check for rate limit error specifically
if e.response and e.response.status_code == 429:
logger.warning(
f"Slack rate limit hit while fetching thread context for {channel_id}/{thread_id}"
)
return ThreadContextResult.rate_limited(message.text)
# For other Slack errors, log and return original text
logger.error(f"Slack API error in thread context fetch: {e}")
return ThreadContextResult.error(message.text)
except Exception as e:
# Network errors, timeouts, etc - treat as recoverable error
logger.error(f"Unexpected error in thread context fetch: {e}")
return ThreadContextResult.error(message.text)
# make sure we didn't get an empty response or a single message (not a thread)
# If empty response or single message (not a thread), return original text
if len(messages) <= 1:
return message.text
return ThreadContextResult.success(message.text)
# add the initial thread message
# Build thread text from thread starter + context window around matched message
thread_text = _build_thread_text(
messages, message_id, thread_id, access_token, team_id, slack_client
)
return ThreadContextResult.success(thread_text)
def _build_thread_text(
messages: list[dict[str, Any]],
message_id: str,
thread_id: str,
access_token: str,
team_id: str | None,
slack_client: WebClient,
) -> str:
"""Build the thread text from messages."""
msg_text = messages[0].get("text", "")
msg_sender = messages[0].get("user", "")
thread_text = f"<@{msg_sender}>: {msg_text}"
# add the message (unless it's the initial message)
thread_text += "\n\nReplies:"
if thread_id == message_id:
message_id_idx = 0
@@ -681,28 +730,21 @@ def get_contextualized_thread_text(
if not message_id_idx:
return thread_text
# Include a few messages BEFORE the matched message for context
# This helps understand what the matched message is responding to
start_idx = max(
1, message_id_idx - SLACK_THREAD_CONTEXT_WINDOW
) # Start after thread starter
start_idx = max(1, message_id_idx - SLACK_THREAD_CONTEXT_WINDOW)
# Add ellipsis if we're skipping messages between thread starter and context window
if start_idx > 1:
thread_text += "\n..."
# Add context messages before the matched message
for i in range(start_idx, message_id_idx):
msg_text = messages[i].get("text", "")
msg_sender = messages[i].get("user", "")
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
# Add the matched message itself
msg_text = messages[message_id_idx].get("text", "")
msg_sender = messages[message_id_idx].get("user", "")
thread_text += f"\n\n<@{msg_sender}>: {msg_text}"
# add the following replies to the thread text
# Add following replies
len_replies = 0
for msg in messages[message_id_idx + 1 :]:
msg_text = msg.get("text", "")
@@ -710,22 +752,19 @@ def get_contextualized_thread_text(
reply = f"\n\n<@{msg_sender}>: {msg_text}"
thread_text += reply
# stop if len_replies exceeds chunk_size * 4 chars as the rest likely won't fit
len_replies += len(reply)
if len_replies >= DOC_EMBEDDING_CONTEXT_SIZE * 4:
thread_text += "\n..."
break
# replace user ids with names in the thread text using cached lookups
# Replace user IDs with names using cached lookups
userids: set[str] = set(re.findall(r"<@([A-Z0-9]+)>", thread_text))
if team_id:
# Use cached batch lookup when team_id is available
user_profiles = batch_get_user_profiles(access_token, team_id, userids)
for userid, name in user_profiles.items():
thread_text = thread_text.replace(f"<@{userid}>", name)
else:
# Fallback to individual lookups (no caching) when team_id not available
for userid in userids:
try:
response = slack_client.users_profile_get(user=userid)
@@ -735,7 +774,7 @@ def get_contextualized_thread_text(
except SlackApiError as e:
if "user_not_found" in str(e):
logger.debug(
f"User {userid} not found in Slack workspace (likely deleted/deactivated)"
f"User {userid} not found (likely deleted/deactivated)"
)
else:
logger.warning(f"Could not fetch profile for user {userid}: {e}")
@@ -747,6 +786,115 @@ def get_contextualized_thread_text(
return thread_text
def fetch_thread_contexts_with_rate_limit_handling(
slack_messages: list[SlackMessage],
access_token: str,
team_id: str | None,
batch_size: int = SLACK_THREAD_CONTEXT_BATCH_SIZE,
max_messages: int | None = MAX_SLACK_THREAD_CONTEXT_MESSAGES,
) -> list[str]:
"""
Fetch thread contexts in controlled batches, stopping on rate limit.
Distinguishes between error types:
- Rate limit (429): Stop processing further batches
- Other errors: Continue processing (graceful degradation)
Args:
slack_messages: Messages to fetch thread context for (should be sorted by relevance)
access_token: Slack OAuth token
team_id: Slack team ID for user profile caching
batch_size: Number of concurrent API calls per batch
max_messages: Maximum messages to fetch thread context for (None = no limit)
Returns:
List of thread texts, one per input message.
Messages beyond max_messages or after rate limit get their original text.
"""
if not slack_messages:
return []
# Limit how many messages we fetch thread context for (if max_messages is set)
if max_messages and max_messages < len(slack_messages):
messages_for_context = slack_messages[:max_messages]
messages_without_context = slack_messages[max_messages:]
else:
messages_for_context = slack_messages
messages_without_context = []
logger.info(
f"Fetching thread context for {len(messages_for_context)} of {len(slack_messages)} messages "
f"(batch_size={batch_size}, max={max_messages or 'unlimited'})"
)
results: list[str] = []
rate_limited = False
total_batches = (len(messages_for_context) + batch_size - 1) // batch_size
rate_limit_batch = 0
# Process in batches
for i in range(0, len(messages_for_context), batch_size):
current_batch = i // batch_size + 1
if rate_limited:
# Skip remaining batches, use original message text
remaining = messages_for_context[i:]
skipped_batches = total_batches - rate_limit_batch
logger.warning(
f"Slack rate limit: skipping {len(remaining)} remaining messages "
f"({skipped_batches} of {total_batches} batches). "
f"Successfully enriched {len(results)} messages before rate limit."
)
results.extend([msg.text for msg in remaining])
break
batch = messages_for_context[i : i + batch_size]
# _fetch_thread_context returns ThreadContextResult (never raises)
# allow_failures=True is a safety net for any unexpected exceptions
batch_results: list[ThreadContextResult | None] = (
run_functions_tuples_in_parallel(
[
(
_fetch_thread_context,
(msg, access_token, team_id),
)
for msg in batch
],
allow_failures=True,
max_workers=batch_size,
)
)
# Process results - ThreadContextResult tells us exactly what happened
for j, result in enumerate(batch_results):
if result is None:
# Unexpected exception (shouldn't happen) - use original text, stop
logger.error(f"Unexpected None result for message {j} in batch")
results.append(batch[j].text)
rate_limited = True
rate_limit_batch = current_batch
elif result.is_rate_limited:
# Rate limit hit - use original text, stop further batches
results.append(result.text)
rate_limited = True
rate_limit_batch = current_batch
else:
# Success or recoverable error - use the text (enriched or original)
results.append(result.text)
if rate_limited:
logger.warning(
f"Slack rate limit (429) hit at batch {current_batch}/{total_batches} "
f"while fetching thread context. Stopping further API calls."
)
# Add original text for messages we didn't fetch context for
results.extend([msg.text for msg in messages_without_context])
return results
def convert_slack_score(slack_score: float) -> float:
"""
Convert slack score to a score between 0 and 1.
@@ -830,8 +978,8 @@ def slack_retrieval(
)
# Query slack with entity filtering
_, fast_llm = get_default_llms()
query_strings = build_slack_queries(query, fast_llm, entities, available_channels)
llm = get_default_llm()
query_strings = build_slack_queries(query, llm, entities, available_channels)
# Determine filtering based on entities OR context (bot)
include_dm = False
@@ -964,11 +1112,12 @@ def slack_retrieval(
if not slack_messages:
return []
thread_texts: list[str] = run_functions_tuples_in_parallel(
[
(get_contextualized_thread_text, (slack_message, access_token, team_id))
for slack_message in slack_messages
]
# Fetch thread context with rate limit handling and message limiting
# Messages are already sorted by relevance (slack_score), so top N get full context
thread_texts = fetch_thread_contexts_with_rate_limit_handling(
slack_messages=slack_messages,
access_token=access_token,
team_id=team_id,
)
for slack_message, thread_text in zip(slack_messages, thread_texts):
slack_message.text = thread_text

View File

@@ -90,6 +90,16 @@ def _build_index_filters(
if not source_filter and detected_source_filter:
source_filter = detected_source_filter
# CRITICAL FIX: If user_file_ids are present, we must ensure "user_file"
# source type is included in the filter, otherwise user files will be excluded!
if user_file_ids and source_filter:
from onyx.configs.constants import DocumentSource
# Add user_file to the source filter if not already present
if DocumentSource.USER_FILE not in source_filter:
source_filter = list(source_filter) + [DocumentSource.USER_FILE]
logger.debug("Added USER_FILE to source_filter for user knowledge search")
user_acl_filters = (
None if bypass_acl else build_access_filters_for_user(user, db_session)
)
@@ -104,6 +114,7 @@ def _build_index_filters(
access_control_list=user_acl_filters,
tenant_id=get_current_tenant_id() if MULTI_TENANT else None,
)
return final_filters

View File

@@ -44,6 +44,7 @@ def query_analysis(query: str) -> tuple[bool, list[str]]:
return analysis_model.predict(query)
# TODO: This is unused code.
@log_function_time(print_only=True)
def retrieval_preprocessing(
search_request: SearchRequest,

View File

@@ -118,6 +118,7 @@ def combine_retrieval_results(
return sorted_chunks
# TODO: This is unused code.
@log_function_time(print_only=True)
def doc_index_retrieval(
query: SearchQuery,
@@ -348,6 +349,7 @@ def retrieve_chunks(
list(query.filters.source_type) if query.filters.source_type else None,
query.filters.document_set,
slack_context,
query.filters.user_file_ids,
)
federated_sources = set(
federated_retrieval_info.source.to_non_federated_source()
@@ -475,6 +477,7 @@ def search_chunks(
source_types=list(source_filters) if source_filters else None,
document_set_names=query_request.filters.document_set,
slack_context=slack_context,
user_file_ids=query_request.filters.user_file_ids,
)
federated_sources = set(
@@ -510,6 +513,7 @@ def search_chunks(
return top_chunks
# TODO: This is unused code.
def inference_sections_from_ids(
doc_identifiers: list[tuple[str, int]],
document_index: DocumentIndex,

View File

@@ -63,7 +63,7 @@ def get_live_users_count(db_session: Session) -> int:
This does NOT include invited users, "users" pulled in
from external connectors, or API keys.
"""
count_stmt = func.count(User.id) # type: ignore
count_stmt = func.count(User.id)
select_stmt = select(count_stmt)
select_stmt_w_filters = _add_live_user_count_where_clause(select_stmt, False)
user_count = db_session.scalar(select_stmt_w_filters)
@@ -74,7 +74,7 @@ def get_live_users_count(db_session: Session) -> int:
async def get_user_count(only_admin_users: bool = False) -> int:
async with get_async_session_context_manager() as session:
count_stmt = func.count(User.id) # type: ignore
count_stmt = func.count(User.id)
stmt = select(count_stmt)
stmt_w_filters = _add_live_user_count_where_clause(stmt, only_admin_users)
user_count = await session.scalar(stmt_w_filters)
@@ -100,10 +100,10 @@ class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]):
async def get_user_db(
session: AsyncSession = Depends(get_async_session),
) -> AsyncGenerator[SQLAlchemyUserAdminDB, None]:
yield SQLAlchemyUserAdminDB(session, User, OAuthAccount) # type: ignore
yield SQLAlchemyUserAdminDB(session, User, OAuthAccount)
async def get_access_token_db(
session: AsyncSession = Depends(get_async_session),
) -> AsyncGenerator[SQLAlchemyAccessTokenDatabase, None]:
yield SQLAlchemyAccessTokenDatabase(session, AccessToken) # type: ignore
yield SQLAlchemyAccessTokenDatabase(session, AccessToken)

View File

@@ -626,7 +626,7 @@ def reserve_message_id(
chat_session_id=chat_session_id,
parent_message_id=parent_message,
latest_child_message_id=None,
message="Response was termination prior to completion, try regenerating.",
message="Response was terminated prior to completion, try regenerating.",
token_count=15,
message_type=message_type,
)
@@ -744,29 +744,61 @@ def update_search_docs_table_with_relevance(
db_session.commit()
def _sanitize_for_postgres(value: str) -> str:
"""Remove NUL (0x00) characters from strings as PostgreSQL doesn't allow them."""
sanitized = value.replace("\x00", "")
if value and not sanitized:
logger.warning("Sanitization removed all characters from string")
return sanitized
def _sanitize_list_for_postgres(values: list[str]) -> list[str]:
"""Remove NUL (0x00) characters from all strings in a list."""
return [_sanitize_for_postgres(v) for v in values]
def create_db_search_doc(
server_search_doc: ServerSearchDoc,
db_session: Session,
commit: bool = True,
) -> DBSearchDoc:
# Sanitize string fields to remove NUL characters (PostgreSQL doesn't allow them)
db_search_doc = DBSearchDoc(
document_id=server_search_doc.document_id,
document_id=_sanitize_for_postgres(server_search_doc.document_id),
chunk_ind=server_search_doc.chunk_ind,
semantic_id=server_search_doc.semantic_identifier or "Unknown",
link=server_search_doc.link,
blurb=server_search_doc.blurb,
semantic_id=_sanitize_for_postgres(server_search_doc.semantic_identifier),
link=(
_sanitize_for_postgres(server_search_doc.link)
if server_search_doc.link is not None
else None
),
blurb=_sanitize_for_postgres(server_search_doc.blurb),
source_type=server_search_doc.source_type,
boost=server_search_doc.boost,
hidden=server_search_doc.hidden,
doc_metadata=server_search_doc.metadata,
is_relevant=server_search_doc.is_relevant,
relevance_explanation=server_search_doc.relevance_explanation,
relevance_explanation=(
_sanitize_for_postgres(server_search_doc.relevance_explanation)
if server_search_doc.relevance_explanation is not None
else None
),
# For docs further down that aren't reranked, we can't use the retrieval score
score=server_search_doc.score or 0.0,
match_highlights=server_search_doc.match_highlights,
match_highlights=_sanitize_list_for_postgres(
server_search_doc.match_highlights
),
updated_at=server_search_doc.updated_at,
primary_owners=server_search_doc.primary_owners,
secondary_owners=server_search_doc.secondary_owners,
primary_owners=(
_sanitize_list_for_postgres(server_search_doc.primary_owners)
if server_search_doc.primary_owners is not None
else None
),
secondary_owners=(
_sanitize_list_for_postgres(server_search_doc.secondary_owners)
if server_search_doc.secondary_owners is not None
else None
),
is_internet=server_search_doc.is_internet,
)

View File

@@ -40,6 +40,21 @@ def check_connectors_exist(db_session: Session) -> bool:
return result.scalar() or False
def check_user_files_exist(db_session: Session) -> bool:
"""Check if any user files exist in the system.
This is used to determine if the search tool should be available
when there are no regular connectors but there are user files
(User Knowledge mode).
"""
from onyx.db.models import UserFile
from onyx.db.enums import UserFileStatus
stmt = select(exists(UserFile).where(UserFile.status == UserFileStatus.COMPLETED))
result = db_session.execute(stmt)
return result.scalar() or False
def fetch_connectors(
db_session: Session,
sources: list[DocumentSource] | None = None,

View File

@@ -290,7 +290,7 @@ def get_document_counts_for_cc_pairs(
)
)
for connector_id, credential_id, cnt in db_session.execute(stmt).all(): # type: ignore
for connector_id, credential_id, cnt in db_session.execute(stmt).all():
aggregated_counts[(connector_id, credential_id)] = cnt
# Convert aggregated results back to the expected sequence of tuples
@@ -1098,7 +1098,7 @@ def reset_all_document_kg_stages(db_session: Session) -> int:
# The hasattr check is needed for type checking, even though rowcount
# is guaranteed to exist at runtime for UPDATE operations
return result.rowcount if hasattr(result, "rowcount") else 0 # type: ignore
return result.rowcount if hasattr(result, "rowcount") else 0
def update_document_kg_stages(
@@ -1121,7 +1121,7 @@ def update_document_kg_stages(
result = db_session.execute(stmt)
# The hasattr check is needed for type checking, even though rowcount
# is guaranteed to exist at runtime for UPDATE operations
return result.rowcount if hasattr(result, "rowcount") else 0 # type: ignore
return result.rowcount if hasattr(result, "rowcount") else 0
def get_skipped_kg_documents(db_session: Session) -> list[str]:

View File

@@ -234,9 +234,6 @@ def upsert_llm_provider(
existing_llm_provider.default_model_name = (
llm_provider_upsert_request.default_model_name
)
existing_llm_provider.fast_default_model_name = (
llm_provider_upsert_request.fast_default_model_name
)
existing_llm_provider.is_public = llm_provider_upsert_request.is_public
existing_llm_provider.deployment_name = llm_provider_upsert_request.deployment_name

View File

@@ -1,7 +1,9 @@
import datetime
from typing import cast
from uuid import UUID
from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.orm.attributes import flag_modified
@@ -24,7 +26,9 @@ logger = setup_logger()
# MCPServer operations
def get_all_mcp_servers(db_session: Session) -> list[MCPServer]:
"""Get all MCP servers"""
return list(db_session.scalars(select(MCPServer)).all())
return list(
db_session.scalars(select(MCPServer).order_by(MCPServer.created_at)).all()
)
def get_mcp_server_by_id(server_id: int, db_session: Session) -> MCPServer:
@@ -124,6 +128,7 @@ def update_mcp_server__no_commit(
auth_performer: MCPAuthenticationPerformer | None = None,
transport: MCPTransport | None = None,
status: MCPServerStatus | None = None,
last_refreshed_at: datetime.datetime | None = None,
) -> MCPServer:
"""Update an existing MCP server"""
server = get_mcp_server_by_id(server_id, db_session)
@@ -144,6 +149,8 @@ def update_mcp_server__no_commit(
server.transport = transport
if status is not None:
server.status = status
if last_refreshed_at is not None:
server.last_refreshed_at = last_refreshed_at
db_session.flush() # Don't commit yet, let caller decide when to commit
return server
@@ -330,3 +337,15 @@ def delete_user_connection_configs_for_server(
db_session.delete(config)
db_session.commit()
def delete_all_user_connection_configs_for_server_no_commit(
server_id: int, db_session: Session
) -> None:
"""Delete all user connection configs for a specific MCP server"""
db_session.execute(
delete(MCPConnectionConfig).where(
MCPConnectionConfig.mcp_server_id == server_id
)
)
db_session.flush() # Don't commit yet, let caller decide when to commit

View File

@@ -1,99 +0,0 @@
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from sqlalchemy.orm.attributes import flag_modified
from onyx.configs.constants import MilestoneRecordType
from onyx.db.models import Milestone
from onyx.db.models import User
USER_ASSISTANT_PREFIX = "user_assistants_used_"
MULTI_ASSISTANT_USED = "multi_assistant_used"
def create_milestone(
user: User | None,
event_type: MilestoneRecordType,
db_session: Session,
) -> Milestone:
milestone = Milestone(
event_type=event_type,
user_id=user.id if user else None,
)
db_session.add(milestone)
db_session.commit()
return milestone
def create_milestone_if_not_exists(
user: User | None, event_type: MilestoneRecordType, db_session: Session
) -> tuple[Milestone, bool]:
# Check if it exists
milestone = db_session.execute(
select(Milestone).where(Milestone.event_type == event_type)
).scalar_one_or_none()
if milestone is not None:
return milestone, False
# If it doesn't exist, try to create it.
try:
milestone = create_milestone(user, event_type, db_session)
return milestone, True
except IntegrityError:
# Another thread or process inserted it in the meantime
db_session.rollback()
# Fetch again to return the existing record
milestone = db_session.execute(
select(Milestone).where(Milestone.event_type == event_type)
).scalar_one() # Now should exist
return milestone, False
def update_user_assistant_milestone(
milestone: Milestone,
user_id: str | None,
assistant_id: int,
db_session: Session,
) -> None:
event_tracker = milestone.event_tracker
if event_tracker is None:
milestone.event_tracker = event_tracker = {}
if event_tracker.get(MULTI_ASSISTANT_USED):
# No need to keep tracking and populating if the milestone has already been hit
return
user_key = f"{USER_ASSISTANT_PREFIX}{user_id}"
if event_tracker.get(user_key) is None:
event_tracker[user_key] = [assistant_id]
elif assistant_id not in event_tracker[user_key]:
event_tracker[user_key].append(assistant_id)
flag_modified(milestone, "event_tracker")
db_session.commit()
def check_multi_assistant_milestone(
milestone: Milestone,
db_session: Session,
) -> tuple[bool, bool]:
"""Returns if the milestone was hit and if it was just hit for the first time"""
event_tracker = milestone.event_tracker
if event_tracker is None:
return False, False
if event_tracker.get(MULTI_ASSISTANT_USED):
return True, False
for key, value in event_tracker.items():
if key.startswith(USER_ASSISTANT_PREFIX) and len(value) > 1:
event_tracker[MULTI_ASSISTANT_USED] = True
flag_modified(milestone, "event_tracker")
db_session.commit()
return True, True
return False, False

View File

@@ -2215,6 +2215,8 @@ class ToolCall(Base):
# The tools with the same turn number (and parent) were called in parallel
# Ones with different turn numbers (and same parent) were called sequentially
turn_number: Mapped[int] = mapped_column(Integer)
# Index order of tool calls from the LLM for parallel tool calls
tab_index: Mapped[int] = mapped_column(Integer, default=0)
# Not a FK because we want to be able to delete the tool without deleting
# this entry
@@ -2382,7 +2384,6 @@ class LLMProvider(Base):
postgresql.JSONB(), nullable=True
)
default_model_name: Mapped[str] = mapped_column(String)
fast_default_model_name: Mapped[str | None] = mapped_column(String, nullable=True)
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
@@ -2958,7 +2959,7 @@ class SlackChannelConfig(Base):
"slack_bot_id",
"is_default",
unique=True,
postgresql_where=(is_default is True), # type: ignore
postgresql_where=(is_default is True),
),
)
@@ -3673,6 +3674,9 @@ class MCPServer(Base):
updated_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
last_refreshed_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
# Relationships
admin_connection_config: Mapped["MCPConnectionConfig | None"] = relationship(
@@ -3685,6 +3689,7 @@ class MCPServer(Base):
"MCPConnectionConfig",
foreign_keys="MCPConnectionConfig.mcp_server_id",
back_populates="mcp_server",
passive_deletes=True,
)
current_actions: Mapped[list["Tool"]] = relationship(
"Tool", back_populates="mcp_server", cascade="all, delete-orphan"
@@ -3913,3 +3918,22 @@ class ExternalGroupPermissionSyncAttempt(Base):
def is_finished(self) -> bool:
return self.status.is_terminal()
class License(Base):
"""Stores the signed license blob (singleton pattern - only one row)."""
__tablename__ = "license"
__table_args__ = (
# Singleton pattern - unique index on constant ensures only one row
Index("idx_license_singleton", text("(true)"), unique=True),
)
id: Mapped[int] = mapped_column(primary_key=True)
license_data: Mapped[str] = mapped_column(Text, nullable=False)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
updated_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)

View File

@@ -221,6 +221,7 @@ def create_tool_call_no_commit(
parent_tool_call_id: int | None = None,
reasoning_tokens: str | None = None,
generated_images: list[dict] | None = None,
tab_index: int = 0,
add_only: bool = True,
) -> ToolCall:
"""
@@ -239,6 +240,7 @@ def create_tool_call_no_commit(
parent_tool_call_id: Optional parent tool call ID (for nested tool calls)
reasoning_tokens: Optional reasoning tokens
generated_images: Optional list of generated image metadata for replay
tab_index: Index order of tool calls from the LLM for parallel tool calls
commit: If True, commit the transaction; if False, flush only
Returns:
@@ -249,6 +251,7 @@ def create_tool_call_no_commit(
parent_chat_message_id=parent_chat_message_id,
parent_tool_call_id=parent_tool_call_id,
turn_number=turn_number,
tab_index=tab_index,
tool_id=tool_id,
tool_call_id=tool_call_id,
reasoning_tokens=reasoning_tokens,

View File

@@ -257,7 +257,7 @@ def _get_users_by_emails(
"""given a list of lowercase emails,
returns a list[User] of Users whose emails match and a list[str]
the missing emails that had no User"""
stmt = select(User).filter(func.lower(User.email).in_(lower_emails)) # type: ignore
stmt = select(User).filter(func.lower(User.email).in_(lower_emails))
found_users = list(db_session.scalars(stmt).unique().all()) # Convert to list
# Extract found emails and convert to lowercase to avoid case sensitivity issues

Some files were not shown because too many files have changed in this diff Show More