Compare commits

..

279 Commits

Author SHA1 Message Date
pablodanswer
477f8eeb68 minor update 2025-02-05 16:53:04 -08:00
pablodanswer
737e37170d minor updates 2025-02-05 16:53:02 -08:00
Yuhong Sun
c58a7ef819 Slackbot to know its name (#3917) 2025-02-05 16:39:42 -08:00
rkuo-danswer
bd08e6d787 alert if revisions are null or query fails (#3910)
* alert if revisions are null or query fails

* comment

* mypy

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-05 23:45:38 +00:00
rkuo-danswer
47e6192b99 fix bug in validation logic (#3915)
* fix bug in validation logic

* test

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-05 22:49:18 +00:00
evan-danswer
29f5f4edfa fixed citations when sections selected (#3914)
* removed some dead code and fixed citations when a search request is made with sections selected

* fix black formatting issue
2025-02-05 22:16:07 +00:00
pablonyx
b469a7eff4 Put components in components directory + remove unused shortcut commands (#3909) 2025-02-05 14:29:29 -08:00
pablonyx
d1e9760b92 Enforce Slack Channel Default Config
Enforce Slack Channel Default Config
2025-02-05 14:28:03 -08:00
pablodanswer
7153cb09f1 add default slack channel config 2025-02-05 14:26:26 -08:00
pablonyx
78153e5012 Merge pull request #3913 from onyx-dot-app/very_minor_ux
remove unused border
2025-02-05 11:57:41 -08:00
pablodanswer
b1ee1efecb remove minor border issue 2025-02-05 11:57:03 -08:00
Sam Warner
526932a7f6 fix chat image upload double read 2025-02-05 09:52:51 -08:00
Weves
6889152d81 Fix issue causing file connector to fail 2025-02-04 22:19:04 -08:00
pablonyx
4affc259a6 Password reset tenant (#3895)
* nots

* functional

* minor naming cleanup

* nit

* update constant

* k
2025-02-05 03:17:11 +00:00
pablonyx
0ec065f1fb Set GPT 4o as default and add O3 mini (#3899)
* quick update to models

* add reqs

* update version
2025-02-05 03:06:05 +00:00
Weves
8eb4320f76 Support not pausing connectors on initialization failure 2025-02-04 19:32:55 -08:00
Weves
1c12ab31f9 Fix extra __init__ file + allow adding API keys to user groups 2025-02-04 17:21:06 -08:00
Yuhong Sun
49fd76b336 Tool Call Error Display (#3897) 2025-02-04 16:12:50 -08:00
rkuo-danswer
5854b39dd4 Merge pull request #3893 from onyx-dot-app/mypy_random
Mypy random fixes
2025-02-04 16:02:18 -08:00
rkuo-danswer
c0271a948a Merge pull request #3856 from onyx-dot-app/feature/no_scan_iter
lessen usage of scan_iter
2025-02-04 15:57:03 -08:00
Richard Kuo (Danswer)
aff4ee5ebf commented code 2025-02-04 15:56:18 -08:00
Richard Kuo (Danswer)
675d2f3539 Merge branch 'main' of https://github.com/onyx-dot-app/onyx into feature/no_scan_iter 2025-02-04 15:55:42 -08:00
rkuo-danswer
2974b57ef4 Merge pull request #3898 from onyx-dot-app/bugfix/temporary_xfail
xfail test until fixed
2025-02-04 15:54:44 -08:00
Richard Kuo (Danswer)
679bdd5e04 xfail test until fixed 2025-02-04 15:53:45 -08:00
Yuhong Sun
e6cb47fcb8 Prompt 2025-02-04 14:42:18 -08:00
Yuhong Sun
a514818e13 Citations 2025-02-04 14:34:44 -08:00
Yuhong Sun
89021cde90 Citation Prompt 2025-02-04 14:17:23 -08:00
Chris Weaver
32ecc282a2 Update README.md
Fix Cal link in README
2025-02-04 13:11:46 -08:00
Yuhong Sun
59b1d4673f Updating some Prompts (#3894) 2025-02-04 12:23:15 -08:00
pablodanswer
ec0c655c8d misc improvement 2025-02-04 12:06:11 -08:00
pablodanswer
42a0f45a96 update 2025-02-04 12:06:11 -08:00
pablodanswer
125e5eaab1 various mypy improvements 2025-02-04 12:06:10 -08:00
Richard Kuo (Danswer)
f2dab9ba89 Merge branch 'main' of https://github.com/onyx-dot-app/onyx into feature/no_scan_iter 2025-02-04 12:01:57 -08:00
Richard Kuo
02a068a68b multiplier from 8 to 4 2025-02-03 23:59:36 -08:00
evan-danswer
91f0650071 Merge pull request #3749 from onyx-dot-app/agent-search-feature
Agent search
2025-02-03 21:31:46 -08:00
pablodanswer
b97819189b push various minor updates 2025-02-03 21:23:45 -08:00
Evan Lohn
b928201397 fixed rebase issue and some cleanup 2025-02-03 20:49:45 -08:00
Yuhong Sun
b500c914b0 cleanup 2025-02-03 20:10:51 -08:00
Yuhong Sun
4b0d22fae3 prompts 2025-02-03 20:10:51 -08:00
joachim-danswer
b46c09ac6c EL comments 2025-02-03 20:10:51 -08:00
joachim-danswer
3ce8923086 fix for citation update 2025-02-03 20:10:51 -08:00
joachim-danswer
7ac6d3ed50 logging level changes 2025-02-03 20:10:51 -08:00
joachim-danswer
3cd057d7a2 LangGraph comments 2025-02-03 20:10:51 -08:00
joachim-danswer
4834ee6223 new citation format 2025-02-03 20:10:51 -08:00
pablodanswer
cb85be41b1 add proper citation handling 2025-02-03 20:10:51 -08:00
joachim-danswer
eb227c0acc nit update 2025-02-03 20:10:51 -08:00
joachim-danswer
25a57e2292 add title and meta-data to doc 2025-02-03 20:10:51 -08:00
pablodanswer
3f3b04a4ee update width 2025-02-03 20:10:51 -08:00
Evan Lohn
3f6de7968a prompt improvements for wekaer models 2025-02-03 20:10:51 -08:00
pablodanswer
024207e2d9 update 2025-02-03 20:10:51 -08:00
Yuhong Sun
8f7db9212c k 2025-02-03 20:10:51 -08:00
pablodanswer
b1e9e03aa4 nit 2025-02-03 20:10:51 -08:00
pablodanswer
87a53d6d80 quick update 2025-02-03 20:10:51 -08:00
Yuhong Sun
59c65a4192 prompts 2025-02-03 20:10:51 -08:00
pablodanswer
c984c6c7f2 add pro search disable 2025-02-03 20:10:51 -08:00
Yuhong Sun
9a3ce504bc beta 2025-02-03 20:10:51 -08:00
Yuhong Sun
16265d27f5 k 2025-02-03 20:10:51 -08:00
Yuhong Sun
570fe43efb log level changes 2025-02-03 20:10:51 -08:00
Yuhong Sun
506a9f1b94 Yuhong 2025-02-03 20:10:51 -08:00
Yuhong Sun
a067b32467 Partial Prompt Updates (#3880) 2025-02-03 20:10:51 -08:00
pablodanswer
9b6e51b4fe k 2025-02-03 20:10:51 -08:00
joachim-danswer
e23dd0a3fa renames + fix of refined answer generation prompt 2025-02-03 20:10:51 -08:00
Evan Lohn
71304e4228 always persist in agent search 2025-02-03 20:10:51 -08:00
Evan Lohn
2adeaaeded loading object into model instead of json 2025-02-03 20:10:51 -08:00
Evan Lohn
a96728ff4d prompt piece optimizations 2025-02-03 20:10:51 -08:00
pablodanswer
eaffdee0dc broadly fixed minus some issues 2025-02-03 20:10:51 -08:00
pablodanswer
feaa3b653f fix misc issues 2025-02-03 20:10:51 -08:00
joachim-danswer
9438f9df05 removal of sone unused states/models 2025-02-03 20:10:51 -08:00
joachim-danswer
b90e0834a5 major renaming 2025-02-03 20:10:51 -08:00
Evan Lohn
29440f5482 alembic heads, basic citations, search pipeline state 2025-02-03 20:10:51 -08:00
Evan Lohn
5a95a5c9fd large number of PR comments addressed 2025-02-03 20:10:51 -08:00
Evan Lohn
118e8afbef reworked config to have logical structure 2025-02-03 20:10:51 -08:00
joachim-danswer
8342168658 initial variable renaming 2025-02-03 20:10:51 -08:00
joachim-danswer
d5661baf98 history summary fix
- adjusted prompt
 - adjusted citation removal
 - length cutoff by words, not characters
2025-02-03 20:10:51 -08:00
joachim-danswer
95fcc0019c history summary update 2025-02-03 20:10:51 -08:00
joachim-danswer
0ccd83e809 deep_search_a and agent_a_config renaming 2025-02-03 20:10:51 -08:00
joachim-danswer
732861a940 rename of documents to verified_reranked_documents 2025-02-03 20:10:51 -08:00
joachim-danswer
d53dd1e356 cited_docs -> cited_documents 2025-02-03 20:10:51 -08:00
joachim-danswer
1a2760edee improved logging through agent_state plus some default fixes 2025-02-03 20:10:51 -08:00
joachim-danswer
23ae4547ca default values of number of strings and other things 2025-02-03 20:10:51 -08:00
Evan Lohn
385b344a43 addressed TODOs 2025-02-03 20:10:51 -08:00
Evan Lohn
a340529de3 sync streaming impl 2025-02-03 20:10:51 -08:00
joachim-danswer
4a0b2a6c09 additional naming fixes 2025-02-03 20:10:51 -08:00
joachim-danswer
756a1cbf8f answer_refined_question_subgraphs 2025-02-03 20:10:51 -08:00
joachim-danswer
8af4f1da8e more renaming 2025-02-03 20:10:51 -08:00
Evan Lohn
4b82440915 finished rebase and fixed issues 2025-02-03 20:10:51 -08:00
Evan Lohn
bb6d55783e addressing PR comments 2025-02-03 20:10:51 -08:00
Evan Lohn
2b8cd63b34 main nodes renaming 2025-02-03 20:10:51 -08:00
joachim-danswer
b0c3098693 more renaming and consolidation 2025-02-03 20:10:51 -08:00
joachim-danswer
2517aa39b2 more renamings 2025-02-03 20:10:51 -08:00
joachim-danswer
ceaaa05af0 renamings and consolidation of formatting nodes in orig question retrieval 2025-02-03 20:10:51 -08:00
joachim-danswer
3b13380051 k 2025-02-03 20:10:51 -08:00
joachim-danswer
ef6e6f9556 more renaming 2025-02-03 20:10:51 -08:00
joachim-danswer
0a6808c4c1 rename initial_sub_question_creation 2025-02-03 20:10:51 -08:00
Evan Lohn
6442c56d82 remaining small find replace fix 2025-02-03 20:10:51 -08:00
Evan Lohn
e191e514b9 fixed find and replace issue 2025-02-03 20:10:51 -08:00
Evan Lohn
f33a2ffb01 node renaming 2025-02-03 20:10:51 -08:00
joachim-danswer
0578c31522 rename retrieval & consolidate_sub_answers (initial and refinement) 2025-02-03 20:10:51 -08:00
joachim-danswer
8cbdc6d8fe fix for refinement renaming 2025-02-03 20:10:51 -08:00
joachim-danswer
60fb06da4e rename initial_answer_generation pt 2 2025-02-03 20:10:51 -08:00
joachim-danswer
55ed6e2294 rename initial_answer_generation 2025-02-03 20:10:50 -08:00
joachim-danswer
42780d5f97 rename of individual_sub_answer_generation 2025-02-03 20:10:50 -08:00
Evan Lohn
f050d281fd refininement->refinement 2025-02-03 20:10:50 -08:00
joachim-danswer
3ca4d532b4 renamed directories, prompts, and small citation fix 2025-02-03 20:10:50 -08:00
pablodanswer
e3e855c526 potential question fix 2025-02-03 20:10:50 -08:00
pablodanswer
23bf50b90a address doc 2025-02-03 20:10:50 -08:00
Yuhong Sun
c43c2320e7 Tiny nits 2025-02-03 20:10:50 -08:00
Evan Lohn
01e6e9a2ba fixed errors on import 2025-02-03 20:10:50 -08:00
Evan Lohn
bd3b1943c4 WIP PR comments 2025-02-03 20:10:50 -08:00
Evan Lohn
1dbf561db0 fix revision to match internal alembic state 2025-02-03 20:10:50 -08:00
Evan Lohn
a43a6627eb fix revision to match internal alembic state 2025-02-03 20:10:50 -08:00
Evan Lohn
5bff8bc8ce collapsed db migrations post-rebase (added missing file) 2025-02-03 20:10:50 -08:00
Evan Lohn
7879ba6a77 collapsed db migrations post-rebase 2025-02-03 20:10:50 -08:00
pablodanswer
a63b341913 latex update 2025-02-03 20:10:50 -08:00
pablodanswer
c062097b2a post rebase fix 2025-02-03 20:10:50 -08:00
Evan Lohn
48e42af8e7 fix rebase issue 2025-02-03 20:10:50 -08:00
Evan Lohn
6c7f8eaefb first pass at dead code deletion 2025-02-03 20:10:50 -08:00
joachim-danswer
3d99ad7bc4 var initialization 2025-02-03 20:10:50 -08:00
joachim-danswer
8fea571f6e k 2025-02-03 20:10:50 -08:00
joachim-danswer
d70bbcc2ce k 2025-02-03 20:10:50 -08:00
joachim-danswer
73769c6cae k 2025-02-03 20:10:50 -08:00
joachim-danswer
7e98936c58 Enrichment prompts, prompt improvements, dispatch logging & reinsert empty tool response 2025-02-03 20:10:50 -08:00
joachim-danswer
4e17fc06ff variable renaming 2025-02-03 20:10:50 -08:00
joachim-danswer
ff4df6f3bf fix for merge error (#3814) 2025-02-03 20:10:50 -08:00
joachim-danswer
91b929d466 graph directory renamings 2025-02-03 20:10:50 -08:00
joachim-danswer
6bef5ca7a4 persona_prompt improvements 2025-02-03 20:10:50 -08:00
joachim-danswer
4817fa0bd1 average dispatch time collection for sub-answers 2025-02-03 20:10:50 -08:00
joachim-danswer
da4a086398 added total time to logging 2025-02-03 20:10:50 -08:00
joachim-danswer
69e8c5f0fc agent default changes/restructuring 2025-02-03 20:10:50 -08:00
joachim-danswer
12d1186888 increased logging 2025-02-03 20:10:50 -08:00
joachim-danswer
325892a21c cleanup of refined answer generation 2025-02-03 20:10:50 -08:00
joachim-danswer
18d92559b5 application of content limitation ion refined answer as well 2025-02-03 20:10:50 -08:00
joachim-danswer
f2aeeb7b3c Optimizations: docs for context & history
- summarize history if long
- introduced cited_docs from SQ as those must be provided to answer generations
- limit number of docs

TODO: same for refined flow
2025-02-03 20:10:50 -08:00
Evan Lohn
110c9f7e1b nit 2025-02-03 20:10:50 -08:00
Evan Lohn
1a22af4f27 AgentPromptConfig in Answer class 2025-02-03 20:10:50 -08:00
Evan Lohn
efa32a8c04 use reranking settings and persona during preprocessing in reranker 2025-02-03 20:10:50 -08:00
Evan Lohn
9bad12968f removed unused files 2025-02-03 20:10:50 -08:00
Evan Lohn
f1d96343a9 always send search response 2025-02-03 20:10:50 -08:00
Evan Lohn
0496ec3bb8 remove debug 2025-02-03 20:10:50 -08:00
pablodanswer
568f927b9b improve regeneration state 2025-02-03 20:10:50 -08:00
pablodanswer
f842e15d64 nit 2025-02-03 20:10:50 -08:00
pablodanswer
3a07093663 improved timing 2025-02-03 20:10:50 -08:00
Evan Lohn
1fe966d0f7 increased timeout to get rid of asyncio logger errors 2025-02-03 20:10:50 -08:00
joachim-danswer
812172f1bd addressing nits of EL 2025-02-03 20:10:50 -08:00
joachim-danswer
9e9bd440f4 updated answer_comparison prompt + small cleanup 2025-02-03 20:10:50 -08:00
joachim-danswer
7487b15522 refined search + question answering as sub-graphs 2025-02-03 20:10:50 -08:00
joachim-danswer
de5ce8a613 sub-graphs for initial question/search 2025-02-03 20:10:50 -08:00
joachim-danswer
8c9577aa95 refined search + question answering as sub-graphs 2025-02-03 20:10:50 -08:00
pablodanswer
4baf3dc484 minor update 2025-02-03 20:10:50 -08:00
pablodanswer
50ef5115e7 k 2025-02-03 20:10:50 -08:00
pablodanswer
a2247363af update switching logic 2025-02-03 20:10:50 -08:00
pablodanswer
a0af8ee91c fix toggling edge case 2025-02-03 20:10:50 -08:00
pablodanswer
25f6543443 update bool 2025-02-03 20:10:50 -08:00
pablodanswer
d52a0b96ac various improvements 2025-02-03 20:10:50 -08:00
pablodanswer
f14b282f0f quick nit 2025-02-03 20:10:50 -08:00
Evan Lohn
7d494cd65e allowed empty Search Tool for non-agentic search 2025-02-03 20:10:50 -08:00
pablodanswer
139374966f minor update - doc ordering 2025-02-03 20:10:50 -08:00
pablodanswer
bf06710215 k 2025-02-03 20:10:50 -08:00
pablodanswer
d4e0d0db05 quick nit 2025-02-03 20:10:50 -08:00
pablodanswer
f96a3ee29a k 2025-02-03 20:10:50 -08:00
joachim-danswer
3bf6b77319 Replaced additional limit with variable 2025-02-03 20:10:50 -08:00
joachim-danswer
3b3b0c8a87 Addressing EL's comments
- created vars for a couple of agent settings
 - moved agent configs
 - created a search function
2025-02-03 20:10:50 -08:00
joachim-danswer
aa8cb44a33 taking out Extraction for now 2025-02-03 20:10:50 -08:00
joachim-danswer
fc60fd0322 earlier entity extraction & sharper generation prompts 2025-02-03 20:10:50 -08:00
joachim-danswer
46402a97c7 tmp: force agent search 2025-02-03 20:10:50 -08:00
Evan Lohn
5bf6a47948 skip reranking for <=1 doc 2025-02-03 20:10:50 -08:00
Evan Lohn
2d8486bac4 stop infos when done streaming answers 2025-02-03 20:10:50 -08:00
Evan Lohn
eea6f2749a make field nullable 2025-02-03 20:10:50 -08:00
Evan Lohn
5e9b2e41ae persisting refined answer improvement 2025-02-03 20:10:50 -08:00
Evan Lohn
2bbe20edc3 address JR comments 2025-02-03 20:10:50 -08:00
Evan Lohn
db2004542e fixed chat tests 2025-02-03 20:10:50 -08:00
Evan Lohn
ddbfc65ad0 implemented top-level tool calling + force search 2025-02-03 20:10:50 -08:00
Evan Lohn
982040c792 WIP, but working basic search using initial tool choice node 2025-02-03 20:10:50 -08:00
pablodanswer
4b0a4a2741 k 2025-02-03 20:10:50 -08:00
pablodanswer
28ba01b361 updated + functional 2025-02-03 20:10:50 -08:00
pablodanswer
d32d1c6079 update- reorg 2025-02-03 20:10:50 -08:00
pablodanswer
dd494d2daa k 2025-02-03 20:10:50 -08:00
pablodanswer
eb6dbf49a1 build fix 2025-02-03 20:10:50 -08:00
joachim-danswer
e5fa411092 EL comments addressed 2025-02-03 20:10:50 -08:00
joachim-danswer
1ced8924b3 loser verification prompt 2025-02-03 20:10:50 -08:00
joachim-danswer
3c3900fac6 turning off initial search pre route decision 2025-02-03 20:10:50 -08:00
joachim-danswer
3b298e19bc change of sub-question answer if no docs recovered 2025-02-03 20:10:50 -08:00
joachim-danswer
71eafe04a8 various fixes from Yuhong's list 2025-02-03 20:10:50 -08:00
Yuhong Sun
80d248e02d Copy changes 2025-02-03 20:10:50 -08:00
Evan Lohn
2032fb10da removed print statements, fixed pass through handling 2025-02-03 20:10:50 -08:00
Evan Lohn
ca1f176c61 fixed basic flow citations and second test 2025-02-03 20:10:50 -08:00
Evan Lohn
3ced9bc28b fix for early cancellation test; solves issue with tasks being destroyed while pending 2025-02-03 20:10:50 -08:00
pablodanswer
deea9c8c3c add agent search frontend 2025-02-03 20:10:47 -08:00
Evan Lohn
4e47c81ed8 fix alembic history 2025-02-03 20:07:57 -08:00
joachim-danswer
00cee71c18 streaming + saving of search docs of no verified ones available
- sub-questions only
2025-02-03 20:07:57 -08:00
Evan Lohn
470c4d15dd reworked history messages in agent config 2025-02-03 20:07:57 -08:00
Evan Lohn
50bacc03b3 missed files from prev commit 2025-02-03 20:07:57 -08:00
Evan Lohn
dd260140b2 basic search restructure: WIP on fixing tests 2025-02-03 20:07:57 -08:00
joachim-danswer
8aa82be12a prompts that even further motivates to cite docs over sub-q's 2025-02-03 20:07:57 -08:00
joachim-danswer
b7f9e431a5 pydantic for LangGraph + changed ERT extraction flow 2025-02-03 20:07:57 -08:00
joachim-danswer
b9bd2ea4e2 history added to agent flow 2025-02-03 20:07:57 -08:00
pablodanswer
e4c93bed8b minor fixes to branch 2025-02-03 20:07:57 -08:00
Evan Lohn
4fd6e36c2f second clean commit 2025-02-03 20:07:57 -08:00
trial-danswer
715359c120 Helm chart refactoring (#3797)
* initial commit for helm chart refactoring

* Continue refactoring helm. I was able to use helm to deploy all of the apps to a cluster in aws. The bottleneck was setting up PVC dynamic provisioning.

* use default storage class

* Fix linter errors

* Fix broken helm test

---------

Co-authored-by: jpb80 <jordan.buttkevitz@gmail.com>
2025-02-03 10:56:07 -08:00
Richard Kuo (Danswer)
6f018d75ee use replica, remove some commented code 2025-02-03 10:10:05 -08:00
Richard Kuo (Danswer)
fd947aadea slow down to 8 again 2025-02-03 00:32:23 -08:00
Weves
e061ba2b93 another airtable fix 2025-02-02 20:58:24 -08:00
Weves
87bccc13cc Handle expiring attachments 2025-02-02 12:02:44 -08:00
Richard Kuo (Danswer)
3a950721b9 get rid of some more scan_iter 2025-02-02 01:14:10 -08:00
Weves
569639eb90 Improved attachment handling 2025-02-01 23:07:01 -08:00
pablodanswer
68cb1f3409 ensure tests don't run temporarily 2025-02-01 17:31:44 -08:00
pablonyx
11da0d9889 Add user specific chat session temperature (#3867)
* add user specific chat session temperature

* kbetter typing

* update
2025-02-01 17:29:58 -08:00
pablodanswer
6a7e2a8036 temporarily disable chat tests 2025-02-01 14:15:16 -08:00
pablodanswer
035f83c464 ensure tests pass (temporary dragging disabled) 2025-02-01 12:58:03 -08:00
pablonyx
3c34ddcc4f E2e assistant tests (#3869)
* adding llm override logic

* update

* general cleanup

* fix various tests

* rm

* update

* update

* better comments

* k

* k

* update to pass tests

* clarify content

* improve timeout
2025-02-01 20:05:53 +00:00
Richard Kuo (Danswer)
bbee2865e9 Merge branch 'main' of https://github.com/onyx-dot-app/onyx into feature/no_scan_iter 2025-02-01 10:46:38 -08:00
pablonyx
a82cac5361 Ensure anonymous users can give feedback
Ensure anonymous users can give feedback
2025-02-01 10:36:14 -08:00
pablodanswer
83e5cb2d2f tested 2025-01-31 16:40:37 -08:00
Chris Weaver
a5d2f0d9ac Fix airtable connector w/ mt cloud + move telem logic to match new st… (#3868)
* Fix airtable connector w/ mt cloud + move telem logic to match new standard

* Address Greptile comment

* Small fixes/improvements

* Revert back monitoring frequency

* Small monitoring fix
2025-01-31 16:29:04 -08:00
Richard Kuo (Danswer)
d3cf18160e lower CLOUD_BEAT_SCHEDULE_MULTIPLIER to 4 2025-01-31 16:13:13 -08:00
Richard Kuo (Danswer)
618e4addd8 better signal names 2025-01-31 13:25:27 -08:00
Richard Kuo (Danswer)
69f16cc972 dont add to the lookup table if it already exists 2025-01-31 13:23:52 -08:00
Richard Kuo (Danswer)
2676d40065 mereging 2025-01-31 12:14:24 -08:00
Richard Kuo (Danswer)
b64545c7c7 build a lookup table every so often to handle cloud migration 2025-01-31 12:12:52 -08:00
Weves
7bc8554e01 Airtable fix 2025-01-31 10:42:27 -08:00
Richard Kuo (Danswer)
5232aeacad Merge branch 'main' of https://github.com/onyx-dot-app/onyx into feature/no_scan_iter
# Conflicts:
#	backend/onyx/background/celery/tasks/vespa/tasks.py
#	backend/onyx/redis/redis_connector_doc_perm_sync.py
2025-01-31 10:38:10 -08:00
rkuo-danswer
261150e81a Validate permission locks (#3799)
* WIP for external group sync lock fixes

* prototyping permissions validation

* validate permission sync tasks in celery

* mypy

* cleanup and wire off external group sync checks for now

* add active key to reset

* improve logging

* reset on payload format change

* return False on exception

* missed a return

* add count of tasks scanned

* add comment

* better logging

* add return

* more return

* catch payload exceptions

* code review fixes

* push to restart test

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-01-31 17:33:07 +00:00
pablonyx
3e0d24a3f6 Update foreign key migration
Update foreign key migration
2025-01-31 08:45:19 -08:00
pablodanswer
ffe8ac168f update foreign key migration 2025-01-31 08:42:28 -08:00
pablonyx
17b280e59e Remove cloud_kubes from public repo
Remove `cloud_kubes` from public repo
2025-01-30 19:19:09 -08:00
pablonyx
5edba4a7f3 Foreign key input prompts
Foreign key input prompts
2025-01-30 19:18:49 -08:00
pablodanswer
d842fed37e foreign key updates 2025-01-30 19:17:32 -08:00
Weves
14981162fd Pin shapely version 2025-01-30 18:02:35 -08:00
Chris Weaver
288daa4e90 Add more airtable logging (#3862)
* Add more airtable logging

* Add multithreading

* Remove empty comment
2025-01-30 17:33:42 -08:00
Richard Kuo (Danswer)
30e8fb12e4 remove commented code 2025-01-30 15:34:00 -08:00
Richard Kuo (Danswer)
d8578bc1cb first full cut 2025-01-30 15:21:52 -08:00
pablonyx
5e21dc6cb3 Optimize /persona query (#3859)
* k

* delete

* k
2025-01-30 23:20:19 +00:00
Weves
39b3a503b4 Add more group sync logging 2025-01-30 14:42:14 -08:00
pablonyx
a70d472b5c Update e2e frontend tests (#3843)
* fix input prompts

* assistant ordering validation

* k

* Revert "fix input prompts"

This reverts commit a4b577bdd7.

* fix alembic

* foreign key updates

* Revert "foreign key updates"

This reverts commit fe17795a037f831790d69229e1067ccb5aab5bd9.

* improve e2e tests

* fix admin
2025-01-30 20:15:29 +00:00
devin-ai-integration[bot]
0ed2886ad0 Can't create starter messages for existing assistants. (#3825)
* fix: move starter messages out of advanced options for better visibility

Co-Authored-By: Chris Weaver <chris@onyx.app>

* fix: ensure starter message input field is visible in edit flow

Co-Authored-By: Chris Weaver <chris@onyx.app>

* chore: fix prettier formatting

Co-Authored-By: Chris Weaver <chris@onyx.app>

* chore: fix prettier formatting for starter messages description

Co-Authored-By: Chris Weaver <chris@onyx.app>

* chore: fix prettier formatting for starter messages initialization

Co-Authored-By: Chris Weaver <chris@onyx.app>

* fix: prevent unintended deletion of second message in StarterMessagesList

Co-Authored-By: Chris Weaver <chris@onyx.app>

* Fix empty starter messages

---------

Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
Co-authored-by: Chris Weaver <chris@onyx.app>
Co-authored-by: Weves <chrisweaver101@gmail.com>
2025-01-30 10:26:54 -08:00
pablodanswer
6b31e2f622 remove cloud_kubes from public repo 2025-01-30 09:52:57 -08:00
hagen-danswer
aabf8a99bc Fixed SharePoint connector polling (#3834)
* Fixed SharePoint connector polling

* finish

* fix sharepoint connector
2025-01-30 17:43:11 +00:00
Richard Kuo (Danswer)
7ccfe85ee5 WIP 2025-01-29 22:52:21 -08:00
Chris Weaver
95701db1bd Add more sync records + fix small bug in monitoring task causing deletion metrics to never be emitted (#3837)
Double check we don't double-emit + fix pruning metric

Add log

Fix comment

rename
2025-01-29 18:03:49 -08:00
rkuo-danswer
24105254ac fix race condition with permission sync and fences (#3841)
* fix race condition with permission sync and fences

* comments

* set the fence

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-01-29 23:40:44 +00:00
rkuo-danswer
4fe99d05fd add timings for syncing (#3798)
* add timings for syncing

* add more logging

* more debugging

* refactor multipass/db check out of VespaIndex

* circular imports?

* more debugging

* add logs

* various improvements

* additional logs to narrow down issue

* use global httpx pool for the main vespa flows in celery. Use in more places eventually.

* cleanup debug logging, etc

* remove debug logging

* this should use the secondary index

* mypy

* missed some logging

* review fixes

* refactor get_default_document_index to use search settings

* more missed logging

* fix circular refs

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
Co-authored-by: pablodanswer <pablo@danswer.ai>
2025-01-29 23:24:44 +00:00
pablonyx
d35f93b233 k (#3838) 2025-01-29 22:39:48 +00:00
hagen-danswer
766b0f35df Lowercase all user emails (#3830) 2025-01-29 19:09:06 +00:00
evan-danswer
a0470a96eb removed logic to search first message, fixed query override (#3812) 2025-01-29 19:02:29 +00:00
devin-ai-integration[bot]
b82123563b Fix Unicode sanitization for Vespa document indexing (#3831)
* Add support for filtering 0xFDD0-0xFDEF Unicode range

- Update remove_invalid_unicode_chars to handle 0xFDD0-0xFDEF range
- Add comprehensive test cases for Unicode character sanitization
- Fix issue with illegal code point 0xFDDB in Vespa indexing

Co-Authored-By: Chris Weaver <chris@onyx.app>

* Remove unused pytest import

Co-Authored-By: Chris Weaver <chris@onyx.app>

---------

Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
Co-authored-by: Chris Weaver <chris@onyx.app>
2025-01-29 18:32:00 +00:00
rkuo-danswer
787e25cd78 Merge pull request #3823 from onyx-dot-app/bugfix/sharepoint_app_init
app should be initialized once per connector
2025-01-28 23:55:09 -08:00
pablonyx
c6375f8abf Tool id constants (#3827)
* tool id constants

* clarification
2025-01-29 06:33:31 +00:00
Richard Kuo (Danswer)
58e5deba01 Merge branch 'main' of https://github.com/onyx-dot-app/onyx into bugfix/sharepoint_app_init
# Conflicts:
#	backend/onyx/connectors/sharepoint/connector.py
2025-01-28 21:11:13 -08:00
Chris Weaver
028e877342 Sharepoint fixes (#3826)
* Sharepoint connector fixes

* Refactor sharepoint to be better

* Improve env variable naming

* Fix

* Add new secrets

* Fix unstructured failure
2025-01-28 20:06:09 -08:00
Richard Kuo (Danswer)
47bff2b6a9 missed init 2025-01-28 19:11:38 -08:00
Richard Kuo (Danswer)
1502bcea12 do teams too 2025-01-28 19:03:54 -08:00
pablonyx
2701f83634 llm provider re-org (#3810)
* nit

* clean up logic

* update
2025-01-29 02:44:50 +00:00
pablonyx
601037abb5 Customer love (#3813)
* additional logs

* disable gdrive oauth

* Revert "additional ogs"

This reverts commit 1bd7f9d433.
2025-01-28 17:42:28 -08:00
devin-ai-integration[bot]
7e9b12403a Allow Slack workflow messages when respond_to_bots is enabled (#3819)
* Allow workflow 'bot_message' subtype when respond_to_bots is enabled

Co-Authored-By: Chris Weaver <chris@onyx.app>

* refactor: consolidate bot message checks to avoid redundant code

Co-Authored-By: Chris Weaver <chris@onyx.app>

* style: fix black formatting

Co-Authored-By: Chris Weaver <chris@onyx.app>

* Remove unnecessary call

---------

Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
Co-authored-by: Chris Weaver <chris@onyx.app>
Co-authored-by: Weves <chrisweaver101@gmail.com>
2025-01-28 17:29:23 -08:00
devin-ai-integration[bot]
d903e5912a feat: add option to treat all non-attachment fields as metadata in Airtable connector (#3817)
* feat: add option to treat all non-attachment fields as metadata in Airtable connector

- Added new UI option 'treat_all_non_attachment_fields_as_metadata'
- Updated backend logic to support treating all fields except attachments as metadata
- Added tests for both default and all-metadata behaviors

Co-Authored-By: Chris Weaver <chris@onyx.app>

* fix: handle missing environment variables gracefully in airtable tests

Co-Authored-By: Chris Weaver <chris@onyx.app>

* fix: clean up test file and handle environment variables properly

Co-Authored-By: Chris Weaver <chris@onyx.app>

* fix: add missing test fixture and fix formatting

Co-Authored-By: Chris Weaver <chris@onyx.app>

* chore: fix black formatting

Co-Authored-By: Chris Weaver <chris@onyx.app>

* fix: add type annotation for metadata dict in airtable tests

Co-Authored-By: Chris Weaver <chris@onyx.app>

* fix: add type annotation for mock_get_api_key fixture

Co-Authored-By: Chris Weaver <chris@onyx.app>

* fix: update Generator import to use collections.abc

Co-Authored-By: Chris Weaver <chris@onyx.app>

* refactor: make treat_all_non_attachment_fields_as_metadata a direct required parameter

- Move parameter from connector_config to direct class parameter
- Place parameter right under table_name_or_id argument
- Make parameter required in UI with no default value
- Update tests to use new parameter structure

Co-Authored-By: Chris Weaver <chris@onyx.app>

* chore: fix black formatting

Co-Authored-By: Chris Weaver <chris@onyx.app>

* chore: rename _METADATA_FIELD_TYPES to DEFAULT_METADATA_FIELD_TYPES and clarify usage

Co-Authored-By: Chris Weaver <chris@onyx.app>

* chore: fix black formatting in docstring

Co-Authored-By: Chris Weaver <chris@onyx.app>

* test: make airtable tests fail loudly on missing env vars

Co-Authored-By: Chris Weaver <chris@onyx.app>

* style: fix black formatting in test file

Co-Authored-By: Chris Weaver <chris@onyx.app>

* style: add required newline between test functions

Co-Authored-By: Chris Weaver <chris@onyx.app>

* test: update error message pattern in parameter validation test

Co-Authored-By: Chris Weaver <chris@onyx.app>

* style: fix black formatting in test file

Co-Authored-By: Chris Weaver <chris@onyx.app>

* test: fix error message pattern in parameter validation test

Co-Authored-By: Chris Weaver <chris@onyx.app>

* style: fix line length in test file

Co-Authored-By: Chris Weaver <chris@onyx.app>

* test: simplify error message pattern in parameter validation test

Co-Authored-By: Chris Weaver <chris@onyx.app>

* test: add type validation test for treat_all_non_attachment_fields_as_metadata

Co-Authored-By: Chris Weaver <chris@onyx.app>

* fix: add missing required parameter in test

Co-Authored-By: Chris Weaver <chris@onyx.app>

* fix: remove parameter from test to properly validate it is required

Co-Authored-By: Chris Weaver <chris@onyx.app>

* fix: add type validation for treat_all_non_attachment_fields_as_metadata parameter

Co-Authored-By: Chris Weaver <chris@onyx.app>

* style: fix black formatting in airtable_connector.py

Co-Authored-By: Chris Weaver <chris@onyx.app>

* fix: update type validation test to handle mypy errors

Co-Authored-By: Chris Weaver <chris@onyx.app>

* fix: specify mypy ignore type for call-arg

Co-Authored-By: Chris Weaver <chris@onyx.app>

* Also handle rows w/o sections

* style: fix black formatting in test assertion

Co-Authored-By: Chris Weaver <chris@onyx.app>

* add TODO

* Remove unnecessary check

* Fix test

* Do not break existing airtable connectors

---------

Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
Co-authored-by: Chris Weaver <chris@onyx.app>
Co-authored-by: Weves <chrisweaver101@gmail.com>
2025-01-28 17:28:32 -08:00
pablonyx
d2aea63573 Merge pull request #3824 from onyx-dot-app/naming
Fix search tool name
2025-01-28 16:57:02 -08:00
pablodanswer
57b4639709 fix name 2025-01-28 16:52:00 -08:00
Richard Kuo (Danswer)
1308b6cbe8 app should be initialized once per connector 2025-01-28 15:55:52 -08:00
rkuo-danswer
98abd7d3fa Merge pull request #3821 from onyx-dot-app/bugfix/google_drive_test_fix
don't duplicate test module names
2025-01-28 15:29:55 -08:00
Richard Kuo (Danswer)
e4180cefba don't duplicate test module names 2025-01-28 15:24:05 -08:00
skylares
f67b5356fa Create google drive e2e test (#3635)
* Create e2e google drive test

* Drive sync issue

* Add endpoints for group syncing

* google e2e fixes/improvements and add xfail to zendesk tests

* mypy errors

* Key change

* Small changes

* Merged main to fix group sync issue

* Update test_permission_sync.py

* Update google_drive_api_utils.py

* Update test_zendesk_connector.py

---------

Co-authored-by: hagen-danswer <hagen@danswer.ai>
2025-01-28 14:12:57 -08:00
pablonyx
9bdb581220 Update slack configs (#3776)
* update

* fix build
2025-01-28 21:10:09 +00:00
pablonyx
42d6d935ae continue on internal error (#3728) 2025-01-28 20:19:07 +00:00
pablonyx
8d62b992ef Double check all chat accessible dependencies (#3801)
* double check all chat accessible dependencies

* k

* k

* k

* k

* k

* k
2025-01-28 17:38:32 +00:00
pablonyx
2ad86aa9a6 Unstructured fix (#3809)
* fix v1

* temporary patch for pdfs

* nit
2025-01-28 16:46:27 +00:00
pablonyx
74a472ece7 Remove checkmark
Remove checkmark
2025-01-27 22:38:22 -08:00
pablodanswer
b2ce848b53 add fix 2025-01-27 21:54:20 -08:00
pablonyx
519ec20d05 Feedback (#3800)
* k

* k:wq

* update user auth

* update
2025-01-28 03:13:21 +00:00
pablodanswer
3b1e26d0d4 remove checkmark 2025-01-27 19:12:49 -08:00
pablonyx
118d2b52e6 Improvements for web build (#3786)
* k

* improvements for web build
2025-01-27 20:40:06 +00:00
pablonyx
e625884702 Chat Touchups (#3775) 2025-01-27 12:30:43 -08:00
rkuo-danswer
fa78f50fe3 Bugfix/celery ignore result (#3770)
* try using a redis replica in some areas

* harden up replica usage

* ignore results

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-01-27 08:53:01 +00:00
Yuhong Sun
05ab94945b Fix Sharepoint Folder Parsing (#3791) 2025-01-26 16:45:24 -08:00
Yuhong Sun
7a64a25ff4 Fix Confluence Missing Labels (#3788) 2025-01-26 14:05:02 -08:00
pablonyx
7f10494bbe Better vespa interface (#3781)
* k

* much cleaner vespa util class

* log

* typing

* improvement

* improve
2025-01-26 21:22:44 +00:00
pablodanswer
f2d4024783 improve base page latency 2025-01-26 11:44:34 -08:00
pablonyx
70795a4047 Sync status improvements (#3782)
* minor improvments / clarity

* additional comment for clarity

* typing

* quick updates to monitoring

* connector deletion

* quick nit

* fix typing

* update values

* quick nit

* functioning

* improvements to monitoring

* update

* minutes -> seconds
2025-01-26 17:35:26 +00:00
rkuo-danswer
d8a17a7238 try using a redis replica in some areas (#3748)
* try using a redis replica in some areas

* harden up replica usage

* comment

* slow down cloud dispatch temporarily

* add ignored syncing list back

* raise multiplier to 8

* comment out per tenant code (no longer used by fanout)

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-01-26 03:48:25 +00:00
Yuhong Sun
cbf98c0128 Fix Seeding Link for Support Use Case (#3784) 2025-01-25 19:39:36 -08:00
411 changed files with 20236 additions and 4671 deletions

View File

@@ -8,6 +8,8 @@ on: push
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
MOCK_LLM_RESPONSE: true
jobs:
playwright-tests:

View File

@@ -21,10 +21,10 @@ jobs:
- name: Set up Helm
uses: azure/setup-helm@v4.2.0
with:
version: v3.14.4
version: v3.17.0
- name: Set up chart-testing
uses: helm/chart-testing-action@v2.6.1
uses: helm/chart-testing-action@v2.7.0
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
- name: Run chart-testing (list-changed)
@@ -37,22 +37,6 @@ jobs:
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
# rkuo: I don't think we need python?
# - name: Set up Python
# uses: actions/setup-python@v5
# with:
# python-version: '3.11'
# cache: 'pip'
# cache-dependency-path: |
# backend/requirements/default.txt
# backend/requirements/dev.txt
# backend/requirements/model_server.txt
# - run: |
# python -m pip install --upgrade pip
# pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
# pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
# pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
# lint all charts if any changes were detected
- name: Run chart-testing (lint)
if: steps.list-changed.outputs.changed == 'true'
@@ -62,7 +46,7 @@ jobs:
- name: Create kind cluster
if: steps.list-changed.outputs.changed == 'true'
uses: helm/kind-action@v1.10.0
uses: helm/kind-action@v1.12.0
- name: Run chart-testing (install)
if: steps.list-changed.outputs.changed == 'true'

View File

@@ -39,6 +39,12 @@ env:
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
# Sharepoint
SHAREPOINT_CLIENT_ID: ${{ secrets.SHAREPOINT_CLIENT_ID }}
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
jobs:
connectors-check:
# See https://runs-on.com/runners/linux/

4
.gitignore vendored
View File

@@ -7,4 +7,6 @@
.vscode/
*.sw?
/backend/tests/regression/answer_quality/search_test_config.yaml
/web/test-results/
/web/test-results/
backend/onyx/agent_search/main/test_data.json
backend/tests/regression/answer_quality/test_data.json

View File

@@ -52,3 +52,9 @@ BING_API_KEY=<REPLACE THIS>
# Enable the full set of Danswer Enterprise Edition features
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
# Agent Search configs # TODO: Remove give proper namings
AGENT_RETRIEVAL_STATS=False # Note: This setting will incur substantial re-ranking effort
AGENT_RERANKING_STATS=True
AGENT_MAX_QUERY_RETRIEVAL_RESULTS=20
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS=20

View File

@@ -124,7 +124,7 @@ There are two editions of Onyx:
To try the Onyx Enterprise Edition:
1. Checkout our [Cloud product](https://cloud.onyx.app/signup).
2. For self-hosting, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/danswer/founders).
2. For self-hosting, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/onyx/founders).
## 💡 Contributing

View File

@@ -0,0 +1,36 @@
"""add chat session specific temperature override
Revision ID: 2f80c6a2550f
Revises: 33ea50e88f24
Create Date: 2025-01-31 10:30:27.289646
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2f80c6a2550f"
down_revision = "33ea50e88f24"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"chat_session", sa.Column("temperature_override", sa.Float(), nullable=True)
)
op.add_column(
"user",
sa.Column(
"temperature_override_enabled",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
)
def downgrade() -> None:
op.drop_column("chat_session", "temperature_override")
op.drop_column("user", "temperature_override_enabled")

View File

@@ -0,0 +1,80 @@
"""foreign key input prompts
Revision ID: 33ea50e88f24
Revises: a6df6b88ef81
Create Date: 2025-01-29 10:54:22.141765
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "33ea50e88f24"
down_revision = "a6df6b88ef81"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Safely drop constraints if exists
op.execute(
"""
ALTER TABLE inputprompt__user
DROP CONSTRAINT IF EXISTS inputprompt__user_input_prompt_id_fkey
"""
)
op.execute(
"""
ALTER TABLE inputprompt__user
DROP CONSTRAINT IF EXISTS inputprompt__user_user_id_fkey
"""
)
# Recreate with ON DELETE CASCADE
op.create_foreign_key(
"inputprompt__user_input_prompt_id_fkey",
"inputprompt__user",
"inputprompt",
["input_prompt_id"],
["id"],
ondelete="CASCADE",
)
op.create_foreign_key(
"inputprompt__user_user_id_fkey",
"inputprompt__user",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
def downgrade() -> None:
# Drop the new FKs with ondelete
op.drop_constraint(
"inputprompt__user_input_prompt_id_fkey",
"inputprompt__user",
type_="foreignkey",
)
op.drop_constraint(
"inputprompt__user_user_id_fkey",
"inputprompt__user",
type_="foreignkey",
)
# Recreate them without cascading
op.create_foreign_key(
"inputprompt__user_input_prompt_id_fkey",
"inputprompt__user",
"inputprompt",
["input_prompt_id"],
["id"],
)
op.create_foreign_key(
"inputprompt__user_user_id_fkey",
"inputprompt__user",
"user",
["user_id"],
["id"],
)

View File

@@ -0,0 +1,37 @@
"""lowercase_user_emails
Revision ID: 4d58345da04a
Revises: f1ca58b2f2ec
Create Date: 2025-01-29 07:48:46.784041
"""
from alembic import op
from sqlalchemy.sql import text
# revision identifiers, used by Alembic.
revision = "4d58345da04a"
down_revision = "f1ca58b2f2ec"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Get database connection
connection = op.get_bind()
# Update all user emails to lowercase
connection.execute(
text(
"""
UPDATE "user"
SET email = LOWER(email)
WHERE email != LOWER(email)
"""
)
)
def downgrade() -> None:
# Cannot restore original case of emails
pass

View File

@@ -0,0 +1,107 @@
"""agent_tracking
Revision ID: 98a5008d8711
Revises: 2f80c6a2550f
Create Date: 2025-01-29 17:00:00.000001
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision = "98a5008d8711"
down_revision = "2f80c6a2550f"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"agent__search_metrics",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=True),
sa.Column("persona_id", sa.Integer(), nullable=True),
sa.Column("agent_type", sa.String(), nullable=False),
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
sa.Column("base_duration_s", sa.Float(), nullable=False),
sa.Column("full_duration_s", sa.Float(), nullable=False),
sa.Column("base_metrics", postgresql.JSONB(), nullable=True),
sa.Column("refined_metrics", postgresql.JSONB(), nullable=True),
sa.Column("all_metrics", postgresql.JSONB(), nullable=True),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
# Create sub_question table
op.create_table(
"agent__sub_question",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("primary_question_id", sa.Integer, sa.ForeignKey("chat_message.id")),
sa.Column(
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
),
sa.Column("sub_question", sa.Text),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
),
sa.Column("sub_answer", sa.Text),
sa.Column("sub_question_doc_results", postgresql.JSONB(), nullable=True),
sa.Column("level", sa.Integer(), nullable=False),
sa.Column("level_question_num", sa.Integer(), nullable=False),
)
# Create sub_query table
op.create_table(
"agent__sub_query",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column(
"parent_question_id", sa.Integer, sa.ForeignKey("agent__sub_question.id")
),
sa.Column(
"chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id")
),
sa.Column("sub_query", sa.Text),
sa.Column(
"time_created", sa.DateTime(timezone=True), server_default=sa.func.now()
),
)
# Create sub_query__search_doc association table
op.create_table(
"agent__sub_query__search_doc",
sa.Column(
"sub_query_id",
sa.Integer,
sa.ForeignKey("agent__sub_query.id"),
primary_key=True,
),
sa.Column(
"search_doc_id",
sa.Integer,
sa.ForeignKey("search_doc.id"),
primary_key=True,
),
)
op.add_column(
"chat_message",
sa.Column(
"refined_answer_improvement",
sa.Boolean(),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("chat_message", "refined_answer_improvement")
op.drop_table("agent__sub_query__search_doc")
op.drop_table("agent__sub_query")
op.drop_table("agent__sub_question")
op.drop_table("agent__search_metrics")

View File

@@ -0,0 +1,29 @@
"""remove recent assistants
Revision ID: a6df6b88ef81
Revises: 4d58345da04a
Create Date: 2025-01-29 10:25:52.790407
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "a6df6b88ef81"
down_revision = "4d58345da04a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_column("user", "recent_assistants")
def downgrade() -> None:
op.add_column(
"user",
sa.Column(
"recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False
),
)

View File

@@ -0,0 +1,76 @@
"""add default slack channel config
Revision ID: eaa3b5593925
Revises: 98a5008d8711
Create Date: 2025-02-03 18:07:56.552526
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "eaa3b5593925"
down_revision = "98a5008d8711"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add is_default column
op.add_column(
"slack_channel_config",
sa.Column("is_default", sa.Boolean(), nullable=False, server_default="false"),
)
op.create_index(
"ix_slack_channel_config_slack_bot_id_default",
"slack_channel_config",
["slack_bot_id", "is_default"],
unique=True,
postgresql_where=sa.text("is_default IS TRUE"),
)
# Create default channel configs for existing slack bots without one
conn = op.get_bind()
slack_bots = conn.execute(sa.text("SELECT id FROM slack_bot")).fetchall()
for slack_bot in slack_bots:
slack_bot_id = slack_bot[0]
existing_default = conn.execute(
sa.text(
"SELECT id FROM slack_channel_config WHERE slack_bot_id = :bot_id AND is_default = TRUE"
),
{"bot_id": slack_bot_id},
).fetchone()
if not existing_default:
conn.execute(
sa.text(
"""
INSERT INTO slack_channel_config (
slack_bot_id, persona_id, channel_config, enable_auto_filters, is_default
) VALUES (
:bot_id, NULL,
'{"channel_name": null, "respond_member_group_list": [], "answer_filters": [], "follow_up_tags": []}',
FALSE, TRUE
)
"""
),
{"bot_id": slack_bot_id},
)
def downgrade() -> None:
# Delete default slack channel configs
conn = op.get_bind()
conn.execute(sa.text("DELETE FROM slack_channel_config WHERE is_default = TRUE"))
# Remove index
op.drop_index(
"ix_slack_channel_config_slack_bot_id_default",
table_name="slack_channel_config",
)
# Remove is_default column
op.drop_column("slack_channel_config", "is_default")

View File

@@ -32,6 +32,7 @@ def perform_ttl_management_task(
@celery_app.task(
name="check_ttl_management_task",
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def check_ttl_management_task(*, tenant_id: str | None) -> None:
@@ -56,6 +57,7 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
@celery_app.task(
name="autogenerate_usage_report_task",
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:

View File

@@ -13,6 +13,7 @@ from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.confluence.utils import get_user_email_from_username__server
from onyx.connectors.models import SlimDocument
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -257,6 +258,7 @@ def _fetch_all_page_restrictions(
slim_docs: list[SlimDocument],
space_permissions_by_space_key: dict[str, ExternalAccess],
is_cloud: bool,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
"""
For all pages, if a page has restrictions, then use those restrictions.
@@ -265,6 +267,12 @@ def _fetch_all_page_restrictions(
document_restrictions: list[DocExternalAccess] = []
for slim_doc in slim_docs:
if callback:
if callback.should_stop():
raise RuntimeError("confluence_doc_sync: Stop signal detected")
callback.progress("confluence_doc_sync:fetch_all_page_restrictions", 1)
if slim_doc.perm_sync_data is None:
raise ValueError(
f"No permission sync data found for document {slim_doc.id}"
@@ -334,7 +342,7 @@ def _fetch_all_page_restrictions(
def confluence_doc_sync(
cc_pair: ConnectorCredentialPair,
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@@ -359,6 +367,12 @@ def confluence_doc_sync(
logger.debug("Fetching all slim documents from confluence")
for doc_batch in confluence_connector.retrieve_all_slim_documents():
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
if callback:
if callback.should_stop():
raise RuntimeError("confluence_doc_sync: Stop signal detected")
callback.progress("confluence_doc_sync", 1)
slim_docs.extend(doc_batch)
logger.debug("Fetching all page restrictions for space")
@@ -367,4 +381,5 @@ def confluence_doc_sync(
slim_docs=slim_docs,
space_permissions_by_space_key=space_permissions_by_space_key,
is_cloud=is_cloud,
callback=callback,
)

View File

@@ -14,6 +14,8 @@ def _build_group_member_email_map(
) -> dict[str, set[str]]:
group_member_emails: dict[str, set[str]] = {}
for user_result in confluence_client.paginated_cql_user_retrieval():
logger.debug(f"Processing groups for user: {user_result}")
user = user_result.get("user", {})
if not user:
logger.warning(f"user result missing user field: {user_result}")
@@ -33,10 +35,17 @@ def _build_group_member_email_map(
logger.warning(f"user result missing email field: {user_result}")
continue
all_users_groups: set[str] = set()
for group in confluence_client.paginated_groups_by_user_retrieval(user):
# group name uniqueness is enforced by Confluence, so we can use it as a group ID
group_id = group["name"]
group_member_emails.setdefault(group_id, set()).add(email)
all_users_groups.add(group_id)
if not group_member_emails:
logger.warning(f"No groups found for user with email: {email}")
else:
logger.debug(f"Found groups {all_users_groups} for user with email {email}")
return group_member_emails

View File

@@ -6,6 +6,7 @@ from onyx.access.models import ExternalAccess
from onyx.connectors.gmail.connector import GmailConnector
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -28,7 +29,7 @@ def _get_slim_doc_generator(
def gmail_doc_sync(
cc_pair: ConnectorCredentialPair,
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@@ -44,6 +45,12 @@ def gmail_doc_sync(
document_external_access: list[DocExternalAccess] = []
for slim_doc_batch in slim_doc_generator:
for slim_doc in slim_doc_batch:
if callback:
if callback.should_stop():
raise RuntimeError("gmail_doc_sync: Stop signal detected")
callback.progress("gmail_doc_sync", 1)
if slim_doc.perm_sync_data is None:
logger.warning(f"No permissions found for document {slim_doc.id}")
continue

View File

@@ -10,6 +10,7 @@ from onyx.connectors.google_utils.resources import get_drive_service
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.models import SlimDocument
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -42,24 +43,22 @@ def _fetch_permissions_for_permission_ids(
if not permission_info or not doc_id:
return []
# Check cache first for all permission IDs
permissions = [
_PERMISSION_ID_PERMISSION_MAP[pid]
for pid in permission_ids
if pid in _PERMISSION_ID_PERMISSION_MAP
]
# If we found all permissions in cache, return them
if len(permissions) == len(permission_ids):
return permissions
owner_email = permission_info.get("owner_email")
drive_service = get_drive_service(
creds=google_drive_connector.creds,
user_email=(owner_email or google_drive_connector.primary_admin_email),
)
# Otherwise, fetch all permissions and update cache
fetched_permissions = execute_paginated_retrieval(
retrieval_function=drive_service.permissions().list,
list_key="permissions",
@@ -69,7 +68,6 @@ def _fetch_permissions_for_permission_ids(
)
permissions_for_doc_id = []
# Update cache and return all permissions
for permission in fetched_permissions:
permissions_for_doc_id.append(permission)
_PERMISSION_ID_PERMISSION_MAP[permission["id"]] = permission
@@ -131,7 +129,7 @@ def _get_permissions_from_slim_doc(
def gdrive_doc_sync(
cc_pair: ConnectorCredentialPair,
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@@ -149,6 +147,12 @@ def gdrive_doc_sync(
document_external_accesses = []
for slim_doc_batch in slim_doc_generator:
for slim_doc in slim_doc_batch:
if callback:
if callback.should_stop():
raise RuntimeError("gdrive_doc_sync: Stop signal detected")
callback.progress("gdrive_doc_sync", 1)
ext_access = _get_permissions_from_slim_doc(
google_drive_connector=google_drive_connector,
slim_doc=slim_doc,

View File

@@ -7,6 +7,7 @@ from onyx.connectors.slack.connector import get_channels
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
from onyx.connectors.slack.connector import SlackPollConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -14,7 +15,7 @@ logger = setup_logger()
def _get_slack_document_ids_and_channels(
cc_pair: ConnectorCredentialPair,
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
) -> dict[str, list[str]]:
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
slack_connector.load_credentials(cc_pair.credential.credential_json)
@@ -24,6 +25,14 @@ def _get_slack_document_ids_and_channels(
channel_doc_map: dict[str, list[str]] = {}
for doc_metadata_batch in slim_doc_generator:
for doc_metadata in doc_metadata_batch:
if callback:
if callback.should_stop():
raise RuntimeError(
"_get_slack_document_ids_and_channels: Stop signal detected"
)
callback.progress("_get_slack_document_ids_and_channels", 1)
if doc_metadata.perm_sync_data is None:
continue
channel_id = doc_metadata.perm_sync_data["channel_id"]
@@ -114,7 +123,7 @@ def _fetch_channel_permissions(
def slack_doc_sync(
cc_pair: ConnectorCredentialPair,
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@@ -127,7 +136,7 @@ def slack_doc_sync(
)
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
channel_doc_map = _get_slack_document_ids_and_channels(
cc_pair=cc_pair,
cc_pair=cc_pair, callback=callback
)
workspace_permissions = _fetch_workspace_permissions(
user_id_to_email_map=user_id_to_email_map,

View File

@@ -15,11 +15,13 @@ from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
# Defining the input/output types for the sync functions
DocSyncFuncType = Callable[
[
ConnectorCredentialPair,
IndexingHeartbeatInterface | None,
],
list[DocExternalAccess],
]

View File

@@ -80,7 +80,7 @@ def oneoff_standard_answers(
def _handle_standard_answers(
message_info: SlackMessageInfo,
receiver_ids: list[str] | None,
slack_channel_config: SlackChannelConfig | None,
slack_channel_config: SlackChannelConfig,
prompt: Prompt | None,
logger: OnyxLoggingAdapter,
client: WebClient,
@@ -94,13 +94,10 @@ def _handle_standard_answers(
Returns True if standard answers are found to match the user's message and therefore,
we still need to respond to the users.
"""
# if no channel config, then no standard answers are configured
if not slack_channel_config:
return False
slack_thread_id = message_info.thread_to_respond
configured_standard_answer_categories = (
slack_channel_config.standard_answer_categories if slack_channel_config else []
slack_channel_config.standard_answer_categories
)
configured_standard_answers = set(
[

View File

@@ -10,6 +10,7 @@ from fastapi import Response
from ee.onyx.auth.users import decode_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from onyx.auth.api_key import extract_tenant_from_api_key_header
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
from onyx.db.engine import is_valid_schema_name
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis
from shared_configs.configs import MULTI_TENANT
@@ -43,6 +44,7 @@ async def _get_tenant_id_from_request(
Attempt to extract tenant_id from:
1) The API key header
2) The Redis-based token (stored in Cookie: fastapiusersauth)
3) Reset token cookie
Fallback: POSTGRES_DEFAULT_SCHEMA
"""
# Check for API key
@@ -90,3 +92,12 @@ async def _get_tenant_id_from_request(
except Exception as e:
logger.error(f"Unexpected error in _get_tenant_id_from_request: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
finally:
# As a final step, check for explicit tenant_id cookie
tenant_id_cookie = request.cookies.get(TENANT_ID_COOKIE_NAME)
if tenant_id_cookie and is_valid_schema_name(tenant_id_cookie):
return tenant_id_cookie
# If we've reached this point, return the default schema
return POSTGRES_DEFAULT_SCHEMA

View File

@@ -286,6 +286,7 @@ def prepare_authorization_request(
oauth_state = (
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
)
session: str
if connector == DocumentSource.SLACK:
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
@@ -554,6 +555,7 @@ def handle_google_drive_oauth_callback(
)
session_json = session_json_bytes.decode("utf-8")
session: GoogleDriveOAuth.OAuthSession
try:
session = GoogleDriveOAuth.parse_session(session_json)

View File

@@ -179,6 +179,7 @@ def handle_simplified_chat_message(
chunks_below=0,
full_doc=chat_message_req.full_doc,
structured_response_format=chat_message_req.structured_response_format,
use_agentic_search=chat_message_req.use_agentic_search,
)
packets = stream_chat_message_objects(
@@ -301,6 +302,7 @@ def handle_send_message_simple_with_history(
chunks_below=0,
full_doc=req.full_doc,
structured_response_format=req.structured_response_format,
use_agentic_search=req.use_agentic_search,
)
packets = stream_chat_message_objects(

View File

@@ -57,6 +57,9 @@ class BasicCreateChatMessageRequest(ChunkContext):
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
# If True, uses agentic search instead of basic search
use_agentic_search: bool = False
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
# Last element is the new query. All previous elements are historical context
@@ -71,6 +74,8 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
# only works if using an OpenAI model. See the following for more details:
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
# If True, uses agentic search instead of basic search
use_agentic_search: bool = False
class SimpleDoc(BaseModel):
@@ -120,9 +125,12 @@ class OneShotQARequest(ChunkContext):
# will also disable Thread-based Rewording if specified
query_override: str | None = None
# If True, skips generative an AI response to the search query
# If True, skips generating an AI response to the search query
skip_gen_ai_answer_generation: bool = False
# If True, uses agentic search instead of basic search
use_agentic_search: bool = False
@model_validator(mode="after")
def check_persona_fields(self) -> "OneShotQARequest":
if self.persona_override_config is None and self.persona_id is None:

View File

@@ -196,6 +196,8 @@ def get_answer_stream(
retrieval_details=query_request.retrieval_options,
rerank_settings=query_request.rerank_settings,
db_session=db_session,
use_agentic_search=query_request.use_agentic_search,
skip_gen_ai_answer_generation=query_request.skip_gen_ai_answer_generation,
)
packets = stream_chat_message_objects(

View File

@@ -34,6 +34,7 @@ from onyx.auth.users import get_redis_strategy
from onyx.auth.users import optional_user
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.auth import get_user_count
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
@@ -111,6 +112,7 @@ async def login_as_anonymous_user(
token = generate_anonymous_user_jwt_token(tenant_id)
response = Response()
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
response.set_cookie(
key=ANONYMOUS_USER_COOKIE_NAME,
value=token,

View File

@@ -58,6 +58,7 @@ class UserGroup(BaseModel):
credential=CredentialSnapshot.from_credential_db_model(
cc_pair_relationship.cc_pair.credential
),
access_type=cc_pair_relationship.cc_pair.access_type,
)
for cc_pair_relationship in user_group_model.cc_pair_relationships
if cc_pair_relationship.is_current

View File

@@ -0,0 +1,97 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.basic.states import BasicInput
from onyx.agents.agent_search.basic.states import BasicOutput
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
basic_use_tool_response,
)
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
prepare_tool_input,
)
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
from onyx.utils.logger import setup_logger
logger = setup_logger()
def basic_graph_builder() -> StateGraph:
graph = StateGraph(
state_schema=BasicState,
input=BasicInput,
output=BasicOutput,
)
### Add nodes ###
graph.add_node(
node="prepare_tool_input",
action=prepare_tool_input,
)
graph.add_node(
node="llm_tool_choice",
action=llm_tool_choice,
)
graph.add_node(
node="tool_call",
action=tool_call,
)
graph.add_node(
node="basic_use_tool_response",
action=basic_use_tool_response,
)
### Add edges ###
graph.add_edge(start_key=START, end_key="prepare_tool_input")
graph.add_edge(start_key="prepare_tool_input", end_key="llm_tool_choice")
graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END])
graph.add_edge(
start_key="tool_call",
end_key="basic_use_tool_response",
)
graph.add_edge(
start_key="basic_use_tool_response",
end_key=END,
)
return graph
def should_continue(state: BasicState) -> str:
return (
# If there are no tool calls, basic graph already streamed the answer
END
if state.tool_choice is None
else "tool_call"
)
if __name__ == "__main__":
from onyx.db.engine import get_session_context_manager
from onyx.context.search.models import SearchRequest
from onyx.llm.factory import get_default_llms
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
graph = basic_graph_builder()
compiled_graph = graph.compile()
input = BasicInput(_unused=True)
primary_llm, fast_llm = get_default_llms()
with get_session_context_manager() as db_session:
config, _ = get_test_config(
db_session=db_session,
primary_llm=primary_llm,
fast_llm=fast_llm,
search_request=SearchRequest(query="How does onyx use FastAPI?"),
)
compiled_graph.invoke(input, config={"metadata": {"config": config}})

View File

@@ -0,0 +1,35 @@
from typing import TypedDict
from langchain_core.messages import AIMessageChunk
from pydantic import BaseModel
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
# States contain values that change over the course of graph execution,
# Config is for values that are set at the start and never change.
# If you are using a value from the config and realize it needs to change,
# you should add it to the state and use/update the version in the state.
## Graph Input State
class BasicInput(BaseModel):
# Langgraph needs a nonempty input, but we pass in all static
# data through a RunnableConfig.
_unused: bool = True
## Graph Output State
class BasicOutput(TypedDict):
tool_call_chunk: AIMessageChunk
## Graph State
class BasicState(
BasicInput,
ToolChoiceInput,
ToolCallUpdate,
ToolChoiceUpdate,
):
pass

View File

@@ -0,0 +1,64 @@
from collections.abc import Iterator
from typing import cast
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from langgraph.types import StreamWriter
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
from onyx.chat.stream_processing.answer_response_handler import (
PassThroughAnswerResponseHandler,
)
from onyx.chat.stream_processing.utils import map_document_id_order
from onyx.utils.logger import setup_logger
logger = setup_logger()
def process_llm_stream(
messages: Iterator[BaseMessage],
should_stream_answer: bool,
writer: StreamWriter,
final_search_results: list[LlmDoc] | None = None,
displayed_search_results: list[OnyxContext] | list[LlmDoc] | None = None,
) -> AIMessageChunk:
tool_call_chunk = AIMessageChunk(content="")
if final_search_results and displayed_search_results:
answer_handler: AnswerResponseHandler = CitationResponseHandler(
context_docs=final_search_results,
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
)
else:
answer_handler = PassThroughAnswerResponseHandler()
full_answer = ""
# This stream will be the llm answer if no tool is chosen. When a tool is chosen,
# the stream will contain AIMessageChunks with tool call information.
for message in messages:
answer_piece = message.content
if not isinstance(answer_piece, str):
# this is only used for logging, so fine to
# just add the string representation
answer_piece = str(answer_piece)
full_answer += answer_piece
if isinstance(message, AIMessageChunk) and (
message.tool_call_chunks or message.tool_calls
):
tool_call_chunk += message # type: ignore
elif should_stream_answer:
for response_part in answer_handler.handle_response_part(message, []):
write_custom_event(
"basic_response",
response_part,
writer,
)
logger.debug(f"Full answer: {full_answer}")
return cast(AIMessageChunk, tool_call_chunk)

View File

@@ -0,0 +1,21 @@
from operator import add
from typing import Annotated
from pydantic import BaseModel
class CoreState(BaseModel):
"""
This is the core state that is shared across all subgraphs.
"""
base_question: str = ""
log_messages: Annotated[list[str], add] = []
class SubgraphCoreState(BaseModel):
"""
This is the core state that is shared across all subgraphs.
"""
log_messages: Annotated[list[str], add]

View File

@@ -0,0 +1,31 @@
from collections.abc import Hashable
from datetime import datetime
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def send_to_expanded_retrieval(state: SubQuestionAnsweringInput) -> Send | Hashable:
"""
LangGraph edge to send a sub-question to the expanded retrieval.
"""
edge_start_time = datetime.now()
return Send(
"initial_sub_question_expanded_retrieval",
ExpandedRetrievalInput(
question=state.question,
base_search=False,
sub_question_id=state.question_id,
log_messages=[f"{edge_start_time} -- Sending to expanded retrieval"],
),
)

View File

@@ -0,0 +1,137 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.edges import (
send_to_expanded_retrieval,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.check_sub_answer import (
check_sub_answer,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.format_sub_answer import (
format_sub_answer,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.generate_sub_answer import (
generate_sub_answer,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.ingest_retrieved_documents import (
ingest_retrieved_documents,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.graph_builder import (
expanded_retrieval_graph_builder,
)
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.utils.logger import setup_logger
logger = setup_logger()
def answer_query_graph_builder() -> StateGraph:
"""
LangGraph sub-graph builder for the initial individual sub-answer generation.
"""
graph = StateGraph(
state_schema=AnswerQuestionState,
input=SubQuestionAnsweringInput,
output=AnswerQuestionOutput,
)
### Add nodes ###
# The sub-graph that executes the expanded retrieval process for a sub-question
expanded_retrieval = expanded_retrieval_graph_builder().compile()
graph.add_node(
node="initial_sub_question_expanded_retrieval",
action=expanded_retrieval,
)
# The node that ingests the retrieved documents and puts them into the proper
# state keys.
graph.add_node(
node="ingest_retrieval",
action=ingest_retrieved_documents,
)
# The node that generates the sub-answer
graph.add_node(
node="generate_sub_answer",
action=generate_sub_answer,
)
# The node that checks the sub-answer
graph.add_node(
node="answer_check",
action=check_sub_answer,
)
# The node that formats the sub-answer for the following initial answer generation
graph.add_node(
node="format_answer",
action=format_sub_answer,
)
### Add edges ###
graph.add_conditional_edges(
source=START,
path=send_to_expanded_retrieval,
path_map=["initial_sub_question_expanded_retrieval"],
)
graph.add_edge(
start_key="initial_sub_question_expanded_retrieval",
end_key="ingest_retrieval",
)
graph.add_edge(
start_key="ingest_retrieval",
end_key="generate_sub_answer",
)
graph.add_edge(
start_key="generate_sub_answer",
end_key="answer_check",
)
graph.add_edge(
start_key="answer_check",
end_key="format_answer",
)
graph.add_edge(
start_key="format_answer",
end_key=END,
)
return graph
if __name__ == "__main__":
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
graph = answer_query_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
with get_session_context_manager() as db_session:
graph_config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
inputs = SubQuestionAnsweringInput(
question="what can you do with onyx?",
question_id="0_0",
log_messages=[],
)
for thing in compiled_graph.stream(
input=inputs,
config={"configurable": {"config": graph_config}},
):
logger.debug(thing)

View File

@@ -0,0 +1,75 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnswerCheckUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.prompts.agent_search import SUB_ANSWER_CHECK_PROMPT
from onyx.prompts.agent_search import UNKNOWN_ANSWER
def check_sub_answer(
state: AnswerQuestionState, config: RunnableConfig
) -> SubQuestionAnswerCheckUpdate:
"""
LangGraph node to check the quality of the sub-answer. The answer
is represented as a boolean value.
"""
node_start_time = datetime.now()
level, question_num = parse_question_id(state.question_id)
if state.answer == UNKNOWN_ANSWER:
return SubQuestionAnswerCheckUpdate(
answer_quality=False,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate individual sub answer",
node_name="check sub answer",
node_start_time=node_start_time,
result="unknown answer",
)
],
)
msg = [
HumanMessage(
content=SUB_ANSWER_CHECK_PROMPT.format(
question=state.question,
base_answer=state.answer,
)
)
]
graph_config = cast(GraphConfig, config["metadata"]["config"])
fast_llm = graph_config.tooling.fast_llm
response = list(
fast_llm.stream(
prompt=msg,
)
)
quality_str: str = merge_message_runs(response, chunk_separator="")[0].content
answer_quality = "yes" in quality_str.lower()
return SubQuestionAnswerCheckUpdate(
answer_quality=answer_quality,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate individual sub answer",
node_name="check sub answer",
node_start_time=node_start_time,
result=f"Answer quality: {quality_str}",
)
],
)

View File

@@ -0,0 +1,30 @@
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.shared_graph_utils.models import (
SubQuestionAnswerResults,
)
def format_sub_answer(state: AnswerQuestionState) -> AnswerQuestionOutput:
"""
LangGraph node to generate the sub-answer format.
"""
return AnswerQuestionOutput(
answer_results=[
SubQuestionAnswerResults(
question=state.question,
question_id=state.question_id,
verified_high_quality=state.answer_quality,
answer=state.answer,
sub_query_retrieval_results=state.expanded_retrieval_results,
verified_reranked_documents=state.verified_reranked_documents,
context_documents=state.context_documents,
cited_documents=state.cited_documents,
sub_question_retrieval_stats=state.sub_question_retrieval_stats,
)
],
)

View File

@@ -0,0 +1,137 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.messages import merge_message_runs
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnswerGenerationUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_sub_question_answer_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_persona_agent_prompt_expressions,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.prompts.agent_search import NO_RECOVERED_DOCS
from onyx.utils.logger import setup_logger
logger = setup_logger()
def generate_sub_answer(
state: AnswerQuestionState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> SubQuestionAnswerGenerationUpdate:
"""
LangGraph node to generate a sub-answer.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = state.question
state.verified_reranked_documents
level, question_num = parse_question_id(state.question_id)
context_docs = state.context_documents[:AGENT_MAX_ANSWER_CONTEXT_DOCS]
persona_contextualized_prompt = get_persona_agent_prompt_expressions(
graph_config.inputs.search_request.persona
).contextualized_prompt
if len(context_docs) == 0:
answer_str = NO_RECOVERED_DOCS
write_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=answer_str,
level=level,
level_question_num=question_num,
answer_type="agent_sub_answer",
),
writer,
)
else:
fast_llm = graph_config.tooling.fast_llm
msg = build_sub_question_answer_prompt(
question=question,
original_question=graph_config.inputs.search_request.query,
docs=context_docs,
persona_specification=persona_contextualized_prompt,
config=fast_llm.config,
)
response: list[str | list[str | dict[str, Any]]] = []
dispatch_timings: list[float] = []
for message in fast_llm.stream(
prompt=msg,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=content,
level=level,
level_question_num=question_num,
answer_type="agent_sub_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
response.append(content)
answer_str = merge_message_runs(response, chunk_separator="")[0].content
logger.debug(
f"Average dispatch time: {sum(dispatch_timings) / len(dispatch_timings)}"
)
answer_citation_ids = get_answer_citation_ids(answer_str)
cited_documents = [
context_docs[id] for id in answer_citation_ids if id < len(context_docs)
]
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.SUB_ANSWER,
level=level,
level_question_num=question_num,
)
write_custom_event("stream_finished", stop_event, writer)
return SubQuestionAnswerGenerationUpdate(
answer=answer_str,
cited_documents=cited_documents,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate individual sub answer",
node_name="generate sub answer",
node_start_time=node_start_time,
result="",
)
],
)

View File

@@ -0,0 +1,25 @@
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionRetrievalIngestionUpdate,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalOutput,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
def ingest_retrieved_documents(
state: ExpandedRetrievalOutput,
) -> SubQuestionRetrievalIngestionUpdate:
"""
LangGraph node to ingest the retrieved documents to format it for the sub-answer.
"""
sub_question_retrieval_stats = state.expanded_retrieval_result.retrieval_stats
if sub_question_retrieval_stats is None:
sub_question_retrieval_stats = [AgentChunkRetrievalStats()]
return SubQuestionRetrievalIngestionUpdate(
expanded_retrieval_results=state.expanded_retrieval_result.expanded_query_results,
verified_reranked_documents=state.expanded_retrieval_result.verified_reranked_documents,
context_documents=state.expanded_retrieval_result.context_documents,
sub_question_retrieval_stats=sub_question_retrieval_stats,
)

View File

@@ -0,0 +1,75 @@
from operator import add
from typing import Annotated
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import SubgraphCoreState
from onyx.agents.agent_search.deep_search.main.states import LoggerUpdate
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
from onyx.agents.agent_search.shared_graph_utils.models import (
SubQuestionAnswerResults,
)
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.context.search.models import InferenceSection
## Update States
class SubQuestionAnswerCheckUpdate(LoggerUpdate, BaseModel):
answer_quality: bool = False
log_messages: list[str] = []
class SubQuestionAnswerGenerationUpdate(LoggerUpdate, BaseModel):
answer: str = ""
log_messages: list[str] = []
cited_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
# answer_stat: AnswerStats
class SubQuestionRetrievalIngestionUpdate(LoggerUpdate, BaseModel):
expanded_retrieval_results: list[QueryRetrievalResult] = []
verified_reranked_documents: Annotated[
list[InferenceSection], dedup_inference_sections
] = []
context_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
sub_question_retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats()
## Graph Input State
class SubQuestionAnsweringInput(SubgraphCoreState):
question: str = ""
question_id: str = (
"" # 0_0 is original question, everything else is <level>_<question_num>.
)
# level 0 is original question and first decomposition, level 1 is follow up, etc
# question_num is a unique number per original question per level.
## Graph State
class AnswerQuestionState(
SubQuestionAnsweringInput,
SubQuestionAnswerGenerationUpdate,
SubQuestionAnswerCheckUpdate,
SubQuestionRetrievalIngestionUpdate,
):
pass
## Graph Output State
class AnswerQuestionOutput(LoggerUpdate, BaseModel):
"""
This is a list of results even though each call of this subgraph only returns one result.
This is because if we parallelize the answer query subgraph, there will be multiple
results in a list so the add operator is used to add them together.
"""
answer_results: Annotated[list[SubQuestionAnswerResults], add] = []

View File

@@ -0,0 +1,50 @@
from collections.abc import Hashable
from datetime import datetime
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalState,
)
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
def parallelize_initial_sub_question_answering(
state: SubQuestionRetrievalState,
) -> list[Send | Hashable]:
"""
LangGraph edge to parallelize the initial sub-question answering. If there are no sub-questions,
we send empty answers to the initial answer generation, and that answer would be generated
solely based on the documents retrieved for the original question.
"""
edge_start_time = datetime.now()
if len(state.initial_sub_questions) > 0:
return [
Send(
"answer_query_subgraph",
SubQuestionAnsweringInput(
question=question,
question_id=make_question_id(0, question_num + 1),
log_messages=[
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
],
),
)
for question_num, question in enumerate(state.initial_sub_questions)
]
else:
return [
Send(
"ingest_answers",
AnswerQuestionOutput(
answer_results=[],
),
)
]

View File

@@ -0,0 +1,96 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.nodes.generate_initial_answer import (
generate_initial_answer,
)
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.nodes.validate_initial_answer import (
validate_initial_answer,
)
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalInput,
)
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalState,
)
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.graph_builder import (
generate_sub_answers_graph_builder,
)
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.graph_builder import (
retrieve_orig_question_docs_graph_builder,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def generate_initial_answer_graph_builder(test_mode: bool = False) -> StateGraph:
"""
LangGraph graph builder for the initial answer generation.
"""
graph = StateGraph(
state_schema=SubQuestionRetrievalState,
input=SubQuestionRetrievalInput,
)
# The sub-graph that generates the initial sub-answers
generate_sub_answers = generate_sub_answers_graph_builder().compile()
graph.add_node(
node="generate_sub_answers_subgraph",
action=generate_sub_answers,
)
# The sub-graph that retrieves the original question documents. This is run
# in parallel with the sub-answer generation process
retrieve_orig_question_docs = retrieve_orig_question_docs_graph_builder().compile()
graph.add_node(
node="retrieve_orig_question_docs_subgraph_wrapper",
action=retrieve_orig_question_docs,
)
# Node that generates the initial answer using the results of the previous
# two sub-graphs
graph.add_node(
node="generate_initial_answer",
action=generate_initial_answer,
)
# Node that validates the initial answer
graph.add_node(
node="validate_initial_answer",
action=validate_initial_answer,
)
### Add edges ###
graph.add_edge(
start_key=START,
end_key="retrieve_orig_question_docs_subgraph_wrapper",
)
graph.add_edge(
start_key=START,
end_key="generate_sub_answers_subgraph",
)
# Wait for both, the original question docs and the sub-answers to be generated before proceeding
graph.add_edge(
start_key=[
"retrieve_orig_question_docs_subgraph_wrapper",
"generate_sub_answers_subgraph",
],
end_key="generate_initial_answer",
)
graph.add_edge(
start_key="generate_initial_answer",
end_key="validate_initial_answer",
)
graph.add_edge(
start_key="validate_initial_answer",
end_key=END,
)
return graph

View File

@@ -0,0 +1,313 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalState,
)
from onyx.agents.agent_search.deep_search.main.models import AgentBaseMetrics
from onyx.agents.agent_search.deep_search.main.operations import (
calculate_initial_agent_stats,
)
from onyx.agents.agent_search.deep_search.main.operations import get_query_info
from onyx.agents.agent_search.deep_search.main.operations import logger
from onyx.agents.agent_search.deep_search.main.states import (
InitialAnswerUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
get_prompt_enrichment_components,
)
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
from onyx.agents.agent_search.shared_graph_utils.utils import remove_document_citations
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.context.search.models import InferenceSection
from onyx.prompts.agent_search import (
INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS,
)
from onyx.prompts.agent_search import (
INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS,
)
from onyx.prompts.agent_search import (
SUB_QUESTION_ANSWER_TEMPLATE,
)
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
def generate_initial_answer(
state: SubQuestionRetrievalState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> InitialAnswerUpdate:
"""
LangGraph node to generate the initial answer, using the initial sub-questions/sub-answers and the
documents retrieved for the original question.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
prompt_enrichment_components = get_prompt_enrichment_components(graph_config)
sub_questions_cited_documents = state.cited_documents
orig_question_retrieval_documents = state.orig_question_retrieved_documents
consolidated_context_docs: list[InferenceSection] = sub_questions_cited_documents
counter = 0
for original_doc_number, original_doc in enumerate(
orig_question_retrieval_documents
):
if original_doc_number not in sub_questions_cited_documents:
if (
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
or len(consolidated_context_docs) < AGENT_MAX_ANSWER_CONTEXT_DOCS
):
consolidated_context_docs.append(original_doc)
counter += 1
# sort docs by their scores - though the scores refer to different questions
relevant_docs = dedup_inference_sections(
consolidated_context_docs, consolidated_context_docs
)
sub_questions: list[str] = []
streamed_documents = (
relevant_docs
if len(relevant_docs) > 0
else state.orig_question_retrieved_documents[:15]
)
# Use the query info from the base document retrieval
query_info = get_query_info(state.orig_question_sub_query_retrieval_results)
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
relevance_list = relevance_from_docs(relevant_docs)
for tool_response in yield_search_responses(
query=question,
reranked_sections=streamed_documents,
final_context_sections=streamed_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
):
write_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
response=tool_response.response,
level=0,
level_question_num=0, # 0, 0 is the base question
),
writer,
)
if len(relevant_docs) == 0:
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=UNKNOWN_ANSWER,
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
dispatch_main_answer_stop_info(0, writer)
answer = UNKNOWN_ANSWER
initial_agent_stats = InitialAgentResultStats(
sub_questions={},
original_question={},
agent_effectiveness={},
)
else:
sub_question_answer_results = state.sub_question_results
# Collect the sub-questions and sub-answers and construct an appropriate
# prompt string.
# Consider replacing by a function.
answered_sub_questions: list[str] = []
all_sub_questions: list[str] = [] # Separate list for tracking all questions
for idx, sub_question_answer_result in enumerate(
sub_question_answer_results, start=1
):
all_sub_questions.append(sub_question_answer_result.question)
is_valid_answer = (
sub_question_answer_result.verified_high_quality
and sub_question_answer_result.answer
and sub_question_answer_result.answer != UNKNOWN_ANSWER
)
if is_valid_answer:
answered_sub_questions.append(
SUB_QUESTION_ANSWER_TEMPLATE.format(
sub_question=sub_question_answer_result.question,
sub_answer=sub_question_answer_result.answer,
sub_question_num=idx,
)
)
sub_question_answer_str = (
"\n\n------\n\n".join(answered_sub_questions)
if answered_sub_questions
else ""
)
# Use the appropriate prompt based on whether there are sub-questions.
base_prompt = (
INITIAL_ANSWER_PROMPT_W_SUB_QUESTIONS
if answered_sub_questions
else INITIAL_ANSWER_PROMPT_WO_SUB_QUESTIONS
)
sub_questions = all_sub_questions # Replace the original assignment
model = graph_config.tooling.fast_llm
doc_context = format_docs(relevant_docs)
doc_context = trim_prompt_piece(
config=model.config,
prompt_piece=doc_context,
reserved_str=(
base_prompt
+ sub_question_answer_str
+ prompt_enrichment_components.persona_prompts.contextualized_prompt
+ prompt_enrichment_components.history
+ prompt_enrichment_components.date_str
),
)
msg = [
HumanMessage(
content=base_prompt.format(
question=question,
answered_sub_questions=remove_document_citations(
sub_question_answer_str
),
relevant_docs=doc_context,
persona_specification=prompt_enrichment_components.persona_prompts.contextualized_prompt,
history=prompt_enrichment_components.history,
date_prompt=prompt_enrichment_components.date_str,
)
)
]
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
dispatch_timings: list[float] = []
for message in model.stream(msg):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
streamed_tokens.append(content)
logger.debug(
f"Average dispatch time for initial answer: {sum(dispatch_timings) / len(dispatch_timings)}"
)
dispatch_main_answer_stop_info(0, writer)
response = merge_content(*streamed_tokens)
answer = cast(str, response)
initial_agent_stats = calculate_initial_agent_stats(
state.sub_question_results, state.orig_question_retrieval_stats
)
logger.debug(
f"\n\nYYYYY--Sub-Questions:\n\n{sub_question_answer_str}\n\nStats:\n\n"
)
if initial_agent_stats:
logger.debug(initial_agent_stats.original_question)
logger.debug(initial_agent_stats.sub_questions)
logger.debug(initial_agent_stats.agent_effectiveness)
agent_base_end_time = datetime.now()
if agent_base_end_time and state.agent_start_time:
duration_s = (agent_base_end_time - state.agent_start_time).total_seconds()
else:
duration_s = None
agent_base_metrics = AgentBaseMetrics(
num_verified_documents_total=len(relevant_docs),
num_verified_documents_core=state.orig_question_retrieval_stats.verified_count,
verified_avg_score_core=state.orig_question_retrieval_stats.verified_avg_scores,
num_verified_documents_base=initial_agent_stats.sub_questions.get(
"num_verified_documents"
),
verified_avg_score_base=initial_agent_stats.sub_questions.get(
"verified_avg_score"
),
base_doc_boost_factor=initial_agent_stats.agent_effectiveness.get(
"utilized_chunk_ratio"
),
support_boost_factor=initial_agent_stats.agent_effectiveness.get(
"support_ratio"
),
duration_s=duration_s,
)
return InitialAnswerUpdate(
initial_answer=answer,
initial_agent_stats=initial_agent_stats,
generated_sub_questions=sub_questions,
agent_base_end_time=agent_base_end_time,
agent_base_metrics=agent_base_metrics,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate initial answer",
node_name="generate initial answer",
node_start_time=node_start_time,
result="",
)
],
)

View File

@@ -0,0 +1,40 @@
from datetime import datetime
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalState,
)
from onyx.agents.agent_search.deep_search.main.operations import logger
from onyx.agents.agent_search.deep_search.main.states import (
InitialAnswerQualityUpdate,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
def validate_initial_answer(
state: SubQuestionRetrievalState,
) -> InitialAnswerQualityUpdate:
"""
Check whether the initial answer sufficiently addresses the original user question.
"""
node_start_time = datetime.now()
logger.debug(
f"--------{node_start_time}--------Checking for base answer validity - for not set True/False manually"
)
verdict = True
return InitialAnswerQualityUpdate(
initial_answer_quality_eval=verdict,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate initial answer",
node_name="validate initial answer",
node_start_time=node_start_time,
result="",
)
],
)

View File

@@ -0,0 +1,51 @@
from operator import add
from typing import Annotated
from typing import TypedDict
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.deep_search.main.states import (
ExploratorySearchUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
InitialAnswerQualityUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
InitialAnswerUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
InitialQuestionDecompositionUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
OrigQuestionRetrievalUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
SubQuestionResultsUpdate,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.models import (
QuestionRetrievalResult,
)
from onyx.context.search.models import InferenceSection
### States ###
class SubQuestionRetrievalInput(CoreState):
exploratory_search_results: list[InferenceSection]
## Graph State
class SubQuestionRetrievalState(
# This includes the core state
SubQuestionRetrievalInput,
InitialQuestionDecompositionUpdate,
InitialAnswerUpdate,
SubQuestionResultsUpdate,
OrigQuestionRetrievalUpdate,
InitialAnswerQualityUpdate,
ExploratorySearchUpdate,
):
base_raw_search_result: Annotated[list[QuestionRetrievalResult], add]
## Graph Output State
class SubQuestionRetrievalOutput(TypedDict):
log_messages: list[str]

View File

@@ -0,0 +1,48 @@
from collections.abc import Hashable
from datetime import datetime
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalState,
)
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
def parallelize_initial_sub_question_answering(
state: SubQuestionRetrievalState,
) -> list[Send | Hashable]:
"""
LangGraph edge to parallelize the initial sub-question answering.
"""
edge_start_time = datetime.now()
if len(state.initial_sub_questions) > 0:
return [
Send(
"answer_sub_question_subgraphs",
SubQuestionAnsweringInput(
question=question,
question_id=make_question_id(0, question_num + 1),
log_messages=[
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
],
),
)
for question_num, question in enumerate(state.initial_sub_questions)
]
else:
return [
Send(
"ingest_answers",
AnswerQuestionOutput(
answer_results=[],
),
)
]

View File

@@ -0,0 +1,81 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.graph_builder import (
answer_query_graph_builder,
)
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.edges import (
parallelize_initial_sub_question_answering,
)
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.nodes.decompose_orig_question import (
decompose_orig_question,
)
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.nodes.format_initial_sub_answers import (
format_initial_sub_answers,
)
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.initial.generate_sub_answers.states import (
SubQuestionAnsweringState,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
test_mode = False
def generate_sub_answers_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the initial sub-answer generation process.
It generates the initial sub-questions and produces the answers.
"""
graph = StateGraph(
state_schema=SubQuestionAnsweringState,
input=SubQuestionAnsweringInput,
)
# Decompose the original question into sub-questions
graph.add_node(
node="decompose_orig_question",
action=decompose_orig_question,
)
# The sub-graph that executes the initial sub-question answering for
# each of the sub-questions.
answer_sub_question_subgraphs = answer_query_graph_builder().compile()
graph.add_node(
node="answer_sub_question_subgraphs",
action=answer_sub_question_subgraphs,
)
# Node that collects and formats the initial sub-question answers
graph.add_node(
node="format_initial_sub_question_answers",
action=format_initial_sub_answers,
)
graph.add_edge(
start_key=START,
end_key="decompose_orig_question",
)
graph.add_conditional_edges(
source="decompose_orig_question",
path=parallelize_initial_sub_question_answering,
path_map=["answer_sub_question_subgraphs"],
)
graph.add_edge(
start_key=["answer_sub_question_subgraphs"],
end_key="format_initial_sub_question_answers",
)
graph.add_edge(
start_key="format_initial_sub_question_answers",
end_key=END,
)
return graph

View File

@@ -0,0 +1,153 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalState,
)
from onyx.agents.agent_search.deep_search.main.models import (
AgentRefinedMetrics,
)
from onyx.agents.agent_search.deep_search.main.operations import (
dispatch_subquestion,
)
from onyx.agents.agent_search.deep_search.main.states import (
InitialQuestionDecompositionUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
from onyx.prompts.agent_search import (
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH,
)
from onyx.prompts.agent_search import (
INITIAL_QUESTION_DECOMPOSITION_PROMPT,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def decompose_orig_question(
state: SubQuestionRetrievalState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> InitialQuestionDecompositionUpdate:
"""
LangGraph node to decompose the original question into sub-questions.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
perform_initial_search_decomposition = (
graph_config.behavior.perform_initial_search_decomposition
)
# Get the rewritten queries in a defined format
model = graph_config.tooling.fast_llm
history = build_history_prompt(graph_config, question)
# Use the initial search results to inform the decomposition
agent_start_time = datetime.now()
# Initial search to inform decomposition. Just get top 3 fits
if perform_initial_search_decomposition:
# Due to unfortunate state representation in LangGraph, we need here to double check that the retrieval has
# happened prior to this point, allowing silent failure here since it is not critical for decomposition in
# all queries.
if not state.exploratory_search_results:
logger.error("Initial search for decomposition failed")
sample_doc_str = "\n\n".join(
[
doc.combined_content
for doc in state.exploratory_search_results[
:AGENT_NUM_DOCS_FOR_DECOMPOSITION
]
]
)
decomposition_prompt = (
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS_AFTER_SEARCH.format(
question=question, sample_doc_str=sample_doc_str, history=history
)
)
else:
decomposition_prompt = INITIAL_QUESTION_DECOMPOSITION_PROMPT.format(
question=question, history=history
)
# Start decomposition
msg = [HumanMessage(content=decomposition_prompt)]
# Send the initial question as a subquestion with number 0
write_custom_event(
"decomp_qs",
SubQuestionPiece(
sub_question=question,
level=0,
level_question_num=0,
),
writer,
)
# dispatches custom events for subquestion tokens, adding in subquestion ids.
streamed_tokens = dispatch_separated(
model.stream(msg), dispatch_subquestion(0, writer)
)
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.SUB_QUESTIONS,
level=0,
)
write_custom_event("stream_finished", stop_event, writer)
deomposition_response = merge_content(*streamed_tokens)
# this call should only return strings. Commenting out for efficiency
# assert [type(tok) == str for tok in streamed_tokens]
# use no-op cast() instead of str() which runs code
# list_of_subquestions = clean_and_parse_list_string(cast(str, response))
list_of_subqs = cast(str, deomposition_response).split("\n")
decomp_list: list[str] = [sq.strip() for sq in list_of_subqs if sq.strip() != ""]
return InitialQuestionDecompositionUpdate(
initial_sub_questions=decomp_list,
agent_start_time=agent_start_time,
agent_refined_start_time=None,
agent_refined_end_time=None,
agent_refined_metrics=AgentRefinedMetrics(
refined_doc_boost_factor=None,
refined_question_boost_factor=None,
duration_s=None,
),
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate sub answers",
node_name="decompose original question",
node_start_time=node_start_time,
result=f"decomposed original question into {len(decomp_list)} subquestions",
)
],
)

View File

@@ -0,0 +1,50 @@
from datetime import datetime
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.main.states import (
SubQuestionResultsUpdate,
)
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
def format_initial_sub_answers(
state: AnswerQuestionOutput,
) -> SubQuestionResultsUpdate:
"""
LangGraph node to format the answers to the initial sub-questions, including
deduping verified documents and context documents.
"""
node_start_time = datetime.now()
documents = []
context_documents = []
cited_documents = []
answer_results = state.answer_results
for answer_result in answer_results:
documents.extend(answer_result.verified_reranked_documents)
context_documents.extend(answer_result.context_documents)
cited_documents.extend(answer_result.cited_documents)
return SubQuestionResultsUpdate(
# Deduping is done by the documents operator for the main graph
# so we might not need to dedup here
verified_reranked_documents=dedup_inference_sections(documents, []),
context_documents=dedup_inference_sections(context_documents, []),
cited_documents=dedup_inference_sections(cited_documents, []),
sub_question_results=answer_results,
log_messages=[
get_langgraph_node_log_string(
graph_component="initial - generate sub answers",
node_name="format initial sub answers",
node_start_time=node_start_time,
result="",
)
],
)

View File

@@ -0,0 +1,34 @@
from typing import TypedDict
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.deep_search.main.states import (
InitialAnswerUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
InitialQuestionDecompositionUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import (
SubQuestionResultsUpdate,
)
from onyx.context.search.models import InferenceSection
### States ###
class SubQuestionAnsweringInput(CoreState):
exploratory_search_results: list[InferenceSection]
## Graph State
class SubQuestionAnsweringState(
# This includes the core state
SubQuestionAnsweringInput,
InitialQuestionDecompositionUpdate,
InitialAnswerUpdate,
SubQuestionResultsUpdate,
):
pass
## Graph Output State
class SubQuestionAnsweringOutput(TypedDict):
log_messages: list[str]

View File

@@ -0,0 +1,81 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.nodes.format_orig_question_search_input import (
format_orig_question_search_input,
)
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.nodes.format_orig_question_search_output import (
format_orig_question_search_output,
)
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.states import (
BaseRawSearchInput,
)
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.states import (
BaseRawSearchOutput,
)
from onyx.agents.agent_search.deep_search.initial.retrieve_orig_question_docs.states import (
BaseRawSearchState,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.graph_builder import (
expanded_retrieval_graph_builder,
)
def retrieve_orig_question_docs_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the retrieval of documents
that are relevant to the original question. This is
largely a wrapper around the expanded retrieval process to
ensure parallelism with the sub-question answer process.
"""
graph = StateGraph(
state_schema=BaseRawSearchState,
input=BaseRawSearchInput,
output=BaseRawSearchOutput,
)
### Add nodes ###
# Format the original question search output
graph.add_node(
node="format_orig_question_search_output",
action=format_orig_question_search_output,
)
# The sub-graph that executes the expanded retrieval process
expanded_retrieval = expanded_retrieval_graph_builder().compile()
graph.add_node(
node="retrieve_orig_question_docs_subgraph",
action=expanded_retrieval,
)
# Format the original question search input
graph.add_node(
node="format_orig_question_search_input",
action=format_orig_question_search_input,
)
### Add edges ###
graph.add_edge(start_key=START, end_key="format_orig_question_search_input")
graph.add_edge(
start_key="format_orig_question_search_input",
end_key="retrieve_orig_question_docs_subgraph",
)
graph.add_edge(
start_key="retrieve_orig_question_docs_subgraph",
end_key="format_orig_question_search_output",
)
graph.add_edge(
start_key="format_orig_question_search_output",
end_key=END,
)
return graph
if __name__ == "__main__":
pass

View File

@@ -0,0 +1,28 @@
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.utils.logger import setup_logger
logger = setup_logger()
def format_orig_question_search_input(
state: CoreState, config: RunnableConfig
) -> ExpandedRetrievalInput:
"""
LangGraph node to format the search input for the original question.
"""
logger.debug("generate_raw_search_data")
graph_config = cast(GraphConfig, config["metadata"]["config"])
return ExpandedRetrievalInput(
question=graph_config.inputs.search_request.query,
base_search=True,
sub_question_id=None, # This graph is always and only used for the original question
log_messages=[],
)

View File

@@ -0,0 +1,30 @@
from onyx.agents.agent_search.deep_search.main.states import OrigQuestionRetrievalUpdate
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalOutput,
)
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.utils.logger import setup_logger
logger = setup_logger()
def format_orig_question_search_output(
state: ExpandedRetrievalOutput,
) -> OrigQuestionRetrievalUpdate:
"""
LangGraph node to format the search result for the original question into the
proper format.
"""
sub_question_retrieval_stats = state.expanded_retrieval_result.retrieval_stats
if sub_question_retrieval_stats is None:
sub_question_retrieval_stats = AgentChunkRetrievalStats()
else:
sub_question_retrieval_stats = sub_question_retrieval_stats
return OrigQuestionRetrievalUpdate(
orig_question_verified_reranked_documents=state.expanded_retrieval_result.verified_reranked_documents,
orig_question_sub_query_retrieval_results=state.expanded_retrieval_result.expanded_query_results,
orig_question_retrieved_documents=state.retrieved_documents,
orig_question_retrieval_stats=sub_question_retrieval_stats,
log_messages=[],
)

View File

@@ -0,0 +1,29 @@
from onyx.agents.agent_search.deep_search.main.states import (
OrigQuestionRetrievalUpdate,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
## Graph Input State
class BaseRawSearchInput(ExpandedRetrievalInput):
pass
## Graph Output State
class BaseRawSearchOutput(OrigQuestionRetrievalUpdate):
"""
This is a list of results even though each call of this subgraph only returns one result.
This is because if we parallelize the answer query subgraph, there will be multiple
results in a list so the add operator is used to add them together.
"""
# base_expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult()
## Graph State
class BaseRawSearchState(
BaseRawSearchInput, BaseRawSearchOutput, OrigQuestionRetrievalUpdate
):
pass

View File

@@ -0,0 +1,113 @@
from collections.abc import Hashable
from datetime import datetime
from typing import cast
from typing import Literal
from langchain_core.runnables import RunnableConfig
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.main.states import (
RequireRefinemenEvalUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.utils.logger import setup_logger
logger = setup_logger()
def route_initial_tool_choice(
state: MainState, config: RunnableConfig
) -> Literal["tool_call", "start_agent_search", "logging_node"]:
"""
LangGraph edge to route to agent search.
"""
agent_config = cast(GraphConfig, config["metadata"]["config"])
if state.tool_choice is not None:
if (
agent_config.behavior.use_agentic_search
and agent_config.tooling.search_tool is not None
and state.tool_choice.tool.name == agent_config.tooling.search_tool.name
):
return "start_agent_search"
else:
return "tool_call"
else:
return "logging_node"
def parallelize_initial_sub_question_answering(
state: MainState,
) -> list[Send | Hashable]:
edge_start_time = datetime.now()
if len(state.initial_sub_questions) > 0:
return [
Send(
"answer_query_subgraph",
SubQuestionAnsweringInput(
question=question,
question_id=make_question_id(0, question_num + 1),
log_messages=[
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
],
),
)
for question_num, question in enumerate(state.initial_sub_questions)
]
else:
return [
Send(
"ingest_answers",
AnswerQuestionOutput(
answer_results=[],
),
)
]
# Define the function that determines whether to continue or not
def continue_to_refined_answer_or_end(
state: RequireRefinemenEvalUpdate,
) -> Literal["create_refined_sub_questions", "logging_node"]:
if state.require_refined_answer_eval:
return "create_refined_sub_questions"
else:
return "logging_node"
def parallelize_refined_sub_question_answering(
state: MainState,
) -> list[Send | Hashable]:
edge_start_time = datetime.now()
if len(state.refined_sub_questions) > 0:
return [
Send(
"answer_refined_question_subgraphs",
SubQuestionAnsweringInput(
question=question_data.sub_question,
question_id=make_question_id(1, question_num),
log_messages=[
f"{edge_start_time} -- Main Edge - Parallelize Refined Sub-question Answering"
],
),
)
for question_num, question_data in state.refined_sub_questions.items()
]
else:
return [
Send(
"ingest_refined_sub_answers",
AnswerQuestionOutput(
answer_results=[],
),
)
]

View File

@@ -0,0 +1,265 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.graph_builder import (
generate_initial_answer_graph_builder,
)
from onyx.agents.agent_search.deep_search.main.edges import (
continue_to_refined_answer_or_end,
)
from onyx.agents.agent_search.deep_search.main.edges import (
parallelize_refined_sub_question_answering,
)
from onyx.agents.agent_search.deep_search.main.edges import (
route_initial_tool_choice,
)
from onyx.agents.agent_search.deep_search.main.nodes.compare_answers import (
compare_answers,
)
from onyx.agents.agent_search.deep_search.main.nodes.create_refined_sub_questions import (
create_refined_sub_questions,
)
from onyx.agents.agent_search.deep_search.main.nodes.decide_refinement_need import (
decide_refinement_need,
)
from onyx.agents.agent_search.deep_search.main.nodes.extract_entities_terms import (
extract_entities_terms,
)
from onyx.agents.agent_search.deep_search.main.nodes.generate_refined_answer import (
generate_refined_answer,
)
from onyx.agents.agent_search.deep_search.main.nodes.ingest_refined_sub_answers import (
ingest_refined_sub_answers,
)
from onyx.agents.agent_search.deep_search.main.nodes.persist_agent_results import (
persist_agent_results,
)
from onyx.agents.agent_search.deep_search.main.nodes.start_agent_search import (
start_agent_search,
)
from onyx.agents.agent_search.deep_search.main.states import MainInput
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.refinement.consolidate_sub_answers.graph_builder import (
answer_refined_query_graph_builder,
)
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
basic_use_tool_response,
)
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
prepare_tool_input,
)
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.utils.logger import setup_logger
logger = setup_logger()
test_mode = False
def main_graph_builder(test_mode: bool = False) -> StateGraph:
"""
LangGraph graph builder for the main agent search process.
"""
graph = StateGraph(
state_schema=MainState,
input=MainInput,
)
# Prepare the tool input
graph.add_node(
node="prepare_tool_input",
action=prepare_tool_input,
)
# Choose the initial tool
graph.add_node(
node="initial_tool_choice",
action=llm_tool_choice,
)
# Call the tool, if required
graph.add_node(
node="tool_call",
action=tool_call,
)
# Use the tool response
graph.add_node(
node="basic_use_tool_response",
action=basic_use_tool_response,
)
# Start the agent search process
graph.add_node(
node="start_agent_search",
action=start_agent_search,
)
# The sub-graph for the initial answer generation
generate_initial_answer_subgraph = generate_initial_answer_graph_builder().compile()
graph.add_node(
node="generate_initial_answer_subgraph",
action=generate_initial_answer_subgraph,
)
# Create the refined sub-questions
graph.add_node(
node="create_refined_sub_questions",
action=create_refined_sub_questions,
)
# Subgraph for the refined sub-answer generation
answer_refined_question = answer_refined_query_graph_builder().compile()
graph.add_node(
node="answer_refined_question_subgraphs",
action=answer_refined_question,
)
# Ingest the refined sub-answers
graph.add_node(
node="ingest_refined_sub_answers",
action=ingest_refined_sub_answers,
)
# Node to generate the refined answer
graph.add_node(
node="generate_refined_answer",
action=generate_refined_answer,
)
# Early node to extract the entities and terms from the initial answer,
# This information is used to inform the creation the refined sub-questions
graph.add_node(
node="extract_entity_term",
action=extract_entities_terms,
)
# Decide if the answer needs to be refined (currently always true)
graph.add_node(
node="decide_refinement_need",
action=decide_refinement_need,
)
# Compare the initial and refined answers, and determine whether
# the refined answer is sufficiently better
graph.add_node(
node="compare_answers",
action=compare_answers,
)
# Log the results. This will log the stats as well as the answers, sub-questions, and sub-answers
graph.add_node(
node="logging_node",
action=persist_agent_results,
)
### Add edges ###
graph.add_edge(start_key=START, end_key="prepare_tool_input")
graph.add_edge(
start_key="prepare_tool_input",
end_key="initial_tool_choice",
)
graph.add_conditional_edges(
"initial_tool_choice",
route_initial_tool_choice,
["tool_call", "start_agent_search", "logging_node"],
)
graph.add_edge(
start_key="tool_call",
end_key="basic_use_tool_response",
)
graph.add_edge(
start_key="basic_use_tool_response",
end_key="logging_node",
)
graph.add_edge(
start_key="start_agent_search",
end_key="generate_initial_answer_subgraph",
)
graph.add_edge(
start_key="start_agent_search",
end_key="extract_entity_term",
)
# Wait for the initial answer generation and the entity/term extraction to be complete
# before deciding if a refinement is needed.
graph.add_edge(
start_key=["generate_initial_answer_subgraph", "extract_entity_term"],
end_key="decide_refinement_need",
)
graph.add_conditional_edges(
source="decide_refinement_need",
path=continue_to_refined_answer_or_end,
path_map=["create_refined_sub_questions", "logging_node"],
)
graph.add_conditional_edges(
source="create_refined_sub_questions",
path=parallelize_refined_sub_question_answering,
path_map=["answer_refined_question_subgraphs"],
)
graph.add_edge(
start_key="answer_refined_question_subgraphs",
end_key="ingest_refined_sub_answers",
)
graph.add_edge(
start_key="ingest_refined_sub_answers",
end_key="generate_refined_answer",
)
graph.add_edge(
start_key="generate_refined_answer",
end_key="compare_answers",
)
graph.add_edge(
start_key="compare_answers",
end_key="logging_node",
)
graph.add_edge(
start_key="logging_node",
end_key=END,
)
return graph
if __name__ == "__main__":
pass
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
graph = main_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
with get_session_context_manager() as db_session:
search_request = SearchRequest(query="Who created Excel?")
graph_config = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
inputs = MainInput(
base_question=graph_config.inputs.search_request.query, log_messages=[]
)
for thing in compiled_graph.stream(
input=inputs,
config={"configurable": {"config": graph_config}},
stream_mode="custom",
subgraphs=True,
):
logger.debug(thing)

View File

@@ -0,0 +1,36 @@
from pydantic import BaseModel
class RefinementSubQuestion(BaseModel):
sub_question: str
sub_question_id: str
verified: bool
answered: bool
answer: str
class AgentTimings(BaseModel):
base_duration_s: float | None
refined_duration_s: float | None
full_duration_s: float | None
class AgentBaseMetrics(BaseModel):
num_verified_documents_total: int | None
num_verified_documents_core: int | None
verified_avg_score_core: float | None
num_verified_documents_base: int | float | None
verified_avg_score_base: float | None = None
base_doc_boost_factor: float | None = None
support_boost_factor: float | None = None
duration_s: float | None = None
class AgentRefinedMetrics(BaseModel):
refined_doc_boost_factor: float | None = None
refined_question_boost_factor: float | None = None
duration_s: float | None = None
class AgentAdditionalMetrics(BaseModel):
pass

View File

@@ -0,0 +1,71 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.main.states import (
InitialRefinedAnswerComparisonUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import RefinedAnswerImprovement
from onyx.prompts.agent_search import (
INITIAL_REFINED_ANSWER_COMPARISON_PROMPT,
)
def compare_answers(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> InitialRefinedAnswerComparisonUpdate:
"""
LangGraph node to compare the initial answer and the refined answer and determine if the
refined answer is sufficiently better than the initial answer.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
initial_answer = state.initial_answer
refined_answer = state.refined_answer
compare_answers_prompt = INITIAL_REFINED_ANSWER_COMPARISON_PROMPT.format(
question=question, initial_answer=initial_answer, refined_answer=refined_answer
)
msg = [HumanMessage(content=compare_answers_prompt)]
# Get the rewritten queries in a defined format
model = graph_config.tooling.fast_llm
# no need to stream this
resp = model.invoke(msg)
refined_answer_improvement = (
isinstance(resp.content, str) and "yes" in resp.content.lower()
)
write_custom_event(
"refined_answer_improvement",
RefinedAnswerImprovement(
refined_answer_improvement=refined_answer_improvement,
),
writer,
)
return InitialRefinedAnswerComparisonUpdate(
refined_answer_improvement_eval=refined_answer_improvement,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="compare answers",
node_start_time=node_start_time,
result=f"Answer comparison: {refined_answer_improvement}",
)
],
)

View File

@@ -0,0 +1,131 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.main.models import (
RefinementSubQuestion,
)
from onyx.agents.agent_search.deep_search.main.operations import (
dispatch_subquestion,
)
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.main.states import (
RefinedQuestionDecompositionUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
format_entity_term_extraction,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.prompts.agent_search import (
REFINEMENT_QUESTION_DECOMPOSITION_PROMPT,
)
from onyx.tools.models import ToolCallKickoff
def create_refined_sub_questions(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> RefinedQuestionDecompositionUpdate:
"""
LangGraph node to create refined sub-questions based on the initial answer, the history,
the entity term extraction results found earlier, and the sub-questions that were answered and failed.
"""
graph_config = cast(GraphConfig, config["metadata"]["config"])
write_custom_event(
"start_refined_answer_creation",
ToolCallKickoff(
tool_name="agent_search_1",
tool_args={
"query": graph_config.inputs.search_request.query,
"answer": state.initial_answer,
},
),
writer,
)
node_start_time = datetime.now()
agent_refined_start_time = datetime.now()
question = graph_config.inputs.search_request.query
base_answer = state.initial_answer
history = build_history_prompt(graph_config, question)
# get the entity term extraction dict and properly format it
entity_retlation_term_extractions = state.entity_relation_term_extractions
entity_term_extraction_str = format_entity_term_extraction(
entity_retlation_term_extractions
)
initial_question_answers = state.sub_question_results
addressed_question_list = [
x.question for x in initial_question_answers if x.verified_high_quality
]
failed_question_list = [
x.question for x in initial_question_answers if not x.verified_high_quality
]
msg = [
HumanMessage(
content=REFINEMENT_QUESTION_DECOMPOSITION_PROMPT.format(
question=question,
history=history,
entity_term_extraction_str=entity_term_extraction_str,
base_answer=base_answer,
answered_sub_questions="\n - ".join(addressed_question_list),
failed_sub_questions="\n - ".join(failed_question_list),
),
)
]
# Grader
model = graph_config.tooling.fast_llm
streamed_tokens = dispatch_separated(
model.stream(msg), dispatch_subquestion(1, writer)
)
response = merge_content(*streamed_tokens)
if isinstance(response, str):
parsed_response = [q for q in response.split("\n") if q.strip() != ""]
else:
raise ValueError("LLM response is not a string")
refined_sub_question_dict = {}
for sub_question_num, sub_question in enumerate(parsed_response):
refined_sub_question = RefinementSubQuestion(
sub_question=sub_question,
sub_question_id=make_question_id(1, sub_question_num + 1),
verified=False,
answered=False,
answer="",
)
refined_sub_question_dict[sub_question_num + 1] = refined_sub_question
return RefinedQuestionDecompositionUpdate(
refined_sub_questions=refined_sub_question_dict,
agent_refined_start_time=agent_refined_start_time,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="create refined sub questions",
node_start_time=node_start_time,
result=f"Created {len(refined_sub_question_dict)} refined sub questions",
)
],
)

View File

@@ -0,0 +1,47 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.main.states import (
RequireRefinemenEvalUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
def decide_refinement_need(
state: MainState, config: RunnableConfig
) -> RequireRefinemenEvalUpdate:
"""
LangGraph node to decide if refinement is needed based on the initial answer and the question.
At present, we always refine.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
decision = True # TODO: just for current testing purposes
log_messages = [
get_langgraph_node_log_string(
graph_component="main",
node_name="decide refinement need",
node_start_time=node_start_time,
result=f"Refinement decision: {decision}",
)
]
if graph_config.behavior.allow_refinement:
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=decision,
log_messages=log_messages,
)
else:
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=False,
log_messages=log_messages,
)

View File

@@ -0,0 +1,116 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search.main.operations import logger
from onyx.agents.agent_search.deep_search.main.states import (
EntityTermExtractionUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.models import EntityExtractionResult
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.constants import NUM_EXPLORATORY_DOCS
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE
def extract_entities_terms(
state: MainState, config: RunnableConfig
) -> EntityTermExtractionUpdate:
"""
LangGraph node to extract entities, relationships, and terms from the initial search results.
This data is used to inform particularly the sub-questions that are created for the refined answer.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
if not graph_config.behavior.allow_refinement:
return EntityTermExtractionUpdate(
entity_relation_term_extractions=EntityRelationshipTermExtraction(
entities=[],
relationships=[],
terms=[],
),
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="extract entities terms",
node_start_time=node_start_time,
result="Refinement is not allowed",
)
],
)
# first four lines duplicates from generate_initial_answer
question = graph_config.inputs.search_request.query
initial_search_docs = state.exploratory_search_results[:NUM_EXPLORATORY_DOCS]
# start with the entity/term/extraction
doc_context = format_docs(initial_search_docs)
# Calculation here is only approximate
doc_context = trim_prompt_piece(
graph_config.tooling.fast_llm.config,
doc_context,
ENTITY_TERM_EXTRACTION_PROMPT
+ question
+ ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE,
)
msg = [
HumanMessage(
content=ENTITY_TERM_EXTRACTION_PROMPT.format(
question=question, context=doc_context
)
+ ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE,
)
]
fast_llm = graph_config.tooling.fast_llm
# Grader
llm_response = fast_llm.invoke(
prompt=msg,
)
cleaned_response = (
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
try:
entity_extraction_result = EntityExtractionResult.model_validate_json(
cleaned_response
)
except ValueError:
logger.error("Failed to parse LLM response as JSON in Entity-Term Extraction")
entity_extraction_result = EntityExtractionResult(
retrieved_entities_relationships=EntityRelationshipTermExtraction(
entities=[],
relationships=[],
terms=[],
),
)
return EntityTermExtractionUpdate(
entity_relation_term_extractions=entity_extraction_result.retrieved_entities_relationships,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="extract entities terms",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,339 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_content
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.main.models import (
AgentRefinedMetrics,
)
from onyx.agents.agent_search.deep_search.main.operations import get_query_info
from onyx.agents.agent_search.deep_search.main.operations import logger
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.main.states import (
RefinedAnswerUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
get_prompt_enrichment_components,
)
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.models import InferenceSection
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
remove_document_citations,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.prompts.agent_search import (
REFINED_ANSWER_PROMPT_W_SUB_QUESTIONS,
)
from onyx.prompts.agent_search import (
REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS,
)
from onyx.prompts.agent_search import (
SUB_QUESTION_ANSWER_TEMPLATE_REFINED,
)
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
def generate_refined_answer(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> RefinedAnswerUpdate:
"""
LangGraph node to generate the refined answer.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
prompt_enrichment_components = get_prompt_enrichment_components(graph_config)
persona_contextualized_prompt = (
prompt_enrichment_components.persona_prompts.contextualized_prompt
)
verified_reranked_documents = state.verified_reranked_documents
sub_questions_cited_documents = state.cited_documents
original_question_verified_documents = (
state.orig_question_verified_reranked_documents
)
original_question_retrieved_documents = state.orig_question_retrieved_documents
consolidated_context_docs: list[InferenceSection] = sub_questions_cited_documents
counter = 0
for original_doc_number, original_doc in enumerate(
original_question_verified_documents
):
if original_doc_number not in sub_questions_cited_documents:
if (
counter <= AGENT_MIN_ORIG_QUESTION_DOCS
or len(consolidated_context_docs)
< 1.5
* AGENT_MAX_ANSWER_CONTEXT_DOCS # allow for larger context in refinement
):
consolidated_context_docs.append(original_doc)
counter += 1
# sort docs by their scores - though the scores refer to different questions
relevant_docs = dedup_inference_sections(
consolidated_context_docs, consolidated_context_docs
)
streaming_docs = (
relevant_docs
if len(relevant_docs) > 0
else original_question_retrieved_documents[:15]
)
query_info = get_query_info(state.orig_question_sub_query_retrieval_results)
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
# stream refined answer docs, or original question docs if no relevant docs are found
relevance_list = relevance_from_docs(relevant_docs)
for tool_response in yield_search_responses(
query=question,
reranked_sections=streaming_docs,
final_context_sections=streaming_docs,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
):
write_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
response=tool_response.response,
level=1,
level_question_num=0, # 0, 0 is the base question
),
writer,
)
if len(verified_reranked_documents) > 0:
refined_doc_effectiveness = len(relevant_docs) / len(
verified_reranked_documents
)
else:
refined_doc_effectiveness = 10.0
sub_question_answer_results = state.sub_question_results
answered_sub_question_answer_list: list[str] = []
sub_questions: list[str] = []
initial_answered_sub_questions: set[str] = set()
refined_answered_sub_questions: set[str] = set()
for i, result in enumerate(sub_question_answer_results, 1):
question_level, _ = parse_question_id(result.question_id)
sub_questions.append(result.question)
if (
result.verified_high_quality
and result.answer
and result.answer != UNKNOWN_ANSWER
):
sub_question_type = "initial" if question_level == 0 else "refined"
question_set = (
initial_answered_sub_questions
if question_level == 0
else refined_answered_sub_questions
)
question_set.add(result.question)
answered_sub_question_answer_list.append(
SUB_QUESTION_ANSWER_TEMPLATE_REFINED.format(
sub_question=result.question,
sub_answer=result.answer,
sub_question_num=i,
sub_question_type=sub_question_type,
)
)
# Calculate efficiency
total_answered_questions = (
initial_answered_sub_questions | refined_answered_sub_questions
)
revision_question_efficiency = (
len(total_answered_questions) / len(initial_answered_sub_questions)
if initial_answered_sub_questions
else 10.0
if refined_answered_sub_questions
else 1.0
)
sub_question_answer_str = "\n\n------\n\n".join(
set(answered_sub_question_answer_list)
)
initial_answer = state.initial_answer or ""
# Choose appropriate prompt template
base_prompt = (
REFINED_ANSWER_PROMPT_W_SUB_QUESTIONS
if answered_sub_question_answer_list
else REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS
)
model = graph_config.tooling.fast_llm
relevant_docs_str = format_docs(relevant_docs)
relevant_docs_str = trim_prompt_piece(
model.config,
relevant_docs_str,
base_prompt
+ question
+ sub_question_answer_str
+ initial_answer
+ persona_contextualized_prompt
+ prompt_enrichment_components.history,
)
msg = [
HumanMessage(
content=base_prompt.format(
question=question,
history=prompt_enrichment_components.history,
answered_sub_questions=remove_document_citations(
sub_question_answer_str
),
relevant_docs=relevant_docs_str,
initial_answer=remove_document_citations(initial_answer)
if initial_answer
else None,
persona_specification=persona_contextualized_prompt,
date_prompt=prompt_enrichment_components.date_str,
)
)
]
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
dispatch_timings: list[float] = []
for message in model.stream(msg):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"refined_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=1,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
streamed_tokens.append(content)
logger.debug(
f"Average dispatch time for refined answer: {sum(dispatch_timings) / len(dispatch_timings)}"
)
dispatch_main_answer_stop_info(1, writer)
response = merge_content(*streamed_tokens)
answer = cast(str, response)
refined_agent_stats = RefinedAgentStats(
revision_doc_efficiency=refined_doc_effectiveness,
revision_question_efficiency=revision_question_efficiency,
)
logger.debug(f"\n\n---INITIAL ANSWER ---\n\n Answer:\n Agent: {initial_answer}")
logger.debug("-" * 10)
logger.debug(f"\n\n---REVISED AGENT ANSWER ---\n\n Answer:\n Agent: {answer}")
logger.debug("-" * 100)
if state.initial_agent_stats:
initial_doc_boost_factor = state.initial_agent_stats.agent_effectiveness.get(
"utilized_chunk_ratio", "--"
)
initial_support_boost_factor = (
state.initial_agent_stats.agent_effectiveness.get("support_ratio", "--")
)
num_initial_verified_docs = state.initial_agent_stats.original_question.get(
"num_verified_documents", "--"
)
initial_verified_docs_avg_score = (
state.initial_agent_stats.original_question.get("verified_avg_score", "--")
)
initial_sub_questions_verified_docs = (
state.initial_agent_stats.sub_questions.get("num_verified_documents", "--")
)
logger.debug("INITIAL AGENT STATS")
logger.debug(f"Document Boost Factor: {initial_doc_boost_factor}")
logger.debug(f"Support Boost Factor: {initial_support_boost_factor}")
logger.debug(f"Originally Verified Docs: {num_initial_verified_docs}")
logger.debug(
f"Originally Verified Docs Avg Score: {initial_verified_docs_avg_score}"
)
logger.debug(
f"Sub-Questions Verified Docs: {initial_sub_questions_verified_docs}"
)
if refined_agent_stats:
logger.debug("-" * 10)
logger.debug("REFINED AGENT STATS")
logger.debug(
f"Revision Doc Factor: {refined_agent_stats.revision_doc_efficiency}"
)
logger.debug(
f"Revision Question Factor: {refined_agent_stats.revision_question_efficiency}"
)
agent_refined_end_time = datetime.now()
if state.agent_refined_start_time:
agent_refined_duration = (
agent_refined_end_time - state.agent_refined_start_time
).total_seconds()
else:
agent_refined_duration = None
agent_refined_metrics = AgentRefinedMetrics(
refined_doc_boost_factor=refined_agent_stats.revision_doc_efficiency,
refined_question_boost_factor=refined_agent_stats.revision_question_efficiency,
duration_s=agent_refined_duration,
)
return RefinedAnswerUpdate(
refined_answer=answer,
refined_answer_quality=True, # TODO: replace this with the actual check value
refined_agent_stats=refined_agent_stats,
agent_refined_end_time=agent_refined_end_time,
agent_refined_metrics=agent_refined_metrics,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="generate refined answer",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,42 @@
from datetime import datetime
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.main.states import (
SubQuestionResultsUpdate,
)
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
def ingest_refined_sub_answers(
state: AnswerQuestionOutput,
) -> SubQuestionResultsUpdate:
"""
LangGraph node to ingest and format the refined sub-answers and retrieved documents.
"""
node_start_time = datetime.now()
documents = []
answer_results = state.answer_results
for answer_result in answer_results:
documents.extend(answer_result.verified_reranked_documents)
return SubQuestionResultsUpdate(
# Deduping is done by the documents operator for the main graph
# so we might not need to dedup here
verified_reranked_documents=dedup_inference_sections(documents, []),
sub_question_results=answer_results,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="ingest refined answers",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,129 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search.main.models import (
AgentAdditionalMetrics,
)
from onyx.agents.agent_search.deep_search.main.models import AgentTimings
from onyx.agents.agent_search.deep_search.main.operations import logger
from onyx.agents.agent_search.deep_search.main.states import MainOutput
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.models import CombinedAgentMetrics
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.db.chat import log_agent_metrics
from onyx.db.chat import log_agent_sub_question_results
def persist_agent_results(state: MainState, config: RunnableConfig) -> MainOutput:
"""
LangGraph node to persist the agent results, including agent logging data.
"""
node_start_time = datetime.now()
agent_start_time = state.agent_start_time
agent_base_end_time = state.agent_base_end_time
agent_refined_start_time = state.agent_refined_start_time
agent_refined_end_time = state.agent_refined_end_time
agent_end_time = agent_refined_end_time or agent_base_end_time
agent_base_duration = None
if agent_base_end_time and agent_start_time:
agent_base_duration = (agent_base_end_time - agent_start_time).total_seconds()
agent_refined_duration = None
if agent_refined_start_time and agent_refined_end_time:
agent_refined_duration = (
agent_refined_end_time - agent_refined_start_time
).total_seconds()
agent_full_duration = None
if agent_end_time and agent_start_time:
agent_full_duration = (agent_end_time - agent_start_time).total_seconds()
agent_type = "refined" if agent_refined_duration else "base"
agent_base_metrics = state.agent_base_metrics
agent_refined_metrics = state.agent_refined_metrics
combined_agent_metrics = CombinedAgentMetrics(
timings=AgentTimings(
base_duration_s=agent_base_duration,
refined_duration_s=agent_refined_duration,
full_duration_s=agent_full_duration,
),
base_metrics=agent_base_metrics,
refined_metrics=agent_refined_metrics,
additional_metrics=AgentAdditionalMetrics(),
)
persona_id = None
graph_config = cast(GraphConfig, config["metadata"]["config"])
if graph_config.inputs.search_request.persona:
persona_id = graph_config.inputs.search_request.persona.id
user_id = None
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
user = graph_config.tooling.search_tool.user
if user:
user_id = user.id
# log the agent metrics
if graph_config.persistence:
if agent_base_duration is not None:
log_agent_metrics(
db_session=graph_config.persistence.db_session,
user_id=user_id,
persona_id=persona_id,
agent_type=agent_type,
start_time=agent_start_time,
agent_metrics=combined_agent_metrics,
)
# Persist the sub-answer in the database
db_session = graph_config.persistence.db_session
chat_session_id = graph_config.persistence.chat_session_id
primary_message_id = graph_config.persistence.message_id
sub_question_answer_results = state.sub_question_results
log_agent_sub_question_results(
db_session=db_session,
chat_session_id=chat_session_id,
primary_message_id=primary_message_id,
sub_question_answer_results=sub_question_answer_results,
)
main_output = MainOutput(
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="persist agent results",
node_start_time=node_start_time,
)
],
)
for log_message in state.log_messages:
logger.debug(log_message)
if state.agent_base_metrics:
logger.debug(f"Initial loop: {state.agent_base_metrics.duration_s}")
if state.agent_refined_metrics:
logger.debug(f"Refined loop: {state.agent_refined_metrics.duration_s}")
if (
state.agent_base_metrics
and state.agent_refined_metrics
and state.agent_base_metrics.duration_s
and state.agent_refined_metrics.duration_s
):
logger.debug(
f"Total time: {float(state.agent_base_metrics.duration_s) + float(state.agent_refined_metrics.duration_s)}"
)
return main_output

View File

@@ -0,0 +1,52 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables import RunnableConfig
from onyx.agents.agent_search.deep_search.main.states import (
ExploratorySearchUpdate,
)
from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
build_history_prompt,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import retrieve_search_docs
from onyx.configs.agent_configs import AGENT_EXPLORATORY_SEARCH_RESULTS
from onyx.context.search.models import InferenceSection
def start_agent_search(
state: MainState, config: RunnableConfig
) -> ExploratorySearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
node_start_time = datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
history = build_history_prompt(graph_config, question)
# Initial search to inform decomposition. Just get top 3 fits
search_tool = graph_config.tooling.search_tool
assert search_tool, "search_tool must be provided for agentic search"
retrieved_docs: list[InferenceSection] = retrieve_search_docs(search_tool, question)
exploratory_search_results = retrieved_docs[:AGENT_EXPLORATORY_SEARCH_RESULTS]
return ExploratorySearchUpdate(
exploratory_search_results=exploratory_search_results,
previous_history_summary=history,
log_messages=[
get_langgraph_node_log_string(
graph_component="main",
node_name="start agent search",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,132 @@
from collections.abc import Callable
from langgraph.types import StreamWriter
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
from onyx.agents.agent_search.shared_graph_utils.models import (
SubQuestionAnswerResults,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import SubQuestionPiece
from onyx.context.search.models import IndexFilters
from onyx.tools.models import SearchQueryInfo
from onyx.utils.logger import setup_logger
logger = setup_logger()
def dispatch_subquestion(
level: int, writer: StreamWriter
) -> Callable[[str, int], None]:
def _helper(sub_question_part: str, sep_num: int) -> None:
write_custom_event(
"decomp_qs",
SubQuestionPiece(
sub_question=sub_question_part,
level=level,
level_question_num=sep_num,
),
writer,
)
return _helper
def calculate_initial_agent_stats(
decomp_answer_results: list[SubQuestionAnswerResults],
original_question_stats: AgentChunkRetrievalStats,
) -> InitialAgentResultStats:
initial_agent_result_stats: InitialAgentResultStats = InitialAgentResultStats(
sub_questions={},
original_question={},
agent_effectiveness={},
)
orig_verified = original_question_stats.verified_count
orig_support_score = original_question_stats.verified_avg_scores
verified_document_chunk_ids = []
support_scores = 0.0
for decomp_answer_result in decomp_answer_results:
verified_document_chunk_ids += (
decomp_answer_result.sub_question_retrieval_stats.verified_doc_chunk_ids
)
if (
decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores
is not None
):
support_scores += (
decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores
)
verified_document_chunk_ids = list(set(verified_document_chunk_ids))
# Calculate sub-question stats
if (
verified_document_chunk_ids
and len(verified_document_chunk_ids) > 0
and support_scores is not None
):
sub_question_stats: dict[str, float | int | None] = {
"num_verified_documents": len(verified_document_chunk_ids),
"verified_avg_score": float(support_scores / len(decomp_answer_results)),
}
else:
sub_question_stats = {"num_verified_documents": 0, "verified_avg_score": None}
initial_agent_result_stats.sub_questions.update(sub_question_stats)
# Get original question stats
initial_agent_result_stats.original_question.update(
{
"num_verified_documents": original_question_stats.verified_count,
"verified_avg_score": original_question_stats.verified_avg_scores,
}
)
# Calculate chunk utilization ratio
sub_verified = initial_agent_result_stats.sub_questions["num_verified_documents"]
chunk_ratio: float | None = None
if sub_verified is not None and orig_verified is not None and orig_verified > 0:
chunk_ratio = (float(sub_verified) / orig_verified) if sub_verified > 0 else 0.0
elif sub_verified is not None and sub_verified > 0:
chunk_ratio = 10.0
initial_agent_result_stats.agent_effectiveness["utilized_chunk_ratio"] = chunk_ratio
if (
orig_support_score is None
or orig_support_score == 0.0
and initial_agent_result_stats.sub_questions["verified_avg_score"] is None
):
initial_agent_result_stats.agent_effectiveness["support_ratio"] = None
elif orig_support_score is None or orig_support_score == 0.0:
initial_agent_result_stats.agent_effectiveness["support_ratio"] = 10
elif initial_agent_result_stats.sub_questions["verified_avg_score"] is None:
initial_agent_result_stats.agent_effectiveness["support_ratio"] = 0
else:
initial_agent_result_stats.agent_effectiveness["support_ratio"] = (
initial_agent_result_stats.sub_questions["verified_avg_score"]
/ orig_support_score
)
return initial_agent_result_stats
def get_query_info(results: list[QueryRetrievalResult]) -> SearchQueryInfo:
# Use the query info from the base document retrieval
# this is used for some fields that are the same across the searches done
query_info = None
for result in results:
if result.query_info is not None:
query_info = result.query_info
break
return query_info or SearchQueryInfo(
predicted_search=None,
final_filters=IndexFilters(access_control_list=None),
recency_bias_multiplier=1.0,
)

View File

@@ -0,0 +1,172 @@
from datetime import datetime
from operator import add
from typing import Annotated
from typing import TypedDict
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.deep_search.main.models import AgentBaseMetrics
from onyx.agents.agent_search.deep_search.main.models import (
AgentRefinedMetrics,
)
from onyx.agents.agent_search.deep_search.main.models import (
RefinementSubQuestion,
)
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
)
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
from onyx.agents.agent_search.shared_graph_utils.models import (
SubQuestionAnswerResults,
)
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_question_answer_results,
)
from onyx.context.search.models import InferenceSection
### States ###
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
class RefinedAgentStartStats(BaseModel):
agent_refined_start_time: datetime | None = None
class RefinedAgentEndStats(BaseModel):
agent_refined_end_time: datetime | None = None
agent_refined_metrics: AgentRefinedMetrics = AgentRefinedMetrics()
class InitialQuestionDecompositionUpdate(
RefinedAgentStartStats, RefinedAgentEndStats, LoggerUpdate
):
agent_start_time: datetime | None = None
previous_history: str | None = None
initial_sub_questions: list[str] = []
class ExploratorySearchUpdate(LoggerUpdate):
exploratory_search_results: list[InferenceSection] = []
previous_history_summary: str | None = None
class InitialRefinedAnswerComparisonUpdate(LoggerUpdate):
"""
Evaluation of whether the refined answer is better than the initial answer
"""
refined_answer_improvement_eval: bool = False
class InitialAnswerUpdate(LoggerUpdate):
"""
Initial answer information
"""
initial_answer: str | None = None
initial_agent_stats: InitialAgentResultStats | None = None
generated_sub_questions: list[str] = []
agent_base_end_time: datetime | None = None
agent_base_metrics: AgentBaseMetrics | None = None
class RefinedAnswerUpdate(RefinedAgentEndStats, LoggerUpdate):
"""
Refined answer information
"""
refined_answer: str | None = None
refined_agent_stats: RefinedAgentStats | None = None
refined_answer_quality: bool = False
class InitialAnswerQualityUpdate(LoggerUpdate):
"""
Initial answer quality evaluation
"""
initial_answer_quality_eval: bool = False
class RequireRefinemenEvalUpdate(LoggerUpdate):
require_refined_answer_eval: bool = True
class SubQuestionResultsUpdate(LoggerUpdate):
verified_reranked_documents: Annotated[
list[InferenceSection], dedup_inference_sections
] = []
context_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
cited_documents: Annotated[
list[InferenceSection], dedup_inference_sections
] = [] # cited docs from sub-answers are used for answer context
sub_question_results: Annotated[
list[SubQuestionAnswerResults], dedup_question_answer_results
] = []
class OrigQuestionRetrievalUpdate(LoggerUpdate):
orig_question_retrieved_documents: Annotated[
list[InferenceSection], dedup_inference_sections
]
orig_question_verified_reranked_documents: Annotated[
list[InferenceSection], dedup_inference_sections
]
orig_question_sub_query_retrieval_results: list[QueryRetrievalResult] = []
orig_question_retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats()
class EntityTermExtractionUpdate(LoggerUpdate):
entity_relation_term_extractions: EntityRelationshipTermExtraction = (
EntityRelationshipTermExtraction()
)
class RefinedQuestionDecompositionUpdate(RefinedAgentStartStats, LoggerUpdate):
refined_sub_questions: dict[int, RefinementSubQuestion] = {}
## Graph Input State
class MainInput(CoreState):
pass
## Graph State
class MainState(
# This includes the core state
MainInput,
ToolChoiceInput,
ToolCallUpdate,
ToolChoiceUpdate,
InitialQuestionDecompositionUpdate,
InitialAnswerUpdate,
SubQuestionResultsUpdate,
OrigQuestionRetrievalUpdate,
EntityTermExtractionUpdate,
InitialAnswerQualityUpdate,
RequireRefinemenEvalUpdate,
RefinedQuestionDecompositionUpdate,
RefinedAnswerUpdate,
RefinedAgentStartStats,
RefinedAgentEndStats,
InitialRefinedAnswerComparisonUpdate,
ExploratorySearchUpdate,
):
pass
## Graph Output State - presently not used
class MainOutput(TypedDict):
log_messages: list[str]

View File

@@ -0,0 +1,33 @@
from collections.abc import Hashable
from datetime import datetime
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def send_to_expanded_refined_retrieval(
state: SubQuestionAnsweringInput,
) -> Send | Hashable:
"""
LangGraph edge to sends a refined sub-question extended retrieval.
"""
logger.debug("sending to expanded retrieval for follow up question via edge")
datetime.now()
return Send(
"refined_sub_question_expanded_retrieval",
ExpandedRetrievalInput(
question=state.question,
sub_question_id=state.question_id,
base_search=False,
log_messages=[f"{datetime.now()} -- Sending to expanded retrieval"],
),
)

View File

@@ -0,0 +1,132 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.check_sub_answer import (
check_sub_answer,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.format_sub_answer import (
format_sub_answer,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.generate_sub_answer import (
generate_sub_answer,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.nodes.ingest_retrieved_documents import (
ingest_retrieved_documents,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionState,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.refinement.consolidate_sub_answers.edges import (
send_to_expanded_refined_retrieval,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.graph_builder import (
expanded_retrieval_graph_builder,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def answer_refined_query_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the refined sub-answer generation process.
"""
graph = StateGraph(
state_schema=AnswerQuestionState,
input=SubQuestionAnsweringInput,
output=AnswerQuestionOutput,
)
### Add nodes ###
# Subgraph for the expanded retrieval process
expanded_retrieval = expanded_retrieval_graph_builder().compile()
graph.add_node(
node="refined_sub_question_expanded_retrieval",
action=expanded_retrieval,
)
# Ingest the retrieved documents
graph.add_node(
node="ingest_refined_retrieval",
action=ingest_retrieved_documents,
)
# Generate the refined sub-answer
graph.add_node(
node="generate_refined_sub_answer",
action=generate_sub_answer,
)
# Check if the refined sub-answer is correct
graph.add_node(
node="refined_sub_answer_check",
action=check_sub_answer,
)
# Format the refined sub-answer
graph.add_node(
node="format_refined_sub_answer",
action=format_sub_answer,
)
### Add edges ###
graph.add_conditional_edges(
source=START,
path=send_to_expanded_refined_retrieval,
path_map=["refined_sub_question_expanded_retrieval"],
)
graph.add_edge(
start_key="refined_sub_question_expanded_retrieval",
end_key="ingest_refined_retrieval",
)
graph.add_edge(
start_key="ingest_refined_retrieval",
end_key="generate_refined_sub_answer",
)
graph.add_edge(
start_key="generate_refined_sub_answer",
end_key="refined_sub_answer_check",
)
graph.add_edge(
start_key="refined_sub_answer_check",
end_key="format_refined_sub_answer",
)
graph.add_edge(
start_key="format_refined_sub_answer",
end_key=END,
)
return graph
if __name__ == "__main__":
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
graph = answer_refined_query_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
with get_session_context_manager() as db_session:
inputs = SubQuestionAnsweringInput(
question="what can you do with onyx?",
question_id="0_0",
log_messages=[],
)
for thing in compiled_graph.stream(
input=inputs,
stream_mode="custom",
):
logger.debug(thing)

View File

@@ -0,0 +1,42 @@
from collections.abc import Hashable
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
RetrievalInput,
)
from onyx.agents.agent_search.models import GraphConfig
def parallel_retrieval_edge(
state: ExpandedRetrievalState, config: RunnableConfig
) -> list[Send | Hashable]:
"""
LangGraph edge to parallelize the retrieval process for each of the
generated sub-queries and the original question.
"""
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = (
state.question if state.question else graph_config.inputs.search_request.query
)
query_expansions = state.expanded_queries + [question]
return [
Send(
"retrieve_documents",
RetrievalInput(
query_to_retrieve=query,
question=question,
base_search=False,
sub_question_id=state.sub_question_id,
log_messages=[],
),
)
for query in query_expansions
]

View File

@@ -0,0 +1,161 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.edges import (
parallel_retrieval_edge,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.expand_queries import (
expand_queries,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.format_queries import (
format_queries,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.format_results import (
format_results,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.kickoff_verification import (
kickoff_verification,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.rerank_documents import (
rerank_documents,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.retrieve_documents import (
retrieve_documents,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.nodes.verify_documents import (
verify_documents,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalOutput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.utils.logger import setup_logger
logger = setup_logger()
def expanded_retrieval_graph_builder() -> StateGraph:
"""
LangGraph graph builder for the expanded retrieval process.
"""
graph = StateGraph(
state_schema=ExpandedRetrievalState,
input=ExpandedRetrievalInput,
output=ExpandedRetrievalOutput,
)
### Add nodes ###
# Convert the question into multiple sub-queries
graph.add_node(
node="expand_queries",
action=expand_queries,
)
# Format the sub-queries into a list of strings
graph.add_node(
node="format_queries",
action=format_queries,
)
# Retrieve the documents for each sub-query
graph.add_node(
node="retrieve_documents",
action=retrieve_documents,
)
# Start verification process that the documents are relevant to the question (not the query)
graph.add_node(
node="kickoff_verification",
action=kickoff_verification,
)
# Verify that a given document is relevant to the question (not the query)
graph.add_node(
node="verify_documents",
action=verify_documents,
)
# Rerank the documents that have been verified
graph.add_node(
node="rerank_documents",
action=rerank_documents,
)
# Format the results into a list of strings
graph.add_node(
node="format_results",
action=format_results,
)
### Add edges ###
graph.add_edge(
start_key=START,
end_key="expand_queries",
)
graph.add_edge(
start_key="expand_queries",
end_key="format_queries",
)
graph.add_conditional_edges(
source="format_queries",
path=parallel_retrieval_edge,
path_map=["retrieve_documents"],
)
graph.add_edge(
start_key="retrieve_documents",
end_key="kickoff_verification",
)
graph.add_edge(
start_key="verify_documents",
end_key="rerank_documents",
)
graph.add_edge(
start_key="rerank_documents",
end_key="format_results",
)
graph.add_edge(
start_key="format_results",
end_key=END,
)
return graph
if __name__ == "__main__":
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.context.search.models import SearchRequest
graph = expanded_retrieval_graph_builder()
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
query="what can you do with onyx or danswer?",
)
with get_session_context_manager() as db_session:
graph_config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
inputs = ExpandedRetrievalInput(
question="what can you do with onyx?",
base_search=False,
sub_question_id=None,
log_messages=[],
)
for thing in compiled_graph.stream(
input=inputs,
config={"configurable": {"config": graph_config}},
stream_mode="custom",
subgraphs=True,
):
logger.debug(thing)

View File

@@ -0,0 +1,13 @@
from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
from onyx.context.search.models import InferenceSection
class QuestionRetrievalResult(BaseModel):
expanded_query_results: list[QueryRetrievalResult] = []
retrieved_documents: list[InferenceSection] = []
verified_reranked_documents: list[InferenceSection] = []
context_documents: list[InferenceSection] = []
retrieval_stats: AgentChunkRetrievalStats = AgentChunkRetrievalStats()

View File

@@ -0,0 +1,75 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.operations import (
dispatch_subquery,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalInput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
QueryExpansionUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import dispatch_separated
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.prompts.agent_search import (
QUERY_REWRITING_PROMPT,
)
def expand_queries(
state: ExpandedRetrievalInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> QueryExpansionUpdate:
"""
LangGraph node to expand a question into multiple search queries.
"""
# Sometimes we want to expand the original question, sometimes we want to expand a sub-question.
# When we are running this node on the original question, no question is explictly passed in.
# Instead, we use the original question from the search request.
graph_config = cast(GraphConfig, config["metadata"]["config"])
node_start_time = datetime.now()
question = state.question
llm = graph_config.tooling.fast_llm
sub_question_id = state.sub_question_id
if sub_question_id is None:
level, question_num = 0, 0
else:
level, question_num = parse_question_id(sub_question_id)
msg = [
HumanMessage(
content=QUERY_REWRITING_PROMPT.format(question=question),
)
]
llm_response_list = dispatch_separated(
llm.stream(prompt=msg), dispatch_subquery(level, question_num, writer)
)
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
rewritten_queries = llm_response.split("\n")
return QueryExpansionUpdate(
expanded_queries=rewritten_queries,
log_messages=[
get_langgraph_node_log_string(
graph_component="shared - expanded retrieval",
node_name="expand queries",
node_start_time=node_start_time,
result=f"Number of expanded queries: {len(rewritten_queries)}",
)
],
)

View File

@@ -0,0 +1,19 @@
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
QueryExpansionUpdate,
)
def format_queries(
state: ExpandedRetrievalState, config: RunnableConfig
) -> QueryExpansionUpdate:
"""
LangGraph node to format the expanded queries into a list of strings.
"""
return QueryExpansionUpdate(
expanded_queries=state.expanded_queries,
)

View File

@@ -0,0 +1,91 @@
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.deep_search.main.operations import get_query_info
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.models import (
QuestionRetrievalResult,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.operations import (
calculate_sub_question_retrieval_stats,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import relevance_from_docs
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import ExtendedToolResponse
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
def format_results(
state: ExpandedRetrievalState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ExpandedRetrievalUpdate:
"""
LangGraph node that constructs the proper expanded retrieval format.
"""
level, question_num = parse_question_id(state.sub_question_id or "0_0")
query_info = get_query_info(state.query_retrieval_results)
graph_config = cast(GraphConfig, config["metadata"]["config"])
# Main question docs will be sent later after aggregation and deduping with sub-question docs
reranked_documents = state.reranked_documents
if not (level == 0 and question_num == 0):
if len(reranked_documents) == 0:
# The sub-question is used as the last query. If no verified documents are found, stream
# the top 3 for that one. We may want to revisit this.
reranked_documents = state.query_retrieval_results[-1].retrieved_documents[
:3
]
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
relevance_list = relevance_from_docs(reranked_documents)
for tool_response in yield_search_responses(
query=state.question,
reranked_sections=state.retrieved_documents,
final_context_sections=reranked_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
):
write_custom_event(
"tool_response",
ExtendedToolResponse(
id=tool_response.id,
response=tool_response.response,
level=level,
level_question_num=question_num,
),
writer,
)
sub_question_retrieval_stats = calculate_sub_question_retrieval_stats(
verified_documents=state.verified_documents,
expanded_retrieval_results=state.query_retrieval_results,
)
if sub_question_retrieval_stats is None:
sub_question_retrieval_stats = AgentChunkRetrievalStats()
return ExpandedRetrievalUpdate(
expanded_retrieval_result=QuestionRetrievalResult(
expanded_query_results=state.query_retrieval_results,
retrieved_documents=state.retrieved_documents,
verified_reranked_documents=reranked_documents,
context_documents=state.reranked_documents,
retrieval_stats=sub_question_retrieval_stats,
),
)

View File

@@ -0,0 +1,44 @@
from typing import Literal
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import Command
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
DocVerificationInput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalState,
)
def kickoff_verification(
state: ExpandedRetrievalState,
config: RunnableConfig,
) -> Command[Literal["verify_documents"]]:
"""
LangGraph node (Command node!) that kicks off the verification process for the retrieved documents.
Note that this is a Command node and does the routing as well. (At present, no state updates
are done here, so this could be replaced with an edge. But we may choose to make state
updates later.)
"""
retrieved_documents = state.retrieved_documents
verification_question = state.question
sub_question_id = state.sub_question_id
return Command(
update={},
goto=[
Send(
node="verify_documents",
arg=DocVerificationInput(
retrieved_document_to_verify=document,
question=verification_question,
base_search=False,
sub_question_id=sub_question_id,
log_messages=[],
),
)
for document in retrieved_documents
],
)

View File

@@ -0,0 +1,105 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.operations import (
logger,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
DocRerankingUpdate,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
ExpandedRetrievalState,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS
from onyx.configs.agent_configs import AGENT_RERANKING_STATS
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import SearchRequest
from onyx.context.search.pipeline import retrieval_preprocessing
from onyx.context.search.postprocessing.postprocessing import rerank_sections
from onyx.db.engine import get_session_context_manager
def rerank_documents(
state: ExpandedRetrievalState, config: RunnableConfig
) -> DocRerankingUpdate:
"""
LangGraph node to rerank the retrieved and verified documents. A part of the
pre-existing pipeline is used here.
"""
node_start_time = datetime.now()
verified_documents = state.verified_documents
# Rerank post retrieval and verification. First, create a search query
# then create the list of reranked sections
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = (
state.question if state.question else graph_config.inputs.search_request.query
)
assert (
graph_config.tooling.search_tool
), "search_tool must be provided for agentic search"
with get_session_context_manager() as db_session:
# we ignore some of the user specified fields since this search is
# internal to agentic search, but we still want to pass through
# persona (for stuff like document sets) and rerank settings
# (to not make an unnecessary db call).
search_request = SearchRequest(
query=question,
persona=graph_config.inputs.search_request.persona,
rerank_settings=graph_config.inputs.search_request.rerank_settings,
)
_search_query = retrieval_preprocessing(
search_request=search_request,
user=graph_config.tooling.search_tool.user, # bit of a hack
llm=graph_config.tooling.fast_llm,
db_session=db_session,
)
# skip section filtering
if (
_search_query.rerank_settings
and _search_query.rerank_settings.rerank_model_name
and _search_query.rerank_settings.num_rerank > 0
and len(verified_documents) > 0
):
if len(verified_documents) > 1:
reranked_documents = rerank_sections(
_search_query,
verified_documents,
)
else:
num = "No" if len(verified_documents) == 0 else "One"
logger.warning(f"{num} verified document(s) found, skipping reranking")
reranked_documents = verified_documents
else:
logger.warning("No reranking settings found, using unranked documents")
reranked_documents = verified_documents
if AGENT_RERANKING_STATS:
fit_scores = get_fit_scores(verified_documents, reranked_documents)
else:
fit_scores = RetrievalFitStats(fit_score_lift=0, rerank_effect=0, fit_scores={})
return DocRerankingUpdate(
reranked_documents=[
doc for doc in reranked_documents if type(doc) == InferenceSection
][:AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS],
sub_question_retrieval_stats=fit_scores,
log_messages=[
get_langgraph_node_log_string(
graph_component="shared - expanded retrieval",
node_name="rerank documents",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,113 @@
from datetime import datetime
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.operations import (
logger,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
DocRetrievalUpdate,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
RetrievalInput,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.calculations import get_fit_scores
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import AGENT_MAX_QUERY_RETRIEVAL_RESULTS
from onyx.configs.agent_configs import AGENT_RETRIEVAL_STATS
from onyx.context.search.models import InferenceSection
from onyx.db.engine import get_session_context_manager
from onyx.tools.models import SearchQueryInfo
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
def retrieve_documents(
state: RetrievalInput, config: RunnableConfig
) -> DocRetrievalUpdate:
"""
LangGraph node to retrieve documents from the search tool.
"""
node_start_time = datetime.now()
query_to_retrieve = state.query_to_retrieve
graph_config = cast(GraphConfig, config["metadata"]["config"])
search_tool = graph_config.tooling.search_tool
retrieved_docs: list[InferenceSection] = []
if not query_to_retrieve.strip():
logger.warning("Empty query, skipping retrieval")
return DocRetrievalUpdate(
query_retrieval_results=[],
retrieved_documents=[],
log_messages=[
get_langgraph_node_log_string(
graph_component="shared - expanded retrieval",
node_name="retrieve documents",
node_start_time=node_start_time,
result="Empty query, skipping retrieval",
)
],
)
query_info = None
if search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
callback_container: list[list[InferenceSection]] = []
# new db session to avoid concurrency issues
with get_session_context_manager() as db_session:
for tool_response in search_tool.run(
query=query_to_retrieve,
force_no_rerank=True,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
response = cast(SearchResponseSummary, tool_response.response)
retrieved_docs = response.top_sections
query_info = SearchQueryInfo(
predicted_search=response.predicted_search,
final_filters=response.final_filters,
recency_bias_multiplier=response.recency_bias_multiplier,
)
break
retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]
if AGENT_RETRIEVAL_STATS:
pre_rerank_docs = callback_container[0]
fit_scores = get_fit_scores(
pre_rerank_docs,
retrieved_docs,
)
else:
fit_scores = None
expanded_retrieval_result = QueryRetrievalResult(
query=query_to_retrieve,
retrieved_documents=retrieved_docs,
stats=fit_scores,
query_info=query_info,
)
return DocRetrievalUpdate(
query_retrieval_results=[expanded_retrieval_result],
retrieved_documents=retrieved_docs,
log_messages=[
get_langgraph_node_log_string(
graph_component="shared - expanded retrieval",
node_name="retrieve documents",
node_start_time=node_start_time,
)
],
)

View File

@@ -0,0 +1,62 @@
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
DocVerificationInput,
)
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.states import (
DocVerificationUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.prompts.agent_search import (
DOCUMENT_VERIFICATION_PROMPT,
)
def verify_documents(
state: DocVerificationInput, config: RunnableConfig
) -> DocVerificationUpdate:
"""
LangGraph node to check whether the document is relevant for the original user question
Args:
state (DocVerificationInput): The current state
config (RunnableConfig): Configuration containing ProSearchConfig
Updates:
verified_documents: list[InferenceSection]
"""
question = state.question
retrieved_document_to_verify = state.retrieved_document_to_verify
document_content = retrieved_document_to_verify.combined_content
graph_config = cast(GraphConfig, config["metadata"]["config"])
fast_llm = graph_config.tooling.fast_llm
document_content = trim_prompt_piece(
fast_llm.config, document_content, DOCUMENT_VERIFICATION_PROMPT + question
)
msg = [
HumanMessage(
content=DOCUMENT_VERIFICATION_PROMPT.format(
question=question, document_content=document_content
)
)
]
response = fast_llm.invoke(msg)
verified_documents = []
if isinstance(response.content, str) and "yes" in response.content.lower():
verified_documents.append(retrieved_document_to_verify)
return DocVerificationUpdate(
verified_documents=verified_documents,
)

View File

@@ -0,0 +1,93 @@
from collections import defaultdict
from collections.abc import Callable
import numpy as np
from langgraph.types import StreamWriter
from onyx.agents.agent_search.shared_graph_utils.models import AgentChunkRetrievalStats
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import SubQueryPiece
from onyx.context.search.models import InferenceSection
from onyx.utils.logger import setup_logger
logger = setup_logger()
def dispatch_subquery(
level: int, question_num: int, writer: StreamWriter
) -> Callable[[str, int], None]:
def helper(token: str, num: int) -> None:
write_custom_event(
"subqueries",
SubQueryPiece(
sub_query=token,
level=level,
level_question_num=question_num,
query_id=num,
),
writer,
)
return helper
def calculate_sub_question_retrieval_stats(
verified_documents: list[InferenceSection],
expanded_retrieval_results: list[QueryRetrievalResult],
) -> AgentChunkRetrievalStats:
chunk_scores: dict[str, dict[str, list[int | float]]] = defaultdict(
lambda: defaultdict(list)
)
for expanded_retrieval_result in expanded_retrieval_results:
for doc in expanded_retrieval_result.retrieved_documents:
doc_chunk_id = f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
if doc.center_chunk.score is not None:
chunk_scores[doc_chunk_id]["score"].append(doc.center_chunk.score)
verified_doc_chunk_ids = [
f"{verified_document.center_chunk.document_id}_{verified_document.center_chunk.chunk_id}"
for verified_document in verified_documents
]
dismissed_doc_chunk_ids = []
raw_chunk_stats_counts: dict[str, int] = defaultdict(int)
raw_chunk_stats_scores: dict[str, float] = defaultdict(float)
for doc_chunk_id, chunk_data in chunk_scores.items():
valid_chunk_scores = [
score for score in chunk_data["score"] if score is not None
]
key = "verified" if doc_chunk_id in verified_doc_chunk_ids else "rejected"
raw_chunk_stats_counts[f"{key}_count"] += 1
raw_chunk_stats_scores[f"{key}_scores"] += float(np.mean(valid_chunk_scores))
if key == "rejected":
dismissed_doc_chunk_ids.append(doc_chunk_id)
if raw_chunk_stats_counts["verified_count"] == 0:
verified_avg_scores = 0.0
else:
verified_avg_scores = raw_chunk_stats_scores["verified_scores"] / float(
raw_chunk_stats_counts["verified_count"]
)
rejected_scores = raw_chunk_stats_scores.get("rejected_scores")
if rejected_scores is not None:
rejected_avg_scores = rejected_scores / float(
raw_chunk_stats_counts["rejected_count"]
)
else:
rejected_avg_scores = None
chunk_stats = AgentChunkRetrievalStats(
verified_count=raw_chunk_stats_counts["verified_count"],
verified_avg_scores=verified_avg_scores,
rejected_count=raw_chunk_stats_counts["rejected_count"],
rejected_avg_scores=rejected_avg_scores,
verified_doc_chunk_ids=verified_doc_chunk_ids,
dismissed_doc_chunk_ids=dismissed_doc_chunk_ids,
)
return chunk_stats

View File

@@ -0,0 +1,91 @@
from operator import add
from typing import Annotated
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import SubgraphCoreState
from onyx.agents.agent_search.deep_search.main.states import LoggerUpdate
from onyx.agents.agent_search.deep_search.shared.expanded_retrieval.models import (
QuestionRetrievalResult,
)
from onyx.agents.agent_search.shared_graph_utils.models import QueryRetrievalResult
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_sections,
)
from onyx.context.search.models import InferenceSection
### States ###
## Graph Input State
class ExpandedRetrievalInput(SubgraphCoreState):
question: str = ""
base_search: bool = False
sub_question_id: str | None = None
## Update/Return States
class QueryExpansionUpdate(LoggerUpdate, BaseModel):
expanded_queries: list[str] = []
log_messages: list[str] = []
class DocVerificationUpdate(BaseModel):
verified_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
class DocRetrievalUpdate(LoggerUpdate, BaseModel):
query_retrieval_results: Annotated[list[QueryRetrievalResult], add] = []
retrieved_documents: Annotated[
list[InferenceSection], dedup_inference_sections
] = []
class DocRerankingUpdate(LoggerUpdate, BaseModel):
reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections] = []
sub_question_retrieval_stats: RetrievalFitStats | None = None
class ExpandedRetrievalUpdate(LoggerUpdate, BaseModel):
expanded_retrieval_result: QuestionRetrievalResult
## Graph Output State
class ExpandedRetrievalOutput(LoggerUpdate, BaseModel):
expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult()
base_expanded_retrieval_result: QuestionRetrievalResult = QuestionRetrievalResult()
retrieved_documents: Annotated[
list[InferenceSection], dedup_inference_sections
] = []
## Graph State
class ExpandedRetrievalState(
# This includes the core state
ExpandedRetrievalInput,
QueryExpansionUpdate,
DocRetrievalUpdate,
DocVerificationUpdate,
DocRerankingUpdate,
ExpandedRetrievalOutput,
):
pass
## Conditional Input States
class DocVerificationInput(ExpandedRetrievalInput):
retrieved_document_to_verify: InferenceSection
class RetrievalInput(ExpandedRetrievalInput):
query_to_retrieve: str = ""

View File

@@ -0,0 +1,90 @@
from uuid import UUID
from pydantic import BaseModel
from pydantic import model_validator
from sqlalchemy.orm import Session
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.context.search.models import SearchRequest
from onyx.file_store.utils import InMemoryChatFile
from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
class GraphInputs(BaseModel):
"""Input data required for the graph execution"""
search_request: SearchRequest
prompt_builder: AnswerPromptBuilder
files: list[InMemoryChatFile] | None = None
structured_response_format: dict | None = None
class Config:
arbitrary_types_allowed = True
class GraphTooling(BaseModel):
"""Tools and LLMs available to the graph"""
primary_llm: LLM
fast_llm: LLM
search_tool: SearchTool | None = None
tools: list[Tool]
# Whether to force use of a tool, or to
# force tool args IF the tool is used
force_use_tool: ForceUseTool
using_tool_calling_llm: bool = False
class Config:
arbitrary_types_allowed = True
class GraphPersistence(BaseModel):
"""Configuration for data persistence"""
chat_session_id: UUID
# The message ID of the to-be-created first agent message
# in response to the user message that triggered the Pro Search
message_id: int
# The database session the user and initial agent
# message were flushed to; only needed for agentic search
db_session: Session
class Config:
arbitrary_types_allowed = True
class GraphSearchConfig(BaseModel):
"""Configuration controlling search behavior"""
use_agentic_search: bool = False
# Whether to perform initial search to inform decomposition
perform_initial_search_decomposition: bool = True
# Whether to allow creation of refinement questions (and entity extraction, etc.)
allow_refinement: bool = True
skip_gen_ai_answer_generation: bool = False
class GraphConfig(BaseModel):
"""
Main container for data needed for Langgraph execution
"""
inputs: GraphInputs
tooling: GraphTooling
behavior: GraphSearchConfig
# Only needed for agentic search
persistence: GraphPersistence
@model_validator(mode="after")
def validate_search_tool(self) -> "GraphConfig":
if self.behavior.use_agentic_search and self.tooling.search_tool is None:
raise ValueError("search_tool must be provided for agentic search")
return self
class Config:
arbitrary_types_allowed = True

View File

@@ -0,0 +1,77 @@
from typing import cast
from langchain_core.messages import AIMessageChunk
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.basic.states import BasicOutput
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.agents.agent_search.models import GraphConfig
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContexts
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_DOC_CONTENT_ID,
)
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
def basic_use_tool_response(
state: BasicState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> BasicOutput:
agent_config = cast(GraphConfig, config["metadata"]["config"])
structured_response_format = agent_config.inputs.structured_response_format
llm = agent_config.tooling.primary_llm
tool_choice = state.tool_choice
if tool_choice is None:
raise ValueError("Tool choice is None")
tool = tool_choice.tool
prompt_builder = agent_config.inputs.prompt_builder
if state.tool_call_output is None:
raise ValueError("Tool call output is None")
tool_call_output = state.tool_call_output
tool_call_summary = tool_call_output.tool_call_summary
tool_call_responses = tool_call_output.tool_call_responses
new_prompt_builder = tool.build_next_prompt(
prompt_builder=prompt_builder,
tool_call_summary=tool_call_summary,
tool_responses=tool_call_responses,
using_tool_calling_llm=agent_config.tooling.using_tool_calling_llm,
)
final_search_results = []
initial_search_results = []
for yield_item in tool_call_responses:
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_search_results = cast(list[LlmDoc], yield_item.response)
elif yield_item.id == SEARCH_DOC_CONTENT_ID:
search_contexts = cast(OnyxContexts, yield_item.response).contexts
for doc in search_contexts:
if doc.document_id not in initial_search_results:
initial_search_results.append(doc)
new_tool_call_chunk = AIMessageChunk(content="")
if not agent_config.behavior.skip_gen_ai_answer_generation:
stream = llm.stream(
prompt=new_prompt_builder.build(),
structured_response_format=structured_response_format,
)
# For now, we don't do multiple tool calls, so we ignore the tool_message
new_tool_call_chunk = process_llm_stream(
stream,
True,
writer,
final_search_results=final_search_results,
# when the search tool is called with specific doc ids, initial search
# results are not output. But, we still want i.e. citations to be processed.
displayed_search_results=initial_search_results or final_search_results,
)
return BasicOutput(tool_call_chunk=new_tool_call_chunk)

View File

@@ -0,0 +1,154 @@
from typing import cast
from uuid import uuid4
from langchain_core.messages import ToolCall
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.orchestration.states import ToolChoice
from onyx.agents.agent_search.orchestration.states import ToolChoiceState
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.chat.tool_handling.tool_response_handler import (
get_tool_call_for_non_tool_calling_llm_impl,
)
from onyx.tools.tool import Tool
from onyx.utils.logger import setup_logger
logger = setup_logger()
# TODO: break this out into an implementation function
# and a function that handles extracting the necessary fields
# from the state and config
# TODO: fan-out to multiple tool call nodes? Make this configurable?
def llm_tool_choice(
state: ToolChoiceState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ToolChoiceUpdate:
"""
This node is responsible for calling the LLM to choose a tool. If no tool is chosen,
The node MAY emit an answer, depending on whether state["should_stream_answer"] is set.
"""
should_stream_answer = state.should_stream_answer
agent_config = cast(GraphConfig, config["metadata"]["config"])
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
llm = agent_config.tooling.primary_llm
skip_gen_ai_answer_generation = agent_config.behavior.skip_gen_ai_answer_generation
structured_response_format = agent_config.inputs.structured_response_format
tools = [
tool for tool in (agent_config.tooling.tools or []) if tool.name in state.tools
]
force_use_tool = agent_config.tooling.force_use_tool
tool, tool_args = None, None
if force_use_tool.force_use and force_use_tool.args is not None:
tool_name, tool_args = (
force_use_tool.tool_name,
force_use_tool.args,
)
tool = get_tool_by_name(tools, tool_name)
# special pre-logic for non-tool calling LLM case
elif not using_tool_calling_llm and tools:
chosen_tool_and_args = get_tool_call_for_non_tool_calling_llm_impl(
force_use_tool=force_use_tool,
tools=tools,
prompt_builder=prompt_builder,
llm=llm,
)
if chosen_tool_and_args:
tool, tool_args = chosen_tool_and_args
# If we have a tool and tool args, we are ready to request a tool call.
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
if tool and tool_args:
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=tool,
tool_args=tool_args,
id=str(uuid4()),
),
)
# if we're skipping gen ai answer generation, we should only
# continue if we're forcing a tool call (which will be emitted by
# the tool calling llm in the stream() below)
if skip_gen_ai_answer_generation and not force_use_tool.force_use:
return ToolChoiceUpdate(
tool_choice=None,
)
built_prompt = (
prompt_builder.build()
if isinstance(prompt_builder, AnswerPromptBuilder)
else prompt_builder.built_prompt
)
# At this point, we are either using a tool calling LLM or we are skipping the tool call.
# DEBUG: good breakpoint
stream = llm.stream(
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
prompt=built_prompt,
tools=[tool.tool_definition() for tool in tools] or None,
tool_choice=("required" if tools and force_use_tool.force_use else None),
structured_response_format=structured_response_format,
)
tool_message = process_llm_stream(
stream,
should_stream_answer
and not agent_config.behavior.skip_gen_ai_answer_generation,
writer,
)
# If no tool calls are emitted by the LLM, we should not choose a tool
if len(tool_message.tool_calls) == 0:
logger.debug("No tool calls emitted by LLM")
return ToolChoiceUpdate(
tool_choice=None,
)
# TODO: here we could handle parallel tool calls. Right now
# we just pick the first one that matches.
selected_tool: Tool | None = None
selected_tool_call_request: ToolCall | None = None
for tool_call_request in tool_message.tool_calls:
known_tools_by_name = [
tool for tool in tools if tool.name == tool_call_request["name"]
]
if known_tools_by_name:
selected_tool = known_tools_by_name[0]
selected_tool_call_request = tool_call_request
break
logger.error(
"Tool call requested with unknown name field. \n"
f"tools: {tools}"
f"tool_call_request: {tool_call_request}"
)
if not selected_tool or not selected_tool_call_request:
raise ValueError(
f"Tool call attempted with tool {selected_tool}, request {selected_tool_call_request}"
)
logger.debug(f"Selected tool: {selected_tool.name}")
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=selected_tool,
tool_args=selected_tool_call_request["args"],
id=selected_tool_call_request["id"],
),
)

View File

@@ -0,0 +1,17 @@
from typing import Any
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
def prepare_tool_input(state: Any, config: RunnableConfig) -> ToolChoiceInput:
agent_config = cast(GraphConfig, config["metadata"]["config"])
return ToolChoiceInput(
# NOTE: this node is used at the top level of the agent, so we always stream
should_stream_answer=True,
prompt_snapshot=None, # uses default prompt builder
tools=[tool.name for tool in (agent_config.tooling.tools or [])],
)

View File

@@ -0,0 +1,79 @@
from typing import cast
from langchain_core.messages import AIMessageChunk
from langchain_core.messages.tool import ToolCall
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.orchestration.states import ToolCallOutput
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AnswerPacket
from onyx.tools.message import build_tool_message
from onyx.tools.message import ToolCallSummary
from onyx.tools.tool_runner import ToolRunner
from onyx.utils.logger import setup_logger
logger = setup_logger()
class ToolCallException(Exception):
"""Exception raised for errors during tool calls."""
def emit_packet(packet: AnswerPacket, writer: StreamWriter) -> None:
write_custom_event("basic_response", packet, writer)
def tool_call(
state: ToolChoiceUpdate,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ToolCallUpdate:
"""Calls the tool specified in the state and updates the state with the result"""
cast(GraphConfig, config["metadata"]["config"])
tool_choice = state.tool_choice
if tool_choice is None:
raise ValueError("Cannot invoke tool call node without a tool choice")
tool = tool_choice.tool
tool_args = tool_choice.tool_args
tool_id = tool_choice.id
tool_runner = ToolRunner(tool, tool_args)
tool_kickoff = tool_runner.kickoff()
emit_packet(tool_kickoff, writer)
try:
tool_responses = []
for response in tool_runner.tool_responses():
tool_responses.append(response)
emit_packet(response, writer)
tool_final_result = tool_runner.tool_final_result()
emit_packet(tool_final_result, writer)
except Exception as e:
raise ToolCallException(
f"Error during tool call for {tool.display_name}: {e}"
) from e
tool_call = ToolCall(name=tool.name, args=tool_args, id=tool_id)
tool_call_summary = ToolCallSummary(
tool_call_request=AIMessageChunk(content="", tool_calls=[tool_call]),
tool_call_result=build_tool_message(
tool_call, tool_runner.tool_message_content()
),
)
tool_call_output = ToolCallOutput(
tool_call_summary=tool_call_summary,
tool_call_kickoff=tool_kickoff,
tool_call_responses=tool_responses,
tool_call_final_result=tool_final_result,
)
return ToolCallUpdate(tool_call_output=tool_call_output)

View File

@@ -0,0 +1,48 @@
from pydantic import BaseModel
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
# TODO: adapt the tool choice/tool call to allow for parallel tool calls by
# creating a subgraph that can be invoked in parallel via Send/Command APIs
class ToolChoiceInput(BaseModel):
should_stream_answer: bool = True
# default to the prompt builder from the config, but
# allow overrides for arbitrary tool calls
prompt_snapshot: PromptSnapshot | None = None
# names of tools to use for tool calling. Filters the tools available in the config
tools: list[str] = []
class ToolCallOutput(BaseModel):
tool_call_summary: ToolCallSummary
tool_call_kickoff: ToolCallKickoff
tool_call_responses: list[ToolResponse]
tool_call_final_result: ToolCallFinalResult
class ToolCallUpdate(BaseModel):
tool_call_output: ToolCallOutput | None = None
class ToolChoice(BaseModel):
tool: Tool
tool_args: dict
id: str | None
class Config:
arbitrary_types_allowed = True
class ToolChoiceUpdate(BaseModel):
tool_choice: ToolChoice | None = None
class ToolChoiceState(ToolChoiceUpdate, ToolChoiceInput):
pass

View File

@@ -0,0 +1,213 @@
from collections.abc import Iterable
from datetime import datetime
from typing import cast
from langchain_core.runnables.schema import CustomStreamEvent
from langchain_core.runnables.schema import StreamEvent
from langgraph.graph.state import CompiledStateGraph
from onyx.agents.agent_search.basic.graph_builder import basic_graph_builder
from onyx.agents.agent_search.basic.states import BasicInput
from onyx.agents.agent_search.deep_search.main.graph_builder import (
main_graph_builder as main_graph_builder_a,
)
from onyx.agents.agent_search.deep_search.main.states import (
MainInput as MainInput_a,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AnswerPacket
from onyx.chat.models import AnswerStream
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionPiece
from onyx.chat.models import ToolResponse
from onyx.configs.agent_configs import ALLOW_REFINEMENT
from onyx.configs.agent_configs import INITIAL_SEARCH_DECOMPOSITION_ENABLED
from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_context_manager
from onyx.llm.factory import get_default_llms
from onyx.tools.tool_runner import ToolCallKickoff
from onyx.utils.logger import setup_logger
logger = setup_logger()
_COMPILED_GRAPH: CompiledStateGraph | None = None
def _parse_agent_event(
event: StreamEvent,
) -> AnswerPacket | None:
"""
Parse the event into a typed object.
Return None if we are not interested in the event.
"""
event_type = event["event"]
# We always just yield the event data, but this piece is useful for two development reasons:
# 1. It's a list of the names of every place we dispatch a custom event
# 2. We maintain the intended types yielded by each event
if event_type == "on_custom_event":
if event["name"] == "decomp_qs":
return cast(SubQuestionPiece, event["data"])
elif event["name"] == "subqueries":
return cast(SubQueryPiece, event["data"])
elif event["name"] == "sub_answers":
return cast(AgentAnswerPiece, event["data"])
elif event["name"] == "stream_finished":
return cast(StreamStopInfo, event["data"])
elif event["name"] == "initial_agent_answer":
return cast(AgentAnswerPiece, event["data"])
elif event["name"] == "refined_agent_answer":
return cast(AgentAnswerPiece, event["data"])
elif event["name"] == "start_refined_answer_creation":
return cast(ToolCallKickoff, event["data"])
elif event["name"] == "tool_response":
return cast(ToolResponse, event["data"])
elif event["name"] == "basic_response":
return cast(AnswerPacket, event["data"])
elif event["name"] == "refined_answer_improvement":
return cast(RefinedAnswerImprovement, event["data"])
return None
def manage_sync_streaming(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
graph_input: BasicInput | MainInput_a,
) -> Iterable[StreamEvent]:
message_id = config.persistence.message_id if config.persistence else None
for event in compiled_graph.stream(
stream_mode="custom",
input=graph_input,
config={"metadata": {"config": config, "thread_id": str(message_id)}},
):
yield cast(CustomStreamEvent, event)
def run_graph(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
input: BasicInput | MainInput_a,
) -> AnswerStream:
config.behavior.perform_initial_search_decomposition = (
INITIAL_SEARCH_DECOMPOSITION_ENABLED
)
config.behavior.allow_refinement = ALLOW_REFINEMENT
for event in manage_sync_streaming(
compiled_graph=compiled_graph, config=config, graph_input=input
):
if not (parsed_object := _parse_agent_event(event)):
continue
yield parsed_object
# It doesn't actually take very long to load the graph, but we'd rather
# not compile it again on every request.
def load_compiled_graph() -> CompiledStateGraph:
global _COMPILED_GRAPH
if _COMPILED_GRAPH is None:
graph = main_graph_builder_a()
_COMPILED_GRAPH = graph.compile()
return _COMPILED_GRAPH
def run_main_graph(
config: GraphConfig,
) -> AnswerStream:
compiled_graph = load_compiled_graph()
input = MainInput_a(
base_question=config.inputs.search_request.query, log_messages=[]
)
# Agent search is not a Tool per se, but this is helpful for the frontend
yield ToolCallKickoff(
tool_name="agent_search_0",
tool_args={"query": config.inputs.search_request.query},
)
yield from run_graph(compiled_graph, config, input)
def run_basic_graph(
config: GraphConfig,
) -> AnswerStream:
graph = basic_graph_builder()
compiled_graph = graph.compile()
input = BasicInput()
return run_graph(compiled_graph, config, input)
if __name__ == "__main__":
for _ in range(1):
query_start_time = datetime.now()
logger.debug(f"Start at {query_start_time}")
graph = main_graph_builder_a()
compiled_graph = graph.compile()
query_end_time = datetime.now()
logger.debug(f"Graph compiled in {query_end_time - query_start_time} seconds")
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
# query="what can you do with gitlab?",
# query="What are the guiding principles behind the development of cockroachDB",
# query="What are the temperatures in Munich, Hawaii, and New York?",
# query="When was Washington born?",
# query="What is Onyx?",
# query="What is the difference between astronomy and astrology?",
query="Do a search to tell me what is the difference between astronomy and astrology?",
)
with get_session_context_manager() as db_session:
config = get_test_config(db_session, primary_llm, fast_llm, search_request)
assert (
config.persistence is not None
), "set a chat session id to run this test"
# search_request.persona = get_persona_by_id(1, None, db_session)
# config.perform_initial_search_path_decision = False
config.behavior.perform_initial_search_decomposition = True
input = MainInput_a(
base_question=config.inputs.search_request.query, log_messages=[]
)
tool_responses: list = []
for output in run_graph(compiled_graph, config, input):
if isinstance(output, ToolCallKickoff):
pass
elif isinstance(output, ExtendedToolResponse):
tool_responses.append(output.response)
logger.info(
f" ---- ET {output.level} - {output.level_question_num} | "
)
elif isinstance(output, SubQueryPiece):
logger.info(
f"Sq {output.level} - {output.level_question_num} - {output.sub_query} | "
)
elif isinstance(output, SubQuestionPiece):
logger.info(
f"SQ {output.level} - {output.level_question_num} - {output.sub_question} | "
)
elif (
isinstance(output, AgentAnswerPiece)
and output.answer_type == "agent_sub_answer"
):
logger.info(
f" ---- SA {output.level} - {output.level_question_num} {output.answer_piece} | "
)
elif (
isinstance(output, AgentAnswerPiece)
and output.answer_type == "agent_level_answer"
):
logger.info(
f" ---------- FA {output.level} - {output.level_question_num} {output.answer_piece} | "
)
elif isinstance(output, RefinedAnswerImprovement):
logger.info(
f" ---------- RE {output.refined_answer_improvement} | "
)

View File

@@ -0,0 +1,152 @@
from langchain.schema import AIMessage
from langchain.schema import HumanMessage
from langchain.schema import SystemMessage
from langchain_core.messages.tool import ToolMessage
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.models import (
AgentPromptEnrichmentComponents,
)
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_persona_agent_prompt_expressions,
)
from onyx.agents.agent_search.shared_graph_utils.utils import remove_document_citations
from onyx.agents.agent_search.shared_graph_utils.utils import summarize_history
from onyx.configs.agent_configs import AGENT_MAX_STATIC_HISTORY_WORD_LENGTH
from onyx.configs.constants import MessageType
from onyx.context.search.models import InferenceSection
from onyx.llm.interfaces import LLMConfig
from onyx.llm.utils import get_max_input_tokens
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.natural_language_processing.utils import tokenizer_trim_content
from onyx.prompts.agent_search import HISTORY_FRAMING_PROMPT
from onyx.prompts.agent_search import SUB_QUESTION_RAG_PROMPT
from onyx.prompts.prompt_utils import build_date_time_string
from onyx.utils.logger import setup_logger
logger = setup_logger()
def build_sub_question_answer_prompt(
question: str,
original_question: str,
docs: list[InferenceSection],
persona_specification: str,
config: LLMConfig,
) -> list[SystemMessage | HumanMessage | AIMessage | ToolMessage]:
system_message = SystemMessage(
content=persona_specification,
)
date_str = build_date_time_string()
# TODO: This should include document metadata and title
docs_format_list = [
f"Document Number: [D{doc_num + 1}]\nContent: {doc.combined_content}\n\n"
for doc_num, doc in enumerate(docs)
]
docs_str = "\n\n".join(docs_format_list)
docs_str = trim_prompt_piece(
config,
docs_str,
SUB_QUESTION_RAG_PROMPT + question + original_question + date_str,
)
human_message = HumanMessage(
content=SUB_QUESTION_RAG_PROMPT.format(
question=question,
original_question=original_question,
context=docs_str,
date_prompt=date_str,
)
)
return [system_message, human_message]
def trim_prompt_piece(config: LLMConfig, prompt_piece: str, reserved_str: str) -> str:
# TODO: save the max input tokens in LLMConfig
max_tokens = get_max_input_tokens(
model_provider=config.model_provider,
model_name=config.model_name,
)
# no need to trim if a conservative estimate of one token
# per character is already less than the max tokens
if len(prompt_piece) + len(reserved_str) < max_tokens:
return prompt_piece
llm_tokenizer = get_tokenizer(
provider_type=config.model_provider,
model_name=config.model_name,
)
# slightly conservative trimming
return tokenizer_trim_content(
content=prompt_piece,
desired_length=max_tokens - len(llm_tokenizer.encode(reserved_str)),
tokenizer=llm_tokenizer,
)
def build_history_prompt(config: GraphConfig, question: str) -> str:
prompt_builder = config.inputs.prompt_builder
persona_base = get_persona_agent_prompt_expressions(
config.inputs.search_request.persona
).base_prompt
if prompt_builder is None:
return ""
if prompt_builder.single_message_history is not None:
history = prompt_builder.single_message_history
else:
history_components = []
previous_message_type = None
for message in prompt_builder.raw_message_history:
if message.message_type == MessageType.USER:
history_components.append(f"User: {message.message}\n")
previous_message_type = MessageType.USER
elif message.message_type == MessageType.ASSISTANT:
# Previously there could be multiple assistant messages in a row
# Now this is handled at the message history construction
assert previous_message_type is not MessageType.ASSISTANT
history_components.append(f"You/Agent: {message.message}\n")
previous_message_type = MessageType.ASSISTANT
else:
# Other message types are not included here, currently there should be no other message types
logger.error(
f"Unhandled message type: {message.message_type} with message: {message.message}"
)
continue
history = "\n".join(history_components)
history = remove_document_citations(history)
if len(history.split()) > AGENT_MAX_STATIC_HISTORY_WORD_LENGTH:
history = summarize_history(
history=history,
question=question,
persona_specification=persona_base,
llm=config.tooling.fast_llm,
)
return HISTORY_FRAMING_PROMPT.format(history=history) if history else ""
def get_prompt_enrichment_components(
config: GraphConfig,
) -> AgentPromptEnrichmentComponents:
persona_prompts = get_persona_agent_prompt_expressions(
config.inputs.search_request.persona
)
history = build_history_prompt(config, config.inputs.search_request.query)
date_str = build_date_time_string()
return AgentPromptEnrichmentComponents(
persona_prompts=persona_prompts,
history=history,
date_str=date_str,
)

View File

@@ -0,0 +1,98 @@
import numpy as np
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitScoreMetrics
from onyx.agents.agent_search.shared_graph_utils.models import RetrievalFitStats
from onyx.chat.models import SectionRelevancePiece
from onyx.context.search.models import InferenceSection
from onyx.utils.logger import setup_logger
logger = setup_logger()
def unique_chunk_id(doc: InferenceSection) -> str:
return f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
def calculate_rank_shift(list1: list, list2: list, top_n: int = 20) -> float:
shift = 0
for rank_first, doc_id in enumerate(list1[:top_n], 1):
try:
rank_second = list2.index(doc_id) + 1
except ValueError:
rank_second = len(list2) # Document not found in second list
shift += np.abs(rank_first - rank_second) / np.log(1 + rank_first * rank_second)
return shift / top_n
def get_fit_scores(
pre_reranked_results: list[InferenceSection],
post_reranked_results: list[InferenceSection] | list[SectionRelevancePiece],
) -> RetrievalFitStats | None:
"""
Calculate retrieval metrics for search purposes
"""
if len(pre_reranked_results) == 0 or len(post_reranked_results) == 0:
return None
ranked_sections = {
"initial": pre_reranked_results,
"reranked": post_reranked_results,
}
fit_eval: RetrievalFitStats = RetrievalFitStats(
fit_score_lift=0,
rerank_effect=0,
fit_scores={
"initial": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
"reranked": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
},
)
for rank_type, docs in ranked_sections.items():
logger.debug(f"rank_type: {rank_type}")
for i in [1, 5, 10]:
fit_eval.fit_scores[rank_type].scores[str(i)] = (
sum(
[
float(doc.center_chunk.score)
for doc in docs[:i]
if type(doc) == InferenceSection
and doc.center_chunk.score is not None
]
)
/ i
)
fit_eval.fit_scores[rank_type].scores["fit_score"] = (
1
/ 3
* (
fit_eval.fit_scores[rank_type].scores["1"]
+ fit_eval.fit_scores[rank_type].scores["5"]
+ fit_eval.fit_scores[rank_type].scores["10"]
)
)
fit_eval.fit_scores[rank_type].scores["fit_score"] = fit_eval.fit_scores[
rank_type
].scores["1"]
fit_eval.fit_scores[rank_type].chunk_ids = [
unique_chunk_id(doc) for doc in docs if type(doc) == InferenceSection
]
fit_eval.fit_score_lift = (
fit_eval.fit_scores["reranked"].scores["fit_score"]
/ fit_eval.fit_scores["initial"].scores["fit_score"]
)
fit_eval.rerank_effect = calculate_rank_shift(
fit_eval.fit_scores["initial"].chunk_ids,
fit_eval.fit_scores["reranked"].chunk_ids,
)
return fit_eval

View File

@@ -0,0 +1,128 @@
from pydantic import BaseModel
from onyx.agents.agent_search.deep_search.main.models import (
AgentAdditionalMetrics,
)
from onyx.agents.agent_search.deep_search.main.models import AgentBaseMetrics
from onyx.agents.agent_search.deep_search.main.models import (
AgentRefinedMetrics,
)
from onyx.agents.agent_search.deep_search.main.models import AgentTimings
from onyx.context.search.models import InferenceSection
from onyx.tools.models import SearchQueryInfo
# Pydantic models for structured outputs
# class RewrittenQueries(BaseModel):
# rewritten_queries: list[str]
# class BinaryDecision(BaseModel):
# decision: Literal["yes", "no"]
# class BinaryDecisionWithReasoning(BaseModel):
# reasoning: str
# decision: Literal["yes", "no"]
class RetrievalFitScoreMetrics(BaseModel):
scores: dict[str, float]
chunk_ids: list[str]
class RetrievalFitStats(BaseModel):
fit_score_lift: float
rerank_effect: float
fit_scores: dict[str, RetrievalFitScoreMetrics]
# class AgentChunkScores(BaseModel):
# scores: dict[str, dict[str, list[int | float]]]
class AgentChunkRetrievalStats(BaseModel):
verified_count: int | None = None
verified_avg_scores: float | None = None
rejected_count: int | None = None
rejected_avg_scores: float | None = None
verified_doc_chunk_ids: list[str] = []
dismissed_doc_chunk_ids: list[str] = []
class InitialAgentResultStats(BaseModel):
sub_questions: dict[str, float | int | None]
original_question: dict[str, float | int | None]
agent_effectiveness: dict[str, float | int | None]
class RefinedAgentStats(BaseModel):
revision_doc_efficiency: float | None
revision_question_efficiency: float | None
class Term(BaseModel):
term_name: str = ""
term_type: str = ""
term_similar_to: list[str] = []
### Models ###
class Entity(BaseModel):
entity_name: str = ""
entity_type: str = ""
class Relationship(BaseModel):
relationship_name: str = ""
relationship_type: str = ""
relationship_entities: list[str] = []
class EntityRelationshipTermExtraction(BaseModel):
entities: list[Entity] = []
relationships: list[Relationship] = []
terms: list[Term] = []
class EntityExtractionResult(BaseModel):
retrieved_entities_relationships: EntityRelationshipTermExtraction
class QueryRetrievalResult(BaseModel):
query: str
retrieved_documents: list[InferenceSection]
stats: RetrievalFitStats | None
query_info: SearchQueryInfo | None
class SubQuestionAnswerResults(BaseModel):
question: str
question_id: str
answer: str
verified_high_quality: bool
sub_query_retrieval_results: list[QueryRetrievalResult]
verified_reranked_documents: list[InferenceSection]
context_documents: list[InferenceSection]
cited_documents: list[InferenceSection]
sub_question_retrieval_stats: AgentChunkRetrievalStats
class CombinedAgentMetrics(BaseModel):
timings: AgentTimings
base_metrics: AgentBaseMetrics | None
refined_metrics: AgentRefinedMetrics
additional_metrics: AgentAdditionalMetrics
class PersonaPromptExpressions(BaseModel):
contextualized_prompt: str
base_prompt: str | None
class AgentPromptEnrichmentComponents(BaseModel):
persona_prompts: PersonaPromptExpressions
history: str
date_str: str

View File

@@ -0,0 +1,31 @@
from onyx.agents.agent_search.shared_graph_utils.models import (
SubQuestionAnswerResults,
)
from onyx.chat.prune_and_merge import _merge_sections
from onyx.context.search.models import InferenceSection
def dedup_inference_sections(
list1: list[InferenceSection], list2: list[InferenceSection]
) -> list[InferenceSection]:
deduped = _merge_sections(list1 + list2)
return deduped
def dedup_question_answer_results(
question_answer_results_1: list[SubQuestionAnswerResults],
question_answer_results_2: list[SubQuestionAnswerResults],
) -> list[SubQuestionAnswerResults]:
deduped_question_answer_results: list[
SubQuestionAnswerResults
] = question_answer_results_1
utilized_question_ids: set[str] = set(
[x.question_id for x in question_answer_results_1]
)
for question_answer_result in question_answer_results_2:
if question_answer_result.question_id not in utilized_question_ids:
deduped_question_answer_results.append(question_answer_result)
utilized_question_ids.add(question_answer_result.question_id)
return deduped_question_answer_results

View File

@@ -0,0 +1,433 @@
import os
import re
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Sequence
from datetime import datetime
from typing import Any
from typing import cast
from typing import Literal
from typing import TypedDict
from uuid import UUID
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langgraph.types import StreamWriter
from sqlalchemy.orm import Session
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.models import GraphInputs
from onyx.agents.agent_search.models import GraphPersistence
from onyx.agents.agent_search.models import GraphSearchConfig
from onyx.agents.agent_search.models import GraphTooling
from onyx.agents.agent_search.shared_graph_utils.models import (
EntityRelationshipTermExtraction,
)
from onyx.agents.agent_search.shared_graph_utils.models import PersonaPromptExpressions
from onyx.chat.models import AnswerPacket
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationConfig
from onyx.chat.models import DocumentPruningConfig
from onyx.chat.models import PromptConfig
from onyx.chat.models import SectionRelevancePiece
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.configs.constants import DISPATCH_SEP_CHAR
from onyx.configs.constants import FORMAT_DOCS_SEPARATOR
from onyx.context.search.enums import LLMEvaluationType
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SearchRequest
from onyx.db.engine import get_session_context_manager
from onyx.db.persona import get_persona_by_id
from onyx.db.persona import Persona
from onyx.llm.interfaces import LLM
from onyx.prompts.agent_search import (
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
)
from onyx.prompts.agent_search import (
ASSISTANT_SYSTEM_PROMPT_PERSONA,
)
from onyx.prompts.agent_search import (
HISTORY_CONTEXT_SUMMARY_PROMPT,
)
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
from onyx.tools.force import ForceUseTool
from onyx.tools.tool_constructor import SearchToolConfig
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported
BaseMessage_Content = str | list[str | dict[str, Any]]
# Post-processing
def format_docs(docs: Sequence[InferenceSection]) -> str:
formatted_doc_list = []
for doc_num, doc in enumerate(docs):
title: str | None = doc.center_chunk.title
metadata: dict[str, str | list[str]] | None = (
doc.center_chunk.metadata if doc.center_chunk.metadata else None
)
doc_str = f"**Document: D{doc_num + 1}**"
if title:
doc_str += f"\nTitle: {title}"
if metadata:
metadata_str = ""
for key, value in metadata.items():
if isinstance(value, str):
metadata_str += f" - {key}: {value}"
elif isinstance(value, list):
metadata_str += f" - {key}: {', '.join(value)}"
doc_str += f"\nMetadata: {metadata_str}"
doc_str += f"\nContent:\n{doc.combined_content}"
formatted_doc_list.append(doc_str)
return FORMAT_DOCS_SEPARATOR.join(formatted_doc_list)
def format_entity_term_extraction(
entity_term_extraction_dict: EntityRelationshipTermExtraction,
) -> str:
entities = entity_term_extraction_dict.entities
terms = entity_term_extraction_dict.terms
relationships = entity_term_extraction_dict.relationships
entity_strs = ["\nEntities:\n"]
for entity in entities:
entity_str = f"{entity.entity_name} ({entity.entity_type})"
entity_strs.append(entity_str)
entity_str = "\n - ".join(entity_strs)
relationship_strs = ["\n\nRelationships:\n"]
for relationship in relationships:
relationship_name = relationship.relationship_name
relationship_type = relationship.relationship_type
relationship_entities = relationship.relationship_entities
relationship_str = (
f"""{relationship_name} ({relationship_type}): {relationship_entities}"""
)
relationship_strs.append(relationship_str)
relationship_str = "\n - ".join(relationship_strs)
term_strs = ["\n\nTerms:\n"]
for term in terms:
term_str = f"{term.term_name} ({term.term_type}): similar to {', '.join(term.term_similar_to)}"
term_strs.append(term_str)
term_str = "\n - ".join(term_strs)
return "\n".join(entity_strs + relationship_strs + term_strs)
def get_test_config(
db_session: Session,
primary_llm: LLM,
fast_llm: LLM,
search_request: SearchRequest,
use_agentic_search: bool = True,
) -> GraphConfig:
persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session)
document_pruning_config = DocumentPruningConfig(
max_chunks=int(
persona.num_chunks
if persona.num_chunks is not None
else MAX_CHUNKS_FED_TO_CHAT
),
max_window_percentage=CHAT_TARGET_CHUNK_PERCENTAGE,
)
answer_style_config = AnswerStyleConfig(
citation_config=CitationConfig(
# The docs retrieved by this flow are already relevance-filtered
all_docs_useful=True
),
document_pruning_config=document_pruning_config,
structured_response_format=None,
)
search_tool_config = SearchToolConfig(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,
retrieval_options=RetrievalDetails(), # may want to set dedupe_docs=True
rerank_settings=None, # Can use this to change reranking model
selected_sections=None,
latest_query_files=None,
bypass_acl=False,
)
prompt_config = PromptConfig.from_model(persona.prompts[0])
search_tool = SearchTool(
db_session=db_session,
user=None,
persona=persona,
retrieval_options=search_tool_config.retrieval_options,
prompt_config=prompt_config,
llm=primary_llm,
fast_llm=fast_llm,
pruning_config=search_tool_config.document_pruning_config,
answer_style_config=search_tool_config.answer_style_config,
selected_sections=search_tool_config.selected_sections,
chunks_above=search_tool_config.chunks_above,
chunks_below=search_tool_config.chunks_below,
full_doc=search_tool_config.full_doc,
evaluation_type=(
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
),
rerank_settings=search_tool_config.rerank_settings,
bypass_acl=search_tool_config.bypass_acl,
)
graph_inputs = GraphInputs(
search_request=search_request,
prompt_builder=AnswerPromptBuilder(
user_message=HumanMessage(content=search_request.query),
message_history=[],
llm_config=primary_llm.config,
raw_user_query=search_request.query,
raw_user_uploaded_files=[],
),
structured_response_format=answer_style_config.structured_response_format,
)
using_tool_calling_llm = explicit_tool_calling_supported(
primary_llm.config.model_provider, primary_llm.config.model_name
)
graph_tooling = GraphTooling(
primary_llm=primary_llm,
fast_llm=fast_llm,
search_tool=search_tool,
tools=[search_tool],
force_use_tool=ForceUseTool(force_use=False, tool_name=""),
using_tool_calling_llm=using_tool_calling_llm,
)
chat_session_id = os.environ.get("ONYX_AS_CHAT_SESSION_ID")
assert (
chat_session_id is not None
), "ONYX_AS_CHAT_SESSION_ID must be set for backend tests"
graph_persistence = GraphPersistence(
db_session=db_session,
chat_session_id=UUID(chat_session_id),
message_id=1,
)
search_behavior_config = GraphSearchConfig(
use_agentic_search=use_agentic_search,
skip_gen_ai_answer_generation=False,
allow_refinement=True,
)
graph_config = GraphConfig(
inputs=graph_inputs,
tooling=graph_tooling,
persistence=graph_persistence,
behavior=search_behavior_config,
)
return graph_config
def get_persona_agent_prompt_expressions(
persona: Persona | None,
) -> PersonaPromptExpressions:
if persona is None or len(persona.prompts) == 0:
# TODO base_prompt should be None, but no time to properly fix
return PersonaPromptExpressions(
contextualized_prompt=ASSISTANT_SYSTEM_PROMPT_DEFAULT, base_prompt=""
)
# Only a 1:1 mapping between personas and prompts currently
prompt = persona.prompts[0]
prompt_config = PromptConfig.from_model(prompt)
datetime_aware_system_prompt = handle_onyx_date_awareness(
prompt_str=prompt_config.system_prompt,
prompt_config=prompt_config,
add_additional_info_if_no_tag=prompt.datetime_aware,
)
return PersonaPromptExpressions(
contextualized_prompt=ASSISTANT_SYSTEM_PROMPT_PERSONA.format(
persona_prompt=datetime_aware_system_prompt
),
base_prompt=datetime_aware_system_prompt,
)
def make_question_id(level: int, question_num: int) -> str:
return f"{level}_{question_num}"
def parse_question_id(question_id: str) -> tuple[int, int]:
level, question_num = question_id.split("_")
return int(level), int(question_num)
def _dispatch_nonempty(
content: str, dispatch_event: Callable[[str, int], None], sep_num: int
) -> None:
"""
Dispatch a content string if it is not empty using the given callback.
This function is used in the context of dispatching some arbitrary number
of similar objects which are separated by a separator during the LLM stream.
The callback expects a sep_num denoting which object is being dispatched; these
numbers go from 1 to however many strings the LLM decides to stream.
"""
if content != "":
dispatch_event(content, sep_num)
def dispatch_separated(
tokens: Iterator[BaseMessage],
dispatch_event: Callable[[str, int], None],
sep: str = DISPATCH_SEP_CHAR,
) -> list[BaseMessage_Content]:
num = 1
streamed_tokens: list[BaseMessage_Content] = []
for token in tokens:
content = cast(str, token.content)
if sep in content:
sub_question_parts = content.split(sep)
_dispatch_nonempty(sub_question_parts[0], dispatch_event, num)
num += 1
_dispatch_nonempty(
"".join(sub_question_parts[1:]).strip(), dispatch_event, num
)
else:
_dispatch_nonempty(content, dispatch_event, num)
streamed_tokens.append(content)
return streamed_tokens
def dispatch_main_answer_stop_info(level: int, writer: StreamWriter) -> None:
stop_event = StreamStopInfo(
stop_reason=StreamStopReason.FINISHED,
stream_type=StreamType.MAIN_ANSWER,
level=level,
)
write_custom_event("stream_finished", stop_event, writer)
def retrieve_search_docs(
search_tool: SearchTool, question: str
) -> list[InferenceSection]:
retrieved_docs: list[InferenceSection] = []
# new db session to avoid concurrency issues
with get_session_context_manager() as db_session:
for tool_response in search_tool.run(
query=question,
force_no_rerank=True,
alternate_db_session=db_session,
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
response = cast(SearchResponseSummary, tool_response.response)
retrieved_docs = response.top_sections
break
return retrieved_docs
def get_answer_citation_ids(answer_str: str) -> list[int]:
"""
Extract citation numbers of format [D<number>] from the answer string.
"""
citation_ids = re.findall(r"\[D(\d+)\]", answer_str)
return list(set([(int(id) - 1) for id in citation_ids]))
def summarize_history(
history: str, question: str, persona_specification: str | None, llm: LLM
) -> str:
history_context_prompt = remove_document_citations(
HISTORY_CONTEXT_SUMMARY_PROMPT.format(
persona_specification=persona_specification,
question=question,
history=history,
)
)
history_response = llm.invoke(history_context_prompt)
assert isinstance(history_response.content, str)
return history_response.content
# taken from langchain_core.runnables.schema
# we don't use the one from their library because
# it includes ids they generate
class CustomStreamEvent(TypedDict):
# Overwrite the event field to be more specific.
event: Literal["on_custom_event"] # type: ignore[misc]
"""The event type."""
name: str
"""User defined name for the event."""
data: Any
"""The data associated with the event. Free form and can be anything."""
def write_custom_event(
name: str, event: AnswerPacket, stream_writer: StreamWriter
) -> None:
stream_writer(CustomStreamEvent(event="on_custom_event", name=name, data=event))
def relevance_from_docs(
relevant_docs: list[InferenceSection],
) -> list[SectionRelevancePiece]:
return [
SectionRelevancePiece(
relevant=True,
content=doc.center_chunk.content,
document_id=doc.center_chunk.document_id,
chunk_id=doc.center_chunk.chunk_id,
)
for doc in relevant_docs
]
def get_langgraph_node_log_string(
graph_component: str,
node_name: str,
node_start_time: datetime,
result: str | None = None,
) -> str:
duration = datetime.now() - node_start_time
results_str = "" if result is None else f" -- Result: {result}"
return f"{node_start_time} -- {graph_component} - {node_name} -- Time taken: {duration}{results_str}"
def remove_document_citations(text: str) -> str:
"""
Removes citation expressions of format '[[D1]]()' from text.
The number after D can vary.
Args:
text: Input text containing citations
Returns:
Text with citations removed
"""
# Pattern explanation:
# \[(?:D|Q)?\d+\] matches:
# \[ - literal [ character
# (?:D|Q)? - optional D or Q character
# \d+ - one or more digits
# \] - literal ] character
return re.sub(r"\[(?:D|Q)?\d+\]", "", text)

View File

@@ -10,6 +10,7 @@ from onyx.configs.app_configs import SMTP_PORT
from onyx.configs.app_configs import SMTP_SERVER
from onyx.configs.app_configs import SMTP_USER
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
from onyx.db.models import User
@@ -65,9 +66,13 @@ def send_forgot_password_email(
user_email: str,
token: str,
mail_from: str = EMAIL_FROM,
tenant_id: str | None = None,
) -> None:
subject = "Onyx Forgot Password"
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
if tenant_id:
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
# Keep search param same name as cookie for simplicity
body = f"Click the following link to reset your password: {link}"
send_email(user_email, subject, body, mail_from)

View File

@@ -42,6 +42,10 @@ class UserCreate(schemas.BaseUserCreate):
tenant_id: str | None = None
class UserUpdateWithRole(schemas.BaseUserUpdate):
role: UserRole
class UserUpdate(schemas.BaseUserUpdate):
"""
Role updates are not allowed through the user update endpoint for security reasons

View File

@@ -57,7 +57,7 @@ from onyx.auth.invited_users import get_invited_users
from onyx.auth.schemas import AuthBackend
from onyx.auth.schemas import UserCreate
from onyx.auth.schemas import UserRole
from onyx.auth.schemas import UserUpdate
from onyx.auth.schemas import UserUpdateWithRole
from onyx.configs.app_configs import AUTH_BACKEND
from onyx.configs.app_configs import AUTH_COOKIE_EXPIRE_TIME_SECONDS
from onyx.configs.app_configs import AUTH_TYPE
@@ -73,6 +73,7 @@ from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import PASSWORD_SPECIAL_CHARS
@@ -216,9 +217,26 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET
verification_token_lifetime_seconds = AUTH_COOKIE_EXPIRE_TIME_SECONDS
user_db: SQLAlchemyUserDatabase[User, uuid.UUID]
async def get_by_email(self, user_email: str) -> User:
tenant_id = fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
)(user_email)
async with get_async_session_with_tenant(tenant_id) as db_session:
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
db_session, User, OAuthAccount
)
user = await tenant_user_db.get_by_email(user_email)
else:
user = await self.user_db.get_by_email(user_email)
if not user:
raise exceptions.UserNotExists()
return user
async def create(
self,
user_create: schemas.UC | UserCreate,
@@ -246,10 +264,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
referral_source=referral_source,
request=request,
)
user: User
async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
if MULTI_TENANT:
@@ -268,16 +286,16 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
# Handle case where user has used product outside of web and is now creating an account through web
if not user.role.is_web_login() and user_create.role.is_web_login():
user_update = UserUpdate(
user_update = UserUpdateWithRole(
password=user_create.password,
is_verified=user_create.is_verified,
role=user_create.role,
)
user = await self.update(user_update, user)
else:
@@ -285,7 +303,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return user
async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None:
@@ -372,6 +389,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
"refresh_token": refresh_token,
}
user: User
try:
# Attempt to get user by OAuth account
user = await self.get_by_oauth_account(oauth_name, account_id)
@@ -504,9 +523,15 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
)
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
"Your admin has not enbaled this feature.",
"Your admin has not enabled this feature.",
)
send_forgot_password_email(user.email, token)
tenant_id = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
async_return_default_schema,
)(email=user.email)
send_forgot_password_email(user.email, token, tenant_id=tenant_id)
async def on_after_request_verify(
self, user: User, token: str, request: Optional[Request] = None
@@ -580,6 +605,7 @@ async def get_user_manager(
cookie_transport = CookieTransport(
cookie_max_age=SESSION_EXPIRE_TIME_SECONDS,
cookie_secure=WEB_DOMAIN.startswith("https"),
cookie_name=FASTAPI_USERS_AUTH_COOKIE_NAME,
)
@@ -1047,6 +1073,8 @@ async def api_key_dep(
if AUTH_TYPE == AuthType.DISABLED:
return None
user: User | None = None
hashed_api_key = get_hashed_api_key_from_request(request)
if not hashed_api_key:
raise HTTPException(status_code=401, detail="Missing API key")

View File

@@ -24,6 +24,7 @@ from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_sqlalchemy_engine
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
from onyx.httpx.httpx_pool import HttpxPool
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from onyx.redis.redis_connector_delete import RedisConnectorDelete
@@ -197,7 +198,8 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
"""Waits for redis to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
Will raise WorkerShutdown to kill the celery worker if the timeout
is reached."""
r = get_redis_client(tenant_id=None)
@@ -316,6 +318,8 @@ def on_worker_ready(sender: Any, **kwargs: Any) -> None:
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
HttpxPool.close_all()
if not celery_is_worker_primary(sender):
return

View File

@@ -1,6 +1,5 @@
from datetime import timedelta
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
@@ -8,7 +7,6 @@ from celery.beat import PersistentScheduler # type: ignore
from celery.signals import beat_init
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import SqlEngine
@@ -132,21 +130,25 @@ class DynamicTenantScheduler(PersistentScheduler):
# get current schedule and extract current tenants
current_schedule = self.schedule.items()
current_tenants = set()
for task_name, _ in current_schedule:
task_name = cast(str, task_name)
if task_name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
continue
# there are no more per tenant beat tasks, so comment this out
# NOTE: we may not actualy need this scheduler any more and should
# test reverting to a regular beat schedule implementation
if "_" in task_name:
# example: "check-for-condition-tenant_12345678-abcd-efgh-ijkl-12345678"
# -> "12345678-abcd-efgh-ijkl-12345678"
current_tenants.add(task_name.split("_")[-1])
logger.info(f"Found {len(current_tenants)} existing items in schedule")
# current_tenants = set()
# for task_name, _ in current_schedule:
# task_name = cast(str, task_name)
# if task_name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
# continue
for tenant_id in tenant_ids:
if tenant_id not in current_tenants:
logger.info(f"Processing new tenant: {tenant_id}")
# if "_" in task_name:
# # example: "check-for-condition-tenant_12345678-abcd-efgh-ijkl-12345678"
# # -> "12345678-abcd-efgh-ijkl-12345678"
# current_tenants.add(task_name.split("_")[-1])
# logger.info(f"Found {len(current_tenants)} existing items in schedule")
# for tenant_id in tenant_ids:
# if tenant_id not in current_tenants:
# logger.info(f"Processing new tenant: {tenant_id}")
new_schedule = self._generate_schedule(tenant_ids)

View File

@@ -10,6 +10,10 @@ from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
from onyx.db.engine import SqlEngine
from onyx.utils.logger import setup_logger
@@ -54,12 +58,23 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
@worker_init.connect
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
EXTRA_CONCURRENCY = 8 # small extra fudge factor for connection limits
logger.info("worker_init signal received.")
logger.info(f"Concurrency: {sender.concurrency}") # type: ignore
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=EXTRA_CONCURRENCY) # type: ignore
if MANAGED_VESPA:
httpx_init_vespa_pool(
sender.concurrency + EXTRA_CONCURRENCY, # type: ignore
ssl_cert=VESPA_CLOUD_CERT_PATH,
ssl_key=VESPA_CLOUD_KEY_PATH,
)
else:
httpx_init_vespa_pool(sender.concurrency + EXTRA_CONCURRENCY) # type: ignore
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)

View File

@@ -21,13 +21,16 @@ from onyx.background.celery.tasks.indexing.utils import (
get_unfenced_index_attempt_ids,
)
from onyx.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from onyx.db.engine import get_session_with_default_tenant
from onyx.db.engine import SqlEngine
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import mark_attempt_canceled
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from onyx.redis.redis_connector_credential_pair import (
RedisGlobalConnectorCredentialPair,
)
from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
@@ -141,23 +144,16 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
r.delete(OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
r.delete(OnyxRedisConstants.ACTIVE_FENCES)
RedisGlobalConnectorCredentialPair.reset_all(r)
RedisDocumentSet.reset_all(r)
RedisUserGroup.reset_all(r)
RedisConnectorDelete.reset_all(r)
RedisConnectorPrune.reset_all(r)
RedisConnectorIndex.reset_all(r)
RedisConnectorStop.reset_all(r)
RedisConnectorPermissionSync.reset_all(r)
RedisConnectorExternalGroupSync.reset_all(r)
# mark orphaned index attempts as failed

View File

@@ -91,6 +91,28 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
return False
def celery_get_queued_task_ids(queue: str, r: Redis) -> set[str]:
"""This is a redis specific way to build a list of tasks in a queue.
This helps us read the queue once and then efficiently look for missing tasks
in the queue.
"""
task_set: set[str] = set()
for priority in range(len(OnyxCeleryPriority)):
queue_name = f"{queue}{CELERY_SEPARATOR}{priority}" if priority > 0 else queue
tasks = cast(list[bytes], r.lrange(queue_name, 0, -1))
for task in tasks:
task_dict: dict[str, Any] = json.loads(task.decode("utf-8"))
task_id = task_dict.get("headers", {}).get("id")
if task_id:
task_set.add(task_id)
return task_set
def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str]:
"""Returns a list of current workers containing name_filter, or all workers if
name_filter is None.

View File

@@ -1,10 +1,13 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
import httpx
from sqlalchemy.orm import Session
from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
@@ -17,6 +20,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import TaskStatus
from onyx.db.models import TaskQueueState
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.redis.redis_connector import RedisConnector
from onyx.server.documents.models import DeletionAttemptSnapshot
@@ -154,3 +158,25 @@ def celery_is_worker_primary(worker: Any) -> bool:
return True
return False
def httpx_init_vespa_pool(
max_keepalive_connections: int,
timeout: int = VESPA_REQUEST_TIMEOUT,
ssl_cert: str | None = None,
ssl_key: str | None = None,
) -> None:
httpx_cert = None
httpx_verify = False
if ssl_cert and ssl_key:
httpx_cert = cast(tuple[str, str], (ssl_cert, ssl_key))
httpx_verify = True
HttpxPool.init_client(
name="vespa",
cert=httpx_cert,
verify=httpx_verify,
timeout=timeout,
http2=False,
limits=httpx.Limits(max_keepalive_connections=max_keepalive_connections),
)

Some files were not shown because too many files have changed in this diff Show More