Compare commits

..

145 Commits

Author SHA1 Message Date
Richard Kuo (Danswer)
a13dea160a use redis completion signal to double check exit code 2024-12-12 10:56:11 -08:00
pablonyx
ca172f3306 Merge pull request #3442 from onyx-dot-app/vespa_seeding_fix
Update initial seeding for latency requirements
2024-12-12 09:59:50 -08:00
pablodanswer
e5d0587efa pre-commit 2024-12-12 09:12:08 -08:00
pablonyx
a9516202fe update conditional (#3446) 2024-12-12 17:07:30 +00:00
Richard Kuo
d23fca96c4 reverse commit (fix later) 2024-12-11 22:19:10 -08:00
pablodanswer
a45724c899 run black 2024-12-11 19:18:06 -08:00
pablodanswer
34e250407a k 2024-12-11 19:14:10 -08:00
pablodanswer
046c0fbe3e update indexing 2024-12-11 19:08:05 -08:00
pablonyx
76595facef Merge pull request #3432 from onyx-dot-app/vercel_preview
Enable Vercel Preview
2024-12-11 18:55:14 -08:00
pablodanswer
af2d548766 k 2024-12-11 18:52:47 -08:00
Weves
7c29b1e028 add more egnyte failure logging 2024-12-11 18:19:55 -08:00
pablonyx
a52c821e78 Merge pull request #3436 from onyx-dot-app/cloud_improvements
cloud improvements
2024-12-11 17:06:06 -08:00
pablonyx
0770a587f1 remove slack workspace (#3394)
* remove slack workspace

* update client tokens

* fix up

* clean up docs

* fix up tests
2024-12-12 01:01:43 +00:00
hagen-danswer
748b79b0ef Added text for empty table and cascade delete for slack bot deletion (#3390)
* fixed fk issue for slack bot deletion

* Added text for empty table and cascade delete for slack bot deletion
2024-12-12 01:00:32 +00:00
pablonyx
9cacb373ef let users specify resourcing caps (#3403)
* let users specify resourcing caps

* functioanl resource limits

* improve defaults

* k

* update

* update comment + refer to proper resource

* self nit

* update var names
2024-12-12 00:59:41 +00:00
pablodanswer
21967d4b6f cloud improvements 2024-12-11 16:48:00 -08:00
pablodanswer
f5d638161b k 2024-12-11 15:35:44 -08:00
pablodanswer
0b5013b47d k 2024-12-11 15:34:26 -08:00
pablodanswer
1b846fbf06 update config 2024-12-11 15:17:11 -08:00
hagen-danswer
cae8a131a2 Made frontend conditional check for source (#3434) 2024-12-11 22:46:32 +00:00
pablonyx
72b4e8e9fe Clean citation cards (#3396)
* seed

* initial steps

* clean up

* fully clickable
2024-12-11 21:37:11 +00:00
pablonyx
c04e2f14d9 remove double x (#3387) 2024-12-11 21:36:58 +00:00
pablonyx
b40a12d5d7 clean up cursor pointers (#3385)
* update

* nit
2024-12-11 21:36:43 +00:00
pablonyx
5e7d454ebe Merge pull request #3433 from onyx-dot-app/silence_integration
Silence Slack Permission Sync test flakiness
2024-12-11 13:49:52 -08:00
pablodanswer
238509c536 silence 2024-12-11 13:48:37 -08:00
pablodanswer
d7f8cf8f18 testing 2024-12-11 13:36:10 -08:00
pablodanswer
5d810d373e k 2024-12-11 13:32:09 -08:00
joachim-danswer
9455576078 Mismatch issue of Documents shown and Citation number in text fix (#3421)
* Mismatch issue of Documents shown and Citation number in text fix

When document order presented to LLM differs from order shown to user, wrong doc numbers are cited.

Fix:
 - SearchTool.get_search_result  returns now final and initial ranking
 - initial ranking is passed through a few objects and used for replacement in citation processing

Notes:
 - the citation_num in the CitationInfo() object has not been changed.

* PR fixes

 - linting
 - removed erroneous tab
 - added a substitution test case
 - adjusted original citation extraction use case

* Included a key test and

* Fixed extra spaces

* Updated test documentation

Updated:
 - test_citation_substitution (changed description)
 - test_citation_processing (removed data only relevant for the substitution)
2024-12-11 19:58:24 +00:00
rkuo-danswer
71421bb782 better handling around index attempts that don't exist and remove unn… (#3417)
* better handling around index attempts that don't exist and remove unnecessary index attempt deletions

* don't delete index attempts, just update them

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2024-12-11 19:32:04 +00:00
pablonyx
b88cb388b7 Faster api hashing (#3423)
* migrate hashing to run faster v1

* k
2024-12-11 19:30:05 +00:00
Wendi
639986001f Fix bug (title overflow) (#3431) 2024-12-11 12:09:44 -08:00
pablonyx
e7a7e78969 clean up csv prompt + frontend (#3393)
* clean up csv prompt + frontend

* nit

* nit

* detect uploading

* upload
2024-12-11 19:10:34 +00:00
rkuo-danswer
e255ff7d23 editable refresh and prune for connectors (#3406)
* editable refresh and prune for connectors

* add extra validations on pruning/refresh frequency

* fix validation

* fix icon usage

* fix TextFormField error formatting

* nit

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
Co-authored-by: pablodanswer <pablo@danswer.ai>
2024-12-11 19:04:09 +00:00
pablonyx
1be2502112 finalize (#3398)
Co-authored-by: hagen-danswer <hagen@danswer.ai>
2024-12-11 18:52:20 +00:00
pablonyx
f2bedb8fdd Borders (#3388)
* remove double x

* incorporate base default padding for modals
2024-12-11 18:47:26 +00:00
pablonyx
637404f482 Connector page lists (pending feedback) (#3415)
* v1 (pending feedback)

* nits

* nit
2024-12-11 18:45:27 +00:00
pablonyx
daae146920 recognize updates (#3397) 2024-12-11 18:19:00 +00:00
pablonyx
d95959fb41 base role setting fix (#3381)
* base role setting fix

* update user tables

* finalize

* minor cleanup

* fix chromatic
2024-12-11 18:09:47 +00:00
rkuo-danswer
c667d28e7a update helm charts for onyx-dot-app rebrand (#3412)
* update helm charts for onyx-dot-app rebrand

* fix helm chart testing config

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2024-12-11 18:08:39 +00:00
pablonyx
9e0b482f47 k (#3399) 2024-12-11 18:05:39 +00:00
pablonyx
fa84eb657f cleaner citations (#3389) 2024-12-11 17:36:15 +00:00
pablonyx
264df3441b Various clean ups (#3413)
* tbd

* minor

* prettify

* update sidebar values
2024-12-11 17:19:14 +00:00
pablonyx
b9bad8b7a0 fix wikipedia icon (#3395) 2024-12-11 09:03:29 -08:00
pablonyx
600ebb6432 remove doc sets (#3400) 2024-12-11 16:31:14 +00:00
pablonyx
09fe8ea868 improved display - no odd cutoffs (#3401) 2024-12-11 16:09:19 +00:00
evan-danswer
ad6be03b4d centered score in feedbac panel (#3426) 2024-12-11 08:19:53 -08:00
rkuo-danswer
65d2511216 change text and formatting to guide users away from thinking "Back to… (#3382)
* change text and formatting to guide users away from thinking "Back to Danswer" is a back button

* regular text color and different icon

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2024-12-11 03:31:27 +00:00
Weves
113bf19c65 Remove dev-only check 2024-12-10 19:04:21 -08:00
Yuhong Sun
6026536110 Model Server Async (#3386)
* need-verify

* fix some lib calls

* k

* tests

* k

* k

* k

* Address the comments

* fix comment
2024-12-11 01:33:44 +00:00
Weves
056b671cd4 Small tweaks to get Egynte to work on our cloud 2024-12-10 17:43:46 -08:00
pablonyx
8d83ae2ee8 fix linear (#3402) 2024-12-11 00:45:06 +00:00
Yuhong Sun
ca988f5c5f Max File Size (#3422)
* k

* k

* k
2024-12-11 00:06:47 +00:00
Chris Weaver
4e4214b82c Egnyte connector (#3420) 2024-12-10 16:07:33 -08:00
Yuhong Sun
fe83f676df k (#3404) 2024-12-10 23:27:48 +00:00
hagen-danswer
6d6e12119b made external group emails lowercase (#3410) 2024-12-10 22:08:00 +00:00
pablonyx
1f2b7cb9c8 strip text for slackbot (#3416)
* stripe text for slackbot

* k
2024-12-10 21:42:35 +00:00
pablonyx
878a189011 delete input prompts (#3380)
* delete input prompts

* nit

* remove vestigial test

* nit
2024-12-10 21:36:40 +00:00
hagen-danswer
48c10271c2 fixed ephemeral slackbot messages (#3409) 2024-12-10 18:00:34 +00:00
evan-danswer
c6a79d847e fix typo (#3408)
expliticly -> explicitly
2024-12-10 16:44:42 +00:00
hagen-danswer
1bc3f8b96f Revert "Fixed ephemeral slackbot messages"
This reverts commit 7f6a6944d6.
2024-12-10 08:18:31 -08:00
hagen-danswer
7f6a6944d6 Fixed ephemeral slackbot messages 2024-12-10 07:57:28 -08:00
Weves
06f4146597 Bump litellm to support Nova models from AWS 2024-12-09 21:19:11 -08:00
hagen-danswer
7ea73d5a5a Temp slackbot url error Fix (#3392) 2024-12-09 18:34:38 -08:00
Weves
30dfe6dcb4 Add better vertex support + LLM form cleanup 2024-12-09 13:44:44 -08:00
Yuhong Sun
dc5d5dfe05 README Update (#3383) 2024-12-09 13:17:53 -08:00
pablonyx
0746e0be5b unify toggling (#3378) 2024-12-09 19:48:06 +00:00
Chris Weaver
970320bd49 Persona / prompt hardening (#3375)
* Persona / prompt hardening

* fix it
2024-12-09 03:39:59 +00:00
Chris Weaver
4a7bd5578e Fix Confluence perm sync for cloud users (#3374) 2024-12-09 01:41:30 +00:00
Chris Weaver
874b098a4b Add more logging + retries to teams connector (#3369) 2024-12-08 00:56:34 +00:00
pablodanswer
ce18b63eea hide oauth sources (#3368) 2024-12-07 23:57:37 +00:00
Yuhong Sun
7a919c3589 Dev Version Niceness 2024-12-07 15:10:13 -08:00
rkuo-danswer
631bac4432 Bugfix/log exit code (#3362)
* log the exit code of the spawned task

* exitcode can be negative

* mypy fixes
2024-12-06 22:32:59 +00:00
hagen-danswer
53428f6e9c More logging/fixes (#3364)
* More logging for external group syncing

* Fixed edge case where some spaces were not being fetched

* made refresh frequency for confluence syncs configurable

* clarity
2024-12-06 21:56:29 +00:00
pablodanswer
53b3dcbace fix slackbot channel config nullable (#3363)
* fix slackbot

* nit
2024-12-06 21:24:36 +00:00
rkuo-danswer
7a3c06c2d2 first cut at slack oauth flow (#3323)
* first cut at slack oauth flow

* fix usage of hooks

* fix button spacing

* add additional error logging

* no dev redirect

* cleanup

* comment work in progress

* move some stuff to ee, add some playwright tests for the oauth callback edge cases

* fix ee, fix test name

* fix tests

* code review fixes
2024-12-06 19:55:21 +00:00
pablodanswer
7a0d823c89 Improved file handling (#3353)
* update props

* update documents

* nit

* update chat processing

* k

* k

* nit

* minor nit

* minor nits

* k

* nits
2024-12-06 19:16:54 +00:00
Yuhong Sun
db69e445d6 k (#3358) 2024-12-06 18:08:44 +00:00
Weves
18e63889b7 Change default log level back to info 2024-12-06 10:07:14 -08:00
Weves
738e60c8ed Increase vespa attempts on startup 2024-12-06 09:46:33 -08:00
hagen-danswer
8aec873e66 Merge pull request #3359 from danswer-ai/conf-logging-filter
Added filter to slim connector and logging for space permissions
2024-12-06 09:03:07 -08:00
hagen-danswer
7c57dde8ab fixed test 2024-12-06 08:33:12 -08:00
hagen-danswer
f30adab853 Merge remote-tracking branch 'origin/main' into conf-logging-filter 2024-12-06 08:30:07 -08:00
hagen-danswer
601687a522 Add test for Confluence permissions 2024-12-06 08:28:42 -08:00
hagen-danswer
350cf407c9 explicitly set page and attachment restrictions and space keys 2024-12-06 08:12:07 -08:00
hagen-danswer
32ec4efc7a tygod for tests 2024-12-06 08:03:34 -08:00
hagen-danswer
7c6981e052 Added filter to slim connector and logging for space permissions 2024-12-06 07:55:54 -08:00
Yuhong Sun
c50cd20156 Fix SlackBot Page Bugs (#3354) 2024-12-05 13:17:04 -08:00
hagen-danswer
14772dee71 Add persona stats (#3282)
* Added a chart to display persona message stats

* polish

* k

* hope this works

* cleanup
2024-12-05 17:15:56 +00:00
pablodanswer
c81e704c95 various niceties (#3348) 2024-12-05 17:12:52 +00:00
Chris Weaver
3266ef6321 Improve chat page performance (#3347)
* Simplify /manage/indexing-status

* Rename endpoint
2024-12-04 20:28:30 -08:00
pablodanswer
c89b98b4f2 update email invites (#3349) 2024-12-05 03:29:07 +00:00
rkuo-danswer
e70e0ab859 Merge pull request #3346 from danswer-ai/bugfix/chromatic-tests-2
Bugfix/chromatic tests 2
2024-12-04 19:44:05 -08:00
Richard Kuo (Danswer)
69b6e9321e Merge branch 'main' of https://github.com/danswer-ai/danswer into bugfix/chromatic-tests-2
# Conflicts:
#	web/tests/e2e/home.spec.ts
2024-12-04 19:10:25 -08:00
Chris Weaver
7e53af18b6 Add b64 image support for image generation (#3342)
* Add b64 image support

* Fix

* enhance

* Fix mypy

* Fix imports
2024-12-05 02:24:54 +00:00
Richard Kuo (Danswer)
b9eb1ca2ba wait for whole placeholder string 2024-12-04 18:23:06 -08:00
rkuo-danswer
91d44c83d2 fixing chromatic tests (#3344)
* wait for the page to load

* fix up tests

* make sure "Initializing Danswer" is gone
2024-12-05 02:19:43 +00:00
Richard Kuo (Danswer)
4dbc6bb4d1 make sure "Initializing Danswer" is gone 2024-12-04 17:49:59 -08:00
Richard Kuo (Danswer)
4b6a4c6bbf fix up tests 2024-12-04 17:19:16 -08:00
pablodanswer
fd1999454a ensure we can order by doc id (#3343) 2024-12-05 01:10:37 +00:00
Richard Kuo (Danswer)
0a35422d1d wait for the page to load 2024-12-04 16:47:42 -08:00
pablodanswer
69b99056b2 Redirect to chat (#3341)
* k

* nit
2024-12-05 00:08:52 +00:00
Yuhong Sun
2a55696545 Move Answer (#3339) 2024-12-04 16:30:47 -08:00
hagen-danswer
ef9942b751 Related permission docs to cc_pair to prevent orphan docs (#3336)
* Related permission docs to cc_pair to prevent orphan docs

* added script

* group sync deduping

* logging
2024-12-04 21:00:54 +00:00
pablodanswer
993acec5e9 Update memoization + silence unnecessary errors (#3337)
* update memoization + silence unnecessary errors

* proper org
2024-12-04 20:08:15 +00:00
Weves
b01a1b509a Add basic loadtest script 2024-12-04 10:53:48 -08:00
pablodanswer
4f994124ef remove now unnecessary user loading indicatort log (#3333) 2024-12-04 00:09:22 +00:00
rkuo-danswer
14863bd457 try single threaded playwright testing (#3322) 2024-12-03 23:21:46 +00:00
Yuhong Sun
aa1c4c635a Combining Search and Chat Backend (#3273)
* k

* k

* fix slack issues

* rebase

* k
2024-12-03 22:37:14 +00:00
rkuo-danswer
13f6e8a6b4 disable thread local locking in callbacks (#3319) 2024-12-03 22:32:56 +00:00
pablodanswer
66f47d294c Shared filter utility for clarity (#3270)
* shared filter util

* clearer comment
2024-12-03 19:30:42 +00:00
pablodanswer
0a685bda7d add comments for clarity (#3249) 2024-12-03 19:27:28 +00:00
pablodanswer
23dc8b5dad Search flow improvements (#3314)
* untoggle if no docs

* update

* nits

* nit

* typing

* nit
2024-12-03 18:56:27 +00:00
pablodanswer
cd5f2293ad Temperature (#3310)
* fix temperatures for default llm

* ensure anthropic models don't overflow

* minor cleanup

* k

* k

* k

* fix typing
2024-12-03 17:22:22 +00:00
rkuo-danswer
6c2269e565 refactor celery task names to constants (#3296) 2024-12-03 16:02:17 +00:00
Weves
46315cddf1 Adjust default confulence timezone 2024-12-02 22:25:29 -08:00
rkuo-danswer
5f28a1b0e4 Bugfix/confluence time zone (#3265)
* RedisLock typing

* checkpoint

* put in debug logging

* improve comments

* mypy fixes
2024-12-02 22:23:23 -08:00
rkuo-danswer
9e9b7ed61d Bugfix/connector aborted logging (#3309)
* improve error logging on task failure.

* add db exception hardening to the indexing watchdog

* log on db exception
2024-12-03 02:34:40 +00:00
pablodanswer
3fb2bfefec Update Chromatic Tests (#3300)
* remove / update search tests

* minor update
2024-12-02 23:08:54 +00:00
pablodanswer
7c618c9d17 Unified UI (#3308)
* fix typing

* add filters display
2024-12-02 15:12:13 -08:00
pablodanswer
03e2789392 Text embedding (PDF, TXT) (#3113)
* add text embedding

* post rebase cleanup

* fully functional post rebase

* rm logs

* rm '

* quick clean up

* k
2024-12-02 22:43:53 +00:00
Chris Weaver
2783fa08a3 Update openai version in model server (#3306) 2024-12-02 21:39:10 +00:00
pablodanswer
edeaee93a2 hard refresh on auth (#3305)
* hard refresh on auth

* k

* k

* comment for clarity
2024-12-02 20:12:12 +00:00
hagen-danswer
5385bae100 Add slim connector description (#3303)
* added docs example and test

* updated docs

* needed to make the tests run

* updated docs
2024-12-02 19:52:13 +00:00
pablodanswer
813445ab59 Minor JWT Feature (#3290)
* first pass

* k

* k

* finalize

* minor cleanup

* k

* address

* minor typing updates
2024-12-02 19:14:31 +00:00
pablodanswer
af814823c8 display name + model truncation (#3304) 2024-12-02 18:54:08 +00:00
pablodanswer
607f61eaeb Reusable function for search settings spread operation (#3301)
* combine for clarity once and for all

* remove logs

* k
2024-12-02 17:23:01 +00:00
pablodanswer
de66f7adb2 Updated chat flow (#3244)
* proper no assistant typing + no assistant modal

* updated chat flow

* k

* updates

* update

* k

* clean up

* fix mystery reorg

* cleanup

* update scroll

* default

* update logs

* push fade

* scroll nit

* finalize tags

* updates

* k

* various updates

* viewport height update

* source types update

* clean up unused components

* minor cleanup

* cleanup complete

* finalize changes

* badge up

* update filters

* small nit

* k

* k

* address comments

* quick unification of icons

* minor date range clarity

* minor nit

* k

* update sidebar line

* update for all screen sizes

* k

* k

* k

* k

* rm shs

* fix memoization

* fix memoization

* slack chat

* k

* k

* build org
2024-12-02 01:58:28 +00:00
Yuhong Sun
3432d932d1 Citation code comments 2024-12-01 14:10:11 -08:00
Yuhong Sun
9bd0cb9eb5 Fix Citation Minor Bugs (#3294) 2024-12-01 13:55:24 -08:00
Chris Weaver
f12eb4a5cf Fix assistant prompt zero-ing (#3293) 2024-11-30 04:45:40 +00:00
Chris Weaver
16863de0aa Improve model token limit detection (#3292)
* Properly find context window for ollama llama

* Better ollama support + upgrade litellm

* Ugprade OpenAI as well

* Fix mypy
2024-11-30 04:42:56 +00:00
Weves
63d1eefee5 Add read_only=True for xlsx parsing 2024-11-28 16:02:02 -08:00
pablodanswer
e338677896 order seeding 2024-11-28 15:41:10 -08:00
hagen-danswer
7be80c4af9 increased the pagination limit for confluence spaces (#3288) 2024-11-28 19:04:38 +00:00
rkuo-danswer
7f1e4a02bf Feature/kill indexing (#3213)
* checkpoint

* add celery termination of the task

* rename to RedisConnectorPermissionSyncPayload, add RedisLock to more places, add get_active_search_settings

* rename payload

* pretty sure these weren't named correctly

* testing in progress

* cleanup

* remove space

* merge fix

* three dots animation on Pausing

* improve messaging when connector is stopped or killed and animate buttons

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-11-28 05:32:45 +00:00
rkuo-danswer
5be7d27285 use indexing flag in db for manually triggering indexing (#3264)
* use indexing flag in db for manually trigger indexing

* add comment.

* only try to release the lock if we actually succeeded with the lock

* ensure we don't trigger manual indexing on anything but the primary search settings

* comment usage of primary search settings

* run check for indexing immediately after indexing triggers are set

* reorder fix
2024-11-28 01:34:34 +00:00
Weves
fd84b7a768 Remove duplicate API key router 2024-11-27 16:30:59 -08:00
Subash-Mohan
36941ae663 fix: Cannot configure API keys #3191 2024-11-27 16:25:00 -08:00
Matthew Holland
212353ed4a Fixed default feedback options 2024-11-27 16:23:52 -08:00
Richard Kuo (Danswer)
eb8708f770 the word "error" might be throwing off sentry 2024-11-27 14:31:21 -08:00
Chris Weaver
ac448956e9 Add handling for rate limiting (#3280) 2024-11-27 14:22:15 -08:00
pablodanswer
634a0b9398 no stack by default (#3278) 2024-11-27 20:58:21 +00:00
hagen-danswer
09d3e47c03 Perm sync behavior change (#3262)
* Change external permissions behavior

* fixed behavior

* added error handling

* LLM the goat

* comment

* simplify

* fixed

* done

* limits increased

* added a ton of logging

* uhhhh
2024-11-27 20:04:15 +00:00
pablodanswer
9c0cc94f15 refresh router -> refresh assistants (#3271) 2024-11-27 19:11:58 +00:00
hagen-danswer
07dfde2209 add continue in danswer button to slack bot responses (#3239)
* all done except routing

* fixed initial changes

* added backend endpoint for duplicating a chat session from Slack

* got chat duplication routing done

* got login routing working

* improved answer handling

* finished all checks

* finished all!

* made sure it works with google oauth

* dont remove that lol

* fixed weird thing

* bad comments
2024-11-27 18:25:38 +00:00
423 changed files with 13394 additions and 9643 deletions

View File

@@ -24,6 +24,8 @@ env:
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }}
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
# Slab
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
jobs:
connectors-check:

View File

@@ -1,48 +1,48 @@
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/README.md"} -->
<!-- DANSWER_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/README.md"} -->
<a name="readme-top"></a>
<h2 align="center">
<a href="https://www.danswer.ai/"> <img width="50%" src="https://github.com/danswer-owners/danswer/blob/1fabd9372d66cd54238847197c33f091a724803b/DanswerWithName.png?raw=true)" /></a>
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/LogoOnyx.png?raw=true)" /></a>
</h2>
<p align="center">
<p align="center">Open Source Gen-AI Chat + Unified Search.</p>
<p align="center">Open Source Gen-AI + Enterprise Search.</p>
<p align="center">
<a href="https://docs.danswer.dev/" target="_blank">
<a href="https://docs.onyx.app/" target="_blank">
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
</a>
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA" target="_blank">
<a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2sslpdbyq-iIbTaNIVPBw_i_4vrujLYQ" target="_blank">
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
</a>
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
<img src="https://img.shields.io/badge/discord-join-blue.svg?logo=discord&logoColor=white" alt="Discord">
</a>
<a href="https://github.com/danswer-ai/danswer/blob/main/README.md" target="_blank">
<a href="https://github.com/onyx-dot-app/onyx/blob/main/README.md" target="_blank">
<img src="https://img.shields.io/static/v1?label=license&message=MIT&color=blue" alt="License">
</a>
</p>
<strong>[Danswer](https://www.danswer.ai/)</strong> is the AI Assistant connected to your company's docs, apps, and people.
Danswer provides a Chat interface and plugs into any LLM of your choice. Danswer can be deployed anywhere and for any
<strong>[Onyx](https://www.onyx.app/)</strong> (Formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
Onyx provides a Chat interface and plugs into any LLM of your choice. Onyx can be deployed anywhere and for any
scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your
own control. Danswer is MIT licensed and designed to be modular and easily extensible. The system also comes fully ready
own control. Onyx is dual Licensed with most of it under MIT license and designed to be modular and easily extensible. The system also comes fully ready
for production usage with user authentication, role management (admin/basic users), chat persistence, and a UI for
configuring Personas (AI Assistants) and their Prompts.
configuring AI Assistants.
Danswer also serves as a Unified Search across all common workplace tools such as Slack, Google Drive, Confluence, etc.
By combining LLMs and team specific knowledge, Danswer becomes a subject matter expert for the team. Imagine ChatGPT if
Onyx also serves as a Enterprise Search across all common workplace tools such as Slack, Google Drive, Confluence, etc.
By combining LLMs and team specific knowledge, Onyx becomes a subject matter expert for the team. Imagine ChatGPT if
it had access to your team's unique knowledge! It enables questions such as "A customer wants feature X, is this already
supported?" or "Where's the pull request for feature Y?"
<h3>Usage</h3>
Danswer Web App:
Onyx Web App:
https://github.com/danswer-ai/danswer/assets/32520769/563be14c-9304-47b5-bf0a-9049c2b6f410
Or, plug Danswer into your existing Slack workflows (more integrations to come 😁):
Or, plug Onyx into your existing Slack workflows (more integrations to come 😁):
https://github.com/danswer-ai/danswer/assets/25087905/3e19739b-d178-4371-9a38-011430bdec1b
@@ -52,16 +52,16 @@ For more details on the Admin UI to manage connectors and users, check out our
## Deployment
Danswer can easily be run locally (even on a laptop) or deployed on a virtual machine with a single
`docker compose` command. Checkout our [docs](https://docs.danswer.dev/quickstart) to learn more.
Onyx can easily be run locally (even on a laptop) or deployed on a virtual machine with a single
`docker compose` command. Checkout our [docs](https://docs.onyx.app/quickstart) to learn more.
We also have built-in support for deployment on Kubernetes. Files for that can be found [here](https://github.com/danswer-ai/danswer/tree/main/deployment/kubernetes).
We also have built-in support for deployment on Kubernetes. Files for that can be found [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment/kubernetes).
## 💃 Main Features
* Chat UI with the ability to select documents to chat with.
* Create custom AI Assistants with different prompts and backing knowledge sets.
* Connect Danswer with LLM of your choice (self-host for a fully airgapped solution).
* Connect Onyx with LLM of your choice (self-host for a fully airgapped solution).
* Document Search + AI Answers for natural language queries.
* Connectors to all common workplace tools like Google Drive, Confluence, Slack, etc.
* Slack integration to get answers and search results directly in Slack.
@@ -75,12 +75,12 @@ We also have built-in support for deployment on Kubernetes. Files for that can b
* Organizational understanding and ability to locate and suggest experts from your team.
## Other Notable Benefits of Danswer
## Other Notable Benefits of Onyx
* User Authentication with document level access management.
* Best in class Hybrid Search across all sources (BM-25 + prefix aware embedding models).
* Admin Dashboard to configure connectors, document-sets, access, etc.
* Custom deep learning models + learn from user feedback.
* Easy deployment and ability to host Danswer anywhere of your choosing.
* Easy deployment and ability to host Onyx anywhere of your choosing.
## 🔌 Connectors
@@ -108,10 +108,10 @@ Efficiently pulls the latest changes from:
## 📚 Editions
There are two editions of Danswer:
There are two editions of Onyx:
* Danswer Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Danswer you will get if you follow the Deployment guide above.
* Danswer Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes:
* Onyx Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Onyx you will get if you follow the Deployment guide above.
* Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes:
* Single Sign-On (SSO), with support for both SAML and OIDC
* Role-based access control
* Document permission inheritance from connected sources
@@ -119,24 +119,24 @@ There are two editions of Danswer:
* Whitelabeling
* API key authentication
* Encryption of secrets
* Any many more! Checkout [our website](https://www.danswer.ai/) for the latest.
* Any many more! Checkout [our website](https://www.onyx.app/) for the latest.
To try the Danswer Enterprise Edition:
To try the Onyx Enterprise Edition:
1. Checkout our [Cloud product](https://app.danswer.ai/signup).
2. For self-hosting, contact us at [founders@danswer.ai](mailto:founders@danswer.ai) or book a call with us on our [Cal](https://cal.com/team/danswer/founders).
1. Checkout our [Cloud product](https://cloud.onyx.app/signup).
2. For self-hosting, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/danswer/founders).
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
## ⭐Star History
[![Star History Chart](https://api.star-history.com/svg?repos=danswer-ai/danswer&type=Date)](https://star-history.com/#danswer-ai/danswer&Date)
[![Star History Chart](https://api.star-history.com/svg?repos=onyx-dot-app/onyx&type=Date)](https://star-history.com/#onyx-dot-app/onyx&Date)
## ✨Contributors
<a href="https://github.com/danswer-ai/danswer/graphs/contributors">
<img alt="contributors" src="https://contrib.rocks/image?repo=danswer-ai/danswer"/>
<a href="https://github.com/onyx-dot-app/onyx/graphs/contributors">
<img alt="contributors" src="https://contrib.rocks/image?repo=onyx-dot-app/onyx"/>
</a>
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">

View File

@@ -73,6 +73,7 @@ RUN apt-get update && \
rm -rf /var/lib/apt/lists/* && \
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
# Pre-downloading models for setups with limited egress
RUN python -c "from tokenizers import Tokenizer; \
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"

View File

@@ -1,5 +1,5 @@
from sqlalchemy.engine.base import Connection
from typing import Any
from typing import Literal
import asyncio
from logging.config import fileConfig
import logging
@@ -8,6 +8,7 @@ from alembic import context
from sqlalchemy import pool
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.sql import text
from sqlalchemy.sql.schema import SchemaItem
from shared_configs.configs import MULTI_TENANT
from danswer.db.engine import build_connection_string
@@ -35,7 +36,18 @@ logger = logging.getLogger(__name__)
def include_object(
object: Any, name: str, type_: str, reflected: bool, compare_to: Any
object: SchemaItem,
name: str | None,
type_: Literal[
"schema",
"table",
"column",
"index",
"unique_constraint",
"foreign_key_constraint",
],
reflected: bool,
compare_to: SchemaItem | None,
) -> bool:
"""
Determines whether a database object should be included in migrations.

View File

@@ -0,0 +1,35 @@
"""add web ui option to slack config
Revision ID: 93560ba1b118
Revises: 6d562f86c78b
Create Date: 2024-11-24 06:36:17.490612
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "93560ba1b118"
down_revision = "6d562f86c78b"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add show_continue_in_web_ui with default False to all existing channel_configs
op.execute(
"""
UPDATE slack_channel_config
SET channel_config = channel_config || '{"show_continue_in_web_ui": false}'::jsonb
WHERE NOT channel_config ? 'show_continue_in_web_ui'
"""
)
def downgrade() -> None:
# Remove show_continue_in_web_ui from all channel_configs
op.execute(
"""
UPDATE slack_channel_config
SET channel_config = channel_config - 'show_continue_in_web_ui'
"""
)

View File

@@ -0,0 +1,36 @@
"""Combine Search and Chat
Revision ID: 9f696734098f
Revises: a8c2065484e6
Create Date: 2024-11-27 15:32:19.694972
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "9f696734098f"
down_revision = "a8c2065484e6"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column("chat_session", "description", nullable=True)
op.drop_column("chat_session", "one_shot")
op.drop_column("slack_channel_config", "response_type")
def downgrade() -> None:
op.execute("UPDATE chat_session SET description = '' WHERE description IS NULL")
op.alter_column("chat_session", "description", nullable=False)
op.add_column(
"chat_session",
sa.Column("one_shot", sa.Boolean(), nullable=False, server_default=sa.false()),
)
op.add_column(
"slack_channel_config",
sa.Column(
"response_type", sa.String(), nullable=False, server_default="citations"
),
)

View File

@@ -0,0 +1,27 @@
"""add auto scroll to user model
Revision ID: a8c2065484e6
Revises: abe7378b8217
Create Date: 2024-11-22 17:34:09.690295
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "a8c2065484e6"
down_revision = "abe7378b8217"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column("auto_scroll", sa.Boolean(), nullable=True, server_default=None),
)
def downgrade() -> None:
op.drop_column("user", "auto_scroll")

View File

@@ -0,0 +1,30 @@
"""add indexing trigger to cc_pair
Revision ID: abe7378b8217
Revises: 6d562f86c78b
Create Date: 2024-11-26 19:09:53.481171
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "abe7378b8217"
down_revision = "93560ba1b118"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column(
"indexing_trigger",
sa.Enum("UPDATE", "REINDEX", name="indexingmode", native_enum=False),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "indexing_trigger")

View File

@@ -0,0 +1,57 @@
"""delete_input_prompts
Revision ID: bf7a81109301
Revises: f7a894b06d02
Create Date: 2024-12-09 12:00:49.884228
"""
from alembic import op
import sqlalchemy as sa
import fastapi_users_db_sqlalchemy
# revision identifiers, used by Alembic.
revision = "bf7a81109301"
down_revision = "f7a894b06d02"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_table("inputprompt__user")
op.drop_table("inputprompt")
def downgrade() -> None:
op.create_table(
"inputprompt",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("prompt", sa.String(), nullable=False),
sa.Column("content", sa.String(), nullable=False),
sa.Column("active", sa.Boolean(), nullable=False),
sa.Column("is_public", sa.Boolean(), nullable=False),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"inputprompt__user",
sa.Column("input_prompt_id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["input_prompt_id"],
["inputprompt.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["inputprompt.id"],
),
sa.PrimaryKeyConstraint("input_prompt_id", "user_id"),
)

View File

@@ -0,0 +1,40 @@
"""non-nullbale slack bot id in channel config
Revision ID: f7a894b06d02
Revises: 9f696734098f
Create Date: 2024-12-06 12:55:42.845723
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "f7a894b06d02"
down_revision = "9f696734098f"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Delete all rows with null slack_bot_id
op.execute("DELETE FROM slack_channel_config WHERE slack_bot_id IS NULL")
# Make slack_bot_id non-nullable
op.alter_column(
"slack_channel_config",
"slack_bot_id",
existing_type=sa.Integer(),
nullable=False,
)
def downgrade() -> None:
# Make slack_bot_id nullable again
op.alter_column(
"slack_channel_config",
"slack_bot_id",
existing_type=sa.Integer(),
nullable=True,
)

View File

@@ -1,5 +1,6 @@
import asyncio
from logging.config import fileConfig
from typing import Literal
from sqlalchemy import pool
from sqlalchemy.engine import Connection
@@ -37,8 +38,15 @@ EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
def include_object(
object: SchemaItem,
name: str,
type_: str,
name: str | None,
type_: Literal[
"schema",
"table",
"column",
"index",
"unique_constraint",
"foreign_key_constraint",
],
reflected: bool,
compare_to: SchemaItem | None,
) -> bool:

View File

@@ -18,6 +18,11 @@ class ExternalAccess:
@dataclass(frozen=True)
class DocExternalAccess:
"""
This is just a class to wrap the external access and the document ID
together. It's used for syncing document permissions to Redis.
"""
external_access: ExternalAccess
# The document ID
doc_id: str

View File

@@ -1,3 +1,4 @@
import hashlib
import secrets
import uuid
from urllib.parse import quote
@@ -18,7 +19,8 @@ _API_KEY_HEADER_NAME = "Authorization"
# organizations like the Internet Engineering Task Force (IETF).
_API_KEY_HEADER_ALTERNATIVE_NAME = "X-Danswer-Authorization"
_BEARER_PREFIX = "Bearer "
_API_KEY_PREFIX = "dn_"
_API_KEY_PREFIX = "on_"
_DEPRECATED_API_KEY_PREFIX = "dn_"
_API_KEY_LEN = 192
@@ -52,7 +54,9 @@ def extract_tenant_from_api_key_header(request: Request) -> str | None:
api_key = raw_api_key_header[len(_BEARER_PREFIX) :].strip()
if not api_key.startswith(_API_KEY_PREFIX):
if not api_key.startswith(_API_KEY_PREFIX) and not api_key.startswith(
_DEPRECATED_API_KEY_PREFIX
):
return None
parts = api_key[len(_API_KEY_PREFIX) :].split(".", 1)
@@ -63,10 +67,19 @@ def extract_tenant_from_api_key_header(request: Request) -> str | None:
return unquote(tenant_id) if tenant_id else None
def _deprecated_hash_api_key(api_key: str) -> str:
return sha256_crypt.hash(api_key, salt="", rounds=API_KEY_HASH_ROUNDS)
def hash_api_key(api_key: str) -> str:
# NOTE: no salt is needed, as the API key is randomly generated
# and overlaps are impossible
return sha256_crypt.hash(api_key, salt="", rounds=API_KEY_HASH_ROUNDS)
if api_key.startswith(_API_KEY_PREFIX):
return hashlib.sha256(api_key.encode("utf-8")).hexdigest()
elif api_key.startswith(_DEPRECATED_API_KEY_PREFIX):
return _deprecated_hash_api_key(api_key)
else:
raise ValueError(f"Invalid API key prefix: {api_key[:3]}")
def build_displayable_api_key(api_key: str) -> str:

View File

@@ -9,7 +9,6 @@ from danswer.utils.special_types import JSON_ro
def get_invited_users() -> list[str]:
try:
store = get_kv_store()
return cast(list, store.load(KV_USER_STORE_KEY))
except KvKeyNotFoundError:
return list()

View File

@@ -23,7 +23,9 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
)
return UserPreferences(**preferences_data)
except KvKeyNotFoundError:
return UserPreferences(chosen_assistants=None, default_model=None)
return UserPreferences(
chosen_assistants=None, default_model=None, auto_scroll=True
)
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:

View File

@@ -58,7 +58,6 @@ from danswer.auth.schemas import UserRole
from danswer.auth.schemas import UserUpdate
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.app_configs import DISABLE_VERIFICATION
from danswer.configs.app_configs import EMAIL_FROM
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
@@ -87,6 +86,7 @@ from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
from danswer.db.models import User
from danswer.db.users import get_user_by_email
from danswer.server.utils import BasicAuthenticationError
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
@@ -99,11 +99,6 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
class BasicAuthenticationError(HTTPException):
def __init__(self, detail: str):
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
def is_user_admin(user: User | None) -> bool:
if AUTH_TYPE == AuthType.DISABLED:
return True
@@ -136,11 +131,12 @@ def get_display_email(email: str | None, space_less: bool = False) -> str:
def user_needs_to_be_verified() -> bool:
# all other auth types besides basic should require users to be
# verified
return not DISABLE_VERIFICATION and (
AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
)
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
return REQUIRE_EMAIL_VERIFICATION
# For other auth types, if the user is authenticated it's assumed that
# the user is already verified via the external IDP
return False
def verify_email_is_invited(email: str) -> None:

View File

@@ -11,6 +11,7 @@ from celery.exceptions import WorkerShutdown
from celery.states import READY_STATES
from celery.utils.log import get_task_logger
from celery.worker import strategy # type: ignore
from redis.lock import Lock as RedisLock
from sentry_sdk.integrations.celery import CeleryIntegration
from sqlalchemy import text
from sqlalchemy.orm import Session
@@ -332,16 +333,16 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
return
logger.info("Releasing primary worker lock.")
lock = sender.primary_worker_lock
lock: RedisLock = sender.primary_worker_lock
try:
if lock.owned():
try:
lock.release()
sender.primary_worker_lock = None
except Exception as e:
logger.error(f"Failed to release primary worker lock: {e}")
except Exception as e:
logger.error(f"Failed to check if primary worker lock is owned: {e}")
except Exception:
logger.exception("Failed to release primary worker lock")
except Exception:
logger.exception("Failed to check if primary worker lock is owned")
def on_setup_logging(

View File

@@ -11,6 +11,7 @@ from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
from redis.lock import Lock as RedisLock
import danswer.background.celery.apps.app_base as app_base
from danswer.background.celery.apps.app_base import task_logger
@@ -38,7 +39,6 @@ from danswer.redis.redis_usergroup import RedisUserGroup
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
@@ -116,9 +116,13 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.
lock = r.lock(
# set thread_local=False since we don't control what thread the periodic task might
# reacquire the lock with
lock: RedisLock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
thread_local=False,
)
logger.info("Primary worker lock: Acquire starting.")
@@ -227,7 +231,7 @@ class HubPeriodicTask(bootsteps.StartStopStep):
if not hasattr(worker, "primary_worker_lock"):
return
lock = worker.primary_worker_lock
lock: RedisLock = worker.primary_worker_lock
r = get_redis_client(tenant_id=None)

View File

@@ -2,54 +2,55 @@ from datetime import timedelta
from typing import Any
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryTask
tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": "check_for_vespa_sync_task",
"task": DanswerCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
"schedule": timedelta(seconds=20),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-connector-deletion",
"task": "check_for_connector_deletion_task",
"task": DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
"schedule": timedelta(seconds=20),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-indexing",
"task": "check_for_indexing",
"task": DanswerCeleryTask.CHECK_FOR_INDEXING,
"schedule": timedelta(seconds=15),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-prune",
"task": "check_for_pruning",
"task": DanswerCeleryTask.CHECK_FOR_PRUNING,
"schedule": timedelta(seconds=15),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "kombu-message-cleanup",
"task": "kombu_message_cleanup_task",
"task": DanswerCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
"schedule": timedelta(seconds=3600),
"options": {"priority": DanswerCeleryPriority.LOWEST},
},
{
"name": "monitor-vespa-sync",
"task": "monitor_vespa_sync",
"task": DanswerCeleryTask.MONITOR_VESPA_SYNC,
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-doc-permissions-sync",
"task": "check_for_doc_permissions_sync",
"task": DanswerCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
"schedule": timedelta(seconds=30),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-external-group-sync",
"task": "check_for_external_group_sync",
"task": DanswerCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
"schedule": timedelta(seconds=20),
"options": {"priority": DanswerCeleryPriority.HIGH},
},

View File

@@ -5,13 +5,13 @@ from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.connector_credential_pair import get_connector_credential_pairs
@@ -29,7 +29,7 @@ class TaskDependencyError(RuntimeError):
@shared_task(
name="check_for_connector_deletion_task",
name=DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,
@@ -37,7 +37,7 @@ class TaskDependencyError(RuntimeError):
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
lock_beat: RedisLock = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -60,7 +60,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
redis_connector = RedisConnector(tenant_id, cc_pair_id)
try:
try_generate_document_cc_pair_cleanup_tasks(
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
self.app, cc_pair_id, db_session, lock_beat, tenant_id
)
except TaskDependencyError as e:
# this means we wanted to start deleting but dependent tasks were running
@@ -86,7 +86,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
app: Celery,
cc_pair_id: int,
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str | None,
) -> int | None:

View File

@@ -8,6 +8,7 @@ from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from danswer.access.models import DocExternalAccess
from danswer.background.celery.apps.app_base import task_logger
@@ -17,9 +18,11 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import DocumentSource
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import upsert_document_by_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
@@ -27,7 +30,7 @@ from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_ext_perm_user_if_not_exists
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncData,
RedisConnectorPermissionSyncPayload,
)
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import doc_permission_sync_ctx
@@ -81,7 +84,7 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
@shared_task(
name="check_for_doc_permissions_sync",
name=DanswerCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
@@ -138,7 +141,7 @@ def try_creating_permissions_sync_task(
LOCK_TIMEOUT = 30
lock = r.lock(
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
timeout=LOCK_TIMEOUT,
)
@@ -162,8 +165,8 @@ def try_creating_permissions_sync_task(
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
app.send_task(
"connector_permission_sync_generator_task",
result = app.send_task(
DanswerCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
kwargs=dict(
cc_pair_id=cc_pair_id,
tenant_id=tenant_id,
@@ -174,8 +177,8 @@ def try_creating_permissions_sync_task(
)
# set a basic fence to start
payload = RedisConnectorPermissionSyncData(
started=None,
payload = RedisConnectorPermissionSyncPayload(
started=None, celery_task_id=result.id
)
redis_connector.permissions.set_fence(payload)
@@ -190,7 +193,7 @@ def try_creating_permissions_sync_task(
@shared_task(
name="connector_permission_sync_generator_task",
name=DanswerCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
acks_late=False,
soft_time_limit=JOB_TIMEOUT,
track_started=True,
@@ -216,7 +219,7 @@ def connector_permission_sync_generator_task(
r = get_redis_client(tenant_id=tenant_id)
lock = r.lock(
lock: RedisLock = r.lock(
DanswerRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,
@@ -241,13 +244,17 @@ def connector_permission_sync_generator_task(
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
if doc_sync_func is None:
raise ValueError(f"No doc sync func found for {source_type}")
raise ValueError(
f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}"
)
logger.info(f"Syncing docs for {source_type}")
logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}")
payload = RedisConnectorPermissionSyncData(
started=datetime.now(timezone.utc),
)
payload = redis_connector.permissions.payload
if not payload:
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
payload.started = datetime.now(timezone.utc)
redis_connector.permissions.set_fence(payload)
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
@@ -256,7 +263,12 @@ def connector_permission_sync_generator_task(
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
)
tasks_generated = redis_connector.permissions.generate_tasks(
self.app, lock, document_external_accesses, source_type
celery_app=self.app,
lock=lock,
new_permissions=document_external_accesses,
source_string=source_type,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
)
if tasks_generated is None:
return None
@@ -281,7 +293,7 @@ def connector_permission_sync_generator_task(
@shared_task(
name="update_external_document_permissions_task",
name=DanswerCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
time_limit=LIGHT_TIME_LIMIT,
max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES,
@@ -292,6 +304,8 @@ def update_external_document_permissions_task(
tenant_id: str | None,
serialized_doc_external_access: dict,
source_string: str,
connector_id: int,
credential_id: int,
) -> bool:
document_external_access = DocExternalAccess.from_dict(
serialized_doc_external_access
@@ -300,18 +314,28 @@ def update_external_document_permissions_task(
external_access = document_external_access.external_access
try:
with get_session_with_tenant(tenant_id) as db_session:
# Then we build the update requests to update vespa
# Add the users to the DB if they don't exist
batch_add_ext_perm_user_if_not_exists(
db_session=db_session,
emails=list(external_access.external_user_emails),
)
upsert_document_external_perms(
# Then we upsert the document's external permissions in postgres
created_new_doc = upsert_document_external_perms(
db_session=db_session,
doc_id=doc_id,
external_access=external_access,
source_type=DocumentSource(source_string),
)
if created_new_doc:
# If a new document was created, we associate it with the cc_pair
upsert_document_by_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
document_ids=[doc_id],
)
logger.debug(
f"Successfully synced postgres document permissions for {doc_id}"
)

View File

@@ -8,6 +8,7 @@ from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
@@ -16,6 +17,7 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector import mark_cc_pair_as_external_group_synced
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
@@ -24,13 +26,20 @@ from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_ext_group_sync import (
RedisConnectorExternalGroupSyncPayload,
)
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.danswer.db.connector_credential_pair import get_cc_pairs_by_source
from ee.danswer.db.external_perm import ExternalUserGroup
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair
from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIODS
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
from ee.danswer.external_permissions.sync_params import (
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC,
)
logger = setup_logger()
@@ -49,7 +58,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
if cc_pair.access_type != AccessType.SYNC:
return False
# skip pruning if not active
# skip external group sync if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
@@ -81,7 +90,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
@shared_task(
name="check_for_external_group_sync",
name=DanswerCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
@@ -102,12 +111,28 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
# We only want to sync one cc_pair per source type in
# GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC
for source in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC:
# These are ordered by cc_pair id so the first one is the one we want
cc_pairs_to_dedupe = get_cc_pairs_by_source(
db_session, source, only_sync=True
)
# We only want to sync one cc_pair per source type
# in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC so we dedupe here
for cc_pair_to_remove in cc_pairs_to_dedupe[1:]:
cc_pairs = [
cc_pair
for cc_pair in cc_pairs
if cc_pair.id != cc_pair_to_remove.id
]
for cc_pair in cc_pairs:
if _is_external_group_sync_due(cc_pair):
cc_pair_ids_to_sync.append(cc_pair.id)
for cc_pair_id in cc_pair_ids_to_sync:
tasks_created = try_creating_permissions_sync_task(
tasks_created = try_creating_external_group_sync_task(
self.app, cc_pair_id, r, tenant_id
)
if not tasks_created:
@@ -125,7 +150,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
lock_beat.release()
def try_creating_permissions_sync_task(
def try_creating_external_group_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
@@ -156,8 +181,8 @@ def try_creating_permissions_sync_task(
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
_ = app.send_task(
"connector_external_group_sync_generator_task",
result = app.send_task(
DanswerCeleryTask.CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK,
kwargs=dict(
cc_pair_id=cc_pair_id,
tenant_id=tenant_id,
@@ -166,8 +191,13 @@ def try_creating_permissions_sync_task(
task_id=custom_task_id,
priority=DanswerCeleryPriority.HIGH,
)
# set a basic fence to start
redis_connector.external_group_sync.set_fence(True)
payload = RedisConnectorExternalGroupSyncPayload(
started=datetime.now(timezone.utc),
celery_task_id=result.id,
)
redis_connector.external_group_sync.set_fence(payload)
except Exception:
task_logger.exception(
@@ -182,7 +212,7 @@ def try_creating_permissions_sync_task(
@shared_task(
name="connector_external_group_sync_generator_task",
name=DanswerCeleryTask.CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK,
acks_late=False,
soft_time_limit=JOB_TIMEOUT,
track_started=True,
@@ -195,7 +225,7 @@ def connector_external_group_sync_generator_task(
tenant_id: str | None,
) -> None:
"""
Permission sync task that handles document permission syncing for a given connector credential pair
Permission sync task that handles external group syncing for a given connector credential pair
This task assumes that the task has already been properly fenced
"""
@@ -203,7 +233,7 @@ def connector_external_group_sync_generator_task(
r = get_redis_client(tenant_id=tenant_id)
lock = r.lock(
lock: RedisLock = r.lock(
DanswerRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
@@ -228,9 +258,13 @@ def connector_external_group_sync_generator_task(
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
if ext_group_sync_func is None:
raise ValueError(f"No external group sync func found for {source_type}")
raise ValueError(
f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
)
logger.info(f"Syncing docs for {source_type}")
logger.info(
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
)
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)
@@ -249,7 +283,6 @@ def connector_external_group_sync_generator_task(
)
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
except Exception as e:
task_logger.exception(
f"Failed to run external group sync: cc_pair={cc_pair_id}"
@@ -260,6 +293,6 @@ def connector_external_group_sync_generator_task(
raise e
finally:
# we always want to clear the fence after the task is done or failed so it doesn't get stuck
redis_connector.external_group_sync.set_fence(False)
redis_connector.external_group_sync.set_fence(None)
if lock.owned():
lock.release()

View File

@@ -23,13 +23,16 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import DocumentSource
from danswer.db.connector import mark_ccpair_with_indexing_trigger
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.enums import IndexingMode
from danswer.db.enums import IndexingStatus
from danswer.db.enums import IndexModelStatus
from danswer.db.index_attempt import create_index_attempt
@@ -37,12 +40,13 @@ from danswer.db.index_attempt import delete_index_attempt
from danswer.db.index_attempt import get_all_index_attempts_by_status
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
from danswer.db.index_attempt import mark_attempt_canceled
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexAttempt
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_active_search_settings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
@@ -153,13 +157,13 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
@shared_task(
name="check_for_indexing",
name=DanswerCeleryTask.CHECK_FOR_INDEXING,
soft_time_limit=300,
bind=True,
)
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
tasks_created = 0
locked = False
r = get_redis_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
@@ -172,6 +176,8 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
if not lock_beat.acquire(blocking=False):
return None
locked = True
# check for search settings swap
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
old_search_settings = check_index_swap(db_session=db_session)
@@ -205,17 +211,10 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
redis_connector = RedisConnector(tenant_id, cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings = [primary_search_settings]
# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings.append(secondary_search_settings)
for search_settings_instance in search_settings:
search_settings_list: list[SearchSettings] = get_active_search_settings(
db_session
)
for search_settings_instance in search_settings_list:
redis_connector_index = redis_connector.new_index(
search_settings_instance.id
)
@@ -231,22 +230,46 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
last_attempt = get_last_attempt_for_cc_pair(
cc_pair.id, search_settings_instance.id, db_session
)
search_settings_primary = False
if search_settings_instance.id == search_settings_list[0].id:
search_settings_primary = True
if not _should_index(
cc_pair=cc_pair,
last_index=last_attempt,
search_settings_instance=search_settings_instance,
secondary_index_building=len(search_settings) > 1,
search_settings_primary=search_settings_primary,
secondary_index_building=len(search_settings_list) > 1,
db_session=db_session,
):
continue
reindex = False
if search_settings_instance.id == search_settings_list[0].id:
# the indexing trigger is only checked and cleared with the primary search settings
if cc_pair.indexing_trigger is not None:
if cc_pair.indexing_trigger == IndexingMode.REINDEX:
reindex = True
task_logger.info(
f"Connector indexing manual trigger detected: "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings_instance.id} "
f"indexing_mode={cc_pair.indexing_trigger}"
)
mark_ccpair_with_indexing_trigger(
cc_pair.id, None, db_session
)
# using a task queue and only allowing one task per cc_pair/search_setting
# prevents us from starving out certain attempts
attempt_id = try_creating_indexing_task(
self.app,
cc_pair,
search_settings_instance,
False,
reindex,
db_session,
r,
tenant_id,
@@ -256,7 +279,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
f"Connector indexing queued: "
f"index_attempt={attempt_id} "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings_instance.id} "
f"search_settings={search_settings_instance.id}"
)
tasks_created += 1
@@ -281,7 +304,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
mark_attempt_failed(
attempt.id, db_session, failure_reason=failure_reason
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -289,13 +311,14 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error(
"check_for_indexing - Lock not owned on completion: "
f"tenant={tenant_id}"
)
if locked:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error(
"check_for_indexing - Lock not owned on completion: "
f"tenant={tenant_id}"
)
return tasks_created
@@ -304,6 +327,7 @@ def _should_index(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
search_settings_instance: SearchSettings,
search_settings_primary: bool,
secondary_index_building: bool,
db_session: Session,
) -> bool:
@@ -368,6 +392,11 @@ def _should_index(
):
return False
if search_settings_primary:
if cc_pair.indexing_trigger is not None:
# if a manual indexing trigger is on the cc pair, honor it for primary search settings
return True
# if no attempt has ever occurred, we should index regardless of refresh_freq
if not last_index:
return True
@@ -458,7 +487,7 @@ def try_creating_indexing_task(
# when the task is sent, we have yet to finish setting up the fence
# therefore, the task must contain code that blocks until the fence is ready
result = celery_app.send_task(
"connector_indexing_proxy_task",
DanswerCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
kwargs=dict(
index_attempt_id=index_attempt_id,
cc_pair_id=cc_pair.id,
@@ -495,8 +524,14 @@ def try_creating_indexing_task(
return index_attempt_id
@shared_task(name="connector_indexing_proxy_task", acks_late=False, track_started=True)
@shared_task(
name=DanswerCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
bind=True,
acks_late=False,
track_started=True,
)
def connector_indexing_proxy_task(
self: Task,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
@@ -509,6 +544,10 @@ def connector_indexing_proxy_task(
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
if not self.request.id:
task_logger.error("self.request.id is None!")
client = SimpleJobClient()
job = client.submit(
@@ -537,32 +576,106 @@ def connector_indexing_proxy_task(
f"search_settings={search_settings_id}"
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
while True:
sleep(10)
sleep(5)
# do nothing for ongoing jobs that haven't been stopped
if not job.done():
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
if job.status == "error":
task_logger.error(
f"Indexing watchdog - spawned task exceptioned: "
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
"Indexing watchdog - termination signal detected: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"error={job.exception()}"
f"search_settings={search_settings_id}"
)
try:
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_canceled(
index_attempt_id,
db_session,
"Connector termination signal detected",
)
except Exception:
# if the DB exceptions, we'll just get an unfriendly failure message
# in the UI instead of the cancellation message
logger.exception(
"Indexing watchdog - transient exception marking index attempt as canceled: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
job.cancel()
break
if not job.done():
# if the spawned task is still running, restart the check once again
# if the index attempt is not in a finished status
try:
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
except Exception:
# if the DB exceptioned, just restart the check.
# polling the index attempt status doesn't need to be strongly consistent
logger.exception(
"Indexing watchdog - transient exception looking up index attempt: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
continue
if job.status == "error":
ignore_exitcode = False
exit_code: int | None = None
if job.process:
exit_code = job.process.exitcode
# seeing non-deterministic behavior where spawned tasks occasionally return exit code 1
# even though logging clearly indicates that they completed successfully
# to work around this, we ignore the job error state if the completion signal is OK
status_int = redis_connector_index.get_completion()
if status_int:
status_enum = HTTPStatus(status_int)
if status_enum == HTTPStatus.OK:
ignore_exitcode = True
if ignore_exitcode:
task_logger.warning(
"Indexing watchdog - spawned task has non-zero exit code "
"but completion signal is OK. Continuing...: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"exit_code={exit_code}"
)
else:
task_logger.error(
"Indexing watchdog - spawned task exceptioned: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"exit_code={exit_code} "
f"error={job.exception()}"
)
job.release()
break
@@ -703,9 +816,12 @@ def connector_indexing_task(
)
break
# set thread_local=False since we don't control what thread the indexing/pruning
# might run our callback with
lock: RedisLock = r.lock(
redis_connector_index.generator_lock_key,
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
thread_local=False,
)
acquired = lock.acquire(blocking=False)

View File

@@ -13,12 +13,13 @@ from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import PostgresAdvisoryLocks
from danswer.db.engine import get_session_with_tenant
@shared_task(
name="kombu_message_cleanup_task",
name=DanswerCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
soft_time_limit=JOB_TIMEOUT,
bind=True,
base=AbortableTask,

View File

@@ -8,6 +8,7 @@ from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
@@ -20,6 +21,7 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.models import InputType
@@ -75,7 +77,7 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
@shared_task(
name="check_for_pruning",
name=DanswerCeleryTask.CHECK_FOR_PRUNING,
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
@@ -184,7 +186,7 @@ def try_creating_prune_generator_task(
custom_task_id = f"{redis_connector.prune.generator_task_key}_{uuid4()}"
celery_app.send_task(
"connector_pruning_generator_task",
DanswerCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
kwargs=dict(
cc_pair_id=cc_pair.id,
connector_id=cc_pair.connector_id,
@@ -209,7 +211,7 @@ def try_creating_prune_generator_task(
@shared_task(
name="connector_pruning_generator_task",
name=DanswerCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
acks_late=False,
soft_time_limit=JOB_TIMEOUT,
track_started=True,
@@ -238,9 +240,12 @@ def connector_pruning_generator_task(
r = get_redis_client(tenant_id=tenant_id)
lock = r.lock(
# set thread_local=False since we don't control what thread the indexing/pruning
# might run our callback with
lock: RedisLock = r.lock(
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.id}",
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
thread_local=False,
)
acquired = lock.acquire(blocking=False)

View File

@@ -9,6 +9,7 @@ from tenacity import RetryError
from danswer.access.access import get_access_for_document
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from danswer.configs.constants import DanswerCeleryTask
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
from danswer.db.document import delete_documents_complete__no_commit
from danswer.db.document import get_document
@@ -31,7 +32,7 @@ LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
@shared_task(
name="document_by_cc_pair_cleanup_task",
name=DanswerCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
time_limit=LIGHT_TIME_LIMIT,
max_retries=DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES,

View File

@@ -25,6 +25,7 @@ from danswer.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector import mark_cc_pair_as_permissions_synced
@@ -46,6 +47,7 @@ from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import IndexingStatus
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
@@ -58,7 +60,7 @@ from danswer.redis.redis_connector_credential_pair import RedisConnectorCredenti
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncData,
RedisConnectorPermissionSyncPayload,
)
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
@@ -79,7 +81,7 @@ logger = setup_logger()
# celery auto associates tasks created inside another task,
# which bloats the result metadata considerably. trail=False prevents this.
@shared_task(
name="check_for_vespa_sync_task",
name=DanswerCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,
@@ -588,7 +590,7 @@ def monitor_ccpair_permissions_taskset(
if remaining > 0:
return
payload: RedisConnectorPermissionSyncData | None = (
payload: RedisConnectorPermissionSyncPayload | None = (
redis_connector.permissions.payload
)
start_time: datetime | None = payload.started if payload else None
@@ -596,9 +598,7 @@ def monitor_ccpair_permissions_taskset(
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
redis_connector.permissions.taskset_clear()
redis_connector.permissions.generator_clear()
redis_connector.permissions.set_fence(None)
redis_connector.permissions.reset()
def monitor_ccpair_indexing_taskset(
@@ -655,33 +655,52 @@ def monitor_ccpair_indexing_taskset(
# outer = result.state in READY state
status_int = redis_connector_index.get_completion()
if status_int is None: # inner signal not set ... possible error
result_state = result.state
task_state = result.state
if (
result_state in READY_STATES
task_state in READY_STATES
): # outer signal in terminal state ... possible error
# Now double check!
if redis_connector_index.get_completion() is None:
# inner signal still not set (and cannot change when outer result_state is READY)
# Task is finished but generator complete isn't set.
# We have a problem! Worker may have crashed.
task_result = str(result.result)
task_traceback = str(result.traceback)
msg = (
f"Connector indexing aborted or exceptioned: "
f"attempt={payload.index_attempt_id} "
f"celery_task={payload.celery_task_id} "
f"result_state={result_state} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
f"result.state={task_state} "
f"result.result={task_result} "
f"result.traceback={task_traceback}"
)
task_logger.warning(msg)
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
if index_attempt:
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason=msg,
try:
index_attempt = get_index_attempt(
db_session, payload.index_attempt_id
)
if index_attempt:
if (
index_attempt.status != IndexingStatus.CANCELED
and index_attempt.status != IndexingStatus.FAILED
):
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason=msg,
)
except Exception:
task_logger.exception(
"monitor_ccpair_indexing_taskset - transient exception marking index attempt as failed: "
f"attempt={payload.index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
redis_connector_index.reset()
@@ -692,6 +711,7 @@ def monitor_ccpair_indexing_taskset(
task_logger.info(
f"Connector indexing finished: cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"status={status_enum.name} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
@@ -699,7 +719,7 @@ def monitor_ccpair_indexing_taskset(
redis_connector_index.reset()
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
@shared_task(name=DanswerCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True)
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
It scans for fence values and then gets the counts of any associated tasksets.
@@ -724,7 +744,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
# print current queue lengths
r_celery = self.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r)
n_celery = celery_get_queue_length("celery", r_celery)
n_indexing = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_INDEXING, r_celery
)
@@ -810,7 +830,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
@shared_task(
name="vespa_metadata_sync_task",
name=DanswerCeleryTask.VESPA_METADATA_SYNC_TASK,
bind=True,
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
time_limit=LIGHT_TIME_LIMIT,

View File

@@ -1,6 +1,8 @@
"""Factory stub for running celery worker / celery beat."""
from celery import Celery
from danswer.background.celery.apps.beat import celery_app
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app = celery_app
app: Celery = celery_app

View File

@@ -1,8 +1,10 @@
"""Factory stub for running celery worker / celery beat."""
from celery import Celery
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app = fetch_versioned_implementation(
app: Celery = fetch_versioned_implementation(
"danswer.background.celery.apps.primary", "celery_app"
)

View File

@@ -82,7 +82,7 @@ class SimpleJob:
return "running"
elif self.process.exitcode is None:
return "cancelled"
elif self.process.exitcode > 0:
elif self.process.exitcode != 0:
return "error"
else:
return "finished"
@@ -123,7 +123,8 @@ class SimpleJobClient:
self._cleanup_completed_jobs()
if len(self.jobs) >= self.n_workers:
logger.debug(
f"No available workers to run job. Currently running '{len(self.jobs)}' jobs, with a limit of '{self.n_workers}'."
f"No available workers to run job. "
f"Currently running '{len(self.jobs)}' jobs, with a limit of '{self.n_workers}'."
)
return None

View File

@@ -19,6 +19,7 @@ from danswer.db.connector_credential_pair import get_last_successful_attempt_tim
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.index_attempt import mark_attempt_canceled
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_partially_succeeded
from danswer.db.index_attempt import mark_attempt_succeeded
@@ -87,6 +88,10 @@ def _get_connector_runner(
)
class ConnectorStopSignal(Exception):
"""A custom exception used to signal a stop in processing."""
def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
@@ -208,9 +213,7 @@ def _run_indexing(
# contents still need to be initially pulled.
if callback:
if callback.should_stop():
raise RuntimeError(
"_run_indexing: Connector stop signal detected"
)
raise ConnectorStopSignal("Connector stop signal detected")
# TODO: should we move this into the above callback instead?
db_session.refresh(db_cc_pair)
@@ -304,26 +307,16 @@ def _run_indexing(
)
except Exception as e:
logger.exception(
f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds"
f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds"
)
# Only mark the attempt as a complete failure if this is the first indexing window.
# Otherwise, some progress was made - the next run will not start from the beginning.
# In this case, it is not accurate to mark it as a failure. When the next run begins,
# if that fails immediately, it will be marked as a failure.
#
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or not db_cc_pair.status.is_active()
or index_attempt.status != IndexingStatus.IN_PROGRESS
):
mark_attempt_failed(
if isinstance(e, ConnectorStopSignal):
mark_attempt_canceled(
index_attempt.id,
db_session,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
reason=str(e),
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
@@ -335,6 +328,37 @@ def _run_indexing(
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
else:
# Only mark the attempt as a complete failure if this is the first indexing window.
# Otherwise, some progress was made - the next run will not start from the beginning.
# In this case, it is not accurate to mark it as a failure. When the next run begins,
# if that fails immediately, it will be marked as a failure.
#
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or not db_cc_pair.status.is_active()
or index_attempt.status != IndexingStatus.IN_PROGRESS
):
mark_attempt_failed(
index_attempt.id,
db_session,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
net_docs=net_doc_change,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
# break => similar to success case. As mentioned above, if the next run fails for the same
# reason it will then be marked as a failure

View File

@@ -6,33 +6,27 @@ from langchain.schema.messages import BaseMessage
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import ToolCall
from danswer.chat.llm_response_handler import LLMResponseHandlerManager
from danswer.chat.models import AnswerQuestionPossibleReturn
from danswer.chat.models import AnswerStyleConfig
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.file_store.utils import InMemoryChatFile
from danswer.llm.answering.llm_response_handler import LLMCall
from danswer.llm.answering.llm_response_handler import LLMResponseHandlerManager
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.llm.answering.prompts.build import default_build_system_message
from danswer.llm.answering.prompts.build import default_build_user_message
from danswer.llm.answering.stream_processing.answer_response_handler import (
AnswerResponseHandler,
)
from danswer.llm.answering.stream_processing.answer_response_handler import (
from danswer.chat.models import PromptConfig
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
from danswer.chat.prompt_builder.build import default_build_system_message
from danswer.chat.prompt_builder.build import default_build_user_message
from danswer.chat.prompt_builder.build import LLMCall
from danswer.chat.stream_processing.answer_response_handler import (
CitationResponseHandler,
)
from danswer.llm.answering.stream_processing.answer_response_handler import (
from danswer.chat.stream_processing.answer_response_handler import (
DummyAnswerResponseHandler,
)
from danswer.llm.answering.stream_processing.answer_response_handler import (
QuotesResponseHandler,
)
from danswer.llm.answering.stream_processing.utils import map_document_id_order
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
from danswer.chat.stream_processing.utils import map_document_id_order
from danswer.chat.tool_handling.tool_response_handler import ToolResponseHandler
from danswer.file_store.utils import InMemoryChatFile
from danswer.llm.interfaces import LLM
from danswer.llm.models import PreviousMessage
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.tools.force import ForceUseTool
from danswer.tools.models import ToolResponse
@@ -212,20 +206,28 @@ class Answer:
# + figure out what the next LLM call should be
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
search_result = SearchTool.get_search_result(current_llm_call) or []
search_result, displayed_search_results_map = SearchTool.get_search_result(
current_llm_call
) or ([], {})
answer_handler: AnswerResponseHandler
if self.answer_style_config.citation_config:
answer_handler = CitationResponseHandler(
context_docs=search_result,
doc_id_to_rank_map=map_document_id_order(search_result),
)
elif self.answer_style_config.quotes_config:
answer_handler = QuotesResponseHandler(
context_docs=search_result,
)
else:
raise ValueError("No answer style config provided")
# Quotes are no longer supported
# answer_handler: AnswerResponseHandler
# if self.answer_style_config.citation_config:
# answer_handler = CitationResponseHandler(
# context_docs=search_result,
# doc_id_to_rank_map=map_document_id_order(search_result),
# )
# elif self.answer_style_config.quotes_config:
# answer_handler = QuotesResponseHandler(
# context_docs=search_result,
# )
# else:
# raise ValueError("No answer style config provided")
answer_handler = CitationResponseHandler(
context_docs=search_result,
doc_id_to_rank_map=map_document_id_order(search_result),
display_doc_order_dict=displayed_search_results_map,
)
response_handler_manager = LLMResponseHandlerManager(
tool_call_handler, answer_handler, self.is_cancelled

View File

@@ -2,20 +2,79 @@ import re
from typing import cast
from uuid import UUID
from fastapi import HTTPException
from fastapi.datastructures import Headers
from sqlalchemy.orm import Session
from danswer.auth.users import is_user_admin
from danswer.chat.models import CitationInfo
from danswer.chat.models import LlmDoc
from danswer.chat.models import PersonaOverrideConfig
from danswer.chat.models import ThreadMessage
from danswer.configs.constants import DEFAULT_PERSONA_ID
from danswer.configs.constants import MessageType
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import RerankingDetails
from danswer.context.search.models import RetrievalDetails
from danswer.db.chat import create_chat_session
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.llm import fetch_existing_doc_sets
from danswer.db.llm import fetch_existing_tools
from danswer.db.models import ChatMessage
from danswer.llm.answering.models import PreviousMessage
from danswer.db.models import Persona
from danswer.db.models import Prompt
from danswer.db.models import Tool
from danswer.db.models import User
from danswer.db.persona import get_prompts_by_ids
from danswer.llm.models import PreviousMessage
from danswer.natural_language_processing.utils import BaseTokenizer
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.utils.logger import setup_logger
logger = setup_logger()
def prepare_chat_message_request(
message_text: str,
user: User | None,
persona_id: int | None,
# Does the question need to have a persona override
persona_override_config: PersonaOverrideConfig | None,
prompt: Prompt | None,
message_ts_to_respond_to: str | None,
retrieval_details: RetrievalDetails | None,
rerank_settings: RerankingDetails | None,
db_session: Session,
) -> CreateChatMessageRequest:
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
new_chat_session = create_chat_session(
db_session=db_session,
description=None,
user_id=user.id if user else None,
# If using an override, this id will be ignored later on
persona_id=persona_id or DEFAULT_PERSONA_ID,
danswerbot_flow=True,
slack_thread_id=message_ts_to_respond_to,
)
return CreateChatMessageRequest(
chat_session_id=new_chat_session.id,
parent_message_id=None, # It's a standalone chat session each time
message=message_text,
file_descriptors=[], # Currently SlackBot/answer api do not support files in the context
prompt_id=prompt.id if prompt else None,
# Can always override the persona for the single query, if it's a normal persona
# then it will be treated the same
persona_override_config=persona_override_config,
search_doc_ids=None,
retrieval_options=retrieval_details,
rerank_settings=rerank_settings,
)
def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDoc:
return LlmDoc(
document_id=inference_section.center_chunk.document_id,
@@ -31,9 +90,49 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
if inference_section.center_chunk.source_links
else None,
source_links=inference_section.center_chunk.source_links,
match_highlights=inference_section.center_chunk.match_highlights,
)
def combine_message_thread(
messages: list[ThreadMessage],
max_tokens: int | None,
llm_tokenizer: BaseTokenizer,
) -> str:
"""Used to create a single combined message context from threads"""
if not messages:
return ""
message_strs: list[str] = []
total_token_count = 0
for message in reversed(messages):
if message.role == MessageType.USER:
role_str = message.role.value.upper()
if message.sender:
role_str += " " + message.sender
else:
# Since other messages might have the user identifying information
# better to use Unknown for symmetry
role_str += " Unknown"
else:
role_str = message.role.value.upper()
msg_str = f"{role_str}:\n{message.message}"
message_token_count = len(llm_tokenizer.encode(msg_str))
if (
max_tokens is not None
and total_token_count + message_token_count > max_tokens
):
break
message_strs.insert(0, msg_str)
total_token_count += message_token_count
return "\n\n".join(message_strs)
def create_chat_chain(
chat_session_id: UUID,
db_session: Session,
@@ -196,3 +295,71 @@ def extract_headers(
if lowercase_key in headers:
extracted_headers[lowercase_key] = headers[lowercase_key]
return extracted_headers
def create_temporary_persona(
persona_config: PersonaOverrideConfig, db_session: Session, user: User | None = None
) -> Persona:
if not is_user_admin(user):
raise HTTPException(
status_code=403,
detail="User is not authorized to create a persona in one shot queries",
)
"""Create a temporary Persona object from the provided configuration."""
persona = Persona(
name=persona_config.name,
description=persona_config.description,
num_chunks=persona_config.num_chunks,
llm_relevance_filter=persona_config.llm_relevance_filter,
llm_filter_extraction=persona_config.llm_filter_extraction,
recency_bias=persona_config.recency_bias,
llm_model_provider_override=persona_config.llm_model_provider_override,
llm_model_version_override=persona_config.llm_model_version_override,
)
if persona_config.prompts:
persona.prompts = [
Prompt(
name=p.name,
description=p.description,
system_prompt=p.system_prompt,
task_prompt=p.task_prompt,
include_citations=p.include_citations,
datetime_aware=p.datetime_aware,
)
for p in persona_config.prompts
]
elif persona_config.prompt_ids:
persona.prompts = get_prompts_by_ids(
db_session=db_session, prompt_ids=persona_config.prompt_ids
)
persona.tools = []
if persona_config.custom_tools_openapi:
for schema in persona_config.custom_tools_openapi:
tools = cast(
list[Tool],
build_custom_tools_from_openapi_schema_and_headers(schema),
)
persona.tools.extend(tools)
if persona_config.tools:
tool_ids = [tool.id for tool in persona_config.tools]
persona.tools.extend(
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
)
if persona_config.tool_ids:
persona.tools.extend(
fetch_existing_tools(
db_session=db_session, tool_ids=persona_config.tool_ids
)
)
fetched_docs = fetch_existing_doc_sets(
db_session=db_session, doc_ids=persona_config.document_set_ids
)
persona.document_sets = fetched_docs
return persona

View File

@@ -1,60 +1,22 @@
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from typing import TYPE_CHECKING
from langchain_core.messages import BaseMessage
from pydantic.v1 import BaseModel as BaseModel__v1
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import ResponsePart
from danswer.chat.models import StreamStopInfo
from danswer.chat.models import StreamStopReason
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.tools.force import ForceUseTool
from danswer.tools.models import ToolCallFinalResult
from danswer.tools.models import ToolCallKickoff
from danswer.tools.models import ToolResponse
from danswer.tools.tool import Tool
if TYPE_CHECKING:
from danswer.llm.answering.stream_processing.answer_response_handler import (
AnswerResponseHandler,
)
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
ResponsePart = (
DanswerAnswerPiece
| CitationInfo
| DanswerQuotes
| ToolCallKickoff
| ToolResponse
| ToolCallFinalResult
| StreamStopInfo
)
class LLMCall(BaseModel__v1):
prompt_builder: AnswerPromptBuilder
tools: list[Tool]
force_use_tool: ForceUseTool
files: list[InMemoryChatFile]
tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult]
using_tool_calling_llm: bool
class Config:
arbitrary_types_allowed = True
from danswer.chat.prompt_builder.build import LLMCall
from danswer.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from danswer.chat.tool_handling.tool_response_handler import ToolResponseHandler
class LLMResponseHandlerManager:
def __init__(
self,
tool_handler: "ToolResponseHandler",
answer_handler: "AnswerResponseHandler",
tool_handler: ToolResponseHandler,
answer_handler: AnswerResponseHandler,
is_cancelled: Callable[[], bool],
):
self.tool_handler = tool_handler

View File

@@ -1,17 +1,30 @@
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from enum import Enum
from typing import Any
from typing import TYPE_CHECKING
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import model_validator
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import MessageType
from danswer.context.search.enums import QueryFlow
from danswer.context.search.enums import RecencyBiasSetting
from danswer.context.search.enums import SearchType
from danswer.context.search.models import RetrievalDocs
from danswer.context.search.models import SearchResponse
from danswer.llm.override_models import PromptOverride
from danswer.tools.models import ToolCallFinalResult
from danswer.tools.models import ToolCallKickoff
from danswer.tools.models import ToolResponse
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType
if TYPE_CHECKING:
from danswer.db.models import Prompt
class LlmDoc(BaseModel):
"""This contains the minimal set information for the LLM portion including citations"""
@@ -25,6 +38,7 @@ class LlmDoc(BaseModel):
updated_at: datetime | None
link: str | None
source_links: dict[int, str] | None
match_highlights: list[str] | None
# First chunk of info for streaming QA
@@ -117,20 +131,6 @@ class StreamingError(BaseModel):
stack_trace: str | None = None
class DanswerQuote(BaseModel):
# This is during inference so everything is a string by this point
quote: str
document_id: str
link: str | None
source_type: str
semantic_identifier: str
blurb: str
class DanswerQuotes(BaseModel):
quotes: list[DanswerQuote]
class DanswerContext(BaseModel):
content: str
document_id: str
@@ -146,14 +146,20 @@ class DanswerAnswer(BaseModel):
answer: str | None
class QAResponse(SearchResponse, DanswerAnswer):
quotes: list[DanswerQuote] | None
contexts: list[DanswerContexts] | None
predicted_flow: QueryFlow
predicted_search: SearchType
eval_res_valid: bool | None = None
class ThreadMessage(BaseModel):
message: str
sender: str | None = None
role: MessageType = MessageType.USER
class ChatDanswerBotResponse(BaseModel):
answer: str | None = None
citations: list[CitationInfo] | None = None
docs: QADocsResponse | None = None
llm_selected_doc_indices: list[int] | None = None
error_msg: str | None = None
chat_message_id: int | None = None
answer_valid: bool = True # Reflexion result, default True if Reflexion not run
class FileChatDisplay(BaseModel):
@@ -165,9 +171,41 @@ class CustomToolResponse(BaseModel):
tool_name: str
class ToolConfig(BaseModel):
id: int
class PromptOverrideConfig(BaseModel):
name: str
description: str = ""
system_prompt: str
task_prompt: str = ""
include_citations: bool = True
datetime_aware: bool = True
class PersonaOverrideConfig(BaseModel):
name: str
description: str
search_type: SearchType = SearchType.SEMANTIC
num_chunks: float | None = None
llm_relevance_filter: bool = False
llm_filter_extraction: bool = False
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
llm_model_provider_override: str | None = None
llm_model_version_override: str | None = None
prompts: list[PromptOverrideConfig] = Field(default_factory=list)
prompt_ids: list[int] = Field(default_factory=list)
document_set_ids: list[int] = Field(default_factory=list)
tools: list[ToolConfig] = Field(default_factory=list)
tool_ids: list[int] = Field(default_factory=list)
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
AnswerQuestionPossibleReturn = (
DanswerAnswerPiece
| DanswerQuotes
| CitationInfo
| DanswerContexts
| FileChatDisplay
@@ -183,3 +221,109 @@ AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn]
class LLMMetricsContainer(BaseModel):
prompt_tokens: int
response_tokens: int
StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]
class DocumentPruningConfig(BaseModel):
max_chunks: int | None = None
max_window_percentage: float | None = None
max_tokens: int | None = None
# different pruning behavior is expected when the
# user manually selects documents they want to chat with
# e.g. we don't want to truncate each document to be no more
# than one chunk long
is_manually_selected_docs: bool = False
# If user specifies to include additional context Chunks for each match, then different pruning
# is used. As many Sections as possible are included, and the last Section is truncated
# If this is false, all of the Sections are truncated if they are longer than the expected Chunk size.
# Sections are often expected to be longer than the maximum Chunk size but Chunks should not be.
use_sections: bool = True
# If using tools, then we need to consider the tool length
tool_num_tokens: int = 0
# If using a tool message to represent the docs, then we have to JSON serialize
# the document content, which adds to the token count.
using_tool_message: bool = False
class ContextualPruningConfig(DocumentPruningConfig):
num_chunk_multiple: int
@classmethod
def from_doc_pruning_config(
cls, num_chunk_multiple: int, doc_pruning_config: DocumentPruningConfig
) -> "ContextualPruningConfig":
return cls(num_chunk_multiple=num_chunk_multiple, **doc_pruning_config.dict())
class CitationConfig(BaseModel):
all_docs_useful: bool = False
class QuotesConfig(BaseModel):
pass
class AnswerStyleConfig(BaseModel):
citation_config: CitationConfig | None = None
quotes_config: QuotesConfig | None = None
document_pruning_config: DocumentPruningConfig = Field(
default_factory=DocumentPruningConfig
)
# forces the LLM to return a structured response, see
# https://platform.openai.com/docs/guides/structured-outputs/introduction
# right now, only used by the simple chat API
structured_response_format: dict | None = None
@model_validator(mode="after")
def check_quotes_and_citation(self) -> "AnswerStyleConfig":
if self.citation_config is None and self.quotes_config is None:
raise ValueError(
"One of `citation_config` or `quotes_config` must be provided"
)
if self.citation_config is not None and self.quotes_config is not None:
raise ValueError(
"Only one of `citation_config` or `quotes_config` must be provided"
)
return self
class PromptConfig(BaseModel):
"""Final representation of the Prompt configuration passed
into the `Answer` object."""
system_prompt: str
task_prompt: str
datetime_aware: bool
include_citations: bool
@classmethod
def from_model(
cls, model: "Prompt", prompt_override: PromptOverride | None = None
) -> "PromptConfig":
override_system_prompt = (
prompt_override.system_prompt if prompt_override else None
)
override_task_prompt = prompt_override.task_prompt if prompt_override else None
return cls(
system_prompt=override_system_prompt or model.system_prompt,
task_prompt=override_task_prompt or model.task_prompt,
datetime_aware=model.datetime_aware,
include_citations=model.include_citations,
)
model_config = ConfigDict(frozen=True)
ResponsePart = (
DanswerAnswerPiece
| CitationInfo
| ToolCallKickoff
| ToolResponse
| ToolCallFinalResult
| StreamStopInfo
)

View File

@@ -6,16 +6,24 @@ from typing import cast
from sqlalchemy.orm import Session
from danswer.chat.answer import Answer
from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.chat_utils import create_temporary_persona
from danswer.chat.models import AllCitations
from danswer.chat.models import AnswerStyleConfig
from danswer.chat.models import ChatDanswerBotResponse
from danswer.chat.models import CitationConfig
from danswer.chat.models import CitationInfo
from danswer.chat.models import CustomToolResponse
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerContexts
from danswer.chat.models import DocumentPruningConfig
from danswer.chat.models import FileChatDisplay
from danswer.chat.models import FinalUsedContextDocsResponse
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import MessageResponseIDInfo
from danswer.chat.models import MessageSpecificCitations
from danswer.chat.models import PromptConfig
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.chat.models import StreamStopInfo
@@ -54,16 +62,11 @@ from danswer.document_index.factory import get_default_document_index
from danswer.file_store.models import ChatFileType
from danswer.file_store.models import FileDescriptor
from danswer.file_store.utils import load_all_chat_files
from danswer.file_store.utils import save_files_from_urls
from danswer.llm.answering.answer import Answer
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import CitationConfig
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.file_store.utils import save_files
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.models import PreviousMessage
from danswer.llm.utils import litellm_exception_to_error_msg
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.server.query_and_chat.models import ChatMessageDetail
@@ -102,6 +105,7 @@ from danswer.tools.tool_implementations.internet_search.internet_search_tool imp
from danswer.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from danswer.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
from danswer.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
@@ -113,7 +117,10 @@ from danswer.tools.tool_implementations.search.search_tool import (
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.utils.logger import setup_logger
from danswer.utils.long_term_log import LongTermLogger
from danswer.utils.timing import log_function_time
from danswer.utils.timing import log_generator_function_time
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -256,6 +263,7 @@ def _get_force_search_settings(
ChatPacket = (
StreamingError
| QADocsResponse
| DanswerContexts
| LLMRelevanceFilterResponse
| FinalUsedContextDocsResponse
| ChatMessageDetail
@@ -286,6 +294,8 @@ def stream_chat_message_objects(
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
enforce_chat_session_id_for_search_docs: bool = True,
bypass_acl: bool = False,
include_contexts: bool = False,
) -> ChatPacketStream:
"""Streams in order:
1. [conditional] Retrieved documents if a search needs to be run
@@ -293,6 +303,7 @@ def stream_chat_message_objects(
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
4. [always] Details on the final AI response message that is created
"""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
use_existing_user_message = new_msg_req.use_existing_user_message
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
@@ -322,17 +333,31 @@ def stream_chat_message_objects(
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
)
# use alternate persona if alternative assistant id is passed in
if alternate_assistant_id is not None:
# Allows users to specify a temporary persona (assistant) in the chat session
# this takes highest priority since it's user specified
persona = get_persona_by_id(
alternate_assistant_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
elif new_msg_req.persona_override_config:
# Certain endpoints allow users to specify arbitrary persona settings
# this should never conflict with the alternate_assistant_id
persona = persona = create_temporary_persona(
db_session=db_session,
persona_config=new_msg_req.persona_override_config,
user=user,
)
else:
persona = chat_session.persona
if not persona:
raise RuntimeError("No persona specified or found for chat session")
# If a prompt override is specified via the API, use that with highest priority
# but for saving it, we are just mapping it to an existing prompt
prompt_id = new_msg_req.prompt_id
if prompt_id is None and persona.prompts:
prompt_id = sorted(persona.prompts, key=lambda x: x.id)[-1].id
@@ -555,19 +580,34 @@ def stream_chat_message_objects(
reserved_message_id=reserved_message_id,
)
if not final_msg.prompt:
raise RuntimeError("No Prompt found")
prompt_config = (
PromptConfig.from_model(
final_msg.prompt,
prompt_override=(
new_msg_req.prompt_override or chat_session.prompt_override
),
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
if new_msg_req.persona_override_config:
prompt_config = PromptConfig(
system_prompt=new_msg_req.persona_override_config.prompts[
0
].system_prompt,
task_prompt=new_msg_req.persona_override_config.prompts[0].task_prompt,
datetime_aware=new_msg_req.persona_override_config.prompts[
0
].datetime_aware,
include_citations=new_msg_req.persona_override_config.prompts[
0
].include_citations,
)
if not persona
else PromptConfig.from_model(persona.prompts[0])
)
elif prompt_override:
if not final_msg.prompt:
raise ValueError(
"Prompt override cannot be applied, no base prompt found."
)
prompt_config = PromptConfig.from_model(
final_msg.prompt,
prompt_override=prompt_override,
)
elif final_msg.prompt:
prompt_config = PromptConfig.from_model(final_msg.prompt)
else:
prompt_config = PromptConfig.from_model(persona.prompts[0])
answer_style_config = AnswerStyleConfig(
citation_config=CitationConfig(
all_docs_useful=selected_db_search_docs is not None
@@ -587,11 +627,13 @@ def stream_chat_message_objects(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,
retrieval_options=retrieval_options or RetrievalDetails(),
rerank_settings=new_msg_req.rerank_settings,
selected_sections=selected_sections,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
latest_query_files=latest_query_files,
bypass_acl=bypass_acl,
),
internet_search_tool_config=InternetSearchToolConfig(
answer_style_config=answer_style_config,
@@ -605,6 +647,7 @@ def stream_chat_message_objects(
additional_headers=custom_tool_additional_headers,
),
)
tools: list[Tool] = []
for tool_list in tool_dict.values():
tools.extend(tool_list)
@@ -637,7 +680,8 @@ def stream_chat_message_objects(
reference_db_search_docs = None
qa_docs_response = None
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
# any files to associate with the AI message e.g. dall-e generated images
ai_message_files = []
dropped_indices = None
tool_result = None
@@ -692,8 +736,14 @@ def stream_chat_message_objects(
list[ImageGenerationResponse], packet.response
)
file_ids = save_files_from_urls(
[img.url for img in img_generation_response]
file_ids = save_files(
urls=[img.url for img in img_generation_response if img.url],
base64_files=[
img.image_data
for img in img_generation_response
if img.image_data
],
tenant_id=tenant_id,
)
ai_message_files = [
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
@@ -719,15 +769,19 @@ def stream_chat_message_objects(
or custom_tool_response.response_type == "csv"
):
file_ids = custom_tool_response.tool_result.file_ids
ai_message_files = [
FileDescriptor(
id=str(file_id),
type=ChatFileType.IMAGE
if custom_tool_response.response_type == "image"
else ChatFileType.CSV,
)
for file_id in file_ids
]
ai_message_files.extend(
[
FileDescriptor(
id=str(file_id),
type=(
ChatFileType.IMAGE
if custom_tool_response.response_type == "image"
else ChatFileType.CSV
),
)
for file_id in file_ids
]
)
yield FileChatDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
@@ -736,6 +790,8 @@ def stream_chat_message_objects(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
elif packet.id == SEARCH_DOC_CONTENT_ID and include_contexts:
yield cast(DanswerContexts, packet.response)
elif isinstance(packet, StreamStopInfo):
pass
@@ -775,7 +831,8 @@ def stream_chat_message_objects(
citations_list=answer.citations,
db_docs=reference_db_search_docs,
)
yield AllCitations(citations=answer.citations)
if not answer.is_cancelled():
yield AllCitations(citations=answer.citations)
# Saving Gen AI answer and responding with message info
tool_name_to_tool_id: dict[str, int] = {}
@@ -844,3 +901,30 @@ def stream_chat_message(
)
for obj in objects:
yield get_json_line(obj.model_dump())
@log_function_time()
def gather_stream_for_slack(
packets: ChatPacketStream,
) -> ChatDanswerBotResponse:
response = ChatDanswerBotResponse()
answer = ""
for packet in packets:
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
elif isinstance(packet, QADocsResponse):
response.docs = packet
elif isinstance(packet, StreamingError):
response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail):
response.chat_message_id = packet.message_id
elif isinstance(packet, LLMRelevanceFilterResponse):
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
elif isinstance(packet, AllCitations):
response.citations = packet.citations
if answer:
response.answer = answer
return response

View File

@@ -4,20 +4,26 @@ from typing import cast
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from pydantic.v1 import BaseModel as BaseModel__v1
from danswer.chat.models import PromptConfig
from danswer.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens
from danswer.chat.prompt_builder.utils import translate_history_to_basemessages
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input_tokens
from danswer.llm.interfaces import LLMConfig
from danswer.llm.models import PreviousMessage
from danswer.llm.utils import build_content_with_imgs
from danswer.llm.utils import check_message_tokens
from danswer.llm.utils import message_to_prompt_and_imgs
from danswer.llm.utils import translate_history_to_basemessages
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from danswer.prompts.prompt_utils import add_date_time_to_prompt
from danswer.prompts.prompt_utils import drop_messages_history_overflow
from danswer.tools.force import ForceUseTool
from danswer.tools.models import ToolCallFinalResult
from danswer.tools.models import ToolCallKickoff
from danswer.tools.models import ToolResponse
from danswer.tools.tool import Tool
def default_build_system_message(
@@ -139,3 +145,15 @@ class AnswerPromptBuilder:
return drop_messages_history_overflow(
final_messages_with_tokens, self.max_tokens
)
class LLMCall(BaseModel__v1):
prompt_builder: AnswerPromptBuilder
tools: list[Tool]
force_use_tool: ForceUseTool
files: list[InMemoryChatFile]
tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult]
using_tool_calling_llm: bool
class Config:
arbitrary_types_allowed = True

View File

@@ -2,12 +2,12 @@ from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from danswer.chat.models import LlmDoc
from danswer.chat.models import PromptConfig
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
from danswer.context.search.models import InferenceChunk
from danswer.db.models import Persona
from danswer.db.persona import get_default_prompt__read_only
from danswer.db.search_settings import get_multilingual_expansion
from danswer.llm.answering.models import PromptConfig
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.interfaces import LLMConfig

View File

@@ -1,10 +1,10 @@
from langchain.schema.messages import HumanMessage
from danswer.chat.models import LlmDoc
from danswer.chat.models import PromptConfig
from danswer.configs.chat_configs import LANGUAGE_HINT
from danswer.context.search.models import InferenceChunk
from danswer.db.search_settings import get_multilingual_expansion
from danswer.llm.answering.models import PromptConfig
from danswer.llm.utils import message_to_prompt_and_imgs
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK

View File

@@ -0,0 +1,62 @@
from langchain.schema.messages import AIMessage
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from danswer.configs.constants import MessageType
from danswer.db.models import ChatMessage
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.models import PreviousMessage
from danswer.llm.utils import build_content_with_imgs
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
def build_dummy_prompt(
system_prompt: str, task_prompt: str, retrieval_disabled: bool
) -> str:
if retrieval_disabled:
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
user_query="<USER_QUERY>",
system_prompt=system_prompt,
task_prompt=task_prompt,
).strip()
return PARAMATERIZED_PROMPT.format(
context_docs_str="<CONTEXT_DOCS>",
user_query="<USER_QUERY>",
system_prompt=system_prompt,
task_prompt=task_prompt,
).strip()
def translate_danswer_msg_to_langchain(
msg: ChatMessage | PreviousMessage,
) -> BaseMessage:
files: list[InMemoryChatFile] = []
# If the message is a `ChatMessage`, it doesn't have the downloaded files
# attached. Just ignore them for now.
if not isinstance(msg, ChatMessage):
files = msg.files
content = build_content_with_imgs(msg.message, files, message_type=msg.message_type)
if msg.message_type == MessageType.SYSTEM:
raise ValueError("System messages are not currently part of history")
if msg.message_type == MessageType.ASSISTANT:
return AIMessage(content=content)
if msg.message_type == MessageType.USER:
return HumanMessage(content=content)
raise ValueError(f"New message type {msg.message_type} not handled")
def translate_history_to_basemessages(
history: list[ChatMessage] | list["PreviousMessage"],
) -> tuple[list[BaseMessage], list[int]]:
history_basemessages = [
translate_danswer_msg_to_langchain(msg)
for msg in history
if msg.token_count != 0
]
history_token_counts = [msg.token_count for msg in history if msg.token_count != 0]
return history_basemessages, history_token_counts

View File

@@ -5,16 +5,16 @@ from typing import TypeVar
from pydantic import BaseModel
from danswer.chat.models import ContextualPruningConfig
from danswer.chat.models import (
LlmDoc,
)
from danswer.chat.models import PromptConfig
from danswer.chat.prompt_builder.citations_prompt import compute_max_document_tokens
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.context.search.models import InferenceChunk
from danswer.context.search.models import InferenceSection
from danswer.llm.answering.models import ContextualPruningConfig
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
from danswer.llm.interfaces import LLMConfig
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content

View File

@@ -3,16 +3,11 @@ from collections.abc import Generator
from langchain_core.messages import BaseMessage
from danswer.chat.llm_response_handler import ResponsePart
from danswer.chat.models import CitationInfo
from danswer.chat.models import LlmDoc
from danswer.llm.answering.llm_response_handler import ResponsePart
from danswer.llm.answering.stream_processing.citation_processing import (
CitationProcessor,
)
from danswer.llm.answering.stream_processing.quotes_processing import (
QuotesProcessor,
)
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
from danswer.chat.stream_processing.citation_processing import CitationProcessor
from danswer.chat.stream_processing.utils import DocumentIdOrderMapping
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -40,13 +35,18 @@ class DummyAnswerResponseHandler(AnswerResponseHandler):
class CitationResponseHandler(AnswerResponseHandler):
def __init__(
self, context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping
self,
context_docs: list[LlmDoc],
doc_id_to_rank_map: DocumentIdOrderMapping,
display_doc_order_dict: dict[str, int],
):
self.context_docs = context_docs
self.doc_id_to_rank_map = doc_id_to_rank_map
self.display_doc_order_dict = display_doc_order_dict
self.citation_processor = CitationProcessor(
context_docs=self.context_docs,
doc_id_to_rank_map=self.doc_id_to_rank_map,
display_doc_order_dict=self.display_doc_order_dict,
)
self.processed_text = ""
self.citations: list[CitationInfo] = []
@@ -70,28 +70,29 @@ class CitationResponseHandler(AnswerResponseHandler):
yield from self.citation_processor.process_token(content)
class QuotesResponseHandler(AnswerResponseHandler):
def __init__(
self,
context_docs: list[LlmDoc],
is_json_prompt: bool = True,
):
self.quotes_processor = QuotesProcessor(
context_docs=context_docs,
is_json_prompt=is_json_prompt,
)
# No longer in use, remove later
# class QuotesResponseHandler(AnswerResponseHandler):
# def __init__(
# self,
# context_docs: list[LlmDoc],
# is_json_prompt: bool = True,
# ):
# self.quotes_processor = QuotesProcessor(
# context_docs=context_docs,
# is_json_prompt=is_json_prompt,
# )
def handle_response_part(
self,
response_item: BaseMessage | None,
previous_response_items: list[BaseMessage],
) -> Generator[ResponsePart, None, None]:
if response_item is None:
yield from self.quotes_processor.process_token(None)
return
# def handle_response_part(
# self,
# response_item: BaseMessage | None,
# previous_response_items: list[BaseMessage],
# ) -> Generator[ResponsePart, None, None]:
# if response_item is None:
# yield from self.quotes_processor.process_token(None)
# return
content = (
response_item.content if isinstance(response_item.content, str) else ""
)
# content = (
# response_item.content if isinstance(response_item.content, str) else ""
# )
yield from self.quotes_processor.process_token(content)
# yield from self.quotes_processor.process_token(content)

View File

@@ -4,8 +4,8 @@ from collections.abc import Generator
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.chat.stream_processing.utils import DocumentIdOrderMapping
from danswer.configs.chat_configs import STOP_STREAM_PAT
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
from danswer.prompts.constants import TRIPLE_BACKTICK
from danswer.utils.logger import setup_logger
@@ -22,12 +22,16 @@ class CitationProcessor:
self,
context_docs: list[LlmDoc],
doc_id_to_rank_map: DocumentIdOrderMapping,
display_doc_order_dict: dict[str, int],
stop_stream: str | None = STOP_STREAM_PAT,
):
self.context_docs = context_docs
self.doc_id_to_rank_map = doc_id_to_rank_map
self.stop_stream = stop_stream
self.order_mapping = doc_id_to_rank_map.order_mapping
self.display_doc_order_dict = (
display_doc_order_dict # original order of docs to displayed to user
)
self.llm_out = ""
self.max_citation_num = len(context_docs)
self.citation_order: list[int] = []
@@ -67,9 +71,9 @@ class CitationProcessor:
if piece_that_comes_after == "\n" and in_code_block(self.llm_out):
self.curr_segment = self.curr_segment.replace("```", "```plaintext")
citation_pattern = r"\[(\d+)\]"
citation_pattern = r"\[(\d+)\]|\[\[(\d+)\]\]" # [1], [[1]], etc.
citations_found = list(re.finditer(citation_pattern, self.curr_segment))
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
possible_citation_pattern = r"(\[+\d*$)" # [1, [, [[, [[2, etc.
possible_citation_found = re.search(
possible_citation_pattern, self.curr_segment
)
@@ -77,13 +81,15 @@ class CitationProcessor:
if len(citations_found) == 0 and len(self.llm_out) - self.past_cite_count > 5:
self.current_citations = []
result = "" # Initialize result here
result = ""
if citations_found and not in_code_block(self.llm_out):
last_citation_end = 0
length_to_add = 0
while len(citations_found) > 0:
citation = citations_found.pop(0)
numerical_value = int(citation.group(1))
numerical_value = int(
next(group for group in citation.groups() if group is not None)
)
if 1 <= numerical_value <= self.max_citation_num:
context_llm_doc = self.context_docs[numerical_value - 1]
@@ -96,6 +102,18 @@ class CitationProcessor:
self.citation_order.index(real_citation_num) + 1
)
# get the value that was displayed to user, should always
# be in the display_doc_order_dict. But check anyways
if context_llm_doc.document_id in self.display_doc_order_dict:
displayed_citation_num = self.display_doc_order_dict[
context_llm_doc.document_id
]
else:
displayed_citation_num = real_citation_num
logger.warning(
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
)
# Skip consecutive citations of the same work
if target_citation_num in self.current_citations:
start, end = citation.span()
@@ -116,6 +134,7 @@ class CitationProcessor:
doc_id = int(match.group(1))
context_llm_doc = self.context_docs[doc_id - 1]
yield CitationInfo(
# stay with the original for now (order of LLM cites)
citation_num=target_citation_num,
document_id=context_llm_doc.document_id,
)
@@ -131,29 +150,24 @@ class CitationProcessor:
link = context_llm_doc.link
# Replace the citation in the current segment
start, end = citation.span()
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[{target_citation_num}]"
+ self.curr_segment[end + length_to_add :]
)
self.past_cite_count = len(self.llm_out)
self.current_citations.append(target_citation_num)
if target_citation_num not in self.cited_inds:
self.cited_inds.add(target_citation_num)
yield CitationInfo(
# stay with the original for now (order of LLM cites)
citation_num=target_citation_num,
document_id=context_llm_doc.document_id,
)
start, end = citation.span()
if link:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{target_citation_num}]]({link})"
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
# + f"[[{target_citation_num}]]({link})"
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
@@ -161,7 +175,8 @@ class CitationProcessor:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{target_citation_num}]]()"
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
# + f"[[{target_citation_num}]]()"
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length

View File

@@ -1,3 +1,4 @@
# THIS IS NO LONGER IN USE
import math
import re
from collections.abc import Generator
@@ -5,11 +6,10 @@ from json import JSONDecodeError
from typing import Optional
import regex
from pydantic import BaseModel
from danswer.chat.models import DanswerAnswer
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerQuote
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.context.search.models import InferenceChunk
@@ -26,6 +26,20 @@ logger = setup_logger()
answer_pattern = re.compile(r'{\s*"answer"\s*:\s*"', re.IGNORECASE)
class DanswerQuote(BaseModel):
# This is during inference so everything is a string by this point
quote: str
document_id: str
link: str | None
source_type: str
semantic_identifier: str
blurb: str
class DanswerQuotes(BaseModel):
quotes: list[DanswerQuote]
def _extract_answer_quotes_freeform(
answer_raw: str,
) -> tuple[Optional[str], Optional[list[str]]]:

View File

@@ -4,8 +4,8 @@ from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from langchain_core.messages import ToolCall
from danswer.llm.answering.llm_response_handler import LLMCall
from danswer.llm.answering.llm_response_handler import ResponsePart
from danswer.chat.models import ResponsePart
from danswer.chat.prompt_builder.build import LLMCall
from danswer.llm.interfaces import LLM
from danswer.tools.force import ForceUseTool
from danswer.tools.message import build_tool_message

View File

@@ -43,9 +43,6 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
AUTH_TYPE = AuthType((os.environ.get("AUTH_TYPE") or AuthType.DISABLED.value).lower())
DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED
# Necessary for cloud integration tests
DISABLE_VERIFICATION = os.environ.get("DISABLE_VERIFICATION", "").lower() == "true"
# Encryption key secret is used to encrypt connector credentials, api keys, and other sensitive
# information. This provides an extra layer of security on top of Postgres access controls
# and is available in Danswer EE
@@ -84,7 +81,14 @@ OAUTH_CLIENT_SECRET = (
or ""
)
# for future OAuth connector support
# OAUTH_CONFLUENCE_CLIENT_ID = os.environ.get("OAUTH_CONFLUENCE_CLIENT_ID", "")
# OAUTH_CONFLUENCE_CLIENT_SECRET = os.environ.get("OAUTH_CONFLUENCE_CLIENT_SECRET", "")
# OAUTH_JIRA_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLIENT_ID", "")
# OAUTH_JIRA_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLIENT_SECRET", "")
USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "")
# for basic auth
REQUIRE_EMAIL_VERIFICATION = (
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
@@ -118,6 +122,8 @@ VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
VESPA_CONFIG_SERVER_HOST = os.environ.get("VESPA_CONFIG_SERVER_HOST") or VESPA_HOST
VESPA_PORT = os.environ.get("VESPA_PORT") or "8081"
VESPA_TENANT_PORT = os.environ.get("VESPA_TENANT_PORT") or "19071"
# the number of times to try and connect to vespa on startup before giving up
VESPA_NUM_ATTEMPTS_ON_STARTUP = int(os.environ.get("NUM_RETRIES_ON_STARTUP") or 10)
VESPA_CLOUD_URL = os.environ.get("VESPA_CLOUD_URL", "")
@@ -308,6 +314,22 @@ CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000)
)
# Due to breakages in the confluence API, the timezone offset must be specified client side
# to match the user's specified timezone.
# The current state of affairs:
# CQL queries are parsed in the user's timezone and cannot be specified in UTC
# no API retrieves the user's timezone
# All data is returned in UTC, so we can't derive the user's timezone from that
# https://community.developer.atlassian.com/t/confluence-cloud-time-zone-get-via-rest-api/35954/16
# https://jira.atlassian.com/browse/CONFCLOUD-69670
# enter as a floating point offset from UTC in hours (-24 < val < 24)
# this will be applied globally, so it probably makes sense to transition this to per
# connector as some point.
CONFLUENCE_TIMEZONE_OFFSET = float(os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", 0.0))
JIRA_CONNECTOR_LABELS_TO_SKIP = [
ignored_tag
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
@@ -326,6 +348,12 @@ GITLAB_CONNECTOR_INCLUDE_CODE_FILES = (
os.environ.get("GITLAB_CONNECTOR_INCLUDE_CODE_FILES", "").lower() == "true"
)
# Egnyte specific configs
EGNYTE_LOCALHOST_OVERRIDE = os.getenv("EGNYTE_LOCALHOST_OVERRIDE")
EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN")
EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID")
EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
DASK_JOB_CLIENT_ENABLED = (
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
)
@@ -389,21 +417,28 @@ LARGE_CHUNK_RATIO = 4
# We don't want the metadata to overwhelm the actual contents of the chunk
SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true"
# Timeout to wait for job's last update before killing it, in hours
CLEANUP_INDEXING_JOBS_TIMEOUT = int(os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT", 3))
CLEANUP_INDEXING_JOBS_TIMEOUT = int(
os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT") or 3
)
# The indexer will warn in the logs whenver a document exceeds this threshold (in bytes)
INDEXING_SIZE_WARNING_THRESHOLD = int(
os.environ.get("INDEXING_SIZE_WARNING_THRESHOLD", 100 * 1024 * 1024)
os.environ.get("INDEXING_SIZE_WARNING_THRESHOLD") or 100 * 1024 * 1024
)
# during indexing, will log verbose memory diff stats every x batches and at the end.
# 0 disables this behavior and is the default.
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL", 0))
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL") or 0)
# During an indexing attempt, specifies the number of batches which are allowed to
# exception without aborting the attempt.
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT", 0))
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT") or 0)
# Maximum file size in a document to be indexed
MAX_DOCUMENT_CHARS = int(os.environ.get("MAX_DOCUMENT_CHARS") or 5_000_000)
MAX_FILE_SIZE_BYTES = int(
os.environ.get("MAX_FILE_SIZE_BYTES") or 2 * 1024 * 1024 * 1024
) # 2GB in bytes
#####
# Miscellaneous
@@ -493,10 +528,6 @@ CONTROL_PLANE_API_BASE_URL = os.environ.get(
# JWT configuration
JWT_ALGORITHM = "HS256"
# Super Users
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
#####
# API Key Configs
@@ -510,3 +541,6 @@ API_KEY_HASH_ROUNDS = (
POD_NAME = os.environ.get("POD_NAME")
POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"

View File

@@ -3,7 +3,6 @@ import os
PROMPTS_YAML = "./danswer/seeding/prompts.yaml"
PERSONAS_YAML = "./danswer/seeding/personas.yaml"
INPUT_PROMPT_YAML = "./danswer/seeding/input_prompts.yaml"
NUM_RETURNED_HITS = 50
# Used for LLM filtering and reranking

View File

@@ -31,6 +31,8 @@ DISABLED_GEN_AI_MSG = (
"You can still use Danswer as a search engine."
)
DEFAULT_PERSONA_ID = 0
# Postgres connection constants for application_name
POSTGRES_WEB_APP_NAME = "web"
POSTGRES_INDEXER_APP_NAME = "indexer"
@@ -130,6 +132,7 @@ class DocumentSource(str, Enum):
NOT_APPLICABLE = "not_applicable"
FRESHDESK = "freshdesk"
FIREFLIES = "fireflies"
EGNYTE = "egnyte"
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
@@ -259,6 +262,32 @@ class DanswerCeleryPriority(int, Enum):
LOWEST = auto()
class DanswerCeleryTask:
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
CHECK_FOR_INDEXING = "check_for_indexing"
CHECK_FOR_PRUNING = "check_for_pruning"
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
"connector_permission_sync_generator_task"
)
UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK = (
"update_external_document_permissions_task"
)
CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK = (
"connector_external_group_sync_generator_task"
)
CONNECTOR_INDEXING_PROXY_TASK = "connector_indexing_proxy_task"
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"
CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task"
AUTOGENERATE_USAGE_REPORT_TASK = "autogenerate_usage_report_task"
REDIS_SOCKET_KEEPALIVE_OPTIONS = {}
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3

View File

@@ -4,11 +4,8 @@ import os
# Danswer Slack Bot Configs
#####
DANSWER_BOT_NUM_RETRIES = int(os.environ.get("DANSWER_BOT_NUM_RETRIES", "5"))
DANSWER_BOT_ANSWER_GENERATION_TIMEOUT = int(
os.environ.get("DANSWER_BOT_ANSWER_GENERATION_TIMEOUT", "90")
)
# How much of the available input context can be used for thread context
DANSWER_BOT_TARGET_CHUNK_PERCENTAGE = 512 * 2 / 3072
MAX_THREAD_CONTEXT_PERCENTAGE = 512 * 2 / 3072
# Number of docs to display in "Reference Documents"
DANSWER_BOT_NUM_DOCS_TO_DISPLAY = int(
os.environ.get("DANSWER_BOT_NUM_DOCS_TO_DISPLAY", "5")
@@ -47,17 +44,6 @@ DANSWER_BOT_DISPLAY_ERROR_MSGS = os.environ.get(
DANSWER_BOT_RESPOND_EVERY_CHANNEL = (
os.environ.get("DANSWER_BOT_RESPOND_EVERY_CHANNEL", "").lower() == "true"
)
# Add a second LLM call post Answer to verify if the Answer is valid
# Throws out answers that don't directly or fully answer the user query
# This is the default for all DanswerBot channels unless the channel is configured individually
# Set/unset by "Hide Non Answers"
ENABLE_DANSWERBOT_REFLEXION = (
os.environ.get("ENABLE_DANSWERBOT_REFLEXION", "").lower() == "true"
)
# Currently not support chain of thought, probably will add back later
DANSWER_BOT_DISABLE_COT = True
# if set, will default DanswerBot to use quotes and reference documents
DANSWER_BOT_USE_QUOTES = os.environ.get("DANSWER_BOT_USE_QUOTES", "").lower() == "true"
# Maximum Questions Per Minute, Default Uncapped
DANSWER_BOT_MAX_QPM = int(os.environ.get("DANSWER_BOT_MAX_QPM") or 0) or None

View File

@@ -70,7 +70,9 @@ GEN_AI_NUM_RESERVED_OUTPUT_TOKENS = int(
)
# Typically, GenAI models nowadays are at least 4K tokens
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = 4096
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = int(
os.environ.get("GEN_AI_MODEL_FALLBACK_MAX_TOKENS") or 4096
)
# Number of tokens from chat history to include at maximum
# 3000 should be enough context regardless of use, no need to include as much as possible

View File

@@ -2,6 +2,8 @@ import json
import os
IMAGE_GENERATION_OUTPUT_FORMAT = os.environ.get("IMAGE_GENERATION_OUTPUT_FORMAT", "url")
# if specified, will pass through request headers to the call to API calls made by custom tools
CUSTOM_TOOL_PASS_THROUGH_HEADERS: list[str] | None = None
_CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW = os.environ.get(

View File

@@ -11,11 +11,16 @@ Connectors come in 3 different flows:
- Load Connector:
- Bulk indexes documents to reflect a point in time. This type of connector generally works by either pulling all
documents via a connector's API or loads the documents from some sort of a dump file.
- Poll connector:
- Poll Connector:
- Incrementally updates documents based on a provided time range. It is used by the background job to pull the latest
changes and additions since the last round of polling. This connector helps keep the document index up to date
without needing to fetch/embed/index every document which would be too slow to do frequently on large sets of
documents.
- Slim Connector:
- This connector should be a lighter weight method of checking all documents in the source to see if they still exist.
- This connector should be identical to the Poll or Load Connector except that it only fetches the IDs of the documents, not the documents themselves.
- This is used by our pruning job which removes old documents from the index.
- The optional start and end datetimes can be ignored.
- Event Based connectors:
- Connectors that listen to events and update documents accordingly.
- Currently not used by the background job, this exists for future design purposes.
@@ -26,8 +31,14 @@ Refer to [interfaces.py](https://github.com/danswer-ai/danswer/blob/main/backend
and this first contributor created Pull Request for a new connector (Shoutout to Dan Brown):
[Reference Pull Request](https://github.com/danswer-ai/danswer/pull/139)
For implementing a Slim Connector, refer to the comments in this PR:
[Slim Connector PR](https://github.com/danswer-ai/danswer/pull/3303/files)
All new connectors should have tests added to the `backend/tests/daily/connectors` directory. Refer to the above PR for an example of adding tests for a new connector.
#### Implementing the new Connector
The connector must subclass one or more of LoadConnector, PollConnector, or EventConnector.
The connector must subclass one or more of LoadConnector, PollConnector, SlimConnector, or EventConnector.
The `__init__` should take arguments for configuring what documents the connector will and where it finds those
documents. For example, if you have a wiki site, it may include the configuration for the team, topic, folder, etc. of

View File

@@ -1,9 +1,11 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from urllib.parse import quote
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
from danswer.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
@@ -13,6 +15,7 @@ from danswer.connectors.confluence.utils import attachment_to_content
from danswer.connectors.confluence.utils import build_confluence_document_id
from danswer.connectors.confluence.utils import datetime_from_string
from danswer.connectors.confluence.utils import extract_text_from_confluence_html
from danswer.connectors.confluence.utils import validate_attachment_filetype
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector
@@ -51,7 +54,7 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
"restrictions.read.restrictions.group",
]
_SLIM_DOC_BATCH_SIZE = 1000
_SLIM_DOC_BATCH_SIZE = 5000
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
@@ -69,6 +72,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
# skip it. This is generally used to avoid indexing extra sensitive
# pages.
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
timezone_offset: float = CONFLUENCE_TIMEZONE_OFFSET,
) -> None:
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
@@ -104,6 +108,8 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
)
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
self.timezone: timezone = timezone(offset=timedelta(hours=timezone_offset))
@property
def confluence_client(self) -> OnyxConfluence:
if self._confluence_client is None:
@@ -204,12 +210,14 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
confluence_page_ids: list[str] = []
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
logger.debug(f"page_query: {page_query}")
# Fetch pages as Documents
for page in self.confluence_client.paginated_cql_retrieval(
cql=page_query,
expand=",".join(_PAGE_EXPANSION_FIELDS),
limit=self.batch_size,
):
logger.debug(f"_fetch_document_batches: {page['id']}")
confluence_page_ids.append(page["id"])
doc = self._convert_object_to_document(page)
if doc is not None:
@@ -242,10 +250,10 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput:
# Add time filters
formatted_start_time = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
formatted_start_time = datetime.fromtimestamp(start, tz=self.timezone).strftime(
"%Y-%m-%d %H:%M"
)
formatted_end_time = datetime.fromtimestamp(end, tz=timezone.utc).strftime(
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
"%Y-%m-%d %H:%M"
)
self.cql_time_filter = f" and lastmodified >= '{formatted_start_time}'"
@@ -269,9 +277,11 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
):
# If the page has restrictions, add them to the perm_sync_data
# These will be used by doc_sync.py to sync permissions
perm_sync_data = {
"restrictions": page.get("restrictions", {}),
"space_key": page.get("space", {}).get("key"),
page_restrictions = page.get("restrictions")
page_space_key = page.get("space", {}).get("key")
page_perm_sync_data = {
"restrictions": page_restrictions or {},
"space_key": page_space_key,
}
doc_metadata_list.append(
@@ -281,7 +291,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
page["_links"]["webui"],
self.is_cloud,
),
perm_sync_data=perm_sync_data,
perm_sync_data=page_perm_sync_data,
)
)
attachment_cql = f"type=attachment and container='{page['id']}'"
@@ -291,6 +301,21 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
expand=restrictions_expand,
limit=_SLIM_DOC_BATCH_SIZE,
):
if not validate_attachment_filetype(attachment):
continue
attachment_restrictions = attachment.get("restrictions")
if not attachment_restrictions:
attachment_restrictions = page_restrictions
attachment_space_key = attachment.get("space", {}).get("key")
if not attachment_space_key:
attachment_space_key = page_space_key
attachment_perm_sync_data = {
"restrictions": attachment_restrictions or {},
"space_key": attachment_space_key,
}
doc_metadata_list.append(
SlimDocument(
id=build_confluence_document_id(
@@ -298,8 +323,11 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
attachment["_links"]["webui"],
self.is_cloud,
),
perm_sync_data=perm_sync_data,
perm_sync_data=attachment_perm_sync_data,
)
)
yield doc_metadata_list
doc_metadata_list = []
if len(doc_metadata_list) > _SLIM_DOC_BATCH_SIZE:
yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE]
doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:]
yield doc_metadata_list

View File

@@ -120,7 +120,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
return cast(F, wrapped_call)
_DEFAULT_PAGINATION_LIMIT = 100
_DEFAULT_PAGINATION_LIMIT = 1000
class OnyxConfluence(Confluence):
@@ -134,6 +134,32 @@ class OnyxConfluence(Confluence):
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
self._wrap_methods()
def get_current_user(self, expand: str | None = None) -> Any:
"""
Implements a method that isn't in the third party client.
Get information about the current user
:param expand: OPTIONAL expand for get status of user.
Possible param is "status". Results are "Active, Deactivated"
:return: Returns the user details
"""
from atlassian.errors import ApiPermissionError # type:ignore
url = "rest/api/user/current"
params = {}
if expand:
params["expand"] = expand
try:
response = self.get(url, params=params)
except HTTPError as e:
if e.response.status_code == 403:
raise ApiPermissionError(
"The calling user does not have permission", reason=e
)
raise
return response
def _wrap_methods(self) -> None:
"""
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
@@ -306,6 +332,13 @@ def _validate_connector_configuration(
)
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
# uncomment the following for testing
# the following is an attempt to retrieve the user's timezone
# Unfornately, all data is returned in UTC regardless of the user's time zone
# even tho CQL parses incoming times based on the user's time zone
# space_key = spaces["results"][0]["key"]
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
if not spaces:
raise RuntimeError(
f"No spaces found at {wiki_base}! "
@@ -335,4 +368,5 @@ def build_confluence_client(
backoff_and_retry=True,
max_backoff_retries=10,
max_backoff_seconds=60,
cloud=is_cloud,
)

View File

@@ -32,7 +32,11 @@ def get_user_email_from_username__server(
response = confluence_client.get_mobile_parameters(user_name)
email = response.get("email")
except Exception:
email = None
# For now, we'll just return a string that indicates failure
# We may want to revert to returning None in the future
# email = None
email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
logger.warning(f"failed to get confluence email for {user_name}")
_USER_EMAIL_CACHE[user_name] = email
return _USER_EMAIL_CACHE[user_name]
@@ -173,19 +177,23 @@ def extract_text_from_confluence_html(
return format_document_soup(soup)
def attachment_to_content(
confluence_client: OnyxConfluence,
attachment: dict[str, Any],
) -> str | None:
"""If it returns None, assume that we should skip this attachment."""
if attachment["metadata"]["mediaType"] in [
def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
return attachment["metadata"]["mediaType"] not in [
"image/jpeg",
"image/png",
"image/gif",
"image/svg+xml",
"video/mp4",
"video/quicktime",
]:
]
def attachment_to_content(
confluence_client: OnyxConfluence,
attachment: dict[str, Any],
) -> str | None:
"""If it returns None, assume that we should skip this attachment."""
if not validate_attachment_filetype(attachment):
return None
download_link = confluence_client.url + attachment["_links"]["download"]
@@ -241,7 +249,7 @@ def build_confluence_document_id(
return f"{base_url}{content_url}"
def extract_referenced_attachment_names(page_text: str) -> list[str]:
def _extract_referenced_attachment_names(page_text: str) -> list[str]:
"""Parse a Confluence html page to generate a list of current
attachments in use

View File

@@ -0,0 +1,384 @@
import io
import os
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from logging import Logger
from typing import Any
from typing import cast
from typing import IO
import requests
from retry import retry
from danswer.configs.app_configs import EGNYTE_BASE_DOMAIN
from danswer.configs.app_configs import EGNYTE_CLIENT_ID
from danswer.configs.app_configs import EGNYTE_CLIENT_SECRET
from danswer.configs.app_configs import EGNYTE_LOCALHOST_OVERRIDE
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import OAuthConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.extract_file_text import detect_encoding
from danswer.file_processing.extract_file_text import extract_file_text
from danswer.file_processing.extract_file_text import get_file_ext
from danswer.file_processing.extract_file_text import is_text_file_extension
from danswer.file_processing.extract_file_text import is_valid_file_ext
from danswer.file_processing.extract_file_text import read_text_file
from danswer.utils.logger import setup_logger
logger = setup_logger()
_EGNYTE_API_BASE = "https://{domain}.egnyte.com/pubapi/v1"
_EGNYTE_APP_BASE = "https://{domain}.egnyte.com"
_TIMEOUT = 60
def _request_with_retries(
method: str,
url: str,
data: dict[str, Any] | None = None,
headers: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
timeout: int = _TIMEOUT,
stream: bool = False,
tries: int = 8,
delay: float = 1,
backoff: float = 2,
) -> requests.Response:
@retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger))
def _make_request() -> requests.Response:
response = requests.request(
method,
url,
data=data,
headers=headers,
params=params,
timeout=timeout,
stream=stream,
)
try:
response.raise_for_status()
except requests.exceptions.HTTPError as e:
if e.response.status_code != 403:
logger.exception(
f"Failed to call Egnyte API.\n"
f"URL: {url}\n"
f"Headers: {headers}\n"
f"Data: {data}\n"
f"Params: {params}"
)
raise e
return response
return _make_request()
def _parse_last_modified(last_modified: str) -> datetime:
return datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z").replace(
tzinfo=timezone.utc
)
def _process_egnyte_file(
file_metadata: dict[str, Any],
file_content: IO,
base_url: str,
folder_path: str | None = None,
) -> Document | None:
"""Process an Egnyte file into a Document object
Args:
file_data: The file data from Egnyte API
file_content: The raw content of the file in bytes
base_url: The base URL for the Egnyte instance
folder_path: Optional folder path to filter results
"""
# Skip if file path doesn't match folder path filter
if folder_path and not file_metadata["path"].startswith(folder_path):
raise ValueError(
f"File path {file_metadata['path']} does not match folder path {folder_path}"
)
file_name = file_metadata["name"]
extension = get_file_ext(file_name)
if not is_valid_file_ext(extension):
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
return None
# Extract text content based on file type
if is_text_file_extension(file_name):
encoding = detect_encoding(file_content)
file_content_raw, file_metadata = read_text_file(
file_content, encoding=encoding, ignore_danswer_metadata=False
)
else:
file_content_raw = extract_file_text(
file=file_content,
file_name=file_name,
break_on_unprocessable=True,
)
# Build the web URL for the file
web_url = f"{base_url}/navigate/file/{file_metadata['group_id']}"
# Create document metadata
metadata: dict[str, str | list[str]] = {
"file_path": file_metadata["path"],
"last_modified": file_metadata.get("last_modified", ""),
}
# Add lock info if present
if lock_info := file_metadata.get("lock_info"):
metadata[
"lock_owner"
] = f"{lock_info.get('first_name', '')} {lock_info.get('last_name', '')}"
# Create the document owners
primary_owner = None
if uploaded_by := file_metadata.get("uploaded_by"):
primary_owner = BasicExpertInfo(
email=uploaded_by, # Using username as email since that's what we have
)
# Create the document
return Document(
id=f"egnyte-{file_metadata['entry_id']}",
sections=[Section(text=file_content_raw.strip(), link=web_url)],
source=DocumentSource.EGNYTE,
semantic_identifier=file_name,
metadata=metadata,
doc_updated_at=(
_parse_last_modified(file_metadata["last_modified"])
if "last_modified" in file_metadata
else None
),
primary_owners=[primary_owner] if primary_owner else None,
)
class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
def __init__(
self,
folder_path: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.domain = "" # will always be set in `load_credentials`
self.folder_path = folder_path or "" # Root folder if not specified
self.batch_size = batch_size
self.access_token: str | None = None
@classmethod
def oauth_id(cls) -> DocumentSource:
return DocumentSource.EGNYTE
@classmethod
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
if not EGNYTE_CLIENT_ID:
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
if not EGNYTE_BASE_DOMAIN:
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
if EGNYTE_LOCALHOST_OVERRIDE:
base_domain = EGNYTE_LOCALHOST_OVERRIDE
callback_uri = f"{base_domain.strip('/')}/connector/oauth/callback/egnyte"
return (
f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
f"?client_id={EGNYTE_CLIENT_ID}"
f"&redirect_uri={callback_uri}"
f"&scope=Egnyte.filesystem"
f"&state={state}"
f"&response_type=code"
)
@classmethod
def oauth_code_to_token(cls, code: str) -> dict[str, Any]:
if not EGNYTE_CLIENT_ID:
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
if not EGNYTE_CLIENT_SECRET:
raise ValueError("EGNYTE_CLIENT_SECRET environment variable must be set")
if not EGNYTE_BASE_DOMAIN:
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
# Exchange code for token
url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
data = {
"client_id": EGNYTE_CLIENT_ID,
"client_secret": EGNYTE_CLIENT_SECRET,
"code": code,
"grant_type": "authorization_code",
"redirect_uri": f"{EGNYTE_LOCALHOST_OVERRIDE or ''}/connector/oauth/callback/egnyte",
"scope": "Egnyte.filesystem",
}
headers = {"Content-Type": "application/x-www-form-urlencoded"}
response = _request_with_retries(
method="POST",
url=url,
data=data,
headers=headers,
# try a lot faster since this is a realtime flow
backoff=0,
delay=0.1,
)
if not response.ok:
raise RuntimeError(f"Failed to exchange code for token: {response.text}")
token_data = response.json()
return {
"domain": EGNYTE_BASE_DOMAIN,
"access_token": token_data["access_token"],
}
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.domain = credentials["domain"]
self.access_token = credentials["access_token"]
return None
def _get_files_list(
self,
path: str,
) -> list[dict[str, Any]]:
if not self.access_token or not self.domain:
raise ConnectorMissingCredentialError("Egnyte")
headers = {
"Authorization": f"Bearer {self.access_token}",
}
params: dict[str, Any] = {
"list_content": True,
}
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{path or ''}"
response = _request_with_retries(
method="GET", url=url, headers=headers, params=params, timeout=_TIMEOUT
)
if not response.ok:
raise RuntimeError(f"Failed to fetch files from Egnyte: {response.text}")
data = response.json()
all_files: list[dict[str, Any]] = []
# Add files from current directory
all_files.extend(data.get("files", []))
# Recursively traverse folders
for item in data.get("folders", []):
all_files.extend(self._get_files_list(item["path"]))
return all_files
def _filter_files(
self,
files: list[dict[str, Any]],
start_time: datetime | None = None,
end_time: datetime | None = None,
) -> list[dict[str, Any]]:
filtered_files = []
for file in files:
if file["is_folder"]:
continue
file_modified = _parse_last_modified(file["last_modified"])
if start_time and file_modified < start_time:
continue
if end_time and file_modified > end_time:
continue
filtered_files.append(file)
return filtered_files
def _process_files(
self,
start_time: datetime | None = None,
end_time: datetime | None = None,
) -> Generator[list[Document], None, None]:
files = self._get_files_list(self.folder_path)
files = self._filter_files(files, start_time, end_time)
current_batch: list[Document] = []
for file in files:
try:
# Set up request with streaming enabled
headers = {
"Authorization": f"Bearer {self.access_token}",
}
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{file['path']}"
response = _request_with_retries(
method="GET",
url=url,
headers=headers,
timeout=_TIMEOUT,
stream=True,
)
if not response.ok:
logger.error(
f"Failed to fetch file content: {file['path']} (status code: {response.status_code})"
)
continue
# Stream the response content into a BytesIO buffer
buffer = io.BytesIO()
for chunk in response.iter_content(chunk_size=8192):
if chunk:
buffer.write(chunk)
# Reset buffer's position to the start
buffer.seek(0)
# Process the streamed file content
doc = _process_egnyte_file(
file_metadata=file,
file_content=buffer,
base_url=_EGNYTE_APP_BASE.format(domain=self.domain),
folder_path=self.folder_path,
)
if doc is not None:
current_batch.append(doc)
if len(current_batch) >= self.batch_size:
yield current_batch
current_batch = []
except Exception:
logger.exception(f"Failed to process file {file['path']}")
continue
if current_batch:
yield current_batch
def load_from_state(self) -> GenerateDocumentsOutput:
yield from self._process_files()
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_time = datetime.fromtimestamp(start, tz=timezone.utc)
end_time = datetime.fromtimestamp(end, tz=timezone.utc)
yield from self._process_files(start_time=start_time, end_time=end_time)
if __name__ == "__main__":
connector = EgnyteConnector()
connector.load_credentials(
{
"domain": os.environ["EGNYTE_DOMAIN"],
"access_token": os.environ["EGNYTE_ACCESS_TOKEN"],
}
)
document_batches = connector.load_from_state()
print(next(document_batches))

View File

@@ -15,6 +15,7 @@ from danswer.connectors.danswer_jira.connector import JiraConnector
from danswer.connectors.discourse.connector import DiscourseConnector
from danswer.connectors.document360.connector import Document360Connector
from danswer.connectors.dropbox.connector import DropboxConnector
from danswer.connectors.egnyte.connector import EgnyteConnector
from danswer.connectors.file.connector import LocalFileConnector
from danswer.connectors.fireflies.connector import FirefliesConnector
from danswer.connectors.freshdesk.connector import FreshdeskConnector
@@ -40,7 +41,6 @@ from danswer.connectors.salesforce.connector import SalesforceConnector
from danswer.connectors.sharepoint.connector import SharepointConnector
from danswer.connectors.slab.connector import SlabConnector
from danswer.connectors.slack.connector import SlackPollConnector
from danswer.connectors.slack.load_connector import SlackLoadConnector
from danswer.connectors.teams.connector import TeamsConnector
from danswer.connectors.web.connector import WebConnector
from danswer.connectors.wikipedia.connector import WikipediaConnector
@@ -63,7 +63,6 @@ def identify_connector_class(
DocumentSource.WEB: WebConnector,
DocumentSource.FILE: LocalFileConnector,
DocumentSource.SLACK: {
InputType.LOAD_STATE: SlackLoadConnector,
InputType.POLL: SlackPollConnector,
InputType.SLIM_RETRIEVAL: SlackPollConnector,
},
@@ -103,6 +102,7 @@ def identify_connector_class(
DocumentSource.XENFORO: XenforoConnector,
DocumentSource.FRESHDESK: FreshdeskConnector,
DocumentSource.FIREFLIES: FirefliesConnector,
DocumentSource.EGNYTE: EgnyteConnector,
}
connector_by_source = connector_map.get(source, {})

View File

@@ -17,11 +17,11 @@ from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.db.engine import get_session_with_tenant
from danswer.file_processing.extract_file_text import check_file_ext_is_valid
from danswer.file_processing.extract_file_text import detect_encoding
from danswer.file_processing.extract_file_text import extract_file_text
from danswer.file_processing.extract_file_text import get_file_ext
from danswer.file_processing.extract_file_text import is_text_file_extension
from danswer.file_processing.extract_file_text import is_valid_file_ext
from danswer.file_processing.extract_file_text import load_files_from_zip
from danswer.file_processing.extract_file_text import read_pdf_file
from danswer.file_processing.extract_file_text import read_text_file
@@ -50,7 +50,7 @@ def _read_files_and_metadata(
file_content, ignore_dirs=True
):
yield os.path.join(directory_path, file_info.filename), file, metadata
elif check_file_ext_is_valid(extension):
elif is_valid_file_ext(extension):
yield file_name, file_content, metadata
else:
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
@@ -63,7 +63,7 @@ def _process_file(
pdf_pass: str | None = None,
) -> list[Document]:
extension = get_file_ext(file_name)
if not check_file_ext_is_valid(extension):
if not is_valid_file_ext(extension):
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
return []

View File

@@ -4,11 +4,13 @@ from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Any
from typing import cast
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.app_configs import MAX_FILE_SIZE_BYTES
from danswer.configs.constants import DocumentSource
from danswer.connectors.google_drive.doc_conversion import build_slim_document
from danswer.connectors.google_drive.doc_conversion import (
@@ -452,12 +454,14 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
if isinstance(self.creds, ServiceAccountCredentials)
else self._manage_oauth_retrieval
)
return retrieval_method(
drive_files = retrieval_method(
is_slim=is_slim,
start=start,
end=end,
)
return drive_files
def _extract_docs_from_google_drive(
self,
start: SecondsSinceUnixEpoch | None = None,
@@ -473,6 +477,15 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
files_to_process = []
# Gather the files into batches to be processed in parallel
for file in self._fetch_drive_items(is_slim=False, start=start, end=end):
if (
file.get("size")
and int(cast(str, file.get("size"))) > MAX_FILE_SIZE_BYTES
):
logger.warning(
f"Skipping file {file.get('name', 'Unknown')} as it is too large: {file.get('size')} bytes"
)
continue
files_to_process.append(file)
if len(files_to_process) >= LARGE_BATCH_SIZE:
yield from _process_files_batch(

View File

@@ -16,7 +16,7 @@ logger = setup_logger()
FILE_FIELDS = (
"nextPageToken, files(mimeType, id, name, permissions, modifiedTime, webViewLink, "
"shortcutDetails, owners(emailAddress))"
"shortcutDetails, owners(emailAddress), size)"
)
SLIM_FILE_FIELDS = (
"nextPageToken, files(mimeType, id, name, permissions(emailAddress, type), "

View File

@@ -2,6 +2,7 @@ import abc
from collections.abc import Iterator
from typing import Any
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import Document
from danswer.connectors.models import SlimDocument
@@ -64,6 +65,23 @@ class SlimConnector(BaseConnector):
raise NotImplementedError
class OAuthConnector(BaseConnector):
@classmethod
@abc.abstractmethod
def oauth_id(cls) -> DocumentSource:
raise NotImplementedError
@classmethod
@abc.abstractmethod
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
raise NotImplementedError
@classmethod
@abc.abstractmethod
def oauth_code_to_token(cls, code: str) -> dict[str, Any]:
raise NotImplementedError
# Event driven
class EventConnector(BaseConnector):
@abc.abstractmethod

View File

@@ -132,7 +132,6 @@ class LinearConnector(LoadConnector, PollConnector):
branchName
customerTicketCount
description
descriptionData
comments {
nodes {
url
@@ -215,5 +214,6 @@ class LinearConnector(LoadConnector, PollConnector):
if __name__ == "__main__":
connector = LinearConnector()
connector.load_credentials({"linear_api_key": os.environ["LINEAR_API_KEY"]})
document_batches = connector.load_from_state()
print(next(document_batches))

View File

@@ -12,12 +12,15 @@ from dateutil import parser
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.interfaces import SlimConnector
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.models import SlimDocument
from danswer.utils.logger import setup_logger
@@ -28,6 +31,8 @@ logger = setup_logger()
SLAB_GRAPHQL_MAX_TRIES = 10
SLAB_API_URL = "https://api.slab.com/v1/graphql"
_SLIM_BATCH_SIZE = 1000
def run_graphql_request(
graphql_query: dict, bot_token: str, max_tries: int = SLAB_GRAPHQL_MAX_TRIES
@@ -158,21 +163,26 @@ def get_slab_url_from_title_id(base_url: str, title: str, page_id: str) -> str:
return urljoin(urljoin(base_url, "posts/"), url_id)
class SlabConnector(LoadConnector, PollConnector):
class SlabConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
self,
base_url: str,
batch_size: int = INDEX_BATCH_SIZE,
slab_bot_token: str | None = None,
) -> None:
self.base_url = base_url
self.batch_size = batch_size
self.slab_bot_token = slab_bot_token
self._slab_bot_token: str | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.slab_bot_token = credentials["slab_bot_token"]
self._slab_bot_token = credentials["slab_bot_token"]
return None
@property
def slab_bot_token(self) -> str:
if self._slab_bot_token is None:
raise ConnectorMissingCredentialError("Slab")
return self._slab_bot_token
def _iterate_posts(
self, time_filter: Callable[[datetime], bool] | None = None
) -> GenerateDocumentsOutput:
@@ -227,3 +237,21 @@ class SlabConnector(LoadConnector, PollConnector):
yield from self._iterate_posts(
time_filter=lambda t: start_time <= t <= end_time
)
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
slim_doc_batch: list[SlimDocument] = []
for post_id in get_all_post_ids(self.slab_bot_token):
slim_doc_batch.append(
SlimDocument(
id=post_id,
)
)
if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
yield slim_doc_batch
slim_doc_batch = []
if slim_doc_batch:
yield slim_doc_batch

View File

@@ -134,7 +134,6 @@ def get_latest_message_time(thread: ThreadType) -> datetime:
def thread_to_doc(
workspace: str,
channel: ChannelType,
thread: ThreadType,
slack_cleaner: SlackTextCleaner,
@@ -171,15 +170,15 @@ def thread_to_doc(
else first_message
)
doc_sem_id = f"{initial_sender_name} in #{channel['name']}: {snippet}"
doc_sem_id = f"{initial_sender_name} in #{channel['name']}: {snippet}".replace(
"\n", " "
)
return Document(
id=f"{channel_id}__{thread[0]['ts']}",
sections=[
Section(
link=get_message_link(
event=m, workspace=workspace, channel_id=channel_id
),
link=get_message_link(event=m, client=client, channel_id=channel_id),
text=slack_cleaner.index_clean(cast(str, m["text"])),
)
for m in thread
@@ -263,7 +262,6 @@ def filter_channels(
def _get_all_docs(
client: WebClient,
workspace: str,
channels: list[str] | None = None,
channel_name_regex_enabled: bool = False,
oldest: str | None = None,
@@ -310,7 +308,6 @@ def _get_all_docs(
if filtered_thread:
channel_docs += 1
yield thread_to_doc(
workspace=workspace,
channel=channel,
thread=filtered_thread,
slack_cleaner=slack_cleaner,
@@ -373,14 +370,12 @@ def _get_all_doc_ids(
class SlackPollConnector(PollConnector, SlimConnector):
def __init__(
self,
workspace: str,
channels: list[str] | None = None,
# if specified, will treat the specified channel strings as
# regexes, and will only index channels that fully match the regexes
channel_regex_enabled: bool = False,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.workspace = workspace
self.channels = channels
self.channel_regex_enabled = channel_regex_enabled
self.batch_size = batch_size
@@ -414,7 +409,6 @@ class SlackPollConnector(PollConnector, SlimConnector):
documents: list[Document] = []
for document in _get_all_docs(
client=self.client,
workspace=self.workspace,
channels=self.channels,
channel_name_regex_enabled=self.channel_regex_enabled,
# NOTE: need to impute to `None` instead of using 0.0, since Slack will
@@ -438,7 +432,6 @@ if __name__ == "__main__":
slack_channel = os.environ.get("SLACK_CHANNEL")
connector = SlackPollConnector(
workspace=os.environ["SLACK_WORKSPACE"],
channels=[slack_channel] if slack_channel else None,
)
connector.load_credentials({"slack_bot_token": os.environ["SLACK_BOT_TOKEN"]})

View File

@@ -1,140 +0,0 @@
import json
import os
from datetime import datetime
from datetime import timezone
from pathlib import Path
from typing import Any
from typing import cast
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.slack.connector import filter_channels
from danswer.connectors.slack.utils import get_message_link
from danswer.utils.logger import setup_logger
logger = setup_logger()
def get_event_time(event: dict[str, Any]) -> datetime | None:
ts = event.get("ts")
if not ts:
return None
return datetime.fromtimestamp(float(ts), tz=timezone.utc)
class SlackLoadConnector(LoadConnector):
# WARNING: DEPRECATED, DO NOT USE
def __init__(
self,
workspace: str,
export_path_str: str,
channels: list[str] | None = None,
# if specified, will treat the specified channel strings as
# regexes, and will only index channels that fully match the regexes
channel_regex_enabled: bool = False,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.workspace = workspace
self.channels = channels
self.channel_regex_enabled = channel_regex_enabled
self.export_path_str = export_path_str
self.batch_size = batch_size
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
if credentials:
logger.warning("Unexpected credentials provided for Slack Load Connector")
return None
@staticmethod
def _process_batch_event(
slack_event: dict[str, Any],
channel: dict[str, Any],
matching_doc: Document | None,
workspace: str,
) -> Document | None:
if (
slack_event["type"] == "message"
and slack_event.get("subtype") != "channel_join"
):
if matching_doc:
return Document(
id=matching_doc.id,
sections=matching_doc.sections
+ [
Section(
link=get_message_link(
event=slack_event,
workspace=workspace,
channel_id=channel["id"],
),
text=slack_event["text"],
)
],
source=matching_doc.source,
semantic_identifier=matching_doc.semantic_identifier,
title="", # slack docs don't really have a "title"
doc_updated_at=get_event_time(slack_event),
metadata=matching_doc.metadata,
)
return Document(
id=slack_event["ts"],
sections=[
Section(
link=get_message_link(
event=slack_event,
workspace=workspace,
channel_id=channel["id"],
),
text=slack_event["text"],
)
],
source=DocumentSource.SLACK,
semantic_identifier=channel["name"],
title="", # slack docs don't really have a "title"
doc_updated_at=get_event_time(slack_event),
metadata={},
)
return None
def load_from_state(self) -> GenerateDocumentsOutput:
export_path = Path(self.export_path_str)
with open(export_path / "channels.json") as f:
all_channels = json.load(f)
filtered_channels = filter_channels(
all_channels, self.channels, self.channel_regex_enabled
)
document_batch: dict[str, Document] = {}
for channel_info in filtered_channels:
channel_dir_path = export_path / cast(str, channel_info["name"])
channel_file_paths = [
channel_dir_path / file_name
for file_name in os.listdir(channel_dir_path)
]
for path in channel_file_paths:
with open(path) as f:
events = cast(list[dict[str, Any]], json.load(f))
for slack_event in events:
doc = self._process_batch_event(
slack_event=slack_event,
channel=channel_info,
matching_doc=document_batch.get(
slack_event.get("thread_ts", "")
),
workspace=self.workspace,
)
if doc:
document_batch[doc.id] = doc
if len(document_batch) >= self.batch_size:
yield list(document_batch.values())
yield list(document_batch.values())

View File

@@ -2,6 +2,7 @@ import re
import time
from collections.abc import Callable
from collections.abc import Generator
from functools import lru_cache
from functools import wraps
from typing import Any
from typing import cast
@@ -21,19 +22,21 @@ basic_retry_wrapper = retry_builder()
_SLACK_LIMIT = 900
@lru_cache()
def get_base_url(token: str) -> str:
"""Retrieve and cache the base URL of the Slack workspace based on the client token."""
client = WebClient(token=token)
return client.auth_test()["url"]
def get_message_link(
event: dict[str, Any], workspace: str, channel_id: str | None = None
event: dict[str, Any], client: WebClient, channel_id: str | None = None
) -> str:
channel_id = channel_id or cast(
str, event["channel"]
) # channel must either be present in the event or passed in
message_ts = cast(str, event["ts"])
message_ts_without_dot = message_ts.replace(".", "")
thread_ts = cast(str | None, event.get("thread_ts"))
return (
f"https://{workspace}.slack.com/archives/{channel_id}/p{message_ts_without_dot}"
+ (f"?thread_ts={thread_ts}" if thread_ts else "")
)
channel_id = channel_id or event["channel"]
message_ts = event["ts"]
response = client.chat_getPermalink(channel=channel_id, message_ts=message_ts)
permalink = response["permalink"]
return permalink
def _make_slack_api_call_logged(

View File

@@ -33,7 +33,7 @@ def get_created_datetime(chat_message: ChatMessage) -> datetime:
def _extract_channel_members(channel: Channel) -> list[BasicExpertInfo]:
channel_members_list: list[BasicExpertInfo] = []
members = channel.members.get().execute_query()
members = channel.members.get().execute_query_retry()
for member in members:
channel_members_list.append(BasicExpertInfo(display_name=member.display_name))
return channel_members_list
@@ -51,7 +51,7 @@ def _get_threads_from_channel(
end = end.replace(tzinfo=timezone.utc)
query = channel.messages.get()
base_messages: list[ChatMessage] = query.execute_query()
base_messages: list[ChatMessage] = query.execute_query_retry()
threads: list[list[ChatMessage]] = []
for base_message in base_messages:
@@ -65,7 +65,7 @@ def _get_threads_from_channel(
continue
reply_query = base_message.replies.get_all()
replies = reply_query.execute_query()
replies = reply_query.execute_query_retry()
# start a list containing the base message and its replies
thread: list[ChatMessage] = [base_message]
@@ -82,7 +82,7 @@ def _get_channels_from_teams(
channels_list: list[Channel] = []
for team in teams:
query = team.channels.get()
channels = query.execute_query()
channels = query.execute_query_retry()
channels_list.extend(channels)
return channels_list
@@ -210,7 +210,7 @@ class TeamsConnector(LoadConnector, PollConnector):
teams_list: list[Team] = []
teams = self.graph_client.teams.get().execute_query()
teams = self.graph_client.teams.get().execute_query_retry()
if len(self.requested_team_list) > 0:
adjusted_request_strings = [
@@ -234,14 +234,25 @@ class TeamsConnector(LoadConnector, PollConnector):
raise ConnectorMissingCredentialError("Teams")
teams = self._get_all_teams()
logger.debug(f"Found available teams: {[str(t) for t in teams]}")
if not teams:
msg = "No teams found."
logger.error(msg)
raise ValueError(msg)
channels = _get_channels_from_teams(
teams=teams,
)
logger.debug(f"Found available channels: {[c.id for c in channels]}")
if not channels:
msg = "No channels found."
logger.error(msg)
raise ValueError(msg)
# goes over channels, converts them into Document objects and then yields them in batches
doc_batch: list[Document] = []
for channel in channels:
logger.debug(f"Fetching threads from channel: {channel.id}")
thread_list = _get_threads_from_channel(channel, start=start, end=end)
for thread in thread_list:
converted_doc = _convert_thread_to_document(channel, thread)
@@ -259,8 +270,8 @@ class TeamsConnector(LoadConnector, PollConnector):
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_datetime = datetime.utcfromtimestamp(start)
end_datetime = datetime.utcfromtimestamp(end)
start_datetime = datetime.fromtimestamp(start, timezone.utc)
end_datetime = datetime.fromtimestamp(end, timezone.utc)
return self._fetch_from_teams(start=start_datetime, end=end_datetime)

View File

@@ -5,7 +5,11 @@ from typing import cast
from sqlalchemy.orm import Session
from danswer.chat.models import PromptConfig
from danswer.chat.models import SectionRelevancePiece
from danswer.chat.prune_and_merge import _merge_sections
from danswer.chat.prune_and_merge import ChunkRange
from danswer.chat.prune_and_merge import merge_chunk_intervals
from danswer.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
from danswer.context.search.enums import LLMEvaluationType
from danswer.context.search.enums import QueryFlow
@@ -27,10 +31,6 @@ from danswer.db.models import User
from danswer.db.search_settings import get_current_search_settings
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaChunkRequest
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prune_and_merge import _merge_sections
from danswer.llm.answering.prune_and_merge import ChunkRange
from danswer.llm.answering.prune_and_merge import merge_chunk_intervals
from danswer.llm.interfaces import LLM
from danswer.secondary_llm_flows.agentic_evaluation import evaluate_inference_section
from danswer.utils.logger import setup_logger

View File

@@ -16,24 +16,31 @@ from slack_sdk.models.blocks import SectionBlock
from slack_sdk.models.blocks.basic_components import MarkdownTextObject
from slack_sdk.models.blocks.block_elements import ImageElement
from danswer.chat.models import DanswerQuote
from danswer.chat.models import ChatDanswerBotResponse
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import SearchFeedbackType
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY
from danswer.context.search.models import SavedSearchDoc
from danswer.danswerbot.slack.constants import CONTINUE_IN_WEB_UI_ACTION_ID
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID
from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.formatting import format_slack_message
from danswer.danswerbot.slack.icons import source_to_github_img_link
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import build_continue_in_web_ui_id
from danswer.danswerbot.slack.utils import build_feedback_id
from danswer.danswerbot.slack.utils import remove_slack_text_interactions
from danswer.danswerbot.slack.utils import translate_vespa_highlight_to_slack
from danswer.db.chat import get_chat_session_by_message_id
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import ChannelConfig
from danswer.utils.text_processing import decode_escapes
from danswer.utils.text_processing import replace_whitespaces_w_space
_MAX_BLURB_LEN = 45
@@ -101,12 +108,12 @@ def _split_text(text: str, limit: int = 3000) -> list[str]:
return chunks
def clean_markdown_link_text(text: str) -> str:
def _clean_markdown_link_text(text: str) -> str:
# Remove any newlines within the text
return text.replace("\n", " ").strip()
def build_qa_feedback_block(
def _build_qa_feedback_block(
message_id: int, feedback_reminder_id: str | None = None
) -> Block:
return ActionsBlock(
@@ -115,7 +122,6 @@ def build_qa_feedback_block(
ButtonElement(
action_id=LIKE_BLOCK_ACTION_ID,
text="👍 Helpful",
style="primary",
value=feedback_reminder_id,
),
ButtonElement(
@@ -155,7 +161,7 @@ def get_document_feedback_blocks() -> Block:
)
def build_doc_feedback_block(
def _build_doc_feedback_block(
message_id: int,
document_id: str,
document_rank: int,
@@ -182,7 +188,7 @@ def get_restate_blocks(
]
def build_documents_blocks(
def _build_documents_blocks(
documents: list[SavedSearchDoc],
message_id: int | None,
num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY,
@@ -198,7 +204,8 @@ def build_documents_blocks(
continue
seen_docs_identifiers.add(d.document_id)
doc_sem_id = d.semantic_identifier
# Strip newlines from the semantic identifier for Slackbot formatting
doc_sem_id = d.semantic_identifier.replace("\n", " ")
if d.source_type == DocumentSource.SLACK.value:
doc_sem_id = "#" + doc_sem_id
@@ -223,7 +230,7 @@ def build_documents_blocks(
feedback: ButtonElement | dict = {}
if message_id is not None:
feedback = build_doc_feedback_block(
feedback = _build_doc_feedback_block(
message_id=message_id,
document_id=d.document_id,
document_rank=rank,
@@ -241,7 +248,7 @@ def build_documents_blocks(
return section_blocks
def build_sources_blocks(
def _build_sources_blocks(
cited_documents: list[tuple[int, SavedSearchDoc]],
num_docs_to_display: int = DANSWER_BOT_NUM_DOCS_TO_DISPLAY,
) -> list[Block]:
@@ -286,7 +293,7 @@ def build_sources_blocks(
+ ([days_ago_str] if days_ago_str else [])
)
document_title = clean_markdown_link_text(doc_sem_id)
document_title = _clean_markdown_link_text(doc_sem_id)
img_link = source_to_github_img_link(d.source_type)
section_blocks.append(
@@ -317,106 +324,105 @@ def build_sources_blocks(
return section_blocks
def build_quotes_block(
quotes: list[DanswerQuote],
def _priority_ordered_documents_blocks(
answer: ChatDanswerBotResponse,
) -> list[Block]:
quote_lines: list[str] = []
doc_to_quotes: dict[str, list[str]] = {}
doc_to_link: dict[str, str] = {}
doc_to_sem_id: dict[str, str] = {}
for q in quotes:
quote = q.quote
doc_id = q.document_id
doc_link = q.link
doc_name = q.semantic_identifier
if doc_link and doc_name and doc_id and quote:
if doc_id not in doc_to_quotes:
doc_to_quotes[doc_id] = [quote]
doc_to_link[doc_id] = doc_link
doc_to_sem_id[doc_id] = (
doc_name
if q.source_type != DocumentSource.SLACK.value
else "#" + doc_name
)
else:
doc_to_quotes[doc_id].append(quote)
for doc_id, quote_strs in doc_to_quotes.items():
quotes_str_clean = [
replace_whitespaces_w_space(q_str).strip() for q_str in quote_strs
]
longest_quotes = sorted(quotes_str_clean, key=len, reverse=True)[:5]
single_quote_str = "\n".join([f"```{q_str}```" for q_str in longest_quotes])
link = doc_to_link[doc_id]
sem_id = doc_to_sem_id[doc_id]
quote_lines.append(
f"<{link}|{sem_id}>:\n{remove_slack_text_interactions(single_quote_str)}"
)
if not doc_to_quotes:
docs_response = answer.docs if answer.docs else None
top_docs = docs_response.top_documents if docs_response else []
llm_doc_inds = answer.llm_selected_doc_indices or []
llm_docs = [top_docs[i] for i in llm_doc_inds]
remaining_docs = [
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
]
priority_ordered_docs = llm_docs + remaining_docs
if not priority_ordered_docs:
return []
return [SectionBlock(text="*Relevant Snippets*\n" + "\n".join(quote_lines))]
document_blocks = _build_documents_blocks(
documents=priority_ordered_docs,
message_id=answer.chat_message_id,
)
if document_blocks:
document_blocks = [DividerBlock()] + document_blocks
return document_blocks
def build_qa_response_blocks(
message_id: int | None,
answer: str | None,
quotes: list[DanswerQuote] | None,
source_filters: list[DocumentSource] | None,
time_cutoff: datetime | None,
favor_recent: bool,
skip_quotes: bool = False,
process_message_for_citations: bool = False,
skip_ai_feedback: bool = False,
feedback_reminder_id: str | None = None,
def _build_citations_blocks(
answer: ChatDanswerBotResponse,
) -> list[Block]:
docs_response = answer.docs if answer.docs else None
top_docs = docs_response.top_documents if docs_response else []
citations = answer.citations or []
cited_docs = []
for citation in citations:
matching_doc = next(
(d for d in top_docs if d.document_id == citation.document_id),
None,
)
if matching_doc:
cited_docs.append((citation.citation_num, matching_doc))
cited_docs.sort()
citations_block = _build_sources_blocks(cited_documents=cited_docs)
return citations_block
def _build_qa_response_blocks(
answer: ChatDanswerBotResponse,
process_message_for_citations: bool = False,
) -> list[Block]:
retrieval_info = answer.docs
if not retrieval_info:
# This should not happen, even with no docs retrieved, there is still info returned
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
formatted_answer = format_slack_message(answer.answer) if answer.answer else None
if DISABLE_GENERATIVE_AI:
return []
quotes_blocks: list[Block] = []
filter_block: Block | None = None
if time_cutoff or favor_recent or source_filters:
if (
retrieval_info.applied_time_cutoff
or retrieval_info.recency_bias_multiplier > 1
or retrieval_info.applied_source_filters
):
filter_text = "Filters: "
if source_filters:
sources_str = ", ".join([s.value for s in source_filters])
if retrieval_info.applied_source_filters:
sources_str = ", ".join(
[s.value for s in retrieval_info.applied_source_filters]
)
filter_text += f"`Sources in [{sources_str}]`"
if time_cutoff or favor_recent:
if (
retrieval_info.applied_time_cutoff
or retrieval_info.recency_bias_multiplier > 1
):
filter_text += " and "
if time_cutoff is not None:
time_str = time_cutoff.strftime("%b %d, %Y")
if retrieval_info.applied_time_cutoff is not None:
time_str = retrieval_info.applied_time_cutoff.strftime("%b %d, %Y")
filter_text += f"`Docs Updated >= {time_str}` "
if favor_recent:
if time_cutoff is not None:
if retrieval_info.recency_bias_multiplier > 1:
if retrieval_info.applied_time_cutoff is not None:
filter_text += "+ "
filter_text += "`Prioritize Recently Updated Docs`"
filter_block = SectionBlock(text=f"_{filter_text}_")
if not answer:
if not formatted_answer:
answer_blocks = [
SectionBlock(
text="Sorry, I was unable to find an answer, but I did find some potentially relevant docs 🤓"
)
]
else:
answer_processed = decode_escapes(remove_slack_text_interactions(answer))
answer_processed = decode_escapes(
remove_slack_text_interactions(formatted_answer)
)
if process_message_for_citations:
answer_processed = _process_citations_for_slack(answer_processed)
answer_blocks = [
SectionBlock(text=text) for text in _split_text(answer_processed)
]
if quotes:
quotes_blocks = build_quotes_block(quotes)
# if no quotes OR `build_quotes_block()` did not give back any blocks
if not quotes_blocks:
quotes_blocks = [
SectionBlock(
text="*Warning*: no sources were quoted for this answer, so it may be unreliable 😔"
)
]
response_blocks: list[Block] = []
@@ -425,20 +431,34 @@ def build_qa_response_blocks(
response_blocks.extend(answer_blocks)
if message_id is not None and not skip_ai_feedback:
response_blocks.append(
build_qa_feedback_block(
message_id=message_id, feedback_reminder_id=feedback_reminder_id
)
)
if not skip_quotes:
response_blocks.extend(quotes_blocks)
return response_blocks
def build_follow_up_block(message_id: int | None) -> ActionsBlock:
def _build_continue_in_web_ui_block(
tenant_id: str | None,
message_id: int | None,
) -> Block:
if message_id is None:
raise ValueError("No message id provided to build continue in web ui block")
with get_session_with_tenant(tenant_id) as db_session:
chat_session = get_chat_session_by_message_id(
db_session=db_session,
message_id=message_id,
)
return ActionsBlock(
block_id=build_continue_in_web_ui_id(message_id),
elements=[
ButtonElement(
action_id=CONTINUE_IN_WEB_UI_ACTION_ID,
text="Continue Chat in Danswer!",
style="primary",
url=f"{WEB_DOMAIN}/chat?slackChatId={chat_session.id}",
),
],
)
def _build_follow_up_block(message_id: int | None) -> ActionsBlock:
return ActionsBlock(
block_id=build_feedback_id(message_id) if message_id is not None else None,
elements=[
@@ -483,3 +503,75 @@ def build_follow_up_resolved_blocks(
]
)
return [text_block, button_block]
def build_slack_response_blocks(
answer: ChatDanswerBotResponse,
tenant_id: str | None,
message_info: SlackMessageInfo,
channel_conf: ChannelConfig | None,
use_citations: bool,
feedback_reminder_id: str | None,
skip_ai_feedback: bool = False,
) -> list[Block]:
"""
This function is a top level function that builds all the blocks for the Slack response.
It also handles combining all the blocks together.
"""
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
restate_question_block = get_restate_blocks(
message_info.thread_messages[-1].message, message_info.is_bot_msg
)
answer_blocks = _build_qa_response_blocks(
answer=answer,
process_message_for_citations=use_citations,
)
web_follow_up_block = []
if channel_conf and channel_conf.get("show_continue_in_web_ui"):
web_follow_up_block.append(
_build_continue_in_web_ui_block(
tenant_id=tenant_id,
message_id=answer.chat_message_id,
)
)
follow_up_block = []
if channel_conf and channel_conf.get("follow_up_tags") is not None:
follow_up_block.append(
_build_follow_up_block(message_id=answer.chat_message_id)
)
ai_feedback_block = []
if answer.chat_message_id is not None and not skip_ai_feedback:
ai_feedback_block.append(
_build_qa_feedback_block(
message_id=answer.chat_message_id,
feedback_reminder_id=feedback_reminder_id,
)
)
citations_blocks = []
document_blocks = []
if use_citations and answer.citations:
citations_blocks = _build_citations_blocks(answer)
else:
document_blocks = _priority_ordered_documents_blocks(answer)
citations_divider = [DividerBlock()] if citations_blocks else []
buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else []
all_blocks = (
restate_question_block
+ answer_blocks
+ ai_feedback_block
+ citations_divider
+ citations_blocks
+ document_blocks
+ buttons_divider
+ web_follow_up_block
+ follow_up_block
)
return all_blocks

View File

@@ -2,6 +2,7 @@ from enum import Enum
LIKE_BLOCK_ACTION_ID = "feedback-like"
DISLIKE_BLOCK_ACTION_ID = "feedback-dislike"
CONTINUE_IN_WEB_UI_ACTION_ID = "continue-in-web-ui"
FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID = "feedback-doc-button"
IMMEDIATE_RESOLVED_BUTTON_ACTION_ID = "immediate-resolved-button"
FOLLOWUP_BUTTON_ACTION_ID = "followup-button"

View File

@@ -28,7 +28,7 @@ from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import build_feedback_id
from danswer.danswerbot.slack.utils import decompose_action_id
from danswer.danswerbot.slack.utils import fetch_group_ids_from_names
from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails
from danswer.danswerbot.slack.utils import fetch_slack_user_ids_from_emails
from danswer.danswerbot.slack.utils import get_channel_name_from_id
from danswer.danswerbot.slack.utils import get_feedback_visibility
from danswer.danswerbot.slack.utils import read_slack_thread
@@ -267,7 +267,7 @@ def handle_followup_button(
tag_names = slack_channel_config.channel_config.get("follow_up_tags")
remaining = None
if tag_names:
tag_ids, remaining = fetch_user_ids_from_emails(
tag_ids, remaining = fetch_slack_user_ids_from_emails(
tag_names, client.web_client
)
if remaining:

View File

@@ -13,7 +13,7 @@ from danswer.danswerbot.slack.handlers.handle_standard_answers import (
handle_standard_answers,
)
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails
from danswer.danswerbot.slack.utils import fetch_slack_user_ids_from_emails
from danswer.danswerbot.slack.utils import fetch_user_ids_from_groups
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import slack_usage_report
@@ -184,7 +184,7 @@ def handle_message(
send_to: list[str] | None = None
missing_users: list[str] | None = None
if respond_member_group_list:
send_to, missing_ids = fetch_user_ids_from_emails(
send_to, missing_ids = fetch_slack_user_ids_from_emails(
respond_member_group_list, client
)

View File

@@ -1,60 +1,43 @@
import functools
from collections.abc import Callable
from typing import Any
from typing import cast
from typing import Optional
from typing import TypeVar
from retry import retry
from slack_sdk import WebClient
from slack_sdk.models.blocks import DividerBlock
from slack_sdk.models.blocks import SectionBlock
from danswer.chat.chat_utils import prepare_chat_message_request
from danswer.chat.models import ChatDanswerBotResponse
from danswer.chat.process_message import gather_stream_for_slack
from danswer.chat.process_message import stream_chat_message_objects
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT
from danswer.configs.constants import DEFAULT_PERSONA_ID
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
from danswer.configs.danswerbot_configs import MAX_THREAD_CONTEXT_PERCENTAGE
from danswer.context.search.enums import OptionalSearchSetting
from danswer.context.search.models import BaseFilters
from danswer.context.search.models import RerankingDetails
from danswer.context.search.models import RetrievalDetails
from danswer.danswerbot.slack.blocks import build_documents_blocks
from danswer.danswerbot.slack.blocks import build_follow_up_block
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
from danswer.danswerbot.slack.blocks import build_sources_blocks
from danswer.danswerbot.slack.blocks import get_restate_blocks
from danswer.danswerbot.slack.formatting import format_slack_message
from danswer.danswerbot.slack.blocks import build_slack_response_blocks
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
from danswer.danswerbot.slack.handlers.utils import slackify_message_thread
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import SlackRateLimiter
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import Persona
from danswer.db.models import SlackBotResponseType
from danswer.db.models import SlackChannelConfig
from danswer.db.persona import fetch_persona_by_id
from danswer.db.search_settings import get_current_search_settings
from danswer.db.models import User
from danswer.db.persona import get_persona_by_id
from danswer.db.users import get_user_by_email
from danswer.llm.answering.prompts.citations_prompt import (
compute_max_document_tokens_for_persona,
)
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_max_input_tokens
from danswer.one_shot_answer.answer_question import get_search_answer
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.utils.logger import DanswerLoggingAdapter
srl = SlackRateLimiter()
RT = TypeVar("RT") # return type
@@ -89,16 +72,14 @@ def handle_regular_answer(
feedback_reminder_id: str | None,
tenant_id: str | None,
num_retries: int = DANSWER_BOT_NUM_RETRIES,
answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
thread_context_percent: float = DANSWER_BOT_TARGET_CHUNK_PERCENTAGE,
thread_context_percent: float = MAX_THREAD_CONTEXT_PERCENTAGE,
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
disable_docs_only_answer: bool = DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER,
disable_cot: bool = DANSWER_BOT_DISABLE_COT,
reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
) -> bool:
channel_conf = slack_channel_config.channel_config if slack_channel_config else None
messages = message_info.thread_messages
message_ts_to_respond_to = message_info.msg_to_respond
is_bot_msg = message_info.is_bot_msg
user = None
@@ -108,9 +89,18 @@ def handle_regular_answer(
user = get_user_by_email(message_info.email, db_session)
document_set_names: list[str] | None = None
persona = slack_channel_config.persona if slack_channel_config else None
prompt = None
if persona:
# If no persona is specified, use the default search based persona
# This way slack flow always has a persona
persona = slack_channel_config.persona if slack_channel_config else None
if not persona:
with get_session_with_tenant(tenant_id) as db_session:
persona = get_persona_by_id(DEFAULT_PERSONA_ID, user, db_session)
document_set_names = [
document_set.name for document_set in persona.document_sets
]
prompt = persona.prompts[0] if persona.prompts else None
else:
document_set_names = [
document_set.name for document_set in persona.document_sets
]
@@ -118,6 +108,26 @@ def handle_regular_answer(
should_respond_even_with_no_docs = persona.num_chunks == 0 if persona else False
# TODO: Add in support for Slack to truncate messages based on max LLM context
# llm, _ = get_llms_for_persona(persona)
# llm_tokenizer = get_tokenizer(
# model_name=llm.config.model_name,
# provider_type=llm.config.model_provider,
# )
# # In cases of threads, split the available tokens between docs and thread context
# input_tokens = get_max_input_tokens(
# model_name=llm.config.model_name,
# model_provider=llm.config.model_provider,
# )
# max_history_tokens = int(input_tokens * thread_context_percent)
# combined_message = combine_message_thread(
# messages, max_tokens=max_history_tokens, llm_tokenizer=llm_tokenizer
# )
combined_message = slackify_message_thread(messages)
bypass_acl = False
if (
slack_channel_config
@@ -128,13 +138,6 @@ def handle_regular_answer(
# with non-public document sets
bypass_acl = True
# figure out if we want to use citations or quotes
use_citations = (
not DANSWER_BOT_USE_QUOTES
if slack_channel_config is None
else slack_channel_config.response_type == SlackBotResponseType.CITATIONS
)
if not message_ts_to_respond_to and not is_bot_msg:
# if the message is not "/danswer" command, then it should have a message ts to respond to
raise RuntimeError(
@@ -147,75 +150,23 @@ def handle_regular_answer(
backoff=2,
)
@rate_limits(client=client, channel=channel, thread_ts=message_ts_to_respond_to)
def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | None:
max_document_tokens: int | None = None
max_history_tokens: int | None = None
def _get_slack_answer(
new_message_request: CreateChatMessageRequest, danswer_user: User | None
) -> ChatDanswerBotResponse:
with get_session_with_tenant(tenant_id) as db_session:
if len(new_message_request.messages) > 1:
if new_message_request.persona_config:
raise RuntimeError("Slack bot does not support persona config")
elif new_message_request.persona_id is not None:
persona = cast(
Persona,
fetch_persona_by_id(
db_session,
new_message_request.persona_id,
user=None,
get_editable=False,
),
)
else:
raise RuntimeError(
"No persona id provided, this should never happen."
)
llm, _ = get_llms_for_persona(persona)
# In cases of threads, split the available tokens between docs and thread context
input_tokens = get_max_input_tokens(
model_name=llm.config.model_name,
model_provider=llm.config.model_provider,
)
max_history_tokens = int(input_tokens * thread_context_percent)
remaining_tokens = input_tokens - max_history_tokens
query_text = new_message_request.messages[0].message
if persona:
max_document_tokens = compute_max_document_tokens_for_persona(
persona=persona,
actual_user_input=query_text,
max_llm_token_override=remaining_tokens,
)
else:
max_document_tokens = (
remaining_tokens
- 512 # Needs to be more than any of the QA prompts
- check_number_of_tokens(query_text)
)
if DISABLE_GENERATIVE_AI:
return None
# This also handles creating the query event in postgres
answer = get_search_answer(
query_req=new_message_request,
user=user,
max_document_tokens=max_document_tokens,
max_history_tokens=max_history_tokens,
packets = stream_chat_message_objects(
new_msg_req=new_message_request,
user=danswer_user,
db_session=db_session,
answer_generation_timeout=answer_generation_timeout,
enable_reflexion=reflexion,
bypass_acl=bypass_acl,
use_citations=use_citations,
danswerbot_flow=True,
)
if not answer.error_msg:
return answer
else:
raise RuntimeError(answer.error_msg)
answer = gather_stream_for_slack(packets)
if answer.error_msg:
raise RuntimeError(answer.error_msg)
return answer
try:
# By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters
@@ -245,26 +196,24 @@ def handle_regular_answer(
enable_auto_detect_filters=auto_detect_filters,
)
# Always apply reranking settings if it exists, this is the non-streaming flow
with get_session_with_tenant(tenant_id) as db_session:
saved_search_settings = get_current_search_settings(db_session)
# This includes throwing out answer via reflexion
answer = _get_answer(
DirectQARequest(
messages=messages,
multilingual_query_expansion=saved_search_settings.multilingual_expansion
if saved_search_settings
else None,
prompt_id=prompt.id if prompt else None,
persona_id=persona.id if persona is not None else 0,
retrieval_options=retrieval_details,
chain_of_thought=not disable_cot,
rerank_settings=RerankingDetails.from_db_model(saved_search_settings)
if saved_search_settings
else None,
answer_request = prepare_chat_message_request(
message_text=combined_message,
user=user,
persona_id=persona.id,
# This is not used in the Slack flow, only in the answer API
persona_override_config=None,
prompt=prompt,
message_ts_to_respond_to=message_ts_to_respond_to,
retrieval_details=retrieval_details,
rerank_settings=None, # Rerank customization supported in Slack flow
db_session=db_session,
)
answer = _get_slack_answer(
new_message_request=answer_request, danswer_user=user
)
except Exception as e:
logger.exception(
f"Unable to process message - did not successfully answer "
@@ -365,7 +314,7 @@ def handle_regular_answer(
top_docs = retrieval_info.top_documents
if not top_docs and not should_respond_even_with_no_docs:
logger.error(
f"Unable to answer question: '{answer.rephrase}' - no documents found"
f"Unable to answer question: '{combined_message}' - no documents found"
)
# Optionally, respond in thread with the error message
# Used primarily for debugging purposes
@@ -386,18 +335,18 @@ def handle_regular_answer(
)
return True
only_respond_with_citations_or_quotes = (
only_respond_if_citations = (
channel_conf
and "well_answered_postfilter" in channel_conf.get("answer_filters", [])
)
has_citations_or_quotes = bool(answer.citations or answer.quotes)
if (
only_respond_with_citations_or_quotes
and not has_citations_or_quotes
only_respond_if_citations
and not answer.citations
and not message_info.bypass_filters
):
logger.error(
f"Unable to find citations or quotes to answer: '{answer.rephrase}' - not answering!"
f"Unable to find citations to answer: '{answer.answer}' - not answering!"
)
# Optionally, respond in thread with the error message
# Used primarily for debugging purposes
@@ -411,67 +360,22 @@ def handle_regular_answer(
)
return True
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg)
formatted_answer = format_slack_message(answer.answer) if answer.answer else None
answer_blocks = build_qa_response_blocks(
message_id=answer.chat_message_id,
answer=formatted_answer,
quotes=answer.quotes.quotes if answer.quotes else None,
source_filters=retrieval_info.applied_source_filters,
time_cutoff=retrieval_info.applied_time_cutoff,
favor_recent=retrieval_info.recency_bias_multiplier > 1,
# currently Personas don't support quotes
# if citations are enabled, also don't use quotes
skip_quotes=persona is not None or use_citations,
process_message_for_citations=use_citations,
all_blocks = build_slack_response_blocks(
tenant_id=tenant_id,
message_info=message_info,
answer=answer,
channel_conf=channel_conf,
use_citations=True, # No longer supporting quotes
feedback_reminder_id=feedback_reminder_id,
)
# Get the chunks fed to the LLM only, then fill with other docs
llm_doc_inds = answer.llm_selected_doc_indices or []
llm_docs = [top_docs[i] for i in llm_doc_inds]
remaining_docs = [
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
]
priority_ordered_docs = llm_docs + remaining_docs
document_blocks = []
citations_block = []
# if citations are enabled, only show cited documents
if use_citations:
citations = answer.citations or []
cited_docs = []
for citation in citations:
matching_doc = next(
(d for d in top_docs if d.document_id == citation.document_id),
None,
)
if matching_doc:
cited_docs.append((citation.citation_num, matching_doc))
cited_docs.sort()
citations_block = build_sources_blocks(cited_documents=cited_docs)
elif priority_ordered_docs:
document_blocks = build_documents_blocks(
documents=priority_ordered_docs,
message_id=answer.chat_message_id,
)
document_blocks = [DividerBlock()] + document_blocks
all_blocks = (
restate_question_block + answer_blocks + citations_block + document_blocks
)
if channel_conf and channel_conf.get("follow_up_tags") is not None:
all_blocks.append(build_follow_up_block(message_id=answer.chat_message_id))
try:
respond_in_thread(
client=client,
channel=channel,
receiver_ids=receiver_ids,
receiver_ids=[message_info.sender]
if message_info.is_bot_msg and message_info.sender
else receiver_ids,
text="Hello! Danswer has some results for you!",
blocks=all_blocks,
thread_ts=message_ts_to_respond_to,

View File

@@ -1,8 +1,33 @@
from slack_sdk import WebClient
from danswer.chat.models import ThreadMessage
from danswer.configs.constants import MessageType
from danswer.danswerbot.slack.utils import respond_in_thread
def slackify_message_thread(messages: list[ThreadMessage]) -> str:
# Note: this does not handle extremely long threads, every message will be included
# with weaker LLMs, this could cause issues with exceeeding the token limit
if not messages:
return ""
message_strs: list[str] = []
for message in messages:
if message.role == MessageType.USER:
message_text = (
f"{message.sender or 'Unknown User'} said in Slack:\n{message.message}"
)
elif message.role == MessageType.ASSISTANT:
message_text = f"AI said in Slack:\n{message.message}"
else:
message_text = (
f"{message.role.value.upper()} said in Slack:\n{message.message}"
)
message_strs.append(message_text)
return "\n\n".join(message_strs)
def send_team_member_message(
client: WebClient,
channel: str,

View File

@@ -19,6 +19,8 @@ from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from sqlalchemy.orm import Session
from danswer.chat.models import ThreadMessage
from danswer.configs.app_configs import DEV_MODE
from danswer.configs.app_configs import POD_NAME
from danswer.configs.app_configs import POD_NAMESPACE
from danswer.configs.constants import DanswerRedisLocks
@@ -74,7 +76,6 @@ from danswer.db.slack_bot import fetch_slack_bots
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.one_shot_answer.models import ThreadMessage
from danswer.redis.redis_pool import get_redis_client
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
@@ -250,7 +251,7 @@ class SlackbotHandler:
nx=True,
ex=TENANT_LOCK_EXPIRATION,
)
if not acquired:
if not acquired and not DEV_MODE:
logger.debug(f"Another pod holds the lock for tenant {tenant_id}")
continue

View File

@@ -1,6 +1,6 @@
from pydantic import BaseModel
from danswer.one_shot_answer.models import ThreadMessage
from danswer.chat.models import ThreadMessage
class SlackMessageInfo(BaseModel):

View File

@@ -3,14 +3,15 @@ import random
import re
import string
import time
import uuid
from typing import Any
from typing import cast
from typing import Optional
from retry import retry
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.models.blocks import Block
from slack_sdk.models.blocks import SectionBlock
from slack_sdk.models.metadata import Metadata
from slack_sdk.socket_mode import SocketModeClient
@@ -30,13 +31,13 @@ from danswer.configs.danswerbot_configs import (
from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.connectors.slack.utils import SlackTextCleaner
from danswer.danswerbot.slack.constants import FeedbackVisibility
from danswer.danswerbot.slack.models import ThreadMessage
from danswer.db.engine import get_session_with_tenant
from danswer.db.users import get_user_by_email
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llms
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import message_to_string
from danswer.one_shot_answer.models import ThreadMessage
from danswer.prompts.miscellaneous_prompts import SLACK_LANGUAGE_REPHRASE_PROMPT
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
@@ -140,6 +141,40 @@ def remove_danswer_bot_tag(message_str: str, client: WebClient) -> str:
return re.sub(rf"<@{bot_tag_id}>\s", "", message_str)
def _check_for_url_in_block(block: Block) -> bool:
"""
Check if the block has a key that contains "url" in it
"""
block_dict = block.to_dict()
def check_dict_for_url(d: dict) -> bool:
for key, value in d.items():
if "url" in key.lower():
return True
if isinstance(value, dict):
if check_dict_for_url(value):
return True
elif isinstance(value, list):
for item in value:
if isinstance(item, dict) and check_dict_for_url(item):
return True
return False
return check_dict_for_url(block_dict)
def _build_error_block(error_message: str) -> Block:
"""
Build an error block to display in slack so that the user can see
the error without completely breaking
"""
display_text = (
"There was an error displaying all of the Onyx answers."
f" Please let an admin or an onyx developer know. Error: {error_message}"
)
return SectionBlock(text=display_text)
@retry(
tries=DANSWER_BOT_NUM_RETRIES,
delay=0.25,
@@ -162,24 +197,9 @@ def respond_in_thread(
message_ids: list[str] = []
if not receiver_ids:
slack_call = make_slack_api_rate_limited(client.chat_postMessage)
response = slack_call(
channel=channel,
text=text,
blocks=blocks,
thread_ts=thread_ts,
metadata=metadata,
unfurl_links=unfurl,
unfurl_media=unfurl,
)
if not response.get("ok"):
raise RuntimeError(f"Failed to post message: {response}")
message_ids.append(response["message_ts"])
else:
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
for receiver in receiver_ids:
try:
response = slack_call(
channel=channel,
user=receiver,
text=text,
blocks=blocks,
thread_ts=thread_ts,
@@ -187,8 +207,68 @@ def respond_in_thread(
unfurl_links=unfurl,
unfurl_media=unfurl,
)
if not response.get("ok"):
raise RuntimeError(f"Failed to post message: {response}")
except Exception as e:
logger.warning(f"Failed to post message: {e} \n blocks: {blocks}")
logger.warning("Trying again without blocks that have urls")
if not blocks:
raise e
blocks_without_urls = [
block for block in blocks if not _check_for_url_in_block(block)
]
blocks_without_urls.append(_build_error_block(str(e)))
# Try again wtihout blocks containing url
response = slack_call(
channel=channel,
text=text,
blocks=blocks_without_urls,
thread_ts=thread_ts,
metadata=metadata,
unfurl_links=unfurl,
unfurl_media=unfurl,
)
message_ids.append(response["message_ts"])
else:
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
for receiver in receiver_ids:
try:
response = slack_call(
channel=channel,
user=receiver,
text=text,
blocks=blocks,
thread_ts=thread_ts,
metadata=metadata,
unfurl_links=unfurl,
unfurl_media=unfurl,
)
except Exception as e:
logger.warning(f"Failed to post message: {e} \n blocks: {blocks}")
logger.warning("Trying again without blocks that have urls")
if not blocks:
raise e
blocks_without_urls = [
block for block in blocks if not _check_for_url_in_block(block)
]
blocks_without_urls.append(_build_error_block(str(e)))
# Try again wtihout blocks containing url
response = slack_call(
channel=channel,
user=receiver,
text=text,
blocks=blocks_without_urls,
thread_ts=thread_ts,
metadata=metadata,
unfurl_links=unfurl,
unfurl_media=unfurl,
)
message_ids.append(response["message_ts"])
return message_ids
@@ -216,6 +296,13 @@ def build_feedback_id(
return unique_prefix + ID_SEPARATOR + feedback_id
def build_continue_in_web_ui_id(
message_id: int,
) -> str:
unique_prefix = str(uuid.uuid4())[:10]
return unique_prefix + ID_SEPARATOR + str(message_id)
def decompose_action_id(feedback_id: str) -> tuple[int, str | None, int | None]:
"""Decompose into query_id, document_id, document_rank, see above function"""
try:
@@ -313,7 +400,7 @@ def get_channel_name_from_id(
raise e
def fetch_user_ids_from_emails(
def fetch_slack_user_ids_from_emails(
user_emails: list[str], client: WebClient
) -> tuple[list[str], list[str]]:
user_ids: list[str] = []
@@ -522,7 +609,7 @@ class SlackRateLimiter:
self.last_reset_time = time.time()
def notify(
self, client: WebClient, channel: str, position: int, thread_ts: Optional[str]
self, client: WebClient, channel: str, position: int, thread_ts: str | None
) -> None:
respond_in_thread(
client=client,

View File

@@ -3,6 +3,7 @@ from datetime import datetime
from datetime import timedelta
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import func
@@ -30,6 +31,7 @@ from danswer.db.models import SearchDoc
from danswer.db.models import SearchDoc as DBSearchDoc
from danswer.db.models import ToolCall
from danswer.db.models import User
from danswer.db.persona import get_best_persona_id_for_user
from danswer.db.pg_file_store import delete_lobj_by_name
from danswer.file_store.models import FileDescriptor
from danswer.llm.override_models import LLMOverride
@@ -143,16 +145,10 @@ def get_chat_sessions_by_user(
user_id: UUID | None,
deleted: bool | None,
db_session: Session,
only_one_shot: bool = False,
limit: int = 50,
) -> list[ChatSession]:
stmt = select(ChatSession).where(ChatSession.user_id == user_id)
if only_one_shot:
stmt = stmt.where(ChatSession.one_shot.is_(True))
else:
stmt = stmt.where(ChatSession.one_shot.is_(False))
stmt = stmt.order_by(desc(ChatSession.time_created))
if deleted is not None:
@@ -224,12 +220,11 @@ def delete_messages_and_files_from_chat_session(
def create_chat_session(
db_session: Session,
description: str,
description: str | None,
user_id: UUID | None,
persona_id: int | None, # Can be none if temporary persona is used
llm_override: LLMOverride | None = None,
prompt_override: PromptOverride | None = None,
one_shot: bool = False,
danswerbot_flow: bool = False,
slack_thread_id: str | None = None,
) -> ChatSession:
@@ -239,7 +234,6 @@ def create_chat_session(
description=description,
llm_override=llm_override,
prompt_override=prompt_override,
one_shot=one_shot,
danswerbot_flow=danswerbot_flow,
slack_thread_id=slack_thread_id,
)
@@ -250,6 +244,48 @@ def create_chat_session(
return chat_session
def duplicate_chat_session_for_user_from_slack(
db_session: Session,
user: User | None,
chat_session_id: UUID,
) -> ChatSession:
"""
This takes a chat session id for a session in Slack and:
- Creates a new chat session in the DB
- Tries to copy the persona from the original chat session
(if it is available to the user clicking the button)
- Sets the user to the given user (if provided)
"""
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id,
user_id=None, # Ignore user permissions for this
db_session=db_session,
)
if not chat_session:
raise HTTPException(status_code=400, detail="Invalid Chat Session ID provided")
# This enforces permissions and sets a default
new_persona_id = get_best_persona_id_for_user(
db_session=db_session,
user=user,
persona_id=chat_session.persona_id,
)
return create_chat_session(
db_session=db_session,
user_id=user.id if user else None,
persona_id=new_persona_id,
# Set this to empty string so the frontend will force a rename
description="",
llm_override=chat_session.llm_override,
prompt_override=chat_session.prompt_override,
# Chat is in UI now so this is false
danswerbot_flow=False,
# Maybe we want this in the future to track if it was created from Slack
slack_thread_id=None,
)
def update_chat_session(
db_session: Session,
user_id: UUID | None,
@@ -336,6 +372,28 @@ def get_chat_message(
return chat_message
def get_chat_session_by_message_id(
db_session: Session,
message_id: int,
) -> ChatSession:
"""
Should only be used for Slack
Get the chat session associated with a specific message ID
Note: this ignores permission checks.
"""
stmt = select(ChatMessage).where(ChatMessage.id == message_id)
result = db_session.execute(stmt)
chat_message = result.scalar_one_or_none()
if chat_message is None:
raise ValueError(
f"Unable to find chat session associated with message ID: {message_id}"
)
return chat_message.chat_session
def get_chat_messages_by_sessions(
chat_session_ids: list[UUID],
user_id: UUID | None,
@@ -355,6 +413,44 @@ def get_chat_messages_by_sessions(
return db_session.execute(stmt).scalars().all()
def add_chats_to_session_from_slack_thread(
db_session: Session,
slack_chat_session_id: UUID,
new_chat_session_id: UUID,
) -> None:
new_root_message = get_or_create_root_message(
chat_session_id=new_chat_session_id,
db_session=db_session,
)
for chat_message in get_chat_messages_by_sessions(
chat_session_ids=[slack_chat_session_id],
user_id=None, # Ignore user permissions for this
db_session=db_session,
skip_permission_check=True,
):
if chat_message.message_type == MessageType.SYSTEM:
continue
# Duplicate the message
new_root_message = create_new_chat_message(
db_session=db_session,
chat_session_id=new_chat_session_id,
parent_message=new_root_message,
message=chat_message.message,
files=chat_message.files,
rephrased_query=chat_message.rephrased_query,
error=chat_message.error,
citations=chat_message.citations,
reference_docs=chat_message.search_docs,
tool_call=chat_message.tool_call,
prompt_id=chat_message.prompt_id,
token_count=chat_message.token_count,
message_type=chat_message.message_type,
alternate_assistant_id=chat_message.alternate_assistant_id,
overridden_model=chat_message.overridden_model,
)
def get_search_docs_for_chat_message(
chat_message_id: int, db_session: Session
) -> list[SearchDoc]:

View File

@@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
from danswer.configs.app_configs import DEFAULT_PRUNING_FREQ
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import InputType
from danswer.db.enums import IndexingMode
from danswer.db.models import Connector
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexAttempt
@@ -311,3 +312,25 @@ def mark_cc_pair_as_external_group_synced(db_session: Session, cc_pair_id: int)
# If this changes, we need to update this function.
cc_pair.last_time_external_group_sync = datetime.now(timezone.utc)
db_session.commit()
def mark_ccpair_with_indexing_trigger(
cc_pair_id: int, indexing_mode: IndexingMode | None, db_session: Session
) -> None:
"""indexing_mode sets a field which will be picked up by a background task
to trigger indexing. Set to None to disable the trigger."""
try:
cc_pair = db_session.execute(
select(ConnectorCredentialPair)
.where(ConnectorCredentialPair.id == cc_pair_id)
.with_for_update()
).scalar_one()
if cc_pair is None:
raise ValueError(f"No cc_pair with ID: {cc_pair_id}")
cc_pair.indexing_trigger = indexing_mode
db_session.commit()
except Exception:
db_session.rollback()
raise

View File

@@ -324,8 +324,11 @@ def associate_default_cc_pair(db_session: Session) -> None:
def _relate_groups_to_cc_pair__no_commit(
db_session: Session,
cc_pair_id: int,
user_group_ids: list[int],
user_group_ids: list[int] | None = None,
) -> None:
if not user_group_ids:
return
for group_id in user_group_ids:
db_session.add(
UserGroup__ConnectorCredentialPair(
@@ -402,12 +405,11 @@ def add_credential_to_connector(
db_session.flush() # make sure the association has an id
db_session.refresh(association)
if groups and access_type != AccessType.SYNC:
_relate_groups_to_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=association.id,
user_group_ids=groups,
)
_relate_groups_to_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=association.id,
user_group_ids=groups,
)
db_session.commit()

View File

@@ -20,7 +20,6 @@ from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.server.documents.models import CredentialBase
from danswer.server.documents.models import CredentialDataUpdateRequest
from danswer.utils.logger import setup_logger
@@ -248,7 +247,6 @@ def create_credential(
)
db_session.commit()
return credential
@@ -263,7 +261,8 @@ def _cleanup_credential__user_group_relationships__no_commit(
def alter_credential(
credential_id: int,
credential_data: CredentialDataUpdateRequest,
name: str,
credential_json: dict[str, Any],
user: User,
db_session: Session,
) -> Credential | None:
@@ -273,11 +272,13 @@ def alter_credential(
if credential is None:
return None
credential.name = credential_data.name
credential.name = name
# Update only the keys present in credential_data.credential_json
for key, value in credential_data.credential_json.items():
credential.credential_json[key] = value
# Assign a new dictionary to credential.credential_json
credential.credential_json = {
**credential.credential_json,
**credential_json,
}
credential.user_id = user.id if user is not None else None
db_session.commit()
@@ -310,8 +311,8 @@ def update_credential_json(
credential = fetch_credential_by_id(credential_id, user, db_session)
if credential is None:
return None
credential.credential_json = credential_json
credential.credential_json = credential_json
db_session.commit()
return credential

View File

@@ -37,6 +37,7 @@ from danswer.configs.app_configs import POSTGRES_PORT
from danswer.configs.app_configs import POSTGRES_USER
from danswer.configs.app_configs import USER_AUTH_SECRET
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from danswer.server.utils import BasicAuthenticationError
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
@@ -426,7 +427,9 @@ def get_session() -> Generator[Session, None, None]:
"""Generate a database session with the appropriate tenant schema set."""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
raise HTTPException(status_code=401, detail="User must authenticate")
raise BasicAuthenticationError(
detail="User must authenticate",
)
engine = get_sqlalchemy_engine()

View File

@@ -19,6 +19,11 @@ class IndexingStatus(str, PyEnum):
return self in terminal_states
class IndexingMode(str, PyEnum):
UPDATE = "update"
REINDEX = "reindex"
# these may differ in the future, which is why we're okay with this duplication
class DeletionStatus(str, PyEnum):
NOT_STARTED = "not_started"

View File

@@ -522,12 +522,16 @@ def expire_index_attempts(
search_settings_id: int,
db_session: Session,
) -> None:
delete_query = (
delete(IndexAttempt)
not_started_query = (
update(IndexAttempt)
.where(IndexAttempt.search_settings_id == search_settings_id)
.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
.values(
status=IndexingStatus.CANCELED,
error_msg="Canceled, likely due to model swap",
)
)
db_session.execute(delete_query)
db_session.execute(not_started_query)
update_query = (
update(IndexAttempt)
@@ -549,9 +553,14 @@ def cancel_indexing_attempts_for_ccpair(
include_secondary_index: bool = False,
) -> None:
stmt = (
delete(IndexAttempt)
update(IndexAttempt)
.where(IndexAttempt.connector_credential_pair_id == cc_pair_id)
.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
.values(
status=IndexingStatus.CANCELED,
error_msg="Canceled by user",
time_started=datetime.now(timezone.utc),
)
)
if not include_secondary_index:

View File

@@ -1,202 +0,0 @@
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import InputPrompt
from danswer.db.models import User
from danswer.server.features.input_prompt.models import InputPromptSnapshot
from danswer.server.manage.models import UserInfo
from danswer.utils.logger import setup_logger
logger = setup_logger()
def insert_input_prompt_if_not_exists(
user: User | None,
input_prompt_id: int | None,
prompt: str,
content: str,
active: bool,
is_public: bool,
db_session: Session,
commit: bool = True,
) -> InputPrompt:
if input_prompt_id is not None:
input_prompt = (
db_session.query(InputPrompt).filter_by(id=input_prompt_id).first()
)
else:
query = db_session.query(InputPrompt).filter(InputPrompt.prompt == prompt)
if user:
query = query.filter(InputPrompt.user_id == user.id)
else:
query = query.filter(InputPrompt.user_id.is_(None))
input_prompt = query.first()
if input_prompt is None:
input_prompt = InputPrompt(
id=input_prompt_id,
prompt=prompt,
content=content,
active=active,
is_public=is_public or user is None,
user_id=user.id if user else None,
)
db_session.add(input_prompt)
if commit:
db_session.commit()
return input_prompt
def insert_input_prompt(
prompt: str,
content: str,
is_public: bool,
user: User | None,
db_session: Session,
) -> InputPrompt:
input_prompt = InputPrompt(
prompt=prompt,
content=content,
active=True,
is_public=is_public or user is None,
user_id=user.id if user is not None else None,
)
db_session.add(input_prompt)
db_session.commit()
return input_prompt
def update_input_prompt(
user: User | None,
input_prompt_id: int,
prompt: str,
content: str,
active: bool,
db_session: Session,
) -> InputPrompt:
input_prompt = db_session.scalar(
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
)
if input_prompt is None:
raise ValueError(f"No input prompt with id {input_prompt_id}")
if not validate_user_prompt_authorization(user, input_prompt):
raise HTTPException(status_code=401, detail="You don't own this prompt")
input_prompt.prompt = prompt
input_prompt.content = content
input_prompt.active = active
db_session.commit()
return input_prompt
def validate_user_prompt_authorization(
user: User | None, input_prompt: InputPrompt
) -> bool:
prompt = InputPromptSnapshot.from_model(input_prompt=input_prompt)
if prompt.user_id is not None:
if user is None:
return False
user_details = UserInfo.from_model(user)
if str(user_details.id) != str(prompt.user_id):
return False
return True
def remove_public_input_prompt(input_prompt_id: int, db_session: Session) -> None:
input_prompt = db_session.scalar(
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
)
if input_prompt is None:
raise ValueError(f"No input prompt with id {input_prompt_id}")
if not input_prompt.is_public:
raise HTTPException(status_code=400, detail="This prompt is not public")
db_session.delete(input_prompt)
db_session.commit()
def remove_input_prompt(
user: User | None, input_prompt_id: int, db_session: Session
) -> None:
input_prompt = db_session.scalar(
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
)
if input_prompt is None:
raise ValueError(f"No input prompt with id {input_prompt_id}")
if input_prompt.is_public:
raise HTTPException(
status_code=400, detail="Cannot delete public prompts with this method"
)
if not validate_user_prompt_authorization(user, input_prompt):
raise HTTPException(status_code=401, detail="You do not own this prompt")
db_session.delete(input_prompt)
db_session.commit()
def fetch_input_prompt_by_id(
id: int, user_id: UUID | None, db_session: Session
) -> InputPrompt:
query = select(InputPrompt).where(InputPrompt.id == id)
if user_id:
query = query.where(
(InputPrompt.user_id == user_id) | (InputPrompt.user_id is None)
)
else:
# If no user_id is provided, only fetch prompts without a user_id (aka public)
query = query.where(InputPrompt.user_id == None) # noqa
result = db_session.scalar(query)
if result is None:
raise HTTPException(422, "No input prompt found")
return result
def fetch_public_input_prompts(
db_session: Session,
) -> list[InputPrompt]:
query = select(InputPrompt).where(InputPrompt.is_public)
return list(db_session.scalars(query).all())
def fetch_input_prompts_by_user(
db_session: Session,
user_id: UUID | None,
active: bool | None = None,
include_public: bool = False,
) -> list[InputPrompt]:
query = select(InputPrompt)
if user_id is not None:
if include_public:
query = query.where(
(InputPrompt.user_id == user_id) | InputPrompt.is_public
)
else:
query = query.where(InputPrompt.user_id == user_id)
elif include_public:
query = query.where(InputPrompt.is_public)
if active is not None:
query = query.where(InputPrompt.active == active)
return list(db_session.scalars(query).all())

View File

@@ -1,6 +1,5 @@
import datetime
import json
from enum import Enum as PyEnum
from typing import Any
from typing import Literal
from typing import NotRequired
@@ -42,7 +41,7 @@ from danswer.configs.constants import DEFAULT_BOOST
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import FileOrigin
from danswer.configs.constants import MessageType
from danswer.db.enums import AccessType
from danswer.db.enums import AccessType, IndexingMode
from danswer.configs.constants import NotificationType
from danswer.configs.constants import SearchFeedbackType
from danswer.configs.constants import TokenRateLimitScope
@@ -126,6 +125,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
# if specified, controls the assistants that are shown to the user + their order
# if not specified, all assistants are shown
auto_scroll: Mapped[bool] = mapped_column(Boolean, default=True)
chosen_assistants: Mapped[list[int] | None] = mapped_column(
postgresql.JSONB(), nullable=True, default=None
)
@@ -159,9 +159,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
)
prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user")
input_prompts: Mapped[list["InputPrompt"]] = relationship(
"InputPrompt", back_populates="user"
)
# Personas owned by this user
personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user")
@@ -178,31 +175,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
)
class InputPrompt(Base):
__tablename__ = "inputprompt"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
prompt: Mapped[str] = mapped_column(String)
content: Mapped[str] = mapped_column(String)
active: Mapped[bool] = mapped_column(Boolean)
user: Mapped[User | None] = relationship("User", back_populates="input_prompts")
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
class InputPrompt__User(Base):
__tablename__ = "inputprompt__user"
input_prompt_id: Mapped[int] = mapped_column(
ForeignKey("inputprompt.id"), primary_key=True
)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("inputprompt.id"), primary_key=True
)
class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base):
pass
@@ -438,6 +410,10 @@ class ConnectorCredentialPair(Base):
total_docs_indexed: Mapped[int] = mapped_column(Integer, default=0)
indexing_trigger: Mapped[IndexingMode | None] = mapped_column(
Enum(IndexingMode, native_enum=False), nullable=True
)
connector: Mapped["Connector"] = relationship(
"Connector", back_populates="credentials"
)
@@ -592,6 +568,25 @@ class Connector(Base):
list["DocumentByConnectorCredentialPair"]
] = relationship("DocumentByConnectorCredentialPair", back_populates="connector")
# synchronize this validation logic with RefreshFrequencySchema etc on front end
# until we have a centralized validation schema
# TODO(rkuo): experiment with SQLAlchemy validators rather than manual checks
# https://docs.sqlalchemy.org/en/20/orm/mapped_attributes.html
def validate_refresh_freq(self) -> None:
if self.refresh_freq is not None:
if self.refresh_freq < 60:
raise ValueError(
"refresh_freq must be greater than or equal to 60 seconds."
)
def validate_prune_freq(self) -> None:
if self.prune_freq is not None:
if self.prune_freq < 86400:
raise ValueError(
"prune_freq must be greater than or equal to 86400 seconds."
)
class Credential(Base):
__tablename__ = "credential"
@@ -959,9 +954,8 @@ class ChatSession(Base):
persona_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id"), nullable=True
)
description: Mapped[str] = mapped_column(Text)
# One-shot direct answering, currently the two types of chats are not mixed
one_shot: Mapped[bool] = mapped_column(Boolean, default=False)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
# This chat created by DanswerBot
danswerbot_flow: Mapped[bool] = mapped_column(Boolean, default=False)
# Only ever set to True if system is set to not hard-delete chats
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
@@ -1480,18 +1474,16 @@ class ChannelConfig(TypedDict):
# If None then no follow up
# If empty list, follow up with no tags
follow_up_tags: NotRequired[list[str]]
class SlackBotResponseType(str, PyEnum):
QUOTES = "quotes"
CITATIONS = "citations"
show_continue_in_web_ui: NotRequired[bool] # defaults to False
class SlackChannelConfig(Base):
__tablename__ = "slack_channel_config"
id: Mapped[int] = mapped_column(primary_key=True)
slack_bot_id: Mapped[int] = mapped_column(ForeignKey("slack_bot.id"), nullable=True)
slack_bot_id: Mapped[int] = mapped_column(
ForeignKey("slack_bot.id"), nullable=False
)
persona_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id"), nullable=True
)
@@ -1499,9 +1491,6 @@ class SlackChannelConfig(Base):
channel_config: Mapped[ChannelConfig] = mapped_column(
postgresql.JSONB(), nullable=False
)
response_type: Mapped[SlackBotResponseType] = mapped_column(
Enum(SlackBotResponseType, native_enum=False), nullable=False
)
enable_auto_filters: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
@@ -1532,6 +1521,7 @@ class SlackBot(Base):
slack_channel_configs: Mapped[list[SlackChannelConfig]] = relationship(
"SlackChannelConfig",
back_populates="slack_bot",
cascade="all, delete-orphan",
)

View File

@@ -113,6 +113,31 @@ def fetch_persona_by_id(
return persona
def get_best_persona_id_for_user(
db_session: Session, user: User | None, persona_id: int | None = None
) -> int | None:
if persona_id is not None:
stmt = select(Persona).where(Persona.id == persona_id).distinct()
stmt = _add_user_filters(
stmt=stmt,
user=user,
# We don't want to filter by editable here, we just want to see if the
# persona is usable by the user
get_editable=False,
)
persona = db_session.scalars(stmt).one_or_none()
if persona:
return persona.id
# If the persona is not found, or the slack bot is using doc sets instead of personas,
# we need to find the best persona for the user
# This is the persona with the highest display priority that the user has access to
stmt = select(Persona).order_by(Persona.display_priority.desc()).distinct()
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=True)
persona = db_session.scalars(stmt).one_or_none()
return persona.id if persona else None
def _get_persona_by_name(
persona_name: str, user: User | None, db_session: Session
) -> Persona | None:
@@ -160,7 +185,7 @@ def create_update_persona(
"persona_id": persona_id,
"user": user,
"db_session": db_session,
**create_persona_request.dict(exclude={"users", "groups"}),
**create_persona_request.model_dump(exclude={"users", "groups"}),
}
persona = upsert_persona(**persona_data)
@@ -390,9 +415,6 @@ def upsert_prompt(
return prompt
# NOTE: This operation cannot update persona configuration options that
# are core to the persona, such as its display priority and
# whether or not the assistant is a built-in / default assistant
def upsert_persona(
user: User | None,
name: str,
@@ -424,10 +446,16 @@ def upsert_persona(
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
chunks_below: int = CONTEXT_CHUNKS_BELOW,
) -> Persona:
"""
NOTE: This operation cannot update persona configuration options that
are core to the persona, such as its display priority and
whether or not the assistant is a built-in / default assistant
"""
if persona_id is not None:
persona = db_session.query(Persona).filter_by(id=persona_id).first()
existing_persona = db_session.query(Persona).filter_by(id=persona_id).first()
else:
persona = _get_persona_by_name(
existing_persona = _get_persona_by_name(
persona_name=name, user=user, db_session=db_session
)
@@ -453,57 +481,78 @@ def upsert_persona(
prompts = None
if prompt_ids is not None:
prompts = db_session.query(Prompt).filter(Prompt.id.in_(prompt_ids)).all()
if not prompts and prompt_ids:
raise ValueError("prompts not found")
if prompts is not None and len(prompts) == 0:
raise ValueError(
f"Invalid Persona config, no valid prompts "
f"specified. Specified IDs were: '{prompt_ids}'"
)
# ensure all specified tools are valid
if tools:
validate_persona_tools(tools)
if persona:
if persona.builtin_persona and not builtin_persona:
if existing_persona:
# Built-in personas can only be updated through YAML configuration.
# This ensures that core system personas are not modified unintentionally.
if existing_persona.builtin_persona and not builtin_persona:
raise ValueError("Cannot update builtin persona with non-builtin.")
# this checks if the user has permission to edit the persona
persona = fetch_persona_by_id(
db_session=db_session, persona_id=persona.id, user=user, get_editable=True
# will raise an Exception if the user does not have permission
existing_persona = fetch_persona_by_id(
db_session=db_session,
persona_id=existing_persona.id,
user=user,
get_editable=True,
)
persona.name = name
persona.description = description
persona.num_chunks = num_chunks
persona.chunks_above = chunks_above
persona.chunks_below = chunks_below
persona.llm_relevance_filter = llm_relevance_filter
persona.llm_filter_extraction = llm_filter_extraction
persona.recency_bias = recency_bias
persona.llm_model_provider_override = llm_model_provider_override
persona.llm_model_version_override = llm_model_version_override
persona.starter_messages = starter_messages
persona.deleted = False # Un-delete if previously deleted
persona.is_public = is_public
persona.icon_color = icon_color
persona.icon_shape = icon_shape
# The following update excludes `default`, `built-in`, and display priority.
# Display priority is handled separately in the `display-priority` endpoint.
# `default` and `built-in` properties can only be set when creating a persona.
existing_persona.name = name
existing_persona.description = description
existing_persona.num_chunks = num_chunks
existing_persona.chunks_above = chunks_above
existing_persona.chunks_below = chunks_below
existing_persona.llm_relevance_filter = llm_relevance_filter
existing_persona.llm_filter_extraction = llm_filter_extraction
existing_persona.recency_bias = recency_bias
existing_persona.llm_model_provider_override = llm_model_provider_override
existing_persona.llm_model_version_override = llm_model_version_override
existing_persona.starter_messages = starter_messages
existing_persona.deleted = False # Un-delete if previously deleted
existing_persona.is_public = is_public
existing_persona.icon_color = icon_color
existing_persona.icon_shape = icon_shape
if remove_image or uploaded_image_id:
persona.uploaded_image_id = uploaded_image_id
persona.is_visible = is_visible
persona.search_start_date = search_start_date
persona.category_id = category_id
existing_persona.uploaded_image_id = uploaded_image_id
existing_persona.is_visible = is_visible
existing_persona.search_start_date = search_start_date
existing_persona.category_id = category_id
# Do not delete any associations manually added unless
# a new updated list is provided
if document_sets is not None:
persona.document_sets.clear()
persona.document_sets = document_sets or []
existing_persona.document_sets.clear()
existing_persona.document_sets = document_sets or []
if prompts is not None:
persona.prompts.clear()
persona.prompts = prompts or []
existing_persona.prompts.clear()
existing_persona.prompts = prompts
if tools is not None:
persona.tools = tools or []
existing_persona.tools = tools or []
persona = existing_persona
else:
persona = Persona(
if not prompts:
raise ValueError(
"Invalid Persona config. "
"Must specify at least one prompt for a new persona."
)
new_persona = Persona(
id=persona_id,
user_id=user.id if user else None,
is_public=is_public,
@@ -516,7 +565,7 @@ def upsert_persona(
llm_filter_extraction=llm_filter_extraction,
recency_bias=recency_bias,
builtin_persona=builtin_persona,
prompts=prompts or [],
prompts=prompts,
document_sets=document_sets or [],
llm_model_provider_override=llm_model_provider_override,
llm_model_version_override=llm_model_version_override,
@@ -531,8 +580,8 @@ def upsert_persona(
is_default_persona=is_default_persona,
category_id=category_id,
)
db_session.add(persona)
db_session.add(new_persona)
persona = new_persona
if commit:
db_session.commit()
else:
@@ -733,6 +782,8 @@ def get_prompt_by_name(
if user and user.role != UserRole.ADMIN:
stmt = stmt.where(Prompt.user_id == user.id)
# Order by ID to ensure consistent result when multiple prompts exist
stmt = stmt.order_by(Prompt.id).limit(1)
result = db_session.execute(stmt).scalar_one_or_none()
return result

View File

@@ -143,6 +143,25 @@ def get_secondary_search_settings(db_session: Session) -> SearchSettings | None:
return latest_settings
def get_active_search_settings(db_session: Session) -> list[SearchSettings]:
"""Returns active search settings. The first entry will always be the current search
settings. If there are new search settings that are being migrated to, those will be
the second entry."""
search_settings_list: list[SearchSettings] = []
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings_list.append(primary_search_settings)
# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings_list.append(secondary_search_settings)
return search_settings_list
def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
query = select(SearchSettings).order_by(SearchSettings.id.desc())
result = db_session.execute(query)

View File

@@ -10,7 +10,6 @@ from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
from danswer.db.models import ChannelConfig
from danswer.db.models import Persona
from danswer.db.models import Persona__DocumentSet
from danswer.db.models import SlackBotResponseType
from danswer.db.models import SlackChannelConfig
from danswer.db.models import User
from danswer.db.persona import get_default_prompt
@@ -83,7 +82,6 @@ def insert_slack_channel_config(
slack_bot_id: int,
persona_id: int | None,
channel_config: ChannelConfig,
response_type: SlackBotResponseType,
standard_answer_category_ids: list[int],
enable_auto_filters: bool,
) -> SlackChannelConfig:
@@ -115,7 +113,6 @@ def insert_slack_channel_config(
slack_bot_id=slack_bot_id,
persona_id=persona_id,
channel_config=channel_config,
response_type=response_type,
standard_answer_categories=existing_standard_answer_categories,
enable_auto_filters=enable_auto_filters,
)
@@ -130,7 +127,6 @@ def update_slack_channel_config(
slack_channel_config_id: int,
persona_id: int | None,
channel_config: ChannelConfig,
response_type: SlackBotResponseType,
standard_answer_category_ids: list[int],
enable_auto_filters: bool,
) -> SlackChannelConfig:
@@ -170,7 +166,6 @@ def update_slack_channel_config(
# will encounter `violates foreign key constraint` errors
slack_channel_config.persona_id = persona_id
slack_channel_config.channel_config = channel_config
slack_channel_config.response_type = response_type
slack_channel_config.standard_answer_categories = list(
existing_standard_answer_categories
)

View File

@@ -148,6 +148,7 @@ class Indexable(abc.ABC):
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
fresh_index: bool = False,
) -> set[DocumentInsertionRecord]:
"""
Takes a list of document chunks and indexes them in the document index
@@ -165,9 +166,14 @@ class Indexable(abc.ABC):
only needs to index chunks into the PRIMARY index. Do not update the secondary index here,
it is done automatically outside of this code.
NOTE: The fresh_index parameter, when set to True, assumes no documents have been previously
indexed for the given index/tenant. This can be used to optimize the indexing process for
new or empty indices.
Parameters:
- chunks: Document chunks with all of the information needed for indexing to the document
index.
- fresh_index: Boolean indicating whether this is a fresh index with no existing documents.
Returns:
List of document ids which map to unique documents and are used for deduping chunks

View File

@@ -4,6 +4,8 @@ schema DANSWER_CHUNK_NAME {
# Not to be confused with the UUID generated for this chunk which is called documentid by default
field document_id type string {
indexing: summary | attribute
attribute: fast-search
rank: filter
}
field chunk_id type int {
indexing: summary | attribute

View File

@@ -306,6 +306,7 @@ class VespaIndex(DocumentIndex):
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
fresh_index: bool = False,
) -> set[DocumentInsertionRecord]:
"""Receive a list of chunks from a batch of documents and index the chunks into Vespa along
with updating the associated permissions. Assumes that a document will not be split into
@@ -322,26 +323,29 @@ class VespaIndex(DocumentIndex):
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
get_vespa_http_client() as http_client,
):
# Check for existing documents, existing documents need to have all of their chunks deleted
# prior to indexing as the document size (num chunks) may have shrunk
first_chunks = [chunk for chunk in cleaned_chunks if chunk.chunk_id == 0]
for chunk_batch in batch_generator(first_chunks, BATCH_SIZE):
existing_docs.update(
get_existing_documents_from_chunks(
chunks=chunk_batch,
if not fresh_index:
# Check for existing documents, existing documents need to have all of their chunks deleted
# prior to indexing as the document size (num chunks) may have shrunk
first_chunks = [
chunk for chunk in cleaned_chunks if chunk.chunk_id == 0
]
for chunk_batch in batch_generator(first_chunks, BATCH_SIZE):
existing_docs.update(
get_existing_documents_from_chunks(
chunks=chunk_batch,
index_name=self.index_name,
http_client=http_client,
executor=executor,
)
)
for doc_id_batch in batch_generator(existing_docs, BATCH_SIZE):
delete_vespa_docs(
document_ids=doc_id_batch,
index_name=self.index_name,
http_client=http_client,
executor=executor,
)
)
for doc_id_batch in batch_generator(existing_docs, BATCH_SIZE):
delete_vespa_docs(
document_ids=doc_id_batch,
index_name=self.index_name,
http_client=http_client,
executor=executor,
)
for chunk_batch in batch_generator(cleaned_chunks, BATCH_SIZE):
batch_index_vespa_chunks(

View File

@@ -6,6 +6,7 @@ import zipfile
from collections.abc import Callable
from collections.abc import Iterator
from email.parser import Parser as EmailParser
from io import BytesIO
from pathlib import Path
from typing import Any
from typing import Dict
@@ -15,13 +16,17 @@ import chardet
import docx # type: ignore
import openpyxl # type: ignore
import pptx # type: ignore
from docx import Document
from fastapi import UploadFile
from pypdf import PdfReader
from pypdf.errors import PdfStreamError
from danswer.configs.constants import DANSWER_METADATA_FILENAME
from danswer.configs.constants import FileOrigin
from danswer.file_processing.html_utils import parse_html_page_basic
from danswer.file_processing.unstructured import get_unstructured_api_key
from danswer.file_processing.unstructured import unstructured_to_text
from danswer.file_store.file_store import FileStore
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -65,7 +70,7 @@ def get_file_ext(file_path_or_name: str | Path) -> str:
return extension
def check_file_ext_is_valid(ext: str) -> bool:
def is_valid_file_ext(ext: str) -> bool:
return ext in VALID_FILE_EXTENSIONS
@@ -295,7 +300,7 @@ def pptx_to_text(file: IO[Any]) -> str:
def xlsx_to_text(file: IO[Any]) -> str:
workbook = openpyxl.load_workbook(file)
workbook = openpyxl.load_workbook(file, read_only=True)
text_content = []
for sheet in workbook.worksheets:
sheet_string = "\n".join(
@@ -359,7 +364,7 @@ def extract_file_text(
elif file_name is not None:
final_extension = get_file_ext(file_name)
if check_file_ext_is_valid(final_extension):
if is_valid_file_ext(final_extension):
return extension_to_function.get(final_extension, file_io_to_text)(file)
# Either the file somehow has no name or the extension is not one that we recognize
@@ -375,3 +380,35 @@ def extract_file_text(
) from e
logger.warning(f"Failed to process file {file_name or 'Unknown'}: {str(e)}")
return ""
def convert_docx_to_txt(
file: UploadFile, file_store: FileStore, file_path: str
) -> None:
file.file.seek(0)
docx_content = file.file.read()
doc = Document(BytesIO(docx_content))
# Extract text from the document
full_text = []
for para in doc.paragraphs:
full_text.append(para.text)
# Join the extracted text
text_content = "\n".join(full_text)
txt_file_path = docx_to_txt_filename(file_path)
file_store.save_file(
file_name=txt_file_path,
content=BytesIO(text_content.encode("utf-8")),
display_name=file.filename,
file_origin=FileOrigin.CONNECTOR,
file_type="text/plain",
)
def docx_to_txt_filename(file_path: str) -> str:
"""
Convert a .docx file path to its corresponding .txt file path.
"""
return file_path.rsplit(".", 1)[0] + ".txt"

View File

@@ -59,6 +59,12 @@ class FileStore(ABC):
Contents of the file and metadata dict
"""
@abstractmethod
def read_file_record(self, file_name: str) -> PGFileStore:
"""
Read the file record by the name
"""
@abstractmethod
def delete_file(self, file_name: str) -> None:
"""

View File

@@ -1,6 +1,6 @@
import base64
from collections.abc import Callable
from io import BytesIO
from typing import Any
from typing import cast
from uuid import uuid4
@@ -13,8 +13,8 @@ from danswer.db.models import ChatMessage
from danswer.file_store.file_store import get_default_file_store
from danswer.file_store.models import FileDescriptor
from danswer.file_store.models import InMemoryChatFile
from danswer.utils.b64 import get_image_type
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
def load_chat_file(
@@ -75,11 +75,58 @@ def save_file_from_url(url: str, tenant_id: str) -> str:
return unique_id
def save_files_from_urls(urls: list[str]) -> list[str]:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
def save_file_from_base64(base64_string: str, tenant_id: str) -> str:
with get_session_with_tenant(tenant_id) as db_session:
unique_id = str(uuid4())
file_store = get_default_file_store(db_session)
file_store.save_file(
file_name=unique_id,
content=BytesIO(base64.b64decode(base64_string)),
display_name="GeneratedImage",
file_origin=FileOrigin.CHAT_IMAGE_GEN,
file_type=get_image_type(base64_string),
)
return unique_id
funcs: list[tuple[Callable[..., Any], tuple[Any, ...]]] = [
(save_file_from_url, (url, tenant_id)) for url in urls
def save_file(
tenant_id: str,
url: str | None = None,
base64_data: str | None = None,
) -> str:
"""Save a file from either a URL or base64 encoded string.
Args:
tenant_id: The tenant ID to save the file under
url: URL to download file from
base64_data: Base64 encoded file data
Returns:
The unique ID of the saved file
Raises:
ValueError: If neither url nor base64_data is provided, or if both are provided
"""
if url is not None and base64_data is not None:
raise ValueError("Cannot specify both url and base64_data")
if url is not None:
return save_file_from_url(url, tenant_id)
elif base64_data is not None:
return save_file_from_base64(base64_data, tenant_id)
else:
raise ValueError("Must specify either url or base64_data")
def save_files(urls: list[str], base64_files: list[str], tenant_id: str) -> list[str]:
# NOTE: be explicit about typing so that if we change things, we get notified
funcs: list[
tuple[
Callable[[str, str | None, str | None], str],
tuple[str, str | None, str | None],
]
] = [(save_file, (tenant_id, url, None)) for url in urls] + [
(save_file, (tenant_id, None, base64_file)) for base64_file in base64_files
]
# Must pass in tenant_id here, since this is called by multithreading
return run_functions_tuples_in_parallel(funcs)

View File

@@ -1,4 +1,5 @@
import traceback
from collections.abc import Callable
from functools import partial
from http import HTTPStatus
from typing import Protocol
@@ -12,6 +13,7 @@ from danswer.access.access import get_access_for_documents
from danswer.access.models import DocumentAccess
from danswer.configs.app_configs import ENABLE_MULTIPASS_INDEXING
from danswer.configs.app_configs import INDEXING_EXCEPTION_LIMIT
from danswer.configs.app_configs import MAX_DOCUMENT_CHARS
from danswer.configs.constants import DEFAULT_BOOST
from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
get_experts_stores_representations,
@@ -202,40 +204,13 @@ def index_doc_batch_with_handler(
def index_doc_batch_prepare(
document_batch: list[Document],
documents: list[Document],
index_attempt_metadata: IndexAttemptMetadata,
db_session: Session,
ignore_time_skip: bool = False,
) -> DocumentBatchPrepareContext | None:
"""Sets up the documents in the relational DB (source of truth) for permissions, metadata, etc.
This preceeds indexing it into the actual document index."""
documents: list[Document] = []
for document in document_batch:
empty_contents = not any(section.text.strip() for section in document.sections)
if (
(not document.title or not document.title.strip())
and not document.semantic_identifier.strip()
and empty_contents
):
# Skip documents that have neither title nor content
# If the document doesn't have either, then there is no useful information in it
# This is again verified later in the pipeline after chunking but at that point there should
# already be no documents that are empty.
logger.warning(
f"Skipping document with ID {document.id} as it has neither title nor content."
)
continue
if document.title is not None and not document.title.strip() and empty_contents:
# The title is explicitly empty ("" and not None) and the document is empty
# so when building the chunk text representation, it will be empty and unuseable
logger.warning(
f"Skipping document with ID {document.id} as the chunks will be empty."
)
continue
documents.append(document)
# Create a trimmed list of docs that don't have a newer updated at
# Shortcuts the time-consuming flow on connector index retries
document_ids: list[str] = [document.id for document in documents]
@@ -282,17 +257,64 @@ def index_doc_batch_prepare(
)
def filter_documents(document_batch: list[Document]) -> list[Document]:
documents: list[Document] = []
for document in document_batch:
empty_contents = not any(section.text.strip() for section in document.sections)
if (
(not document.title or not document.title.strip())
and not document.semantic_identifier.strip()
and empty_contents
):
# Skip documents that have neither title nor content
# If the document doesn't have either, then there is no useful information in it
# This is again verified later in the pipeline after chunking but at that point there should
# already be no documents that are empty.
logger.warning(
f"Skipping document with ID {document.id} as it has neither title nor content."
)
continue
if document.title is not None and not document.title.strip() and empty_contents:
# The title is explicitly empty ("" and not None) and the document is empty
# so when building the chunk text representation, it will be empty and unuseable
logger.warning(
f"Skipping document with ID {document.id} as the chunks will be empty."
)
continue
section_chars = sum(len(section.text) for section in document.sections)
if (
MAX_DOCUMENT_CHARS
and len(document.title or document.semantic_identifier) + section_chars
> MAX_DOCUMENT_CHARS
):
# Skip documents that are too long, later on there are more memory intensive steps done on the text
# and the container will run out of memory and crash. Several other checks are included upstream but
# those are at the connector level so a catchall is still needed.
# Assumption here is that files that are that long, are generated files and not the type users
# generally care for.
logger.warning(
f"Skipping document with ID {document.id} as it is too long."
)
continue
documents.append(document)
return documents
@log_function_time(debug_only=True)
def index_doc_batch(
*,
document_batch: list[Document],
chunker: Chunker,
embedder: IndexingEmbedder,
document_index: DocumentIndex,
document_batch: list[Document],
index_attempt_metadata: IndexAttemptMetadata,
db_session: Session,
ignore_time_skip: bool = False,
tenant_id: str | None = None,
filter_fnc: Callable[[list[Document]], list[Document]] = filter_documents,
) -> tuple[int, int]:
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
Note that the documents should already be batched at this point so that it does not inflate the
@@ -309,8 +331,11 @@ def index_doc_batch(
is_public=False,
)
logger.debug("Filtering Documents")
filtered_documents = filter_fnc(document_batch)
ctx = index_doc_batch_prepare(
document_batch=document_batch,
documents=filtered_documents,
index_attempt_metadata=index_attempt_metadata,
ignore_time_skip=ignore_time_skip,
db_session=db_session,

View File

@@ -1,163 +0,0 @@
from collections.abc import Callable
from collections.abc import Iterator
from typing import TYPE_CHECKING
from langchain.schema.messages import AIMessage
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import model_validator
from danswer.chat.models import AnswerQuestionStreamReturn
from danswer.configs.constants import MessageType
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.override_models import PromptOverride
from danswer.llm.utils import build_content_with_imgs
from danswer.tools.models import ToolCallFinalResult
if TYPE_CHECKING:
from danswer.db.models import ChatMessage
from danswer.db.models import Prompt
StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]
class PreviousMessage(BaseModel):
"""Simplified version of `ChatMessage`"""
message: str
token_count: int
message_type: MessageType
files: list[InMemoryChatFile]
tool_call: ToolCallFinalResult | None
@classmethod
def from_chat_message(
cls, chat_message: "ChatMessage", available_files: list[InMemoryChatFile]
) -> "PreviousMessage":
message_file_ids = (
[file["id"] for file in chat_message.files] if chat_message.files else []
)
return cls(
message=chat_message.message,
token_count=chat_message.token_count,
message_type=chat_message.message_type,
files=[
file
for file in available_files
if str(file.file_id) in message_file_ids
],
tool_call=ToolCallFinalResult(
tool_name=chat_message.tool_call.tool_name,
tool_args=chat_message.tool_call.tool_arguments,
tool_result=chat_message.tool_call.tool_result,
)
if chat_message.tool_call
else None,
)
def to_langchain_msg(self) -> BaseMessage:
content = build_content_with_imgs(self.message, self.files)
if self.message_type == MessageType.USER:
return HumanMessage(content=content)
elif self.message_type == MessageType.ASSISTANT:
return AIMessage(content=content)
else:
return SystemMessage(content=content)
class DocumentPruningConfig(BaseModel):
max_chunks: int | None = None
max_window_percentage: float | None = None
max_tokens: int | None = None
# different pruning behavior is expected when the
# user manually selects documents they want to chat with
# e.g. we don't want to truncate each document to be no more
# than one chunk long
is_manually_selected_docs: bool = False
# If user specifies to include additional context Chunks for each match, then different pruning
# is used. As many Sections as possible are included, and the last Section is truncated
# If this is false, all of the Sections are truncated if they are longer than the expected Chunk size.
# Sections are often expected to be longer than the maximum Chunk size but Chunks should not be.
use_sections: bool = True
# If using tools, then we need to consider the tool length
tool_num_tokens: int = 0
# If using a tool message to represent the docs, then we have to JSON serialize
# the document content, which adds to the token count.
using_tool_message: bool = False
class ContextualPruningConfig(DocumentPruningConfig):
num_chunk_multiple: int
@classmethod
def from_doc_pruning_config(
cls, num_chunk_multiple: int, doc_pruning_config: DocumentPruningConfig
) -> "ContextualPruningConfig":
return cls(num_chunk_multiple=num_chunk_multiple, **doc_pruning_config.dict())
class CitationConfig(BaseModel):
all_docs_useful: bool = False
class QuotesConfig(BaseModel):
pass
class AnswerStyleConfig(BaseModel):
citation_config: CitationConfig | None = None
quotes_config: QuotesConfig | None = None
document_pruning_config: DocumentPruningConfig = Field(
default_factory=DocumentPruningConfig
)
# forces the LLM to return a structured response, see
# https://platform.openai.com/docs/guides/structured-outputs/introduction
# right now, only used by the simple chat API
structured_response_format: dict | None = None
@model_validator(mode="after")
def check_quotes_and_citation(self) -> "AnswerStyleConfig":
if self.citation_config is None and self.quotes_config is None:
raise ValueError(
"One of `citation_config` or `quotes_config` must be provided"
)
if self.citation_config is not None and self.quotes_config is not None:
raise ValueError(
"Only one of `citation_config` or `quotes_config` must be provided"
)
return self
class PromptConfig(BaseModel):
"""Final representation of the Prompt configuration passed
into the `Answer` object."""
system_prompt: str
task_prompt: str
datetime_aware: bool
include_citations: bool
@classmethod
def from_model(
cls, model: "Prompt", prompt_override: PromptOverride | None = None
) -> "PromptConfig":
override_system_prompt = (
prompt_override.system_prompt if prompt_override else None
)
override_task_prompt = prompt_override.task_prompt if prompt_override else None
return cls(
system_prompt=override_system_prompt or model.system_prompt,
task_prompt=override_task_prompt or model.task_prompt,
datetime_aware=model.datetime_aware,
include_citations=model.include_citations,
)
model_config = ConfigDict(frozen=True)

View File

@@ -1,20 +0,0 @@
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
def build_dummy_prompt(
system_prompt: str, task_prompt: str, retrieval_disabled: bool
) -> str:
if retrieval_disabled:
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
user_query="<USER_QUERY>",
system_prompt=system_prompt,
task_prompt=task_prompt,
).strip()
return PARAMATERIZED_PROMPT.format(
context_docs_str="<CONTEXT_DOCS>",
user_query="<USER_QUERY>",
system_prompt=system_prompt,
task_prompt=task_prompt,
).strip()

View File

@@ -26,7 +26,9 @@ from langchain_core.messages.tool import ToolMessage
from langchain_core.prompt_values import PromptValue
from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
from danswer.configs.model_configs import DISABLE_LITELLM_STREAMING
from danswer.configs.model_configs import (
DISABLE_LITELLM_STREAMING,
)
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.configs.model_configs import LITELLM_EXTRA_BODY
from danswer.llm.interfaces import LLM
@@ -161,7 +163,9 @@ def _convert_delta_to_message_chunk(
if role == "user":
return HumanMessageChunk(content=content)
elif role == "assistant":
# NOTE: if tool calls are present, then it's an assistant.
# In Ollama, the role will be None for tool-calls
elif role == "assistant" or tool_calls:
if tool_calls:
tool_call = tool_calls[0]
tool_name = tool_call.function.name or (curr_msg and curr_msg.name) or ""
@@ -236,6 +240,7 @@ class DefaultMultiLLM(LLM):
custom_config: dict[str, str] | None = None,
extra_headers: dict[str, str] | None = None,
extra_body: dict | None = LITELLM_EXTRA_BODY,
model_kwargs: dict[str, Any] | None = None,
long_term_logger: LongTermLogger | None = None,
):
self._timeout = timeout
@@ -263,12 +268,16 @@ class DefaultMultiLLM(LLM):
# NOTE: have to set these as environment variables for Litellm since
# not all are able to passed in but they always support them set as env
# variables
# variables. We'll also try passing them in, since litellm just ignores
# addtional kwargs (and some kwargs MUST be passed in rather than set as
# env variables)
if custom_config:
for k, v in custom_config.items():
os.environ[k] = v
model_kwargs: dict[str, Any] = {}
model_kwargs = model_kwargs or {}
if custom_config:
model_kwargs.update(custom_config)
if extra_headers:
model_kwargs.update({"extra_headers": extra_headers})
if extra_body:

Some files were not shown because too many files have changed in this diff Show More