Compare commits

...

183 Commits

Author SHA1 Message Date
Weves
30983657ec Fix indexing of whitespace only 2024-01-05 19:35:38 -08:00
Yuhong Sun
6b6b3daab7 Reenable option to run Danswer without Gen AI (#906) 2024-01-03 18:31:16 -08:00
Chris Weaver
20441df4a4 Add Tag Filter UI + other UI cleanup (#905) 2024-01-02 11:30:36 -08:00
Yuhong Sun
d7141df5fc Metadata and Title Search (#903) 2024-01-02 11:25:50 -08:00
Yuhong Sun
615bb7b095 Update CONTRIBUTING.md 2024-01-01 18:07:50 -08:00
Yuhong Sun
e759718c3e Update CONTRIBUTING.md 2024-01-01 18:06:56 -08:00
Yuhong Sun
06d8d0e53c Update CONTRIBUTING.md 2024-01-01 18:06:17 -08:00
Weves
ae9b556876 Revamp new chat screen for chat UI 2023-12-30 18:13:24 -08:00
Chris Weaver
f883611e94 Add query editing in Chat UI (#899) 2023-12-30 12:46:48 -08:00
Yuhong Sun
13c536c033 Final Backend CVEs (#900) 2023-12-30 11:57:49 -08:00
Yuhong Sun
2e6be57880 Model Server CVEs (#898) 2023-12-29 21:14:08 -08:00
Weves
b352d83b8c Increase max upload size 2023-12-29 21:11:57 -08:00
Yuhong Sun
aa67768c79 CVEs continued (#889) 2023-12-29 20:42:16 -08:00
Weves
6004e540f3 Improve Vespa invalid char cleanup 2023-12-29 20:36:03 -08:00
eukub
64d2cea396 reduced redunduncy and changed concatenation of strings to f-strings 2023-12-29 00:35:04 -08:00
Weves
b5947a1c74 Add illegal char stripping to title field 2023-12-29 00:17:40 -08:00
Weves
cdf260b277 FIx chat refresh + add stop button 2023-12-28 23:33:41 -08:00
Weves
73483b5e09 Fix more auth disabled flakiness 2023-12-27 01:23:29 -08:00
Yuhong Sun
a6a444f365 Bump Python Version for security (#887) 2023-12-26 16:15:14 -08:00
Yuhong Sun
449a403c73 Automatic Security Scan (#886) 2023-12-26 14:41:23 -08:00
Yuhong Sun
4aebf824d2 Fix broken build SHA issue (#885) 2023-12-26 14:36:40 -08:00
Weves
26946198de Fix disabled auth 2023-12-26 12:51:58 -08:00
Yuhong Sun
e5035b8992 Move some util functions around (#883) 2023-12-26 00:38:29 -08:00
Weves
2e9af3086a Remove old comment 2023-12-25 21:36:54 -08:00
Weves
dab3ba8a41 Add support for basic auth on FE 2023-12-25 21:19:59 -08:00
Yuhong Sun
1e84b0daa4 Fix escape character handling in DanswerBot (#880) 2023-12-25 12:28:35 -08:00
Yuhong Sun
f4c8abdf21 Remove Extraneous Persona Config (#878) 2023-12-24 22:48:48 -08:00
sweep-ai[bot]
ccc5bb1e67 Configure Sweep (#875)
* Create sweep.yaml

* Create sweep template

* Update sweep.yaml

---------

Co-authored-by: sweep-ai[bot] <128439645+sweep-ai[bot]@users.noreply.github.com>
Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2023-12-24 19:04:52 -08:00
Yuhong Sun
c3cf9134bb Telemetry Revision (#868) 2023-12-24 17:39:37 -08:00
Weves
0370b9b38d Stop copying local node_modules / .next dir into web docker image 2023-12-24 15:27:11 -08:00
Weves
95bf1c13ad Add http2 dependency 2023-12-24 14:49:31 -08:00
Yuhong Sun
00c1f93b12 Zendesk Tiny Cleanup (#867) 2023-12-23 16:39:15 -08:00
Yuhong Sun
a122510cee Zendesk Connector Metadata and small batch fix (#866) 2023-12-23 16:34:48 -08:00
Weves
dca4f7a72b Adding http2 support to Vespa 2023-12-23 16:23:24 -08:00
Weves
535dc265c5 Fix boost resetting on document update + fix refresh on re-index 2023-12-23 15:23:21 -08:00
Weves
56882367ba Fix migrations 2023-12-23 12:58:00 -08:00
Weves
d9fbd7ffe2 Add hiding + re-ordering to personas 2023-12-22 23:04:43 -08:00
Yuhong Sun
8b7d01fb3b Allow Duplicate Naming for CC-Pair (#862) 2023-12-22 23:03:44 -08:00
voarsh2
016a087b10 Refactor environment variable handling using ConfigMap for Kubernetes deployment (#515)
---------

Co-authored-by: Reese Jenner <reesevader@hotmail.co.uk>
Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2023-12-22 21:33:36 -08:00
Sam Jakos
241b886976 fix: parse INDEX_BATCH_SIZE to an int (#858) 2023-12-22 13:03:21 -08:00
Yuhong Sun
ff014e4f5a Bump Transformer Version (#857) 2023-12-22 01:47:18 -08:00
Aliaksandr_С
0318507911 Indexing settings and logging improve (#821)
---------

Co-authored-by: Aliaksandr Chernak <aliaksandr_chernak@epam.com>
Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2023-12-22 01:13:24 -08:00
Yuhong Sun
6650f01dc6 Multilingual Docs Updates (#856) 2023-12-22 00:26:00 -08:00
Yuhong Sun
962e3f726a Slack Feedback Message Tweaks (#855) 2023-12-21 20:52:11 -08:00
mattboret
25a73b9921 Slack bot improve source feedback (#827)
---------

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
Co-authored-by: Matthieu Boret <matthieu.boret@fr.clara.net>
2023-12-21 20:33:20 -08:00
Yuhong Sun
dc0b3672ac git push --set-upstream origin danswerbot-format (#854) 2023-12-21 18:46:30 -08:00
Yuhong Sun
c4ad03a65d Handle DanswerBot case where no updated at (#853) 2023-12-21 18:33:42 -08:00
mattboret
c6f354fd03 Add the latest document update to the Slack bot answer (#817)
* Add the latest source update to the Slack bot answer

* fix mypy errors

---------

Co-authored-by: Matthieu Boret <matthieu.boret@fr.clara.net>
2023-12-21 18:16:05 -08:00
Yuhong Sun
2f001c23b7 Confluence add tag to replaced names (#852) 2023-12-21 18:03:56 -08:00
mattboret
4d950aa60d Replace user id by the user display name in the exported Confluence page (#815)
Co-authored-by: Matthieu Boret <matthieu.boret@fr.clara.net>
2023-12-21 17:52:28 -08:00
Yuhong Sun
56406a0b53 Bump Vespa to 8.277.17 (#851) 2023-12-21 17:23:27 -08:00
sam lockart
eb31c08461 Update Vespa to 8.267.29 (#812) 2023-12-21 17:18:16 -08:00
Weves
26f94c9890 Improve re-sizing 2023-12-21 10:03:03 -08:00
Weves
a9570e01e2 Make document sidebar scrollbar darker 2023-12-21 10:03:03 -08:00
Weves
402d83e167 Make it so docs without links aren't clickable in chat citations 2023-12-21 10:03:03 -08:00
Ikko Eltociear Ashimine
10dcd49fc8 Update CONTRIBUTING.md
Nagivate -> Navigate
2023-12-21 09:10:52 -08:00
Yuhong Sun
0fdad0e777 Update Demo Video 2023-12-20 19:05:23 -08:00
Weves
fab767d794 Fix persona document sets 2023-12-20 15:24:32 -08:00
Weves
7dd70ca4c0 Change danswer header link in chat page 2023-12-20 11:38:33 -08:00
Weves
370760eeee Fix editing deleted personas, editing personas with no prompts, and model selection 2023-12-19 14:42:13 -08:00
Weves
24a62cb33d Fix persona + prompt apis 2023-12-19 10:23:06 -08:00
Weves
9e4a4ddf39 Update search helper styling 2023-12-19 07:08:11 -08:00
Yuhong Sun
c281859509 Google Drive handle invalid PDFs (#838) 2023-12-18 23:39:45 -08:00
Yuhong Sun
2180a40bd3 Disable Chain of Thought for now (#837) 2023-12-18 21:44:47 -08:00
Weves
997f9c3191 Fix ccPair pages crashing 2023-12-17 23:28:26 -08:00
Weves
677c32ea79 Fix issue where a message that errors out creates a bad state 2023-12-17 23:28:26 -08:00
Yuhong Sun
edfc849652 Search more frequently (#834) 2023-12-17 22:45:46 -08:00
Yuhong Sun
9d296b623b Shield Update (#833) 2023-12-17 22:17:44 -08:00
Yuhong Sun
5957b888a5 DanswerBot Chat (#831) 2023-12-17 18:18:48 -08:00
Chris Weaver
c7a91b1819 Allow re-sizing of document sidebar + make central chat smaller on small screens (#832) 2023-12-17 18:17:43 -08:00
Weves
a099f8e296 Rework header a bit + remove assumption of all personas having a prompt 2023-12-14 23:06:39 -08:00
Weves
16c8969028 Chat UI 2023-12-14 22:18:42 -08:00
Yuhong Sun
65fde8f1b3 Chat Backend (#801) 2023-12-14 22:14:37 -08:00
Yuhong Sun
229db47e5d Update LLM Key Check Logic (#825) 2023-12-09 13:41:31 -08:00
Weves
2e3397feb0 Check for slack bot token changes every 60 seconds 2023-12-08 14:14:22 -08:00
Weves
d5658ce477 Persona enhancements 2023-12-07 14:29:37 -08:00
Weves
ddf3f99da4 Add support for global API prefix env variable 2023-12-07 12:42:17 -08:00
Weves
56785e6065 Add model choice to Persona 2023-12-07 00:20:42 -08:00
Weves
26e808d2a1 Fix welcome modal 2023-12-06 21:07:34 -08:00
Yuhong Sun
e3ac373f05 Make Default Fast LLM not identical to main LLM (#818) 2023-12-06 16:14:04 -08:00
Yuhong Sun
9e9a578921 Option to speed up DanswerBot by turning off chain of thought (#816) 2023-12-05 00:43:45 -08:00
Weves
f7172612e1 Allow persona usage for Slack bots 2023-12-04 19:20:03 -08:00
Yuhong Sun
5aa2de7a40 Fix Weak Models Concurrency Issue (#811) 2023-12-04 15:40:10 -08:00
Yuhong Sun
e0b87d9d4e Fix Weak Model Prompt (#810) 2023-12-04 15:02:08 -08:00
Weves
5607fdcddd Make Slack Bot setup UI more similar to Persona setup 2023-12-03 23:36:54 -08:00
Yuhong Sun
651de071f7 Improve English rephrasing for multilingual use case (#808) 2023-12-03 14:34:12 -08:00
John Bergvall
5629ca7d96 Copy SearchQuery model with updated attribute due to Config.frozen=True (#806)
Fixes the following TypeError:

api_server_1     |   File "/usr/local/lib/python3.11/site-packages/anyio/to_thread.py", line 33, in run_sync
api_server_1     |     return await get_asynclib().run_sync_in_worker_thread(
api_server_1     |            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api_server_1     |   File "/usr/local/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 877, in run_sync_in_worker_thread
api_server_1     |     return await future
api_server_1     |            ^^^^^^^^^^^^
api_server_1     |   File "/usr/local/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 807, in run
api_server_1     |     result = context.run(func, *args)
api_server_1     |              ^^^^^^^^^^^^^^^^^^^^^^^^
api_server_1     |   File "/usr/local/lib/python3.11/site-packages/starlette/concurrency.py", line 53, in _next
api_server_1     |     return next(iterator)
api_server_1     |            ^^^^^^^^^^^^^^
api_server_1     |   File "/app/danswer/utils/timing.py", line 47, in wrapped_func
api_server_1     |     value = next(gen)
api_server_1     |             ^^^^^^^^^
api_server_1     |   File "/app/danswer/direct_qa/answer_question.py", line 243, in answer_qa_query_stream
api_server_1     |     top_chunks = cast(list[InferenceChunk], next(search_generator))
api_server_1     |                                             ^^^^^^^^^^^^^^^^^^^^^^
api_server_1     |   File "/app/danswer/search/search_runner.py", line 469, in full_chunk_search_generator
api_server_1     |     retrieved_chunks = retrieve_chunks(
api_server_1     |                        ^^^^^^^^^^^^^^^^
api_server_1     |   File "/app/danswer/search/search_runner.py", line 353, in retrieve_chunks
api_server_1     |     q_copy.query = rephrase
api_server_1     |     ^^^^^^^^^^^^
api_server_1     |   File "pydantic/main.py", line 359, in pydantic.main.BaseModel.__setattr__
api_server_1     | TypeError: "SearchQuery" is immutable and does not support item assignment
2023-12-03 13:47:11 -08:00
Yuhong Sun
bc403d97f2 Organize Prompts for Chat implementation (#807) 2023-12-03 13:27:11 -08:00
Weves
292c78b193 Always pull latest data when visiting main search page 2023-12-03 03:25:13 -08:00
Weves
ac35719038 FE improvements to make initial setup more intuitive 2023-12-02 16:40:44 -08:00
Yuhong Sun
02095e9281 Restructure APIs (#803) 2023-12-02 14:48:08 -08:00
Yuhong Sun
8954a04602 Reorder Tables for cleaner extending (#800) 2023-12-01 17:46:13 -08:00
Yuhong Sun
8020db9e9a Update connector interface with optional Owners information (#798) 2023-11-30 23:08:16 -08:00
Yuhong Sun
17c2f06338 Add more metadata options for File connector (#797) 2023-11-30 13:24:22 -08:00
Weves
9cff294a71 Increase retries for google drive connector 2023-11-30 03:03:26 -08:00
Weves
e983aaeca7 Add more logging on existing jobs 2023-11-30 02:58:37 -08:00
Weves
7ea774f35b Change in-progress status color 2023-11-29 20:57:45 -08:00
Weves
d1846823ba Associate a user with web/file connectors 2023-11-29 18:18:56 -08:00
Yuhong Sun
fda89ac810 Expert Recommendation Heuristic Only (#791) 2023-11-29 15:53:57 -08:00
Yuhong Sun
006fd4c438 Ingestion API now always updates regardless of document updated_at (#786) 2023-11-29 02:08:50 -08:00
Weves
9b7069a043 Disallow re-indexing for File connector 2023-11-29 02:01:11 -08:00
Weves
c64c25b2e1 Fix temp file deletion 2023-11-29 02:00:20 -08:00
Yuhong Sun
c2727a3f19 Custom OpenAI Model Server (#782) 2023-11-29 01:41:56 -08:00
Chris Weaver
37daf4f3e4 Remove AI Thoughts by default (#783)
- Removes AI Thoughts by default - only shows when validation fails
- Removes punctuation "words" from queries in addition to stopwords (Vespa ignores punctuation anyways)
- Fixes Vespa deletion script for larger doc counts
2023-11-29 01:00:53 -08:00
Yuhong Sun
fcb7f6fcc0 Accept files with character issues (#781) 2023-11-28 22:43:58 -08:00
Weves
429016d4a2 Fix zulip page 2023-11-28 16:28:51 -08:00
Chris Weaver
c83a450ec4 Remove personal connectors page(#779) 2023-11-28 16:11:42 -08:00
Yuhong Sun
187b94a7d8 Blurb Key Error (#778) 2023-11-28 16:09:33 -08:00
Weves
30225fd4c5 Fix filter hiding 2023-11-28 04:13:11 -08:00
Weves
a4f053fa5b Fix persona refresh 2023-11-28 02:53:18 -08:00
Weves
eab4fe83a0 Remove Slack bot personas from web UI 2023-11-28 02:53:18 -08:00
Chris Weaver
78d1ae0379 Customizable personas (#772)
Also includes a small fix to LLM filtering when combined with reranking
2023-11-28 00:57:48 -08:00
Yuhong Sun
87beb1f4d1 Log LLM details on server start (#773) 2023-11-27 21:32:48 -08:00
Yuhong Sun
05c2b7d34e Update LLM related Libs (#771) 2023-11-26 19:54:16 -08:00
Yuhong Sun
39d09a162a Danswer APIs Document Ingestion Endpoint (#716) 2023-11-26 19:09:22 -08:00
Yuhong Sun
d291fea020 Turn off Reranking for Streaming Flows (#770) 2023-11-26 16:45:23 -08:00
Yuhong Sun
2665bff78e Option to turn off LLM for eval script (#769) 2023-11-26 15:31:03 -08:00
Yuhong Sun
65d38ac8c3 Slack to respect LLM chunk filter settings (#768) 2023-11-26 01:06:12 -08:00
Yuhong Sun
8391d89bea Fix Indexing Concurrency (#767) 2023-11-25 21:40:36 -08:00
Yuhong Sun
ac2ed31726 Indexing Jobs to have shorter lived DB sessions (#766) 2023-11-24 21:38:16 -08:00
Chris Weaver
47f947b045 Use torch.multiprocessing + enable SimpleJobClient by default (#765) 2023-11-24 18:29:28 -08:00
dependabot[bot]
63b051b342 Bump sharp from 0.32.5 to 0.32.6 in /web
Bumps [sharp](https://github.com/lovell/sharp) from 0.32.5 to 0.32.6.
- [Release notes](https://github.com/lovell/sharp/releases)
- [Changelog](https://github.com/lovell/sharp/blob/main/docs/changelog.md)
- [Commits](https://github.com/lovell/sharp/compare/v0.32.5...v0.32.6)

---
updated-dependencies:
- dependency-name: sharp
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-11-24 18:14:45 -08:00
Weves
a5729e2fa6 Add new model server env vars to the compose file 2023-11-24 00:12:04 -08:00
Weves
3cec854c5c Allow different model servers for different models / indexing jobs 2023-11-23 23:39:03 -08:00
Weves
26c6651a03 Improve LLM answer parsing 2023-11-23 15:03:35 -08:00
Yuhong Sun
13001ede98 Search Regression Test and Save/Load State updates (#761) 2023-11-23 00:00:30 -08:00
Yuhong Sun
fda377a2fa Regression Script for Search quality (#760) 2023-11-22 19:33:28 -08:00
Yuhong Sun
bdfb894507 Slack Role Override (#755) 2023-11-22 17:47:18 -08:00
Weves
35c3511daa Increase Vespa timeout 2023-11-22 01:42:59 -08:00
Chris Weaver
c1e19d0d93 Add selected docs in UI + rework the backend flow a bit(#754)
Changes the flow so that the selected docs are sent over in a separate packet rather than as part of the initial packet for the streaming QA endpoint.
2023-11-21 19:46:12 -08:00
mattboret
e78aefb408 Add script to analyse the sources selection (#721)
---------

Co-authored-by: Matthieu Boret <matthieu.boret@fr.clara.net>
2023-11-21 18:35:26 -08:00
Bryan Peterson
aa2e859b46 add missing dependencies in model_server dockerfile (#752)
Thanks for catching this! Super helpful!
2023-11-21 17:59:28 -08:00
Yuhong Sun
c0c8ae6c08 Minor Tuning for Filters (#753) 2023-11-21 15:47:58 -08:00
Weves
1225c663eb Add new env variable to compose file 2023-11-20 21:40:54 -08:00
Weves
e052d607d5 Add option to log Vespa timing info 2023-11-20 21:37:22 -08:00
Yuhong Sun
8e5e11a554 Add md files to File Connector (#749) 2023-11-20 19:56:06 -08:00
Yuhong Sun
57f0323f52 NLP Model Warmup Reworked (#748) 2023-11-20 17:28:23 -08:00
Weves
6e9f31d1e9 Fix ResourceLogger blocking main thread 2023-11-20 16:46:18 -08:00
Weves
eeb844e35e Fix bug with Google Drive shortcut error case 2023-11-20 16:34:07 -08:00
Sid Ravinutala
d6a84ab413 fix for url parsing google site 2023-11-20 16:08:43 -08:00
Weves
68160d49dd Small mods to enable deployment on AWS EKS 2023-11-20 01:42:48 -08:00
Yuhong Sun
0cc3d65839 Add option to run a faster/cheaper LLM for secondary flows (#742) 2023-11-19 17:48:42 -08:00
Weves
df37387146 Fix a couple bugs with google sites link finding 2023-11-19 15:35:54 -08:00
Yuhong Sun
f72825cd46 Provide Metadata to the LLM (#740) 2023-11-19 12:28:45 -08:00
Yuhong Sun
6fb07d20cc Multilingual Query Expansion (#737) 2023-11-19 10:55:55 -08:00
Chris Weaver
b258ec1bed Adjust checks for removal from existing_jobs dict + add more logging + only one scheduled job for a connector at a time (#739) 2023-11-19 02:03:17 -08:00
Yuhong Sun
4fd55b8928 Fix GPT4All (#738) 2023-11-18 21:21:02 -08:00
Yuhong Sun
b3ea53fa46 Fix Build Version (#736) 2023-11-18 17:16:25 -08:00
Yuhong Sun
fa0d19cc8c LLM Chunk Filtering (#735) 2023-11-18 17:12:24 -08:00
Weves
d5916e420c Fix duplicated query event for 'answer_qa_query_stream' and missing llm_answer in 'answer_qa_query' 2023-11-17 21:10:23 -08:00
Weves
39b912befd Enable show GPT answer option immediately 2023-11-17 17:08:38 -08:00
Weves
37c5f24d91 Fix logout redirect 2023-11-17 16:43:24 -08:00
Weves
ae72cd56f8 Add a bit more logging in indexing pipeline 2023-11-16 12:00:19 -08:00
Yuhong Sun
be5ef77896 Optional Anonymous Telemetry (#727) 2023-11-16 09:22:36 -08:00
Weves
0ed8f14015 Improve Vespa filtering performance 2023-11-15 14:30:12 -08:00
Weves
a03e443541 Add root_page_id option for Notion connector 2023-11-15 12:46:41 -08:00
Weves
4935459798 Fix hover being transparent 2023-11-15 11:52:40 -08:00
Weves
efb52873dd Prettier fix 2023-11-14 22:22:42 -08:00
Bradley
442f7595cc Added connector configuration link and external link icon to web connector page. 2023-11-14 22:19:00 -08:00
Weves
81cbcbb403 Fix connector deletion bug 2023-11-14 09:07:59 -08:00
Weves
0a0e672b35 Fix no letsencrypt 2023-11-13 14:32:51 -08:00
Yuhong Sun
69644b266e Hybrid Search Alpha Parameter (#714) 2023-11-09 17:11:10 -08:00
Yuhong Sun
5a4820c55f Skip Index on Docs with no newer updated at (#713) 2023-11-09 16:27:32 -08:00
Weves
a5d69bb392 Add back end time to Gong 2023-11-09 14:03:46 -08:00
Weves
23ee45c033 Enhance document explorer 2023-11-09 00:58:51 -08:00
Yuhong Sun
31bfd015ae Request Tracker Connector (#709)
Contributed by Evan! Thanks for the contribution!

- Minor linting and rebasing done by Yuhong, everything else from Evan

---------

Co-authored-by: Evan Sarmiento <e.sarmiento@soax.com>
Co-authored-by: Evan <esarmien@fas.harvard.edu>
2023-11-07 16:55:10 -08:00
Yuhong Sun
0125d8a0f6 Source Filter Extraction (#708) 2023-11-07 14:21:04 -08:00
Yuhong Sun
4f64444f0f Fix Version from Tag not picked up (#705) 2023-11-06 20:01:20 -08:00
Weves
abf9cc3248 Add timeout to all Notion calls 2023-11-06 19:29:42 -08:00
Chris Weaver
f5bf2e6374 Fix experimental checkpointing + move check for disabled connector to the start of the batch (#703) 2023-11-06 17:14:31 -08:00
Yuhong Sun
24b3b1fa9e Fix GitHub Actions Naming (#702) 2023-11-06 16:40:49 -08:00
Yuhong Sun
7433dddac3 Model Server (#695)
Provides the ability to pull out the NLP models into a separate model server which can then be hosted on a GPU instance if desired.
2023-11-06 16:36:09 -08:00
Weves
fe938b6fc6 Add experimental checkpointing 2023-11-04 14:51:28 -07:00
dependabot[bot]
2db029672b Bump pypdf from 3.16.4 to 3.17.0 in /backend/requirements (#667)
Bumps [pypdf](https://github.com/py-pdf/pypdf) from 3.16.4 to 3.17.0.
- [Release notes](https://github.com/py-pdf/pypdf/releases)
- [Changelog](https://github.com/py-pdf/pypdf/blob/main/CHANGELOG.md)
- [Commits](https://github.com/py-pdf/pypdf/compare/3.16.4...3.17.0)

---
updated-dependencies:
- dependency-name: pypdf
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-03 18:54:29 -07:00
Yuhong Sun
602f9c4a0a Default Version to 0.2-dev (#690) 2023-11-03 18:37:01 -07:00
Bradley
551705ad62 Implemented Danswer versioning system. (#649)
* Web & API server versioning system. Displayed on UI.

* Remove some debugging code.

* Integrated backend version into GitHub Action & Docker build workflow using env variables.

* Fixed web container environment variable name.

* Revise Dockerfiles for GitHub Actions workflow.

* Added system information page to admin panel with version info. Updated github workflows to include tagged version, and corresponding changes in the dockerfiles and codebases for web&backend to use env variables if present. Changed to 'dev' naming scheme if no env var is present to indicate local setup. Removed version from admin panel header.

* Added missing systeminfo dir to remote repo.
2023-11-03 18:02:39 -07:00
Weves
d9581ce0ae Fix Notion recursive search for non-shared database 2023-11-03 15:46:23 -07:00
Yuhong Sun
e27800d501 Formatting 2023-11-02 23:31:19 -07:00
Yuhong Sun
927dffecb5 Prompt Layer Rework (#688) 2023-11-02 23:26:47 -07:00
Weves
68b23b6339 Enable database reading in recursive notion crawl 2023-11-02 23:14:54 -07:00
Weves
174f54473e Fix notion recursive search for blocks with children 2023-11-02 22:21:55 -07:00
Weves
329824ab22 Address issue with links for Google Sites connector 2023-11-02 22:01:08 -07:00
Yuhong Sun
b0f76b97ef Guru and Productboard Time Updated (#683) 2023-11-02 14:27:06 -07:00
414 changed files with 25977 additions and 10293 deletions

View File

@@ -0,0 +1,15 @@
name: Sweep Issue
title: 'Sweep: '
description: For small bugs, features, refactors, and tests to be handled by Sweep, an AI-powered junior developer.
labels: sweep
body:
- type: textarea
id: description
attributes:
label: Details
description: Tell Sweep where and what to edit and provide enough context for a new developer to the codebase
placeholder: |
Unit Tests: Write unit tests for <FILE>. Test each function in the file. Make sure to test edge cases.
Bugs: The bug might be in <FILE>. Here are the logs: ...
Features: the new endpoint should use the ... class from <FILE> because it contains ... logic.
Refactors: We are migrating this function to ... version because ...

View File

@@ -1,4 +1,4 @@
name: Build and Push Backend Images on Tagging
name: Build and Push Backend Image on Tag
on:
push:
@@ -32,3 +32,11 @@ jobs:
tags: |
danswer/danswer-backend:${{ github.ref_name }}
danswer/danswer-backend:latest
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
with:
image-ref: docker.io/danswer/danswer-backend:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'

View File

@@ -0,0 +1,42 @@
name: Build and Push Model Server Image on Tag
on:
push:
tags:
- '*'
jobs:
build-and-push:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1
- name: Login to Docker Hub
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Model Server Image Docker Build and Push
uses: docker/build-push-action@v2
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64,linux/arm64
push: true
tags: |
danswer/danswer-model-server:${{ github.ref_name }}
danswer/danswer-model-server:latest
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
with:
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'

View File

@@ -1,4 +1,4 @@
name: Build and Push Web Images on Tagging
name: Build and Push Web Image on Tag
on:
push:
@@ -32,3 +32,11 @@ jobs:
tags: |
danswer/danswer-web-server:${{ github.ref_name }}
danswer/danswer-web-server:latest
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
with:
image-ref: docker.io/danswer/danswer-web-server:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'

3
.gitignore vendored
View File

@@ -1,3 +1,4 @@
.env
.DS_store
.venv
.venv
.mypy_cache

View File

@@ -1,3 +1,5 @@
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/CONTRIBUTING.md"} -->
# Contributing to Danswer
Hey there! We are so excited that you're interested in Danswer.
@@ -86,7 +88,12 @@ Once the above is done, navigate to `danswer/web` run:
npm i
```
Install Playwright (required by the Web Connector), with the python venv active, run:
Install Playwright (required by the Web Connector)
> Note: If you have just done the pip install, open a new terminal and source the python virtual-env again.
This will update the path to include playwright
Then install Playwright by running:
```bash
playwright install
```
@@ -113,7 +120,7 @@ npm run dev
Package the Vespa schema. This will only need to be done when the Vespa schema is updated locally.
Nagivate to `danswer/backend/danswer/document_index/vespa/app_config` and run:
Navigate to `danswer/backend/danswer/document_index/vespa/app_config` and run:
```bash
zip -r ../vespa-app.zip .
```

View File

@@ -1,3 +1,5 @@
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/README.md"} -->
<h2 align="center">
<a href="https://www.danswer.ai/"> <img width="50%" src="https://github.com/danswer-owners/danswer/blob/1fabd9372d66cd54238847197c33f091a724803b/DanswerWithName.png?raw=true)" /></a>
</h2>
@@ -9,7 +11,7 @@
<a href="https://docs.danswer.dev/" target="_blank">
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
</a>
<a href="https://join.slack.com/t/danswer/shared_invite/zt-1u5ycen3o-6SJbWfivLWP5LPyp_jftuw" target="_blank">
<a href="https://join.slack.com/t/danswer/shared_invite/zt-1u3h3ke3b-VGh1idW19R8oiNRiKBYv2w" target="_blank">
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
</a>
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
@@ -27,7 +29,11 @@
Danswer provides a fully-featured web UI:
https://github.com/danswer-ai/danswer/assets/25087905/619607a1-4ad2-41a0-9728-351752acc26e
https://github.com/danswer-ai/danswer/assets/32520769/563be14c-9304-47b5-bf0a-9049c2b6f410
Or, if you prefer, you can plug Danswer into your existing Slack workflows (more integrations to come 😁):
@@ -45,37 +51,43 @@ Danswer can easily be tested locally or deployed on a virtual machine with a sin
We also have built-in support for deployment on Kubernetes. Files for that can be found [here](https://github.com/danswer-ai/danswer/tree/main/deployment/kubernetes).
## 💃 Features
* Direct QA powered by Generative AI models with answers backed by quotes and source links.
* Intelligent Document Retrieval (Semantic Search/Reranking) using the latest LLMs.
* An AI Helper backed by a custom Deep Learning model to interpret user intent.
* Direct QA + Chat powered by Generative AI models with answers backed by quotes and source links.
* Intelligent Document Retrieval (Hybrid Search + Reranking) using the latest NLP models.
* Automatic time/source filter extraction from natural language + custom model to identify user intent.
* User authentication and document level access management.
* Support for an LLM of your choice (GPT-4, Llama2, Orca, etc.)
* Management Dashboard to manage connectors and set up features such as live update fetching.
* Support for LLMs of your choice (GPT-4, Mixstral, Llama2, etc.)
* Management Dashboards to manage connectors and set up features such as live update fetching.
* One line Docker Compose (or Kubernetes) deployment to host Danswer anywhere.
## 🔌 Connectors
Danswer currently syncs documents (every 10 minutes) from:
Efficiently pulls the latest changes from:
* Slack
* GitHub
* Google Drive
* Confluence
* Jira
* Zendesk
* Notion
* Gong
* Slab
* Linear
* Productboard
* Guru
* Zulip
* Bookstack
* Document360
* Request Tracker
* Hubspot
* Local Files
* Websites
* With more to come...
## 🚧 Roadmap
* Chat/Conversation support.
* Organizational understanding.
* Ability to locate and suggest experts.
* Ability to locate and suggest experts from your team.
* Code Search
* Structured Query Languages (SQL, Excel formulas, etc.)
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.

17
backend/.dockerignore Normal file
View File

@@ -0,0 +1,17 @@
**/__pycache__
venv/
env/
*.egg-info
.cache
.git/
.svn/
.vscode/
.idea/
*.log
log/
.env
secrets.yaml
build/
dist/
.coverage
htmlcov/

1
backend/.gitignore vendored
View File

@@ -1,4 +1,5 @@
__pycache__/
.mypy_cache
.idea/
site_crawls/
.ipynb_checkpoints/

View File

@@ -1,10 +1,18 @@
FROM python:3.11.4-slim-bookworm
FROM python:3.11.7-slim-bookworm
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
ARG DANSWER_VERSION=0.3-dev
ENV DANSWER_VERSION=${DANSWER_VERSION}
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
# Install system dependencies
# cmake needed for psycopg (postgres)
# libpq-dev needed for psycopg (postgres)
# curl included just for users' convenience
# zip for Vespa step futher down
# ca-certificates for HTTPS
RUN apt-get update && \
apt-get install -y git cmake pkg-config libprotobuf-c-dev protobuf-compiler \
libprotobuf-dev libgoogle-perftools-dev libpq-dev build-essential cron curl \
supervisor zip ca-certificates gnupg && \
apt-get install -y cmake curl zip ca-certificates && \
rm -rf /var/lib/apt/lists/* && \
apt-get clean
@@ -13,27 +21,15 @@ RUN apt-get update && \
COPY ./requirements/default.txt /tmp/requirements.txt
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt && \
pip uninstall -y py && \
playwright install chromium && \
playwright install-deps chromium
# install nodejs and replace nodejs packaged with playwright (18.17.0) with the one installed below
# based on the instructions found here:
# https://nodejs.org/en/download/package-manager#debian-and-ubuntu-based-linux-distributions
# this is temporarily needed until playwright updates their packaged node version to
# 20.5.1+
RUN mkdir -p /etc/apt/keyrings && \
curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg && \
echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_20.x nodistro main" | tee /etc/apt/sources.list.d/nodesource.list && \
apt-get update && \
apt-get install -y nodejs && \
cp /usr/bin/node /usr/local/lib/python3.11/site-packages/playwright/driver/node && \
apt-get remove -y nodejs
playwright install chromium && playwright install-deps chromium && \
ln -s /usr/local/bin/supervisord /usr/bin/supervisord
# Cleanup for CVEs and size reduction
# Remove tornado test key to placate vulnerability scanners
# More details can be found here:
# https://github.com/tornadoweb/tornado/issues/3107
RUN apt-get remove -y linux-libc-dev && \
# xserver-common and xvfb included by playwright installation but not needed after
# perl-base is part of the base Python Debian image but not needed for Danswer functionality
# perl-base could only be removed with --allow-remove-essential
RUN apt-get remove -y --allow-remove-essential perl-base xserver-common xvfb cmake libldap-2.5-0 libldap-2.5-0 && \
apt-get autoremove -y && \
rm -rf /var/lib/apt/lists/* && \
rm /usr/local/lib/python3.11/site-packages/tornado/test/test.key
@@ -41,18 +37,16 @@ RUN apt-get remove -y linux-libc-dev && \
# Set up application files
WORKDIR /app
COPY ./danswer /app/danswer
COPY ./shared_models /app/shared_models
COPY ./alembic /app/alembic
COPY ./alembic.ini /app/alembic.ini
COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
COPY supervisord.conf /usr/etc/supervisord.conf
# Create Vespa app zip
WORKDIR /app/danswer/document_index/vespa/app_config
RUN zip -r /app/danswer/vespa-app.zip .
WORKDIR /app
# TODO: remove this once all users have migrated
COPY ./scripts/migrate_vespa_to_acl.py /app/migrate_vespa_to_acl.py
ENV PYTHONPATH /app
# Default command which does nothing

View File

@@ -0,0 +1,39 @@
FROM python:3.11.7-slim-bookworm
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
ARG DANSWER_VERSION=0.3-dev
ENV DANSWER_VERSION=${DANSWER_VERSION}
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
COPY ./requirements/model_server.txt /tmp/requirements.txt
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt
RUN apt-get remove -y --allow-remove-essential perl-base && \
apt-get autoremove -y
WORKDIR /app
# Needed for model configs and defaults
COPY ./danswer/configs /app/danswer/configs
COPY ./danswer/dynamic_configs /app/danswer/dynamic_configs
# Utils used by model server
COPY ./danswer/utils/logger.py /app/danswer/utils/logger.py
COPY ./danswer/utils/timing.py /app/danswer/utils/timing.py
COPY ./danswer/utils/telemetry.py /app/danswer/utils/telemetry.py
# Place to fetch version information
COPY ./danswer/__init__.py /app/danswer/__init__.py
# Shared implementations for running NLP models locally
COPY ./danswer/search/search_nlp_models.py /app/danswer/search/search_nlp_models.py
# Request/Response models
COPY ./shared_models /app/shared_models
# Model Server main code
COPY ./model_server /app/model_server
ENV PYTHONPATH /app
CMD ["uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000"]

View File

@@ -1,4 +1,8 @@
Generic single-database configuration with an async dbapi.
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/backend/alembic/README.md"} -->
# Alembic DB Migrations
These files are for creating/updating the tables in the Relational DB (Postgres).
Danswer migrations use a generic single-database configuration with an async dbapi.
## To generate new migrations:
run from danswer/backend:
@@ -7,7 +11,6 @@ run from danswer/backend:
More info can be found here: https://alembic.sqlalchemy.org/en/latest/autogenerate.html
## Running migrations
To run all un-applied migrations:
`alembic upgrade head`

View File

@@ -0,0 +1,37 @@
"""Introduce Danswer APIs
Revision ID: 15326fcec57e
Revises: 77d07dffae64
Create Date: 2023-11-11 20:51:24.228999
"""
from alembic import op
import sqlalchemy as sa
from danswer.configs.constants import DocumentSource
# revision identifiers, used by Alembic.
revision = "15326fcec57e"
down_revision = "77d07dffae64"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column("credential", "is_admin", new_column_name="admin_public")
op.add_column(
"document",
sa.Column("from_ingestion_api", sa.Boolean(), nullable=True),
)
op.alter_column(
"connector",
"source",
type_=sa.String(length=50),
existing_type=sa.Enum(DocumentSource, native_enum=False),
existing_nullable=False,
)
def downgrade() -> None:
op.drop_column("document", "from_ingestion_api")
op.alter_column("credential", "admin_public", new_column_name="is_admin")

View File

@@ -0,0 +1,28 @@
"""Add additional retrieval controls to Persona
Revision ID: 50b683a8295c
Revises: 7da0ae5ad583
Create Date: 2023-11-27 17:23:29.668422
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "50b683a8295c"
down_revision = "7da0ae5ad583"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column("persona", sa.Column("num_chunks", sa.Integer(), nullable=True))
op.add_column(
"persona",
sa.Column("apply_llm_relevance_filter", sa.Boolean(), nullable=True),
)
def downgrade() -> None:
op.drop_column("persona", "apply_llm_relevance_filter")
op.drop_column("persona", "num_chunks")

View File

@@ -0,0 +1,32 @@
"""CC-Pair Name not Unique
Revision ID: 76b60d407dfb
Revises: b156fa702355
Create Date: 2023-12-22 21:42:10.018804
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "76b60d407dfb"
down_revision = "b156fa702355"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute("DELETE FROM connector_credential_pair WHERE name IS NULL")
op.drop_constraint(
"connector_credential_pair__name__key",
"connector_credential_pair",
type_="unique",
)
op.alter_column(
"connector_credential_pair", "name", existing_type=sa.String(), nullable=False
)
def downgrade() -> None:
# This wasn't really required by the code either, no good reason to make it unique again
pass

View File

@@ -0,0 +1,23 @@
"""Add description to persona
Revision ID: 7da0ae5ad583
Revises: e86866a9c78a
Create Date: 2023-11-27 00:16:19.959414
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7da0ae5ad583"
down_revision = "e86866a9c78a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column("persona", sa.Column("description", sa.String(), nullable=True))
def downgrade() -> None:
op.drop_column("persona", "description")

View File

@@ -0,0 +1,36 @@
"""Add chat session to query_event
Revision ID: 80696cf850ae
Revises: 15326fcec57e
Create Date: 2023-11-26 02:38:35.008070
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "80696cf850ae"
down_revision = "15326fcec57e"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"query_event",
sa.Column("chat_session_id", sa.Integer(), nullable=True),
)
op.create_foreign_key(
"fk_query_event_chat_session_id",
"query_event",
"chat_session",
["chat_session_id"],
["id"],
)
def downgrade() -> None:
op.drop_constraint(
"fk_query_event_chat_session_id", "query_event", type_="foreignkey"
)
op.drop_column("query_event", "chat_session_id")

View File

@@ -0,0 +1,34 @@
"""Add is_visible to Persona
Revision ID: 891cd83c87a8
Revises: 76b60d407dfb
Create Date: 2023-12-21 11:55:54.132279
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "891cd83c87a8"
down_revision = "76b60d407dfb"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"persona",
sa.Column("is_visible", sa.Boolean(), nullable=True),
)
op.execute("UPDATE persona SET is_visible = true")
op.alter_column("persona", "is_visible", nullable=False)
op.add_column(
"persona",
sa.Column("display_priority", sa.Integer(), nullable=True),
)
def downgrade() -> None:
op.drop_column("persona", "is_visible")
op.drop_column("persona", "display_priority")

View File

@@ -0,0 +1,61 @@
"""Tags
Revision ID: 904e5138fffb
Revises: 891cd83c87a8
Create Date: 2024-01-01 10:44:43.733974
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "904e5138fffb"
down_revision = "891cd83c87a8"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"tag",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("tag_key", sa.String(), nullable=False),
sa.Column("tag_value", sa.String(), nullable=False),
sa.Column("source", sa.String(), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint(
"tag_key", "tag_value", "source", name="_tag_key_value_source_uc"
),
)
op.create_table(
"document__tag",
sa.Column("document_id", sa.String(), nullable=False),
sa.Column("tag_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["document_id"],
["document.id"],
),
sa.ForeignKeyConstraint(
["tag_id"],
["tag.id"],
),
sa.PrimaryKeyConstraint("document_id", "tag_id"),
)
op.add_column(
"search_doc",
sa.Column(
"doc_metadata",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)
op.execute("UPDATE search_doc SET doc_metadata = '{}' WHERE doc_metadata IS NULL")
op.alter_column("search_doc", "doc_metadata", nullable=False)
def downgrade() -> None:
op.drop_table("document__tag")
op.drop_table("tag")
op.drop_column("search_doc", "doc_metadata")

View File

@@ -0,0 +1,520 @@
"""Chat Reworked
Revision ID: b156fa702355
Revises: baf71f781b9e
Create Date: 2023-12-12 00:57:41.823371
"""
import fastapi_users_db_sqlalchemy
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy.dialects.postgresql import ENUM
from danswer.configs.constants import DocumentSource
# revision identifiers, used by Alembic.
revision = "b156fa702355"
down_revision = "baf71f781b9e"
branch_labels = None
depends_on = None
searchtype_enum = ENUM(
"KEYWORD", "SEMANTIC", "HYBRID", name="searchtype", create_type=True
)
recencybiassetting_enum = ENUM(
"FAVOR_RECENT",
"BASE_DECAY",
"NO_DECAY",
"AUTO",
name="recencybiassetting",
create_type=True,
)
def upgrade() -> None:
bind = op.get_bind()
searchtype_enum.create(bind)
recencybiassetting_enum.create(bind)
# This is irrecoverable, whatever
op.execute("DELETE FROM chat_feedback")
op.execute("DELETE FROM document_retrieval_feedback")
op.create_table(
"search_doc",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("document_id", sa.String(), nullable=False),
sa.Column("chunk_ind", sa.Integer(), nullable=False),
sa.Column("semantic_id", sa.String(), nullable=False),
sa.Column("link", sa.String(), nullable=True),
sa.Column("blurb", sa.String(), nullable=False),
sa.Column("boost", sa.Integer(), nullable=False),
sa.Column(
"source_type",
sa.Enum(DocumentSource, native=False),
nullable=False,
),
sa.Column("hidden", sa.Boolean(), nullable=False),
sa.Column("score", sa.Float(), nullable=False),
sa.Column("match_highlights", postgresql.ARRAY(sa.String()), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("primary_owners", postgresql.ARRAY(sa.String()), nullable=True),
sa.Column("secondary_owners", postgresql.ARRAY(sa.String()), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"prompt",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.Column("name", sa.String(), nullable=False),
sa.Column("description", sa.String(), nullable=False),
sa.Column("system_prompt", sa.Text(), nullable=False),
sa.Column("task_prompt", sa.Text(), nullable=False),
sa.Column("include_citations", sa.Boolean(), nullable=False),
sa.Column("datetime_aware", sa.Boolean(), nullable=False),
sa.Column("default_prompt", sa.Boolean(), nullable=False),
sa.Column("deleted", sa.Boolean(), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"persona__prompt",
sa.Column("persona_id", sa.Integer(), nullable=False),
sa.Column("prompt_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.ForeignKeyConstraint(
["prompt_id"],
["prompt.id"],
),
sa.PrimaryKeyConstraint("persona_id", "prompt_id"),
)
# Changes to persona first so chat_sessions can have the right persona
# The empty persona will be overwritten on server startup
op.add_column(
"persona",
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
)
op.add_column(
"persona",
sa.Column(
"search_type",
searchtype_enum,
nullable=True,
),
)
op.execute("UPDATE persona SET search_type = 'HYBRID'")
op.alter_column("persona", "search_type", nullable=False)
op.add_column(
"persona",
sa.Column("llm_relevance_filter", sa.Boolean(), nullable=True),
)
op.execute("UPDATE persona SET llm_relevance_filter = TRUE")
op.alter_column("persona", "llm_relevance_filter", nullable=False)
op.add_column(
"persona",
sa.Column("llm_filter_extraction", sa.Boolean(), nullable=True),
)
op.execute("UPDATE persona SET llm_filter_extraction = TRUE")
op.alter_column("persona", "llm_filter_extraction", nullable=False)
op.add_column(
"persona",
sa.Column(
"recency_bias",
recencybiassetting_enum,
nullable=True,
),
)
op.execute("UPDATE persona SET recency_bias = 'BASE_DECAY'")
op.alter_column("persona", "recency_bias", nullable=False)
op.alter_column("persona", "description", existing_type=sa.VARCHAR(), nullable=True)
op.execute("UPDATE persona SET description = ''")
op.alter_column("persona", "description", nullable=False)
op.create_foreign_key("persona__user_fk", "persona", "user", ["user_id"], ["id"])
op.drop_column("persona", "datetime_aware")
op.drop_column("persona", "tools")
op.drop_column("persona", "hint_text")
op.drop_column("persona", "apply_llm_relevance_filter")
op.drop_column("persona", "retrieval_enabled")
op.drop_column("persona", "system_text")
# Need to create a persona row so fk can work
result = bind.execute(sa.text("SELECT 1 FROM persona WHERE id = 0"))
exists = result.fetchone()
if not exists:
op.execute(
sa.text(
"""
INSERT INTO persona (
id, user_id, name, description, search_type, num_chunks,
llm_relevance_filter, llm_filter_extraction, recency_bias,
llm_model_version_override, default_persona, deleted
) VALUES (
0, NULL, '', '', 'HYBRID', NULL,
TRUE, TRUE, 'BASE_DECAY', NULL, TRUE, FALSE
)
"""
)
)
delete_statement = sa.text(
"""
DELETE FROM persona
WHERE name = 'Danswer' AND default_persona = TRUE AND id != 0
"""
)
bind.execute(delete_statement)
op.add_column(
"chat_feedback",
sa.Column("chat_message_id", sa.Integer(), nullable=False),
)
op.drop_constraint(
"chat_feedback_chat_message_chat_session_id_chat_message_me_fkey",
"chat_feedback",
type_="foreignkey",
)
op.drop_column("chat_feedback", "chat_message_edit_number")
op.drop_column("chat_feedback", "chat_message_chat_session_id")
op.drop_column("chat_feedback", "chat_message_message_number")
op.add_column(
"chat_message",
sa.Column(
"id",
sa.Integer(),
primary_key=True,
autoincrement=True,
nullable=False,
unique=True,
),
)
op.add_column(
"chat_message",
sa.Column("parent_message", sa.Integer(), nullable=True),
)
op.add_column(
"chat_message",
sa.Column("latest_child_message", sa.Integer(), nullable=True),
)
op.add_column(
"chat_message", sa.Column("rephrased_query", sa.Text(), nullable=True)
)
op.add_column("chat_message", sa.Column("prompt_id", sa.Integer(), nullable=True))
op.add_column(
"chat_message",
sa.Column("citations", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
)
op.add_column("chat_message", sa.Column("error", sa.Text(), nullable=True))
op.drop_constraint("fk_chat_message_persona_id", "chat_message", type_="foreignkey")
op.create_foreign_key(
"chat_message__prompt_fk", "chat_message", "prompt", ["prompt_id"], ["id"]
)
op.drop_column("chat_message", "parent_edit_number")
op.drop_column("chat_message", "persona_id")
op.drop_column("chat_message", "reference_docs")
op.drop_column("chat_message", "edit_number")
op.drop_column("chat_message", "latest")
op.drop_column("chat_message", "message_number")
op.add_column("chat_session", sa.Column("one_shot", sa.Boolean(), nullable=True))
op.execute("UPDATE chat_session SET one_shot = TRUE")
op.alter_column("chat_session", "one_shot", nullable=False)
op.alter_column(
"chat_session",
"persona_id",
existing_type=sa.INTEGER(),
nullable=True,
)
op.execute("UPDATE chat_session SET persona_id = 0")
op.alter_column("chat_session", "persona_id", nullable=False)
op.add_column(
"document_retrieval_feedback",
sa.Column("chat_message_id", sa.Integer(), nullable=False),
)
op.drop_constraint(
"document_retrieval_feedback_qa_event_id_fkey",
"document_retrieval_feedback",
type_="foreignkey",
)
op.create_foreign_key(
"document_retrieval_feedback__chat_message_fk",
"document_retrieval_feedback",
"chat_message",
["chat_message_id"],
["id"],
)
op.drop_column("document_retrieval_feedback", "qa_event_id")
# Relation table must be created after the other tables are correct
op.create_table(
"chat_message__search_doc",
sa.Column("chat_message_id", sa.Integer(), nullable=False),
sa.Column("search_doc_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["chat_message_id"],
["chat_message.id"],
),
sa.ForeignKeyConstraint(
["search_doc_id"],
["search_doc.id"],
),
sa.PrimaryKeyConstraint("chat_message_id", "search_doc_id"),
)
# Needs to be created after chat_message id field is added
op.create_foreign_key(
"chat_feedback__chat_message_fk",
"chat_feedback",
"chat_message",
["chat_message_id"],
["id"],
)
op.drop_table("query_event")
def downgrade() -> None:
op.drop_constraint(
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
)
op.drop_constraint(
"document_retrieval_feedback__chat_message_fk",
"document_retrieval_feedback",
type_="foreignkey",
)
op.drop_constraint("persona__user_fk", "persona", type_="foreignkey")
op.drop_constraint("chat_message__prompt_fk", "chat_message", type_="foreignkey")
op.drop_constraint(
"chat_message__search_doc_chat_message_id_fkey",
"chat_message__search_doc",
type_="foreignkey",
)
op.add_column(
"persona",
sa.Column("system_text", sa.TEXT(), autoincrement=False, nullable=True),
)
op.add_column(
"persona",
sa.Column(
"retrieval_enabled",
sa.BOOLEAN(),
autoincrement=False,
nullable=True,
),
)
op.execute("UPDATE persona SET retrieval_enabled = TRUE")
op.alter_column("persona", "retrieval_enabled", nullable=False)
op.add_column(
"persona",
sa.Column(
"apply_llm_relevance_filter",
sa.BOOLEAN(),
autoincrement=False,
nullable=True,
),
)
op.add_column(
"persona",
sa.Column("hint_text", sa.TEXT(), autoincrement=False, nullable=True),
)
op.add_column(
"persona",
sa.Column(
"tools",
postgresql.JSONB(astext_type=sa.Text()),
autoincrement=False,
nullable=True,
),
)
op.add_column(
"persona",
sa.Column("datetime_aware", sa.BOOLEAN(), autoincrement=False, nullable=True),
)
op.execute("UPDATE persona SET datetime_aware = TRUE")
op.alter_column("persona", "datetime_aware", nullable=False)
op.alter_column("persona", "description", existing_type=sa.VARCHAR(), nullable=True)
op.drop_column("persona", "recency_bias")
op.drop_column("persona", "llm_filter_extraction")
op.drop_column("persona", "llm_relevance_filter")
op.drop_column("persona", "search_type")
op.drop_column("persona", "user_id")
op.add_column(
"document_retrieval_feedback",
sa.Column("qa_event_id", sa.INTEGER(), autoincrement=False, nullable=False),
)
op.drop_column("document_retrieval_feedback", "chat_message_id")
op.alter_column(
"chat_session", "persona_id", existing_type=sa.INTEGER(), nullable=True
)
op.drop_column("chat_session", "one_shot")
op.add_column(
"chat_message",
sa.Column(
"message_number",
sa.INTEGER(),
autoincrement=False,
nullable=False,
primary_key=True,
),
)
op.add_column(
"chat_message",
sa.Column("latest", sa.BOOLEAN(), autoincrement=False, nullable=False),
)
op.add_column(
"chat_message",
sa.Column(
"edit_number",
sa.INTEGER(),
autoincrement=False,
nullable=False,
primary_key=True,
),
)
op.add_column(
"chat_message",
sa.Column(
"reference_docs",
postgresql.JSONB(astext_type=sa.Text()),
autoincrement=False,
nullable=True,
),
)
op.add_column(
"chat_message",
sa.Column("persona_id", sa.INTEGER(), autoincrement=False, nullable=True),
)
op.add_column(
"chat_message",
sa.Column(
"parent_edit_number",
sa.INTEGER(),
autoincrement=False,
nullable=True,
),
)
op.create_foreign_key(
"fk_chat_message_persona_id",
"chat_message",
"persona",
["persona_id"],
["id"],
)
op.drop_column("chat_message", "error")
op.drop_column("chat_message", "citations")
op.drop_column("chat_message", "prompt_id")
op.drop_column("chat_message", "rephrased_query")
op.drop_column("chat_message", "latest_child_message")
op.drop_column("chat_message", "parent_message")
op.drop_column("chat_message", "id")
op.add_column(
"chat_feedback",
sa.Column(
"chat_message_message_number",
sa.INTEGER(),
autoincrement=False,
nullable=False,
),
)
op.add_column(
"chat_feedback",
sa.Column(
"chat_message_chat_session_id",
sa.INTEGER(),
autoincrement=False,
nullable=False,
primary_key=True,
),
)
op.add_column(
"chat_feedback",
sa.Column(
"chat_message_edit_number",
sa.INTEGER(),
autoincrement=False,
nullable=False,
),
)
op.drop_column("chat_feedback", "chat_message_id")
op.create_table(
"query_event",
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column("query", sa.VARCHAR(), autoincrement=False, nullable=False),
sa.Column(
"selected_search_flow",
sa.VARCHAR(),
autoincrement=False,
nullable=True,
),
sa.Column("llm_answer", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("feedback", sa.VARCHAR(), autoincrement=False, nullable=True),
sa.Column("user_id", sa.UUID(), autoincrement=False, nullable=True),
sa.Column(
"time_created",
postgresql.TIMESTAMP(timezone=True),
server_default=sa.text("now()"),
autoincrement=False,
nullable=False,
),
sa.Column(
"retrieved_document_ids",
postgresql.ARRAY(sa.VARCHAR()),
autoincrement=False,
nullable=True,
),
sa.Column("chat_session_id", sa.INTEGER(), autoincrement=False, nullable=True),
sa.ForeignKeyConstraint(
["chat_session_id"],
["chat_session.id"],
name="fk_query_event_chat_session_id",
),
sa.ForeignKeyConstraint(
["user_id"], ["user.id"], name="query_event_user_id_fkey"
),
sa.PrimaryKeyConstraint("id", name="query_event_pkey"),
)
op.drop_table("chat_message__search_doc")
op.drop_table("persona__prompt")
op.drop_table("prompt")
op.drop_table("search_doc")
op.create_unique_constraint(
"uq_chat_message_combination",
"chat_message",
["chat_session_id", "message_number", "edit_number"],
)
op.create_foreign_key(
"chat_feedback_chat_message_chat_session_id_chat_message_me_fkey",
"chat_feedback",
"chat_message",
[
"chat_message_chat_session_id",
"chat_message_message_number",
"chat_message_edit_number",
],
["chat_session_id", "message_number", "edit_number"],
)
op.create_foreign_key(
"document_retrieval_feedback_qa_event_id_fkey",
"document_retrieval_feedback",
"query_event",
["qa_event_id"],
["id"],
)
op.execute("DROP TYPE IF EXISTS searchtype")
op.execute("DROP TYPE IF EXISTS recencybiassetting")
op.execute("DROP TYPE IF EXISTS documentsource")

View File

@@ -0,0 +1,26 @@
"""Add llm_model_version_override to Persona
Revision ID: baf71f781b9e
Revises: 50b683a8295c
Create Date: 2023-12-06 21:56:50.286158
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "baf71f781b9e"
down_revision = "50b683a8295c"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"persona",
sa.Column("llm_model_version_override", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("persona", "llm_model_version_override")

View File

@@ -0,0 +1,27 @@
"""Add persona to chat_session
Revision ID: e86866a9c78a
Revises: 80696cf850ae
Create Date: 2023-11-26 02:51:47.657357
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "e86866a9c78a"
down_revision = "80696cf850ae"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column("chat_session", sa.Column("persona_id", sa.Integer(), nullable=True))
op.create_foreign_key(
"fk_chat_session_persona_id", "chat_session", "persona", ["persona_id"], ["id"]
)
def downgrade() -> None:
op.drop_constraint("fk_chat_session_persona_id", "chat_session", type_="foreignkey")
op.drop_column("chat_session", "persona_id")

View File

@@ -0,0 +1,3 @@
import os
__version__ = os.environ.get("DANSWER_VERSION", "") or "0.3-dev"

View File

@@ -4,7 +4,7 @@ from danswer.access.models import DocumentAccess
from danswer.configs.constants import PUBLIC_DOC_PAT
from danswer.db.document import get_acccess_info_for_documents
from danswer.db.models import User
from danswer.server.models import ConnectorCredentialPairIdentifier
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.utils.variable_functionality import fetch_versioned_implementation

View File

@@ -48,6 +48,8 @@ from danswer.db.engine import get_session
from danswer.db.models import AccessToken
from danswer.db.models import User
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_versioned_implementation
@@ -66,6 +68,12 @@ def verify_auth_setting() -> None:
logger.info(f"Using Auth Type: {AUTH_TYPE.value}")
def user_needs_to_be_verified() -> bool:
# all other auth types besides basic should require users to be
# verified
return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
def get_user_whitelist() -> list[str]:
global _user_whitelist
if _user_whitelist is None:
@@ -102,10 +110,9 @@ def verify_email_domain(email: str) -> None:
def send_user_verification_email(user_email: str, token: str) -> None:
msg = MIMEMultipart()
msg["Subject"] = "Danswer Email Verification"
msg["From"] = "no-reply@danswer.dev"
msg["To"] = user_email
link = f"{WEB_DOMAIN}/verify-email?token={token}"
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
body = MIMEText(f"Click the following link to verify your email address: {link}")
msg.attach(body)
@@ -170,6 +177,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
self, user: User, request: Optional[Request] = None
) -> None:
logger.info(f"User {user.id} has registered.")
optional_telemetry(record_type=RecordType.SIGN_UP, data={"user": "create"})
async def on_after_forgot_password(
self, user: User, token: str, request: Optional[Request] = None
@@ -253,9 +261,11 @@ fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
)
optional_valid_user = fastapi_users.current_user(
active=True, verified=REQUIRE_EMAIL_VERIFICATION, optional=True
)
# NOTE: verified=REQUIRE_EMAIL_VERIFICATION is not used here since we
# take care of that in `double_check_user` ourself. This is needed, since
# we want the /me endpoint to still return a user even if they are not
# yet verified, so that the frontend knows they exist
optional_valid_user = fastapi_users.current_user(active=True, optional=True)
async def double_check_user(
@@ -273,6 +283,12 @@ async def double_check_user(
detail="Access denied. User is not authenticated.",
)
if user_needs_to_be_verified() and not user.is_verified:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not verified.",
)
return user

View File

@@ -36,8 +36,9 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
celery_broker_url = "sqla+" + build_connection_string(db_api=SYNC_DB_API)
celery_backend_url = "db+" + build_connection_string(db_api=SYNC_DB_API)
connection_string = build_connection_string(db_api=SYNC_DB_API)
celery_broker_url = f"sqla+{connection_string}"
celery_backend_url = f"db+{connection_string}"
celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url)
@@ -208,8 +209,10 @@ def clean_old_temp_files_task(
Currently handled async of the indexing job"""
os.makedirs(base_path, exist_ok=True)
for file in os.listdir(base_path):
if file_age_in_hours(file) > age_threshold_in_hours:
os.remove(Path(base_path) / file)
full_file_path = Path(base_path) / file
if file_age_in_hours(full_file_path) > age_threshold_in_hours:
logger.info(f"Cleaning up uploaded file: {full_file_path}")
os.remove(full_file_path)
#####

View File

@@ -2,7 +2,7 @@ from sqlalchemy.orm import Session
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.db.tasks import get_latest_task
from danswer.server.models import DeletionAttemptSnapshot
from danswer.server.documents.models import DeletionAttemptSnapshot
def get_deletion_status(

View File

@@ -11,8 +11,6 @@ connector / credential pair from the access list
(6) delete all relevant entries from postgres
"""
import time
from collections.abc import Callable
from typing import cast
from sqlalchemy.orm import Session
@@ -35,9 +33,8 @@ from danswer.db.index_attempt import delete_index_attempts
from danswer.db.models import ConnectorCredentialPair
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import UpdateRequest
from danswer.server.models import ConnectorCredentialPairIdentifier
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
logger = setup_logger()
@@ -173,14 +170,8 @@ def delete_connector_credential_pair(
# Clean up document sets / access information from Postgres
# and sync these updates to Vespa
cleanup_synced_entities__versioned = cast(
Callable[[ConnectorCredentialPair, Session], None],
fetch_versioned_implementation(
"danswer.background.connector_deletion",
"cleanup_synced_entities",
),
)
cleanup_synced_entities__versioned(cc_pair, db_session)
# TODO: add user group cleanup with `fetch_versioned_implementation`
cleanup_synced_entities(cc_pair, db_session)
# clean up the rest of the related Postgres entities
delete_index_attempts(

View File

@@ -0,0 +1,75 @@
"""Experimental functionality related to splitting up indexing
into a series of checkpoints to better handle intermittent failures
/ jobs being killed by cloud providers."""
import datetime
from danswer.configs.app_configs import EXPERIMENTAL_CHECKPOINTING_ENABLED
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import datetime_to_utc
def _2010_dt() -> datetime.datetime:
return datetime.datetime(year=2010, month=1, day=1, tzinfo=datetime.timezone.utc)
def _2020_dt() -> datetime.datetime:
return datetime.datetime(year=2020, month=1, day=1, tzinfo=datetime.timezone.utc)
def _default_end_time(
last_successful_run: datetime.datetime | None,
) -> datetime.datetime:
"""If year is before 2010, go to the beginning of 2010.
If year is 2010-2020, go in 5 year increments.
If year > 2020, then go in 180 day increments.
For connectors that don't support a `filter_by` and instead rely on `sort_by`
for polling, then this will cause a massive duplication of fetches. For these
connectors, you may want to override this function to return a more reasonable
plan (e.g. extending the 2020+ windows to 6 months, 1 year, or higher)."""
last_successful_run = (
datetime_to_utc(last_successful_run) if last_successful_run else None
)
if last_successful_run is None or last_successful_run < _2010_dt():
return _2010_dt()
if last_successful_run < _2020_dt():
return min(last_successful_run + datetime.timedelta(days=365 * 5), _2020_dt())
return last_successful_run + datetime.timedelta(days=180)
def find_end_time_for_indexing_attempt(
last_successful_run: datetime.datetime | None, source_type: DocumentSource
) -> datetime.datetime | None:
# NOTE: source_type can be used to override the default for certain connectors
end_of_window = _default_end_time(last_successful_run)
now = datetime.datetime.now(tz=datetime.timezone.utc)
if end_of_window < now:
return end_of_window
# None signals that we should index up to current time
return None
def get_time_windows_for_index_attempt(
last_successful_run: datetime.datetime, source_type: DocumentSource
) -> list[tuple[datetime.datetime, datetime.datetime]]:
if not EXPERIMENTAL_CHECKPOINTING_ENABLED:
return [(last_successful_run, datetime.datetime.now(tz=datetime.timezone.utc))]
time_windows: list[tuple[datetime.datetime, datetime.datetime]] = []
start_of_window: datetime.datetime | None = last_successful_run
while start_of_window:
end_of_window = find_end_time_for_indexing_attempt(
last_successful_run=start_of_window, source_type=source_type
)
time_windows.append(
(
start_of_window,
end_of_window or datetime.datetime.now(tz=datetime.timezone.utc),
)
)
start_of_window = end_of_window
return time_windows

View File

@@ -0,0 +1,33 @@
import asyncio
import psutil
from dask.distributed import WorkerPlugin
from distributed import Worker
from danswer.utils.logger import setup_logger
logger = setup_logger()
class ResourceLogger(WorkerPlugin):
def __init__(self, log_interval: int = 60 * 5):
self.log_interval = log_interval
def setup(self, worker: Worker) -> None:
"""This method will be called when the plugin is attached to a worker."""
self.worker = worker
worker.loop.add_callback(self.log_resources)
async def log_resources(self) -> None:
"""Periodically log CPU and memory usage.
NOTE: must be async or else will clog up the worker indefinitely due to the fact that
Dask uses Tornado under the hood (which is async)"""
while True:
cpu_percent = psutil.cpu_percent(interval=None)
memory_available_gb = psutil.virtual_memory().available / (1024.0**3)
# You can now log these values or send them to a monitoring service
logger.debug(
f"Worker {self.worker.address}: CPU usage {cpu_percent}%, Memory available {memory_available_gb}GB"
)
await asyncio.sleep(self.log_interval)

View File

@@ -4,12 +4,13 @@ not follow the expected behavior, etc.
NOTE: cannot use Celery directly due to
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
import multiprocessing
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
from typing import Literal
from torch import multiprocessing
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -94,7 +95,7 @@ class SimpleJobClient:
job_id = self.job_id_counter
self.job_id_counter += 1
process = multiprocessing.Process(target=func, args=args)
process = multiprocessing.Process(target=func, args=args, daemon=True)
job = SimpleJob(id=job_id, process=process)
process.start()

View File

@@ -0,0 +1,260 @@
import time
from datetime import datetime
from datetime import timezone
import torch
from sqlalchemy.orm import Session
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import IndexAttemptMetadata
from danswer.connectors.models import InputType
from danswer.db.connector import disable_connector
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.credentials import backend_update_credential_json
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_in_progress
from danswer.db.index_attempt import mark_attempt_succeeded
from danswer.db.index_attempt import update_docs_indexed
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logger import IndexAttemptSingleton
from danswer.utils.logger import setup_logger
logger = setup_logger()
def _get_document_generator(
db_session: Session,
attempt: IndexAttempt,
start_time: datetime,
end_time: datetime,
) -> GenerateDocumentsOutput:
"""NOTE: `start_time` and `end_time` are only used for poll connectors"""
task = attempt.connector.input_type
try:
runnable_connector, new_credential_json = instantiate_connector(
attempt.connector.source,
task,
attempt.connector.connector_specific_config,
attempt.credential.credential_json,
)
if new_credential_json is not None:
backend_update_credential_json(
attempt.credential, new_credential_json, db_session
)
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")
disable_connector(attempt.connector.id, db_session)
raise e
if task == InputType.LOAD_STATE:
assert isinstance(runnable_connector, LoadConnector)
doc_batch_generator = runnable_connector.load_from_state()
elif task == InputType.POLL:
assert isinstance(runnable_connector, PollConnector)
if attempt.connector_id is None or attempt.credential_id is None:
raise ValueError(
f"Polling attempt {attempt.id} is missing connector_id or credential_id, "
f"can't fetch time range."
)
logger.info(f"Polling for updates between {start_time} and {end_time}")
doc_batch_generator = runnable_connector.poll_source(
start=start_time.timestamp(), end=end_time.timestamp()
)
else:
# Event types cannot be handled by a background type
raise RuntimeError(f"Invalid task type: {task}")
return doc_batch_generator
def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
) -> None:
"""
1. Get documents which are either new or updated from specified application
2. Embed and index these documents into the chosen datastore (vespa)
3. Updates Postgres to record the indexed documents + the outcome of this run
"""
start_time = time.time()
# mark as started
mark_attempt_in_progress(index_attempt, db_session)
update_connector_credential_pair(
db_session=db_session,
connector_id=index_attempt.connector.id,
credential_id=index_attempt.credential.id,
attempt_status=IndexingStatus.IN_PROGRESS,
)
indexing_pipeline = build_indexing_pipeline()
db_connector = index_attempt.connector
db_credential = index_attempt.credential
last_successful_index_time = get_last_successful_attempt_time(
connector_id=db_connector.id,
credential_id=db_credential.id,
db_session=db_session,
)
net_doc_change = 0
document_count = 0
chunk_count = 0
run_end_dt = None
for ind, (window_start, window_end) in enumerate(
get_time_windows_for_index_attempt(
last_successful_run=datetime.fromtimestamp(
last_successful_index_time, tz=timezone.utc
),
source_type=db_connector.source,
)
):
doc_batch_generator = _get_document_generator(
db_session=db_session,
attempt=index_attempt,
start_time=window_start,
end_time=window_end,
)
try:
for doc_batch in doc_batch_generator:
# check if connector is disabled mid run and stop if so
db_session.refresh(db_connector)
if db_connector.disabled:
# let the `except` block handle this
raise RuntimeError("Connector was disabled mid run")
logger.debug(
f"Indexing batch of documents: {[doc.to_short_descriptor() for doc in doc_batch]}"
)
new_docs, total_batch_chunks = indexing_pipeline(
documents=doc_batch,
index_attempt_metadata=IndexAttemptMetadata(
connector_id=db_connector.id,
credential_id=db_credential.id,
),
)
net_doc_change += new_docs
chunk_count += total_batch_chunks
document_count += len(doc_batch)
# commit transaction so that the `update` below begins
# with a brand new transaction. Postgres uses the start
# of the transactions when computing `NOW()`, so if we have
# a long running transaction, the `time_updated` field will
# be inaccurate
db_session.commit()
# This new value is updated every batch, so UI can refresh per batch update
update_docs_indexed(
db_session=db_session,
index_attempt=index_attempt,
total_docs_indexed=document_count,
new_docs_indexed=net_doc_change,
)
run_end_dt = window_end
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
attempt_status=IndexingStatus.IN_PROGRESS,
net_docs=net_doc_change,
run_dt=run_end_dt,
)
except Exception as e:
logger.info(
f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds"
)
# Only mark the attempt as a complete failure if this is the first indexing window.
# Otherwise, some progress was made - the next run will not start from the beginning.
# In this case, it is not accurate to mark it as a failure. When the next run begins,
# if that fails immediately, it will be marked as a failure.
#
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
# to give better clarity in the UI, as the next run will never happen.
if ind == 0 or db_connector.disabled:
mark_attempt_failed(index_attempt, db_session, failure_reason=str(e))
update_connector_credential_pair(
db_session=db_session,
connector_id=index_attempt.connector.id,
credential_id=index_attempt.credential.id,
attempt_status=IndexingStatus.FAILED,
net_docs=net_doc_change,
)
raise e
# break => similar to success case. As mentioned above, if the next run fails for the same
# reason it will then be marked as a failure
break
mark_attempt_succeeded(index_attempt, db_session)
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
attempt_status=IndexingStatus.SUCCESS,
net_docs=net_doc_change,
run_dt=run_end_dt,
)
logger.info(
f"Indexed or refreshed {document_count} total documents for a total of {chunk_count} indexed chunks"
)
logger.info(
f"Connector successfully finished, elapsed time: {time.time() - start_time} seconds"
)
def run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None:
"""Entrypoint for indexing run when using dask distributed.
Wraps the actual logic in a `try` block so that we can catch any exceptions
and mark the attempt as failed."""
try:
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
logger.info(f"Setting task to use {num_threads} threads")
torch.set_num_threads(num_threads)
with Session(get_sqlalchemy_engine()) as db_session:
attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
if attempt is None:
raise RuntimeError(
f"Unable to find IndexAttempt for ID '{index_attempt_id}'"
)
logger.info(
f"Running indexing attempt for connector: '{attempt.connector.name}', "
f"with config: '{attempt.connector.connector_specific_config}', and "
f"with credentials: '{attempt.credential_id}'"
)
_run_indexing(
db_session=db_session,
index_attempt=attempt,
)
logger.info(
f"Completed indexing attempt for connector: '{attempt.connector.name}', "
f"with config: '{attempt.connector.connector_specific_config}', and "
f"with credentials: '{attempt.credential_id}'"
)
except Exception as e:
logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}")

View File

@@ -1,7 +1,6 @@
import logging
import time
from datetime import datetime
from datetime import timezone
import dask
import torch
@@ -10,23 +9,19 @@ from dask.distributed import Future
from distributed import LocalCluster
from sqlalchemy.orm import Session
from danswer.background.indexing.dask_utils import ResourceLogger
from danswer.background.indexing.job_client import SimpleJob
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.configs.app_configs import EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
from danswer.configs.app_configs import LOG_LEVEL
from danswer.configs.app_configs import MODEL_SERVER_HOST
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import IndexAttemptMetadata
from danswer.connectors.models import InputType
from danswer.db.connector import disable_connector
from danswer.db.connector import fetch_connectors
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
from danswer.db.connector_credential_pair import mark_all_in_progress_cc_pairs_failed
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.credentials import backend_update_credential_json
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import create_index_attempt
@@ -35,15 +30,10 @@ from danswer.db.index_attempt import get_inprogress_index_attempts
from danswer.db.index_attempt import get_last_attempt
from danswer.db.index_attempt import get_not_started_index_attempts
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_in_progress
from danswer.db.index_attempt import mark_attempt_succeeded
from danswer.db.index_attempt import update_docs_indexed
from danswer.db.models import Connector
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
from danswer.search.search_nlp_models import warm_up_models
from danswer.utils.logger import IndexAttemptSingleton
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -57,6 +47,9 @@ _UNEXPECTED_STATE_FAILURE_REASON = (
)
"""Util funcs"""
def _get_num_threads() -> int:
"""Get # of "threads" to use for ML models in an indexing job. By default uses
the torch implementation, which returns the # of physical cores on the machine.
@@ -64,19 +57,34 @@ def _get_num_threads() -> int:
return max(MIN_THREADS_ML_MODELS, torch.get_num_threads())
def should_create_new_indexing(
def _should_create_new_indexing(
connector: Connector, last_index: IndexAttempt | None, db_session: Session
) -> bool:
if connector.refresh_freq is None:
return False
if not last_index:
return True
# only one scheduled job per connector at a time
if last_index.status == IndexingStatus.NOT_STARTED:
return False
current_db_time = get_db_current_time(db_session)
time_since_index = current_db_time - last_index.time_updated
return time_since_index.total_seconds() >= connector.refresh_freq
def mark_run_failed(
def _is_indexing_job_marked_as_finished(index_attempt: IndexAttempt | None) -> bool:
if index_attempt is None:
return False
return (
index_attempt.status == IndexingStatus.FAILED
or index_attempt.status == IndexingStatus.SUCCESS
)
def _mark_run_failed(
db_session: Session, index_attempt: IndexAttempt, failure_reason: str
) -> None:
"""Marks the `index_attempt` row as failed + updates the `
@@ -102,342 +110,141 @@ def mark_run_failed(
)
def create_indexing_jobs(
db_session: Session, existing_jobs: dict[int, Future | SimpleJob]
) -> None:
"""Main funcs"""
def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
"""Creates new indexing jobs for each connector / credential pair which is:
1. Enabled
2. `refresh_frequency` time has passed since the last indexing run for this pair
3. There is not already an ongoing indexing attempt for this pair
"""
ongoing_pairs: set[tuple[int | None, int | None]] = set()
for attempt_id in existing_jobs:
attempt = get_index_attempt(db_session=db_session, index_attempt_id=attempt_id)
if attempt is None:
logger.error(
f"Unable to find IndexAttempt for ID '{attempt_id}' when creating "
"indexing jobs"
with Session(get_sqlalchemy_engine()) as db_session:
ongoing_pairs: set[tuple[int | None, int | None]] = set()
for attempt_id in existing_jobs:
attempt = get_index_attempt(
db_session=db_session, index_attempt_id=attempt_id
)
continue
ongoing_pairs.add((attempt.connector_id, attempt.credential_id))
enabled_connectors = fetch_connectors(db_session, disabled_status=False)
for connector in enabled_connectors:
for association in connector.credentials:
credential = association.credential
# check if there is an ogoing indexing attempt for this connector + credential pair
if (connector.id, credential.id) in ongoing_pairs:
if attempt is None:
logger.error(
f"Unable to find IndexAttempt for ID '{attempt_id}' when creating "
"indexing jobs"
)
continue
ongoing_pairs.add((attempt.connector_id, attempt.credential_id))
last_attempt = get_last_attempt(connector.id, credential.id, db_session)
if not should_create_new_indexing(connector, last_attempt, db_session):
continue
create_index_attempt(connector.id, credential.id, db_session)
enabled_connectors = fetch_connectors(db_session, disabled_status=False)
for connector in enabled_connectors:
for association in connector.credentials:
credential = association.credential
update_connector_credential_pair(
db_session=db_session,
connector_id=connector.id,
credential_id=credential.id,
attempt_status=IndexingStatus.NOT_STARTED,
)
# check if there is an ongoing indexing attempt for this connector + credential pair
if (connector.id, credential.id) in ongoing_pairs:
continue
last_attempt = get_last_attempt(connector.id, credential.id, db_session)
if not _should_create_new_indexing(connector, last_attempt, db_session):
continue
create_index_attempt(connector.id, credential.id, db_session)
update_connector_credential_pair(
db_session=db_session,
connector_id=connector.id,
credential_id=credential.id,
attempt_status=IndexingStatus.NOT_STARTED,
)
def cleanup_indexing_jobs(
db_session: Session, existing_jobs: dict[int, Future | SimpleJob]
existing_jobs: dict[int, Future | SimpleJob],
timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
) -> dict[int, Future | SimpleJob]:
existing_jobs_copy = existing_jobs.copy()
# clean up completed jobs
for attempt_id, job in existing_jobs.items():
# do nothing for ongoing jobs
if not job.done():
continue
if job.status == "error":
logger.error(job.exception())
job.release()
del existing_jobs_copy[attempt_id]
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=attempt_id
)
if not index_attempt:
logger.error(
f"Unable to find IndexAttempt for ID '{attempt_id}' when cleaning "
"up indexing jobs"
)
continue
if index_attempt.status == IndexingStatus.IN_PROGRESS or job.status == "error":
mark_run_failed(
db_session=db_session,
index_attempt=index_attempt,
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
with Session(get_sqlalchemy_engine()) as db_session:
for attempt_id, job in existing_jobs.items():
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=attempt_id
)
# clean up in-progress jobs that were never completed
connectors = fetch_connectors(db_session)
for connector in connectors:
in_progress_indexing_attempts = get_inprogress_index_attempts(
connector.id, db_session
)
for index_attempt in in_progress_indexing_attempts:
if index_attempt.id in existing_jobs:
# check to see if the job has been updated in the 3 hours, if not
# assume it to frozen in some bad state and just mark it as failed. Note: this relies
# on the fact that the `time_updated` field is constantly updated every
# batch of documents indexed
current_db_time = get_db_current_time(db_session=db_session)
time_since_update = current_db_time - index_attempt.time_updated
if time_since_update.total_seconds() > 60 * 60:
existing_jobs[index_attempt.id].cancel()
mark_run_failed(
db_session=db_session,
index_attempt=index_attempt,
failure_reason="Indexing run frozen - no updates in an hour. "
"The run will be re-attempted at next scheduled indexing time.",
)
else:
# If job isn't known, simply mark it as failed
mark_run_failed(
# do nothing for ongoing jobs that haven't been stopped
if not job.done() and not _is_indexing_job_marked_as_finished(
index_attempt
):
continue
if job.status == "error":
logger.error(job.exception())
job.release()
del existing_jobs_copy[attempt_id]
if not index_attempt:
logger.error(
f"Unable to find IndexAttempt for ID '{attempt_id}' when cleaning "
"up indexing jobs"
)
continue
if (
index_attempt.status == IndexingStatus.IN_PROGRESS
or job.status == "error"
):
_mark_run_failed(
db_session=db_session,
index_attempt=index_attempt,
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
)
# clean up in-progress jobs that were never completed
connectors = fetch_connectors(db_session)
for connector in connectors:
in_progress_indexing_attempts = get_inprogress_index_attempts(
connector.id, db_session
)
for index_attempt in in_progress_indexing_attempts:
if index_attempt.id in existing_jobs:
# check to see if the job has been updated in last n hours, if not
# assume it to frozen in some bad state and just mark it as failed. Note: this relies
# on the fact that the `time_updated` field is constantly updated every
# batch of documents indexed
current_db_time = get_db_current_time(db_session=db_session)
time_since_update = current_db_time - index_attempt.time_updated
if time_since_update.total_seconds() > 60 * 60 * timeout_hours:
existing_jobs[index_attempt.id].cancel()
_mark_run_failed(
db_session=db_session,
index_attempt=index_attempt,
failure_reason="Indexing run frozen - no updates in an hour. "
"The run will be re-attempted at next scheduled indexing time.",
)
else:
# If job isn't known, simply mark it as failed
_mark_run_failed(
db_session=db_session,
index_attempt=index_attempt,
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
)
return existing_jobs_copy
def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
) -> None:
"""
1. Get documents which are either new or updated from specified application
2. Embed and index these documents into the chosen datastore (vespa)
3. Updates Postgres to record the indexed documents + the outcome of this run
"""
def _get_document_generator(
db_session: Session, attempt: IndexAttempt
) -> tuple[GenerateDocumentsOutput, float]:
# "official" timestamp for this run
# used for setting time bounds when fetching updates from apps and
# is stored in the DB as the last successful run time if this run succeeds
run_time = time.time()
run_dt = datetime.fromtimestamp(run_time, tz=timezone.utc)
run_time_str = run_dt.strftime("%Y-%m-%d %H:%M:%S")
task = attempt.connector.input_type
try:
runnable_connector, new_credential_json = instantiate_connector(
attempt.connector.source,
task,
attempt.connector.connector_specific_config,
attempt.credential.credential_json,
)
if new_credential_json is not None:
backend_update_credential_json(
attempt.credential, new_credential_json, db_session
)
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")
disable_connector(attempt.connector.id, db_session)
raise e
if task == InputType.LOAD_STATE:
assert isinstance(runnable_connector, LoadConnector)
doc_batch_generator = runnable_connector.load_from_state()
elif task == InputType.POLL:
assert isinstance(runnable_connector, PollConnector)
if attempt.connector_id is None or attempt.credential_id is None:
raise ValueError(
f"Polling attempt {attempt.id} is missing connector_id or credential_id, "
f"can't fetch time range."
)
last_run_time = get_last_successful_attempt_time(
attempt.connector_id, attempt.credential_id, db_session
)
last_run_time_str = datetime.fromtimestamp(
last_run_time, tz=timezone.utc
).strftime("%Y-%m-%d %H:%M:%S")
logger.info(
f"Polling for updates between {last_run_time_str} and {run_time_str}"
)
doc_batch_generator = runnable_connector.poll_source(
start=last_run_time, end=run_time
)
else:
# Event types cannot be handled by a background type
raise RuntimeError(f"Invalid task type: {task}")
return doc_batch_generator, run_time
doc_batch_generator, run_time = _get_document_generator(db_session, index_attempt)
def _index(
db_session: Session,
attempt: IndexAttempt,
doc_batch_generator: GenerateDocumentsOutput,
run_time: float,
) -> None:
indexing_pipeline = build_indexing_pipeline()
run_dt = datetime.fromtimestamp(run_time, tz=timezone.utc)
db_connector = attempt.connector
db_credential = attempt.credential
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
attempt_status=IndexingStatus.IN_PROGRESS,
run_dt=run_dt,
)
net_doc_change = 0
document_count = 0
chunk_count = 0
try:
for doc_batch in doc_batch_generator:
logger.debug(
f"Indexing batch of documents: {[doc.to_short_descriptor() for doc in doc_batch]}"
)
new_docs, total_batch_chunks = indexing_pipeline(
documents=doc_batch,
index_attempt_metadata=IndexAttemptMetadata(
connector_id=db_connector.id,
credential_id=db_credential.id,
),
)
net_doc_change += new_docs
chunk_count += total_batch_chunks
document_count += len(doc_batch)
# commit transaction so that the `update` below begins
# with a brand new transaction. Postgres uses the start
# of the transactions when computing `NOW()`, so if we have
# a long running transaction, the `time_updated` field will
# be inaccurate
db_session.commit()
# This new value is updated every batch, so UI can refresh per batch update
update_docs_indexed(
db_session=db_session,
index_attempt=attempt,
total_docs_indexed=document_count,
new_docs_indexed=net_doc_change,
)
# check if connector is disabled mid run and stop if so
db_session.refresh(db_connector)
if db_connector.disabled:
# let the `except` block handle this
raise RuntimeError("Connector was disabled mid run")
mark_attempt_succeeded(attempt, db_session)
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
attempt_status=IndexingStatus.SUCCESS,
net_docs=net_doc_change,
run_dt=run_dt,
)
logger.info(
f"Indexed or updated {document_count} total documents for a total of {chunk_count} chunks"
)
logger.info(
f"Connector successfully finished, elapsed time: {time.time() - run_time} seconds"
)
except Exception as e:
logger.info(
f"Failed connector elapsed time: {time.time() - run_time} seconds"
)
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
# The last attempt won't be marked failed until the next cycle's check for still in-progress attempts
# The connector_credential_pair is marked failed here though to reflect correctly in UI asap
update_connector_credential_pair(
db_session=db_session,
connector_id=attempt.connector.id,
credential_id=attempt.credential.id,
attempt_status=IndexingStatus.FAILED,
net_docs=net_doc_change,
run_dt=run_dt,
)
raise e
_index(db_session, index_attempt, doc_batch_generator, run_time)
def _run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None:
"""Entrypoint for indexing run when using dask distributed.
Wraps the actual logic in a `try` block so that we can catch any exceptions
and mark the attempt as failed."""
try:
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
logger.info(f"Setting task to use {num_threads} threads")
torch.set_num_threads(num_threads)
with Session(get_sqlalchemy_engine()) as db_session:
attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
if attempt is None:
raise RuntimeError(
f"Unable to find IndexAttempt for ID '{index_attempt_id}'"
)
logger.info(
f"Running indexing attempt for connector: '{attempt.connector.name}', "
f"with config: '{attempt.connector.connector_specific_config}', and "
f"with credentials: '{attempt.credential_id}'"
)
mark_attempt_in_progress(attempt, db_session)
update_connector_credential_pair(
db_session=db_session,
connector_id=attempt.connector.id,
credential_id=attempt.credential.id,
attempt_status=IndexingStatus.IN_PROGRESS,
)
_run_indexing(
db_session=db_session,
index_attempt=attempt,
)
logger.info(
f"Completed indexing attempt for connector: '{attempt.connector.name}', "
f"with config: '{attempt.connector.connector_specific_config}', and "
f"with credentials: '{attempt.credential_id}'"
)
except Exception as e:
logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}")
def kickoff_indexing_jobs(
db_session: Session,
existing_jobs: dict[int, Future | SimpleJob],
client: Client | SimpleJobClient,
) -> dict[int, Future | SimpleJob]:
existing_jobs_copy = existing_jobs.copy()
engine = get_sqlalchemy_engine()
# Don't include jobs waiting in the Dask queue that just haven't started running
# Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
new_indexing_attempts = [
attempt
for attempt in get_not_started_index_attempts(db_session)
if attempt.id not in existing_jobs
]
with Session(engine) as db_session:
new_indexing_attempts = [
attempt
for attempt in get_not_started_index_attempts(db_session)
if attempt.id not in existing_jobs
]
logger.info(f"Found {len(new_indexing_attempts)} new indexing tasks.")
@@ -449,19 +256,23 @@ def kickoff_indexing_jobs(
logger.warning(
f"Skipping index attempt as Connector has been deleted: {attempt}"
)
mark_attempt_failed(attempt, db_session, failure_reason="Connector is null")
with Session(engine) as db_session:
mark_attempt_failed(
attempt, db_session, failure_reason="Connector is null"
)
continue
if attempt.credential is None:
logger.warning(
f"Skipping index attempt as Credential has been deleted: {attempt}"
)
mark_attempt_failed(
attempt, db_session, failure_reason="Credential is null"
)
with Session(engine) as db_session:
mark_attempt_failed(
attempt, db_session, failure_reason="Credential is null"
)
continue
run = client.submit(
_run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False
run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False
)
if run:
logger.info(
@@ -476,9 +287,7 @@ def kickoff_indexing_jobs(
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
client: Client | SimpleJobClient
if EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED:
client = SimpleJobClient(n_workers=num_workers)
else:
if DASK_JOB_CLIENT_ENABLED:
cluster = LocalCluster(
n_workers=num_workers,
threads_per_worker=1,
@@ -489,6 +298,10 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
silence_logs=logging.ERROR,
)
client = Client(cluster)
if LOG_LEVEL.lower() == "debug":
client.register_worker_plugin(ResourceLogger())
else:
client = SimpleJobClient(n_workers=num_workers)
existing_jobs: dict[int, Future | SimpleJob] = {}
engine = get_sqlalchemy_engine()
@@ -502,15 +315,20 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
start = time.time()
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
logger.info(f"Running update, current UTC time: {start_time_utc}")
if existing_jobs:
# TODO: make this debug level once the "no jobs are being scheduled" issue is resolved
logger.info(
"Found existing indexing jobs: "
f"{[(attempt_id, job.status) for attempt_id, job in existing_jobs.items()]}"
)
try:
with Session(engine, expire_on_commit=False) as db_session:
existing_jobs = cleanup_indexing_jobs(
db_session=db_session, existing_jobs=existing_jobs
)
create_indexing_jobs(db_session=db_session, existing_jobs=existing_jobs)
existing_jobs = kickoff_indexing_jobs(
db_session=db_session, existing_jobs=existing_jobs, client=client
)
existing_jobs = cleanup_indexing_jobs(existing_jobs=existing_jobs)
create_indexing_jobs(existing_jobs=existing_jobs)
existing_jobs = kickoff_indexing_jobs(
existing_jobs=existing_jobs, client=client
)
except Exception as e:
logger.exception(f"Failed to run update due to {e}")
sleep_time = delay - (time.time() - start)
@@ -518,8 +336,19 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
time.sleep(sleep_time)
if __name__ == "__main__":
logger.info("Warming up Embedding Model(s)")
warm_up_models(indexer_only=True)
def update__main() -> None:
# needed for CUDA to work with multiprocessing
# NOTE: needs to be done on application startup
# before any other torch code has been run
if not DASK_JOB_CLIENT_ENABLED:
torch.multiprocessing.set_start_method("spawn")
if not MODEL_SERVER_HOST:
logger.info("Warming up Embedding Model(s)")
warm_up_models(indexer_only=True, skip_cross_encoders=True)
logger.info("Starting Indexing Loop")
update_loop()
if __name__ == "__main__":
update__main()

View File

@@ -1,581 +0,0 @@
import re
from collections.abc import Callable
from collections.abc import Iterator
from langchain.schema.messages import AIMessage
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from sqlalchemy.orm import Session
from danswer.chat.chat_prompts import build_combined_query
from danswer.chat.chat_prompts import DANSWER_TOOL_NAME
from danswer.chat.chat_prompts import form_require_search_text
from danswer.chat.chat_prompts import form_tool_followup_text
from danswer.chat.chat_prompts import form_tool_less_followup_text
from danswer.chat.chat_prompts import form_tool_section_text
from danswer.chat.chat_prompts import form_user_prompt_text
from danswer.chat.chat_prompts import format_danswer_chunks_for_chat
from danswer.chat.chat_prompts import REQUIRE_DANSWER_SYSTEM_MSG
from danswer.chat.chat_prompts import YES_SEARCH
from danswer.chat.personas import build_system_text_from_persona
from danswer.chat.tools import call_tool
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_CHAT
from danswer.configs.chat_configs import FORCE_TOOL_PROMPT
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.configs.model_configs import GEN_AI_MAX_INPUT_TOKENS
from danswer.db.models import ChatMessage
from danswer.db.models import Persona
from danswer.db.models import User
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import DanswerChatModelOut
from danswer.direct_qa.interfaces import StreamingError
from danswer.direct_qa.qa_utils import get_usable_chunks
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.models import InferenceChunk
from danswer.llm.factory import get_default_llm
from danswer.llm.interfaces import LLM
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import translate_danswer_msg_to_langchain
from danswer.search.access_filters import build_access_filters_for_user
from danswer.search.models import IndexFilters
from danswer.search.models import SearchQuery
from danswer.search.models import SearchType
from danswer.search.search_runner import chunks_to_search_docs
from danswer.search.search_runner import search_chunks
from danswer.server.models import RetrievalDocs
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import extract_embedded_json
from danswer.utils.text_processing import has_unescaped_quote
logger = setup_logger()
LLM_CHAT_FAILURE_MSG = "The large-language-model failed to generate a valid response."
def _parse_embedded_json_streamed_response(
tokens: Iterator[str],
) -> Iterator[DanswerAnswerPiece | DanswerChatModelOut]:
final_answer = False
just_start_stream = False
model_output = ""
hold = ""
finding_end = 0
for token in tokens:
model_output += token
hold += token
if (
final_answer is False
and '"action":"finalanswer",' in model_output.lower().replace(" ", "")
):
final_answer = True
if final_answer and '"actioninput":"' in model_output.lower().replace(
" ", ""
).replace("_", ""):
if not just_start_stream:
just_start_stream = True
hold = ""
if has_unescaped_quote(hold):
finding_end += 1
hold = hold[: hold.find('"')]
if finding_end <= 1:
if finding_end == 1:
finding_end += 1
yield DanswerAnswerPiece(answer_piece=hold)
hold = ""
model_final = extract_embedded_json(model_output)
if "action" not in model_final or "action_input" not in model_final:
raise ValueError("Model did not provide all required action values")
yield DanswerChatModelOut(
model_raw=model_output,
action=model_final["action"],
action_input=model_final["action_input"],
)
return
def _find_last_index(
lst: list[int], max_prompt_tokens: int = GEN_AI_MAX_INPUT_TOKENS
) -> int:
"""From the back, find the index of the last element to include
before the list exceeds the maximum"""
running_sum = 0
last_ind = 0
for i in range(len(lst) - 1, -1, -1):
running_sum += lst[i]
if running_sum > max_prompt_tokens:
last_ind = i + 1
break
if last_ind >= len(lst):
raise ValueError("Last message alone is too large!")
return last_ind
def danswer_chat_retrieval(
query_message: ChatMessage,
history: list[ChatMessage],
llm: LLM,
filters: IndexFilters,
) -> list[InferenceChunk]:
if history:
query_combination_msgs = build_combined_query(query_message, history)
reworded_query = llm.invoke(query_combination_msgs)
else:
reworded_query = query_message.message
search_query = SearchQuery(
query=reworded_query,
search_type=SearchType.HYBRID,
filters=filters,
favor_recent=False,
)
# Good Debug/Breakpoint
ranked_chunks, unranked_chunks = search_chunks(
query=search_query, document_index=get_default_document_index()
)
if not ranked_chunks:
return []
if unranked_chunks:
ranked_chunks.extend(unranked_chunks)
filtered_ranked_chunks = [
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
]
# get all chunks that fit into the token limit
usable_chunks = get_usable_chunks(
chunks=filtered_ranked_chunks,
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_CHAT,
)
return usable_chunks
def _drop_messages_history_overflow(
system_msg: BaseMessage | None,
system_token_count: int,
history_msgs: list[BaseMessage],
history_token_counts: list[int],
final_msg: BaseMessage,
final_msg_token_count: int,
) -> list[BaseMessage]:
"""As message history grows, messages need to be dropped starting from the furthest in the past.
The System message should be kept if at all possible and the latest user input which is inserted in the
prompt template must be included"""
if len(history_msgs) != len(history_token_counts):
# This should never happen
raise ValueError("Need exactly 1 token count per message for tracking overflow")
prompt: list[BaseMessage] = []
# Start dropping from the history if necessary
all_tokens = history_token_counts + [system_token_count, final_msg_token_count]
ind_prev_msg_start = _find_last_index(all_tokens)
if system_msg and ind_prev_msg_start <= len(history_msgs):
prompt.append(system_msg)
prompt.extend(history_msgs[ind_prev_msg_start:])
prompt.append(final_msg)
return prompt
def extract_citations_from_stream(
tokens: Iterator[str], links: list[str | None]
) -> Iterator[str]:
if not links:
yield from tokens
return
max_citation_num = len(links) + 1 # LLM is prompted to 1 index these
curr_segment = ""
prepend_bracket = False
for token in tokens:
# Special case of [1][ where ][ is a single token
if prepend_bracket:
curr_segment += "[" + curr_segment
prepend_bracket = False
curr_segment += token
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
citation_pattern = r"\[(\d+)\]" # [1], [2] etc
citation_found = re.search(citation_pattern, curr_segment)
if citation_found:
numerical_value = int(citation_found.group(1))
if 1 <= numerical_value <= max_citation_num:
link = links[numerical_value - 1]
if link:
curr_segment = re.sub(r"\[", "[[", curr_segment, count=1)
curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1)
# In case there's another open bracket like [1][, don't want to match this
possible_citation_found = None
# if we see "[", but haven't seen the right side, hold back - this may be a
# citation that needs to be replaced with a link
if possible_citation_found:
continue
# Special case with back to back citations [1][2]
if curr_segment and curr_segment[-1] == "[":
curr_segment = curr_segment[:-1]
prepend_bracket = True
yield curr_segment
curr_segment = ""
if curr_segment:
if prepend_bracket:
yield "[" + curr_segment
else:
yield curr_segment
def llm_contextless_chat_answer(
messages: list[ChatMessage],
system_text: str | None = None,
tokenizer: Callable | None = None,
) -> Iterator[DanswerAnswerPiece | StreamingError]:
try:
prompt_msgs = [translate_danswer_msg_to_langchain(msg) for msg in messages]
if system_text:
tokenizer = tokenizer or get_default_llm_tokenizer()
system_tokens = len(tokenizer(system_text))
system_msg = SystemMessage(content=system_text)
message_tokens = [msg.token_count for msg in messages] + [system_tokens]
else:
message_tokens = [msg.token_count for msg in messages]
last_msg_ind = _find_last_index(message_tokens)
remaining_user_msgs = prompt_msgs[last_msg_ind:]
if not remaining_user_msgs:
raise ValueError("Last user message is too long!")
if system_text:
all_msgs = [system_msg] + remaining_user_msgs
else:
all_msgs = remaining_user_msgs
for token in get_default_llm().stream(all_msgs):
yield DanswerAnswerPiece(answer_piece=token)
except Exception as e:
logger.exception(f"LLM failed to produce valid chat message, error: {e}")
yield StreamingError(error=str(e))
def llm_contextual_chat_answer(
messages: list[ChatMessage],
persona: Persona,
user: User | None,
tokenizer: Callable,
db_session: Session,
run_search_system_text: str = REQUIRE_DANSWER_SYSTEM_MSG,
) -> Iterator[DanswerAnswerPiece | RetrievalDocs | StreamingError]:
last_message = messages[-1]
final_query_text = last_message.message
previous_messages = messages[:-1]
previous_msgs_as_basemessage = [
translate_danswer_msg_to_langchain(msg) for msg in previous_messages
]
try:
llm = get_default_llm()
if not final_query_text:
raise ValueError("User chat message is empty.")
# Determine if a search is necessary to answer the user query
user_req_search_text = form_require_search_text(last_message)
last_user_msg = HumanMessage(content=user_req_search_text)
previous_msg_token_counts = [msg.token_count for msg in previous_messages]
danswer_system_tokens = len(tokenizer(run_search_system_text))
last_user_msg_tokens = len(tokenizer(user_req_search_text))
need_search_prompt = _drop_messages_history_overflow(
system_msg=SystemMessage(content=run_search_system_text),
system_token_count=danswer_system_tokens,
history_msgs=previous_msgs_as_basemessage,
history_token_counts=previous_msg_token_counts,
final_msg=last_user_msg,
final_msg_token_count=last_user_msg_tokens,
)
# Good Debug/Breakpoint
model_out = llm.invoke(need_search_prompt)
# Model will output "Yes Search" if search is useful
# Be a little forgiving though, if we match yes, it's good enough
retrieved_chunks: list[InferenceChunk] = []
if (YES_SEARCH.split()[0] + " ").lower() in model_out.lower():
user_acl_filters = build_access_filters_for_user(user, db_session)
doc_set_filter = [doc_set.name for doc_set in persona.document_sets] or None
final_filters = IndexFilters(
source_type=None,
document_set=doc_set_filter,
time_cutoff=None,
access_control_list=user_acl_filters,
)
retrieved_chunks = danswer_chat_retrieval(
query_message=last_message,
history=previous_messages,
llm=llm,
filters=final_filters,
)
yield RetrievalDocs(top_documents=chunks_to_search_docs(retrieved_chunks))
tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks)
last_user_msg_text = form_tool_less_followup_text(
tool_output=tool_result_str,
query=last_message.message,
hint_text=persona.hint_text,
)
last_user_msg_tokens = len(tokenizer(last_user_msg_text))
last_user_msg = HumanMessage(content=last_user_msg_text)
else:
last_user_msg_tokens = len(tokenizer(final_query_text))
last_user_msg = HumanMessage(content=final_query_text)
system_text = build_system_text_from_persona(persona)
system_msg = SystemMessage(content=system_text) if system_text else None
system_tokens = len(tokenizer(system_text)) if system_text else 0
prompt = _drop_messages_history_overflow(
system_msg=system_msg,
system_token_count=system_tokens,
history_msgs=previous_msgs_as_basemessage,
history_token_counts=previous_msg_token_counts,
final_msg=last_user_msg,
final_msg_token_count=last_user_msg_tokens,
)
# Good Debug/Breakpoint
tokens = llm.stream(prompt)
links = [
chunk.source_links[0] if chunk.source_links else None
for chunk in retrieved_chunks
]
for segment in extract_citations_from_stream(tokens, links):
yield DanswerAnswerPiece(answer_piece=segment)
except Exception as e:
logger.exception(f"LLM failed to produce valid chat message, error: {e}")
yield StreamingError(error=str(e))
def llm_tools_enabled_chat_answer(
messages: list[ChatMessage],
persona: Persona,
user: User | None,
tokenizer: Callable,
db_session: Session,
) -> Iterator[DanswerAnswerPiece | RetrievalDocs | StreamingError]:
retrieval_enabled = persona.retrieval_enabled
system_text = build_system_text_from_persona(persona)
hint_text = persona.hint_text
tool_text = form_tool_section_text(persona.tools, persona.retrieval_enabled)
last_message = messages[-1]
previous_messages = messages[:-1]
previous_msgs_as_basemessage = [
translate_danswer_msg_to_langchain(msg) for msg in previous_messages
]
# Failure reasons include:
# - Invalid LLM output, wrong format or wrong/missing keys
# - No "Final Answer" from model after tool calling
# - LLM times out or is otherwise unavailable
# - Calling invalid tool or tool call fails
# - Last message has more tokens than model is set to accept
# - Missing user input
try:
if not last_message.message:
raise ValueError("User chat message is empty.")
# Build the prompt using the last user message
user_text = form_user_prompt_text(
query=last_message.message,
tool_text=tool_text,
hint_text=hint_text,
)
last_user_msg = HumanMessage(content=user_text)
# Count tokens once to reuse
previous_msg_token_counts = [msg.token_count for msg in previous_messages]
system_tokens = len(tokenizer(system_text)) if system_text else 0
last_user_msg_tokens = len(tokenizer(user_text))
prompt = _drop_messages_history_overflow(
system_msg=SystemMessage(content=system_text) if system_text else None,
system_token_count=system_tokens,
history_msgs=previous_msgs_as_basemessage,
history_token_counts=previous_msg_token_counts,
final_msg=last_user_msg,
final_msg_token_count=last_user_msg_tokens,
)
llm = get_default_llm()
# Good Debug/Breakpoint
tokens = llm.stream(prompt)
final_result: DanswerChatModelOut | None = None
final_answer_streamed = False
for result in _parse_embedded_json_streamed_response(tokens):
if isinstance(result, DanswerAnswerPiece) and result.answer_piece:
yield result
final_answer_streamed = True
if isinstance(result, DanswerChatModelOut):
final_result = result
break
if final_answer_streamed:
return
if final_result is None:
raise RuntimeError("Model output finished without final output parsing.")
if (
retrieval_enabled
and final_result.action.lower() == DANSWER_TOOL_NAME.lower()
):
user_acl_filters = build_access_filters_for_user(user, db_session)
doc_set_filter = [doc_set.name for doc_set in persona.document_sets] or None
final_filters = IndexFilters(
source_type=None,
document_set=doc_set_filter,
time_cutoff=None,
access_control_list=user_acl_filters,
)
retrieved_chunks = danswer_chat_retrieval(
query_message=last_message,
history=previous_messages,
llm=llm,
filters=final_filters,
)
yield RetrievalDocs(top_documents=chunks_to_search_docs(retrieved_chunks))
tool_result_str = format_danswer_chunks_for_chat(retrieved_chunks)
else:
tool_result_str = call_tool(final_result)
# The AI's tool calling message
tool_call_msg_text = final_result.model_raw
tool_call_msg_token_count = len(tokenizer(tool_call_msg_text))
# Create the new message to use the results of the tool call
tool_followup_text = form_tool_followup_text(
tool_output=tool_result_str,
query=last_message.message,
hint_text=hint_text,
)
tool_followup_msg = HumanMessage(content=tool_followup_text)
tool_followup_tokens = len(tokenizer(tool_followup_text))
# Drop previous messages, the drop order goes: previous messages in the history,
# the last user prompt and generated intermediate messages from this recent prompt,
# the system message, then finally the tool message that was the last thing generated
follow_up_prompt = _drop_messages_history_overflow(
system_msg=SystemMessage(content=system_text) if system_text else None,
system_token_count=system_tokens,
history_msgs=previous_msgs_as_basemessage
+ [last_user_msg, AIMessage(content=tool_call_msg_text)],
history_token_counts=previous_msg_token_counts
+ [last_user_msg_tokens, tool_call_msg_token_count],
final_msg=tool_followup_msg,
final_msg_token_count=tool_followup_tokens,
)
# Good Debug/Breakpoint
tokens = llm.stream(follow_up_prompt)
for result in _parse_embedded_json_streamed_response(tokens):
if isinstance(result, DanswerAnswerPiece) and result.answer_piece:
yield result
final_answer_streamed = True
if final_answer_streamed is False:
raise RuntimeError("LLM did not to produce a Final Answer after tool call")
except Exception as e:
logger.exception(f"LLM failed to produce valid chat message, error: {e}")
yield StreamingError(error=str(e))
def llm_chat_answer(
messages: list[ChatMessage],
persona: Persona | None,
tokenizer: Callable,
user: User | None,
db_session: Session,
) -> Iterator[DanswerAnswerPiece | RetrievalDocs | StreamingError]:
# Common error cases to keep in mind:
# - User asks question about something long ago, due to context limit, the message is dropped
# - Tool use gives wrong/irrelevant results, model gets confused by the noise
# - Model is too weak of an LLM, fails to follow instructions
# - Bad persona design leads to confusing instructions to the model
# - Bad configurations, too small token limit, mismatched tokenizer to LLM, etc.
# No setting/persona available therefore no retrieval and no additional tools
if persona is None:
return llm_contextless_chat_answer(messages)
# Persona is configured but with retrieval off and no tools
# therefore cannot retrieve any context so contextless
elif persona.retrieval_enabled is False and not persona.tools:
return llm_contextless_chat_answer(
messages, system_text=persona.system_text, tokenizer=tokenizer
)
# No additional tools outside of Danswer retrieval, can use a more basic prompt
# Doesn't require tool calling output format (all LLM outputs are therefore valid)
elif persona.retrieval_enabled and not persona.tools and not FORCE_TOOL_PROMPT:
return llm_contextual_chat_answer(
messages=messages,
persona=persona,
tokenizer=tokenizer,
user=user,
db_session=db_session,
)
# Use most flexible/complex prompt format that allows arbitrary tool calls
# that are configured in the persona file
# WARNING: this flow does not work well with weaker LLMs (anything below GPT-4)
return llm_tools_enabled_chat_answer(
messages=messages,
persona=persona,
tokenizer=tokenizer,
user=user,
db_session=db_session,
)

View File

@@ -1,274 +0,0 @@
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from danswer.configs.constants import CODE_BLOCK_PAT
from danswer.configs.constants import MessageType
from danswer.db.models import ChatMessage
from danswer.db.models import ToolInfo
from danswer.indexing.models import InferenceChunk
from danswer.llm.utils import translate_danswer_msg_to_langchain
DANSWER_TOOL_NAME = "Current Search"
DANSWER_TOOL_DESCRIPTION = (
"A search tool that can find information on any topic "
"including up to date and proprietary knowledge."
)
DANSWER_SYSTEM_MSG = (
"Given a conversation (between Human and Assistant) and a final message from Human, "
"rewrite the last message to be a standalone question which captures required/relevant context "
"from previous messages. This question must be useful for a semantic search engine. "
"It is used for a natural language search."
)
YES_SEARCH = "Yes Search"
NO_SEARCH = "No Search"
REQUIRE_DANSWER_SYSTEM_MSG = (
"You are a large language model whose only job is to determine if the system should call an external search tool "
"to be able to answer the user's last message.\n"
f'\nRespond with "{NO_SEARCH}" if:\n'
f"- there is sufficient information in chat history to fully answer the user query\n"
f"- there is enough knowledge in the LLM to fully answer the user query\n"
f"- the user query does not rely on any specific knowledge\n"
f'\nRespond with "{YES_SEARCH}" if:\n'
"- additional knowledge about entities, processes, problems, or anything else could lead to a better answer.\n"
"- there is some uncertainty what the user is referring to\n\n"
f'Respond with EXACTLY and ONLY "{YES_SEARCH}" or "{NO_SEARCH}"'
)
TOOL_TEMPLATE = """
TOOLS
------
You can use tools to look up information that may be helpful in answering the user's \
original question. The available tools are:
{tool_overviews}
RESPONSE FORMAT INSTRUCTIONS
----------------------------
When responding to me, please output a response in one of two formats:
**Option 1:**
Use this if you want to use a tool. Markdown code snippet formatted in the following schema:
```json
{{
"action": string, \\ The action to take. {tool_names}
"action_input": string \\ The input to the action
}}
```
**Option #2:**
Use this if you want to respond directly to the user. Markdown code snippet formatted in the following schema:
```json
{{
"action": "Final Answer",
"action_input": string \\ You should put what you want to return to use here
}}
```
"""
TOOL_LESS_PROMPT = """
Respond with a markdown code snippet in the following schema:
```json
{{
"action": "Final Answer",
"action_input": string \\ You should put what you want to return to use here
}}
```
"""
USER_INPUT = """
USER'S INPUT
--------------------
Here is the user's input \
(remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else):
{user_input}
"""
TOOL_FOLLOWUP = """
TOOL RESPONSE:
---------------------
{tool_output}
USER'S INPUT
--------------------
Okay, so what is the response to my last comment? If using information obtained from the tools you must \
mention it explicitly without mentioning the tool names - I have forgotten all TOOL RESPONSES!
If the tool response is not useful, ignore it completely.
{optional_reminder}{hint}
IMPORTANT! You MUST respond with a markdown code snippet of a json blob with a single action, and NOTHING else.
"""
TOOL_LESS_FOLLOWUP = """
Refer to the following documents when responding to my final query. Ignore any documents that are not relevant.
CONTEXT DOCUMENTS:
---------------------
{context_str}
FINAL QUERY:
--------------------
{user_query}
{hint_text}
"""
def form_user_prompt_text(
query: str,
tool_text: str | None,
hint_text: str | None,
user_input_prompt: str = USER_INPUT,
tool_less_prompt: str = TOOL_LESS_PROMPT,
) -> str:
user_prompt = tool_text or tool_less_prompt
user_prompt += user_input_prompt.format(user_input=query)
if hint_text:
if user_prompt[-1] != "\n":
user_prompt += "\n"
user_prompt += "\nHint: " + hint_text
return user_prompt.strip()
def form_tool_section_text(
tools: list[ToolInfo] | None, retrieval_enabled: bool, template: str = TOOL_TEMPLATE
) -> str | None:
if not tools and not retrieval_enabled:
return None
if retrieval_enabled and tools:
tools.append(
{"name": DANSWER_TOOL_NAME, "description": DANSWER_TOOL_DESCRIPTION}
)
tools_intro = []
if tools:
num_tools = len(tools)
for tool in tools:
description_formatted = tool["description"].replace("\n", " ")
tools_intro.append(f"> {tool['name']}: {description_formatted}")
prefix = "Must be one of " if num_tools > 1 else "Must be "
tools_intro_text = "\n".join(tools_intro)
tool_names_text = prefix + ", ".join([tool["name"] for tool in tools])
else:
return None
return template.format(
tool_overviews=tools_intro_text, tool_names=tool_names_text
).strip()
def format_danswer_chunks_for_chat(chunks: list[InferenceChunk]) -> str:
if not chunks:
return "No Results Found"
return "\n".join(
f"DOCUMENT {ind}:{CODE_BLOCK_PAT.format(chunk.content)}"
for ind, chunk in enumerate(chunks, start=1)
)
def form_tool_followup_text(
tool_output: str,
query: str,
hint_text: str | None,
tool_followup_prompt: str = TOOL_FOLLOWUP,
ignore_hint: bool = False,
) -> str:
# If multi-line query, it likely confuses the model more than helps
if "\n" not in query:
optional_reminder = f"\nAs a reminder, my query was: {query}\n"
else:
optional_reminder = ""
if not ignore_hint and hint_text:
hint_text_spaced = f"\nHint: {hint_text}\n"
else:
hint_text_spaced = ""
return tool_followup_prompt.format(
tool_output=tool_output,
optional_reminder=optional_reminder,
hint=hint_text_spaced,
).strip()
def build_combined_query(
query_message: ChatMessage,
history: list[ChatMessage],
) -> list[BaseMessage]:
user_query = query_message.message
combined_query_msgs: list[BaseMessage] = []
if not user_query:
raise ValueError("Can't rephrase/search an empty query")
combined_query_msgs.append(SystemMessage(content=DANSWER_SYSTEM_MSG))
combined_query_msgs.extend(
[translate_danswer_msg_to_langchain(msg) for msg in history]
)
combined_query_msgs.append(
HumanMessage(
content=(
"Help me rewrite this final message into a standalone query that takes into consideration the "
f"past messages of the conversation if relevant. This query is used with a semantic search engine to "
f"retrieve documents. You must ONLY return the rewritten query and nothing else. "
f"Remember, the search engine does not have access to the conversation history!"
f"\n\nQuery:\n{query_message.message}"
)
)
)
return combined_query_msgs
def form_require_search_single_msg_text(
query_message: ChatMessage,
history: list[ChatMessage],
) -> str:
prompt = "MESSAGE_HISTORY\n---------------\n" if history else ""
for msg in history:
if msg.message_type == MessageType.ASSISTANT:
prefix = "AI"
else:
prefix = "User"
prompt += f"{prefix}:\n```\n{msg.message}\n```\n\n"
prompt += f"\nFINAL QUERY:\n---------------\n{query_message.message}"
return prompt
def form_require_search_text(query_message: ChatMessage) -> str:
return (
query_message.message
+ f"\n\nHint: respond with EXACTLY {YES_SEARCH} or {NO_SEARCH}"
)
def form_tool_less_followup_text(
tool_output: str,
query: str,
hint_text: str | None,
tool_followup_prompt: str = TOOL_LESS_FOLLOWUP,
) -> str:
hint = f"Hint: {hint_text}" if hint_text else ""
return tool_followup_prompt.format(
context_str=tool_output, user_query=query, hint_text=hint
).strip()

View File

@@ -0,0 +1,479 @@
import re
from collections.abc import Callable
from collections.abc import Iterator
from functools import lru_cache
from typing import cast
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from sqlalchemy.orm import Session
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.chat_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.configs.model_configs import GEN_AI_MAX_INPUT_TOKENS
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.models import ChatMessage
from danswer.db.models import Prompt
from danswer.indexing.models import InferenceChunk
from danswer.llm.utils import check_number_of_tokens
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from danswer.prompts.chat_prompts import CHAT_USER_PROMPT
from danswer.prompts.chat_prompts import CITATION_REMINDER
from danswer.prompts.chat_prompts import DEFAULT_IGNORE_STATEMENT
from danswer.prompts.chat_prompts import NO_CITATION_STATEMENT
from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
from danswer.prompts.constants import CODE_BLOCK_PAT
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
from danswer.prompts.prompt_utils import get_current_llm_day_time
# Maps connector enum string to a more natural language representation for the LLM
# If not on the list, uses the original but slightly cleaned up, see below
CONNECTOR_NAME_MAP = {
"web": "Website",
"requesttracker": "Request Tracker",
"github": "GitHub",
"file": "File Upload",
}
def clean_up_source(source_str: str) -> str:
if source_str in CONNECTOR_NAME_MAP:
return CONNECTOR_NAME_MAP[source_str]
return source_str.replace("_", " ").title()
def build_context_str(
context_docs: list[LlmDoc | InferenceChunk],
include_metadata: bool = True,
) -> str:
context_str = ""
for ind, doc in enumerate(context_docs, start=1):
if include_metadata:
context_str += f"DOCUMENT {ind}: {doc.semantic_identifier}\n"
context_str += f"Source: {clean_up_source(doc.source_type)}\n"
if doc.updated_at:
update_str = doc.updated_at.strftime("%B %d, %Y %H:%M")
context_str += f"Updated: {update_str}\n"
context_str += f"{CODE_BLOCK_PAT.format(doc.content.strip())}\n\n\n"
return context_str.strip()
@lru_cache()
def build_chat_system_message(
prompt: Prompt,
context_exists: bool,
llm_tokenizer: Callable,
citation_line: str = REQUIRE_CITATION_STATEMENT,
no_citation_line: str = NO_CITATION_STATEMENT,
) -> tuple[SystemMessage | None, int]:
system_prompt = prompt.system_prompt.strip()
if prompt.include_citations:
if context_exists:
system_prompt += citation_line
else:
system_prompt += no_citation_line
if prompt.datetime_aware:
if system_prompt:
system_prompt += (
f"\n\nAdditional Information:\n\t- {get_current_llm_day_time()}."
)
else:
system_prompt = get_current_llm_day_time()
if not system_prompt:
return None, 0
token_count = len(llm_tokenizer(system_prompt))
system_msg = SystemMessage(content=system_prompt)
return system_msg, token_count
def build_task_prompt_reminders(
prompt: Prompt,
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
citation_str: str = CITATION_REMINDER,
language_hint_str: str = LANGUAGE_HINT,
) -> str:
base_task = prompt.task_prompt
citation_or_nothing = citation_str if prompt.include_citations else ""
language_hint_or_nothing = language_hint_str.lstrip() if use_language_hint else ""
return base_task + citation_or_nothing + language_hint_or_nothing
def llm_doc_from_inference_chunk(inf_chunk: InferenceChunk) -> LlmDoc:
return LlmDoc(
document_id=inf_chunk.document_id,
content=inf_chunk.content,
semantic_identifier=inf_chunk.semantic_identifier,
source_type=inf_chunk.source_type,
updated_at=inf_chunk.updated_at,
link=inf_chunk.source_links[0] if inf_chunk.source_links else None,
)
def map_document_id_order(
chunks: list[InferenceChunk | LlmDoc], one_indexed: bool = True
) -> dict[str, int]:
order_mapping = {}
current = 1 if one_indexed else 0
for chunk in chunks:
if chunk.document_id not in order_mapping:
order_mapping[chunk.document_id] = current
current += 1
return order_mapping
def build_chat_user_message(
chat_message: ChatMessage,
prompt: Prompt,
context_docs: list[LlmDoc],
llm_tokenizer: Callable,
all_doc_useful: bool,
user_prompt_template: str = CHAT_USER_PROMPT,
context_free_template: str = CHAT_USER_CONTEXT_FREE_PROMPT,
ignore_str: str = DEFAULT_IGNORE_STATEMENT,
) -> tuple[HumanMessage, int]:
user_query = chat_message.message
if not context_docs:
# Simpler prompt for cases where there is no context
user_prompt = (
context_free_template.format(
task_prompt=prompt.task_prompt, user_query=user_query
)
if prompt.task_prompt
else user_query
)
user_prompt = user_prompt.strip()
token_count = len(llm_tokenizer(user_prompt))
user_msg = HumanMessage(content=user_prompt)
return user_msg, token_count
context_docs_str = build_context_str(
cast(list[LlmDoc | InferenceChunk], context_docs)
)
optional_ignore = "" if all_doc_useful else ignore_str
task_prompt_with_reminder = build_task_prompt_reminders(prompt)
user_prompt = user_prompt_template.format(
optional_ignore_statement=optional_ignore,
context_docs_str=context_docs_str,
task_prompt=task_prompt_with_reminder,
user_query=user_query,
)
user_prompt = user_prompt.strip()
token_count = len(llm_tokenizer(user_prompt))
user_msg = HumanMessage(content=user_prompt)
return user_msg, token_count
def _get_usable_chunks(
chunks: list[InferenceChunk], token_limit: int
) -> list[InferenceChunk]:
total_token_count = 0
usable_chunks = []
for chunk in chunks:
chunk_token_count = check_number_of_tokens(chunk.content)
if total_token_count + chunk_token_count > token_limit:
break
total_token_count += chunk_token_count
usable_chunks.append(chunk)
# try and return at least one chunk if possible. This chunk will
# get truncated later on in the pipeline. This would only occur if
# the first chunk is larger than the token limit (usually due to character
# count -> token count mismatches caused by special characters / non-ascii
# languages)
if not usable_chunks and chunks:
usable_chunks = [chunks[0]]
return usable_chunks
def get_usable_chunks(
chunks: list[InferenceChunk],
token_limit: int = NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
offset: int = 0,
) -> list[InferenceChunk]:
offset_into_chunks = 0
usable_chunks: list[InferenceChunk] = []
for _ in range(min(offset + 1, 1)): # go through this process at least once
if offset_into_chunks >= len(chunks) and offset_into_chunks > 0:
raise ValueError(
"Chunks offset too large, should not retry this many times"
)
usable_chunks = _get_usable_chunks(
chunks=chunks[offset_into_chunks:], token_limit=token_limit
)
offset_into_chunks += len(usable_chunks)
return usable_chunks
def get_chunks_for_qa(
chunks: list[InferenceChunk],
llm_chunk_selection: list[bool],
token_limit: float | None = NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
batch_offset: int = 0,
) -> list[int]:
"""
Gives back indices of chunks to pass into the LLM for Q&A.
Only selects chunks viable for Q&A, within the token limit, and prioritize those selected
by the LLM in a separate flow (this can be turned off)
Note, the batch_offset calculation has to count the batches from the beginning each time as
there's no way to know which chunks were included in the prior batches without recounting atm,
this is somewhat slow as it requires tokenizing all the chunks again
"""
batch_index = 0
latest_batch_indices: list[int] = []
token_count = 0
# First iterate the LLM selected chunks, then iterate the rest if tokens remaining
for selection_target in [True, False]:
for ind, chunk in enumerate(chunks):
if llm_chunk_selection[ind] is not selection_target or chunk.metadata.get(
IGNORE_FOR_QA
):
continue
# We calculate it live in case the user uses a different LLM + tokenizer
chunk_token = check_number_of_tokens(chunk.content)
# 50 for an approximate/slight overestimate for # tokens for metadata for the chunk
token_count += chunk_token + 50
# Always use at least 1 chunk
if (
token_limit is None
or token_count <= token_limit
or not latest_batch_indices
):
latest_batch_indices.append(ind)
current_chunk_unused = False
else:
current_chunk_unused = True
if token_limit is not None and token_count >= token_limit:
if batch_index < batch_offset:
batch_index += 1
if current_chunk_unused:
latest_batch_indices = [ind]
token_count = chunk_token
else:
latest_batch_indices = []
token_count = 0
else:
return latest_batch_indices
return latest_batch_indices
def create_chat_chain(
chat_session_id: int,
db_session: Session,
) -> tuple[ChatMessage, list[ChatMessage]]:
"""Build the linear chain of messages without including the root message"""
mainline_messages: list[ChatMessage] = []
all_chat_messages = get_chat_messages_by_session(
chat_session_id=chat_session_id,
user_id=None,
db_session=db_session,
skip_permission_check=True,
)
id_to_msg = {msg.id: msg for msg in all_chat_messages}
if not all_chat_messages:
raise ValueError("No messages in Chat Session")
root_message = all_chat_messages[0]
if root_message.parent_message is not None:
raise RuntimeError(
"Invalid root message, unable to fetch valid chat message sequence"
)
current_message: ChatMessage | None = root_message
while current_message is not None:
child_msg = current_message.latest_child_message
if not child_msg:
break
current_message = id_to_msg.get(child_msg)
if current_message is None:
raise RuntimeError(
"Invalid message chain,"
"could not find next message in the same session"
)
mainline_messages.append(current_message)
if not mainline_messages:
raise RuntimeError("Could not trace chat message history")
return mainline_messages[-1], mainline_messages[:-1]
def combine_message_chain(
messages: list[ChatMessage],
msg_limit: int | None = 10,
token_limit: int | None = GEN_AI_HISTORY_CUTOFF,
) -> str:
"""Used for secondary LLM flows that require the chat history"""
message_strs: list[str] = []
total_token_count = 0
if msg_limit is not None:
messages = messages[-msg_limit:]
for message in reversed(messages):
message_token_count = message.token_count
if (
token_limit is not None
and total_token_count + message_token_count > token_limit
):
break
role = message.message_type.value.upper()
message_strs.insert(0, f"{role}:\n{message.message}")
total_token_count += message_token_count
return "\n\n".join(message_strs)
def find_last_index(
lst: list[int], max_prompt_tokens: int = GEN_AI_MAX_INPUT_TOKENS
) -> int:
"""From the back, find the index of the last element to include
before the list exceeds the maximum"""
running_sum = 0
last_ind = 0
for i in range(len(lst) - 1, -1, -1):
running_sum += lst[i]
if running_sum > max_prompt_tokens:
last_ind = i + 1
break
if last_ind >= len(lst):
raise ValueError("Last message alone is too large!")
return last_ind
def drop_messages_history_overflow(
system_msg: BaseMessage | None,
system_token_count: int,
history_msgs: list[BaseMessage],
history_token_counts: list[int],
final_msg: BaseMessage,
final_msg_token_count: int,
) -> list[BaseMessage]:
"""As message history grows, messages need to be dropped starting from the furthest in the past.
The System message should be kept if at all possible and the latest user input which is inserted in the
prompt template must be included"""
if len(history_msgs) != len(history_token_counts):
# This should never happen
raise ValueError("Need exactly 1 token count per message for tracking overflow")
prompt: list[BaseMessage] = []
# Start dropping from the history if necessary
all_tokens = history_token_counts + [system_token_count, final_msg_token_count]
ind_prev_msg_start = find_last_index(all_tokens)
if system_msg and ind_prev_msg_start <= len(history_msgs):
prompt.append(system_msg)
prompt.extend(history_msgs[ind_prev_msg_start:])
prompt.append(final_msg)
return prompt
def extract_citations_from_stream(
tokens: Iterator[str],
context_docs: list[LlmDoc],
doc_id_to_rank_map: dict[str, int],
) -> Iterator[DanswerAnswerPiece | CitationInfo]:
max_citation_num = len(context_docs)
curr_segment = ""
prepend_bracket = False
cited_inds = set()
for token in tokens:
# Special case of [1][ where ][ is a single token
# This is where the model attempts to do consecutive citations like [1][2]
if prepend_bracket:
curr_segment += "[" + curr_segment
prepend_bracket = False
curr_segment += token
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
citation_pattern = r"\[(\d+)\]" # [1], [2] etc
citation_found = re.search(citation_pattern, curr_segment)
if citation_found:
numerical_value = int(citation_found.group(1))
if 1 <= numerical_value <= max_citation_num:
context_llm_doc = context_docs[
numerical_value - 1
] # remove 1 index offset
link = context_llm_doc.link
target_citation_num = doc_id_to_rank_map[context_llm_doc.document_id]
# Use the citation number for the document's rank in
# the search (or selected docs) results
curr_segment = re.sub(
rf"\[{numerical_value}\]", f"[{target_citation_num}]", curr_segment
)
if target_citation_num not in cited_inds:
cited_inds.add(target_citation_num)
yield CitationInfo(
citation_num=target_citation_num,
document_id=context_llm_doc.document_id,
)
if link:
curr_segment = re.sub(r"\[", "[[", curr_segment, count=1)
curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1)
# In case there's another open bracket like [1][, don't want to match this
possible_citation_found = None
# if we see "[", but haven't seen the right side, hold back - this may be a
# citation that needs to be replaced with a link
if possible_citation_found:
continue
# Special case with back to back citations [1][2]
if curr_segment and curr_segment[-1] == "[":
curr_segment = curr_segment[:-1]
prepend_bracket = True
yield DanswerAnswerPiece(answer_piece=curr_segment)
curr_segment = ""
if curr_segment:
if prepend_bracket:
yield DanswerAnswerPiece(answer_piece="[" + curr_segment)
else:
yield DanswerAnswerPiece(answer_piece=curr_segment)

View File

@@ -0,0 +1,106 @@
from typing import cast
import yaml
from sqlalchemy.orm import Session
from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import PERSONAS_YAML
from danswer.configs.chat_configs import PROMPTS_YAML
from danswer.db.chat import get_prompt_by_name
from danswer.db.chat import upsert_persona
from danswer.db.chat import upsert_prompt
from danswer.db.document_set import get_or_create_document_set_by_name
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import DocumentSet as DocumentSetDBModel
from danswer.db.models import Prompt as PromptDBModel
from danswer.search.models import RecencyBiasSetting
def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None:
with open(prompts_yaml, "r") as file:
data = yaml.safe_load(file)
all_prompts = data.get("prompts", [])
with Session(get_sqlalchemy_engine()) as db_session:
for prompt in all_prompts:
upsert_prompt(
user_id=None,
prompt_id=prompt.get("id"),
name=prompt["name"],
description=prompt["description"].strip(),
system_prompt=prompt["system"].strip(),
task_prompt=prompt["task"].strip(),
include_citations=prompt["include_citations"],
datetime_aware=prompt.get("datetime_aware", True),
default_prompt=True,
personas=None,
shared=True,
db_session=db_session,
commit=True,
)
def load_personas_from_yaml(
personas_yaml: str = PERSONAS_YAML,
default_chunks: float = DEFAULT_NUM_CHUNKS_FED_TO_CHAT,
) -> None:
with open(personas_yaml, "r") as file:
data = yaml.safe_load(file)
all_personas = data.get("personas", [])
with Session(get_sqlalchemy_engine()) as db_session:
for persona in all_personas:
doc_set_names = persona["document_sets"]
doc_sets: list[DocumentSetDBModel] | None = [
get_or_create_document_set_by_name(db_session, name)
for name in doc_set_names
]
# Assume if user hasn't set any document sets for the persona, the user may want
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
# the document sets for the persona
if not doc_sets:
doc_sets = None
prompt_set_names = persona["prompts"]
if not prompt_set_names:
prompts: list[PromptDBModel | None] | None = None
else:
prompts = [
get_prompt_by_name(
prompt_name, user_id=None, shared=True, db_session=db_session
)
for prompt_name in prompt_set_names
]
if any([prompt is None for prompt in prompts]):
raise ValueError("Invalid Persona configs, not all prompts exist")
if not prompts:
prompts = None
upsert_persona(
user_id=None,
persona_id=persona.get("id"),
name=persona["name"],
description=persona["description"],
num_chunks=persona.get("num_chunks")
if persona.get("num_chunks") is not None
else default_chunks,
llm_relevance_filter=persona.get("llm_relevance_filter"),
llm_filter_extraction=persona.get("llm_filter_extraction"),
llm_model_version_override=None,
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
prompts=cast(list[PromptDBModel] | None, prompts),
document_sets=doc_sets,
default_persona=True,
shared=True,
db_session=db_session,
)
def load_chat_yamls(
prompt_yaml: str = PROMPTS_YAML,
personas_yaml: str = PERSONAS_YAML,
) -> None:
load_prompts_from_yaml(prompt_yaml)
load_personas_from_yaml(personas_yaml)

View File

@@ -0,0 +1,100 @@
from collections.abc import Iterator
from datetime import datetime
from typing import Any
from pydantic import BaseModel
from danswer.configs.constants import DocumentSource
from danswer.search.models import QueryFlow
from danswer.search.models import RetrievalDocs
from danswer.search.models import SearchResponse
from danswer.search.models import SearchType
class LlmDoc(BaseModel):
"""This contains the minimal set information for the LLM portion including citations"""
document_id: str
content: str
semantic_identifier: str
source_type: DocumentSource
updated_at: datetime | None
link: str | None
# First chunk of info for streaming QA
class QADocsResponse(RetrievalDocs):
rephrased_query: str | None = None
predicted_flow: QueryFlow | None
predicted_search: SearchType | None
applied_source_filters: list[DocumentSource] | None
applied_time_cutoff: datetime | None
recency_bias_multiplier: float
def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().dict(*args, **kwargs) # type: ignore
initial_dict["applied_time_cutoff"] = (
self.applied_time_cutoff.isoformat() if self.applied_time_cutoff else None
)
return initial_dict
# Second chunk of info for streaming QA
class LLMRelevanceFilterResponse(BaseModel):
relevant_chunk_indices: list[int]
class DanswerAnswerPiece(BaseModel):
# A small piece of a complete answer. Used for streaming back answers.
answer_piece: str | None # if None, specifies the end of an Answer
# An intermediate representation of citations, later translated into
# a mapping of the citation [n] number to SearchDoc
class CitationInfo(BaseModel):
citation_num: int
document_id: str
class StreamingError(BaseModel):
error: str
class DanswerQuote(BaseModel):
# This is during inference so everything is a string by this point
quote: str
document_id: str
link: str | None
source_type: str
semantic_identifier: str
blurb: str
class DanswerQuotes(BaseModel):
quotes: list[DanswerQuote]
class DanswerAnswer(BaseModel):
answer: str | None
class QAResponse(SearchResponse, DanswerAnswer):
quotes: list[DanswerQuote] | None
predicted_flow: QueryFlow
predicted_search: SearchType
eval_res_valid: bool | None = None
llm_chunks_indices: list[int] | None = None
error_msg: str | None = None
AnswerQuestionReturn = tuple[DanswerAnswer, DanswerQuotes]
AnswerQuestionStreamReturn = Iterator[
DanswerAnswerPiece | DanswerQuotes | StreamingError
]
class LLMMetricsContainer(BaseModel):
prompt_tokens: int
response_tokens: int

View File

@@ -1,81 +0,0 @@
from datetime import datetime
from typing import Any
import yaml
from sqlalchemy.orm import Session
from danswer.configs.app_configs import PERSONAS_YAML
from danswer.db.chat import upsert_persona
from danswer.db.document_set import get_or_create_document_set_by_name
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import DocumentSet as DocumentSetDBModel
from danswer.db.models import Persona
from danswer.db.models import ToolInfo
def build_system_text_from_persona(persona: Persona) -> str | None:
text = (persona.system_text or "").strip()
if persona.datetime_aware:
current_datetime = datetime.now()
# Format looks like: "October 16, 2023 14:30"
formatted_datetime = current_datetime.strftime("%B %d, %Y %H:%M")
text += (
"\n\nAdditional Information:\n"
f"\t- The current date and time is {formatted_datetime}."
)
return text or None
def validate_tool_info(item: Any) -> ToolInfo:
if not (
isinstance(item, dict)
and "name" in item
and isinstance(item["name"], str)
and "description" in item
and isinstance(item["description"], str)
):
raise ValueError(
"Invalid Persona configuration yaml Found, not all tools have name/description"
)
return ToolInfo(name=item["name"], description=item["description"])
def load_personas_from_yaml(personas_yaml: str = PERSONAS_YAML) -> None:
with open(personas_yaml, "r") as file:
data = yaml.safe_load(file)
all_personas = data.get("personas", [])
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session:
for persona in all_personas:
tools = [validate_tool_info(tool) for tool in persona["tools"]]
doc_set_names = persona["document_sets"]
doc_sets: list[DocumentSetDBModel] | None = [
get_or_create_document_set_by_name(db_session, name)
for name in doc_set_names
]
# Assume if user hasn't set any document sets for the persona, the user may want
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
# the document sets for the persona
if not doc_sets:
doc_sets = None
upsert_persona(
name=persona["name"],
retrieval_enabled=persona.get("retrieval_enabled", True),
# Default to knowing the date/time if not specified, however if there is no
# system prompt, do not interfere with the flow by adding a
# system prompt that is ONLY the date info, this would likely not be useful
datetime_aware=persona.get(
"datetime_aware", bool(persona.get("system"))
),
system_text=persona.get("system"),
tools=tools,
hint_text=persona.get("hint"),
default_persona=True,
document_sets=doc_sets,
db_session=db_session,
)

View File

@@ -1,12 +1,34 @@
# Currently in the UI, each Persona only has one prompt, which is why there are 3 very similar personas defined below.
personas:
- name: "Danswer"
system: |
You are a question answering system that is constantly learning and improving.
You can process and comprehend vast amounts of text and utilize this knowledge to provide accurate and detailed answers to diverse queries.
Your responses are as INFORMATIVE and DETAILED as possible.
Cite relevant statements using the format [1], [2], etc to reference the document number, do not provide any links following the citation.
# Document Sets that this persona has access to, specified as a list of names here.
# If left empty, the persona has access to all and only public docs
# This id field can be left blank for other default personas, however an id 0 persona must exist
# this is for DanswerBot to use when tagged in a non-configured channel
# Careful setting specific IDs, this won't autoincrement the next ID value for postgres
- id: 0
name: "Default"
description: >
Default Danswer Question Answering functionality.
# Default Prompt objects attached to the persona, see prompts.yaml
prompts:
- "Answer-Question"
# Default number of chunks to include as context, set to 0 to disable retrieval
# Remove the field to set to the system default number of chunks/tokens to pass to Gen AI
# If selecting documents, user can bypass this up until NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
# Each chunk is 512 tokens long
num_chunks: 5
# Enable/Disable usage of the LLM chunk filter feature whereby each chunk is passed to the LLM to determine
# if the chunk is useful or not towards the latest user query
# This feature can be overriden for all personas via DISABLE_LLM_CHUNK_FILTER env variable
llm_relevance_filter: true
# Enable/Disable usage of the LLM to extract query time filters including source type and time range filters
llm_filter_extraction: true
# Decay documents priority as they age, options are:
# - favor_recent (2x base by default, configurable)
# - base_decay
# - no_decay
# - auto (model chooses between favor_recent and base_decay based on user query)
recency_bias: "auto"
# Default Document Sets for this persona, specified as a list of names here.
# If the document set by the name exists, it will be attached to the persona
# If the document set by the name does not exist, it will be created as an empty document set with no connectors
# The admin can then use the UI to add new connectors to the document set
@@ -16,19 +38,28 @@ personas:
# - "Engineer Onboarding"
# - "Benefits"
document_sets: []
# Danswer custom tool flow, "Current Search" tool name is reserved if this is enabled.
retrieval_enabled: true
# Inject a statement at the end of system prompt to inform the LLM of the current date/time
# Format looks like: "October 16, 2023 14:30"
datetime_aware: true
# Personas can be given tools for Agentifying Danswer, however the tool call must be implemented in the code
# Once implemented, it can be given to personas via the config.
# Example of adding tools, it must follow this structure:
# tools:
# - name: "Calculator"
# description: "Use this tool to accurately process math equations, counting, etc."
# - name: "Current Weather"
# description: "Call this to get the current weather info."
tools: []
# Short tip to pass near the end of the prompt to emphasize some requirement
hint: "Try to be as informative as possible!"
- name: "Summarize"
description: >
A less creative assistant which summarizes relevant documents but does not try to
extrapolate any answers for you.
prompts:
- "Summarize"
num_chunks: 5
llm_relevance_filter: true
llm_filter_extraction: true
recency_bias: "auto"
document_sets: []
- name: "Paraphrase"
description: >
The least creative default assistant that only provides quotes from the documents.
prompts:
- "Paraphrase"
num_chunks: 5
llm_relevance_filter: true
llm_filter_extraction: true
recency_bias: "auto"
document_sets: []

View File

@@ -0,0 +1,471 @@
from collections.abc import Callable
from collections.abc import Iterator
from functools import partial
from typing import cast
from sqlalchemy.orm import Session
from danswer.chat.chat_utils import build_chat_system_message
from danswer.chat.chat_utils import build_chat_user_message
from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.chat_utils import drop_messages_history_overflow
from danswer.chat.chat_utils import extract_citations_from_stream
from danswer.chat.chat_utils import get_chunks_for_qa
from danswer.chat.chat_utils import llm_doc_from_inference_chunk
from danswer.chat.chat_utils import map_document_id_order
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import CHUNK_SIZE
from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT
from danswer.configs.constants import DISABLED_GEN_AI_MSG
from danswer.configs.constants import MessageType
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import get_chat_message
from danswer.db.chat import get_chat_session_by_id
from danswer.db.chat import get_db_search_doc_by_id
from danswer.db.chat import get_doc_query_identifiers_from_model
from danswer.db.chat import get_or_create_root_message
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.models import ChatMessage
from danswer.db.models import SearchDoc as DbSearchDoc
from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.models import InferenceChunk
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.interfaces import LLM
from danswer.llm.utils import get_default_llm_token_encode
from danswer.llm.utils import translate_history_to_basemessages
from danswer.search.models import OptionalSearchSetting
from danswer.search.models import RetrievalDetails
from danswer.search.request_preprocessing import retrieval_preprocessing
from danswer.search.search_runner import chunks_to_search_docs
from danswer.search.search_runner import full_chunk_search_generator
from danswer.search.search_runner import inference_documents_from_ids
from danswer.secondary_llm_flows.choose_search import check_if_need_search
from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
logger = setup_logger()
def generate_ai_chat_response(
query_message: ChatMessage,
history: list[ChatMessage],
context_docs: list[LlmDoc],
doc_id_to_rank_map: dict[str, int],
llm: LLM | None,
llm_tokenizer: Callable,
all_doc_useful: bool,
) -> Iterator[DanswerAnswerPiece | CitationInfo | StreamingError]:
if llm is None:
try:
llm = get_default_llm()
except GenAIDisabledException:
# Not an error if it's a user configuration
yield DanswerAnswerPiece(answer_piece=DISABLED_GEN_AI_MSG)
return
if query_message.prompt is None:
raise RuntimeError("No prompt received for generating Gen AI answer.")
try:
context_exists = len(context_docs) > 0
system_message_or_none, system_tokens = build_chat_system_message(
prompt=query_message.prompt,
context_exists=context_exists,
llm_tokenizer=llm_tokenizer,
)
history_basemessages, history_token_counts = translate_history_to_basemessages(
history
)
# Be sure the context_docs passed to build_chat_user_message
# Is the same as passed in later for extracting citations
user_message, user_tokens = build_chat_user_message(
chat_message=query_message,
prompt=query_message.prompt,
context_docs=context_docs,
llm_tokenizer=llm_tokenizer,
all_doc_useful=all_doc_useful,
)
prompt = drop_messages_history_overflow(
system_msg=system_message_or_none,
system_token_count=system_tokens,
history_msgs=history_basemessages,
history_token_counts=history_token_counts,
final_msg=user_message,
final_msg_token_count=user_tokens,
)
# Good Debug/Breakpoint
tokens = llm.stream(prompt)
yield from extract_citations_from_stream(
tokens, context_docs, doc_id_to_rank_map
)
except Exception as e:
logger.exception(f"LLM failed to produce valid chat message, error: {e}")
yield StreamingError(error=str(e))
def translate_citations(
citations_list: list[CitationInfo], db_docs: list[DbSearchDoc]
) -> dict[int, int]:
"""Always cites the first instance of the document_id, assumes the db_docs
are sorted in the order displayed in the UI"""
doc_id_to_saved_doc_id_map: dict[str, int] = {}
for db_doc in db_docs:
if db_doc.document_id not in doc_id_to_saved_doc_id_map:
doc_id_to_saved_doc_id_map[db_doc.document_id] = db_doc.id
citation_to_saved_doc_id_map: dict[int, int] = {}
for citation in citations_list:
if citation.citation_num not in citation_to_saved_doc_id_map:
citation_to_saved_doc_id_map[
citation.citation_num
] = doc_id_to_saved_doc_id_map[citation.document_id]
return citation_to_saved_doc_id_map
@log_generator_function_time()
def stream_chat_message(
new_msg_req: CreateChatMessageRequest,
user: User | None,
db_session: Session,
# Needed to translate persona num_chunks to tokens to the LLM
default_num_chunks: float = DEFAULT_NUM_CHUNKS_FED_TO_CHAT,
default_chunk_size: int = CHUNK_SIZE,
) -> Iterator[str]:
"""Streams in order:
1. [conditional] Retrieved documents if a search needs to be run
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
4. [always] Details on the final AI response message that is created
"""
try:
user_id = user.id if user is not None else None
chat_session = get_chat_session_by_id(
chat_session_id=new_msg_req.chat_session_id,
user_id=user_id,
db_session=db_session,
)
message_text = new_msg_req.message
chat_session_id = new_msg_req.chat_session_id
parent_id = new_msg_req.parent_message_id
prompt_id = new_msg_req.prompt_id
reference_doc_ids = new_msg_req.search_doc_ids
retrieval_options = new_msg_req.retrieval_options
persona = chat_session.persona
query_override = new_msg_req.query_override
if reference_doc_ids is None and retrieval_options is None:
raise RuntimeError(
"Must specify a set of documents for chat or specify search options"
)
try:
llm = get_default_llm()
except GenAIDisabledException:
llm = None
llm_tokenizer = get_default_llm_token_encode()
document_index = get_default_document_index()
# Every chat Session begins with an empty root message
root_message = get_or_create_root_message(
chat_session_id=chat_session_id, db_session=db_session
)
if parent_id is not None:
parent_message = get_chat_message(
chat_message_id=parent_id,
user_id=user_id,
db_session=db_session,
)
else:
parent_message = root_message
# Create new message at the right place in the tree and update the parent's child pointer
# Don't commit yet until we verify the chat message chain
new_user_message = create_new_chat_message(
chat_session_id=chat_session_id,
parent_message=parent_message,
prompt_id=prompt_id,
message=message_text,
token_count=len(llm_tokenizer(message_text)),
message_type=MessageType.USER,
db_session=db_session,
commit=False,
)
# Create linear history of messages
final_msg, history_msgs = create_chat_chain(
chat_session_id=chat_session_id, db_session=db_session
)
if final_msg.id != new_user_message.id:
db_session.rollback()
raise RuntimeError(
"The new message was not on the mainline. "
"Be sure to update the chat pointers before calling this."
)
# Save now to save the latest chat message
db_session.commit()
run_search = False
# Retrieval options are only None if reference_doc_ids are provided
if retrieval_options is not None and persona.num_chunks != 0:
if retrieval_options.run_search == OptionalSearchSetting.ALWAYS:
run_search = True
elif retrieval_options.run_search == OptionalSearchSetting.NEVER:
run_search = False
else:
run_search = check_if_need_search(
query_message=final_msg, history=history_msgs, llm=llm
)
rephrased_query = None
if reference_doc_ids:
identifier_tuples = get_doc_query_identifiers_from_model(
search_doc_ids=reference_doc_ids,
chat_session=chat_session,
user_id=user_id,
db_session=db_session,
)
# Generates full documents currently
# May extend to include chunk ranges
llm_docs: list[LlmDoc] = inference_documents_from_ids(
doc_identifiers=identifier_tuples,
document_index=get_default_document_index(),
)
doc_id_to_rank_map = map_document_id_order(
cast(list[InferenceChunk | LlmDoc], llm_docs)
)
# In case the search doc is deleted, just don't include it
# though this should never happen
db_search_docs_or_none = [
get_db_search_doc_by_id(doc_id=doc_id, db_session=db_session)
for doc_id in reference_doc_ids
]
reference_db_search_docs = [
db_sd for db_sd in db_search_docs_or_none if db_sd
]
elif run_search:
rephrased_query = (
history_based_query_rephrase(
query_message=final_msg, history=history_msgs, llm=llm
)
if query_override is None
else query_override
)
(
retrieval_request,
predicted_search_type,
predicted_flow,
) = retrieval_preprocessing(
query=rephrased_query,
retrieval_details=cast(RetrievalDetails, retrieval_options),
persona=persona,
user=user,
db_session=db_session,
)
documents_generator = full_chunk_search_generator(
search_query=retrieval_request,
document_index=document_index,
)
time_cutoff = retrieval_request.filters.time_cutoff
recency_bias_multiplier = retrieval_request.recency_bias_multiplier
run_llm_chunk_filter = not retrieval_request.skip_llm_chunk_filter
# First fetch and return the top chunks to the UI so the user can
# immediately see some results
top_chunks = cast(list[InferenceChunk], next(documents_generator))
# Get ranking of the documents for citation purposes later
doc_id_to_rank_map = map_document_id_order(
cast(list[InferenceChunk | LlmDoc], top_chunks)
)
top_docs = chunks_to_search_docs(top_chunks)
reference_db_search_docs = [
create_db_search_doc(server_search_doc=top_doc, db_session=db_session)
for top_doc in top_docs
]
response_docs = [
translate_db_search_doc_to_server_search_doc(db_search_doc)
for db_search_doc in reference_db_search_docs
]
initial_response = QADocsResponse(
rephrased_query=rephrased_query,
top_documents=response_docs,
predicted_flow=predicted_flow,
predicted_search=predicted_search_type,
applied_source_filters=retrieval_request.filters.source_type,
applied_time_cutoff=time_cutoff,
recency_bias_multiplier=recency_bias_multiplier,
).dict()
yield get_json_line(initial_response)
# Get the final ordering of chunks for the LLM call
llm_chunk_selection = cast(list[bool], next(documents_generator))
# Yield the list of LLM selected chunks for showing the LLM selected icons in the UI
llm_relevance_filtering_response = LLMRelevanceFilterResponse(
relevant_chunk_indices=[
index for index, value in enumerate(llm_chunk_selection) if value
]
if run_llm_chunk_filter
else []
).dict()
yield get_json_line(llm_relevance_filtering_response)
# Prep chunks to pass to LLM
num_llm_chunks = (
persona.num_chunks
if persona.num_chunks is not None
else default_num_chunks
)
llm_chunks_indices = get_chunks_for_qa(
chunks=top_chunks,
llm_chunk_selection=llm_chunk_selection,
token_limit=num_llm_chunks * default_chunk_size,
)
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in llm_chunks]
else:
llm_docs = []
doc_id_to_rank_map = {}
reference_db_search_docs = None
# Cannot determine these without the LLM step or breaking out early
partial_response = partial(
create_new_chat_message,
chat_session_id=chat_session_id,
parent_message=new_user_message,
prompt_id=prompt_id,
# message=,
rephrased_query=rephrased_query,
# token_count=,
message_type=MessageType.ASSISTANT,
# error=,
reference_docs=reference_db_search_docs,
db_session=db_session,
commit=True,
)
# If no prompt is provided, this is interpreted as not wanting an AI Answer
# Simply provide/save the retrieval results
if final_msg.prompt is None:
gen_ai_response_message = partial_response(
message="",
token_count=0,
citations=None,
error=None,
)
msg_detail_response = translate_db_message_to_chat_message_detail(
gen_ai_response_message
)
yield get_json_line(msg_detail_response.dict())
# Stop here after saving message details, the above still needs to be sent for the
# message id to send the next follow-up message
return
# LLM prompt building, response capturing, etc.
response_packets = generate_ai_chat_response(
query_message=final_msg,
history=history_msgs,
context_docs=llm_docs,
doc_id_to_rank_map=doc_id_to_rank_map,
llm=llm,
llm_tokenizer=llm_tokenizer,
all_doc_useful=reference_doc_ids is not None,
)
# Capture outputs and errors
llm_output = ""
error: str | None = None
citations: list[CitationInfo] = []
for packet in response_packets:
if isinstance(packet, DanswerAnswerPiece):
token = packet.answer_piece
if token:
llm_output += token
elif isinstance(packet, StreamingError):
error = packet.error
elif isinstance(packet, CitationInfo):
citations.append(packet)
continue
yield get_json_line(packet.dict())
except Exception as e:
logger.exception(e)
# Frontend will erase whatever answer and show this instead
# This will be the issue 99% of the time
error_packet = StreamingError(
error="LLM failed to respond, have you set your API key?"
)
yield get_json_line(error_packet.dict())
return
# Post-LLM answer processing
try:
db_citations = None
if reference_db_search_docs:
db_citations = translate_citations(
citations_list=citations,
db_docs=reference_db_search_docs,
)
# Saving Gen AI answer and responding with message info
gen_ai_response_message = partial_response(
message=llm_output,
token_count=len(llm_tokenizer(llm_output)),
citations=db_citations,
error=error,
)
msg_detail_response = translate_db_message_to_chat_message_detail(
gen_ai_response_message
)
yield get_json_line(msg_detail_response.dict())
except Exception as e:
logger.exception(e)
# Frontend will erase whatever answer and show this instead
error_packet = StreamingError(error="Failed to parse LLM output")
yield get_json_line(error_packet.dict())

View File

@@ -0,0 +1,68 @@
prompts:
# This id field can be left blank for other default prompts, however an id 0 prompt must exist
# This is to act as a default
# Careful setting specific IDs, this won't autoincrement the next ID value for postgres
- id: 0
name: "Answer-Question"
description: "Answers user questions using retrieved context!"
# System Prompt (as shown in UI)
system: >
You are a question answering system that is constantly learning and improving.
You can process and comprehend vast amounts of text and utilize this knowledge to provide
grounded, accurate, and concise answers to diverse queries.
You always clearly communicate ANY UNCERTAINTY in your answer.
# Task Prompt (as shown in UI)
task: >
Answer my query based on the documents provided.
The documents may not all be relevant, ignore any documents that are not directly relevant
to the most recent user query.
I have not read or seen any of the documents and do not want to read them.
If there are no relevant documents, refer to the chat history and existing knowledge.
# Inject a statement at the end of system prompt to inform the LLM of the current date/time
# Format looks like: "October 16, 2023 14:30"
datetime_aware: true
# Prompts the LLM to include citations in the for [1], [2] etc.
# which get parsed to match the passed in sources
include_citations: true
- name: "Summarize"
description: "Summarize relevant information from retrieved context!"
system: >
You are a text summarizing assistant that highlights the most important knowledge from the
context provided, prioritizing the information that relates to the user query.
You ARE NOT creative and always stick to the provided documents.
If there are no documents, refer to the conversation history.
IMPORTANT: YOU ONLY SUMMARIZE THE IMPORTANT INFORMATION FROM THE PROVIDED DOCUMENTS,
NEVER USE YOUR OWN KNOWLEDGE.
task: >
Summarize the documents provided in relation to the query below.
NEVER refer to the documents by number, I do not have them in the same order as you.
Do not make up any facts, only use what is in the documents.
datetime_aware: true
include_citations: true
- name: "Paraphrase"
description: "Recites information from retrieved context! Least creative but most safe!"
system: >
Quote and cite relevant information from provided context based on the user query.
You only provide quotes that are EXACT substrings from provided documents!
If there are no documents provided,
simply tell the user that there are no documents to reference.
You NEVER generate new text or phrases outside of the citation.
DO NOT explain your responses, only provide the quotes and NOTHING ELSE.
task: >
Provide EXACT quotes from the provided documents above. Do not generate any new text that is not
directly from the documents.
datetime_aware: true
include_citations: true

View File

@@ -1,7 +1,115 @@
from danswer.direct_qa.interfaces import DanswerChatModelOut
from typing import TypedDict
from pydantic import BaseModel
from danswer.prompts.chat_tools import DANSWER_TOOL_DESCRIPTION
from danswer.prompts.chat_tools import DANSWER_TOOL_NAME
from danswer.prompts.chat_tools import TOOL_FOLLOWUP
from danswer.prompts.chat_tools import TOOL_LESS_FOLLOWUP
from danswer.prompts.chat_tools import TOOL_LESS_PROMPT
from danswer.prompts.chat_tools import TOOL_TEMPLATE
from danswer.prompts.chat_tools import USER_INPUT
class ToolInfo(TypedDict):
name: str
description: str
class DanswerChatModelOut(BaseModel):
model_raw: str
action: str
action_input: str
def call_tool(
model_actions: DanswerChatModelOut,
) -> str:
raise NotImplementedError("There are no additional tool integrations right now")
def form_user_prompt_text(
query: str,
tool_text: str | None,
hint_text: str | None,
user_input_prompt: str = USER_INPUT,
tool_less_prompt: str = TOOL_LESS_PROMPT,
) -> str:
user_prompt = tool_text or tool_less_prompt
user_prompt += user_input_prompt.format(user_input=query)
if hint_text:
if user_prompt[-1] != "\n":
user_prompt += "\n"
user_prompt += "\nHint: " + hint_text
return user_prompt.strip()
def form_tool_section_text(
tools: list[ToolInfo] | None, retrieval_enabled: bool, template: str = TOOL_TEMPLATE
) -> str | None:
if not tools and not retrieval_enabled:
return None
if retrieval_enabled and tools:
tools.append(
{"name": DANSWER_TOOL_NAME, "description": DANSWER_TOOL_DESCRIPTION}
)
tools_intro = []
if tools:
num_tools = len(tools)
for tool in tools:
description_formatted = tool["description"].replace("\n", " ")
tools_intro.append(f"> {tool['name']}: {description_formatted}")
prefix = "Must be one of " if num_tools > 1 else "Must be "
tools_intro_text = "\n".join(tools_intro)
tool_names_text = prefix + ", ".join([tool["name"] for tool in tools])
else:
return None
return template.format(
tool_overviews=tools_intro_text, tool_names=tool_names_text
).strip()
def form_tool_followup_text(
tool_output: str,
query: str,
hint_text: str | None,
tool_followup_prompt: str = TOOL_FOLLOWUP,
ignore_hint: bool = False,
) -> str:
# If multi-line query, it likely confuses the model more than helps
if "\n" not in query:
optional_reminder = f"\nAs a reminder, my query was: {query}\n"
else:
optional_reminder = ""
if not ignore_hint and hint_text:
hint_text_spaced = f"\nHint: {hint_text}\n"
else:
hint_text_spaced = ""
return tool_followup_prompt.format(
tool_output=tool_output,
optional_reminder=optional_reminder,
hint=hint_text_spaced,
).strip()
def form_tool_less_followup_text(
tool_output: str,
query: str,
hint_text: str | None,
tool_followup_prompt: str = TOOL_LESS_FOLLOWUP,
) -> str:
hint = f"Hint: {hint_text}" if hint_text else ""
return tool_followup_prompt.format(
context_str=tool_output, user_query=query, hint_text=hint
).strip()

View File

@@ -3,11 +3,16 @@ import os
from danswer.configs.constants import AuthType
from danswer.configs.constants import DocumentIndexType
#####
# App Configs
#####
APP_HOST = "0.0.0.0"
APP_PORT = 8080
# API_PREFIX is used to prepend a base path for all API routes
# generally used if using a reverse proxy which doesn't support stripping the `/api`
# prefix from requests directed towards the API server. In these cases, set this to `/api`
APP_API_PREFIX = os.environ.get("API_PREFIX", "")
#####
@@ -15,10 +20,9 @@ APP_PORT = 8080
#####
BLURB_SIZE = 128 # Number Encoder Tokens included in the chunk blurb
GENERATIVE_MODEL_ACCESS_CHECK_FREQ = 86400 # 1 day
# DISABLE_GENERATIVE_AI will turn of the question answering part of Danswer.
# Use this if you want to use Danswer as a search engine only without the LLM capabilities
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
#####
# Web Configs
#####
@@ -39,7 +43,7 @@ MASK_CREDENTIAL_PREFIX = (
SECRET = os.environ.get("SECRET", "")
SESSION_EXPIRE_TIME_SECONDS = int(
os.environ.get("SESSION_EXPIRE_TIME_SECONDS", 86400)
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400
) # 1 day
# set `VALID_EMAIL_DOMAINS` to a comma seperated list of domains in order to
@@ -56,7 +60,6 @@ VALID_EMAIL_DOMAINS = (
if _VALID_EMAIL_DOMAINS_STR
else []
)
# OAuth Login Flow
# Used for both Google OAuth2 and OIDC flows
OAUTH_CLIENT_ID = (
@@ -67,12 +70,12 @@ OAUTH_CLIENT_SECRET = (
or ""
)
# The following Basic Auth configs are not supported by the frontend UI
# for basic auth
REQUIRE_EMAIL_VERIFICATION = (
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
)
SMTP_SERVER = os.environ.get("SMTP_SERVER", "smtp.gmail.com")
SMTP_PORT = int(os.environ.get("SMTP_PORT", "587"))
SMTP_SERVER = os.environ.get("SMTP_SERVER") or "smtp.gmail.com"
SMTP_PORT = int(os.environ.get("SMTP_PORT") or "587")
SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
@@ -80,7 +83,7 @@ SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
#####
# DB Configs
#####
DOCUMENT_INDEX_NAME = "danswer_index" # Shared by vector/keyword indices
DOCUMENT_INDEX_NAME = "danswer_index"
# Vespa is now the default document index store for both keyword and vector
DOCUMENT_INDEX_TYPE = os.environ.get(
"DOCUMENT_INDEX_TYPE", DocumentIndexType.COMBINED.value
@@ -93,7 +96,10 @@ VESPA_DEPLOYMENT_ZIP = (
os.environ.get("VESPA_DEPLOYMENT_ZIP") or "/app/danswer/vespa-app.zip"
)
# Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder)
INDEX_BATCH_SIZE = 16
try:
INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE", 16))
except ValueError:
INDEX_BATCH_SIZE = 16
# Below are intended to match the env variables names used by the official postgres docker image
# https://hub.docker.com/_/postgres
@@ -140,80 +146,17 @@ CONFLUENCE_CONNECTOR_LABELS_TO_SKIP = [
GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME")
EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED = (
os.environ.get("EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED", "").lower() == "true"
DASK_JOB_CLIENT_ENABLED = (
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
)
EXPERIMENTAL_CHECKPOINTING_ENABLED = (
os.environ.get("EXPERIMENTAL_CHECKPOINTING_ENABLED", "").lower() == "true"
)
#####
# Query Configs
#####
NUM_RETURNED_HITS = 50
NUM_RERANKED_RESULTS = 15
# We feed in document chunks until we reach this token limit.
# Default is ~5 full chunks (max chunk size is 2000 chars), although some chunks
# may be smaller which could result in passing in more total chunks
NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL = int(
os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL") or (512 * 5)
)
NUM_DOCUMENT_TOKENS_FED_TO_CHAT = int(
os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_CHAT") or (512 * 3)
)
# 1 / (1 + DOC_TIME_DECAY * doc-age-in-years), set to 0 to have no decay
# Capped in Vespa at 0.5
DOC_TIME_DECAY = float(
os.environ.get("DOC_TIME_DECAY") or 0.5 # Hits limit at 2 years by default
)
FAVOR_RECENT_DECAY_MULTIPLIER = 2
DISABLE_TIME_FILTER_EXTRACTION = (
os.environ.get("DISABLE_TIME_FILTER_EXTRACTION", "").lower() == "true"
)
# 1 edit per 2 characters, currently unused due to fuzzy match being too slow
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
# Include additional document/chunk metadata in prompt to GenerativeAI
INCLUDE_METADATA = False
HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "True").lower() != "false"
# Keyword Search Drop Stopwords
# If user has changed the default model, would most likely be to use a multilingual
# model, the stopwords are NLTK english stopwords so then we would want to not drop the keywords
if os.environ.get("EDIT_KEYWORD_QUERY"):
EDIT_KEYWORD_QUERY = os.environ.get("EDIT_KEYWORD_QUERY", "").lower() == "true"
else:
EDIT_KEYWORD_QUERY = not os.environ.get("DOCUMENT_ENCODER_MODEL")
#####
# Text Processing Configs
# Indexing Configs
#####
CHUNK_SIZE = 512 # Tokens by embedding model
CHUNK_OVERLAP = int(CHUNK_SIZE * 0.05) # 5% overlap
# More accurate results at the expense of indexing speed and index size (stores additional 4 MINI_CHUNK vectors)
ENABLE_MINI_CHUNK = os.environ.get("ENABLE_MINI_CHUNK", "").lower() == "true"
# Finer grained chunking for more detail retention
# Slightly larger since the sentence aware split is a max cutoff so most minichunks will be under MINI_CHUNK_SIZE
# tokens. But we need it to be at least as big as 1/4th chunk size to avoid having a tiny mini-chunk at the end
MINI_CHUNK_SIZE = 150
#####
# Encoder Model Endpoint Configs (Currently unused, running the models in memory)
#####
BI_ENCODER_HOST = "localhost"
BI_ENCODER_PORT = 9000
CROSS_ENCODER_HOST = "localhost"
CROSS_ENCODER_PORT = 9000
#####
# Miscellaneous
#####
PERSONAS_YAML = "./danswer/chat/personas.yaml"
DYNAMIC_CONFIG_STORE = os.environ.get(
"DYNAMIC_CONFIG_STORE", "FileSystemBackedDynamicConfigStore"
)
DYNAMIC_CONFIG_DIR_PATH = os.environ.get("DYNAMIC_CONFIG_DIR_PATH", "/home/storage")
# notset, debug, info, warning, error, or critical
LOG_LEVEL = os.environ.get("LOG_LEVEL", "info")
# NOTE: Currently only supported in the Confluence and Google Drive connectors +
# only handles some failures (Confluence = handles API call failures, Google
# Drive = handles failures pulling files / parsing them)
@@ -225,8 +168,57 @@ CONTINUE_ON_CONNECTOR_FAILURE = os.environ.get(
# fairly large amount of memory in order to increase substantially, since
# each worker loads the embedding models into memory.
NUM_INDEXING_WORKERS = int(os.environ.get("NUM_INDEXING_WORKERS") or 1)
CHUNK_OVERLAP = 0
# More accurate results at the expense of indexing speed and index size (stores additional 4 MINI_CHUNK vectors)
ENABLE_MINI_CHUNK = os.environ.get("ENABLE_MINI_CHUNK", "").lower() == "true"
# Finer grained chunking for more detail retention
# Slightly larger since the sentence aware split is a max cutoff so most minichunks will be under MINI_CHUNK_SIZE
# tokens. But we need it to be at least as big as 1/4th chunk size to avoid having a tiny mini-chunk at the end
MINI_CHUNK_SIZE = 150
# Timeout to wait for job's last update before killing it, in hours
CLEANUP_INDEXING_JOBS_TIMEOUT = int(os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT", 1))
#####
# Model Server Configs
#####
# If MODEL_SERVER_HOST is set, the NLP models required for Danswer are offloaded to the server via
# requests. Be sure to include the scheme in the MODEL_SERVER_HOST value.
MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or None
MODEL_SERVER_ALLOWED_HOST = os.environ.get("MODEL_SERVER_HOST") or "0.0.0.0"
MODEL_SERVER_PORT = int(os.environ.get("MODEL_SERVER_PORT") or "9000")
# specify this env variable directly to have a different model server for the background
# indexing job vs the api server so that background indexing does not effect query-time
# performance
INDEXING_MODEL_SERVER_HOST = (
os.environ.get("INDEXING_MODEL_SERVER_HOST") or MODEL_SERVER_HOST
)
#####
# Miscellaneous
#####
DYNAMIC_CONFIG_STORE = os.environ.get(
"DYNAMIC_CONFIG_STORE", "FileSystemBackedDynamicConfigStore"
)
DYNAMIC_CONFIG_DIR_PATH = os.environ.get("DYNAMIC_CONFIG_DIR_PATH", "/home/storage")
JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default
# used to allow the background indexing jobs to use a different embedding
# model server than the API server
CURRENT_PROCESS_IS_AN_INDEXING_JOB = (
os.environ.get("CURRENT_PROCESS_IS_AN_INDEXING_JOB", "").lower() == "true"
)
# Logs every model prompt and output, mostly used for development or exploration purposes
LOG_ALL_MODEL_INTERACTIONS = (
os.environ.get("LOG_ALL_MODEL_INTERACTIONS", "").lower() == "true"
)
# If set to `true` will enable additional logs about Vespa query performance
# (time spent on finding the right docs + time spent fetching summaries from disk)
LOG_VESPA_TIMING_INFORMATION = (
os.environ.get("LOG_VESPA_TIMING_INFORMATION", "").lower() == "true"
)
# Anonymous usage telemetry
DISABLE_TELEMETRY = os.environ.get("DISABLE_TELEMETRY", "").lower() == "true"
# notset, debug, info, warning, error, or critical
LOG_LEVEL = os.environ.get("LOG_LEVEL", "info")

View File

@@ -1,3 +1,75 @@
import os
FORCE_TOOL_PROMPT = os.environ.get("FORCE_TOOL_PROMPT", "").lower() == "true"
from danswer.configs.model_configs import CHUNK_SIZE
PROMPTS_YAML = "./danswer/chat/prompts.yaml"
PERSONAS_YAML = "./danswer/chat/personas.yaml"
NUM_RETURNED_HITS = 50
NUM_RERANKED_RESULTS = 15
# We feed in document chunks until we reach this token limit.
# Default is ~5 full chunks (max chunk size is 2000 chars), although some chunks may be
# significantly smaller which could result in passing in more total chunks.
# There is also a slight bit of overhead, not accounted for here such as separator patterns
# between the docs, metadata for the docs, etc.
# Finally, this is combined with the rest of the QA prompt, so don't set this too close to the
# model token limit
NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL = int(
os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL") or (CHUNK_SIZE * 5)
)
DEFAULT_NUM_CHUNKS_FED_TO_CHAT: float = (
float(NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL) / CHUNK_SIZE
)
NUM_DOCUMENT_TOKENS_FED_TO_CHAT = int(
os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_CHAT") or (CHUNK_SIZE * 3)
)
# For selecting a different LLM question-answering prompt format
# Valid values: default, cot, weak
QA_PROMPT_OVERRIDE = os.environ.get("QA_PROMPT_OVERRIDE") or None
# 1 / (1 + DOC_TIME_DECAY * doc-age-in-years), set to 0 to have no decay
# Capped in Vespa at 0.5
DOC_TIME_DECAY = float(
os.environ.get("DOC_TIME_DECAY") or 0.5 # Hits limit at 2 years by default
)
FAVOR_RECENT_DECAY_MULTIPLIER = 2.0
# Currently this next one is not configurable via env
DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak"
DISABLE_LLM_FILTER_EXTRACTION = (
os.environ.get("DISABLE_LLM_FILTER_EXTRACTION", "").lower() == "true"
)
# Whether the LLM should evaluate all of the document chunks passed in for usefulness
# in relation to the user query
DISABLE_LLM_CHUNK_FILTER = (
os.environ.get("DISABLE_LLM_CHUNK_FILTER", "").lower() == "true"
)
# Whether the LLM should be used to decide if a search would help given the chat history
DISABLE_LLM_CHOOSE_SEARCH = (
os.environ.get("DISABLE_LLM_CHOOSE_SEARCH", "").lower() == "true"
)
# 1 edit per 20 characters, currently unused due to fuzzy match being too slow
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
# Include additional document/chunk metadata in prompt to GenerativeAI
INCLUDE_METADATA = False
# Keyword Search Drop Stopwords
# If user has changed the default model, would most likely be to use a multilingual
# model, the stopwords are NLTK english stopwords so then we would want to not drop the keywords
if os.environ.get("EDIT_KEYWORD_QUERY"):
EDIT_KEYWORD_QUERY = os.environ.get("EDIT_KEYWORD_QUERY", "").lower() == "true"
else:
EDIT_KEYWORD_QUERY = not os.environ.get("DOCUMENT_ENCODER_MODEL")
# Weighting factor between Vector and Keyword Search, 1 for completely vector search
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.66)))
# Weighting factor between Title and Content of documents during search, 1 for completely
# Title based. Default heavily favors Content because Title is also included at the top of
# Content. This is to avoid cases where the Content is very relevant but it may not be clear
# if the title is separated out. Title is most of a "boost" than a separate field.
TITLE_CONTENT_RATIO = max(
0, min(1, float(os.environ.get("TITLE_CONTENT_RATIO") or 0.20))
)
# A list of languages passed to the LLM to rephase the query
# For example "English,French,Spanish", be sure to use the "," separator
MULTILINGUAL_QUERY_EXPANSION = os.environ.get("MULTILINGUAL_QUERY_EXPANSION") or None
# The backend logic for this being True isn't fully supported yet
HARD_DELETE_CHATS = False

View File

@@ -11,11 +11,13 @@ SEMANTIC_IDENTIFIER = "semantic_identifier"
TITLE = "title"
SECTION_CONTINUATION = "section_continuation"
EMBEDDINGS = "embeddings"
TITLE_EMBEDDING = "title_embedding"
ALLOWED_USERS = "allowed_users"
ACCESS_CONTROL_LIST = "access_control_list"
DOCUMENT_SETS = "document_sets"
TIME_FILTER = "time_filter"
METADATA = "metadata"
METADATA_LIST = "metadata_list"
MATCH_HIGHLIGHTS = "match_highlights"
# stored in the `metadata` of a chunk. Used to signify that this chunk should
# not be used for QA. For example, Google Drive file types which can't be parsed
@@ -35,26 +37,31 @@ SCORE = "score"
ID_SEPARATOR = ":;:"
DEFAULT_BOOST = 0
SESSION_KEY = "session"
QUERY_EVENT_ID = "query_event_id"
LLM_CHUNKS = "llm_chunks"
# Prompt building constants:
GENERAL_SEP_PAT = "\n-----\n"
CODE_BLOCK_PAT = "\n```\n{}\n```\n"
DOC_SEP_PAT = "---NEW DOCUMENT---"
DOC_CONTENT_START_PAT = "DOCUMENT CONTENTS:\n"
QUESTION_PAT = "Query:"
THOUGHT_PAT = "Thought:"
ANSWER_PAT = "Answer:"
FINAL_ANSWER_PAT = "Final Answer:"
UNCERTAINTY_PAT = "?"
QUOTE_PAT = "Quote:"
QUOTES_PAT_PLURAL = "Quotes:"
INVALID_PAT = "Invalid:"
# For chunking/processing chunks
TITLE_SEPARATOR = "\n\r\n"
SECTION_SEPARATOR = "\n\n"
# For combining attributes, doesn't have to be unique/perfect to work
INDEX_SEPARATOR = "==="
# Messages
DISABLED_GEN_AI_MSG = (
"Your System Admin has disabled the Generative AI functionalities of Danswer.\n"
"Please contact them if you wish to have this enabled.\n"
"You can still use Danswer as a search engine."
)
class DocumentSource(str, Enum):
# Special case, document passed in via Danswer APIs without specifying a source type
INGESTION_API = "ingestion_api"
SLACK = "slack"
WEB = "web"
GOOGLE_DRIVE = "google_drive"
REQUESTTRACKER = "requesttracker"
GITHUB = "github"
GURU = "guru"
BOOKSTACK = "bookstack"
@@ -86,11 +93,6 @@ class AuthType(str, Enum):
SAML = "saml"
class QAFeedbackType(str, Enum):
LIKE = "like" # User likes the answer, used for metrics
DISLIKE = "dislike" # User dislikes the answer, used for metrics
class SearchFeedbackType(str, Enum):
ENDORSE = "endorse" # boost this document for all future queries
REJECT = "reject" # down-boost this document for all future queries
@@ -100,7 +102,7 @@ class SearchFeedbackType(str, Enum):
class MessageType(str, Enum):
# Using OpenAI standards, Langchain equivalent shown in comment
# System message is always constructed on the fly, not saved
SYSTEM = "system" # SystemMessage
USER = "user" # HumanMessage
ASSISTANT = "assistant" # AIMessage
DANSWER = "danswer" # FunctionMessage

View File

@@ -41,11 +41,10 @@ DISABLE_DANSWER_BOT_FILTER_DETECT = (
)
# Add a second LLM call post Answer to verify if the Answer is valid
# Throws out answers that don't directly or fully answer the user query
# This is the default for all DanswerBot channels unless the bot is configured individually
# This is the default for all DanswerBot channels unless the channel is configured individually
# Set/unset by "Hide Non Answers"
ENABLE_DANSWERBOT_REFLEXION = (
os.environ.get("ENABLE_DANSWERBOT_REFLEXION", "").lower() == "true"
)
# Add the per document feedback blocks that affect the document rankings via boosting
ENABLE_SLACK_DOC_FEEDBACK = (
os.environ.get("ENABLE_SLACK_DOC_FEEDBACK", "").lower() == "true"
)
# Currently not support chain of thought, probably will add back later
DANSWER_BOT_DISABLE_COT = True

View File

@@ -3,11 +3,11 @@ import os
#####
# Embedding/Reranking Model Configs
#####
CHUNK_SIZE = 512
# Important considerations when choosing models
# Max tokens count needs to be high considering use case (at least 512)
# Models used must be MIT or Apache license
# Inference/Indexing speed
# https://huggingface.co/DOCUMENT_ENCODER_MODEL
# The useable models configured as below must be SentenceTransformer compatible
DOCUMENT_ENCODER_MODEL = (
@@ -21,6 +21,7 @@ NORMALIZE_EMBEDDINGS = (
os.environ.get("NORMALIZE_EMBEDDINGS") or "False"
).lower() == "true"
# These are only used if reranking is turned off, to normalize the direct retrieval scores for display
# Currently unused
SIM_SCORE_RANGE_LOW = float(os.environ.get("SIM_SCORE_RANGE_LOW") or 0.0)
SIM_SCORE_RANGE_HIGH = float(os.environ.get("SIM_SCORE_RANGE_HIGH") or 1.0)
# Certain models like e5, BGE, etc use a prefix for asymmetric retrievals (query generally shorter than docs)
@@ -34,7 +35,12 @@ MIN_THREADS_ML_MODELS = int(os.environ.get("MIN_THREADS_ML_MODELS") or 1)
# Cross Encoder Settings
SKIP_RERANKING = os.environ.get("SKIP_RERANKING", "").lower() == "true"
ENABLE_RERANKING_ASYNC_FLOW = (
os.environ.get("ENABLE_RERANKING_ASYNC_FLOW", "").lower() == "true"
)
ENABLE_RERANKING_REAL_TIME_FLOW = (
os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true"
)
# https://www.sbert.net/docs/pretrained-models/ce-msmarco.html
CROSS_ENCODER_MODEL_ENSEMBLE = [
"cross-encoder/ms-marco-MiniLM-L-4-v2",
@@ -70,6 +76,11 @@ INTENT_MODEL_VERSION = "danswer/intent-model"
GEN_AI_MODEL_PROVIDER = os.environ.get("GEN_AI_MODEL_PROVIDER") or "openai"
# If using Azure, it's the engine name, for example: Danswer
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION") or "gpt-3.5-turbo"
# For secondary flows like extracting filters or deciding if a chunk is useful, we don't need
# as powerful of a model as say GPT-4 so we can use an alternative that is faster and cheaper
FAST_GEN_AI_MODEL_VERSION = (
os.environ.get("FAST_GEN_AI_MODEL_VERSION") or GEN_AI_MODEL_VERSION
)
# If the Generative AI model requires an API key for access, otherwise can leave blank
GEN_AI_API_KEY = (
@@ -80,9 +91,14 @@ GEN_AI_API_KEY = (
GEN_AI_API_ENDPOINT = os.environ.get("GEN_AI_API_ENDPOINT") or None
# API Version, such as (for Azure): 2023-09-15-preview
GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None
# LiteLLM custom_llm_provider
GEN_AI_LLM_PROVIDER_TYPE = os.environ.get("GEN_AI_LLM_PROVIDER_TYPE") or None
# Set this to be enough for an answer + quotes. Also used for Chat
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024)
# This next restriction is only used for chat ATM, used to expire old messages as needed
GEN_AI_MAX_INPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_INPUT_TOKENS") or 3000)
# History for secondary LLM flows, not primary chat flow, generally we don't need to
# include as much as possible as this just bumps up the cost unnecessarily
GEN_AI_HISTORY_CUTOFF = int(0.5 * GEN_AI_MAX_INPUT_TOKENS)
GEN_AI_TEMPERATURE = float(os.environ.get("GEN_AI_TEMPERATURE") or 0)

View File

@@ -1,3 +1,5 @@
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/backend/danswer/connectors/README.md"} -->
# Writing a new Danswer Connector
This README covers how to contribute a new Connector for Danswer. It includes an overview of the design, interfaces,
and required changes.

View File

@@ -8,6 +8,7 @@ from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.bookstack.client import BookStackApiClient
from danswer.connectors.cross_connector_utils.html_utils import parse_html_page_basic
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
@@ -72,13 +73,21 @@ class BookstackConnector(LoadConnector, PollConnector):
bookstack_client: BookStackApiClient, book: dict[str, Any]
) -> Document:
url = bookstack_client.build_app_url("/books/" + str(book.get("slug")))
title = str(book.get("name", ""))
text = book.get("name", "") + "\n" + book.get("description", "")
updated_at_str = (
str(book.get("updated_at")) if book.get("updated_at") is not None else None
)
return Document(
id="book:" + str(book.get("id")),
id="book__" + str(book.get("id")),
sections=[Section(link=url, text=text)],
source=DocumentSource.BOOKSTACK,
semantic_identifier="Book: " + str(book.get("name")),
metadata={"type": "book", "updated_at": str(book.get("updated_at"))},
semantic_identifier="Book: " + title,
title=title,
doc_updated_at=time_str_to_utc(updated_at_str)
if updated_at_str is not None
else None,
metadata={"type": "book"},
)
@staticmethod
@@ -91,13 +100,23 @@ class BookstackConnector(LoadConnector, PollConnector):
+ "/chapter/"
+ str(chapter.get("slug"))
)
title = str(chapter.get("name", ""))
text = chapter.get("name", "") + "\n" + chapter.get("description", "")
updated_at_str = (
str(chapter.get("updated_at"))
if chapter.get("updated_at") is not None
else None
)
return Document(
id="chapter:" + str(chapter.get("id")),
id="chapter__" + str(chapter.get("id")),
sections=[Section(link=url, text=text)],
source=DocumentSource.BOOKSTACK,
semantic_identifier="Chapter: " + str(chapter.get("name")),
metadata={"type": "chapter", "updated_at": str(chapter.get("updated_at"))},
semantic_identifier="Chapter: " + title,
title=title,
doc_updated_at=time_str_to_utc(updated_at_str)
if updated_at_str is not None
else None,
metadata={"type": "chapter"},
)
@staticmethod
@@ -105,13 +124,23 @@ class BookstackConnector(LoadConnector, PollConnector):
bookstack_client: BookStackApiClient, shelf: dict[str, Any]
) -> Document:
url = bookstack_client.build_app_url("/shelves/" + str(shelf.get("slug")))
title = str(shelf.get("name", ""))
text = shelf.get("name", "") + "\n" + shelf.get("description", "")
updated_at_str = (
str(shelf.get("updated_at"))
if shelf.get("updated_at") is not None
else None
)
return Document(
id="shelf:" + str(shelf.get("id")),
sections=[Section(link=url, text=text)],
source=DocumentSource.BOOKSTACK,
semantic_identifier="Shelf: " + str(shelf.get("name")),
metadata={"type": "shelf", "updated_at": shelf.get("updated_at")},
semantic_identifier="Shelf: " + title,
title=title,
doc_updated_at=time_str_to_utc(updated_at_str)
if updated_at_str is not None
else None,
metadata={"type": "shelf"},
)
@staticmethod
@@ -119,7 +148,7 @@ class BookstackConnector(LoadConnector, PollConnector):
bookstack_client: BookStackApiClient, page: dict[str, Any]
) -> Document:
page_id = str(page.get("id"))
page_name = str(page.get("name"))
title = str(page.get("name", ""))
page_data = bookstack_client.get("/pages/" + page_id, {})
url = bookstack_client.build_app_url(
"/books/"
@@ -127,17 +156,24 @@ class BookstackConnector(LoadConnector, PollConnector):
+ "/page/"
+ str(page_data.get("slug"))
)
page_html = (
"<h1>" + html.escape(page_name) + "</h1>" + str(page_data.get("html"))
)
page_html = "<h1>" + html.escape(title) + "</h1>" + str(page_data.get("html"))
text = parse_html_page_basic(page_html)
updated_at_str = (
str(page_data.get("updated_at"))
if page_data.get("updated_at") is not None
else None
)
time.sleep(0.1)
return Document(
id="page:" + page_id,
sections=[Section(link=url, text=text)],
source=DocumentSource.BOOKSTACK,
semantic_identifier="Page: " + str(page_name),
metadata={"type": "page", "updated_at": page_data.get("updated_at")},
semantic_identifier="Page: " + str(title),
title=str(title),
doc_updated_at=time_str_to_utc(updated_at_str)
if updated_at_str is not None
else None,
metadata={"type": "page"},
)
def load_from_state(self) -> GenerateDocumentsOutput:

View File

@@ -2,10 +2,12 @@ from collections.abc import Callable
from collections.abc import Collection
from datetime import datetime
from datetime import timezone
from functools import lru_cache
from typing import Any
from typing import cast
from urllib.parse import urlparse
import bs4
from atlassian import Confluence # type:ignore
from requests import HTTPError
@@ -13,11 +15,12 @@ from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.html_utils import parse_html_page_basic
from danswer.connectors.cross_connector_utils.html_utils import format_document_soup
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
@@ -84,6 +87,53 @@ def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, bool]:
return wiki_base, space, is_confluence_cloud
@lru_cache()
def _get_user(user_id: str, confluence_client: Confluence) -> str:
"""Get Confluence Display Name based on the account-id or userkey value
Args:
user_id (str): The user id (i.e: the account-id or userkey)
confluence_client (Confluence): The Confluence Client
Returns:
str: The User Display Name. 'Unknown User' if the user is deactivated or not found
"""
user_not_found = "Unknown User"
try:
return confluence_client.get_user_details_by_accountid(user_id).get(
"displayName", user_not_found
)
except Exception as e:
logger.warning(
f"Unable to get the User Display Name with the id: '{user_id}' - {e}"
)
return user_not_found
def parse_html_page(text: str, confluence_client: Confluence) -> str:
"""Parse a Confluence html page and replace the 'user Id' by the real
User Display Name
Args:
text (str): The page content
confluence_client (Confluence): Confluence client
Returns:
str: loaded and formated Confluence page
"""
soup = bs4.BeautifulSoup(text, "html.parser")
for user in soup.findAll("ri:user"):
user_id = (
user.attrs["ri:account-id"]
if "ri:account-id" in user.attrs
else user.attrs["ri:userkey"]
)
# Include @ sign for tagging, more clear for LLM
user.replaceWith("@" + _get_user(user_id, confluence_client))
return format_document_soup(soup)
def _comment_dfs(
comments_str: str,
comment_pages: Collection[dict[str, Any]],
@@ -91,7 +141,9 @@ def _comment_dfs(
) -> str:
for comment_page in comment_pages:
comment_html = comment_page["body"]["storage"]["value"]
comments_str += "\nComment:\n" + parse_html_page_basic(comment_html)
comments_str += "\nComment:\n" + parse_html_page(
comment_html, confluence_client
)
child_comment_pages = confluence_client.get_page_child_by_type(
comment_page["id"],
type="comment",
@@ -281,9 +333,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
if not page_html:
logger.debug("Page is empty, skipping: %s", page_url)
continue
page_text = (
page.get("title", "") + "\n" + parse_html_page_basic(page_html)
)
page_text = parse_html_page(page_html, self.confluence_client)
comments_text = self._fetch_comments(self.confluence_client, page_id)
page_text += comments_text
@@ -294,7 +344,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
source=DocumentSource.CONFLUENCE,
semantic_identifier=page["title"],
doc_updated_at=last_modified,
primary_owners=[author] if author else None,
primary_owners=[BasicExpertInfo(email=author)]
if author
else None,
metadata={
"Wiki Space Name": self.space,
},

View File

@@ -1,45 +1,71 @@
import json
import os
import re
import zipfile
from collections.abc import Generator
from pathlib import Path
from typing import Any
from typing import IO
import chardet
from pypdf import PdfReader
from pypdf.errors import PdfStreamError
from danswer.utils.logger import setup_logger
logger = setup_logger()
_METADATA_FLAG = "#DANSWER_METADATA="
def extract_metadata(line: str) -> dict | None:
html_comment_pattern = r"<!--\s*DANSWER_METADATA=\{(.*?)\}\s*-->"
hashtag_pattern = r"#DANSWER_METADATA=\{(.*?)\}"
html_comment_match = re.search(html_comment_pattern, line)
hashtag_match = re.search(hashtag_pattern, line)
if html_comment_match:
json_str = html_comment_match.group(1)
elif hashtag_match:
json_str = hashtag_match.group(1)
else:
return None
try:
return json.loads("{" + json_str + "}")
except json.JSONDecodeError:
return None
def read_pdf_file(file: IO[Any], file_name: str, pdf_pass: str | None = None) -> str:
pdf_reader = PdfReader(file)
# if marked as encrypted and a password is provided, try to decrypt
if pdf_reader.is_encrypted and pdf_pass is not None:
decrypt_success = False
if pdf_pass is not None:
try:
decrypt_success = pdf_reader.decrypt(pdf_pass) != 0
except Exception:
logger.error(f"Unable to decrypt pdf {file_name}")
else:
logger.info(f"No Password available to to decrypt pdf {file_name}")
if not decrypt_success:
# By user request, keep files that are unreadable just so they
# can be discoverable by title.
return ""
try:
pdf_reader = PdfReader(file)
# If marked as encrypted and a password is provided, try to decrypt
if pdf_reader.is_encrypted and pdf_pass is not None:
decrypt_success = False
if pdf_pass is not None:
try:
decrypt_success = pdf_reader.decrypt(pdf_pass) != 0
except Exception:
logger.error(f"Unable to decrypt pdf {file_name}")
else:
logger.info(f"No Password available to to decrypt pdf {file_name}")
if not decrypt_success:
# By user request, keep files that are unreadable just so they
# can be discoverable by title.
return ""
return "\n".join(page.extract_text() for page in pdf_reader.pages)
except PdfStreamError:
logger.exception(f"PDF file {file_name} is not a valid PDF")
except Exception:
logger.exception(f"Failed to read PDF {file_name}")
return ""
# File is still discoverable by title
# but the contents are not included as they cannot be parsed
return ""
def is_macos_resource_fork_file(file_name: str) -> bool:
@@ -66,16 +92,33 @@ def load_files_from_zip(
yield file_info, file
def read_file(file_reader: IO[Any]) -> tuple[str, dict[str, Any]]:
def detect_encoding(file_path: str | Path) -> str:
with open(file_path, "rb") as file:
raw_data = file.read(50000) # Read a portion of the file to guess encoding
return chardet.detect(raw_data)["encoding"] or "utf-8"
def read_file(
file_reader: IO[Any], encoding: str = "utf-8", errors: str = "replace"
) -> tuple[str, dict]:
metadata = {}
file_content_raw = ""
for ind, line in enumerate(file_reader):
if isinstance(line, bytes):
line = line.decode("utf-8")
line = str(line)
try:
line = line.decode(encoding) if isinstance(line, bytes) else line
except UnicodeDecodeError:
line = (
line.decode(encoding, errors=errors)
if isinstance(line, bytes)
else line
)
if ind == 0 and line.startswith(_METADATA_FLAG):
metadata = json.loads(line.replace(_METADATA_FLAG, "", 1).strip())
if ind == 0:
metadata_or_none = extract_metadata(line)
if metadata_or_none is not None:
metadata = metadata_or_none
else:
file_content_raw += line
else:
file_content_raw += line

View File

@@ -0,0 +1,45 @@
from datetime import datetime
from datetime import timezone
from dateutil.parser import parse
from danswer.connectors.models import BasicExpertInfo
from danswer.utils.text_processing import is_valid_email
def datetime_to_utc(dt: datetime) -> datetime:
if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt.astimezone(timezone.utc)
def time_str_to_utc(datetime_str: str) -> datetime:
dt = parse(datetime_str)
return datetime_to_utc(dt)
def basic_expert_info_representation(info: BasicExpertInfo) -> str | None:
if info.first_name and info.last_name:
return f"{info.first_name} {info.middle_initial} {info.last_name}"
if info.display_name:
return info.display_name
if info.email and is_valid_email(info.email):
return info.email
if info.first_name:
return info.first_name
return None
def get_experts_stores_representations(
experts: list[BasicExpertInfo] | None,
) -> list[str] | None:
if not experts:
return None
reps = [basic_expert_info_representation(owner) for owner in experts]
return [owner for owner in reps if owner is not None]

View File

@@ -3,16 +3,17 @@ from datetime import timezone
from typing import Any
from urllib.parse import urlparse
from dateutil.parser import parse
from jira import JIRA
from jira.resources import Issue
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
@@ -60,26 +61,32 @@ def fetch_jira_issues_batch(
logger.warning(f"Found Jira object not of type Issue {jira}")
continue
ticket_updated_time = parse(jira.fields.updated)
semantic_rep = (
f"Jira Ticket Summary: {jira.fields.summary}\n"
f"Description: {jira.fields.description}\n"
+ "\n".join(
[f"Comment: {comment.body}" for comment in jira.fields.comment.comments]
)
semantic_rep = f"{jira.fields.description}\n" + "\n".join(
[f"Comment: {comment.body}" for comment in jira.fields.comment.comments]
)
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
author = None
try:
author = BasicExpertInfo(
display_name=jira.fields.creator.displayName,
email=jira.fields.creator.emailAddress,
)
except Exception:
# Author should exist but if not, doesn't matter
pass
doc_batch.append(
Document(
id=page_url,
sections=[Section(link=page_url, text=semantic_rep)],
source=DocumentSource.JIRA,
semantic_identifier=jira.fields.summary,
doc_updated_at=ticket_updated_time.astimezone(timezone.utc),
metadata={},
doc_updated_at=time_str_to_utc(jira.fields.updated),
primary_owners=[author] if author is not None else None,
# TODO add secondary_owners if needed
metadata={"label": jira.fields.labels} if jira.fields.labels else {},
)
)
return doc_batch, len(batch)

View File

@@ -140,11 +140,7 @@ class Document360Connector(LoadConnector, PollConnector):
html_content = article_details["html_content"]
article_content = parse_html_page_basic(html_content)
doc_text = (
f"workspace: {self.workspace}\n"
f"category: {article['category_name']}\n"
f"article: {article_details['title']} - "
f"{article_details.get('description', '')}\n"
f"{article_content}"
f"{article_details.get('description', '')}\n{article_content}".strip()
)
document = Document(
@@ -154,7 +150,10 @@ class Document360Connector(LoadConnector, PollConnector):
semantic_identifier=article_details["title"],
doc_updated_at=updated_at,
primary_owners=authors,
metadata={},
metadata={
"workspace": self.workspace,
"category": article["category_name"],
},
)
doc_batch.append(document)
@@ -190,8 +189,8 @@ if __name__ == "__main__":
)
current = time.time()
one_day_ago = current - 24 * 60 * 60 * 360 # 1 year
latest_docs = document360_connector.poll_source(one_day_ago, current)
one_year_ago = current - 24 * 60 * 60 * 360
latest_docs = document360_connector.poll_source(one_year_ago, current)
for doc in latest_docs:
print(doc)

View File

@@ -21,6 +21,7 @@ from danswer.connectors.linear.connector import LinearConnector
from danswer.connectors.models import InputType
from danswer.connectors.notion.connector import NotionConnector
from danswer.connectors.productboard.connector import ProductboardConnector
from danswer.connectors.requesttracker.connector import RequestTrackerConnector
from danswer.connectors.slab.connector import SlabConnector
from danswer.connectors.slack.connector import SlackLoadConnector
from danswer.connectors.slack.connector import SlackPollConnector
@@ -53,6 +54,7 @@ def identify_connector_class(
DocumentSource.SLAB: SlabConnector,
DocumentSource.NOTION: NotionConnector,
DocumentSource.ZULIP: ZulipConnector,
DocumentSource.REQUESTTRACKER: RequestTrackerConnector,
DocumentSource.GURU: GuruConnector,
DocumentSource.LINEAR: LinearConnector,
DocumentSource.HUBSPOT: HubSpotConnector,

View File

@@ -8,9 +8,11 @@ from typing import IO
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.file_utils import detect_encoding
from danswer.connectors.cross_connector_utils.file_utils import load_files_from_zip
from danswer.connectors.cross_connector_utils.file_utils import read_file
from danswer.connectors.cross_connector_utils.file_utils import read_pdf_file
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.file.utils import check_file_ext_is_valid
from danswer.connectors.file.utils import get_file_ext
from danswer.connectors.interfaces import GenerateDocumentsOutput
@@ -31,11 +33,12 @@ def _open_files_at_location(
if extension == ".zip":
for file_info, file in load_files_from_zip(file_path, ignore_dirs=True):
yield file_info.filename, file
elif extension == ".txt" or extension == ".pdf":
mode = "r"
if extension == ".pdf":
mode = "rb"
with open(file_path, mode) as file:
elif extension in [".txt", ".md", ".mdx"]:
encoding = detect_encoding(file_path)
with open(file_path, "r", encoding=encoding, errors="replace") as file:
yield os.path.basename(file_path), file
elif extension == ".pdf":
with open(file_path, "rb") as file:
yield os.path.basename(file_path), file
else:
logger.warning(f"Skipping file '{file_path}' with extension '{extension}'")
@@ -61,13 +64,20 @@ def _process_file(
else:
file_content_raw, metadata = read_file(file)
dt_str = metadata.get("doc_updated_at")
final_time_updated = time_str_to_utc(dt_str) if dt_str else time_updated
return [
Document(
id=file_name,
sections=[Section(link=metadata.get("link", ""), text=file_content_raw)],
sections=[
Section(link=metadata.get("link"), text=file_content_raw.strip())
],
source=DocumentSource.FILE,
semantic_identifier=file_name,
doc_updated_at=time_updated,
doc_updated_at=final_time_updated,
primary_owners=metadata.get("primary_owners"),
secondary_owners=metadata.get("secondary_owners"),
metadata={},
)
]

View File

@@ -8,7 +8,7 @@ from typing import IO
from danswer.configs.app_configs import FILE_CONNECTOR_TMP_STORAGE_PATH
_VALID_FILE_EXTENSIONS = [".txt", ".zip", ".pdf"]
_VALID_FILE_EXTENSIONS = [".txt", ".zip", ".pdf", ".md", ".mdx"]
def get_file_ext(file_path_or_name: str | Path) -> str:

View File

@@ -37,10 +37,9 @@ def _batch_github_objects(
def _convert_pr_to_document(pull_request: PullRequest) -> Document:
full_context = f"Pull-Request {pull_request.title}\n{pull_request.body}"
return Document(
id=pull_request.html_url,
sections=[Section(link=pull_request.html_url, text=full_context)],
sections=[Section(link=pull_request.html_url, text=pull_request.body or "")],
source=DocumentSource.GITHUB,
semantic_identifier=pull_request.title,
# updated_at is UTC time but is timezone unaware, explicitly add UTC
@@ -48,7 +47,7 @@ def _convert_pr_to_document(pull_request: PullRequest) -> Document:
# due to local time discrepancies with UTC
doc_updated_at=pull_request.updated_at.replace(tzinfo=timezone.utc),
metadata={
"merged": pull_request.merged,
"merged": str(pull_request.merged),
"state": pull_request.state,
},
)
@@ -60,10 +59,9 @@ def _fetch_issue_comments(issue: Issue) -> str:
def _convert_issue_to_document(issue: Issue) -> Document:
full_context = f"Issue {issue.title}\n{issue.body}"
return Document(
id=issue.html_url,
sections=[Section(link=issue.html_url, text=full_context)],
sections=[Section(link=issue.html_url, text=issue.body or "")],
source=DocumentSource.GITHUB,
semantic_identifier=issue.title,
# updated_at is UTC time but is timezone unaware

View File

@@ -32,7 +32,6 @@ class GongConnector(LoadConnector, PollConnector):
self,
workspaces: list[str] | None = None,
batch_size: int = INDEX_BATCH_SIZE,
use_end_time: bool = False,
continue_on_fail: bool = CONTINUE_ON_CONNECTOR_FAILURE,
hide_user_info: bool = False,
) -> None:
@@ -40,7 +39,6 @@ class GongConnector(LoadConnector, PollConnector):
self.batch_size: int = batch_size
self.continue_on_fail = continue_on_fail
self.auth_token_basic: str | None = None
self.use_end_time = use_end_time
self.hide_user_info = hide_user_info
def _get_auth_header(self) -> dict[str, str]:
@@ -102,7 +100,12 @@ class GongConnector(LoadConnector, PollConnector):
# If no calls in the range, just break out
if response.status_code == 404:
break
response.raise_for_status()
try:
response.raise_for_status()
except Exception:
logger.error(f"Error fetching transcripts: {response.text}")
raise
data = response.json()
call_transcripts = data.get("callTranscripts", [])
@@ -203,9 +206,6 @@ class GongConnector(LoadConnector, PollConnector):
speaker_to_name: dict[str, str] = {}
transcript_text = ""
if call_title:
transcript_text += f"Call Title: {call_title}\n\n"
call_purpose = call_metadata["purpose"]
if call_purpose:
transcript_text += f"Call Description: {call_purpose}\n\n"
@@ -231,6 +231,11 @@ class GongConnector(LoadConnector, PollConnector):
)
transcript_text += f"{speaker_name}: {monolog}\n\n"
metadata = {}
if call_metadata.get("system"):
metadata["client"] = call_metadata.get("system")
# TODO calls have a clientUniqueId field, can pull that in later
doc_batch.append(
Document(
id=call_id,
@@ -243,7 +248,7 @@ class GongConnector(LoadConnector, PollConnector):
doc_updated_at=datetime.fromisoformat(call_time_str).astimezone(
timezone.utc
),
metadata={},
metadata={"client": call_metadata.get("system")},
)
)
yield doc_batch
@@ -263,6 +268,8 @@ class GongConnector(LoadConnector, PollConnector):
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
# if this env variable is set, don't start from a timestamp before the specified
# start time
# TODO: remove this once this is globally available
@@ -272,6 +279,10 @@ class GongConnector(LoadConnector, PollConnector):
else:
special_start_datetime = datetime.fromtimestamp(0, tz=timezone.utc)
# don't let the special start dt be past the end time, this causes issues when
# the Gong API (`filter.fromDateTime: must be before toDateTime`)
special_start_datetime = min(special_start_datetime, end_datetime)
start_datetime = max(
datetime.fromtimestamp(start, tz=timezone.utc), special_start_datetime
)
@@ -280,11 +291,8 @@ class GongConnector(LoadConnector, PollConnector):
# so adding a 1 day buffer and fetching by default till current time
start_one_day_offset = start_datetime - timedelta(days=1)
start_time = start_one_day_offset.isoformat()
end_time = (
datetime.fromtimestamp(end, tz=timezone.utc).isoformat()
if self.use_end_time
else None
)
end_time = datetime.fromtimestamp(end, tz=timezone.utc).isoformat()
logger.info(f"Fetching Gong calls between {start_time} and {end_time}")
return self._fetch_calls(start_time, end_time)
@@ -292,7 +300,6 @@ class GongConnector(LoadConnector, PollConnector):
if __name__ == "__main__":
import os
import time
connector = GongConnector()
connector.load_credentials(
@@ -302,6 +309,5 @@ if __name__ == "__main__":
}
)
current = time.time()
latest_docs = connector.load_from_state()
print(next(latest_docs))

View File

@@ -62,7 +62,10 @@ class GDriveMimeType(str, Enum):
GoogleDriveFileType = dict[str, Any]
add_retries = retry_builder()
# Google Drive APIs are quite flakey and may 500 for an
# extended period of time. Trying to combat here by adding a very
# long retry period (~20 minutes of trying every minute)
add_retries = retry_builder(tries=50, max_delay=30)
def _run_drive_file_query(
@@ -101,12 +104,18 @@ def _run_drive_file_query(
for file in files:
if follow_shortcuts and "shortcutDetails" in file:
try:
file = service.files().get(
fileId=file["shortcutDetails"]["targetId"],
supportsAllDrives=include_shared,
fields="mimeType, id, name, modifiedTime, webViewLink, shortcutDetails",
)
file = add_retries(lambda: file.execute())()
file_shortcut_points_to = add_retries(
lambda: (
service.files()
.get(
fileId=file["shortcutDetails"]["targetId"],
supportsAllDrives=include_shared,
fields="mimeType, id, name, modifiedTime, webViewLink, shortcutDetails",
)
.execute()
)
)()
yield file_shortcut_points_to
except HttpError:
logger.error(
f"Failed to follow shortcut with details: {file['shortcutDetails']}"
@@ -114,7 +123,8 @@ def _run_drive_file_query(
if continue_on_failure:
continue
raise
yield file
else:
yield file
def _get_folder_id(
@@ -456,24 +466,20 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
doc_batch = []
for file in files_batch:
try:
text_contents = extract_text(file, service)
if text_contents:
full_context = file["name"] + " - " + text_contents
else:
full_context = file["name"]
text_contents = extract_text(file, service) or ""
doc_batch.append(
Document(
id=file["webViewLink"],
sections=[
Section(link=file["webViewLink"], text=full_context)
Section(link=file["webViewLink"], text=text_contents)
],
source=DocumentSource.GOOGLE_DRIVE,
semantic_identifier=file["name"],
doc_updated_at=datetime.fromisoformat(
file["modifiedTime"]
).astimezone(timezone.utc),
metadata={} if text_contents else {IGNORE_FOR_QA: True},
metadata={} if text_contents else {IGNORE_FOR_QA: "True"},
)
)
except Exception as e:

View File

@@ -25,9 +25,9 @@ from danswer.connectors.google_drive.constants import SCOPES
from danswer.db.credentials import update_credential_json
from danswer.db.models import User
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.server.models import CredentialBase
from danswer.server.models import GoogleAppCredentials
from danswer.server.models import GoogleServiceAccountKey
from danswer.server.documents.models import CredentialBase
from danswer.server.documents.models import GoogleAppCredentials
from danswer.server.documents.models import GoogleServiceAccountKey
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -130,7 +130,7 @@ def build_service_account_creds(
return CredentialBase(
credential_json=credential_dict,
is_admin=True,
admin_public=True,
)

View File

@@ -1,5 +1,5 @@
import os
import urllib.parse
import re
from typing import Any
from typing import cast
@@ -20,42 +20,31 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
def process_link(element: BeautifulSoup | Tag) -> str:
href = cast(str | None, element.get("href"))
if not href:
raise RuntimeError(f"Invalid link - {element}")
def a_tag_text_to_path(atag: Tag) -> str:
page_path = atag.text.strip().lower()
page_path = re.sub(r"[^a-zA-Z0-9\s]", "", page_path)
page_path = "-".join(page_path.split())
# cleanup href
href = urllib.parse.unquote(href)
href = href.rstrip(".html").lower()
href = href.replace("_", "")
href = href.replace(" ", "-")
return href
return page_path
def find_google_sites_page_path_from_navbar(
element: BeautifulSoup | Tag, path: str, is_initial: bool
element: BeautifulSoup | Tag, path: str, depth: int
) -> str | None:
ul = cast(Tag | None, element.find("ul"))
if ul:
if not is_initial:
a = cast(Tag, element.find("a"))
new_path = f"{path}/{process_link(a)}"
if a.get("aria-selected") == "true":
return new_path
else:
new_path = ""
for li in ul.find_all("li", recursive=False):
found_link = find_google_sites_page_path_from_navbar(li, new_path, False)
if found_link:
return found_link
else:
a = cast(Tag, element.find("a"))
if a:
href = process_link(a)
if href and a.get("aria-selected") == "true":
return path + "/" + href
lis = cast(
list[Tag],
element.find_all("li", attrs={"data-nav-level": f"{depth}"}),
)
for li in lis:
a = cast(Tag, li.find("a"))
if a.get("aria-selected") == "true":
return f"{path}/{a_tag_text_to_path(a)}"
elif a.get("aria-expanded") == "true":
sub_path = find_google_sites_page_path_from_navbar(
element, f"{path}/{a_tag_text_to_path(a)}", depth + 1
)
if sub_path:
return sub_path
return None
@@ -79,6 +68,7 @@ class GoogleSitesConnector(LoadConnector):
# load the HTML files
files = load_files_from_zip(self.zip_path)
count = 0
for file_info, file_io in files:
# skip non-published files
if "/PUBLISHED/" not in file_info.filename:
@@ -94,13 +84,15 @@ class GoogleSitesConnector(LoadConnector):
# get the link out of the navbar
header = cast(Tag, soup.find("header"))
nav = cast(Tag, header.find("nav"))
path = find_google_sites_page_path_from_navbar(nav, "", True)
path = find_google_sites_page_path_from_navbar(nav, "", 1)
if not path:
count += 1
logger.error(
f"Could not find path for '{file_info.filename}'. "
+ "This page will not have a working link."
+ "This page will not have a working link.\n\n"
+ f"# of broken links so far - {count}"
)
logger.info(f"Path to page: {path}")
# cleanup the hidden `Skip to main content` and `Skip to navigation` that
# appears at the top of every page
for div in soup.find_all("div", attrs={"data-is-touch-wrapper": "true"}):

View File

@@ -8,6 +8,7 @@ import requests
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.html_utils import parse_html_page_basic
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
@@ -76,14 +77,26 @@ class GuruConnector(LoadConnector, PollConnector):
for card in cards:
title = card["preferredPhrase"]
link = GURU_CARDS_URL + card["slug"]
content_text = title + "\n" + parse_html_page_basic(card["content"])
content_text = parse_html_page_basic(card["content"])
last_updated = time_str_to_utc(card["lastModified"])
last_verified = (
time_str_to_utc(card.get("lastVerified"))
if card.get("lastVerified")
else None
)
# For Danswer, we decay document score overtime, either last_updated or
# last_verified is a good enough signal for the document's recency
latest_time = (
max(last_verified, last_updated) if last_verified else last_updated
)
doc_batch.append(
Document(
id=card["id"],
sections=[Section(link=link, text=content_text)],
source=DocumentSource.GURU,
semantic_identifier=title,
doc_updated_at=latest_time,
metadata={},
)
)
@@ -109,3 +122,18 @@ class GuruConnector(LoadConnector, PollConnector):
end_time = unixtime_to_guru_time_str(end)
return self._process_cards(start_time, end_time)
if __name__ == "__main__":
import os
connector = GuruConnector()
connector.load_credentials(
{
"guru_user": os.environ["GURU_USER"],
"guru_user_token": os.environ["GURU_USER_TOKEN"],
}
)
latest_docs = connector.load_from_state()
print(next(latest_docs))

View File

@@ -73,7 +73,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
title = ticket.properties["subject"]
link = self.ticket_base_url + ticket.id
content_text = title + "\n" + ticket.properties["content"]
content_text = ticket.properties["content"]
associated_emails: list[str] = []
associated_notes: list[str] = []

View File

@@ -8,6 +8,7 @@ import requests
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
@@ -30,7 +31,6 @@ def _make_query(request_body: dict[str, Any], api_key: str) -> requests.Response
"Content-Type": "application/json",
}
response: requests.Response | None = None
for i in range(_NUM_RETRIES):
try:
response = requests.post(
@@ -187,8 +187,8 @@ class LinearConnector(LoadConnector, PollConnector):
],
source=DocumentSource.LINEAR,
semantic_identifier=node["identifier"],
doc_updated_at=time_str_to_utc(node["updatedAt"]),
metadata={
"updated_at": node["updatedAt"],
"team": node["team"]["name"],
},
)

View File

@@ -1,9 +1,17 @@
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Any
from pydantic import BaseModel
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import INDEX_SEPARATOR
from danswer.utils.text_processing import make_url_compatible
class InputType(str, Enum):
LOAD_STATE = "load_state" # e.g. loading a current full state or a save state, such as from a file
POLL = "poll" # e.g. calling an API to get all documents in the last hour
EVENT = "event" # e.g. registered an endpoint as a listener, and processing connector events
class ConnectorMissingCredentialError(PermissionError):
@@ -14,44 +22,93 @@ class ConnectorMissingCredentialError(PermissionError):
)
@dataclass
class Section:
link: str
class Section(BaseModel):
text: str
link: str | None
@dataclass
class Document:
id: str # This must be unique or during indexing/reindexing, chunks will be overwritten
class BasicExpertInfo(BaseModel):
"""Basic Information for the owner of a document, any of the fields can be left as None
Display fallback goes as follows:
- first_name + (optional middle_initial) + last_name
- display_name
- email
- first_name
"""
display_name: str | None = None
first_name: str | None = None
middle_initial: str | None = None
last_name: str | None = None
email: str | None = None
class DocumentBase(BaseModel):
"""Used for Danswer ingestion api, the ID is inferred before use if not provided"""
id: str | None = None
sections: list[Section]
source: DocumentSource
source: DocumentSource | None = None
semantic_identifier: str # displayed in the UI as the main identifier for the doc
metadata: dict[str, Any]
metadata: dict[str, str | list[str]]
# UTC time
doc_updated_at: datetime | None = None
# Owner, creator, etc.
primary_owners: list[str] | None = None
primary_owners: list[BasicExpertInfo] | None = None
# Assignee, space owner, etc.
secondary_owners: list[str] | None = None
# `title` is used when computing best matches for a query
# if `None`, then we will use the `semantic_identifier` as the title in Vespa
secondary_owners: list[BasicExpertInfo] | None = None
# title is used for search whereas semantic_identifier is used for displaying in the UI
# different because Slack message may display as #general but general should not be part
# of the search, at least not in the same way as a document title should be for like Confluence
# The default title is semantic_identifier though unless otherwise specified
title: str | None = None
from_ingestion_api: bool = False
def get_title_for_document_index(self) -> str:
def get_title_for_document_index(self) -> str | None:
# If title is explicitly empty, return a None here for embedding purposes
if self.title == "":
return None
return self.semantic_identifier if self.title is None else self.title
def get_metadata_str_attributes(self) -> list[str] | None:
if not self.metadata:
return None
# Combined string for the key/value for easy filtering
attributes: list[str] = []
for k, v in self.metadata.items():
if isinstance(v, list):
attributes.extend([k + INDEX_SEPARATOR + vi for vi in v])
else:
attributes.append(k + INDEX_SEPARATOR + v)
return attributes
class Document(DocumentBase):
id: str # This must be unique or during indexing/reindexing, chunks will be overwritten
source: DocumentSource
def to_short_descriptor(self) -> str:
"""Used when logging the identity of a document"""
return f"ID: '{self.id}'; Semantic ID: '{self.semantic_identifier}'"
class InputType(str, Enum):
LOAD_STATE = "load_state" # e.g. loading a current full state or a save state, such as from a file
POLL = "poll" # e.g. calling an API to get all documents in the last hour
EVENT = "event" # e.g. registered an endpoint as a listener, and processing connector events
@classmethod
def from_base(cls, base: DocumentBase) -> "Document":
return cls(
id=make_url_compatible(base.id)
if base.id
else "ingestion_api_" + make_url_compatible(base.semantic_identifier),
sections=base.sections,
source=base.source or DocumentSource.INGESTION_API,
semantic_identifier=base.semantic_identifier,
metadata=base.metadata,
doc_updated_at=base.doc_updated_at,
primary_owners=base.primary_owners,
secondary_owners=base.secondary_owners,
title=base.title,
from_ingestion_api=base.from_ingestion_api,
)
@dataclass
class IndexAttemptMetadata:
class IndexAttemptMetadata(BaseModel):
connector_id: int
credential_id: int

View File

@@ -24,6 +24,8 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
_NOTION_CALL_TIMEOUT = 30 # 30 seconds
@dataclass
class NotionPage:
@@ -80,6 +82,7 @@ class NotionConnector(LoadConnector, PollConnector):
"Notion-Version": "2022-06-28",
}
self.indexed_pages: set[str] = set()
self.root_page_id = root_page_id
# if enabled, will recursively index child pages as they are found rather
# relying entirely on the `search` API. We have recieved reports that the
# `search` API misses many pages - in those cases, this might need to be
@@ -87,8 +90,9 @@ class NotionConnector(LoadConnector, PollConnector):
# NOTE: this also removes all benefits polling, since we need to traverse
# all pages regardless of if they are updated. If the notion workspace is
# very large, this may not be practical.
self.recursive_index_enabled = recursive_index_enabled
self.root_page_id = root_page_id
self.recursive_index_enabled = (
recursive_index_enabled or self.root_page_id is not None
)
@retry(tries=3, delay=1, backoff=2)
def _fetch_blocks(self, block_id: str, cursor: str | None = None) -> dict[str, Any]:
@@ -96,7 +100,12 @@ class NotionConnector(LoadConnector, PollConnector):
logger.debug(f"Fetching children of block with ID '{block_id}'")
block_url = f"https://api.notion.com/v1/blocks/{block_id}/children"
query_params = None if not cursor else {"start_cursor": cursor}
res = requests.get(block_url, headers=self.headers, params=query_params)
res = requests.get(
block_url,
headers=self.headers,
params=query_params,
timeout=_NOTION_CALL_TIMEOUT,
)
try:
res.raise_for_status()
except Exception as e:
@@ -109,7 +118,11 @@ class NotionConnector(LoadConnector, PollConnector):
"""Fetch a page from it's ID via the Notion API."""
logger.debug(f"Fetching page for ID '{page_id}'")
block_url = f"https://api.notion.com/v1/pages/{page_id}"
res = requests.get(block_url, headers=self.headers)
res = requests.get(
block_url,
headers=self.headers,
timeout=_NOTION_CALL_TIMEOUT,
)
try:
res.raise_for_status()
except Exception as e:
@@ -117,6 +130,64 @@ class NotionConnector(LoadConnector, PollConnector):
raise e
return NotionPage(**res.json())
@retry(tries=3, delay=1, backoff=2)
def _fetch_database(
self, database_id: str, cursor: str | None = None
) -> dict[str, Any]:
"""Fetch a database from it's ID via the Notion API."""
logger.debug(f"Fetching database for ID '{database_id}'")
block_url = f"https://api.notion.com/v1/databases/{database_id}/query"
body = None if not cursor else {"start_cursor": cursor}
res = requests.post(
block_url,
headers=self.headers,
json=body,
timeout=_NOTION_CALL_TIMEOUT,
)
try:
res.raise_for_status()
except Exception as e:
if res.json().get("code") == "object_not_found":
# this happens when a database is not shared with the integration
# in this case, we should just ignore the database
logger.error(
f"Unable to access database with ID '{database_id}'. "
f"This is likely due to the database not being shared "
f"with the Danswer integration. Exact exception:\n{e}"
)
return {"results": [], "next_cursor": None}
logger.exception(f"Error fetching database - {res.json()}")
raise e
return res.json()
def _read_pages_from_database(self, database_id: str) -> list[str]:
"""Returns a list of all page IDs in the database"""
result_pages: list[str] = []
cursor = None
while True:
data = self._fetch_database(database_id, cursor)
for result in data["results"]:
obj_id = result["id"]
obj_type = result["object"]
if obj_type == "page":
logger.debug(
f"Found page with ID '{obj_id}' in database '{database_id}'"
)
result_pages.append(result["id"])
elif obj_type == "database":
logger.debug(
f"Found database with ID '{obj_id}' in database '{database_id}'"
)
result_pages.extend(self._read_pages_from_database(obj_id))
if data["next_cursor"] is None:
break
cursor = data["next_cursor"]
return result_pages
def _read_blocks(
self, page_block_id: str
) -> tuple[list[tuple[str, str]], list[str]]:
@@ -141,8 +212,20 @@ class NotionConnector(LoadConnector, PollConnector):
text = rich_text["text"]["content"]
cur_result_text_arr.append(text)
if result["has_children"] and result_type == "child_page":
child_pages.append(result_block_id)
if result["has_children"]:
if result_type == "child_page":
child_pages.append(result_block_id)
else:
logger.debug(f"Entering sub-block: {result_block_id}")
subblock_result_lines, subblock_child_pages = self._read_blocks(
result_block_id
)
logger.debug(f"Finished sub-block: {result_block_id}")
result_lines.extend(subblock_result_lines)
child_pages.extend(subblock_child_pages)
if result_type == "child_database" and self.recursive_index_enabled:
child_pages.extend(self._read_pages_from_database(result_block_id))
cur_result_text = "\n".join(cur_result_text_arr)
if cur_result_text:
@@ -184,7 +267,8 @@ class NotionConnector(LoadConnector, PollConnector):
yield (
Document(
id=page.id,
sections=[Section(link=page.url, text=f"{page_title}\n")]
# Will add title to the first section later in processing
sections=[Section(link=page.url, text="")]
+ [
Section(
link=f"{page.url}#{block_id.replace('-', '')}",
@@ -221,6 +305,7 @@ class NotionConnector(LoadConnector, PollConnector):
"https://api.notion.com/v1/search",
headers=self.headers,
json=query_dict,
timeout=_NOTION_CALL_TIMEOUT,
)
res.raise_for_status()
return NotionSearchResponse(**res.json())

View File

@@ -10,9 +10,11 @@ from retry import retry
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.utils.logger import setup_logger
@@ -93,26 +95,24 @@ class ProductboardConnector(PollConnector):
for feature in self._fetch_documents(
initial_link=f"{_PRODUCT_BOARD_BASE_URL}/features"
):
owner = self._get_owner_email(feature)
experts = [BasicExpertInfo(email=owner)] if owner else None
yield Document(
id=feature["id"],
sections=[
Section(
link=feature["links"]["html"],
text=" - ".join(
(
feature["name"],
self._parse_description_html(feature["description"]),
)
),
text=self._parse_description_html(feature["description"]),
)
],
semantic_identifier=feature["name"],
source=DocumentSource.PRODUCTBOARD,
doc_updated_at=time_str_to_utc(feature["updatedAt"]),
primary_owners=experts,
metadata={
"productboard_entity_type": feature["type"],
"entity_type": feature["type"],
"status": feature["status"]["name"],
"owner": self._get_owner_email(feature),
"updated_at": feature["updatedAt"],
},
)
@@ -121,25 +121,23 @@ class ProductboardConnector(PollConnector):
for component in self._fetch_documents(
initial_link=f"{_PRODUCT_BOARD_BASE_URL}/components"
):
owner = self._get_owner_email(component)
experts = [BasicExpertInfo(email=owner)] if owner else None
yield Document(
id=component["id"],
sections=[
Section(
link=component["links"]["html"],
text=" - ".join(
(
component["name"],
self._parse_description_html(component["description"]),
)
),
text=self._parse_description_html(component["description"]),
)
],
semantic_identifier=component["name"],
source=DocumentSource.PRODUCTBOARD,
doc_updated_at=time_str_to_utc(component["updatedAt"]),
primary_owners=experts,
metadata={
"productboard_entity_type": "component",
"owner": self._get_owner_email(component),
"updated_at": component["updatedAt"],
"entity_type": "component",
},
)
@@ -149,25 +147,23 @@ class ProductboardConnector(PollConnector):
for product in self._fetch_documents(
initial_link=f"{_PRODUCT_BOARD_BASE_URL}/products"
):
owner = self._get_owner_email(product)
experts = [BasicExpertInfo(email=owner)] if owner else None
yield Document(
id=product["id"],
sections=[
Section(
link=product["links"]["html"],
text=" - ".join(
(
product["name"],
self._parse_description_html(product["description"]),
)
),
text=self._parse_description_html(product["description"]),
)
],
semantic_identifier=product["name"],
source=DocumentSource.PRODUCTBOARD,
doc_updated_at=time_str_to_utc(product["updatedAt"]),
primary_owners=experts,
metadata={
"productboard_entity_type": "product",
"owner": self._get_owner_email(product),
"updated_at": product["updatedAt"],
"entity_type": "product",
},
)
@@ -175,26 +171,24 @@ class ProductboardConnector(PollConnector):
for objective in self._fetch_documents(
initial_link=f"{_PRODUCT_BOARD_BASE_URL}/objectives"
):
owner = self._get_owner_email(objective)
experts = [BasicExpertInfo(email=owner)] if owner else None
yield Document(
id=objective["id"],
sections=[
Section(
link=objective["links"]["html"],
text=" - ".join(
(
objective["name"],
self._parse_description_html(objective["description"]),
)
),
text=self._parse_description_html(objective["description"]),
)
],
semantic_identifier=objective["name"],
source=DocumentSource.PRODUCTBOARD,
doc_updated_at=time_str_to_utc(objective["updatedAt"]),
primary_owners=experts,
metadata={
"productboard_entity_type": "release",
"entity_type": "release",
"state": objective["state"],
"owner": self._get_owner_email(objective),
"updated_at": objective["updatedAt"],
},
)
@@ -252,3 +246,20 @@ class ProductboardConnector(PollConnector):
if document_batch:
yield document_batch
if __name__ == "__main__":
import os
import time
connector = ProductboardConnector()
connector.load_credentials(
{
"productboard_access_token": os.environ["PRODUCTBOARD_ACCESS_TOKEN"],
}
)
current = time.time()
one_year_ago = current - 24 * 60 * 60 * 360
latest_docs = connector.poll_source(one_year_ago, current)
print(next(latest_docs))

View File

@@ -0,0 +1 @@
.env

View File

@@ -0,0 +1,153 @@
from datetime import datetime
from datetime import timezone
from logging import DEBUG as LOG_LVL_DEBUG
from typing import Any
from typing import List
from typing import Optional
from rt.rest1 import ALL_QUEUES
from rt.rest1 import Rt
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.utils.logger import setup_logger
logger = setup_logger()
class RequestTrackerError(Exception):
pass
class RequestTrackerConnector(PollConnector):
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.batch_size = batch_size
def txn_link(self, tid: int, txn: int) -> str:
return f"{self.rt_base_url}/Ticket/Display.html?id={tid}&txn={txn}"
def build_doc_sections_from_txn(
self, connection: Rt, ticket_id: int
) -> List[Section]:
Sections: List[Section] = []
get_history_resp = connection.get_history(ticket_id)
if get_history_resp is None:
raise RequestTrackerError(f"Ticket {ticket_id} cannot be found")
for tx in get_history_resp:
Sections.append(
Section(
link=self.txn_link(ticket_id, int(tx["id"])),
text="\n".join(
[
f"{k}:\n{v}\n" if k != "Attachments" else ""
for (k, v) in tx.items()
]
),
)
)
return Sections
def load_credentials(self, credentials: dict[str, Any]) -> Optional[dict[str, Any]]:
self.rt_username = credentials.get("requesttracker_username")
self.rt_password = credentials.get("requesttracker_password")
self.rt_base_url = credentials.get("requesttracker_base_url")
return None
# This does not include RT file attachments yet.
def _process_tickets(
self, start: datetime, end: datetime
) -> GenerateDocumentsOutput:
if any([self.rt_username, self.rt_password, self.rt_base_url]) is None:
raise ConnectorMissingCredentialError("requesttracker")
Rt0 = Rt(
f"{self.rt_base_url}/REST/1.0/",
self.rt_username,
self.rt_password,
)
Rt0.login()
d0 = start.strftime("%Y-%m-%d %H:%M:%S")
d1 = end.strftime("%Y-%m-%d %H:%M:%S")
tickets = Rt0.search(
Queue=ALL_QUEUES,
raw_query=f"Updated > '{d0}' AND Updated < '{d1}'",
)
doc_batch: List[Document] = []
for ticket in tickets:
ticket_keys_to_omit = ["id", "Subject"]
tid: int = int(ticket["numerical_id"])
ticketLink: str = f"{self.rt_base_url}/Ticket/Display.html?id={tid}"
logger.info(f"Processing ticket {tid}")
doc = Document(
id=ticket["id"],
# Will add title to the first section later in processing
sections=[Section(link=ticketLink, text="")]
+ self.build_doc_sections_from_txn(Rt0, tid),
source=DocumentSource.REQUESTTRACKER,
semantic_identifier=ticket["Subject"],
metadata={
key: value
for key, value in ticket.items()
if key not in ticket_keys_to_omit
},
)
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
# Keep query short, only look behind 1 day at maximum
one_day_ago: float = end - (24 * 60 * 60)
_start: float = start if start > one_day_ago else one_day_ago
start_datetime = datetime.fromtimestamp(_start, tz=timezone.utc)
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
yield from self._process_tickets(start_datetime, end_datetime)
if __name__ == "__main__":
import time
import os
from dotenv import load_dotenv
load_dotenv()
logger.setLevel(LOG_LVL_DEBUG)
rt_connector = RequestTrackerConnector()
rt_connector.load_credentials(
{
"requesttracker_username": os.getenv("RT_USERNAME"),
"requesttracker_password": os.getenv("RT_PASSWORD"),
"requesttracker_base_url": os.getenv("RT_BASE_URL"),
}
)
current = time.time()
one_day_ago = current - (24 * 60 * 60) # 1 days
latest_docs = rt_connector.poll_source(one_day_ago, current)
for doc in latest_docs:
print(doc)

View File

@@ -5,14 +5,35 @@ from zenpy.lib.api_objects.help_centre_objects import Article # type: ignore
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.html_utils import parse_html_page_basic
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import Document
from danswer.connectors.models import Section
def _article_to_document(article: Article) -> Document:
author = BasicExpertInfo(
display_name=article.author.name, email=article.author.email
)
update_time = time_str_to_utc(article.updated_at)
return Document(
id=f"article:{article.id}",
sections=[
Section(link=article.html_url, text=parse_html_page_basic(article.body))
],
source=DocumentSource.ZENDESK,
semantic_identifier=article.title,
doc_updated_at=update_time,
primary_owners=[author],
metadata={"type": "article"},
)
class ZendeskClientNotSetUpError(PermissionError):
def __init__(self) -> None:
super().__init__("Zendesk Client is not set up, was load_credentials called?")
@@ -34,18 +55,6 @@ class ZendeskConnector(LoadConnector, PollConnector):
def load_from_state(self) -> GenerateDocumentsOutput:
return self.poll_source(None, None)
def _article_to_document(self, article: Article) -> Document:
return Document(
id=f"article:{article.id}",
sections=[Section(link=article.html_url, text=article.body)],
source=DocumentSource.ZENDESK,
semantic_identifier="Article: " + article.title,
metadata={
"type": "article",
"updated_at": article.updated_at,
},
)
def poll_source(
self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
@@ -64,7 +73,10 @@ class ZendeskConnector(LoadConnector, PollConnector):
if article.body is None:
continue
doc_batch.append(self._article_to_document(article))
doc_batch.append(_article_to_document(article))
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch.clear()
if doc_batch:
yield doc_batch

View File

@@ -1,33 +1,37 @@
from datetime import datetime
import pytz
import timeago # type: ignore
from slack_sdk.models.blocks import ActionsBlock
from slack_sdk.models.blocks import Block
from slack_sdk.models.blocks import ButtonElement
from slack_sdk.models.blocks import ConfirmObject
from slack_sdk.models.blocks import DividerBlock
from slack_sdk.models.blocks import HeaderBlock
from slack_sdk.models.blocks import Option
from slack_sdk.models.blocks import RadioButtonsElement
from slack_sdk.models.blocks import SectionBlock
from danswer.chat.models import DanswerQuote
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import SearchFeedbackType
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY
from danswer.configs.danswerbot_configs import ENABLE_SLACK_DOC_FEEDBACK
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.utils import build_feedback_block_id
from danswer.danswerbot.slack.utils import build_feedback_id
from danswer.danswerbot.slack.utils import remove_slack_text_interactions
from danswer.danswerbot.slack.utils import translate_vespa_highlight_to_slack
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.server.models import SearchDoc
from danswer.search.models import SavedSearchDoc
from danswer.utils.text_processing import decode_escapes
from danswer.utils.text_processing import replace_whitespaces_w_space
_MAX_BLURB_LEN = 75
def build_qa_feedback_block(query_event_id: int) -> Block:
def build_qa_feedback_block(message_id: int) -> Block:
return ActionsBlock(
block_id=build_feedback_block_id(query_event_id),
block_id=build_feedback_id(message_id),
elements=[
ButtonElement(
action_id=LIKE_BLOCK_ACTION_ID,
@@ -43,33 +47,44 @@ def build_qa_feedback_block(query_event_id: int) -> Block:
)
def get_document_feedback_blocks() -> Block:
return SectionBlock(
text=(
"- 'Up-Boost' if this document is a good source of information and should be "
"shown more often.\n"
"- 'Down-boost' if this document is a poor source of information and should be "
"shown less often.\n"
"- 'Hide' if this document is deprecated and should never be shown anymore."
),
accessory=RadioButtonsElement(
options=[
Option(
text=":thumbsup: Up-Boost",
value=SearchFeedbackType.ENDORSE.value,
),
Option(
text=":thumbsdown: Down-Boost",
value=SearchFeedbackType.REJECT.value,
),
Option(
text=":x: Hide",
value=SearchFeedbackType.HIDE.value,
),
]
),
)
def build_doc_feedback_block(
query_event_id: int,
message_id: int,
document_id: str,
document_rank: int,
) -> Block:
return ActionsBlock(
block_id=build_feedback_block_id(query_event_id, document_id, document_rank),
elements=[
ButtonElement(
action_id=SearchFeedbackType.ENDORSE.value,
text="",
style="primary",
confirm=ConfirmObject(
title="Endorse this Document",
text="This is a good source of information and should be shown more often!",
),
),
ButtonElement(
action_id=SearchFeedbackType.REJECT.value,
text="",
style="danger",
confirm=ConfirmObject(
title="Reject this Document",
text="This is a bad source of information and should be shown less often.",
),
),
],
) -> ButtonElement:
feedback_id = build_feedback_id(message_id, document_id, document_rank)
return ButtonElement(
action_id=FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID,
value=feedback_id,
text="Give Feedback",
)
@@ -77,7 +92,7 @@ def get_restate_blocks(
msg: str,
is_bot_msg: bool,
) -> list[Block]:
# Only the slash command needs this context because the user doesnt see their own input
# Only the slash command needs this context because the user doesn't see their own input
if not is_bot_msg:
return []
@@ -88,13 +103,15 @@ def get_restate_blocks(
def build_documents_blocks(
documents: list[SearchDoc],
query_event_id: int,
documents: list[SavedSearchDoc],
message_id: int | None,
num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY,
include_feedback: bool = ENABLE_SLACK_DOC_FEEDBACK,
) -> list[Block]:
header_text = (
"Retrieved Documents" if DISABLE_GENERATIVE_AI else "Reference Documents"
)
seen_docs_identifiers = set()
section_blocks: list[Block] = [HeaderBlock(text="Reference Documents")]
section_blocks: list[Block] = [HeaderBlock(text=header_text)]
included_docs = 0
for rank, d in enumerate(documents):
if d.document_id in seen_docs_identifiers:
@@ -110,24 +127,32 @@ def build_documents_blocks(
included_docs += 1
header_line = f"{doc_sem_id}\n"
if d.link:
block_text = f"<{d.link}|{doc_sem_id}>:\n>{remove_slack_text_interactions(match_str)}"
else:
block_text = f"{doc_sem_id}:\n>{remove_slack_text_interactions(match_str)}"
header_line = f"<{d.link}|{doc_sem_id}>\n"
updated_at_line = ""
if d.updated_at is not None:
updated_at_line = (
f"_Updated {timeago.format(d.updated_at, datetime.now(pytz.utc))}_\n"
)
body_text = f">{remove_slack_text_interactions(match_str)}"
block_text = header_line + updated_at_line + body_text
feedback: ButtonElement | dict = {}
if message_id is not None:
feedback = build_doc_feedback_block(
message_id=message_id,
document_id=d.document_id,
document_rank=rank,
)
section_blocks.append(
SectionBlock(text=block_text),
SectionBlock(text=block_text, accessory=feedback),
)
if include_feedback:
section_blocks.append(
build_doc_feedback_block(
query_event_id=query_event_id,
document_id=d.document_id,
document_rank=rank,
),
)
section_blocks.append(DividerBlock())
if included_docs >= num_docs_to_display:
@@ -179,18 +204,29 @@ def build_quotes_block(
def build_qa_response_blocks(
query_event_id: int,
message_id: int | None,
answer: str | None,
quotes: list[DanswerQuote] | None,
source_filters: list[DocumentSource] | None,
time_cutoff: datetime | None,
favor_recent: bool,
skip_quotes: bool = False,
) -> list[Block]:
if DISABLE_GENERATIVE_AI:
return []
quotes_blocks: list[Block] = []
ai_answer_header = HeaderBlock(text="AI Answer")
filter_block: Block | None = None
if time_cutoff or favor_recent:
if time_cutoff or favor_recent or source_filters:
filter_text = "Filters: "
if source_filters:
sources_str = ", ".join([s.value for s in source_filters])
filter_text += f"`Sources in [{sources_str}]`"
if time_cutoff or favor_recent:
filter_text += " and "
if time_cutoff is not None:
time_str = time_cutoff.strftime("%b %d, %Y")
filter_text += f"`Docs Updated >= {time_str}` "
@@ -206,7 +242,8 @@ def build_qa_response_blocks(
text="Sorry, I was unable to find an answer, but I did find some potentially relevant docs 🤓"
)
else:
answer_block = SectionBlock(text=remove_slack_text_interactions(answer))
answer_processed = decode_escapes(remove_slack_text_interactions(answer))
answer_block = SectionBlock(text=answer_processed)
if quotes:
quotes_blocks = build_quotes_block(quotes)
@@ -218,15 +255,22 @@ def build_qa_response_blocks(
)
]
feedback_block = build_qa_feedback_block(query_event_id=query_event_id)
feedback_block = None
if message_id is not None:
feedback_block = build_qa_feedback_block(message_id=message_id)
response_blocks: list[Block] = [ai_answer_header]
if filter_block is not None:
response_blocks.append(filter_block)
response_blocks.extend(
[answer_block, feedback_block] + quotes_blocks + [DividerBlock()]
)
response_blocks.append(answer_block)
if feedback_block is not None:
response_blocks.append(feedback_block)
if not skip_quotes:
response_blocks.extend(quotes_blocks)
response_blocks.append(DividerBlock())
return response_blocks

View File

@@ -1,3 +1,5 @@
LIKE_BLOCK_ACTION_ID = "feedback-like"
DISLIKE_BLOCK_ACTION_ID = "feedback-dislike"
FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID = "feedback-doc-button"
SLACK_CHANNEL_ID = "channel_id"
VIEW_DOC_FEEDBACK_ID = "view-doc-feedback"

View File

@@ -1,19 +1,60 @@
from slack_sdk import WebClient
from slack_sdk.models.views import View
from slack_sdk.socket_mode import SocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest
from sqlalchemy.orm import Session
from danswer.configs.constants import QAFeedbackType
from danswer.configs.constants import SearchFeedbackType
from danswer.danswerbot.slack.blocks import get_document_feedback_blocks
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.utils import decompose_block_id
from danswer.danswerbot.slack.constants import VIEW_DOC_FEEDBACK_ID
from danswer.danswerbot.slack.utils import build_feedback_id
from danswer.danswerbot.slack.utils import decompose_feedback_id
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.feedback import create_chat_message_feedback
from danswer.db.feedback import create_doc_retrieval_feedback
from danswer.db.feedback import update_query_event_feedback
from danswer.document_index.factory import get_default_document_index
from danswer.utils.logger import setup_logger
logger_base = setup_logger()
def handle_doc_feedback_button(
req: SocketModeRequest,
client: SocketModeClient,
) -> None:
if not (actions := req.payload.get("actions")):
logger_base.error("Missing actions. Unable to build the source feedback view")
return
# Extracts the feedback_id coming from the 'source feedback' button
# and generates a new one for the View, to keep track of the doc info
query_event_id, doc_id, doc_rank = decompose_feedback_id(actions[0].get("value"))
external_id = build_feedback_id(query_event_id, doc_id, doc_rank)
channel_id = req.payload["container"]["channel_id"]
thread_ts = req.payload["container"]["thread_ts"]
data = View(
type="modal",
callback_id=VIEW_DOC_FEEDBACK_ID,
external_id=external_id,
# We use the private metadata to keep track of the channel id and thread ts
private_metadata=f"{channel_id}_{thread_ts}",
title="Give Feedback",
blocks=[get_document_feedback_blocks()],
submit="send",
close="cancel",
)
client.web_client.views_open(
trigger_id=req.payload["trigger_id"], view=data.to_dict()
)
def handle_slack_feedback(
block_id: str,
feedback_id: str,
feedback_type: str,
client: WebClient,
user_id_to_post_confirmation: str,
@@ -22,37 +63,43 @@ def handle_slack_feedback(
) -> None:
engine = get_sqlalchemy_engine()
query_id, doc_id, doc_rank = decompose_block_id(block_id)
message_id, doc_id, doc_rank = decompose_feedback_id(feedback_id)
with Session(engine) as db_session:
if feedback_type in [LIKE_BLOCK_ACTION_ID, DISLIKE_BLOCK_ACTION_ID]:
update_query_event_feedback(
feedback=QAFeedbackType.LIKE
if feedback_type == LIKE_BLOCK_ACTION_ID
else QAFeedbackType.DISLIKE,
query_id=query_id,
create_chat_message_feedback(
is_positive=feedback_type == LIKE_BLOCK_ACTION_ID,
feedback_text="",
chat_message_id=message_id,
user_id=None, # no "user" for Slack bot for now
db_session=db_session,
)
if feedback_type in [
elif feedback_type in [
SearchFeedbackType.ENDORSE.value,
SearchFeedbackType.REJECT.value,
SearchFeedbackType.HIDE.value,
]:
if doc_id is None or doc_rank is None:
raise ValueError("Missing information for Document Feedback")
if feedback_type == SearchFeedbackType.ENDORSE.value:
feedback = SearchFeedbackType.ENDORSE
elif feedback_type == SearchFeedbackType.REJECT.value:
feedback = SearchFeedbackType.REJECT
else:
feedback = SearchFeedbackType.HIDE
create_doc_retrieval_feedback(
qa_event_id=query_id,
message_id=message_id,
document_id=doc_id,
document_rank=doc_rank,
user_id=None,
document_index=get_default_document_index(),
db_session=db_session,
clicked=False, # Not tracking this for Slack
feedback=SearchFeedbackType.ENDORSE
if feedback_type == SearchFeedbackType.ENDORSE.value
else SearchFeedbackType.REJECT,
feedback=feedback,
)
else:
logger_base.error(f"Feedback type '{feedback_type}' not supported")
# post message to slack confirming that feedback was received
client.chat_postEphemeral(

View File

@@ -6,8 +6,8 @@ from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from sqlalchemy.orm import Session
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
@@ -25,11 +25,15 @@ from danswer.danswerbot.slack.utils import fetch_userids_from_emails
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import SlackBotConfig
from danswer.direct_qa.answer_question import answer_qa_query
from danswer.one_shot_answer.answer_question import get_search_answer
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.search.models import BaseFilters
from danswer.server.models import QAResponse
from danswer.server.models import QuestionRequest
from danswer.search.models import OptionalSearchSetting
from danswer.search.models import RetrievalDetails
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
logger_base = setup_logger()
@@ -75,6 +79,7 @@ def handle_message(
disable_docs_only_answer: bool = DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER,
disable_auto_detect_filters: bool = DISABLE_DANSWER_BOT_FILTER_DETECT,
reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
disable_cot: bool = DANSWER_BOT_DISABLE_COT,
) -> bool:
"""Potentially respond to the user message depending on filters and if an answer was generated
@@ -83,36 +88,56 @@ def handle_message(
Query thrown out by filters due to config does not count as a failure that should be notified
Danswer failing to answer/retrieve docs does count and should be notified
"""
msg = message_info.msg_content
channel = message_info.channel_to_respond
message_ts_to_respond_to = message_info.msg_to_respond
sender_id = message_info.sender
bipass_filters = message_info.bipass_filters
is_bot_msg = message_info.is_bot_msg
logger = cast(
logging.Logger,
ChannelIdAdapter(logger_base, extra={SLACK_CHANNEL_ID: channel}),
)
messages = message_info.thread_messages
message_ts_to_respond_to = message_info.msg_to_respond
sender_id = message_info.sender
bypass_filters = message_info.bypass_filters
is_bot_msg = message_info.is_bot_msg
is_bot_dm = message_info.is_bot_dm
engine = get_sqlalchemy_engine()
document_set_names: list[str] | None = None
if channel_config and channel_config.persona:
persona = channel_config.persona if channel_config else None
prompt = None
if persona:
document_set_names = [
document_set.name for document_set in channel_config.persona.document_sets
document_set.name for document_set in persona.document_sets
]
prompt = persona.prompts[0] if persona.prompts else None
should_respond_even_with_no_docs = persona.num_chunks == 0 if persona else False
# List of user id to send message to, if None, send to everyone in channel
send_to: list[str] | None = None
respond_tag_only = False
respond_team_member_list = None
bypass_acl = False
if (
channel_config
and channel_config.persona
and channel_config.persona.document_sets
):
# For Slack channels, use the full document set, admin will be warned when configuring it
# with non-public document sets
bypass_acl = True
if channel_config and channel_config.channel_config:
channel_conf = channel_config.channel_config
if not bipass_filters and "answer_filters" in channel_conf:
if not bypass_filters and "answer_filters" in channel_conf:
reflexion = "well_answered_postfilter" in channel_conf["answer_filters"]
if (
"questionmark_prefilter" in channel_conf["answer_filters"]
and "?" not in msg
and "?" not in messages[-1].message
):
logger.info(
"Skipping message since it does not contain a question mark"
@@ -128,7 +153,7 @@ def handle_message(
respond_tag_only = channel_conf.get("respond_tag_only") or False
respond_team_member_list = channel_conf.get("respond_team_member_list") or None
if respond_tag_only and not bipass_filters:
if respond_tag_only and not bypass_filters:
logger.info(
"Skipping message since the channel is configured such that "
"DanswerBot only responds to tags"
@@ -161,24 +186,34 @@ def handle_message(
backoff=2,
logger=logger,
)
def _get_answer(question: QuestionRequest) -> QAResponse:
engine = get_sqlalchemy_engine()
def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse:
action = "slack_message"
if is_bot_msg:
action = "slack_slash_message"
elif bypass_filters:
action = "slack_tag_message"
elif is_bot_dm:
action = "slack_dm_message"
optional_telemetry(
record_type=RecordType.USAGE,
data={"action": action},
)
with Session(engine, expire_on_commit=False) as db_session:
# This also handles creating the query event in postgres
answer = answer_qa_query(
question=question,
answer = get_search_answer(
query_req=new_message_request,
user=None,
db_session=db_session,
answer_generation_timeout=answer_generation_timeout,
real_time_flow=False,
enable_reflexion=reflexion,
bypass_acl=bypass_acl,
)
if not answer.error_msg:
return answer
else:
raise RuntimeError(answer.error_msg)
answer_failed = False
try:
# By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters
# it allows the slack flow to extract out filters from the user query
@@ -188,19 +223,30 @@ def handle_message(
time_cutoff=None,
)
auto_detect_filters = (
persona.llm_filter_extraction if persona is not None else False
)
if disable_auto_detect_filters:
auto_detect_filters = False
retrieval_details = RetrievalDetails(
run_search=OptionalSearchSetting.ALWAYS,
real_time=False,
filters=filters,
enable_auto_detect_filters=auto_detect_filters,
)
# This includes throwing out answer via reflexion
answer = _get_answer(
QuestionRequest(
query=msg,
collection=DOCUMENT_INDEX_NAME,
enable_auto_detect_filters=not disable_auto_detect_filters,
filters=filters,
favor_recent=None,
offset=None,
DirectQARequest(
messages=messages,
prompt_id=prompt.id if prompt else None,
persona_id=persona.id if persona is not None else 0,
retrieval_options=retrieval_details,
chain_of_thought=not disable_cot,
)
)
except Exception as e:
answer_failed = True
logger.exception(
f"Unable to process message - did not successfully answer "
f"in {num_retries} attempts"
@@ -216,15 +262,21 @@ def handle_message(
thread_ts=message_ts_to_respond_to,
)
# In case of failures, don't keep the reaction there permanently
try:
remove_react(message_info, client)
except SlackApiError as e:
logger.error(f"Failed to remove Reaction due to: {e}")
return True
# Got an answer at this point, can remove reaction and give results
try:
remove_react(message_info, client)
except SlackApiError as e:
logger.error(f"Failed to remove Reaction due to: {e}")
if answer_failed:
return True
if answer.eval_res_valid is False:
if answer.answer_valid is False:
logger.info(
"Answer was evaluated to be invalid, throwing it away without responding."
)
@@ -232,10 +284,18 @@ def handle_message(
logger.debug(answer.answer)
return True
if not answer.top_ranked_docs:
logger.error(f"Unable to answer question: '{msg}' - no documents found")
# Optionally, respond in thread with the error message, Used primarily
# for debugging purposes
retrieval_info = answer.docs
if not retrieval_info:
# This should not happen, even with no docs retrieved, there is still info returned
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
top_docs = retrieval_info.top_documents
if not top_docs and not should_respond_even_with_no_docs:
logger.error(
f"Unable to answer question: '{answer.rephrase}' - no documents found"
)
# Optionally, respond in thread with the error message
# Used primarily for debugging purposes
if should_respond_with_error_msgs:
respond_in_thread(
client=client,
@@ -254,18 +314,32 @@ def handle_message(
return True
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
restate_question_block = get_restate_blocks(msg, is_bot_msg)
restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg)
answer_blocks = build_qa_response_blocks(
query_event_id=answer.query_event_id,
message_id=answer.chat_message_id,
answer=answer.answer,
quotes=answer.quotes,
time_cutoff=answer.time_cutoff,
favor_recent=answer.favor_recent,
quotes=answer.quotes.quotes if answer.quotes else None,
source_filters=retrieval_info.applied_source_filters,
time_cutoff=retrieval_info.applied_time_cutoff,
favor_recent=retrieval_info.recency_bias_multiplier > 1,
skip_quotes=persona is not None, # currently Personas don't support quotes
)
document_blocks = build_documents_blocks(
documents=answer.top_ranked_docs, query_event_id=answer.query_event_id
# Get the chunks fed to the LLM only, then fill with other docs
llm_doc_inds = answer.llm_chunks_indices or []
llm_docs = [top_docs[i] for i in llm_doc_inds]
remaining_docs = [
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
]
priority_ordered_docs = llm_docs + remaining_docs
document_blocks = (
build_documents_blocks(
documents=priority_ordered_docs,
message_id=answer.chat_message_id,
)
if priority_ordered_docs
else []
)
try:

View File

@@ -1,5 +1,5 @@
import re
import time
from threading import Event
from typing import Any
from typing import cast
@@ -9,44 +9,39 @@ from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from sqlalchemy.orm import Session
from danswer.configs.constants import MessageType
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
from danswer.danswerbot.slack.constants import VIEW_DOC_FEEDBACK_ID
from danswer.danswerbot.slack.handlers.handle_feedback import handle_doc_feedback_button
from danswer.danswerbot.slack.handlers.handle_feedback import handle_slack_feedback
from danswer.danswerbot.slack.handlers.handle_message import handle_message
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.tokens import fetch_tokens
from danswer.danswerbot.slack.utils import ChannelIdAdapter
from danswer.danswerbot.slack.utils import decompose_block_id
from danswer.danswerbot.slack.utils import decompose_feedback_id
from danswer.danswerbot.slack.utils import get_channel_name_from_id
from danswer.danswerbot.slack.utils import get_danswer_bot_app_id
from danswer.danswerbot.slack.utils import get_view_values
from danswer.danswerbot.slack.utils import read_slack_thread
from danswer.danswerbot.slack.utils import remove_danswer_bot_tag
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.db.engine import get_sqlalchemy_engine
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.one_shot_answer.models import ThreadMessage
from danswer.search.search_nlp_models import warm_up_models
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
logger = setup_logger()
class MissingTokensException(Exception):
pass
def _get_socket_client() -> SocketModeClient:
# For more info on how to set this up, checkout the docs:
# https://docs.danswer.dev/slack_bot_setup
try:
slack_bot_tokens = fetch_tokens()
except ConfigNotFoundError:
raise MissingTokensException("Slack tokens not found")
return SocketModeClient(
# This app-level token will be used only for establishing a connection
app_token=slack_bot_tokens.app_token,
web_client=WebClient(token=slack_bot_tokens.bot_token),
)
def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool:
"""True to keep going, False to ignore this Slack request"""
if req.type == "events_api":
@@ -77,7 +72,7 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
return False
if event_type == "message":
bot_tag_id = client.web_client.auth_test().get("user_id")
bot_tag_id = get_danswer_bot_app_id(client.web_client)
# DMs with the bot don't pick up the @DanswerBot so we have to keep the
# caught events_api
if bot_tag_id and bot_tag_id in msg and event.get("channel_type") != "im":
@@ -101,8 +96,14 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
message_ts = event.get("ts")
thread_ts = event.get("thread_ts")
# Pick the root of the thread (if a thread exists)
if thread_ts and message_ts != thread_ts:
channel_specific_logger.info(
# Can respond in thread if it's an "im" directly to Danswer or @DanswerBot is tagged
if (
thread_ts
and message_ts != thread_ts
and event_type != "app_mention"
and event.get("channel_type") != "im"
):
channel_specific_logger.debug(
"Skipping message since it is not the root of a thread"
)
return False
@@ -135,28 +136,41 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
def process_feedback(req: SocketModeRequest, client: SocketModeClient) -> None:
actions = req.payload.get("actions")
if not actions:
logger.error("Unable to process block actions - no actions found")
# Answer feedback
if actions := req.payload.get("actions"):
action = cast(dict[str, Any], actions[0])
feedback_type = cast(str, action.get("action_id"))
feedback_id = cast(str, action.get("block_id"))
channel_id = cast(str, req.payload["container"]["channel_id"])
thread_ts = cast(str, req.payload["container"]["thread_ts"])
# Doc feedback
elif view := req.payload.get("view"):
view_values = get_view_values(view["state"]["values"])
private_metadata = view.get("private_metadata").split("_")
if not view_values:
logger.error("Unable to process feedback. Missing view values")
return
feedback_type = [x for x in view_values.values()][0]
feedback_id = cast(str, view.get("external_id"))
channel_id = private_metadata[0]
thread_ts = private_metadata[1]
else:
logger.error("Unable to process feedback. Actions or View not found")
return
action = cast(dict[str, Any], actions[0])
action_id = cast(str, action.get("action_id"))
block_id = cast(str, action.get("block_id"))
user_id = cast(str, req.payload["user"]["id"])
channel_id = cast(str, req.payload["container"]["channel_id"])
thread_ts = cast(str, req.payload["container"]["thread_ts"])
handle_slack_feedback(
block_id=block_id,
feedback_type=action_id,
feedback_id=feedback_id,
feedback_type=feedback_type,
client=client.web_client,
user_id_to_post_confirmation=user_id,
channel_id_to_post_confirmation=channel_id,
thread_ts_to_post_confirmation=thread_ts,
)
query_event_id, _, _ = decompose_block_id(block_id)
query_event_id, _, _ = decompose_feedback_id(feedback_id)
logger.info(f"Successfully handled QA feedback for event: {query_event_id}")
@@ -170,21 +184,29 @@ def build_request_details(
tagged = event.get("type") == "app_mention"
message_ts = event.get("ts")
thread_ts = event.get("thread_ts")
bot_tag_id = client.web_client.auth_test().get("user_id")
# Might exist even if not tagged, specifically in the case of @DanswerBot
# in DanswerBot DM channel
msg = re.sub(rf"<@{bot_tag_id}>\s", "", msg)
msg = remove_danswer_bot_tag(msg, client=client.web_client)
if tagged:
logger.info("User tagged DanswerBot")
if thread_ts != message_ts and thread_ts is not None:
thread_messages = read_slack_thread(
channel=channel, thread=thread_ts, client=client.web_client
)
else:
thread_messages = [
ThreadMessage(message=msg, sender=None, role=MessageType.USER)
]
return SlackMessageInfo(
msg_content=msg,
thread_messages=thread_messages,
channel_to_respond=channel,
msg_to_respond=cast(str, thread_ts or message_ts),
msg_to_respond=cast(str, message_ts or thread_ts),
sender=event.get("user") or None,
bipass_filters=tagged,
bypass_filters=tagged,
is_bot_msg=False,
is_bot_dm=event.get("channel_type") == "im",
)
elif req.type == "slash_commands":
@@ -192,13 +214,16 @@ def build_request_details(
msg = req.payload["text"]
sender = req.payload["user_id"]
single_msg = ThreadMessage(message=msg, sender=None, role=MessageType.USER)
return SlackMessageInfo(
msg_content=msg,
thread_messages=[single_msg],
channel_to_respond=channel,
msg_to_respond=None,
sender=sender,
bipass_filters=True,
bypass_filters=True,
is_bot_msg=True,
is_bot_dm=False,
)
raise RuntimeError("Programming fault, this should never happen.")
@@ -247,8 +272,9 @@ def process_message(
and not respond_every_channel
# Can't have configs for DMs so don't toss them out
and not is_dm
# If @DanswerBot or /DanswerBot, always respond with the default configs
and not (details.is_bot_msg or details.bipass_filters)
# If /DanswerBot (is_bot_msg) or @DanswerBot (bypass_filters)
# always respond with the default configs
and not (details.is_bot_msg or details.bypass_filters)
):
return
@@ -268,21 +294,59 @@ def acknowledge_message(req: SocketModeRequest, client: SocketModeClient) -> Non
client.send_socket_mode_response(response)
def action_routing(req: SocketModeRequest, client: SocketModeClient) -> None:
if actions := req.payload.get("actions"):
action = cast(dict[str, Any], actions[0])
if action["action_id"] in [DISLIKE_BLOCK_ACTION_ID, LIKE_BLOCK_ACTION_ID]:
# AI Answer feedback
return process_feedback(req, client)
elif action["action_id"] == FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID:
# Activation of the "source feedback" button
return handle_doc_feedback_button(req, client)
def view_routing(req: SocketModeRequest, client: SocketModeClient) -> None:
if view := req.payload.get("view"):
if view["callback_id"] == VIEW_DOC_FEEDBACK_ID:
return process_feedback(req, client)
def process_slack_event(client: SocketModeClient, req: SocketModeRequest) -> None:
# Always respond right away, if Slack doesn't receive these frequently enough
# it will assume the Bot is DEAD!!! :(
acknowledge_message(req, client)
try:
if req.type == "interactive" and req.payload.get("type") == "block_actions":
return process_feedback(req, client)
if req.type == "interactive":
if req.payload.get("type") == "block_actions":
return action_routing(req, client)
elif req.payload.get("type") == "view_submission":
return view_routing(req, client)
elif req.type == "events_api" or req.type == "slash_commands":
return process_message(req, client)
except Exception:
logger.exception("Failed to process slack event")
def _get_socket_client(slack_bot_tokens: SlackBotTokens) -> SocketModeClient:
# For more info on how to set this up, checkout the docs:
# https://docs.danswer.dev/slack_bot_setup
return SocketModeClient(
# This app-level token will be used only for establishing a connection
app_token=slack_bot_tokens.app_token,
web_client=WebClient(token=slack_bot_tokens.bot_token),
)
def _initialize_socket_client(socket_client: SocketModeClient) -> None:
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
# Establish a WebSocket connection to the Socket Mode servers
logger.info("Listening for messages from Slack...")
socket_client.connect()
# Follow the guide (https://docs.danswer.dev/slack_bot_setup) to set up
# the slack bot in your workspace, and then add the bot to any channels you want to
# try and answer questions for. Running this file will setup Danswer to listen to all
@@ -293,21 +357,37 @@ def process_slack_event(client: SocketModeClient, req: SocketModeRequest) -> Non
# NOTE: we are using Web Sockets so that you can run this from within a firewalled VPC
# without issue.
if __name__ == "__main__":
try:
socket_client = _get_socket_client()
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
warm_up_models(skip_cross_encoders=not ENABLE_RERANKING_ASYNC_FLOW)
# Establish a WebSocket connection to the Socket Mode servers
logger.info("Listening for messages from Slack...")
socket_client.connect()
slack_bot_tokens: SlackBotTokens | None = None
socket_client: SocketModeClient | None = None
while True:
try:
latest_slack_bot_tokens = fetch_tokens()
# Just not to stop this process
from threading import Event
if latest_slack_bot_tokens != slack_bot_tokens:
if slack_bot_tokens is not None:
logger.info("Slack Bot tokens have changed - reconnecting")
slack_bot_tokens = latest_slack_bot_tokens
# potentially may cause a message to be dropped, but it is complicated
# to avoid + (1) if the user is changing tokens, they are likely okay with some
# "migration downtime" and (2) if a single message is lost it is okay
# as this should be a very rare occurrence
if socket_client:
socket_client.close()
Event().wait()
except MissingTokensException:
# try again every 30 seconds. This is needed since the user may add tokens
# via the UI at any point in the programs lifecycle - if we just allow it to
# fail, then the user will need to restart the containers after adding tokens
logger.debug("Missing Slack Bot tokens - waiting 60 seconds and trying again")
time.sleep(60)
socket_client = _get_socket_client(slack_bot_tokens)
_initialize_socket_client(socket_client)
# Let the handlers run in the background + re-check for token updates every 60 seconds
Event().wait(timeout=60)
except ConfigNotFoundError:
# try again every 30 seconds. This is needed since the user may add tokens
# via the UI at any point in the programs lifecycle - if we just allow it to
# fail, then the user will need to restart the containers after adding tokens
logger.debug(
"Missing Slack Bot tokens - waiting 60 seconds and trying again"
)
if socket_client:
socket_client.disconnect()
time.sleep(60)

View File

@@ -1,10 +1,13 @@
from pydantic import BaseModel
from danswer.one_shot_answer.models import ThreadMessage
class SlackMessageInfo(BaseModel):
msg_content: str
thread_messages: list[ThreadMessage]
channel_to_respond: str
msg_to_respond: str | None
sender: str | None
bipass_filters: bool
is_bot_msg: bool
bypass_filters: bool # User has tagged @DanswerBot
is_bot_msg: bool # User is using /DanswerBot
is_bot_dm: bool # User is direct messaging to DanswerBot

View File

@@ -2,7 +2,7 @@ import os
from typing import cast
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.server.models import SlackBotTokens
from danswer.server.manage.models import SlackBotTokens
_SLACK_BOT_TOKENS_CONFIG_KEY = "slack_bot_tokens_config_key"

View File

@@ -13,17 +13,34 @@ from slack_sdk.models.blocks import Block
from slack_sdk.models.metadata import Metadata
from danswer.configs.constants import ID_SEPARATOR
from danswer.configs.constants import MessageType
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.connectors.slack.utils import SlackTextCleaner
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
from danswer.danswerbot.slack.tokens import fetch_tokens
from danswer.one_shot_answer.models import ThreadMessage
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import replace_whitespaces_w_space
logger = setup_logger()
DANSWER_BOT_APP_ID: str | None = None
def get_danswer_bot_app_id(web_client: WebClient) -> Any:
global DANSWER_BOT_APP_ID
if DANSWER_BOT_APP_ID is None:
DANSWER_BOT_APP_ID = web_client.auth_test().get("user_id")
return DANSWER_BOT_APP_ID
def remove_danswer_bot_tag(message_str: str, client: WebClient) -> str:
bot_tag_id = get_danswer_bot_app_id(web_client=client)
return re.sub(rf"<@{bot_tag_id}>\s", "", message_str)
class ChannelIdAdapter(logging.LoggerAdapter):
"""This is used to add the channel ID to all log messages
emitted in this file"""
@@ -95,8 +112,8 @@ def respond_in_thread(
raise RuntimeError(f"Failed to post message: {response}")
def build_feedback_block_id(
query_event_id: int,
def build_feedback_id(
message_id: int,
document_id: str | None = None,
document_rank: int | None = None,
) -> str:
@@ -108,21 +125,21 @@ def build_feedback_block_id(
raise ValueError(
"Separator pattern should not already exist in document id"
)
block_id = ID_SEPARATOR.join(
[str(query_event_id), document_id, str(document_rank)]
feedback_id = ID_SEPARATOR.join(
[str(message_id), document_id, str(document_rank)]
)
else:
block_id = str(query_event_id)
feedback_id = str(message_id)
return unique_prefix + ID_SEPARATOR + block_id
return unique_prefix + ID_SEPARATOR + feedback_id
def decompose_block_id(block_id: str) -> tuple[int, str | None, int | None]:
def decompose_feedback_id(feedback_id: str) -> tuple[int, str | None, int | None]:
"""Decompose into query_id, document_id, document_rank, see above function"""
try:
components = block_id.split(ID_SEPARATOR)
components = feedback_id.split(ID_SEPARATOR)
if len(components) != 2 and len(components) != 4:
raise ValueError("Block ID does not contain right number of elements")
raise ValueError("Feedback ID does not contain right number of elements")
if len(components) == 2:
return int(components[-1]), None, None
@@ -131,7 +148,36 @@ def decompose_block_id(block_id: str) -> tuple[int, str | None, int | None]:
except Exception as e:
logger.error(e)
raise ValueError("Received invalid Feedback Block Identifier")
raise ValueError("Received invalid Feedback Identifier")
def get_view_values(state_values: dict[str, Any]) -> dict[str, str]:
"""Extract view values
Args:
state_values (dict): The Slack view-submission values
Returns:
dict: keys/values of the view state content
"""
view_values = {}
for _, view_data in state_values.items():
for k, v in view_data.items():
if (
"selected_option" in v
and isinstance(v["selected_option"], dict)
and "value" in v["selected_option"]
):
view_values[k] = v["selected_option"]["value"]
elif "selected_options" in v and isinstance(v["selected_options"], list):
view_values[k] = [
x["value"] for x in v["selected_options"] if "value" in x
]
elif "selected_date" in v:
view_values[k] = v["selected_date"]
elif "value" in v:
view_values[k] = v["value"]
return view_values
def translate_vespa_highlight_to_slack(match_strs: list[str], used_chars: int) -> str:
@@ -201,3 +247,57 @@ def fetch_userids_from_emails(user_emails: list[str], client: WebClient) -> list
)
return user_ids
def fetch_user_semantic_id_from_id(user_id: str, client: WebClient) -> str | None:
response = client.users_info(user=user_id)
if not response["ok"]:
return None
user: dict = cast(dict[Any, dict], response.data).get("user", {})
return (
user.get("real_name")
or user.get("name")
or user.get("profile", {}).get("email")
)
def read_slack_thread(
channel: str, thread: str, client: WebClient
) -> list[ThreadMessage]:
thread_messages: list[ThreadMessage] = []
response = client.conversations_replies(channel=channel, ts=thread)
replies = cast(dict, response.data).get("messages", [])
for reply in replies:
if "user" in reply and "bot_id" not in reply:
message = remove_danswer_bot_tag(reply["text"], client=client)
user_sem_id = fetch_user_semantic_id_from_id(reply["user"], client)
message_type = MessageType.USER
else:
self_app_id = get_danswer_bot_app_id(client)
# Only include bot messages from Danswer, other bots are not taken in as context
if self_app_id != reply.get("user"):
continue
blocks = reply["blocks"]
if len(blocks) <= 1:
continue
# The useful block is the second one after the header block that says AI Answer
message = reply["blocks"][1]["text"]["text"]
if message.startswith("_Filters"):
if len(blocks) <= 2:
continue
message = reply["blocks"][2]["text"]["text"]
user_sem_id = "Assistant"
message_type = MessageType.ASSISTANT
thread_messages.append(
ThreadMessage(message=message, sender=user_sem_id, role=message_type)
)
return thread_messages

View File

View File

@@ -1,30 +1,64 @@
from typing import Any
from collections.abc import Sequence
from uuid import UUID
from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import not_
from sqlalchemy import nullsfirst
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm import selectinload
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.orm import Session
from danswer.configs.app_configs import HARD_DELETE_CHATS
from danswer.configs.chat_configs import HARD_DELETE_CHATS
from danswer.configs.constants import MessageType
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
from danswer.db.models import ChatMessage
from danswer.db.models import ChatSession
from danswer.db.models import DocumentSet as DocumentSetDBModel
from danswer.db.models import DocumentSet as DBDocumentSet
from danswer.db.models import Persona
from danswer.db.models import ToolInfo
from danswer.db.models import Prompt
from danswer.db.models import SearchDoc
from danswer.db.models import SearchDoc as DBSearchDoc
from danswer.search.models import RecencyBiasSetting
from danswer.search.models import RetrievalDocs
from danswer.search.models import SavedSearchDoc
from danswer.search.models import SearchDoc as ServerSearchDoc
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.utils.logger import setup_logger
logger = setup_logger()
def fetch_chat_sessions_by_user(
def get_chat_session_by_id(
chat_session_id: int, user_id: UUID | None, db_session: Session
) -> ChatSession:
stmt = select(ChatSession).where(
ChatSession.id == chat_session_id, ChatSession.user_id == user_id
)
result = db_session.execute(stmt)
chat_session = result.scalar_one_or_none()
if not chat_session:
raise ValueError("Invalid Chat Session ID provided")
if chat_session.deleted:
raise ValueError("Chat session has been deleted")
return chat_session
def get_chat_sessions_by_user(
user_id: UUID | None,
deleted: bool | None,
db_session: Session,
include_one_shot: bool = False,
) -> list[ChatSession]:
stmt = select(ChatSession).where(ChatSession.user_id == user_id)
if not include_one_shot:
stmt = stmt.where(ChatSession.one_shot.is_(False))
if deleted is not None:
stmt = stmt.where(ChatSession.deleted == deleted)
@@ -34,76 +68,18 @@ def fetch_chat_sessions_by_user(
return list(chat_sessions)
def fetch_chat_messages_by_session(
chat_session_id: int, db_session: Session
) -> list[ChatMessage]:
stmt = (
select(ChatMessage)
.where(ChatMessage.chat_session_id == chat_session_id)
.order_by(ChatMessage.message_number.asc(), ChatMessage.edit_number.asc())
)
result = db_session.execute(stmt).scalars().all()
return list(result)
def fetch_chat_message(
chat_session_id: int, message_number: int, edit_number: int, db_session: Session
) -> ChatMessage:
stmt = (
select(ChatMessage)
.where(
(ChatMessage.chat_session_id == chat_session_id)
& (ChatMessage.message_number == message_number)
& (ChatMessage.edit_number == edit_number)
)
.options(selectinload(ChatMessage.chat_session))
)
chat_message = db_session.execute(stmt).scalar_one_or_none()
if not chat_message:
raise ValueError("Invalid Chat Message specified")
return chat_message
def fetch_chat_session_by_id(chat_session_id: int, db_session: Session) -> ChatSession:
stmt = select(ChatSession).where(ChatSession.id == chat_session_id)
result = db_session.execute(stmt)
chat_session = result.scalar_one_or_none()
if not chat_session:
raise ValueError("Invalid Chat Session ID provided")
return chat_session
def verify_parent_exists(
chat_session_id: int,
message_number: int,
parent_edit_number: int | None,
db_session: Session,
) -> ChatMessage:
stmt = select(ChatMessage).where(
(ChatMessage.chat_session_id == chat_session_id)
& (ChatMessage.message_number == message_number - 1)
& (ChatMessage.edit_number == parent_edit_number)
)
result = db_session.execute(stmt)
try:
return result.scalar_one()
except NoResultFound:
raise ValueError("Invalid message, parent message not found")
def create_chat_session(
description: str, user_id: UUID | None, db_session: Session
db_session: Session,
description: str,
user_id: UUID | None,
persona_id: int | None = None,
one_shot: bool = False,
) -> ChatSession:
chat_session = ChatSession(
user_id=user_id,
persona_id=persona_id,
description=description,
one_shot=one_shot,
)
db_session.add(chat_session)
@@ -115,14 +91,13 @@ def create_chat_session(
def update_chat_session(
user_id: UUID | None, chat_session_id: int, description: str, db_session: Session
) -> ChatSession:
chat_session = fetch_chat_session_by_id(chat_session_id, db_session)
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
)
if chat_session.deleted:
raise ValueError("Trying to rename a deleted chat session")
if user_id != chat_session.user_id:
raise ValueError("User trying to update chat of another user.")
chat_session.description = description
db_session.commit()
@@ -136,10 +111,9 @@ def delete_chat_session(
db_session: Session,
hard_delete: bool = HARD_DELETE_CHATS,
) -> None:
chat_session = fetch_chat_session_by_id(chat_session_id, db_session)
if user_id != chat_session.user_id:
raise ValueError("User trying to delete chat of another user.")
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
)
if hard_delete:
stmt_messages = delete(ChatMessage).where(
@@ -156,185 +130,374 @@ def delete_chat_session(
db_session.commit()
def _set_latest_chat_message_no_commit(
chat_session_id: int,
message_number: int,
parent_edit_number: int | None,
edit_number: int,
def get_chat_message(
chat_message_id: int,
user_id: UUID | None,
db_session: Session,
) -> None:
if message_number != 0 and parent_edit_number is None:
raise ValueError(
"Only initial message in a chat is allowed to not have a parent"
) -> ChatMessage:
stmt = select(ChatMessage).where(ChatMessage.id == chat_message_id)
result = db_session.execute(stmt)
chat_message = result.scalar_one_or_none()
if not chat_message:
raise ValueError("Invalid Chat Message specified")
chat_user = chat_message.chat_session.user
expected_user_id = chat_user.id if chat_user is not None else None
if expected_user_id != user_id:
logger.error(
f"User {user_id} tried to fetch a chat message that does not belong to them"
)
raise ValueError("Chat message does not belong to user")
return chat_message
def get_chat_messages_by_session(
chat_session_id: int,
user_id: UUID | None,
db_session: Session,
skip_permission_check: bool = False,
) -> list[ChatMessage]:
if not skip_permission_check:
get_chat_session_by_id(
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
)
db_session.query(ChatMessage).filter(
and_(
ChatMessage.chat_session_id == chat_session_id,
ChatMessage.message_number == message_number,
ChatMessage.parent_edit_number == parent_edit_number,
)
).update({ChatMessage.latest: False})
stmt = (
select(ChatMessage).where(ChatMessage.chat_session_id == chat_session_id)
# Start with the root message which has no parent
.order_by(nullsfirst(ChatMessage.parent_message))
)
db_session.query(ChatMessage).filter(
and_(
ChatMessage.chat_session_id == chat_session_id,
ChatMessage.message_number == message_number,
ChatMessage.edit_number == edit_number,
result = db_session.execute(stmt).scalars().all()
return list(result)
def get_or_create_root_message(
chat_session_id: int,
db_session: Session,
) -> ChatMessage:
try:
root_message: ChatMessage | None = (
db_session.query(ChatMessage)
.filter(
ChatMessage.chat_session_id == chat_session_id,
ChatMessage.parent_message.is_(None),
)
.one_or_none()
)
).update({ChatMessage.latest: True})
except MultipleResultsFound:
raise Exception(
"Multiple root messages found for chat session. Data inconsistency detected."
)
if root_message is not None:
return root_message
else:
new_root_message = ChatMessage(
chat_session_id=chat_session_id,
prompt_id=None,
parent_message=None,
latest_child_message=None,
message="",
token_count=0,
message_type=MessageType.SYSTEM,
)
db_session.add(new_root_message)
db_session.commit()
return new_root_message
def create_new_chat_message(
chat_session_id: int,
message_number: int,
parent_message: ChatMessage,
message: str,
prompt_id: int | None,
token_count: int,
parent_edit_number: int | None,
message_type: MessageType,
db_session: Session,
retrieval_docs: dict[str, Any] | None = None,
rephrased_query: str | None = None,
error: str | None = None,
reference_docs: list[DBSearchDoc] | None = None,
# Maps the citation number [n] to the DB SearchDoc
citations: dict[int, int] | None = None,
commit: bool = True,
) -> ChatMessage:
"""Creates a new chat message and sets it to the latest message of its parent message"""
# Get the count of existing edits at the provided message number
latest_edit_number = (
db_session.query(func.max(ChatMessage.edit_number))
.filter_by(
chat_session_id=chat_session_id,
message_number=message_number,
)
.scalar()
)
# The new message is a new edit at the provided message number
new_edit_number = latest_edit_number + 1 if latest_edit_number is not None else 0
# Create a new message and set it to be the latest for its parent message
new_chat_message = ChatMessage(
chat_session_id=chat_session_id,
message_number=message_number,
parent_edit_number=parent_edit_number,
edit_number=new_edit_number,
parent_message=parent_message.id,
latest_child_message=None,
message=message,
reference_docs=retrieval_docs,
rephrased_query=rephrased_query,
prompt_id=prompt_id,
token_count=token_count,
message_type=message_type,
citations=citations,
error=error,
)
# SQL Alchemy will propagate this to update the reference_docs' foreign keys
if reference_docs:
new_chat_message.search_docs = reference_docs
db_session.add(new_chat_message)
# Set the previous latest message of the same parent, as no longer the latest
_set_latest_chat_message_no_commit(
chat_session_id=chat_session_id,
message_number=message_number,
parent_edit_number=parent_edit_number,
edit_number=new_edit_number,
db_session=db_session,
)
# Flush the session to get an ID for the new chat message
db_session.flush()
db_session.commit()
parent_message.latest_child_message = new_chat_message.id
if commit:
db_session.commit()
return new_chat_message
def set_latest_chat_message(
chat_session_id: int,
message_number: int,
parent_edit_number: int | None,
edit_number: int,
def set_as_latest_chat_message(
chat_message: ChatMessage,
user_id: UUID | None,
db_session: Session,
) -> None:
_set_latest_chat_message_no_commit(
chat_session_id=chat_session_id,
message_number=message_number,
parent_edit_number=parent_edit_number,
edit_number=edit_number,
db_session=db_session,
parent_message_id = chat_message.parent_message
if parent_message_id is None:
raise RuntimeError(
f"Trying to set a latest message without parent, message id: {chat_message.id}"
)
parent_message = get_chat_message(
chat_message_id=parent_message_id, user_id=user_id, db_session=db_session
)
parent_message.latest_child_message = chat_message.id
db_session.commit()
def fetch_persona_by_id(persona_id: int, db_session: Session) -> Persona:
def get_prompt_by_id(
prompt_id: int,
user_id: UUID | None,
db_session: Session,
include_deleted: bool = False,
) -> Prompt:
stmt = select(Prompt).where(
Prompt.id == prompt_id, or_(Prompt.user_id == user_id, Prompt.user_id.is_(None))
)
if not include_deleted:
stmt = stmt.where(Prompt.deleted.is_(False))
result = db_session.execute(stmt)
prompt = result.scalar_one_or_none()
if prompt is None:
raise ValueError(
f"Prompt with ID {prompt_id} does not exist or does not belong to user"
)
return prompt
def get_persona_by_id(
persona_id: int,
# if user_id is `None` assume the user is an admin or auth is disabled
user_id: UUID | None,
db_session: Session,
include_deleted: bool = False,
) -> Persona:
stmt = select(Persona).where(Persona.id == persona_id)
if user_id is not None:
stmt = stmt.where(or_(Persona.user_id == user_id, Persona.user_id.is_(None)))
if not include_deleted:
stmt = stmt.where(Persona.deleted.is_(False))
result = db_session.execute(stmt)
persona = result.scalar_one_or_none()
if persona is None:
raise ValueError(f"Persona with ID {persona_id} does not exist")
raise ValueError(
f"Persona with ID {persona_id} does not exist or does not belong to user"
)
return persona
def fetch_default_persona_by_name(
persona_name: str, db_session: Session
) -> Persona | None:
stmt = select(Persona).where(
Persona.name == persona_name, Persona.default_persona == True # noqa: E712
)
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> Sequence[Prompt]:
"""Unsafe, can fetch prompts from all users"""
if not prompt_ids:
return []
prompts = db_session.scalars(select(Prompt).where(Prompt.id.in_(prompt_ids))).all()
return prompts
def get_personas_by_ids(
persona_ids: list[int], db_session: Session
) -> Sequence[Persona]:
"""Unsafe, can fetch personas from all users"""
if not persona_ids:
return []
personas = db_session.scalars(
select(Persona).where(Persona.id.in_(persona_ids))
).all()
return personas
def get_prompt_by_name(
prompt_name: str, user_id: UUID | None, shared: bool, db_session: Session
) -> Prompt | None:
"""Cannot do shared and user owned simultaneously as there may be two of those"""
stmt = select(Prompt).where(Prompt.name == prompt_name)
if shared:
stmt = stmt.where(Prompt.user_id.is_(None))
else:
stmt = stmt.where(Prompt.user_id == user_id)
result = db_session.execute(stmt).scalar_one_or_none()
return result
def fetch_persona_by_name(persona_name: str, db_session: Session) -> Persona | None:
"""Try to fetch a default persona by name first,
if not exist, try to find any persona with the name
Note that name is not guaranteed unique unless default is true"""
persona = fetch_default_persona_by_name(persona_name, db_session)
if persona is not None:
return persona
def get_persona_by_name(
persona_name: str, user_id: UUID | None, shared: bool, db_session: Session
) -> Persona | None:
"""Cannot do shared and user owned simultaneously as there may be two of those"""
stmt = select(Persona).where(Persona.name == persona_name)
if shared:
stmt = stmt.where(Persona.user_id.is_(None))
else:
stmt = stmt.where(Persona.user_id == user_id)
result = db_session.execute(stmt).scalar_one_or_none()
return result
stmt = select(Persona).where(Persona.name == persona_name) # noqa: E712
result = db_session.execute(stmt).first()
if result:
return result[0]
return None
def upsert_prompt(
user_id: UUID | None,
name: str,
description: str,
system_prompt: str,
task_prompt: str,
include_citations: bool,
datetime_aware: bool,
personas: list[Persona] | None,
shared: bool,
db_session: Session,
prompt_id: int | None = None,
default_prompt: bool = True,
commit: bool = True,
) -> Prompt:
if prompt_id is not None:
prompt = db_session.query(Prompt).filter_by(id=prompt_id).first()
else:
prompt = get_prompt_by_name(
prompt_name=name, user_id=user_id, shared=shared, db_session=db_session
)
if prompt:
if not default_prompt and prompt.default_prompt:
raise ValueError("Cannot update default prompt with non-default.")
prompt.name = name
prompt.description = description
prompt.system_prompt = system_prompt
prompt.task_prompt = task_prompt
prompt.include_citations = include_citations
prompt.datetime_aware = datetime_aware
prompt.default_prompt = default_prompt
if personas is not None:
prompt.personas.clear()
prompt.personas = personas
else:
prompt = Prompt(
id=prompt_id,
user_id=None if shared else user_id,
name=name,
description=description,
system_prompt=system_prompt,
task_prompt=task_prompt,
include_citations=include_citations,
datetime_aware=datetime_aware,
default_prompt=default_prompt,
personas=personas or [],
)
db_session.add(prompt)
if commit:
db_session.commit()
else:
# Flush the session so that the Prompt has an ID
db_session.flush()
return prompt
def upsert_persona(
user_id: UUID | None,
name: str,
retrieval_enabled: bool,
datetime_aware: bool,
system_text: str | None,
tools: list[ToolInfo] | None,
hint_text: str | None,
description: str,
num_chunks: float,
llm_relevance_filter: bool,
llm_filter_extraction: bool,
recency_bias: RecencyBiasSetting,
prompts: list[Prompt] | None,
document_sets: list[DBDocumentSet] | None,
llm_model_version_override: str | None,
shared: bool,
db_session: Session,
persona_id: int | None = None,
default_persona: bool = False,
document_sets: list[DocumentSetDBModel] | None = None,
commit: bool = True,
) -> Persona:
persona = db_session.query(Persona).filter_by(id=persona_id).first()
# Default personas are defined via yaml files at deployment time
if persona is None and default_persona:
persona = fetch_default_persona_by_name(name, db_session)
if persona_id is not None:
persona = db_session.query(Persona).filter_by(id=persona_id).first()
else:
persona = get_persona_by_name(
persona_name=name, user_id=user_id, shared=shared, db_session=db_session
)
if persona:
if not default_persona and persona.default_persona:
raise ValueError("Cannot update default persona with non-default.")
persona.name = name
persona.retrieval_enabled = retrieval_enabled
persona.datetime_aware = datetime_aware
persona.system_text = system_text
persona.tools = tools
persona.hint_text = hint_text
persona.description = description
persona.num_chunks = num_chunks
persona.llm_relevance_filter = llm_relevance_filter
persona.llm_filter_extraction = llm_filter_extraction
persona.recency_bias = recency_bias
persona.default_persona = default_persona
persona.llm_model_version_override = llm_model_version_override
persona.deleted = False # Un-delete if previously deleted
# Do not delete any associations manually added unless
# a new updated list is provided
if document_sets is not None:
persona.document_sets.clear()
persona.document_sets = document_sets
persona.document_sets = document_sets or []
if prompts is not None:
persona.prompts.clear()
persona.prompts = prompts
else:
persona = Persona(
id=persona_id,
user_id=None if shared else user_id,
name=name,
retrieval_enabled=retrieval_enabled,
datetime_aware=datetime_aware,
system_text=system_text,
tools=tools,
hint_text=hint_text,
description=description,
num_chunks=num_chunks,
llm_relevance_filter=llm_relevance_filter,
llm_filter_extraction=llm_filter_extraction,
recency_bias=recency_bias,
default_persona=default_persona,
document_sets=document_sets if document_sets else [],
prompts=prompts or [],
document_sets=document_sets or [],
llm_model_version_override=llm_model_version_override,
)
db_session.add(persona)
@@ -345,3 +508,204 @@ def upsert_persona(
db_session.flush()
return persona
def mark_prompt_as_deleted(
prompt_id: int,
user_id: UUID | None,
db_session: Session,
) -> None:
prompt = get_prompt_by_id(
prompt_id=prompt_id, user_id=user_id, db_session=db_session
)
prompt.deleted = True
db_session.commit()
def mark_persona_as_deleted(
persona_id: int,
user_id: UUID | None,
db_session: Session,
) -> None:
persona = get_persona_by_id(
persona_id=persona_id, user_id=user_id, db_session=db_session
)
persona.deleted = True
db_session.commit()
def update_persona_visibility(
persona_id: int,
is_visible: bool,
db_session: Session,
) -> None:
persona = get_persona_by_id(
persona_id=persona_id, user_id=None, db_session=db_session
)
persona.is_visible = is_visible
db_session.commit()
def update_all_personas_display_priority(
display_priority_map: dict[int, int],
db_session: Session,
) -> None:
"""Updates the display priority of all lives Personas"""
personas = get_personas(user_id=None, db_session=db_session)
available_persona_ids = {persona.id for persona in personas}
if available_persona_ids != set(display_priority_map.keys()):
raise ValueError("Invalid persona IDs provided")
for persona in personas:
persona.display_priority = display_priority_map[persona.id]
db_session.commit()
def get_prompts(
user_id: UUID | None,
db_session: Session,
include_default: bool = True,
include_deleted: bool = False,
) -> Sequence[Prompt]:
stmt = select(Prompt).where(
or_(Prompt.user_id == user_id, Prompt.user_id.is_(None))
)
if not include_default:
stmt = stmt.where(Prompt.default_prompt.is_(False))
if not include_deleted:
stmt = stmt.where(Prompt.deleted.is_(False))
return db_session.scalars(stmt).all()
def get_personas(
# if user_id is `None` assume the user is an admin or auth is disabled
user_id: UUID | None,
db_session: Session,
include_default: bool = True,
include_slack_bot_personas: bool = False,
include_deleted: bool = False,
) -> Sequence[Persona]:
stmt = select(Persona)
if user_id is not None:
stmt = stmt.where(or_(Persona.user_id == user_id, Persona.user_id.is_(None)))
if not include_default:
stmt = stmt.where(Persona.default_persona.is_(False))
if not include_slack_bot_personas:
stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX)))
if not include_deleted:
stmt = stmt.where(Persona.deleted.is_(False))
return db_session.scalars(stmt).all()
def get_doc_query_identifiers_from_model(
search_doc_ids: list[int],
chat_session: ChatSession,
user_id: UUID | None,
db_session: Session,
) -> list[tuple[str, int]]:
"""Given a list of search_doc_ids"""
search_docs = (
db_session.query(SearchDoc).filter(SearchDoc.id.in_(search_doc_ids)).all()
)
if user_id != chat_session.user_id:
logger.error(
f"Docs referenced are from a chat session not belonging to user {user_id}"
)
raise ValueError("Docs references do not belong to user")
if any(
[doc.chat_messages[0].chat_session_id != chat_session.id for doc in search_docs]
):
raise ValueError("Invalid reference doc, not from this chat session.")
doc_query_identifiers = [(doc.document_id, doc.chunk_ind) for doc in search_docs]
return doc_query_identifiers
def create_db_search_doc(
server_search_doc: ServerSearchDoc,
db_session: Session,
) -> SearchDoc:
db_search_doc = SearchDoc(
document_id=server_search_doc.document_id,
chunk_ind=server_search_doc.chunk_ind,
semantic_id=server_search_doc.semantic_identifier,
link=server_search_doc.link,
blurb=server_search_doc.blurb,
source_type=server_search_doc.source_type,
boost=server_search_doc.boost,
hidden=server_search_doc.hidden,
doc_metadata=server_search_doc.metadata,
score=server_search_doc.score,
match_highlights=server_search_doc.match_highlights,
updated_at=server_search_doc.updated_at,
primary_owners=server_search_doc.primary_owners,
secondary_owners=server_search_doc.secondary_owners,
)
db_session.add(db_search_doc)
db_session.commit()
return db_search_doc
def get_db_search_doc_by_id(doc_id: int, db_session: Session) -> DBSearchDoc | None:
"""There are no safety checks here like user permission etc., use with caution"""
search_doc = db_session.query(SearchDoc).filter(SearchDoc.id == doc_id).first()
return search_doc
def translate_db_search_doc_to_server_search_doc(
db_search_doc: SearchDoc,
) -> SavedSearchDoc:
return SavedSearchDoc(
db_doc_id=db_search_doc.id,
document_id=db_search_doc.document_id,
chunk_ind=db_search_doc.chunk_ind,
semantic_identifier=db_search_doc.semantic_id,
link=db_search_doc.link,
blurb=db_search_doc.blurb,
source_type=db_search_doc.source_type,
boost=db_search_doc.boost,
hidden=db_search_doc.hidden,
metadata=db_search_doc.doc_metadata,
score=db_search_doc.score,
match_highlights=db_search_doc.match_highlights,
updated_at=db_search_doc.updated_at,
primary_owners=db_search_doc.primary_owners,
secondary_owners=db_search_doc.secondary_owners,
)
def get_retrieval_docs_from_chat_message(chat_message: ChatMessage) -> RetrievalDocs:
return RetrievalDocs(
top_documents=[
translate_db_search_doc_to_server_search_doc(db_doc)
for db_doc in chat_message.search_docs
]
)
def translate_db_message_to_chat_message_detail(
chat_message: ChatMessage,
) -> ChatMessageDetail:
chat_msg_detail = ChatMessageDetail(
message_id=chat_message.id,
parent_message=chat_message.parent_message,
latest_child_message=chat_message.latest_child_message,
message=chat_message.message,
rephrased_query=chat_message.rephrased_query,
context_docs=get_retrieval_docs_from_chat_message(chat_message),
message_type=chat_message.message_type,
time_sent=chat_message.time_sent,
citations=chat_message.citations,
)
return chat_msg_detail

View File

@@ -11,8 +11,8 @@ from danswer.configs.constants import DocumentSource
from danswer.connectors.models import InputType
from danswer.db.models import Connector
from danswer.db.models import IndexAttempt
from danswer.server.models import ConnectorBase
from danswer.server.models import ObjectCreationIdResponse
from danswer.server.documents.models import ConnectorBase
from danswer.server.documents.models import ObjectCreationIdResponse
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
@@ -36,8 +36,12 @@ def fetch_connectors(
return list(results.all())
def connector_by_name_exists(connector_name: str, db_session: Session) -> bool:
stmt = select(Connector).where(Connector.name == connector_name)
def connector_by_name_source_exists(
connector_name: str, source: DocumentSource, db_session: Session
) -> bool:
stmt = select(Connector).where(
Connector.name == connector_name, Connector.source == source
)
result = db_session.execute(stmt)
connector = result.scalar_one_or_none()
return connector is not None
@@ -50,11 +54,26 @@ def fetch_connector_by_id(connector_id: int, db_session: Session) -> Connector |
return connector
def fetch_ingestion_connector_by_name(
connector_name: str, db_session: Session
) -> Connector | None:
stmt = (
select(Connector)
.where(Connector.name == connector_name)
.where(Connector.source == DocumentSource.INGESTION_API)
)
result = db_session.execute(stmt)
connector = result.scalar_one_or_none()
return connector
def create_connector(
connector_data: ConnectorBase,
db_session: Session,
) -> ObjectCreationIdResponse:
if connector_by_name_exists(connector_data.name, db_session):
if connector_by_name_source_exists(
connector_data.name, connector_data.source, db_session
):
raise ValueError(
"Connector by this name already exists, duplicate naming not allowed."
)
@@ -82,8 +101,8 @@ def update_connector(
if connector is None:
return None
if connector_data.name != connector.name and connector_by_name_exists(
connector_data.name, db_session
if connector_data.name != connector.name and connector_by_name_source_exists(
connector_data.name, connector_data.source, db_session
):
raise ValueError(
"Connector by this name already exists, duplicate naming not allowed."
@@ -202,3 +221,44 @@ def fetch_latest_index_attempts_by_status(
),
)
return cast(list[IndexAttempt], query.all())
def fetch_unique_document_sources(db_session: Session) -> list[DocumentSource]:
distinct_sources = db_session.query(Connector.source).distinct().all()
sources = [
source[0]
for source in distinct_sources
if source[0] != DocumentSource.INGESTION_API
]
return sources
def create_initial_default_connector(db_session: Session) -> None:
default_connector_id = 0
default_connector = fetch_connector_by_id(default_connector_id, db_session)
if default_connector is not None:
if (
default_connector.source != DocumentSource.INGESTION_API
or default_connector.input_type != InputType.LOAD_STATE
or default_connector.refresh_freq is not None
or default_connector.disabled
):
raise ValueError(
"DB is not in a valid initial state. "
"Default connector does not have expected values."
)
return
connector = Connector(
id=default_connector_id,
name="Ingestion API",
source=DocumentSource.INGESTION_API,
input_type=InputType.LOAD_STATE,
connector_specific_config={},
refresh_freq=None,
)
db_session.add(connector)
db_session.commit()

View File

@@ -54,6 +54,8 @@ def get_last_successful_attempt_time(
credential_id: int,
db_session: Session,
) -> float:
"""Gets the timestamp of the last successful index run stored in
the CC Pair row in the database"""
connector_credential_pair = get_connector_credential_pair(
connector_id, credential_id, db_session
)
@@ -84,7 +86,10 @@ def update_connector_credential_pair(
cc_pair.last_attempt_status = attempt_status
# simply don't update last_successful_index_time if run_dt is not specified
# at worst, this would result in re-indexing documents that were already indexed
if attempt_status == IndexingStatus.SUCCESS and run_dt is not None:
if (
attempt_status == IndexingStatus.SUCCESS
or attempt_status == IndexingStatus.IN_PROGRESS
) and run_dt is not None:
cc_pair.last_successful_index_time = run_dt
if net_docs is not None:
cc_pair.total_docs_indexed += net_docs
@@ -117,6 +122,27 @@ def mark_all_in_progress_cc_pairs_failed(
db_session.commit()
def associate_default_cc_pair(db_session: Session) -> None:
existing_association = (
db_session.query(ConnectorCredentialPair)
.filter(
ConnectorCredentialPair.connector_id == 0,
ConnectorCredentialPair.credential_id == 0,
)
.one_or_none()
)
if existing_association is not None:
return
association = ConnectorCredentialPair(
connector_id=0,
credential_id=0,
name="DefaultCCPair",
)
db_session.add(association)
db_session.commit()
def add_credential_to_connector(
connector_id: int,
credential_id: int,

View File

@@ -0,0 +1 @@
SLACK_BOT_PERSONA_PREFIX = "__slack_bot_persona__"

View File

@@ -9,18 +9,20 @@ from danswer.auth.schemas import UserRole
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import Credential
from danswer.db.models import User
from danswer.server.models import CredentialBase
from danswer.server.models import ObjectCreationIdResponse
from danswer.server.documents.models import CredentialBase
from danswer.utils.logger import setup_logger
logger = setup_logger()
def _attach_user_filters(stmt: Select[tuple[Credential]], user: User | None) -> Select:
def _attach_user_filters(
stmt: Select[tuple[Credential]],
user: User | None,
assume_admin: bool = False, # Used with API key
) -> Select:
"""Attaches filters to the statement to ensure that the user can only
access the appropriate credentials"""
if user:
@@ -29,11 +31,18 @@ def _attach_user_filters(stmt: Select[tuple[Credential]], user: User | None) ->
or_(
Credential.user_id == user.id,
Credential.user_id.is_(None),
Credential.is_admin == True, # noqa: E712
Credential.admin_public == True, # noqa: E712
)
)
else:
stmt = stmt.where(Credential.user_id == user.id)
elif assume_admin:
stmt = stmt.where(
or_(
Credential.user_id.is_(None),
Credential.admin_public == True, # noqa: E712
)
)
return stmt
@@ -49,10 +58,13 @@ def fetch_credentials(
def fetch_credential_by_id(
credential_id: int, user: User | None, db_session: Session
credential_id: int,
user: User | None,
db_session: Session,
assume_admin: bool = False,
) -> Credential | None:
stmt = select(Credential).where(Credential.id == credential_id)
stmt = _attach_user_filters(stmt, user)
stmt = _attach_user_filters(stmt, user, assume_admin=assume_admin)
result = db_session.execute(stmt)
credential = result.scalar_one_or_none()
return credential
@@ -62,16 +74,16 @@ def create_credential(
credential_data: CredentialBase,
user: User | None,
db_session: Session,
) -> ObjectCreationIdResponse:
) -> Credential:
credential = Credential(
credential_json=credential_data.credential_json,
user_id=user.id if user else None,
is_admin=credential_data.is_admin,
admin_public=credential_data.admin_public,
)
db_session.add(credential)
db_session.commit()
return ObjectCreationIdResponse(id=credential.id)
return credential
def update_credential(
@@ -131,30 +143,26 @@ def delete_credential(
db_session.commit()
def create_initial_public_credential() -> None:
def create_initial_public_credential(db_session: Session) -> None:
public_cred_id = 0
error_msg = (
"DB is not in a valid initial state."
"There must exist an empty public credential for data connectors that do not require additional Auth."
)
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session:
first_credential = fetch_credential_by_id(public_cred_id, None, db_session)
first_credential = fetch_credential_by_id(public_cred_id, None, db_session)
if first_credential is not None:
if (
first_credential.credential_json != {}
or first_credential.user is not None
):
raise ValueError(error_msg)
return
if first_credential is not None:
if first_credential.credential_json != {} or first_credential.user is not None:
raise ValueError(error_msg)
return
credential = Credential(
id=public_cred_id,
credential_json={},
user_id=None,
)
db_session.add(credential)
db_session.commit()
credential = Credential(
id=public_cred_id,
credential_json={},
user_id=None,
)
db_session.add(credential)
db_session.commit()
def delete_google_drive_service_account_credentials(

View File

@@ -1,5 +1,6 @@
import time
from collections.abc import Sequence
from datetime import datetime
from uuid import UUID
from sqlalchemy import and_
@@ -16,9 +17,10 @@ from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential
from danswer.db.models import Document as DbDocument
from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.tag import delete_document_tags_for_documents
from danswer.db.utils import model_to_dict
from danswer.document_index.interfaces import DocumentMetadata
from danswer.server.models import ConnectorCredentialPairIdentifier
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -39,6 +41,15 @@ def get_documents_for_connector_credential_pair(
return db_session.scalars(stmt).all()
def get_documents_by_ids(
document_ids: list[str],
db_session: Session,
) -> list[DbDocument]:
stmt = select(DbDocument).where(DbDocument.id.in_(document_ids))
documents = db_session.execute(stmt).scalars().all()
return list(documents)
def get_document_connector_cnts(
db_session: Session,
document_ids: list[str],
@@ -136,9 +147,13 @@ def get_acccess_info_for_documents(
def upsert_documents(
db_session: Session, document_metadata_batch: list[DocumentMetadata]
db_session: Session,
document_metadata_batch: list[DocumentMetadata],
initial_boost: int = DEFAULT_BOOST,
) -> None:
"""NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause."""
"""NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause.
Also note, this function should not be used for updating documents, only creating and
ensuring that it exists. It IGNORES the doc_updated_at field"""
seen_documents: dict[str, DocumentMetadata] = {}
for document_metadata in document_metadata_batch:
doc_id = document_metadata.document_id
@@ -154,11 +169,12 @@ def upsert_documents(
model_to_dict(
DbDocument(
id=doc.document_id,
boost=DEFAULT_BOOST,
from_ingestion_api=doc.from_ingestion_api,
boost=initial_boost,
hidden=False,
semantic_id=doc.semantic_identifier,
link=doc.first_link,
doc_updated_at=doc.doc_updated_at,
doc_updated_at=None, # this is intentional
primary_owners=doc.primary_owners,
secondary_owners=doc.secondary_owners,
)
@@ -200,6 +216,21 @@ def upsert_document_by_connector_credential_pair(
db_session.commit()
def update_docs_updated_at(
ids_to_new_updated_at: dict[str, datetime],
db_session: Session,
) -> None:
doc_ids = list(ids_to_new_updated_at.keys())
documents_to_update = (
db_session.query(DbDocument).filter(DbDocument.id.in_(doc_ids)).all()
)
for document in documents_to_update:
document.doc_updated_at = ids_to_new_updated_at[document.id]
db_session.commit()
def upsert_documents_complete(
db_session: Session,
document_metadata_batch: list[DocumentMetadata],
@@ -242,6 +273,7 @@ def delete_documents_complete(db_session: Session, document_ids: list[str]) -> N
delete_document_feedback_for_documents(
document_ids=document_ids, db_session=db_session
)
delete_document_tags_for_documents(document_ids=document_ids, db_session=db_session)
delete_documents(db_session, document_ids)
db_session.commit()

View File

@@ -14,8 +14,8 @@ from danswer.db.models import Document
from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.models import DocumentSet as DocumentSetDBModel
from danswer.db.models import DocumentSet__ConnectorCredentialPair
from danswer.server.models import DocumentSetCreationRequest
from danswer.server.models import DocumentSetUpdateRequest
from danswer.server.features.document_set.models import DocumentSetCreationRequest
from danswer.server.features.document_set.models import DocumentSetUpdateRequest
def _delete_document_set_cc_pairs__no_commit(
@@ -60,6 +60,8 @@ def get_document_set_by_name(
def get_document_sets_by_ids(
db_session: Session, document_set_ids: list[int]
) -> Sequence[DocumentSetDBModel]:
if not document_set_ids:
return []
return db_session.scalars(
select(DocumentSetDBModel).where(DocumentSetDBModel.id.in_(document_set_ids))
).all()
@@ -396,3 +398,33 @@ def get_or_create_document_set_by_name(
db_session.commit()
return new_doc_set
def check_document_sets_are_public(
db_session: Session,
document_set_ids: list[int],
) -> bool:
connector_credential_pair_ids = (
db_session.query(
DocumentSet__ConnectorCredentialPair.connector_credential_pair_id
)
.filter(
DocumentSet__ConnectorCredentialPair.document_set_id.in_(document_set_ids)
)
.subquery()
)
not_public_exists = (
db_session.query(ConnectorCredentialPair.id)
.filter(
ConnectorCredentialPair.id.in_(
connector_credential_pair_ids # type:ignore
),
ConnectorCredentialPair.is_public.is_(False),
)
.limit(1)
.first()
is not None
)
return not not_public_exists

View File

@@ -4,34 +4,19 @@ from sqlalchemy import asc
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import select
from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm import Session
from danswer.configs.constants import MessageType
from danswer.configs.constants import QAFeedbackType
from danswer.configs.constants import SearchFeedbackType
from danswer.db.models import ChatMessage as DbChatMessage
from danswer.db.chat import get_chat_message
from danswer.db.models import ChatMessageFeedback
from danswer.db.models import Document as DbDocument
from danswer.db.models import DocumentRetrievalFeedback
from danswer.db.models import QueryEvent
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import UpdateRequest
from danswer.search.models import SearchType
def fetch_query_event_by_id(query_id: int, db_session: Session) -> QueryEvent:
stmt = select(QueryEvent).where(QueryEvent.id == query_id)
result = db_session.execute(stmt)
query_event = result.scalar_one_or_none()
if not query_event:
raise ValueError("Invalid Query Event ID Provided")
return query_event
def fetch_docs_by_id(doc_id: str, db_session: Session) -> DbDocument:
def fetch_db_doc_by_id(doc_id: str, db_session: Session) -> DbDocument:
stmt = select(DbDocument).where(DbDocument.id == doc_id)
result = db_session.execute(stmt)
doc = result.scalar_one_or_none()
@@ -97,80 +82,20 @@ def update_document_hidden(
db_session.commit()
def create_query_event(
db_session: Session,
query: str,
search_type: SearchType | None,
llm_answer: str | None,
user_id: UUID | None,
retrieved_document_ids: list[str] | None = None,
) -> int:
query_event = QueryEvent(
query=query,
selected_search_flow=search_type,
llm_answer=llm_answer,
retrieved_document_ids=retrieved_document_ids,
user_id=user_id,
)
db_session.add(query_event)
db_session.commit()
return query_event.id
def update_query_event_feedback(
db_session: Session,
feedback: QAFeedbackType,
query_id: int,
user_id: UUID | None,
) -> None:
query_event = fetch_query_event_by_id(query_id, db_session)
if user_id != query_event.user_id:
raise ValueError("User trying to give feedback on a query run by another user.")
query_event.feedback = feedback
db_session.commit()
def update_query_event_retrieved_documents(
db_session: Session,
retrieved_document_ids: list[str],
query_id: int,
user_id: UUID | None,
) -> None:
query_event = fetch_query_event_by_id(query_id, db_session)
if user_id != query_event.user_id:
raise ValueError("User trying to update docs on a query run by another user.")
query_event.retrieved_document_ids = retrieved_document_ids
db_session.commit()
def create_doc_retrieval_feedback(
qa_event_id: int,
message_id: int,
document_id: str,
document_rank: int,
user_id: UUID | None,
document_index: DocumentIndex,
db_session: Session,
clicked: bool = False,
feedback: SearchFeedbackType | None = None,
) -> None:
"""Creates a new Document feedback row and updates the boost value in Postgres and Vespa"""
if not clicked and feedback is None:
raise ValueError("No action taken, not valid feedback")
query_event = fetch_query_event_by_id(qa_event_id, db_session)
if user_id != query_event.user_id:
raise ValueError("User trying to give feedback on a query run by another user.")
doc_m = fetch_docs_by_id(document_id, db_session)
db_doc = fetch_db_doc_by_id(document_id, db_session)
retrieval_feedback = DocumentRetrievalFeedback(
qa_event_id=qa_event_id,
chat_message_id=message_id,
document_id=document_id,
document_rank=document_rank,
clicked=clicked,
@@ -179,20 +104,23 @@ def create_doc_retrieval_feedback(
if feedback is not None:
if feedback == SearchFeedbackType.ENDORSE:
doc_m.boost += 1
db_doc.boost += 1
elif feedback == SearchFeedbackType.REJECT:
doc_m.boost -= 1
db_doc.boost -= 1
elif feedback == SearchFeedbackType.HIDE:
doc_m.hidden = True
db_doc.hidden = True
elif feedback == SearchFeedbackType.UNHIDE:
doc_m.hidden = False
db_doc.hidden = False
else:
raise ValueError("Unhandled document feedback type")
if feedback in [SearchFeedbackType.ENDORSE, SearchFeedbackType.REJECT]:
if feedback in [
SearchFeedbackType.ENDORSE,
SearchFeedbackType.REJECT,
SearchFeedbackType.HIDE,
]:
update = UpdateRequest(
document_ids=[document_id],
boost=doc_m.boost,
document_ids=[document_id], boost=db_doc.boost, hidden=db_doc.hidden
)
# Updates are generally batched for efficiency, this case only 1 doc/value is updated
document_index.update([update])
@@ -213,40 +141,24 @@ def delete_document_feedback_for_documents(
def create_chat_message_feedback(
chat_session_id: int,
message_number: int,
edit_number: int,
is_positive: bool | None,
feedback_text: str | None,
chat_message_id: int,
user_id: UUID | None,
db_session: Session,
is_positive: bool | None = None,
feedback_text: str | None = None,
) -> None:
if is_positive is None and feedback_text is None:
raise ValueError("No feedback provided")
try:
chat_message = (
db_session.query(DbChatMessage)
.filter_by(
chat_session_id=chat_session_id,
message_number=message_number,
edit_number=edit_number,
)
.one()
)
except NoResultFound:
raise ValueError("ChatMessage not found")
chat_message = get_chat_message(
chat_message_id=chat_message_id, user_id=user_id, db_session=db_session
)
if chat_message.message_type != MessageType.ASSISTANT:
raise ValueError("Can only provide feedback on LLM Outputs")
if user_id is not None and chat_message.chat_session.user_id != user_id:
raise ValueError("User trying to give feedback on a message by another user.")
message_feedback = ChatMessageFeedback(
chat_message_chat_session_id=chat_session_id,
chat_message_message_number=message_number,
chat_message_edit_number=edit_number,
chat_message_id=chat_message_id,
is_positive=is_positive,
feedback_text=feedback_text,
)

View File

@@ -7,13 +7,15 @@ from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.server.models import ConnectorCredentialPairIdentifier
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
logger = setup_logger()
@@ -55,8 +57,13 @@ def get_inprogress_index_attempts(
def get_not_started_index_attempts(db_session: Session) -> list[IndexAttempt]:
"""This eagerly loads the connector and credential so that the db_session can be expired
before running long-living indexing jobs, which causes increasing memory usage"""
stmt = select(IndexAttempt)
stmt = stmt.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
stmt = stmt.options(
joinedload(IndexAttempt.connector), joinedload(IndexAttempt.credential)
)
new_attempts = db_session.scalars(stmt)
return list(new_attempts.all())
@@ -88,6 +95,9 @@ def mark_attempt_failed(
db_session.add(index_attempt)
db_session.commit()
source = index_attempt.connector.source
optional_telemetry(record_type=RecordType.FAILURE, data={"connector": source})
def update_docs_indexed(
db_session: Session,

View File

@@ -4,6 +4,7 @@ from typing import Any
from typing import List
from typing import Literal
from typing import NotRequired
from typing import Optional
from typing import TypedDict
from uuid import UUID
@@ -13,14 +14,15 @@ from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTa
from sqlalchemy import Boolean
from sqlalchemy import DateTime
from sqlalchemy import Enum
from sqlalchemy import Float
from sqlalchemy import ForeignKey
from sqlalchemy import ForeignKeyConstraint
from sqlalchemy import func
from sqlalchemy import Index
from sqlalchemy import Integer
from sqlalchemy import Sequence
from sqlalchemy import String
from sqlalchemy import Text
from sqlalchemy import UniqueConstraint
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
@@ -31,9 +33,9 @@ from danswer.auth.schemas import UserRole
from danswer.configs.constants import DEFAULT_BOOST
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import MessageType
from danswer.configs.constants import QAFeedbackType
from danswer.configs.constants import SearchFeedbackType
from danswer.connectors.models import InputType
from danswer.search.models import RecencyBiasSetting
from danswer.search.models import SearchType
@@ -64,6 +66,11 @@ class Base(DeclarativeBase):
pass
"""
Auth/Authz (users, permissions, access) Tables
"""
class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
# even an almost empty token from keycloak will not fit the default 1024 bytes
access_token: Mapped[str] = mapped_column(Text, nullable=False) # type: ignore
@@ -79,12 +86,11 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
credentials: Mapped[List["Credential"]] = relationship(
"Credential", back_populates="user", lazy="joined"
)
query_events: Mapped[List["QueryEvent"]] = relationship(
"QueryEvent", back_populates="user"
)
chat_sessions: Mapped[List["ChatSession"]] = relationship(
"ChatSession", back_populates="user"
)
prompts: Mapped[List["Prompt"]] = relationship("Prompt", back_populates="user")
personas: Mapped[List["Persona"]] = relationship("Persona", back_populates="user")
class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base):
@@ -92,7 +98,7 @@ class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base):
"""
Association tables
Association Tables
NOTE: must be at the top since they are referenced by other tables
"""
@@ -106,6 +112,13 @@ class Persona__DocumentSet(Base):
)
class Persona__Prompt(Base):
__tablename__ = "persona__prompt"
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True)
prompt_id: Mapped[int] = mapped_column(ForeignKey("prompt.id"), primary_key=True)
class DocumentSet__ConnectorCredentialPair(Base):
__tablename__ = "document_set__connector_credential_pair"
@@ -130,6 +143,31 @@ class DocumentSet__ConnectorCredentialPair(Base):
document_set: Mapped["DocumentSet"] = relationship("DocumentSet")
class ChatMessage__SearchDoc(Base):
__tablename__ = "chat_message__search_doc"
chat_message_id: Mapped[int] = mapped_column(
ForeignKey("chat_message.id"), primary_key=True
)
search_doc_id: Mapped[int] = mapped_column(
ForeignKey("search_doc.id"), primary_key=True
)
class Document__Tag(Base):
__tablename__ = "document__tag"
document_id: Mapped[str] = mapped_column(
ForeignKey("document.id"), primary_key=True
)
tag_id: Mapped[int] = mapped_column(ForeignKey("tag.id"), primary_key=True)
"""
Documents/Indexing Tables
"""
class ConnectorCredentialPair(Base):
"""Connectors and Credentials can have a many-to-many relationship
I.e. A Confluence Connector may have multiple admin users who can run it with their own credentials
@@ -145,9 +183,7 @@ class ConnectorCredentialPair(Base):
unique=True,
nullable=False,
)
name: Mapped[str] = mapped_column(
String, unique=True, nullable=True
) # nullable for backwards compatability
name: Mapped[str] = mapped_column(String, nullable=False)
connector_id: Mapped[int] = mapped_column(
ForeignKey("connector.id"), primary_key=True
)
@@ -185,6 +221,70 @@ class ConnectorCredentialPair(Base):
)
class Document(Base):
__tablename__ = "document"
# this should correspond to the ID of the document
# (as is passed around in Danswer)
id: Mapped[str] = mapped_column(String, primary_key=True)
from_ingestion_api: Mapped[bool] = mapped_column(
Boolean, default=False, nullable=True
)
# 0 for neutral, positive for mostly endorse, negative for mostly reject
boost: Mapped[int] = mapped_column(Integer, default=DEFAULT_BOOST)
hidden: Mapped[bool] = mapped_column(Boolean, default=False)
semantic_id: Mapped[str] = mapped_column(String)
# First Section's link
link: Mapped[str | None] = mapped_column(String, nullable=True)
# The updated time is also used as a measure of the last successful state of the doc
# pulled from the source (to help skip reindexing already updated docs in case of
# connector retries)
doc_updated_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
# The following are not attached to User because the account/email may not be known
# within Danswer
# Something like the document creator
primary_owners: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True
)
# Something like assignee or space owner
secondary_owners: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True
)
# TODO if more sensitive data is added here for display, make sure to add user/group permission
retrieval_feedbacks: Mapped[List["DocumentRetrievalFeedback"]] = relationship(
"DocumentRetrievalFeedback", back_populates="document"
)
tags = relationship(
"Tag",
secondary="document__tag",
back_populates="documents",
)
class Tag(Base):
__tablename__ = "tag"
id: Mapped[int] = mapped_column(primary_key=True)
tag_key: Mapped[str] = mapped_column(String)
tag_value: Mapped[str] = mapped_column(String)
source: Mapped[DocumentSource] = mapped_column(Enum(DocumentSource))
documents = relationship(
"Document",
secondary="document__tag",
back_populates="tags",
)
__table_args__ = (
UniqueConstraint(
"tag_key", "tag_value", "source", name="_tag_key_value_source_uc"
),
)
class Connector(Base):
__tablename__ = "connector"
@@ -226,7 +326,7 @@ class Credential(Base):
credential_json: Mapped[dict[str, Any]] = mapped_column(postgresql.JSONB())
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
# if `true`, then all Admins will have access to the credential
is_admin: Mapped[bool] = mapped_column(Boolean, default=True)
admin_public: Mapped[bool] = mapped_column(Boolean, default=True)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
@@ -315,8 +415,7 @@ class IndexAttempt(Base):
class DocumentByConnectorCredentialPair(Base):
"""Represents an indexing of a document by a specific connector / credential
pair"""
"""Represents an indexing of a document by a specific connector / credential pair"""
__tablename__ = "document_by_connector_credential_pair"
@@ -337,47 +436,136 @@ class DocumentByConnectorCredentialPair(Base):
)
class QueryEvent(Base):
__tablename__ = "query_event"
"""
Messages Tables
"""
class SearchDoc(Base):
"""Different from Document table. This one stores the state of a document from a retrieval.
This allows chat sessions to be replayed with the searched docs
Notably, this does not include the contents of the Document/Chunk, during inference if a stored
SearchDoc is selected, an inference must be remade to retrieve the contents
"""
__tablename__ = "search_doc"
id: Mapped[int] = mapped_column(primary_key=True)
query: Mapped[str] = mapped_column(Text)
# search_flow refers to user selection, None if user used auto
selected_search_flow: Mapped[SearchType | None] = mapped_column(
Enum(SearchType), nullable=True
document_id: Mapped[str] = mapped_column(String)
chunk_ind: Mapped[int] = mapped_column(Integer)
semantic_id: Mapped[str] = mapped_column(String)
link: Mapped[str | None] = mapped_column(String, nullable=True)
blurb: Mapped[str] = mapped_column(String)
boost: Mapped[int] = mapped_column(Integer)
source_type: Mapped[DocumentSource] = mapped_column(Enum(DocumentSource))
hidden: Mapped[bool] = mapped_column(Boolean)
doc_metadata: Mapped[dict[str, str | list[str]]] = mapped_column(postgresql.JSONB())
score: Mapped[float] = mapped_column(Float)
match_highlights: Mapped[list[str]] = mapped_column(postgresql.ARRAY(String))
# This is for the document, not this row in the table
updated_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
llm_answer: Mapped[str | None] = mapped_column(Text, default=None)
# Document IDs of the top context documents retrieved for the query (if any)
# NOTE: not using a foreign key to enable easy deletion of documents without
# needing to adjust `QueryEvent` rows
retrieved_document_ids: Mapped[list[str] | None] = mapped_column(
primary_owners: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True
)
feedback: Mapped[QAFeedbackType | None] = mapped_column(
Enum(QAFeedbackType), nullable=True
)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
secondary_owners: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True
)
user: Mapped[User | None] = relationship("User", back_populates="query_events")
document_feedbacks: Mapped[List["DocumentRetrievalFeedback"]] = relationship(
"DocumentRetrievalFeedback", back_populates="qa_event"
chat_messages = relationship(
"ChatMessage",
secondary="chat_message__search_doc",
back_populates="search_docs",
)
class ChatSession(Base):
__tablename__ = "chat_session"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"))
description: Mapped[str] = mapped_column(Text)
# One-shot direct answering, currently the two types of chats are not mixed
one_shot: Mapped[bool] = mapped_column(Boolean, default=False)
# Only ever set to True if system is set to not hard-delete chats
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
user: Mapped[User] = relationship("User", back_populates="chat_sessions")
messages: Mapped[List["ChatMessage"]] = relationship(
"ChatMessage", back_populates="chat_session", cascade="delete"
)
persona: Mapped["Persona"] = relationship("Persona")
class ChatMessage(Base):
"""Note, the first message in a chain has no contents, it's a workaround to allow edits
on the first message of a session, an empty root node basically
Since every user message is followed by a LLM response, chat messages generally come in pairs.
Keeping them as separate messages however for future Agentification extensions
Fields will be largely duplicated in the pair.
"""
__tablename__ = "chat_message"
id: Mapped[int] = mapped_column(primary_key=True)
chat_session_id: Mapped[int] = mapped_column(ForeignKey("chat_session.id"))
parent_message: Mapped[int | None] = mapped_column(Integer, nullable=True)
latest_child_message: Mapped[int | None] = mapped_column(Integer, nullable=True)
message: Mapped[str] = mapped_column(Text)
rephrased_query: Mapped[str] = mapped_column(Text, nullable=True)
# If None, then there is no answer generation, it's the special case of only
# showing the user the retrieved docs
prompt_id: Mapped[int | None] = mapped_column(ForeignKey("prompt.id"))
# If prompt is None, then token_count is 0 as this message won't be passed into
# the LLM's context (not included in the history of messages)
token_count: Mapped[int] = mapped_column(Integer)
message_type: Mapped[MessageType] = mapped_column(Enum(MessageType))
# Maps the citation numbers to a SearchDoc id
citations: Mapped[dict[int, int]] = mapped_column(postgresql.JSONB(), nullable=True)
# Only applies for LLM
error: Mapped[str | None] = mapped_column(Text, nullable=True)
time_sent: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
chat_session: Mapped[ChatSession] = relationship("ChatSession")
prompt: Mapped[Optional["Prompt"]] = relationship("Prompt")
chat_message_feedbacks: Mapped[List["ChatMessageFeedback"]] = relationship(
"ChatMessageFeedback", back_populates="chat_message"
)
document_feedbacks: Mapped[List["DocumentRetrievalFeedback"]] = relationship(
"DocumentRetrievalFeedback", back_populates="chat_message"
)
search_docs = relationship(
"SearchDoc",
secondary="chat_message__search_doc",
back_populates="chat_messages",
)
"""
Feedback, Logging, Metrics Tables
"""
class DocumentRetrievalFeedback(Base):
__tablename__ = "document_retrieval_feedback"
id: Mapped[int] = mapped_column(primary_key=True)
qa_event_id: Mapped[int] = mapped_column(
ForeignKey("query_event.id"),
)
document_id: Mapped[str] = mapped_column(
ForeignKey("document.id"),
)
chat_message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
document_id: Mapped[str] = mapped_column(ForeignKey("document.id"))
# How high up this document is in the results, 1 for first
document_rank: Mapped[int] = mapped_column(Integer)
clicked: Mapped[bool] = mapped_column(Boolean, default=False)
@@ -385,46 +573,32 @@ class DocumentRetrievalFeedback(Base):
Enum(SearchFeedbackType), nullable=True
)
qa_event: Mapped[QueryEvent] = relationship(
"QueryEvent", back_populates="document_feedbacks"
chat_message: Mapped[ChatMessage] = relationship(
"ChatMessage", back_populates="document_feedbacks"
)
document: Mapped["Document"] = relationship(
document: Mapped[Document] = relationship(
"Document", back_populates="retrieval_feedbacks"
)
class Document(Base):
__tablename__ = "document"
class ChatMessageFeedback(Base):
__tablename__ = "chat_feedback"
# this should correspond to the ID of the document
# (as is passed around in Danswer)
id: Mapped[str] = mapped_column(String, primary_key=True)
# 0 for neutral, positive for mostly endorse, negative for mostly reject
boost: Mapped[int] = mapped_column(Integer, default=DEFAULT_BOOST)
hidden: Mapped[bool] = mapped_column(Boolean, default=False)
semantic_id: Mapped[str] = mapped_column(String)
# First Section's link
link: Mapped[str | None] = mapped_column(String, nullable=True)
doc_updated_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
# The following are not attached to User because the account/email may not be known
# within Danswer
# Something like the document creator
primary_owners: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True
)
# Something like assignee or space owner
secondary_owners: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True
)
# TODO if more sensitive data is added here for display, make sure to add user/group permission
id: Mapped[int] = mapped_column(Integer, primary_key=True)
chat_message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
is_positive: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
feedback_text: Mapped[str | None] = mapped_column(Text, nullable=True)
retrieval_feedbacks: Mapped[List[DocumentRetrievalFeedback]] = relationship(
"DocumentRetrievalFeedback", back_populates="document"
chat_message: Mapped[ChatMessage] = relationship(
"ChatMessage", back_populates="chat_message_feedbacks"
)
"""
Structures, Organizational, Configurations Tables
"""
class DocumentSet(Base):
__tablename__ = "document_set"
@@ -432,7 +606,7 @@ class DocumentSet(Base):
name: Mapped[str] = mapped_column(String, unique=True)
description: Mapped[str] = mapped_column(String)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
# whether or not changes to the document set have been propogated
# Whether changes to the document set have been propagated
is_up_to_date: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
connector_credential_pairs: Mapped[list[ConnectorCredentialPair]] = relationship(
@@ -448,59 +622,84 @@ class DocumentSet(Base):
)
class ChatSession(Base):
__tablename__ = "chat_session"
class Prompt(Base):
__tablename__ = "prompt"
id: Mapped[int] = mapped_column(primary_key=True)
# If not belong to a user, then it's shared
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
description: Mapped[str] = mapped_column(Text)
name: Mapped[str] = mapped_column(String)
description: Mapped[str] = mapped_column(String)
system_prompt: Mapped[str] = mapped_column(Text)
task_prompt: Mapped[str] = mapped_column(Text)
include_citations: Mapped[bool] = mapped_column(Boolean, default=True)
datetime_aware: Mapped[bool] = mapped_column(Boolean, default=True)
# Default prompts are configured via backend during deployment
# Treated specially (cannot be user edited etc.)
default_prompt: Mapped[bool] = mapped_column(Boolean, default=False)
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
# The following texts help build up the model's ability to use the context effectively
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
user: Mapped[User] = relationship("User", back_populates="chat_sessions")
messages: Mapped[List["ChatMessage"]] = relationship(
"ChatMessage", back_populates="chat_session", cascade="delete"
user: Mapped[User] = relationship("User", back_populates="prompts")
personas: Mapped[list["Persona"]] = relationship(
"Persona",
secondary=Persona__Prompt.__table__,
back_populates="prompts",
)
class ToolInfo(TypedDict):
name: str
description: str
class Persona(Base):
# TODO introduce user and group ownership for personas
__tablename__ = "persona"
id: Mapped[int] = mapped_column(primary_key=True)
# If not belong to a user, then it's shared
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
name: Mapped[str] = mapped_column(String)
# Danswer retrieval, treated as a special tool
retrieval_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
datetime_aware: Mapped[bool] = mapped_column(Boolean, default=True)
system_text: Mapped[str | None] = mapped_column(Text, nullable=True)
tools: Mapped[list[ToolInfo] | None] = mapped_column(
postgresql.JSONB(), nullable=True
description: Mapped[str] = mapped_column(String)
# Currently stored but unused, all flows use hybrid
search_type: Mapped[SearchType] = mapped_column(
Enum(SearchType), default=SearchType.HYBRID
)
# Number of chunks to pass to the LLM for generation.
# If unspecified, uses the default DEFAULT_NUM_CHUNKS_FED_TO_CHAT set in the env variable
num_chunks: Mapped[float | None] = mapped_column(Float, nullable=True)
# Pass every chunk through LLM for evaluation, fairly expensive
# Can be turned off globally by admin, in which case, this setting is ignored
llm_relevance_filter: Mapped[bool] = mapped_column(Boolean)
# Enables using LLM to extract time and source type filters
# Can also be admin disabled globally
llm_filter_extraction: Mapped[bool] = mapped_column(Boolean)
recency_bias: Mapped[RecencyBiasSetting] = mapped_column(Enum(RecencyBiasSetting))
# Allows the Persona to specify a different LLM version than is controlled
# globablly via env variables. For flexibility, validity is not currently enforced
# NOTE: only is applied on the actual response generation - is not used for things like
# auto-detected time filters, relevance filters, etc.
llm_model_version_override: Mapped[str | None] = mapped_column(
String, nullable=True
)
hint_text: Mapped[str | None] = mapped_column(Text, nullable=True)
# Default personas are configured via backend during deployment
# Treated specially (cannot be user edited etc.)
default_persona: Mapped[bool] = mapped_column(Boolean, default=False)
# If it's updated and no longer latest (should no longer be shown), it is also considered deleted
# controls whether the persona is available to be selected by users
is_visible: Mapped[bool] = mapped_column(Boolean, default=True)
# controls the ordering of personas in the UI
# higher priority personas are displayed first, ties are resolved by the ID,
# where lower value IDs (e.g. created earlier) are displayed first
display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=None)
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
# These are only defaults, users can select from all if desired
prompts: Mapped[list[Prompt]] = relationship(
"Prompt",
secondary=Persona__Prompt.__table__,
back_populates="personas",
)
# These are only defaults, users can select from all if desired
document_sets: Mapped[list[DocumentSet]] = relationship(
"DocumentSet",
secondary=Persona__DocumentSet.__table__,
back_populates="personas",
)
user: Mapped[User] = relationship("User", back_populates="personas")
# Default personas loaded via yaml cannot have the same name
__table_args__ = (
@@ -513,78 +712,13 @@ class Persona(Base):
)
class ChatMessage(Base):
__tablename__ = "chat_message"
chat_session_id: Mapped[int] = mapped_column(
ForeignKey("chat_session.id"), primary_key=True
)
message_number: Mapped[int] = mapped_column(Integer, primary_key=True)
edit_number: Mapped[int] = mapped_column(Integer, default=0, primary_key=True)
parent_edit_number: Mapped[int | None] = mapped_column(
Integer, nullable=True
) # null if first message
latest: Mapped[bool] = mapped_column(Boolean, default=True)
message: Mapped[str] = mapped_column(Text)
token_count: Mapped[int] = mapped_column(Integer)
message_type: Mapped[MessageType] = mapped_column(Enum(MessageType))
reference_docs: Mapped[dict[str, Any] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
persona_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id"), nullable=True
)
time_sent: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
chat_session: Mapped[ChatSession] = relationship("ChatSession")
persona: Mapped[Persona | None] = relationship("Persona")
class ChatMessageFeedback(Base):
__tablename__ = "chat_feedback"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
chat_message_chat_session_id: Mapped[int] = mapped_column(Integer)
chat_message_message_number: Mapped[int] = mapped_column(Integer)
chat_message_edit_number: Mapped[int] = mapped_column(Integer)
is_positive: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
feedback_text: Mapped[str | None] = mapped_column(Text, nullable=True)
__table_args__ = (
ForeignKeyConstraint(
[
"chat_message_chat_session_id",
"chat_message_message_number",
"chat_message_edit_number",
],
[
"chat_message.chat_session_id",
"chat_message.message_number",
"chat_message.edit_number",
],
),
)
chat_message: Mapped[ChatMessage] = relationship(
"ChatMessage",
foreign_keys=[
chat_message_chat_session_id,
chat_message_message_number,
chat_message_edit_number,
],
backref="feedbacks",
)
AllowedAnswerFilters = (
Literal["well_answered_postfilter"] | Literal["questionmark_prefilter"]
)
class ChannelConfig(TypedDict):
"""NOTE: is a `TypedDict` so it can be used a type hint for a JSONB column
"""NOTE: is a `TypedDict` so it can be used as a type hint for a JSONB column
in Postgres"""
channel_names: list[str]

View File

@@ -3,15 +3,19 @@ from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT
from danswer.db.chat import upsert_persona
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
from danswer.db.document_set import get_document_sets_by_ids
from danswer.db.models import ChannelConfig
from danswer.db.models import Persona
from danswer.db.models import Persona__DocumentSet
from danswer.db.models import SlackBotConfig
from danswer.search.models import RecencyBiasSetting
def _build_persona_name(channel_names: list[str]) -> str:
return f"__slack_bot_persona__{'-'.join(channel_names)}"
return f"{SLACK_BOT_PERSONA_PREFIX}{'-'.join(channel_names)}"
def _cleanup_relationships(db_session: Session, persona_id: int) -> None:
@@ -26,55 +30,51 @@ def _cleanup_relationships(db_session: Session, persona_id: int) -> None:
db_session.delete(rel)
def _create_slack_bot_persona(
def create_slack_bot_persona(
db_session: Session,
channel_names: list[str],
document_sets: list[int],
document_set_ids: list[int],
existing_persona_id: int | None = None,
num_chunks: float = DEFAULT_NUM_CHUNKS_FED_TO_CHAT,
) -> Persona:
"""NOTE: does not commit changes"""
document_sets = list(
get_document_sets_by_ids(
document_set_ids=document_set_ids,
db_session=db_session,
)
)
# create/update persona associated with the slack bot
persona_name = _build_persona_name(channel_names)
persona = upsert_persona(
user_id=None, # Slack Bot Personas are not attached to users
persona_id=existing_persona_id,
name=persona_name,
datetime_aware=False,
retrieval_enabled=True,
system_text=None,
tools=None,
hint_text=None,
description="",
num_chunks=num_chunks,
llm_relevance_filter=True,
llm_filter_extraction=True,
recency_bias=RecencyBiasSetting.AUTO,
prompts=None,
document_sets=document_sets,
llm_model_version_override=None,
shared=True,
default_persona=False,
db_session=db_session,
commit=False,
)
if existing_persona_id:
_cleanup_relationships(db_session=db_session, persona_id=existing_persona_id)
# create relationship between the new persona and the desired document_sets
for document_set_id in document_sets:
db_session.add(
Persona__DocumentSet(persona_id=persona.id, document_set_id=document_set_id)
)
return persona
def insert_slack_bot_config(
document_sets: list[int],
persona_id: int | None,
channel_config: ChannelConfig,
db_session: Session,
) -> SlackBotConfig:
persona = None
if document_sets:
persona = _create_slack_bot_persona(
db_session=db_session,
channel_names=channel_config["channel_names"],
document_sets=document_sets,
)
slack_bot_config = SlackBotConfig(
persona_id=persona.id if persona else None,
persona_id=persona_id,
channel_config=channel_config,
)
db_session.add(slack_bot_config)
@@ -85,7 +85,7 @@ def insert_slack_bot_config(
def update_slack_bot_config(
slack_bot_config_id: int,
document_sets: list[int],
persona_id: int | None,
channel_config: ChannelConfig,
db_session: Session,
) -> SlackBotConfig:
@@ -96,31 +96,29 @@ def update_slack_bot_config(
raise ValueError(
f"Unable to find slack bot config with ID {slack_bot_config_id}"
)
# get the existing persona id before updating the object
existing_persona_id = slack_bot_config.persona_id
persona = None
if document_sets:
persona = _create_slack_bot_persona(
db_session=db_session,
channel_names=channel_config["channel_names"],
document_sets=document_sets,
existing_persona_id=slack_bot_config.persona_id,
# update the config
# NOTE: need to do this before cleaning up the old persona or else we
# will encounter `violates foreign key constraint` errors
slack_bot_config.persona_id = persona_id
slack_bot_config.channel_config = channel_config
# if the persona has changed, then clean up the old persona
if persona_id != existing_persona_id and existing_persona_id:
existing_persona = db_session.scalar(
select(Persona).where(Persona.id == existing_persona_id)
)
else:
# if no document sets and an existing persona exists, then
# remove persona + persona -> document set relationships
if existing_persona_id:
# if the existing persona was one created just for use with this Slack Bot,
# then clean it up
if existing_persona and existing_persona.name.startswith(
SLACK_BOT_PERSONA_PREFIX
):
_cleanup_relationships(
db_session=db_session, persona_id=existing_persona_id
)
existing_persona = db_session.scalar(
select(Persona).where(Persona.id == existing_persona_id)
)
db_session.delete(existing_persona)
slack_bot_config.persona_id = persona.id if persona else None
slack_bot_config.channel_config = channel_config
db_session.commit()
return slack_bot_config
@@ -140,11 +138,30 @@ def remove_slack_bot_config(
existing_persona_id = slack_bot_config.persona_id
if existing_persona_id:
_cleanup_relationships(db_session=db_session, persona_id=existing_persona_id)
existing_persona = db_session.scalar(
select(Persona).where(Persona.id == existing_persona_id)
)
# if the existing persona was one created just for use with this Slack Bot,
# then clean it up
if existing_persona and existing_persona.name.startswith(
SLACK_BOT_PERSONA_PREFIX
):
_cleanup_relationships(
db_session=db_session, persona_id=existing_persona_id
)
db_session.delete(existing_persona)
db_session.delete(slack_bot_config)
db_session.commit()
def fetch_slack_bot_config(
db_session: Session, slack_bot_config_id: int
) -> SlackBotConfig | None:
return db_session.scalar(
select(SlackBotConfig).where(SlackBotConfig.id == slack_bot_config_id)
)
def fetch_slack_bot_configs(db_session: Session) -> Sequence[SlackBotConfig]:
return db_session.scalars(select(SlackBotConfig)).all()

Some files were not shown because too many files have changed in this diff Show More