Compare commits

..

186 Commits
nit ... bot_nit

Author SHA1 Message Date
pablodanswer
6ff78e077d nit 2024-12-06 12:57:43 -08:00
pablodanswer
c01512f846 fix slackbot 2024-12-06 12:56:46 -08:00
rkuo-danswer
7a3c06c2d2 first cut at slack oauth flow (#3323)
* first cut at slack oauth flow

* fix usage of hooks

* fix button spacing

* add additional error logging

* no dev redirect

* cleanup

* comment work in progress

* move some stuff to ee, add some playwright tests for the oauth callback edge cases

* fix ee, fix test name

* fix tests

* code review fixes
2024-12-06 19:55:21 +00:00
pablodanswer
7a0d823c89 Improved file handling (#3353)
* update props

* update documents

* nit

* update chat processing

* k

* k

* nit

* minor nit

* minor nits

* k

* nits
2024-12-06 19:16:54 +00:00
Yuhong Sun
db69e445d6 k (#3358) 2024-12-06 18:08:44 +00:00
Weves
18e63889b7 Change default log level back to info 2024-12-06 10:07:14 -08:00
Weves
738e60c8ed Increase vespa attempts on startup 2024-12-06 09:46:33 -08:00
hagen-danswer
8aec873e66 Merge pull request #3359 from danswer-ai/conf-logging-filter
Added filter to slim connector and logging for space permissions
2024-12-06 09:03:07 -08:00
hagen-danswer
7c57dde8ab fixed test 2024-12-06 08:33:12 -08:00
hagen-danswer
f30adab853 Merge remote-tracking branch 'origin/main' into conf-logging-filter 2024-12-06 08:30:07 -08:00
hagen-danswer
601687a522 Add test for Confluence permissions 2024-12-06 08:28:42 -08:00
hagen-danswer
350cf407c9 explicitly set page and attachment restrictions and space keys 2024-12-06 08:12:07 -08:00
hagen-danswer
32ec4efc7a tygod for tests 2024-12-06 08:03:34 -08:00
hagen-danswer
7c6981e052 Added filter to slim connector and logging for space permissions 2024-12-06 07:55:54 -08:00
Yuhong Sun
c50cd20156 Fix SlackBot Page Bugs (#3354) 2024-12-05 13:17:04 -08:00
hagen-danswer
14772dee71 Add persona stats (#3282)
* Added a chart to display persona message stats

* polish

* k

* hope this works

* cleanup
2024-12-05 17:15:56 +00:00
pablodanswer
c81e704c95 various niceties (#3348) 2024-12-05 17:12:52 +00:00
Chris Weaver
3266ef6321 Improve chat page performance (#3347)
* Simplify /manage/indexing-status

* Rename endpoint
2024-12-04 20:28:30 -08:00
pablodanswer
c89b98b4f2 update email invites (#3349) 2024-12-05 03:29:07 +00:00
rkuo-danswer
e70e0ab859 Merge pull request #3346 from danswer-ai/bugfix/chromatic-tests-2
Bugfix/chromatic tests 2
2024-12-04 19:44:05 -08:00
Richard Kuo (Danswer)
69b6e9321e Merge branch 'main' of https://github.com/danswer-ai/danswer into bugfix/chromatic-tests-2
# Conflicts:
#	web/tests/e2e/home.spec.ts
2024-12-04 19:10:25 -08:00
Chris Weaver
7e53af18b6 Add b64 image support for image generation (#3342)
* Add b64 image support

* Fix

* enhance

* Fix mypy

* Fix imports
2024-12-05 02:24:54 +00:00
Richard Kuo (Danswer)
b9eb1ca2ba wait for whole placeholder string 2024-12-04 18:23:06 -08:00
rkuo-danswer
91d44c83d2 fixing chromatic tests (#3344)
* wait for the page to load

* fix up tests

* make sure "Initializing Danswer" is gone
2024-12-05 02:19:43 +00:00
Richard Kuo (Danswer)
4dbc6bb4d1 make sure "Initializing Danswer" is gone 2024-12-04 17:49:59 -08:00
Richard Kuo (Danswer)
4b6a4c6bbf fix up tests 2024-12-04 17:19:16 -08:00
pablodanswer
fd1999454a ensure we can order by doc id (#3343) 2024-12-05 01:10:37 +00:00
Richard Kuo (Danswer)
0a35422d1d wait for the page to load 2024-12-04 16:47:42 -08:00
pablodanswer
69b99056b2 Redirect to chat (#3341)
* k

* nit
2024-12-05 00:08:52 +00:00
Yuhong Sun
2a55696545 Move Answer (#3339) 2024-12-04 16:30:47 -08:00
hagen-danswer
ef9942b751 Related permission docs to cc_pair to prevent orphan docs (#3336)
* Related permission docs to cc_pair to prevent orphan docs

* added script

* group sync deduping

* logging
2024-12-04 21:00:54 +00:00
pablodanswer
993acec5e9 Update memoization + silence unnecessary errors (#3337)
* update memoization + silence unnecessary errors

* proper org
2024-12-04 20:08:15 +00:00
Weves
b01a1b509a Add basic loadtest script 2024-12-04 10:53:48 -08:00
pablodanswer
4f994124ef remove now unnecessary user loading indicatort log (#3333) 2024-12-04 00:09:22 +00:00
rkuo-danswer
14863bd457 try single threaded playwright testing (#3322) 2024-12-03 23:21:46 +00:00
Yuhong Sun
aa1c4c635a Combining Search and Chat Backend (#3273)
* k

* k

* fix slack issues

* rebase

* k
2024-12-03 22:37:14 +00:00
rkuo-danswer
13f6e8a6b4 disable thread local locking in callbacks (#3319) 2024-12-03 22:32:56 +00:00
pablodanswer
66f47d294c Shared filter utility for clarity (#3270)
* shared filter util

* clearer comment
2024-12-03 19:30:42 +00:00
pablodanswer
0a685bda7d add comments for clarity (#3249) 2024-12-03 19:27:28 +00:00
pablodanswer
23dc8b5dad Search flow improvements (#3314)
* untoggle if no docs

* update

* nits

* nit

* typing

* nit
2024-12-03 18:56:27 +00:00
pablodanswer
cd5f2293ad Temperature (#3310)
* fix temperatures for default llm

* ensure anthropic models don't overflow

* minor cleanup

* k

* k

* k

* fix typing
2024-12-03 17:22:22 +00:00
rkuo-danswer
6c2269e565 refactor celery task names to constants (#3296) 2024-12-03 16:02:17 +00:00
Weves
46315cddf1 Adjust default confulence timezone 2024-12-02 22:25:29 -08:00
rkuo-danswer
5f28a1b0e4 Bugfix/confluence time zone (#3265)
* RedisLock typing

* checkpoint

* put in debug logging

* improve comments

* mypy fixes
2024-12-02 22:23:23 -08:00
rkuo-danswer
9e9b7ed61d Bugfix/connector aborted logging (#3309)
* improve error logging on task failure.

* add db exception hardening to the indexing watchdog

* log on db exception
2024-12-03 02:34:40 +00:00
pablodanswer
3fb2bfefec Update Chromatic Tests (#3300)
* remove / update search tests

* minor update
2024-12-02 23:08:54 +00:00
pablodanswer
7c618c9d17 Unified UI (#3308)
* fix typing

* add filters display
2024-12-02 15:12:13 -08:00
pablodanswer
03e2789392 Text embedding (PDF, TXT) (#3113)
* add text embedding

* post rebase cleanup

* fully functional post rebase

* rm logs

* rm '

* quick clean up

* k
2024-12-02 22:43:53 +00:00
Chris Weaver
2783fa08a3 Update openai version in model server (#3306) 2024-12-02 21:39:10 +00:00
pablodanswer
edeaee93a2 hard refresh on auth (#3305)
* hard refresh on auth

* k

* k

* comment for clarity
2024-12-02 20:12:12 +00:00
hagen-danswer
5385bae100 Add slim connector description (#3303)
* added docs example and test

* updated docs

* needed to make the tests run

* updated docs
2024-12-02 19:52:13 +00:00
pablodanswer
813445ab59 Minor JWT Feature (#3290)
* first pass

* k

* k

* finalize

* minor cleanup

* k

* address

* minor typing updates
2024-12-02 19:14:31 +00:00
pablodanswer
af814823c8 display name + model truncation (#3304) 2024-12-02 18:54:08 +00:00
pablodanswer
607f61eaeb Reusable function for search settings spread operation (#3301)
* combine for clarity once and for all

* remove logs

* k
2024-12-02 17:23:01 +00:00
pablodanswer
de66f7adb2 Updated chat flow (#3244)
* proper no assistant typing + no assistant modal

* updated chat flow

* k

* updates

* update

* k

* clean up

* fix mystery reorg

* cleanup

* update scroll

* default

* update logs

* push fade

* scroll nit

* finalize tags

* updates

* k

* various updates

* viewport height update

* source types update

* clean up unused components

* minor cleanup

* cleanup complete

* finalize changes

* badge up

* update filters

* small nit

* k

* k

* address comments

* quick unification of icons

* minor date range clarity

* minor nit

* k

* update sidebar line

* update for all screen sizes

* k

* k

* k

* k

* rm shs

* fix memoization

* fix memoization

* slack chat

* k

* k

* build org
2024-12-02 01:58:28 +00:00
Yuhong Sun
3432d932d1 Citation code comments 2024-12-01 14:10:11 -08:00
Yuhong Sun
9bd0cb9eb5 Fix Citation Minor Bugs (#3294) 2024-12-01 13:55:24 -08:00
Chris Weaver
f12eb4a5cf Fix assistant prompt zero-ing (#3293) 2024-11-30 04:45:40 +00:00
Chris Weaver
16863de0aa Improve model token limit detection (#3292)
* Properly find context window for ollama llama

* Better ollama support + upgrade litellm

* Ugprade OpenAI as well

* Fix mypy
2024-11-30 04:42:56 +00:00
Weves
63d1eefee5 Add read_only=True for xlsx parsing 2024-11-28 16:02:02 -08:00
pablodanswer
e338677896 order seeding 2024-11-28 15:41:10 -08:00
hagen-danswer
7be80c4af9 increased the pagination limit for confluence spaces (#3288) 2024-11-28 19:04:38 +00:00
rkuo-danswer
7f1e4a02bf Feature/kill indexing (#3213)
* checkpoint

* add celery termination of the task

* rename to RedisConnectorPermissionSyncPayload, add RedisLock to more places, add get_active_search_settings

* rename payload

* pretty sure these weren't named correctly

* testing in progress

* cleanup

* remove space

* merge fix

* three dots animation on Pausing

* improve messaging when connector is stopped or killed and animate buttons

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-11-28 05:32:45 +00:00
rkuo-danswer
5be7d27285 use indexing flag in db for manually triggering indexing (#3264)
* use indexing flag in db for manually trigger indexing

* add comment.

* only try to release the lock if we actually succeeded with the lock

* ensure we don't trigger manual indexing on anything but the primary search settings

* comment usage of primary search settings

* run check for indexing immediately after indexing triggers are set

* reorder fix
2024-11-28 01:34:34 +00:00
Weves
fd84b7a768 Remove duplicate API key router 2024-11-27 16:30:59 -08:00
Subash-Mohan
36941ae663 fix: Cannot configure API keys #3191 2024-11-27 16:25:00 -08:00
Matthew Holland
212353ed4a Fixed default feedback options 2024-11-27 16:23:52 -08:00
Richard Kuo (Danswer)
eb8708f770 the word "error" might be throwing off sentry 2024-11-27 14:31:21 -08:00
Chris Weaver
ac448956e9 Add handling for rate limiting (#3280) 2024-11-27 14:22:15 -08:00
pablodanswer
634a0b9398 no stack by default (#3278) 2024-11-27 20:58:21 +00:00
hagen-danswer
09d3e47c03 Perm sync behavior change (#3262)
* Change external permissions behavior

* fixed behavior

* added error handling

* LLM the goat

* comment

* simplify

* fixed

* done

* limits increased

* added a ton of logging

* uhhhh
2024-11-27 20:04:15 +00:00
pablodanswer
9c0cc94f15 refresh router -> refresh assistants (#3271) 2024-11-27 19:11:58 +00:00
hagen-danswer
07dfde2209 add continue in danswer button to slack bot responses (#3239)
* all done except routing

* fixed initial changes

* added backend endpoint for duplicating a chat session from Slack

* got chat duplication routing done

* got login routing working

* improved answer handling

* finished all checks

* finished all!

* made sure it works with google oauth

* dont remove that lol

* fixed weird thing

* bad comments
2024-11-27 18:25:38 +00:00
pablodanswer
28e2b78b2e Fix search dropdown (#3269)
* validate dropdown

* validate

* update organization

* move to utils
2024-11-27 16:10:07 +00:00
Emerson Gomes
0553062ac6 Adds icons for Google Gemini models and custom model icons for L… (#3218)
* Add description for Google Gemini models and custom model icons for LiteLLM (OpenAI) proxied models

* Adds Vertex AI aliases for Claude

---------

Co-authored-by: Emerson Gomes <emerson.gomes@thalesgroup.com>
2024-11-26 10:13:21 -08:00
hagen-danswer
284e375ba3 Merge pull request #3257 from danswer-ai/minor-perm-sync
Improved logging for confluence doc sync and robust user creation
2024-11-26 09:59:38 -08:00
hagen-danswer
1f2f7d0ac2 Improved logging for confluence doc sync and robust user creation 2024-11-26 08:51:15 -08:00
pablodanswer
2ecc28b57d remove unused stripe promise (#3248) 2024-11-26 01:50:39 +00:00
rkuo-danswer
77cf9b3539 improve messaging and UI around cleanup of leftover index attempts (#3247)
* improve messaging and UI around cleanup of leftover index attempts

* add tag on init
2024-11-25 22:27:14 +00:00
Weves
076ce2ebd0 Saml fix 2024-11-25 09:12:43 -08:00
pablodanswer
b625ee32a7 File handling cleanup (#3240)
* fix google sites connector

* minior cleanup

* rm comments
2024-11-25 04:06:47 +00:00
Richard Kuo (Danswer)
c32b93fcc3 increase indexing worker concurrency to 3 2024-11-24 18:11:58 -08:00
pablodanswer
1c8476072e Assistant cleanup (#3236)
* minor cleanup

* ensure users don't modify built-in attributes of assistants

* update sidebar

* k

* update update flow + assistant creation
2024-11-25 00:13:34 +00:00
Chris Weaver
7573416ca1 Fix API keys for MIT users (#3237) 2024-11-24 16:55:19 -08:00
Yuhong Sun
86d8666481 Add Test Case 2024-11-24 15:42:14 -08:00
Yuhong Sun
8abcde91d4 Fix Test (#3242) 2024-11-24 14:31:28 -08:00
Yuhong Sun
3466451d51 Fix Prompt for Non Function Calling LLMs (#3241) 2024-11-24 14:16:57 -08:00
Yuhong Sun
413891f143 Token Level Log (#3238) 2024-11-23 18:41:50 -08:00
Yuhong Sun
7a0a4d4b79 Remove Deprecated Endpoints (#3235) 2024-11-23 14:44:23 -08:00
Yuhong Sun
a3439605a5 Remove Dead Code (#3234) 2024-11-23 14:31:59 -08:00
pablodanswer
694e79f5e1 minor enforcement of CSV length for internal processing (#3109) 2024-11-23 21:05:30 +00:00
pablodanswer
5dfafc8612 minor calendar cleanup (#3219) 2024-11-23 21:01:05 +00:00
Yuhong Sun
62a4aa10db Refactor Search (#3233) 2024-11-23 13:42:54 -08:00
Yuhong Sun
a357cdc4c9 Remove Dead Code (#3232) 2024-11-23 13:21:27 -08:00
Yuhong Sun
84615abfdd Seeding (#3231) 2024-11-23 13:12:42 -08:00
pablodanswer
8ae6b1960b Bugfix/usage report (#3075)
* fix pagination

* update side

* fixed query history

* minor update

* minor update

* typing
2024-11-23 20:11:39 +00:00
James Jordan
d9b87bbbc2 Fixed 400 error when author of ticket is no longer an active user in a Zendesk account. (#3168) 2024-11-23 12:15:38 -08:00
Sanju Lokuhitige
a0065b01af Update CONTRIBUTING.md (#3112)
fix Formatting and Linting hyperlink
2024-11-23 12:13:23 -08:00
pablodanswer
c5306148a3 Ensure daterange not consistently re rendered (#3229)
* ensure daterange not consistently re rendered

* minor clean up
2024-11-23 19:35:00 +00:00
hagen-danswer
1e17934de4 Merge pull request #3214 from danswer-ai/fix-slack-ui
cleaned up new slack bot creation
2024-11-23 10:53:47 -08:00
pablodanswer
93add96ccc Various Nits (#3228) 2024-11-23 10:53:24 -08:00
rkuo-danswer
3a466a4b08 add minimal retries to confluence probe (#3222)
* add minimal retries to confluence probe

* name variable correctly
2024-11-23 17:11:15 +00:00
hagen-danswer
85cbd9caed Increased slim doc batch size for confluence connector (#3221) 2024-11-23 00:42:15 +00:00
pablodanswer
9dc23bf3e7 revert to previous doc select logic (#3217)
* revert to previous doc select logic

* k
2024-11-22 23:26:53 +00:00
hagen-danswer
e32809f7ca moved it outside 2024-11-22 14:59:58 -08:00
hagen-danswer
3e58f9f8ab fixed ugly stuff 2024-11-22 14:39:55 -08:00
pablodanswer
2381c8d498 Refresh all assistants on assistant refresh (#3216)
* k

* k
2024-11-22 22:38:23 +00:00
hagen-danswer
c6dadb24dc cleaned up new slack bot creation 2024-11-22 11:53:51 -08:00
hagen-danswer
5dc07d4178 Each section is now cleaned before being chunked (#3210)
* Each section is now cleaned before being chunked

* k

---------

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2024-11-22 19:06:19 +00:00
Chris Weaver
129c8f8faf Add start/end date ability for query history as CSV endpoint (#3211) 2024-11-22 18:29:13 +00:00
pablodanswer
67bfcabbc5 llm provider causing re render in effect (#3205)
* llm provider causing re render in effect

* clean

* unused

* k
2024-11-22 16:53:24 +00:00
rkuo-danswer
9819aa977a implement double check pattern for error conditions (#3201)
* Move unfenced check to check_for_indexing. implement a double check pattern for all indexing error checks

* improved commenting

* exclusions
2024-11-22 04:21:02 +00:00
hagen-danswer
8d5b8a4028 Merge pull request #3202 from danswer-ai/toggled_chat_default
Update default sidebar toggle
2024-11-21 19:53:05 -08:00
pablodanswer
682319d2e9 Bugfix/curator interface (#3198)
* mystery solved

* update config

* update

* update

* update user role

* remove values
2024-11-22 02:33:09 +00:00
hagen-danswer
fe1400aa36 replace deprecated confluence group api endpoint (#3197)
* replace deprecated confluence group api endpoint

* reworked it

* properly escaped the user query

* less passing around is_cloud

* done
2024-11-22 01:51:29 +00:00
pablodanswer
e3573b2bc1 add comment 2024-11-21 17:11:11 -08:00
pablodanswer
35b5c44cc7 update default sidebar toggle 2024-11-21 17:09:56 -08:00
rkuo-danswer
5eddc89b5a merge indexing and heartbeat callbacks (and associated lock reacquisi… (#3178)
* merge indexing and heartbeat callbacks (and associated lock reacquisition). no db updates

* review fixes
2024-11-21 23:48:58 +00:00
hagen-danswer
9a492ceb6d admins cant be set as curator on backend (#3194)
* set-curator

* updated error
2024-11-21 23:33:29 +00:00
rkuo-danswer
3c54ae9de9 Bugfix/redis wait (#3169)
* rename to payload

* log redis info replication on primary worker startup

* fix mypy

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-11-21 23:11:00 +00:00
pablodanswer
13f08f3ebb Horizontal scrollbar (#3195)
* clean horizontal scrollbar

* account for additional edge case
2024-11-21 22:08:21 +00:00
pablodanswer
bd9f15854f provider fix (#3187)
* clean horizontal scrollbar

* provider fix

* ensure proper migration

* k

* update migration

* Revert "clean horizontal scrollbar"

This reverts commit fa592a1b7a.
2024-11-21 22:08:16 +00:00
pablodanswer
366aa2a8ea quick fix (#3200) 2024-11-21 14:07:55 -08:00
pablodanswer
deee237c7e Sheet update (#3189)
* quick pass

* k

* update sheet

* add multiple sheet stuff

* k

* finalized

* update configuration
2024-11-21 18:07:00 +00:00
hagen-danswer
100b4a0d16 Added Slim connector for Jira (#3181)
* Added Slim connector for Jira

* fixed testing

* more cleanup of Jira connector

* cleanup
2024-11-21 17:00:20 +00:00
rkuo-danswer
70207b4b39 improve web testing (#3162)
* shared admin level test dependency

* change to on - push (recommended by chromatic)

* change playwright reporter to list, name test jobs

* use test tags ... much cleaner

* test vs prod

* try copying templates

* run with localhost?

* revert to dev

* new tests and a bit of refactoring

* add additional checks so that page snapshots reflect loaded state

* more admin tests

* User Management tests

* remaining admin pages

* test search and chat

* await fix and exclude UI that changes with dates.
2024-11-21 04:01:15 +00:00
pablodanswer
50826b6bef Formatting Niceties (#3183)
* search bar formatting

* update styling
2024-11-21 03:11:26 +00:00
pablodanswer
3f648cbc31 Folder clarity (#3180)
* folder clarity

* k
2024-11-21 03:11:17 +00:00
pablodanswer
c875a4774f valid props (#3186) 2024-11-21 01:13:54 +00:00
hagen-danswer
049091eb01 decreased confluence retry times and added more logging (#3184)
* decreased confluence retry times and added more logging

* added check on connector startup

* no retries!

* fr no retries
2024-11-21 00:00:14 +00:00
pablodanswer
3dac24542b silence small error (#3182) 2024-11-20 22:46:38 +00:00
pablodanswer
194dcb593d update slack redirect + token missing check (#3179)
* update slack redirect + token missing check

* reset time
2024-11-20 21:42:54 +00:00
pablodanswer
bf291d0c0a Fix missing json (#3177)
* initial steps

* k

* remove logs

* k

* k
2024-11-20 21:24:43 +00:00
rkuo-danswer
8309f4a802 test overlapping connectors (but using a source that is way too big a… (#3152)
* test overlapping connectors (but using a source that is way too big and slow, fix that next)

* pass thru secrets

* rename

* rename again

* now we are fixing it

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-11-20 21:12:01 +00:00
pablodanswer
0ff2565125 ensure margin properly applied (#3176)
* ensure margin properly applied

* formatting
2024-11-20 20:04:45 +00:00
hagen-danswer
e89dcd7f84 added logging and bugfixing to conf (#3167)
* standardized escaping of CQL strings

* think i found it

* fix

* should be fixed

* added handling for special linking behavior in confluence

* Update onyx_confluence.py

* Update onyx_confluence.py

---------

Co-authored-by: rkuo-danswer <rkuo@danswer.ai>
2024-11-20 18:40:21 +00:00
pablodanswer
645e7e828e Add Google Tag Manager for Web Cloud Build (#3173)
* add gtm for cloud build

* update github workflow
2024-11-20 17:38:33 +00:00
pablodanswer
2a54f14195 ensure everythigng has a default max height in selectorformfield (#3174) 2024-11-20 17:26:22 +00:00
hagen-danswer
9209fc804b multiple slackbot support (#3077)
* multiple slackbot support

* app_id + tenant_id key

* removed kv store stuff

* fixed up mypy and migration

* got frontend working for multiple slack bots

* some frontend stuff

* alembic fix

* might be valid

* refactor dun

* alembic stuff

* temp frontend stuff

* alembic stuff

* maybe fixed alembic

* maybe dis fix

* im getting mad

* api names changed

* tested

* almost done

* done

* routing nonsense

* done!

* done!!

* fr done

* doneski

* fix alembic migration

* getting mad again

* PLEASE IM BEGGING YOU
2024-11-20 01:49:43 +00:00
rkuo-danswer
b712877701 Merge pull request #3165 from danswer-ai/bugfix/pruning_logs
improve logging around pruning
2024-11-19 13:19:31 -08:00
Richard Kuo (Danswer)
e6df32dcc3 improve logging around pruning 2024-11-19 12:41:21 -08:00
Chris Weaver
eb81258a23 Update README.md
Fix slack link
2024-11-19 08:02:35 -08:00
hagen-danswer
487ef4acc0 Merge pull request #3160 from danswer-ai/add-to-admin-chat-sessions-api
Extend query history API
2024-11-19 07:28:12 -08:00
pablodanswer
9b7cc83eae add new date search filter (#3065)
* add new complicated filters

* clarity updates

* update date range filter
2024-11-19 03:42:42 +00:00
Weves
ce3124f9e4 Extend query history API 2024-11-18 17:50:21 -08:00
rkuo-danswer
e69303e309 add helpful hint on 507 (#3157)
* add helpful hint on 507

* add helpful hint to the direct exception in _index_vespa_chunk
2024-11-19 01:08:32 +00:00
rkuo-danswer
6e698ac84a Hardening deletion when cc pair relationships are left over (#3154)
* more logs

* this fence should be set to None

* type hinting

* reset deletion attempt if conditions are inconsistent

* always clean up in db if we reach reconciliation

* add reset method

* more logging

* harden up error checking
2024-11-19 01:07:59 +00:00
pablodanswer
d69180aeb8 add additional theming options (#3155)
* add additional theming options

* nit

* Update Filters.tsx
2024-11-18 22:56:48 +00:00
rkuo-danswer
aa37051be9 Bugfix/indexing redux (#3151)
* raise indexing lock timeout

* refactor unknown index attempts and redis lock
2024-11-18 22:47:31 +00:00
pablodanswer
a7d95661b3 Add assistant categories (#3064)
* add assistant categories v1

* functionality finalized

* finalize

* update assistant category display

* nit

* add tests

* post rebase update

* minor update to tests

* update typing

* finalize

* typing

* nit

* alembic

* alembic (once again)
2024-11-18 20:33:48 +00:00
Chris Weaver
33ee899408 Long term logs (#3150) 2024-11-18 10:48:03 -08:00
hagen-danswer
954b5b2a56 Made external permissioned users and slack users show diff (#3147)
* Made external permissioned users and slack users show diff

* finished

* Fix typing

* k

* Fix

* k

---------

Co-authored-by: Weves <chrisweaver101@gmail.com>
2024-11-17 01:13:47 +00:00
pablodanswer
521425a4f2 nits + pricing 2024-11-16 16:28:37 -08:00
hagen-danswer
618bc02d54 Fixed int test (#3148) 2024-11-16 18:13:06 +00:00
rkuo-danswer
b7de74fdf8 Feature/playwright tests (#3129)
* initial PoC

* preliminary working config

* first cut at chromatic tests

* first cut at chromatic tests

* fix yaml

* fix yaml again

* use workingDir

* adapt playwright example

* remove env

* fix working directory

* fix more paths

* fix dir

* add playwright setup

* accidentally deleted a step

* update test

* think we don't need home.png right now

* remove unused home.png

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-11-16 04:26:17 +00:00
hagen-danswer
6e83fe3a39 reworked drive+confluence frontend and implied backend changes (#3143)
* reworked drive+confluence frontend and implied backend changes

* fixed oauth admin tests

* fixed service account tests

* frontend cleanup

* copy change

* details!

* added key

* so good

* whoops!

* fixed mnore treljsertjoslijt

* has issue with boolean form

* should be done
2024-11-16 03:38:30 +00:00
Weves
259fc049b7 Add error message on JSON decode error in CustomTool 2024-11-15 20:00:12 -08:00
rkuo-danswer
7015e6f2ab Bugfix/overlapping connectors (#3138)
* fix tenant logging

* upsert only new/updated docs, but always upsert document to cc pair relationship

* better logging and rough cut at testing
2024-11-16 00:47:52 +00:00
pablodanswer
24be13c015 Improved tokenizer fallback (#3132)
* silence warning

* improved fallback logic

* k

* minor cosmetic update

* minor logic update

* nit
2024-11-14 20:13:29 -08:00
pablodanswer
ddff7ecc3f minor configuration updates (#3134) 2024-11-14 18:09:30 -08:00
Yuhong Sun
97932dc44b Fix Quotes Prompting (#3137) 2024-11-14 17:28:03 -08:00
rkuo-danswer
637b6d9e75 Merge pull request #3135 from danswer-ai/bugfix/helm_ct_python_setup
unnecessary python setup
2024-11-14 14:57:12 -08:00
Richard Kuo (Danswer)
54dc1ac917 unnecessary python setup 2024-11-14 11:14:12 -08:00
rkuo-danswer
21d5cc43f8 Merge pull request #3131 from danswer-ai/bugfix/session_text
use text()
2024-11-13 20:24:14 -08:00
pablodanswer
7c841051ed Cohere (#3111)
* add cohere default

* finalize

* minor improvement

* update

* update

* update configs

* ensure we properly expose name(space) for slackbot

* update config

* config
2024-11-14 01:58:54 +00:00
pablodanswer
6e91964924 minor clarity (#3116) 2024-11-14 01:42:21 +00:00
pablodanswer
facf1d55a0 Cloud improvements (#3099)
* add improved cloud configuration

* fix typing

* finalize slackbot improvements

* minor update

* finalized keda

* moderate slackbot switch

* update some configs

* revert

* include reset engine!
2024-11-13 23:52:52 +00:00
rkuo-danswer
d68f8d6fbc scale indexing sql pool based on concurrency (#3130) 2024-11-13 23:26:13 +00:00
Richard Kuo (Danswer)
65a205d488 use text() 2024-11-13 15:03:21 -08:00
hagen-danswer
485f3f72fa Updated google copy and added non admin oauth support (#3120)
* Updated google copy and added non admin oauth support

* backend update

* accounted for oauth

* further removed class variables

* updated sets
2024-11-13 20:07:10 +00:00
rkuo-danswer
dcbea883ae add creator id to cc pair (#3121)
* add creator id to cc pair

* fix alembic head

* show email instead of UUID

* safer check on email

* make foreign key relationships optional

* always allow creator to edit (per hagen)

* use primary join

* no index_doc_batch spam

* try this again

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-11-13 19:35:08 +00:00
hagen-danswer
a50a3944b3 Make curators able to create permission synced connectors (#3126)
* Make curators able to create permission synced connectors

* removed editing permission synced connectors for curators

* updated tests to use access type instead of is_public

* update copy
2024-11-13 18:58:23 +00:00
hagen-danswer
60471b6a73 Added support for page within a page in Confluence (#3125) 2024-11-13 16:39:00 +00:00
rkuo-danswer
d703e694ce limited role api keys (#3115)
* in progress PoC

* working limited user, needs routes to be marked next

* make selected endpoint available to limited user role

* xfail on test_slack_prune

* add comment to sync function

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-11-13 16:15:43 +00:00
hagen-danswer
6066042fef Merge pull request #3124 from danswer-ai/fix-doc-sync
quick fix for google doc sync
2024-11-13 07:30:52 -08:00
hagen-danswer
eb0e20b9e4 quick fix for google doc sync 2024-11-13 07:24:29 -08:00
pablodanswer
490a68773b update organization (#3118)
* update organization

* minor clean up

* add minor clarity

* k

* slight rejigger

* alembic fix

* update paradigm

* delete code!

* delete code

* minor update
2024-11-13 06:45:32 +00:00
rkuo-danswer
227aff1e47 clean up logging in light worker (#3072) 2024-11-13 03:42:02 +00:00
Weves
6e29d1944c Fix widget example 2024-11-12 18:48:44 -08:00
pablodanswer
22189f02c6 Add referral source to cloud on data plane (#3096)
* cloud auth referral source

* minor clarity

* k

* minor modification to be best practice

* typing

* Update ReferralSourceSelector.tsx

* Update ReferralSourceSelector.tsx

---------

Co-authored-by: hagen-danswer <hagen@danswer.ai>
2024-11-13 00:42:25 +00:00
hagen-danswer
fdc4811fce doc sync celery refactor (#3084)
* doc_sync is refactored

* maybe this works

* tested to work!

* mypy fixes

* enabled integration tests

* fixed the test

* added external group sync

* testing should work now

* mypy

* confluence doc id fix

* got group sync working

* addressed feedback

* renamed some vars and fixed mypy

* conf fix?

* added wiki handling to confluence connector

* test fixes

* revert google drive connector

* fixed groups

* hotfix
2024-11-12 23:57:14 +00:00
Chris Weaver
021d0cf314 Support LITELLM_EXTRA_BODY env variable (#3119)
* Support LITELLM_EXTRA_BODY env variable

* Remove unused param

* Add comment
2024-11-12 23:17:44 +00:00
pablodanswer
942e47db29 improved mobile scroll (#3110) 2024-11-12 01:57:49 +00:00
pablodanswer
f4a020b599 moderate component fixes (#3095)
* moderate component fixes

* nit

* nit

* update colors

* k
2024-11-12 00:47:35 +00:00
pablodanswer
5166649eae Cleaner EE fallback for no op (#3106)
* treat async values differently

* cleaner approach

* spacing

* typing
2024-11-11 17:42:14 +00:00
Chris Weaver
ba805f766f New assistants api (#3097) 2024-11-11 07:55:23 -08:00
591 changed files with 29560 additions and 21788 deletions

View File

@@ -65,6 +65,7 @@ jobs:
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_GTM_ENABLED=true
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}

225
.github/workflows/pr-chromatic-tests.yml vendored Normal file
View File

@@ -0,0 +1,225 @@
name: Run Chromatic Tests
concurrency:
group: Run-Chromatic-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on: push
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
jobs:
playwright-tests:
name: Playwright Tests
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
cache: 'pip'
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
- name: Setup node
uses: actions/setup-node@v4
with:
node-version: 22
- name: Install node dependencies
working-directory: ./web
run: npm ci
- name: Install playwright browsers
working-directory: ./web
run: npx playwright install --with-deps
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# tag every docker image with "test" so that we can spin up the correct set
# of images during testing
# we use the runs-on cache for docker builds
# in conjunction with runs-on runners, it has better speed and unlimited caching
# https://runs-on.com/caching/s3-cache-for-github-actions/
# https://runs-on.com/caching/docker/
# https://github.com/moby/buildkit#s3-cache-experimental
# images are built and run locally for testing purposes. Not pushed.
- name: Build Web Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./web
file: ./web/Dockerfile
platforms: linux/amd64
tags: danswer/danswer-web-server:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/web-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/web-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build Backend Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64
tags: danswer/danswer-backend:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build Model Server Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64
tags: danswer/danswer-model-server:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Start Docker containers
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
AUTH_TYPE=basic \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
docker logs -f danswer-stack-api_server-1 &
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Service did not become ready in 5 minutes."
exit 1
fi
# Use curl with error handling to ignore specific exit code 56
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
if [ "$response" = "200" ]; then
echo "Service is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
else
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
sleep 5
done
echo "Finished waiting for service."
- name: Run pytest playwright test init
working-directory: ./backend
env:
PYTEST_IGNORE_SKIP: true
run: pytest -s tests/integration/tests/playwright/test_playwright.py
- name: Run Playwright tests
working-directory: ./web
run: npx playwright test
- uses: actions/upload-artifact@v4
if: always()
with:
# Chromatic automatically defaults to the test-results directory.
# Replace with the path to your custom directory and adjust the CHROMATIC_ARCHIVE_LOCATION environment variable accordingly.
name: test-results
path: ./web/test-results
retention-days: 30
# save before stopping the containers so the logs can be captured
- name: Save Docker logs
if: success() || failure()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@v4
with:
name: docker-logs
path: ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
chromatic-tests:
name: Chromatic Tests
needs: playwright-tests
runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup node
uses: actions/setup-node@v4
with:
node-version: 22
- name: Install node dependencies
working-directory: ./web
run: npm ci
- name: Download Playwright test results
uses: actions/download-artifact@v4
with:
name: test-results
path: ./web/test-results
- name: Run Chromatic
uses: chromaui/action@latest
with:
playwright: true
projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
workingDir: ./web
env:
CHROMATIC_ARCHIVE_LOCATION: ./test-results

View File

@@ -23,21 +23,6 @@ jobs:
with:
version: v3.14.4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
cache: 'pip'
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- run: |
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
- name: Set up chart-testing
uses: helm/chart-testing-action@v2.6.1
@@ -52,6 +37,22 @@ jobs:
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
# rkuo: I don't think we need python?
# - name: Set up Python
# uses: actions/setup-python@v5
# with:
# python-version: '3.11'
# cache: 'pip'
# cache-dependency-path: |
# backend/requirements/default.txt
# backend/requirements/dev.txt
# backend/requirements/model_server.txt
# - run: |
# python -m pip install --upgrade pip
# pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
# pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
# pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
# lint all charts if any changes were detected
- name: Run chart-testing (lint)
if: steps.list-changed.outputs.changed == 'true'

View File

@@ -13,7 +13,10 @@ on:
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
jobs:
integration-tests:
# See https://runs-on.com/runners/linux/
@@ -195,9 +198,13 @@ jobs:
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
danswer/danswer-integration:test \
/app/tests/integration/tests
/app/tests/integration/tests \
/app/tests/integration/connector_job_tests
continue-on-error: true
id: run_tests

View File

@@ -20,9 +20,12 @@ env:
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
# Google
GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR }}
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1 }}
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }}
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
# Slab
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
jobs:
connectors-check:

1
.gitignore vendored
View File

@@ -7,3 +7,4 @@
.vscode/
*.sw?
/backend/tests/regression/answer_quality/search_test_config.yaml
/web/test-results/

View File

@@ -203,7 +203,7 @@
"--loglevel=INFO",
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
],
"presentation": {
"group": "2",
@@ -232,7 +232,7 @@
"--loglevel=INFO",
"--hostname=heavy@%n",
"-Q",
"connector_pruning",
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync",
],
"presentation": {
"group": "2",

View File

@@ -32,7 +32,7 @@ To contribute to this project, please follow the
When opening a pull request, mention related issues and feel free to tag relevant maintainers.
Before creating a pull request please make sure that the new changes conform to the formatting and linting requirements.
See the [Formatting and Linting](#-formatting-and-linting) section for how to run these checks locally.
See the [Formatting and Linting](#formatting-and-linting) section for how to run these checks locally.
### Getting Help 🙋

View File

@@ -12,7 +12,7 @@
<a href="https://docs.danswer.dev/" target="_blank">
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
</a>
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2lcmqw703-071hBuZBfNEOGUsLa5PXvQ" target="_blank">
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA" target="_blank">
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
</a>
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
@@ -135,7 +135,7 @@ Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md
## ✨Contributors
<a href="https://github.com/aryn-ai/sycamore/graphs/contributors">
<a href="https://github.com/danswer-ai/danswer/graphs/contributors">
<img alt="contributors" src="https://contrib.rocks/image?repo=danswer-ai/danswer"/>
</a>

View File

@@ -73,6 +73,7 @@ RUN apt-get update && \
rm -rf /var/lib/apt/lists/* && \
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
# Pre-downloading models for setups with limited egress
RUN python -c "from tokenizers import Tokenizer; \
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"

View File

@@ -1,5 +1,5 @@
from sqlalchemy.engine.base import Connection
from typing import Any
from typing import Literal
import asyncio
from logging.config import fileConfig
import logging
@@ -8,6 +8,7 @@ from alembic import context
from sqlalchemy import pool
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.sql import text
from sqlalchemy.sql.schema import SchemaItem
from shared_configs.configs import MULTI_TENANT
from danswer.db.engine import build_connection_string
@@ -35,7 +36,18 @@ logger = logging.getLogger(__name__)
def include_object(
object: Any, name: str, type_: str, reflected: bool, compare_to: Any
object: SchemaItem,
name: str | None,
type_: Literal[
"schema",
"table",
"column",
"index",
"unique_constraint",
"foreign_key_constraint",
],
reflected: bool,
compare_to: SchemaItem | None,
) -> bool:
"""
Determines whether a database object should be included in migrations.

View File

@@ -0,0 +1,59 @@
"""display custom llm models
Revision ID: 177de57c21c9
Revises: 4ee1287bd26a
Create Date: 2024-11-21 11:49:04.488677
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy import and_
revision = "177de57c21c9"
down_revision = "4ee1287bd26a"
branch_labels = None
depends_on = None
depends_on = None
def upgrade() -> None:
conn = op.get_bind()
llm_provider = sa.table(
"llm_provider",
sa.column("id", sa.Integer),
sa.column("provider", sa.String),
sa.column("model_names", postgresql.ARRAY(sa.String)),
sa.column("display_model_names", postgresql.ARRAY(sa.String)),
)
excluded_providers = ["openai", "bedrock", "anthropic", "azure"]
providers_to_update = sa.select(
llm_provider.c.id,
llm_provider.c.model_names,
llm_provider.c.display_model_names,
).where(
and_(
~llm_provider.c.provider.in_(excluded_providers),
llm_provider.c.model_names.isnot(None),
)
)
results = conn.execute(providers_to_update).fetchall()
for provider_id, model_names, display_model_names in results:
if display_model_names is None:
display_model_names = []
combined_model_names = list(set(display_model_names + model_names))
update_stmt = (
llm_provider.update()
.where(llm_provider.c.id == provider_id)
.values(display_model_names=combined_model_names)
)
conn.execute(update_stmt)
def downgrade() -> None:
pass

View File

@@ -0,0 +1,68 @@
"""default chosen assistants to none
Revision ID: 26b931506ecb
Revises: 2daa494a0851
Create Date: 2024-11-12 13:23:29.858995
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "26b931506ecb"
down_revision = "2daa494a0851"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user", sa.Column("chosen_assistants_new", postgresql.JSONB(), nullable=True)
)
op.execute(
"""
UPDATE "user"
SET chosen_assistants_new =
CASE
WHEN chosen_assistants = '[-2, -1, 0]' THEN NULL
ELSE chosen_assistants
END
"""
)
op.drop_column("user", "chosen_assistants")
op.alter_column(
"user", "chosen_assistants_new", new_column_name="chosen_assistants"
)
def downgrade() -> None:
op.add_column(
"user",
sa.Column(
"chosen_assistants_old",
postgresql.JSONB(),
nullable=False,
server_default="[-2, -1, 0]",
),
)
op.execute(
"""
UPDATE "user"
SET chosen_assistants_old =
CASE
WHEN chosen_assistants IS NULL THEN '[-2, -1, 0]'::jsonb
ELSE chosen_assistants
END
"""
)
op.drop_column("user", "chosen_assistants")
op.alter_column(
"user", "chosen_assistants_old", new_column_name="chosen_assistants"
)

View File

@@ -0,0 +1,30 @@
"""add-group-sync-time
Revision ID: 2daa494a0851
Revises: c0fd6e4da83a
Create Date: 2024-11-11 10:57:22.991157
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2daa494a0851"
down_revision = "c0fd6e4da83a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column(
"last_time_external_group_sync",
sa.DateTime(timezone=True),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "last_time_external_group_sync")

View File

@@ -0,0 +1,45 @@
"""add persona categories
Revision ID: 47e5bef3a1d7
Revises: dfbe9e93d3c7
Create Date: 2024-11-05 18:55:02.221064
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "47e5bef3a1d7"
down_revision = "dfbe9e93d3c7"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create the persona_category table
op.create_table(
"persona_category",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("description", sa.String(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name"),
)
# Add category_id to persona table
op.add_column("persona", sa.Column("category_id", sa.Integer(), nullable=True))
op.create_foreign_key(
"fk_persona_category",
"persona",
"persona_category",
["category_id"],
["id"],
ondelete="SET NULL",
)
def downgrade() -> None:
op.drop_constraint("fk_persona_category", "persona", type_="foreignkey")
op.drop_column("persona", "category_id")
op.drop_table("persona_category")

View File

@@ -0,0 +1,280 @@
"""add_multiple_slack_bot_support
Revision ID: 4ee1287bd26a
Revises: 47e5bef3a1d7
Create Date: 2024-11-06 13:15:53.302644
"""
import logging
from typing import cast
from alembic import op
import sqlalchemy as sa
from sqlalchemy.orm import Session
from danswer.key_value_store.factory import get_kv_store
from danswer.db.models import SlackBot
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "4ee1287bd26a"
down_revision = "47e5bef3a1d7"
branch_labels: None = None
depends_on: None = None
# Configure logging
logger = logging.getLogger("alembic.runtime.migration")
logger.setLevel(logging.INFO)
def upgrade() -> None:
logger.info(f"{revision}: create_table: slack_bot")
# Create new slack_bot table
op.create_table(
"slack_bot",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("enabled", sa.Boolean(), nullable=False, server_default="true"),
sa.Column("bot_token", sa.LargeBinary(), nullable=False),
sa.Column("app_token", sa.LargeBinary(), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("bot_token"),
sa.UniqueConstraint("app_token"),
)
# # Create new slack_channel_config table
op.create_table(
"slack_channel_config",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("slack_bot_id", sa.Integer(), nullable=True),
sa.Column("persona_id", sa.Integer(), nullable=True),
sa.Column("channel_config", postgresql.JSONB(), nullable=False),
sa.Column("response_type", sa.String(), nullable=False),
sa.Column(
"enable_auto_filters", sa.Boolean(), nullable=False, server_default="false"
),
sa.ForeignKeyConstraint(
["slack_bot_id"],
["slack_bot.id"],
),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# Handle existing Slack bot tokens first
logger.info(f"{revision}: Checking for existing Slack bot.")
bot_token = None
app_token = None
first_row_id = None
try:
tokens = cast(dict, get_kv_store().load("slack_bot_tokens_config_key"))
except Exception:
logger.warning("No existing Slack bot tokens found.")
tokens = {}
bot_token = tokens.get("bot_token")
app_token = tokens.get("app_token")
if bot_token and app_token:
logger.info(f"{revision}: Found bot and app tokens.")
session = Session(bind=op.get_bind())
new_slack_bot = SlackBot(
name="Slack Bot (Migrated)",
enabled=True,
bot_token=bot_token,
app_token=app_token,
)
session.add(new_slack_bot)
session.commit()
first_row_id = new_slack_bot.id
# Create a default bot if none exists
# This is in case there are no slack tokens but there are channels configured
op.execute(
sa.text(
"""
INSERT INTO slack_bot (name, enabled, bot_token, app_token)
SELECT 'Default Bot', true, '', ''
WHERE NOT EXISTS (SELECT 1 FROM slack_bot)
RETURNING id;
"""
)
)
# Get the bot ID to use (either from existing migration or newly created)
bot_id_query = sa.text(
"""
SELECT COALESCE(
:first_row_id,
(SELECT id FROM slack_bot ORDER BY id ASC LIMIT 1)
) as bot_id;
"""
)
result = op.get_bind().execute(bot_id_query, {"first_row_id": first_row_id})
bot_id = result.scalar()
# CTE (Common Table Expression) that transforms the old slack_bot_config table data
# This splits up the channel_names into their own rows
channel_names_cte = """
WITH channel_names AS (
SELECT
sbc.id as config_id,
sbc.persona_id,
sbc.response_type,
sbc.enable_auto_filters,
jsonb_array_elements_text(sbc.channel_config->'channel_names') as channel_name,
sbc.channel_config->>'respond_tag_only' as respond_tag_only,
sbc.channel_config->>'respond_to_bots' as respond_to_bots,
sbc.channel_config->'respond_member_group_list' as respond_member_group_list,
sbc.channel_config->'answer_filters' as answer_filters,
sbc.channel_config->'follow_up_tags' as follow_up_tags
FROM slack_bot_config sbc
)
"""
# Insert the channel names into the new slack_channel_config table
insert_statement = """
INSERT INTO slack_channel_config (
slack_bot_id,
persona_id,
channel_config,
response_type,
enable_auto_filters
)
SELECT
:bot_id,
channel_name.persona_id,
jsonb_build_object(
'channel_name', channel_name.channel_name,
'respond_tag_only',
COALESCE((channel_name.respond_tag_only)::boolean, false),
'respond_to_bots',
COALESCE((channel_name.respond_to_bots)::boolean, false),
'respond_member_group_list',
COALESCE(channel_name.respond_member_group_list, '[]'::jsonb),
'answer_filters',
COALESCE(channel_name.answer_filters, '[]'::jsonb),
'follow_up_tags',
COALESCE(channel_name.follow_up_tags, '[]'::jsonb)
),
channel_name.response_type,
channel_name.enable_auto_filters
FROM channel_names channel_name;
"""
op.execute(sa.text(channel_names_cte + insert_statement).bindparams(bot_id=bot_id))
# Clean up old tokens if they existed
try:
if bot_token and app_token:
logger.info(f"{revision}: Removing old bot and app tokens.")
get_kv_store().delete("slack_bot_tokens_config_key")
except Exception:
logger.warning("tried to delete tokens in dynamic config but failed")
# Rename the table
op.rename_table(
"slack_bot_config__standard_answer_category",
"slack_channel_config__standard_answer_category",
)
# Rename the column
op.alter_column(
"slack_channel_config__standard_answer_category",
"slack_bot_config_id",
new_column_name="slack_channel_config_id",
)
# Drop the table with CASCADE to handle dependent objects
op.execute("DROP TABLE slack_bot_config CASCADE")
logger.info(f"{revision}: Migration complete.")
def downgrade() -> None:
# Recreate the old slack_bot_config table
op.create_table(
"slack_bot_config",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("persona_id", sa.Integer(), nullable=True),
sa.Column("channel_config", postgresql.JSONB(), nullable=False),
sa.Column("response_type", sa.String(), nullable=False),
sa.Column("enable_auto_filters", sa.Boolean(), nullable=False),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# Migrate data back to the old format
# Group by persona_id to combine channel names back into arrays
op.execute(
sa.text(
"""
INSERT INTO slack_bot_config (
persona_id,
channel_config,
response_type,
enable_auto_filters
)
SELECT DISTINCT ON (persona_id)
persona_id,
jsonb_build_object(
'channel_names', (
SELECT jsonb_agg(c.channel_config->>'channel_name')
FROM slack_channel_config c
WHERE c.persona_id = scc.persona_id
),
'respond_tag_only', (channel_config->>'respond_tag_only')::boolean,
'respond_to_bots', (channel_config->>'respond_to_bots')::boolean,
'respond_member_group_list', channel_config->'respond_member_group_list',
'answer_filters', channel_config->'answer_filters',
'follow_up_tags', channel_config->'follow_up_tags'
),
response_type,
enable_auto_filters
FROM slack_channel_config scc
WHERE persona_id IS NOT NULL;
"""
)
)
# Rename the table back
op.rename_table(
"slack_channel_config__standard_answer_category",
"slack_bot_config__standard_answer_category",
)
# Rename the column back
op.alter_column(
"slack_bot_config__standard_answer_category",
"slack_channel_config_id",
new_column_name="slack_bot_config_id",
)
# Try to save the first bot's tokens back to KV store
try:
first_bot = (
op.get_bind()
.execute(
sa.text(
"SELECT bot_token, app_token FROM slack_bot ORDER BY id LIMIT 1"
)
)
.first()
)
if first_bot and first_bot.bot_token and first_bot.app_token:
tokens = {
"bot_token": first_bot.bot_token,
"app_token": first_bot.app_token,
}
get_kv_store().store("slack_bot_tokens_config_key", tokens)
except Exception:
logger.warning("Failed to save tokens back to KV store")
# Drop the new tables in reverse order
op.drop_table("slack_channel_config")
op.drop_table("slack_bot")

View File

@@ -0,0 +1,45 @@
"""remove default bot
Revision ID: 6d562f86c78b
Revises: 177de57c21c9
Create Date: 2024-11-22 11:51:29.331336
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "6d562f86c78b"
down_revision = "177de57c21c9"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
sa.text(
"""
DELETE FROM slack_bot
WHERE name = 'Default Bot'
AND bot_token = ''
AND app_token = ''
AND NOT EXISTS (
SELECT 1 FROM slack_channel_config
WHERE slack_channel_config.slack_bot_id = slack_bot.id
)
"""
)
)
def downgrade() -> None:
op.execute(
sa.text(
"""
INSERT INTO slack_bot (name, enabled, bot_token, app_token)
SELECT 'Default Bot', true, '', ''
WHERE NOT EXISTS (SELECT 1 FROM slack_bot)
RETURNING id;
"""
)
)

View File

@@ -9,8 +9,8 @@ from alembic import op
import sqlalchemy as sa
from danswer.db.models import IndexModelStatus
from danswer.search.enums import RecencyBiasSetting
from danswer.search.enums import SearchType
from danswer.context.search.enums import RecencyBiasSetting
from danswer.context.search.enums import SearchType
# revision identifiers, used by Alembic.
revision = "776b3bbe9092"

View File

@@ -0,0 +1,35 @@
"""add web ui option to slack config
Revision ID: 93560ba1b118
Revises: 6d562f86c78b
Create Date: 2024-11-24 06:36:17.490612
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "93560ba1b118"
down_revision = "6d562f86c78b"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add show_continue_in_web_ui with default False to all existing channel_configs
op.execute(
"""
UPDATE slack_channel_config
SET channel_config = channel_config || '{"show_continue_in_web_ui": false}'::jsonb
WHERE NOT channel_config ? 'show_continue_in_web_ui'
"""
)
def downgrade() -> None:
# Remove show_continue_in_web_ui from all channel_configs
op.execute(
"""
UPDATE slack_channel_config
SET channel_config = channel_config - 'show_continue_in_web_ui'
"""
)

View File

@@ -7,6 +7,7 @@ Create Date: 2024-10-26 13:06:06.937969
"""
from alembic import op
from sqlalchemy.orm import Session
from sqlalchemy import text
# Import your models and constants
from danswer.db.models import (
@@ -15,7 +16,6 @@ from danswer.db.models import (
Credential,
IndexAttempt,
)
from danswer.configs.constants import DocumentSource
# revision identifiers, used by Alembic.
@@ -30,13 +30,11 @@ def upgrade() -> None:
bind = op.get_bind()
session = Session(bind=bind)
connectors_to_delete = (
session.query(Connector)
.filter(Connector.source == DocumentSource.REQUESTTRACKER)
.all()
# Get connectors using raw SQL
result = bind.execute(
text("SELECT id FROM connector WHERE source = 'requesttracker'")
)
connector_ids = [connector.id for connector in connectors_to_delete]
connector_ids = [row[0] for row in result]
if connector_ids:
cc_pairs_to_delete = (

View File

@@ -0,0 +1,30 @@
"""add creator to cc pair
Revision ID: 9cf5c00f72fe
Revises: 26b931506ecb
Create Date: 2024-11-12 15:16:42.682902
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "9cf5c00f72fe"
down_revision = "26b931506ecb"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column(
"creator_id",
sa.UUID(as_uuid=True),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "creator_id")

View File

@@ -0,0 +1,36 @@
"""Combine Search and Chat
Revision ID: 9f696734098f
Revises: a8c2065484e6
Create Date: 2024-11-27 15:32:19.694972
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "9f696734098f"
down_revision = "a8c2065484e6"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column("chat_session", "description", nullable=True)
op.drop_column("chat_session", "one_shot")
op.drop_column("slack_channel_config", "response_type")
def downgrade() -> None:
op.execute("UPDATE chat_session SET description = '' WHERE description IS NULL")
op.alter_column("chat_session", "description", nullable=False)
op.add_column(
"chat_session",
sa.Column("one_shot", sa.Boolean(), nullable=False, server_default=sa.false()),
)
op.add_column(
"slack_channel_config",
sa.Column(
"response_type", sa.String(), nullable=False, server_default="citations"
),
)

View File

@@ -0,0 +1,27 @@
"""add auto scroll to user model
Revision ID: a8c2065484e6
Revises: abe7378b8217
Create Date: 2024-11-22 17:34:09.690295
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "a8c2065484e6"
down_revision = "abe7378b8217"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column("auto_scroll", sa.Boolean(), nullable=True, server_default=None),
)
def downgrade() -> None:
op.drop_column("user", "auto_scroll")

View File

@@ -0,0 +1,30 @@
"""add indexing trigger to cc_pair
Revision ID: abe7378b8217
Revises: 6d562f86c78b
Create Date: 2024-11-26 19:09:53.481171
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "abe7378b8217"
down_revision = "93560ba1b118"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column(
"indexing_trigger",
sa.Enum("UPDATE", "REINDEX", name="indexingmode", native_enum=False),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "indexing_trigger")

View File

@@ -288,6 +288,15 @@ def upgrade() -> None:
def downgrade() -> None:
# NOTE: you will lose all chat history. This is to satisfy the non-nullable constraints
# below
op.execute("DELETE FROM chat_feedback")
op.execute("DELETE FROM chat_message__search_doc")
op.execute("DELETE FROM document_retrieval_feedback")
op.execute("DELETE FROM document_retrieval_feedback")
op.execute("DELETE FROM chat_message")
op.execute("DELETE FROM chat_session")
op.drop_constraint(
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
)

View File

@@ -23,6 +23,56 @@ def upgrade() -> None:
def downgrade() -> None:
# Delete chat messages and feedback first since they reference chat sessions
# Get chat messages from sessions with null persona_id
chat_messages_query = """
SELECT id
FROM chat_message
WHERE chat_session_id IN (
SELECT id
FROM chat_session
WHERE persona_id IS NULL
)
"""
# Delete dependent records first
op.execute(
f"""
DELETE FROM document_retrieval_feedback
WHERE chat_message_id IN (
{chat_messages_query}
)
"""
)
op.execute(
f"""
DELETE FROM chat_message__search_doc
WHERE chat_message_id IN (
{chat_messages_query}
)
"""
)
# Delete chat messages
op.execute(
"""
DELETE FROM chat_message
WHERE chat_session_id IN (
SELECT id
FROM chat_session
WHERE persona_id IS NULL
)
"""
)
# Now we can safely delete the chat sessions
op.execute(
"""
DELETE FROM chat_session
WHERE persona_id IS NULL
"""
)
op.alter_column(
"chat_session",
"persona_id",

View File

@@ -0,0 +1,42 @@
"""extended_role_for_non_web
Revision ID: dfbe9e93d3c7
Revises: 9cf5c00f72fe
Create Date: 2024-11-16 07:54:18.727906
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "dfbe9e93d3c7"
down_revision = "9cf5c00f72fe"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
"""
UPDATE "user"
SET role = 'EXT_PERM_USER'
WHERE has_web_login = false
"""
)
op.drop_column("user", "has_web_login")
def downgrade() -> None:
op.add_column(
"user",
sa.Column("has_web_login", sa.Boolean(), nullable=False, server_default="true"),
)
op.execute(
"""
UPDATE "user"
SET has_web_login = false,
role = 'BASIC'
WHERE role IN ('SLACK_USER', 'EXT_PERM_USER')
"""
)

View File

@@ -0,0 +1,40 @@
"""non-nullbale slack bot id in channel config
Revision ID: f7a894b06d02
Revises: 9f696734098f
Create Date: 2024-12-06 12:55:42.845723
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "f7a894b06d02"
down_revision = "9f696734098f"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Delete all rows with null slack_bot_id
op.execute("DELETE FROM slack_channel_config WHERE slack_bot_id IS NULL")
# Make slack_bot_id non-nullable
op.alter_column(
"slack_channel_config",
"slack_bot_id",
existing_type=sa.Integer(),
nullable=False,
)
def downgrade() -> None:
# Make slack_bot_id nullable again
op.alter_column(
"slack_channel_config",
"slack_bot_id",
existing_type=sa.Integer(),
nullable=True,
)

View File

@@ -1,5 +1,6 @@
import asyncio
from logging.config import fileConfig
from typing import Literal
from sqlalchemy import pool
from sqlalchemy.engine import Connection
@@ -37,8 +38,15 @@ EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
def include_object(
object: SchemaItem,
name: str,
type_: str,
name: str | None,
type_: Literal[
"schema",
"table",
"column",
"index",
"unique_constraint",
"foreign_key_constraint",
],
reflected: bool,
compare_to: SchemaItem | None,
) -> bool:

View File

@@ -16,6 +16,46 @@ class ExternalAccess:
is_public: bool
@dataclass(frozen=True)
class DocExternalAccess:
"""
This is just a class to wrap the external access and the document ID
together. It's used for syncing document permissions to Redis.
"""
external_access: ExternalAccess
# The document ID
doc_id: str
def to_dict(self) -> dict:
return {
"external_access": {
"external_user_emails": list(self.external_access.external_user_emails),
"external_user_group_ids": list(
self.external_access.external_user_group_ids
),
"is_public": self.external_access.is_public,
},
"doc_id": self.doc_id,
}
@classmethod
def from_dict(cls, data: dict) -> "DocExternalAccess":
external_access = ExternalAccess(
external_user_emails=set(
data["external_access"].get("external_user_emails", [])
),
external_user_group_ids=set(
data["external_access"].get("external_user_group_ids", [])
),
is_public=data["external_access"]["is_public"],
)
return cls(
external_access=external_access,
doc_id=data["doc_id"],
)
@dataclass(frozen=True)
class DocumentAccess(ExternalAccess):
# User emails for Danswer users, None indicates admin

View File

@@ -2,8 +2,8 @@ from typing import cast
from danswer.configs.constants import KV_USER_STORE_KEY
from danswer.key_value_store.factory import get_kv_store
from danswer.key_value_store.interface import JSON_ro
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.utils.special_types import JSON_ro
def get_invited_users() -> list[str]:

View File

@@ -23,7 +23,9 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
)
return UserPreferences(**preferences_data)
except KvKeyNotFoundError:
return UserPreferences(chosen_assistants=None, default_model=None)
return UserPreferences(
chosen_assistants=None, default_model=None, auto_scroll=True
)
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:

View File

@@ -13,12 +13,24 @@ class UserRole(str, Enum):
groups they are curators of
- Global Curator can perform admin actions
for all groups they are a member of
- Limited can access a limited set of basic api endpoints
- Slack are users that have used danswer via slack but dont have a web login
- External permissioned users that have been picked up during the external permissions sync process but don't have a web login
"""
LIMITED = "limited"
BASIC = "basic"
ADMIN = "admin"
CURATOR = "curator"
GLOBAL_CURATOR = "global_curator"
SLACK_USER = "slack_user"
EXT_PERM_USER = "ext_perm_user"
def is_web_login(self) -> bool:
return self not in [
UserRole.SLACK_USER,
UserRole.EXT_PERM_USER,
]
class UserStatus(str, Enum):
@@ -33,10 +45,8 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
class UserCreate(schemas.BaseUserCreate):
role: UserRole = UserRole.BASIC
has_web_login: bool | None = True
tenant_id: str | None = None
class UserUpdate(schemas.BaseUserUpdate):
role: UserRole
has_web_login: bool | None = True

View File

@@ -49,8 +49,7 @@ from httpx_oauth.oauth2 import BaseOAuth2
from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy import text
from sqlalchemy.orm import attributes
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from danswer.auth.api_key import get_hashed_api_key_from_request
from danswer.auth.invited_users import get_invited_users
@@ -59,7 +58,6 @@ from danswer.auth.schemas import UserRole
from danswer.auth.schemas import UserUpdate
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.app_configs import DISABLE_VERIFICATION
from danswer.configs.app_configs import EMAIL_FROM
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
@@ -81,13 +79,14 @@ from danswer.db.auth import get_default_admin_user_emails
from danswer.db.auth import get_user_count
from danswer.db.auth import get_user_db
from danswer.db.auth import SQLAlchemyUserAdminDB
from danswer.db.engine import get_async_session
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.engine import get_session
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
from danswer.db.models import User
from danswer.db.users import get_user_by_email
from danswer.server.utils import BasicAuthenticationError
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
@@ -100,11 +99,6 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
class BasicAuthenticationError(HTTPException):
def __init__(self, detail: str):
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
def is_user_admin(user: User | None) -> bool:
if AUTH_TYPE == AuthType.DISABLED:
return True
@@ -137,11 +131,12 @@ def get_display_email(email: str | None, space_less: bool = False) -> str:
def user_needs_to_be_verified() -> bool:
# all other auth types besides basic should require users to be
# verified
return not DISABLE_VERIFICATION and (
AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
)
if AUTH_TYPE == AuthType.BASIC:
return REQUIRE_EMAIL_VERIFICATION
# For other auth types, if the user is authenticated it's assumed that
# the user is already verified via the external IDP
return False
def verify_email_is_invited(email: str) -> None:
@@ -222,18 +217,25 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET
user_db: SQLAlchemyUserDatabase[User, uuid.UUID]
async def create(
self,
user_create: schemas.UC | UserCreate,
safe: bool = False,
request: Optional[Request] = None,
) -> User:
referral_source = None
if request is not None:
referral_source = request.cookies.get("referral_source", None)
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=user_create.email,
referral_source=referral_source,
)
async with get_async_session_with_tenant(tenant_id) as db_session:
@@ -242,7 +244,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
db_session, User, OAuthAccount
)
self.user_db = tenant_user_db
self.database = tenant_user_db
@@ -261,14 +265,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
# Handle case where user has used product outside of web and is now creating an account through web
if (
not user.has_web_login
and hasattr(user_create, "has_web_login")
and user_create.has_web_login
):
if not user.role.is_web_login() and user_create.role.is_web_login():
user_update = UserUpdate(
password=user_create.password,
has_web_login=True,
role=user_create.role,
is_verified=user_create.is_verified,
)
@@ -282,7 +281,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
return user
async def oauth_callback(
self: "BaseUserManager[models.UOAP, models.ID]",
self,
oauth_name: str,
access_token: str,
account_id: str,
@@ -293,13 +292,18 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
*,
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> models.UOAP:
) -> User:
referral_source = None
if request:
referral_source = getattr(request.state, "referral_source", None)
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=account_email,
referral_source=referral_source,
)
if not tenant_id:
@@ -314,9 +318,11 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
verify_email_domain(account_email)
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID](
db_session, User, OAuthAccount
)
self.user_db = tenant_user_db
self.database = tenant_user_db # type: ignore
self.database = tenant_user_db
oauth_account_dict = {
"oauth_name": oauth_name,
@@ -368,7 +374,11 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
and existing_oauth_account.oauth_name == oauth_name
):
user = await self.user_db.update_oauth_account(
user, existing_oauth_account, oauth_account_dict
user,
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
# but the type checker doesn't know that :(
existing_oauth_account, # type: ignore
oauth_account_dict,
)
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
@@ -381,16 +391,15 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
)
# Handle case where user has used product outside of web and is now creating an account through web
if not user.has_web_login: # type: ignore
if not user.role.is_web_login():
await self.user_db.update(
user,
{
"is_verified": is_verified_by_default,
"has_web_login": True,
"role": UserRole.BASIC,
},
)
user.is_verified = is_verified_by_default
user.has_web_login = True # type: ignore
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
# otherwise, the oidc expiry will always be old, and the user will never be able to login
@@ -465,9 +474,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
self.password_helper.hash(credentials.password)
return None
has_web_login = attributes.get_attribute(user, "has_web_login")
if not has_web_login:
if not user.role.is_web_login():
raise BasicAuthenticationError(
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
)
@@ -598,7 +605,7 @@ optional_fastapi_current_user = fastapi_users.current_user(active=True, optional
async def optional_user_(
request: Request,
user: User | None,
db_session: Session,
async_db_session: AsyncSession,
) -> User | None:
"""NOTE: `request` and `db_session` are not used here, but are included
for the EE version of this function."""
@@ -607,13 +614,21 @@ async def optional_user_(
async def optional_user(
request: Request,
db_session: Session = Depends(get_session),
async_db_session: AsyncSession = Depends(get_async_session),
user: User | None = Depends(optional_fastapi_current_user),
) -> User | None:
versioned_fetch_user = fetch_versioned_implementation(
"danswer.auth.users", "optional_user_"
)
return await versioned_fetch_user(request, user, db_session)
user = await versioned_fetch_user(request, user, async_db_session)
# check if an API key is present
if user is None:
hashed_api_key = get_hashed_api_key_from_request(request)
if hashed_api_key:
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)
return user
async def double_check_user(
@@ -652,12 +667,26 @@ async def current_user_with_expired_token(
return await double_check_user(user, include_expired=True)
async def current_user(
async def current_limited_user(
user: User | None = Depends(optional_user),
) -> User | None:
return await double_check_user(user)
async def current_user(
user: User | None = Depends(optional_user),
) -> User | None:
user = await double_check_user(user)
if not user:
return None
if user.role == UserRole.LIMITED:
raise BasicAuthenticationError(
detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.",
)
return user
async def current_curator_or_admin_user(
user: User | None = Depends(current_user),
) -> User | None:
@@ -711,8 +740,6 @@ def generate_state_token(
# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91
def create_danswer_oauth_router(
oauth_client: BaseOAuth2,
backend: AuthenticationBackend,
@@ -762,15 +789,22 @@ def get_oauth_router(
response_model=OAuth2AuthorizeResponse,
)
async def authorize(
request: Request, scopes: List[str] = Query(None)
request: Request,
scopes: List[str] = Query(None),
) -> OAuth2AuthorizeResponse:
referral_source = request.cookies.get("referral_source", None)
if redirect_url is not None:
authorize_redirect_url = redirect_url
else:
authorize_redirect_url = str(request.url_for(callback_route_name))
next_url = request.query_params.get("next", "/")
state_data: Dict[str, str] = {"next_url": next_url}
state_data: Dict[str, str] = {
"next_url": next_url,
"referral_source": referral_source or "default_referral",
}
state = generate_state_token(state_data, state_secret)
authorization_url = await oauth_client.get_authorization_url(
authorize_redirect_url,
@@ -829,8 +863,11 @@ def get_oauth_router(
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
next_url = state_data.get("next_url", "/")
referral_source = state_data.get("referral_source", None)
# Authenticate user
request.state.referral_source = referral_source
# Proceed to authenticate or create the user
try:
user = await user_manager.oauth_callback(
oauth_client.name,
@@ -872,14 +909,13 @@ def get_oauth_router(
redirect_response.status_code = response.status_code
if hasattr(response, "media_type"):
redirect_response.media_type = response.media_type
return redirect_response
return router
def api_key_dep(
request: Request, db_session: Session = Depends(get_session)
async def api_key_dep(
request: Request, async_db_session: AsyncSession = Depends(get_async_session)
) -> User | None:
if AUTH_TYPE == AuthType.DISABLED:
return None
@@ -889,7 +925,7 @@ def api_key_dep(
raise HTTPException(status_code=401, detail="Missing API key")
if hashed_api_key:
user = fetch_user_for_api_key(hashed_api_key, db_session)
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)
if user is None:
raise HTTPException(status_code=401, detail="Invalid API key")

View File

@@ -11,6 +11,7 @@ from celery.exceptions import WorkerShutdown
from celery.states import READY_STATES
from celery.utils.log import get_task_logger
from celery.worker import strategy # type: ignore
from redis.lock import Lock as RedisLock
from sentry_sdk.integrations.celery import CeleryIntegration
from sqlalchemy import text
from sqlalchemy.orm import Session
@@ -24,6 +25,8 @@ from danswer.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_document_set import RedisDocumentSet
from danswer.redis.redis_pool import get_redis_client
@@ -136,6 +139,22 @@ def on_task_postrun(
RedisConnectorPrune.remove_from_taskset(int(cc_pair_id), task_id, r)
return
if task_id.startswith(RedisConnectorPermissionSync.SUBTASK_PREFIX):
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
if cc_pair_id is not None:
RedisConnectorPermissionSync.remove_from_taskset(
int(cc_pair_id), task_id, r
)
return
if task_id.startswith(RedisConnectorExternalGroupSync.SUBTASK_PREFIX):
cc_pair_id = RedisConnector.get_id_from_task_id(task_id)
if cc_pair_id is not None:
RedisConnectorExternalGroupSync.remove_from_taskset(
int(cc_pair_id), task_id, r
)
return
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
"""The first signal sent on celery worker startup"""
@@ -314,16 +333,16 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
return
logger.info("Releasing primary worker lock.")
lock = sender.primary_worker_lock
lock: RedisLock = sender.primary_worker_lock
try:
if lock.owned():
try:
lock.release()
sender.primary_worker_lock = None
except Exception as e:
logger.error(f"Failed to release primary worker lock: {e}")
except Exception as e:
logger.error(f"Failed to check if primary worker lock is owned: {e}")
except Exception:
logger.exception("Failed to release primary worker lock")
except Exception:
logger.exception("Failed to check if primary worker lock is owned")
def on_setup_logging(

View File

@@ -12,6 +12,7 @@ from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
from shared_configs.configs import MULTI_TENANT
logger = setup_logger(__name__)
@@ -72,6 +73,15 @@ class DynamicTenantScheduler(PersistentScheduler):
logger.info(f"Found {len(existing_tenants)} existing tenants in schedule")
for tenant_id in tenant_ids:
if (
IGNORED_SYNCING_TENANT_LIST
and tenant_id in IGNORED_SYNCING_TENANT_LIST
):
logger.info(
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
)
continue
if tenant_id not in existing_tenants:
logger.info(f"Processing new tenant: {tenant_id}")

View File

@@ -91,5 +91,7 @@ def on_setup_logging(
celery_app.autodiscover_tasks(
[
"danswer.background.celery.tasks.pruning",
"danswer.background.celery.tasks.doc_permission_syncing",
"danswer.background.celery.tasks.external_group_syncing",
]
)

View File

@@ -6,6 +6,7 @@ from celery import signals
from celery import Task
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_process_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
@@ -59,7 +60,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=sender.concurrency)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
@@ -81,6 +82,11 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@worker_process_init.connect
def init_worker(**kwargs: Any) -> None:
SqlEngine.reset_engine()
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any

View File

@@ -92,5 +92,6 @@ celery_app.autodiscover_tasks(
"danswer.background.celery.tasks.shared",
"danswer.background.celery.tasks.vespa",
"danswer.background.celery.tasks.connector_deletion",
"danswer.background.celery.tasks.doc_permission_syncing",
]
)

View File

@@ -1,5 +1,6 @@
import multiprocessing
from typing import Any
from typing import cast
from celery import bootsteps # type: ignore
from celery import Celery
@@ -10,16 +11,25 @@ from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
from redis.lock import Lock as RedisLock
import danswer.background.celery.apps.app_base as app_base
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.background.celery.tasks.indexing.tasks import (
get_unfenced_index_attempt_ids,
)
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from danswer.db.engine import get_session_with_default_tenant
from danswer.db.engine import SqlEngine
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_canceled
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_connector_stop import RedisConnectorStop
@@ -29,7 +39,6 @@ from danswer.redis.redis_usergroup import RedisUserGroup
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
@@ -89,6 +98,15 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
# by the primary worker. This is unnecessary in the multi tenant scenario
r = get_redis_client(tenant_id=None)
# Log the role and slave count - being connected to a slave or slave count > 0 could be problematic
info: dict[str, Any] = cast(dict, r.info("replication"))
role: str = cast(str, info.get("role"))
connected_slaves: int = info.get("connected_slaves", 0)
logger.info(
f"Redis INFO REPLICATION: role={role} connected_slaves={connected_slaves}"
)
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
@@ -98,9 +116,13 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.
lock = r.lock(
# set thread_local=False since we don't control what thread the periodic task might
# reacquire the lock with
lock: RedisLock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
thread_local=False,
)
logger.info("Primary worker lock: Acquire starting.")
@@ -134,6 +156,27 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
RedisConnectorStop.reset_all(r)
RedisConnectorPermissionSync.reset_all(r)
RedisConnectorExternalGroupSync.reset_all(r)
# mark orphaned index attempts as failed
with get_session_with_default_tenant() as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
attempt = get_index_attempt(db_session, attempt_id)
if not attempt:
continue
failure_reason = (
f"Canceling leftover index attempt found on startup: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
logger.warning(failure_reason)
mark_attempt_canceled(attempt.id, db_session, failure_reason)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
@@ -188,7 +231,7 @@ class HubPeriodicTask(bootsteps.StartStopStep):
if not hasattr(worker, "primary_worker_lock"):
return
lock = worker.primary_worker_lock
lock: RedisLock = worker.primary_worker_lock
r = get_redis_client(tenant_id=None)
@@ -233,6 +276,8 @@ celery_app.autodiscover_tasks(
"danswer.background.celery.tasks.connector_deletion",
"danswer.background.celery.tasks.indexing",
"danswer.background.celery.tasks.periodic",
"danswer.background.celery.tasks.doc_permission_syncing",
"danswer.background.celery.tasks.external_group_syncing",
"danswer.background.celery.tasks.pruning",
"danswer.background.celery.tasks.shared",
"danswer.background.celery.tasks.vespa",

View File

@@ -1,96 +0,0 @@
from datetime import timedelta
from typing import Any
from celery.beat import PersistentScheduler # type: ignore
from celery.utils.log import get_task_logger
from danswer.db.engine import get_all_tenant_ids
from danswer.utils.variable_functionality import fetch_versioned_implementation
logger = get_task_logger(__name__)
class DynamicTenantScheduler(PersistentScheduler):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._reload_interval = timedelta(minutes=1)
self._last_reload = self.app.now() - self._reload_interval
def setup_schedule(self) -> None:
super().setup_schedule()
def tick(self) -> float:
retval = super().tick()
now = self.app.now()
if (
self._last_reload is None
or (now - self._last_reload) > self._reload_interval
):
logger.info("Reloading schedule to check for new tenants...")
self._update_tenant_tasks()
self._last_reload = now
return retval
def _update_tenant_tasks(self) -> None:
logger.info("Checking for tenant task updates...")
try:
tenant_ids = get_all_tenant_ids()
tasks_to_schedule = fetch_versioned_implementation(
"danswer.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
)
new_beat_schedule: dict[str, dict[str, Any]] = {}
current_schedule = getattr(self, "_store", {"entries": {}}).get(
"entries", {}
)
existing_tenants = set()
for task_name in current_schedule.keys():
if "-" in task_name:
existing_tenants.add(task_name.split("-")[-1])
for tenant_id in tenant_ids:
if tenant_id not in existing_tenants:
logger.info(f"Found new tenant: {tenant_id}")
for task in tasks_to_schedule():
task_name = f"{task['name']}-{tenant_id}"
new_task = {
"task": task["task"],
"schedule": task["schedule"],
"kwargs": {"tenant_id": tenant_id},
}
if options := task.get("options"):
new_task["options"] = options
new_beat_schedule[task_name] = new_task
if self._should_update_schedule(current_schedule, new_beat_schedule):
logger.info(
"Updating schedule",
extra={
"new_tasks": len(new_beat_schedule),
"current_tasks": len(current_schedule),
},
)
if not hasattr(self, "_store"):
self._store: dict[str, dict] = {"entries": {}}
self.update_from_dict(new_beat_schedule)
logger.info(f"New schedule: {new_beat_schedule}")
logger.info("Tenant tasks updated successfully")
else:
logger.debug("No schedule updates needed")
except (AttributeError, KeyError):
logger.exception("Failed to process task configuration")
except Exception:
logger.exception("Unexpected error updating tenant tasks")
def _should_update_schedule(
self, current_schedule: dict, new_schedule: dict
) -> bool:
"""Compare schedules to determine if an update is needed."""
current_tasks = set(current_schedule.keys())
new_tasks = set(new_schedule.keys())
return current_tasks != new_tasks

View File

@@ -4,7 +4,6 @@ from typing import Any
from sqlalchemy.orm import Session
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
@@ -17,6 +16,7 @@ from danswer.connectors.models import Document
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.enums import TaskStatus
from danswer.db.models import TaskQueueState
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.redis.redis_connector import RedisConnector
from danswer.server.documents.models import DeletionAttemptSnapshot
from danswer.utils.logger import setup_logger
@@ -78,10 +78,10 @@ def document_batch_to_ids(
def extract_ids_from_runnable_connector(
runnable_connector: BaseConnector,
callback: RunIndexingCallbackInterface | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> set[str]:
"""
If the PruneConnector hasnt been implemented for the given connector, just pull
If the SlimConnector hasnt been implemented for the given connector, just pull
all docs using the load_from_state and grab out the IDs.
Optionally, a callback can be passed to handle the length of each document batch.
@@ -111,10 +111,15 @@ def extract_ids_from_runnable_connector(
for doc_batch in doc_batch_generator:
if callback:
if callback.should_stop():
raise RuntimeError("Stop signal received")
callback.progress(len(doc_batch))
raise RuntimeError(
"extract_ids_from_runnable_connector: Stop signal detected"
)
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
if callback:
callback.progress("extract_ids_from_runnable_connector", len(doc_batch))
return all_connector_doc_ids

View File

@@ -2,45 +2,58 @@ from datetime import timedelta
from typing import Any
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryTask
tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": "check_for_vespa_sync_task",
"schedule": timedelta(seconds=5),
"task": DanswerCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
"schedule": timedelta(seconds=20),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-connector-deletion",
"task": "check_for_connector_deletion_task",
"task": DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
"schedule": timedelta(seconds=20),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-indexing",
"task": "check_for_indexing",
"schedule": timedelta(seconds=10),
"task": DanswerCeleryTask.CHECK_FOR_INDEXING,
"schedule": timedelta(seconds=15),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-prune",
"task": "check_for_pruning",
"schedule": timedelta(seconds=10),
"task": DanswerCeleryTask.CHECK_FOR_PRUNING,
"schedule": timedelta(seconds=15),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "kombu-message-cleanup",
"task": "kombu_message_cleanup_task",
"task": DanswerCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
"schedule": timedelta(seconds=3600),
"options": {"priority": DanswerCeleryPriority.LOWEST},
},
{
"name": "monitor-vespa-sync",
"task": "monitor_vespa_sync",
"task": DanswerCeleryTask.MONITOR_VESPA_SYNC,
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-doc-permissions-sync",
"task": DanswerCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
"schedule": timedelta(seconds=30),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-external-group-sync",
"task": DanswerCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
"schedule": timedelta(seconds=20),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
]

View File

@@ -1,17 +1,17 @@
from datetime import datetime
from datetime import timezone
import redis
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.connector_credential_pair import get_connector_credential_pairs
@@ -19,7 +19,7 @@ from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.search_settings import get_all_search_settings
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_delete import RedisConnectorDeletionFenceData
from danswer.redis.redis_connector_delete import RedisConnectorDeletePayload
from danswer.redis.redis_pool import get_redis_client
@@ -29,7 +29,7 @@ class TaskDependencyError(RuntimeError):
@shared_task(
name="check_for_connector_deletion_task",
name=DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,
@@ -37,7 +37,7 @@ class TaskDependencyError(RuntimeError):
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
lock_beat: RedisLock = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -60,7 +60,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
redis_connector = RedisConnector(tenant_id, cc_pair_id)
try:
try_generate_document_cc_pair_cleanup_tasks(
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
self.app, cc_pair_id, db_session, lock_beat, tenant_id
)
except TaskDependencyError as e:
# this means we wanted to start deleting but dependent tasks were running
@@ -86,8 +86,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
app: Celery,
cc_pair_id: int,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
lock_beat: RedisLock,
tenant_id: str | None,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
@@ -118,7 +117,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
return None
# set a basic fence to start
fence_payload = RedisConnectorDeletionFenceData(
fence_payload = RedisConnectorDeletePayload(
num_tasks=None,
submitted=datetime.now(timezone.utc),
)
@@ -143,6 +142,12 @@ def try_generate_document_cc_pair_cleanup_tasks(
f"cc_pair={cc_pair_id}"
)
if redis_connector.permissions.fenced:
raise TaskDependencyError(
f"Connector deletion - Delayed (permissions in progress): "
f"cc_pair={cc_pair_id}"
)
# add tasks to celery and build up the task set to monitor in redis
redis_connector.delete.taskset_clear()

View File

@@ -0,0 +1,345 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from uuid import uuid4
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from danswer.access.models import DocExternalAccess
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import DocumentSource
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import upsert_document_by_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_ext_perm_user_if_not_exists
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncPayload,
)
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import doc_permission_sync_ctx
from danswer.utils.logger import setup_logger
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.danswer.db.document import upsert_document_external_perms
from ee.danswer.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
from ee.danswer.external_permissions.sync_params import DOC_PERMISSIONS_FUNC_MAP
logger = setup_logger()
DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES = 3
# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT
LIGHT_SOFT_TIME_LIMIT = 105
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if external doc permissions sync is due."""
if cc_pair.access_type != AccessType.SYNC:
return False
# skip doc permissions sync if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return False
# If the last sync is None, it has never been run so we run the sync
last_perm_sync = cc_pair.last_time_perm_sync
if last_perm_sync is None:
return True
source_sync_period = DOC_PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
if not source_sync_period:
return True
# If the last sync is greater than the full fetch period, we run the sync
next_sync = last_perm_sync + timedelta(seconds=source_sync_period)
if datetime.now(timezone.utc) >= next_sync:
return True
return False
@shared_task(
name=DanswerCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
# get all cc pairs that need to be synced
cc_pair_ids_to_sync: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
for cc_pair in cc_pairs:
if _is_external_doc_permissions_sync_due(cc_pair):
cc_pair_ids_to_sync.append(cc_pair.id)
for cc_pair_id in cc_pair_ids_to_sync:
tasks_created = try_creating_permissions_sync_task(
self.app, cc_pair_id, r, tenant_id
)
if not tasks_created:
continue
task_logger.info(f"Doc permissions sync queued: cc_pair={cc_pair_id}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if lock_beat.owned():
lock_beat.release()
def try_creating_permissions_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
tenant_id: str | None,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Returns None if no syncing is required."""
redis_connector = RedisConnector(tenant_id, cc_pair_id)
LOCK_TIMEOUT = 30
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
timeout=LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
return None
try:
if redis_connector.permissions.fenced:
return None
if redis_connector.delete.fenced:
return None
if redis_connector.prune.fenced:
return None
redis_connector.permissions.generator_clear()
redis_connector.permissions.taskset_clear()
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
result = app.send_task(
DanswerCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
kwargs=dict(
cc_pair_id=cc_pair_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.HIGH,
)
# set a basic fence to start
payload = RedisConnectorPermissionSyncPayload(
started=None, celery_task_id=result.id
)
redis_connector.permissions.set_fence(payload)
except Exception:
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
return None
finally:
if lock.owned():
lock.release()
return 1
@shared_task(
name=DanswerCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
acks_late=False,
soft_time_limit=JOB_TIMEOUT,
track_started=True,
trail=False,
bind=True,
)
def connector_permission_sync_generator_task(
self: Task,
cc_pair_id: int,
tenant_id: str | None,
) -> None:
"""
Permission sync task that handles document permission syncing for a given connector credential pair
This task assumes that the task has already been properly fenced
"""
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
doc_permission_sync_ctx_dict["cc_pair_id"] = cc_pair_id
doc_permission_sync_ctx_dict["request_id"] = self.request.id
doc_permission_sync_ctx.set(doc_permission_sync_ctx_dict)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client(tenant_id=tenant_id)
lock: RedisLock = r.lock(
DanswerRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"Permission sync task already running, exiting...: cc_pair={cc_pair_id}"
)
return None
try:
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if cc_pair is None:
raise ValueError(
f"No connector credential pair found for id: {cc_pair_id}"
)
source_type = cc_pair.connector.source
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
if doc_sync_func is None:
raise ValueError(
f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}"
)
logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}")
payload = redis_connector.permissions.payload
if not payload:
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
payload.started = datetime.now(timezone.utc)
redis_connector.permissions.set_fence(payload)
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
task_logger.info(
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
)
tasks_generated = redis_connector.permissions.generate_tasks(
celery_app=self.app,
lock=lock,
new_permissions=document_external_accesses,
source_string=source_type,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
)
if tasks_generated is None:
return None
task_logger.info(
f"RedisConnector.permissions.generate_tasks finished. "
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
)
redis_connector.permissions.generator_complete = tasks_generated
except Exception as e:
task_logger.exception(f"Failed to run permission sync: cc_pair={cc_pair_id}")
redis_connector.permissions.generator_clear()
redis_connector.permissions.taskset_clear()
redis_connector.permissions.set_fence(None)
raise e
finally:
if lock.owned():
lock.release()
@shared_task(
name=DanswerCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
time_limit=LIGHT_TIME_LIMIT,
max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES,
bind=True,
)
def update_external_document_permissions_task(
self: Task,
tenant_id: str | None,
serialized_doc_external_access: dict,
source_string: str,
connector_id: int,
credential_id: int,
) -> bool:
document_external_access = DocExternalAccess.from_dict(
serialized_doc_external_access
)
doc_id = document_external_access.doc_id
external_access = document_external_access.external_access
try:
with get_session_with_tenant(tenant_id) as db_session:
# Add the users to the DB if they don't exist
batch_add_ext_perm_user_if_not_exists(
db_session=db_session,
emails=list(external_access.external_user_emails),
)
# Then we upsert the document's external permissions in postgres
created_new_doc = upsert_document_external_perms(
db_session=db_session,
doc_id=doc_id,
external_access=external_access,
source_type=DocumentSource(source_string),
)
if created_new_doc:
# If a new document was created, we associate it with the cc_pair
upsert_document_by_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
document_ids=[doc_id],
)
logger.debug(
f"Successfully synced postgres document permissions for {doc_id}"
)
return True
except Exception:
logger.exception("Error Syncing Document Permissions")
return False

View File

@@ -0,0 +1,298 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from uuid import uuid4
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector import mark_cc_pair_as_external_group_synced
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_ext_group_sync import (
RedisConnectorExternalGroupSyncPayload,
)
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.danswer.db.connector_credential_pair import get_cc_pairs_by_source
from ee.danswer.db.external_perm import ExternalUserGroup
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair
from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIODS
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
from ee.danswer.external_permissions.sync_params import (
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC,
)
logger = setup_logger()
EXTERNAL_GROUPS_UPDATE_MAX_RETRIES = 3
# 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT
LIGHT_SOFT_TIME_LIMIT = 105
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if external group sync is due."""
if cc_pair.access_type != AccessType.SYNC:
return False
# skip external group sync if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return False
# If there is not group sync function for the connector, we don't run the sync
# This is fine because all sources dont necessarily have a concept of groups
if not GROUP_PERMISSIONS_FUNC_MAP.get(cc_pair.connector.source):
return False
# If the last sync is None, it has never been run so we run the sync
last_ext_group_sync = cc_pair.last_time_external_group_sync
if last_ext_group_sync is None:
return True
source_sync_period = EXTERNAL_GROUP_SYNC_PERIODS.get(cc_pair.connector.source)
# If EXTERNAL_GROUP_SYNC_PERIODS is None, we always run the sync.
if not source_sync_period:
return True
# If the last sync is greater than the full fetch period, we run the sync
next_sync = last_ext_group_sync + timedelta(seconds=source_sync_period)
if datetime.now(timezone.utc) >= next_sync:
return True
return False
@shared_task(
name=DanswerCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
cc_pair_ids_to_sync: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
# We only want to sync one cc_pair per source type in
# GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC
for source in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC:
# These are ordered by cc_pair id so the first one is the one we want
cc_pairs_to_dedupe = get_cc_pairs_by_source(
db_session, source, only_sync=True
)
# We only want to sync one cc_pair per source type
# in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC so we dedupe here
for cc_pair_to_remove in cc_pairs_to_dedupe[1:]:
cc_pairs = [
cc_pair
for cc_pair in cc_pairs
if cc_pair.id != cc_pair_to_remove.id
]
for cc_pair in cc_pairs:
if _is_external_group_sync_due(cc_pair):
cc_pair_ids_to_sync.append(cc_pair.id)
for cc_pair_id in cc_pair_ids_to_sync:
tasks_created = try_creating_external_group_sync_task(
self.app, cc_pair_id, r, tenant_id
)
if not tasks_created:
continue
task_logger.info(f"External group sync queued: cc_pair={cc_pair_id}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if lock_beat.owned():
lock_beat.release()
def try_creating_external_group_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
tenant_id: str | None,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Returns None if no syncing is required."""
redis_connector = RedisConnector(tenant_id, cc_pair_id)
LOCK_TIMEOUT = 30
lock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_external_group_sync_tasks",
timeout=LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
return None
try:
# Dont kick off a new sync if the previous one is still running
if redis_connector.external_group_sync.fenced:
return None
redis_connector.external_group_sync.generator_clear()
redis_connector.external_group_sync.taskset_clear()
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
result = app.send_task(
DanswerCeleryTask.CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK,
kwargs=dict(
cc_pair_id=cc_pair_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
task_id=custom_task_id,
priority=DanswerCeleryPriority.HIGH,
)
payload = RedisConnectorExternalGroupSyncPayload(
started=datetime.now(timezone.utc),
celery_task_id=result.id,
)
redis_connector.external_group_sync.set_fence(payload)
except Exception:
task_logger.exception(
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
)
return None
finally:
if lock.owned():
lock.release()
return 1
@shared_task(
name=DanswerCeleryTask.CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK,
acks_late=False,
soft_time_limit=JOB_TIMEOUT,
track_started=True,
trail=False,
bind=True,
)
def connector_external_group_sync_generator_task(
self: Task,
cc_pair_id: int,
tenant_id: str | None,
) -> None:
"""
Permission sync task that handles external group syncing for a given connector credential pair
This task assumes that the task has already been properly fenced
"""
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client(tenant_id=tenant_id)
lock: RedisLock = r.lock(
DanswerRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
)
try:
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
)
return None
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if cc_pair is None:
raise ValueError(
f"No connector credential pair found for id: {cc_pair_id}"
)
source_type = cc_pair.connector.source
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
if ext_group_sync_func is None:
raise ValueError(
f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
)
logger.info(
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
)
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)
logger.info(
f"Syncing {len(external_user_groups)} external user groups for {source_type}"
)
replace_user__ext_group_for_cc_pair(
db_session=db_session,
cc_pair_id=cc_pair.id,
group_defs=external_user_groups,
source=cc_pair.connector.source,
)
logger.info(
f"Synced {len(external_user_groups)} external user groups for {source_type}"
)
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
except Exception as e:
task_logger.exception(
f"Failed to run external group sync: cc_pair={cc_pair_id}"
)
redis_connector.external_group_sync.generator_clear()
redis_connector.external_group_sync.taskset_clear()
raise e
finally:
# we always want to clear the fence after the task is done or failed so it doesn't get stuck
redis_connector.external_group_sync.set_fence(None)
if lock.owned():
lock.release()

View File

@@ -10,41 +10,50 @@ from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.exceptions import LockError
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import DocumentSource
from danswer.db.connector import mark_ccpair_with_indexing_trigger
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.enums import IndexingMode
from danswer.db.enums import IndexingStatus
from danswer.db.enums import IndexModelStatus
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import delete_index_attempt
from danswer.db.index_attempt import get_all_index_attempts_by_status
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
from danswer.db.index_attempt import mark_attempt_canceled
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexAttempt
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_active_search_settings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_index import RedisConnectorIndexingFenceData
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_index import RedisConnectorIndexPayload
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
@@ -56,41 +65,108 @@ from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
class RunIndexingCallback(RunIndexingCallbackInterface):
class IndexingCallback(IndexingHeartbeatInterface):
def __init__(
self,
stop_key: str,
generator_progress_key: str,
redis_lock: redis.lock.Lock,
redis_lock: RedisLock,
redis_client: Redis,
):
super().__init__()
self.redis_lock: redis.lock.Lock = redis_lock
self.redis_lock: RedisLock = redis_lock
self.stop_key: str = stop_key
self.generator_progress_key: str = generator_progress_key
self.redis_client = redis_client
self.started: datetime = datetime.now(timezone.utc)
self.redis_lock.reacquire()
self.last_tag: str = "IndexingCallback.__init__"
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
def should_stop(self) -> bool:
if self.redis_client.exists(self.stop_key):
return True
return False
def progress(self, amount: int) -> None:
self.redis_lock.reacquire()
def progress(self, tag: str, amount: int) -> None:
try:
self.redis_lock.reacquire()
self.last_tag = tag
self.last_lock_reacquire = datetime.now(timezone.utc)
except LockError:
logger.exception(
f"IndexingCallback - lock.reacquire exceptioned. "
f"lock_timeout={self.redis_lock.timeout} "
f"start={self.started} "
f"last_tag={self.last_tag} "
f"last_reacquired={self.last_lock_reacquire} "
f"now={datetime.now(timezone.utc)}"
)
raise
self.redis_client.incrby(self.generator_progress_key, amount)
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
want to clean them up.
Unfenced = attempt not in terminal state and fence does not exist.
"""
unfenced_attempts: list[int] = []
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# inner = index_attempt in non terminal state
# outer = r.fence_key down
# check the db for index attempts in a non terminal state
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for attempt in attempts:
fence_key = RedisConnectorIndex.fence_key_with_ids(
attempt.connector_credential_pair_id, attempt.search_settings_id
)
# if the fence is down / doesn't exist, possible error but not confirmed
if r.exists(fence_key):
continue
# Between the time the attempts are first looked up and the time we see the fence down,
# the attempt may have completed and taken down the fence normally.
# We need to double check that the index attempt is still in a non terminal state
# and matches the original state, which confirms we are really in a bad state.
attempt_2 = get_index_attempt(db_session, attempt.id)
if not attempt_2:
continue
if attempt.status != attempt_2.status:
continue
unfenced_attempts.append(attempt.id)
return unfenced_attempts
@shared_task(
name="check_for_indexing",
name=DanswerCeleryTask.CHECK_FOR_INDEXING,
soft_time_limit=300,
bind=True,
)
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
tasks_created = 0
locked = False
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
lock_beat: RedisLock = r.lock(
DanswerRedisLocks.CHECK_INDEXING_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -100,6 +176,9 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
if not lock_beat.acquire(blocking=False):
return None
locked = True
# check for search settings swap
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
old_search_settings = check_index_swap(db_session=db_session)
current_search_settings = get_current_search_settings(db_session)
@@ -118,26 +197,24 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
embedding_model=embedding_model,
)
# gather cc_pair_ids
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
lock_beat.reacquire()
cc_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
# kick off index attempts
for cc_pair_id in cc_pair_ids:
lock_beat.reacquire()
redis_connector = RedisConnector(tenant_id, cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings = [primary_search_settings]
# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings.append(secondary_search_settings)
for search_settings_instance in search_settings:
search_settings_list: list[SearchSettings] = get_active_search_settings(
db_session
)
for search_settings_instance in search_settings_list:
redis_connector_index = redis_connector.new_index(
search_settings_instance.id
)
@@ -153,33 +230,80 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
last_attempt = get_last_attempt_for_cc_pair(
cc_pair.id, search_settings_instance.id, db_session
)
search_settings_primary = False
if search_settings_instance.id == search_settings_list[0].id:
search_settings_primary = True
if not _should_index(
cc_pair=cc_pair,
last_index=last_attempt,
search_settings_instance=search_settings_instance,
secondary_index_building=len(search_settings) > 1,
search_settings_primary=search_settings_primary,
secondary_index_building=len(search_settings_list) > 1,
db_session=db_session,
):
continue
reindex = False
if search_settings_instance.id == search_settings_list[0].id:
# the indexing trigger is only checked and cleared with the primary search settings
if cc_pair.indexing_trigger is not None:
if cc_pair.indexing_trigger == IndexingMode.REINDEX:
reindex = True
task_logger.info(
f"Connector indexing manual trigger detected: "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings_instance.id} "
f"indexing_mode={cc_pair.indexing_trigger}"
)
mark_ccpair_with_indexing_trigger(
cc_pair.id, None, db_session
)
# using a task queue and only allowing one task per cc_pair/search_setting
# prevents us from starving out certain attempts
attempt_id = try_creating_indexing_task(
self.app,
cc_pair,
search_settings_instance,
False,
reindex,
db_session,
r,
tenant_id,
)
if attempt_id:
task_logger.info(
f"Indexing queued: index_attempt={attempt_id} "
f"Connector indexing queued: "
f"index_attempt={attempt_id} "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings_instance.id} "
f"search_settings={search_settings_instance.id}"
)
tasks_created += 1
# Fail any index attempts in the DB that don't have fences
# This shouldn't ever happen!
with get_session_with_tenant(tenant_id) as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
lock_beat.reacquire()
attempt = get_index_attempt(db_session, attempt_id)
if not attempt:
continue
failure_reason = (
f"Unfenced index attempt found in DB: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
task_logger.error(failure_reason)
mark_attempt_failed(
attempt.id, db_session, failure_reason=failure_reason
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -187,8 +311,14 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if lock_beat.owned():
lock_beat.release()
if locked:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error(
"check_for_indexing - Lock not owned on completion: "
f"tenant={tenant_id}"
)
return tasks_created
@@ -197,6 +327,7 @@ def _should_index(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
search_settings_instance: SearchSettings,
search_settings_primary: bool,
secondary_index_building: bool,
db_session: Session,
) -> bool:
@@ -261,6 +392,11 @@ def _should_index(
):
return False
if search_settings_primary:
if cc_pair.indexing_trigger is not None:
# if a manual indexing trigger is on the cc pair, honor it for primary search settings
return True
# if no attempt has ever occurred, we should index regardless of refresh_freq
if not last_index:
return True
@@ -293,10 +429,11 @@ def try_creating_indexing_task(
"""
LOCK_TIMEOUT = 30
index_attempt_id: int | None = None
# we need to serialize any attempt to trigger indexing since it can be triggered
# either via celery beat or manually (API call)
lock = r.lock(
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
timeout=LOCK_TIMEOUT,
)
@@ -325,7 +462,7 @@ def try_creating_indexing_task(
redis_connector_index.generator_clear()
# set a basic fence to start
payload = RedisConnectorIndexingFenceData(
payload = RedisConnectorIndexPayload(
index_attempt_id=None,
started=None,
submitted=datetime.now(timezone.utc),
@@ -347,8 +484,10 @@ def try_creating_indexing_task(
custom_task_id = redis_connector_index.generate_generator_task_id()
# when the task is sent, we have yet to finish setting up the fence
# therefore, the task must contain code that blocks until the fence is ready
result = celery_app.send_task(
"connector_indexing_proxy_task",
DanswerCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
kwargs=dict(
index_attempt_id=index_attempt_id,
cc_pair_id=cc_pair.id,
@@ -366,15 +505,17 @@ def try_creating_indexing_task(
payload.index_attempt_id = index_attempt_id
payload.celery_task_id = result.id
redis_connector_index.set_fence(payload)
except Exception:
redis_connector_index.set_fence(payload)
task_logger.exception(
f"Unexpected exception: "
f"try_creating_indexing_task - Unexpected exception: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings.id}"
)
if index_attempt_id is not None:
delete_index_attempt(db_session, index_attempt_id)
redis_connector_index.set_fence(None)
return None
finally:
if lock.owned():
@@ -383,8 +524,14 @@ def try_creating_indexing_task(
return index_attempt_id
@shared_task(name="connector_indexing_proxy_task", acks_late=False, track_started=True)
@shared_task(
name=DanswerCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
bind=True,
acks_late=False,
track_started=True,
)
def connector_indexing_proxy_task(
self: Task,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
@@ -392,15 +539,19 @@ def connector_indexing_proxy_task(
) -> None:
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
task_logger.info(
f"Indexing proxy - starting: attempt={index_attempt_id} "
f"Indexing watchdog - starting: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
if not self.request.id:
task_logger.error("self.request.id is None!")
client = SimpleJobClient()
job = client.submit(
connector_indexing_task,
connector_indexing_task_wrapper,
index_attempt_id,
cc_pair_id,
search_settings_id,
@@ -411,7 +562,7 @@ def connector_indexing_proxy_task(
if not job:
task_logger.info(
f"Indexing proxy - spawn failed: attempt={index_attempt_id} "
f"Indexing watchdog - spawn failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
@@ -419,31 +570,78 @@ def connector_indexing_proxy_task(
return
task_logger.info(
f"Indexing proxy - spawn succeeded: attempt={index_attempt_id} "
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
while True:
sleep(10)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
# do nothing for ongoing jobs that haven't been stopped
if not job.done():
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
while True:
sleep(5)
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
"Indexing watchdog - termination signal detected: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
try:
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_canceled(
index_attempt_id,
db_session,
"Connector termination signal detected",
)
finally:
# if the DB exceptions, we'll just get an unfriendly failure message
# in the UI instead of the cancellation message
logger.exception(
"Indexing watchdog - transient exception marking index attempt as canceled: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
if not index_attempt:
continue
job.cancel()
if not index_attempt.is_finished():
continue
break
if not job.done():
# if the spawned task is still running, restart the check once again
# if the index attempt is not in a finished status
try:
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
except Exception:
# if the DB exceptioned, just restart the check.
# polling the index attempt status doesn't need to be strongly consistent
logger.exception(
"Indexing watchdog - transient exception looking up index attempt: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
continue
if job.status == "error":
task_logger.error(
f"Indexing proxy - spawned task exceptioned: "
"Indexing watchdog - spawned task exceptioned: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
@@ -455,7 +653,7 @@ def connector_indexing_proxy_task(
break
task_logger.info(
f"Indexing proxy - finished: attempt={index_attempt_id} "
f"Indexing watchdog - finished: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
@@ -463,6 +661,38 @@ def connector_indexing_proxy_task(
return
def connector_indexing_task_wrapper(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
is_ee: bool,
) -> int | None:
"""Just wraps connector_indexing_task so we can log any exceptions before
re-raising it."""
result: int | None = None
try:
result = connector_indexing_task(
index_attempt_id,
cc_pair_id,
search_settings_id,
tenant_id,
is_ee,
)
except:
logger.exception(
f"connector_indexing_task exceptioned: "
f"tenant={tenant_id} "
f"index_attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
raise
return result
def connector_indexing_task(
index_attempt_id: int,
cc_pair_id: int,
@@ -499,7 +729,8 @@ def connector_indexing_task(
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
logger.info(
f"Indexing spawned task starting: attempt={index_attempt_id} "
f"Indexing spawned task starting: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
@@ -516,6 +747,7 @@ def connector_indexing_task(
if redis_connector.delete.fenced:
raise RuntimeError(
f"Indexing will not start because connector deletion is in progress: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.delete.fence_key}"
)
@@ -523,18 +755,18 @@ def connector_indexing_task(
if redis_connector.stop.fenced:
raise RuntimeError(
f"Indexing will not start because a connector stop signal was detected: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.stop.fence_key}"
)
while True:
# wait for the fence to come up
if not redis_connector_index.fenced:
if not redis_connector_index.fenced: # The fence must exist
raise ValueError(
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
)
payload = redis_connector_index.payload
payload = redis_connector_index.payload # The payload must exist
if not payload:
raise ValueError("connector_indexing_task: payload invalid or not found")
@@ -557,16 +789,19 @@ def connector_indexing_task(
)
break
lock = r.lock(
# set thread_local=False since we don't control what thread the indexing/pruning
# might run our callback with
lock: RedisLock = r.lock(
redis_connector_index.generator_lock_key,
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
thread_local=False,
)
acquired = lock.acquire(blocking=False)
if not acquired:
logger.warning(
f"Indexing task already running, exiting...: "
f"cc_pair={cc_pair_id} search_settings={search_settings_id}"
f"index_attempt={index_attempt_id} cc_pair={cc_pair_id} search_settings={search_settings_id}"
)
return None
@@ -601,7 +836,7 @@ def connector_indexing_task(
)
# define a callback class
callback = RunIndexingCallback(
callback = IndexingCallback(
redis_connector.stop.fence_key,
redis_connector_index.generator_progress_key,
lock,

View File

@@ -13,12 +13,13 @@ from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import PostgresAdvisoryLocks
from danswer.db.engine import get_session_with_tenant
@shared_task(
name="kombu_message_cleanup_task",
name=DanswerCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
soft_time_limit=JOB_TIMEOUT,
bind=True,
base=AbortableTask,

View File

@@ -8,11 +8,12 @@ from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.background.celery.tasks.indexing.tasks import RunIndexingCallback
from danswer.background.celery.tasks.indexing.tasks import IndexingCallback
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
@@ -20,6 +21,7 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.models import InputType
@@ -38,8 +40,44 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if pruning is due.
Next pruning time is calculated as a delta from the last successful prune, or the
last successful indexing if pruning has never succeeded.
TODO(rkuo): consider whether we should allow pruning to be immediately rescheduled
if pruning fails (which is what it does now). A backoff could be reasonable.
"""
# skip pruning if no prune frequency is set
# pruning can still be forced via the API which will run a pruning task directly
if not cc_pair.connector.prune_freq:
return False
# skip pruning if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
# skip pruning if the next scheduled prune time hasn't been reached yet
last_pruned = cc_pair.last_pruned
if not last_pruned:
if not cc_pair.last_successful_index_time:
# if we've never indexed, we can't prune
return False
# if never pruned, use the last time the connector indexed successfully
last_pruned = cc_pair.last_successful_index_time
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
if datetime.now(timezone.utc) < next_prune:
return False
return True
@shared_task(
name="check_for_pruning",
name=DanswerCeleryTask.CHECK_FOR_PRUNING,
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
@@ -69,7 +107,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
if not cc_pair:
continue
if not is_pruning_due(cc_pair, db_session, r):
if not _is_pruning_due(cc_pair):
continue
tasks_created = try_creating_prune_generator_task(
@@ -90,47 +128,6 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
lock_beat.release()
def is_pruning_due(
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
) -> bool:
"""Returns an int if pruning is triggered.
The int represents the number of prune tasks generated (in this case, only one
because the task is a long running generator task.)
Returns None if no pruning is triggered (due to not being needed or
other reasons such as simultaneous pruning restrictions.
Checks for scheduling related conditions, then delegates the rest of the checks to
try_creating_prune_generator_task.
"""
# skip pruning if no prune frequency is set
# pruning can still be forced via the API which will run a pruning task directly
if not cc_pair.connector.prune_freq:
return False
# skip pruning if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
# skip pruning if the next scheduled prune time hasn't been reached yet
last_pruned = cc_pair.last_pruned
if not last_pruned:
if not cc_pair.last_successful_index_time:
# if we've never indexed, we can't prune
return False
# if never pruned, use the last time the connector indexed successfully
last_pruned = cc_pair.last_successful_index_time
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
if datetime.now(timezone.utc) < next_prune:
return False
return True
def try_creating_prune_generator_task(
celery_app: Celery,
cc_pair: ConnectorCredentialPair,
@@ -166,10 +163,16 @@ def try_creating_prune_generator_task(
return None
try:
if redis_connector.prune.fenced: # skip pruning if already pruning
# skip pruning if already pruning
if redis_connector.prune.fenced:
return None
if redis_connector.delete.fenced: # skip pruning if the cc_pair is deleting
# skip pruning if the cc_pair is deleting
if redis_connector.delete.fenced:
return None
# skip pruning if doc permissions sync is running
if redis_connector.permissions.fenced:
return None
db_session.refresh(cc_pair)
@@ -183,7 +186,7 @@ def try_creating_prune_generator_task(
custom_task_id = f"{redis_connector.prune.generator_task_key}_{uuid4()}"
celery_app.send_task(
"connector_pruning_generator_task",
DanswerCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
kwargs=dict(
cc_pair_id=cc_pair.id,
connector_id=cc_pair.connector_id,
@@ -208,7 +211,7 @@ def try_creating_prune_generator_task(
@shared_task(
name="connector_pruning_generator_task",
name=DanswerCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
acks_late=False,
soft_time_limit=JOB_TIMEOUT,
track_started=True,
@@ -231,13 +234,18 @@ def connector_pruning_generator_task(
pruning_ctx_dict["request_id"] = self.request.id
pruning_ctx.set(pruning_ctx_dict)
task_logger.info(f"Pruning generator starting: cc_pair={cc_pair_id}")
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client(tenant_id=tenant_id)
lock = r.lock(
# set thread_local=False since we don't control what thread the indexing/pruning
# might run our callback with
lock: RedisLock = r.lock(
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.id}",
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
thread_local=False,
)
acquired = lock.acquire(blocking=False)
@@ -261,6 +269,11 @@ def connector_pruning_generator_task(
)
return
task_logger.info(
f"Pruning generator running connector: "
f"cc_pair={cc_pair_id} "
f"connector_source={cc_pair.connector.source}"
)
runnable_connector = instantiate_connector(
db_session,
cc_pair.connector.source,
@@ -269,12 +282,13 @@ def connector_pruning_generator_task(
cc_pair.credential,
)
callback = RunIndexingCallback(
callback = IndexingCallback(
redis_connector.stop.fence_key,
redis_connector.prune.generator_progress_key,
lock,
r,
)
# a list of docs in the source
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
runnable_connector, callback
@@ -296,8 +310,8 @@ def connector_pruning_generator_task(
task_logger.info(
f"Pruning set collected: "
f"cc_pair={cc_pair_id} "
f"docs_to_remove={len(doc_ids_to_remove)} "
f"doc_source={cc_pair.connector.source}"
f"connector_source={cc_pair.connector.source} "
f"docs_to_remove={len(doc_ids_to_remove)}"
)
task_logger.info(
@@ -320,10 +334,10 @@ def connector_pruning_generator_task(
f"Failed to run pruning: cc_pair={cc_pair_id} connector={connector_id}"
)
redis_connector.prune.generator_clear()
redis_connector.prune.taskset_clear()
redis_connector.prune.set_fence(False)
redis_connector.prune.reset()
raise e
finally:
if lock.owned():
lock.release()
task_logger.info(f"Pruning generator finished: cc_pair={cc_pair_id}")

View File

@@ -9,6 +9,7 @@ from tenacity import RetryError
from danswer.access.access import get_access_for_document
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from danswer.configs.constants import DanswerCeleryTask
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
from danswer.db.document import delete_documents_complete__no_commit
from danswer.db.document import get_document
@@ -31,7 +32,7 @@ LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
@shared_task(
name="document_by_cc_pair_cleanup_task",
name=DanswerCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
time_limit=LIGHT_TIME_LIMIT,
max_retries=DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES,
@@ -59,7 +60,7 @@ def document_by_cc_pair_cleanup_task(
connector / credential pair from the access list
(6) delete all relevant entries from postgres
"""
task_logger.info(f"tenant={tenant_id} doc={document_id}")
task_logger.debug(f"Task start: tenant={tenant_id} doc={document_id}")
try:
with get_session_with_tenant(tenant_id) as db_session:
@@ -141,7 +142,9 @@ def document_by_cc_pair_cleanup_task(
return False
except Exception as ex:
if isinstance(ex, RetryError):
task_logger.info(f"Retry failed: {ex.last_attempt.attempt_number}")
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()
@@ -171,11 +174,21 @@ def document_by_cc_pair_cleanup_task(
else:
# This is the last attempt! mark the document as dirty in the db so that it
# eventually gets fixed out of band via stale document reconciliation
task_logger.info(
f"Max retries reached. Marking doc as dirty for reconciliation: "
task_logger.warning(
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
f"tenant={tenant_id} doc={document_id}"
)
with get_session_with_tenant(tenant_id):
with get_session_with_tenant(tenant_id) as db_session:
# delete the cc pair relationship now and let reconciliation clean it up
# in vespa
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_id=document_id,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
mark_document_as_modified(document_id, db_session)
return False

View File

@@ -5,7 +5,6 @@ from http import HTTPStatus
from typing import cast
import httpx
import redis
from celery import Celery
from celery import shared_task
from celery import Task
@@ -13,6 +12,7 @@ from celery.exceptions import SoftTimeLimitExceeded
from celery.result import AsyncResult
from celery.states import READY_STATES
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from tenacity import RetryError
@@ -25,8 +25,10 @@ from danswer.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector import mark_cc_pair_as_permissions_synced
from danswer.db.connector import mark_ccpair_as_pruned
from danswer.db.connector_credential_pair import add_deletion_failure_message
from danswer.db.connector_credential_pair import (
@@ -47,17 +49,19 @@ from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import IndexingStatus
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.index_attempt import get_all_index_attempts_by_status
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import DocumentSet
from danswer.db.models import IndexAttempt
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaDocumentFields
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncPayload,
)
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
from danswer.redis.redis_document_set import RedisDocumentSet
@@ -77,7 +81,7 @@ logger = setup_logger()
# celery auto associates tasks created inside another task,
# which bloats the result metadata considerably. trail=False prevents this.
@shared_task(
name="check_for_vespa_sync_task",
name=DanswerCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,
@@ -162,7 +166,7 @@ def try_generate_stale_document_sync_tasks(
celery_app: Celery,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
lock_beat: RedisLock,
tenant_id: str | None,
) -> int | None:
# the fence is up, do nothing
@@ -180,7 +184,12 @@ def try_generate_stale_document_sync_tasks(
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
)
task_logger.info("RedisConnector.generate_tasks starting by cc_pair.")
task_logger.info(
"RedisConnector.generate_tasks starting by cc_pair. "
"Documents spanning multiple cc_pairs will only be synced once."
)
docs_to_skip: set[str] = set()
# rkuo: we could technically sync all stale docs in one big pass.
# but I feel it's more understandable to group the docs by cc_pair
@@ -188,22 +197,21 @@ def try_generate_stale_document_sync_tasks(
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
rc = RedisConnectorCredentialPair(tenant_id, cc_pair.id)
tasks_generated = rc.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
rc.set_skip_docs(docs_to_skip)
result = rc.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
if tasks_generated is None:
if result is None:
continue
if tasks_generated == 0:
if result[1] == 0:
continue
task_logger.info(
f"RedisConnector.generate_tasks finished for single cc_pair. "
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
f"cc_pair={cc_pair.id} tasks_generated={result[0]} tasks_possible={result[1]}"
)
total_tasks_generated += tasks_generated
total_tasks_generated += result[0]
task_logger.info(
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
@@ -218,7 +226,7 @@ def try_generate_document_set_sync_tasks(
document_set_id: int,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
lock_beat: RedisLock,
tenant_id: str | None,
) -> int | None:
lock_beat.reacquire()
@@ -246,12 +254,11 @@ def try_generate_document_set_sync_tasks(
)
# Add all documents that need to be updated into the queue
tasks_generated = rds.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
if tasks_generated is None:
result = rds.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
if result is None:
return None
tasks_generated = result[0]
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
@@ -260,7 +267,7 @@ def try_generate_document_set_sync_tasks(
task_logger.info(
f"RedisDocumentSet.generate_tasks finished. "
f"document_set_id={document_set.id} tasks_generated={tasks_generated}"
f"document_set={document_set.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
@@ -273,7 +280,7 @@ def try_generate_user_group_sync_tasks(
usergroup_id: int,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
lock_beat: RedisLock,
tenant_id: str | None,
) -> int | None:
lock_beat.reacquire()
@@ -302,12 +309,11 @@ def try_generate_user_group_sync_tasks(
task_logger.info(
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
)
tasks_generated = rug.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
if tasks_generated is None:
result = rug.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
if result is None:
return None
tasks_generated = result[0]
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
@@ -316,7 +322,7 @@ def try_generate_user_group_sync_tasks(
task_logger.info(
f"RedisUserGroup.generate_tasks finished. "
f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}"
f"usergroup={usergroup.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
@@ -436,11 +442,22 @@ def monitor_connector_deletion_taskset(
db_session, cc_pair.connector_id, cc_pair.credential_id
)
if len(doc_ids) > 0:
# if this happens, documents somehow got added while deletion was in progress. Likely a bug
# gating off pruning and indexing work before deletion starts
# NOTE(rkuo): if this happens, documents somehow got added while
# deletion was in progress. Likely a bug gating off pruning and indexing
# work before deletion starts.
task_logger.warning(
f"Connector deletion - documents still found after taskset completion: "
f"cc_pair={cc_pair_id} num={len(doc_ids)}"
"Connector deletion - documents still found after taskset completion. "
"Clearing the current deletion attempt and allowing deletion to restart: "
f"cc_pair={cc_pair_id} "
f"docs_deleted={fence_data.num_tasks} "
f"docs_remaining={len(doc_ids)}"
)
# We don't want to waive off why we get into this state, but resetting
# our attempt and letting the deletion restart is a good way to recover
redis_connector.delete.reset()
raise RuntimeError(
"Connector deletion - documents still found after taskset completion"
)
# clean up the rest of the related Postgres entities
@@ -504,8 +521,7 @@ def monitor_connector_deletion_taskset(
f"docs_deleted={fence_data.num_tasks}"
)
redis_connector.delete.taskset_clear()
redis_connector.delete.set_fence(None)
redis_connector.delete.reset()
def monitor_ccpair_pruning_taskset(
@@ -546,6 +562,45 @@ def monitor_ccpair_pruning_taskset(
redis_connector.prune.set_fence(False)
def monitor_ccpair_permissions_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"monitor_ccpair_permissions_taskset: could not parse cc_pair_id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
if not redis_connector.permissions.fenced:
return
initial = redis_connector.permissions.generator_complete
if initial is None:
return
remaining = redis_connector.permissions.get_remaining()
task_logger.info(
f"Permissions sync progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
)
if remaining > 0:
return
payload: RedisConnectorPermissionSyncPayload | None = (
redis_connector.permissions.payload
)
start_time: datetime | None = payload.started if payload else None
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
redis_connector.permissions.reset()
def monitor_ccpair_indexing_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
@@ -580,8 +635,8 @@ def monitor_ccpair_indexing_taskset(
progress = redis_connector_index.get_progress()
if progress is not None:
task_logger.info(
f"Connector indexing progress: cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"Connector indexing progress: cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
@@ -590,39 +645,62 @@ def monitor_ccpair_indexing_taskset(
# the task is still setting up
return
# Read result state BEFORE generator_complete_key to avoid a race condition
# never use any blocking methods on the result from inside a task!
result: AsyncResult = AsyncResult(payload.celery_task_id)
result_state = result.state
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# inner = get_completion / generator_complete not signaled
# outer = result.state in READY state
status_int = redis_connector_index.get_completion()
if status_int is None:
if result_state in READY_STATES:
# IF the task state is READY, THEN generator_complete should be set
# if it isn't, then the worker crashed
task_logger.info(
f"Connector indexing aborted: "
f"cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
if status_int is None: # inner signal not set ... possible error
task_state = result.state
if (
task_state in READY_STATES
): # outer signal in terminal state ... possible error
# Now double check!
if redis_connector_index.get_completion() is None:
# inner signal still not set (and cannot change when outer result_state is READY)
# Task is finished but generator complete isn't set.
# We have a problem! Worker may have crashed.
task_result = str(result.result)
task_traceback = str(result.traceback)
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
if index_attempt:
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason="Connector indexing aborted or exceptioned.",
msg = (
f"Connector indexing aborted or exceptioned: "
f"attempt={payload.index_attempt_id} "
f"celery_task={payload.celery_task_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
f"result.state={task_state} "
f"result.result={task_result} "
f"result.traceback={task_traceback}"
)
task_logger.warning(msg)
redis_connector_index.reset()
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
if index_attempt:
if (
index_attempt.status != IndexingStatus.CANCELED
and index_attempt.status != IndexingStatus.FAILED
):
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason=msg,
)
redis_connector_index.reset()
return
status_enum = HTTPStatus(status_int)
task_logger.info(
f"Connector indexing finished: cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"Connector indexing finished: cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"status={status_enum.name} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
@@ -630,7 +708,7 @@ def monitor_ccpair_indexing_taskset(
redis_connector_index.reset()
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
@shared_task(name=DanswerCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True)
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
It scans for fence values and then gets the counts of any associated tasksets.
@@ -643,7 +721,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
"""
r = get_redis_client(tenant_id=tenant_id)
lock_beat: redis.lock.Lock = r.lock(
lock_beat: RedisLock = r.lock(
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -655,7 +733,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
# print current queue lengths
r_celery = self.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r)
n_celery = celery_get_queue_length("celery", r_celery)
n_indexing = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_INDEXING, r_celery
)
@@ -668,41 +746,19 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
n_pruning = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_PRUNING, r_celery
)
n_permissions_sync = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
)
task_logger.info(
f"Queue lengths: celery={n_celery} "
f"indexing={n_indexing} "
f"sync={n_sync} "
f"deletion={n_deletion} "
f"pruning={n_pruning}"
f"pruning={n_pruning} "
f"permissions_sync={n_permissions_sync} "
)
# do some cleanup before clearing fences
# check the db for any outstanding index attempts
with get_session_with_tenant(tenant_id) as db_session:
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for a in attempts:
# if attempts exist in the db but we don't detect them in redis, mark them as failed
fence_key = RedisConnectorIndex.fence_key_with_ids(
a.connector_credential_pair_id, a.search_settings_id
)
if not r.exists(fence_key):
failure_reason = (
f"Unknown index attempt. Might be left over from a process restart: "
f"index_attempt={a.id} "
f"cc_pair={a.connector_credential_pair_id} "
f"search_settings={a.search_settings_id}"
)
task_logger.warning(failure_reason)
mark_attempt_failed(a.id, db_session, failure_reason=failure_reason)
lock_beat.reacquire()
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
@@ -741,6 +797,12 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)
# uncomment for debugging if needed
# r_celery = celery_app.broker_connection().channel().client
# length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery)
@@ -757,7 +819,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
@shared_task(
name="vespa_metadata_sync_task",
name=DanswerCeleryTask.VESPA_METADATA_SYNC_TASK,
bind=True,
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
time_limit=LIGHT_TIME_LIMIT,
@@ -811,7 +873,9 @@ def vespa_metadata_sync_task(
)
except Exception as ex:
if isinstance(ex, RetryError):
task_logger.warning(f"Retry failed: {ex.last_attempt.attempt_number}")
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()

View File

@@ -1,6 +1,8 @@
"""Factory stub for running celery worker / celery beat."""
from celery import Celery
from danswer.background.celery.apps.beat import celery_app
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app = celery_app
app: Celery = celery_app

View File

@@ -1,8 +1,10 @@
"""Factory stub for running celery worker / celery beat."""
from celery import Celery
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app = fetch_versioned_implementation(
app: Celery = fetch_versioned_implementation(
"danswer.background.celery.apps.primary", "celery_app"
)

View File

@@ -29,18 +29,26 @@ JobStatusType = (
def _initializer(
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
) -> Any:
"""Ensure the parent proc's database connections are not touched
in the new connection pool
"""Initialize the child process with a fresh SQLAlchemy Engine.
Based on the recommended approach in the SQLAlchemy docs found:
Based on SQLAlchemy's recommendations to handle multiprocessing:
https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
"""
if kwargs is None:
kwargs = {}
logger.info("Initializing spawned worker child process.")
# Reset the engine in the child process
SqlEngine.reset_engine()
# Optionally set a custom app name for database logging purposes
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
# Initialize a new engine with desired parameters
SqlEngine.init_engine(pool_size=4, max_overflow=12, pool_recycle=60)
# Proceed with executing the target function
return func(*args, **kwargs)

View File

@@ -1,7 +1,5 @@
import time
import traceback
from abc import ABC
from abc import abstractmethod
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -21,6 +19,7 @@ from danswer.db.connector_credential_pair import get_last_successful_attempt_tim
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.index_attempt import mark_attempt_canceled
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_partially_succeeded
from danswer.db.index_attempt import mark_attempt_succeeded
@@ -31,10 +30,10 @@ from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.embedder import DefaultIndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logger import IndexAttemptSingleton
from danswer.utils.logger import setup_logger
from danswer.utils.logger import TaskAttemptSingleton
from danswer.utils.variable_functionality import global_version
logger = setup_logger()
@@ -42,19 +41,6 @@ logger = setup_logger()
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
class RunIndexingCallbackInterface(ABC):
"""Defines a callback interface to be passed to
to run_indexing_entrypoint."""
@abstractmethod
def should_stop(self) -> bool:
"""Signal to stop the looping function in flight."""
@abstractmethod
def progress(self, amount: int) -> None:
"""Send progress updates to the caller."""
def _get_connector_runner(
db_session: Session,
attempt: IndexAttempt,
@@ -102,11 +88,15 @@ def _get_connector_runner(
)
class ConnectorStopSignal(Exception):
"""A custom exception used to signal a stop in processing."""
def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
tenant_id: str | None,
callback: RunIndexingCallbackInterface | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
"""
1. Get documents which are either new or updated from specified application
@@ -138,13 +128,7 @@ def _run_indexing(
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings,
heartbeat=IndexingHeartbeat(
index_attempt_id=index_attempt.id,
db_session=db_session,
# let the world know we're still making progress after
# every 10 batches
freq=10,
),
callback=callback,
)
indexing_pipeline = build_indexing_pipeline(
@@ -157,6 +141,7 @@ def _run_indexing(
),
db_session=db_session,
tenant_id=tenant_id,
callback=callback,
)
db_cc_pair = index_attempt.connector_credential_pair
@@ -228,7 +213,7 @@ def _run_indexing(
# contents still need to be initially pulled.
if callback:
if callback.should_stop():
raise RuntimeError("Connector stop signal detected")
raise ConnectorStopSignal("Connector stop signal detected")
# TODO: should we move this into the above callback instead?
db_session.refresh(db_cc_pair)
@@ -289,7 +274,7 @@ def _run_indexing(
db_session.commit()
if callback:
callback.progress(len(doc_batch))
callback.progress("_run_indexing", len(doc_batch))
# This new value is updated every batch, so UI can refresh per batch update
update_docs_indexed(
@@ -322,26 +307,16 @@ def _run_indexing(
)
except Exception as e:
logger.exception(
f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds"
f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds"
)
# Only mark the attempt as a complete failure if this is the first indexing window.
# Otherwise, some progress was made - the next run will not start from the beginning.
# In this case, it is not accurate to mark it as a failure. When the next run begins,
# if that fails immediately, it will be marked as a failure.
#
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or not db_cc_pair.status.is_active()
or index_attempt.status != IndexingStatus.IN_PROGRESS
):
mark_attempt_failed(
if isinstance(e, ConnectorStopSignal):
mark_attempt_canceled(
index_attempt.id,
db_session,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
reason=str(e),
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
@@ -353,6 +328,37 @@ def _run_indexing(
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
else:
# Only mark the attempt as a complete failure if this is the first indexing window.
# Otherwise, some progress was made - the next run will not start from the beginning.
# In this case, it is not accurate to mark it as a failure. When the next run begins,
# if that fails immediately, it will be marked as a failure.
#
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or not db_cc_pair.status.is_active()
or index_attempt.status != IndexingStatus.IN_PROGRESS
):
mark_attempt_failed(
index_attempt.id,
db_session,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
net_docs=net_doc_change,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
# break => similar to success case. As mentioned above, if the next run fails for the same
# reason it will then be marked as a failure
@@ -419,7 +425,7 @@ def run_indexing_entrypoint(
tenant_id: str | None,
connector_credential_pair_id: int,
is_ee: bool = False,
callback: RunIndexingCallbackInterface | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
try:
if is_ee:
@@ -427,17 +433,19 @@ def run_indexing_entrypoint(
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
IndexAttemptSingleton.set_cc_and_index_id(
TaskAttemptSingleton.set_cc_and_index_id(
index_attempt_id, connector_credential_pair_id
)
with get_session_with_tenant(tenant_id) as db_session:
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
tenant_str = ""
if tenant_id is not None:
tenant_str = f" for tenant {tenant_id}"
logger.info(
f"Indexing starting for tenant {tenant_id}: "
if tenant_id is not None
else ""
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
f"Indexing starting{tenant_str}: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)
@@ -445,10 +453,8 @@ def run_indexing_entrypoint(
_run_indexing(db_session, attempt, tenant_id, callback)
logger.info(
f"Indexing finished for tenant {tenant_id}: "
if tenant_id is not None
else ""
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
f"Indexing finished{tenant_str}: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)

View File

@@ -1,4 +0,0 @@
def name_sync_external_doc_permissions_task(
cc_pair_id: int, tenant_id: str | None = None
) -> str:
return f"sync_external_doc_permissions_task__{cc_pair_id}"

View File

@@ -14,15 +14,6 @@ from danswer.db.tasks import mark_task_start
from danswer.db.tasks import register_task
def name_cc_prune_task(
connector_id: int | None = None, credential_id: int | None = None
) -> str:
task_name = f"prune_connector_credential_pair_{connector_id}_{credential_id}"
if not connector_id or not credential_id:
task_name = "prune_connector_credential_pair"
return task_name
T = TypeVar("T", bound=Callable)

View File

@@ -6,33 +6,27 @@ from langchain.schema.messages import BaseMessage
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import ToolCall
from danswer.chat.llm_response_handler import LLMResponseHandlerManager
from danswer.chat.models import AnswerQuestionPossibleReturn
from danswer.chat.models import AnswerStyleConfig
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.file_store.utils import InMemoryChatFile
from danswer.llm.answering.llm_response_handler import LLMCall
from danswer.llm.answering.llm_response_handler import LLMResponseHandlerManager
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.llm.answering.prompts.build import default_build_system_message
from danswer.llm.answering.prompts.build import default_build_user_message
from danswer.llm.answering.stream_processing.answer_response_handler import (
AnswerResponseHandler,
)
from danswer.llm.answering.stream_processing.answer_response_handler import (
from danswer.chat.models import PromptConfig
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
from danswer.chat.prompt_builder.build import default_build_system_message
from danswer.chat.prompt_builder.build import default_build_user_message
from danswer.chat.prompt_builder.build import LLMCall
from danswer.chat.stream_processing.answer_response_handler import (
CitationResponseHandler,
)
from danswer.llm.answering.stream_processing.answer_response_handler import (
from danswer.chat.stream_processing.answer_response_handler import (
DummyAnswerResponseHandler,
)
from danswer.llm.answering.stream_processing.answer_response_handler import (
QuotesResponseHandler,
)
from danswer.llm.answering.stream_processing.utils import map_document_id_order
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
from danswer.chat.stream_processing.utils import map_document_id_order
from danswer.chat.tool_handling.tool_response_handler import ToolResponseHandler
from danswer.file_store.utils import InMemoryChatFile
from danswer.llm.interfaces import LLM
from danswer.llm.models import PreviousMessage
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.tools.force import ForceUseTool
from danswer.tools.models import ToolResponse
@@ -214,18 +208,23 @@ class Answer:
search_result = SearchTool.get_search_result(current_llm_call) or []
answer_handler: AnswerResponseHandler
if self.answer_style_config.citation_config:
answer_handler = CitationResponseHandler(
context_docs=search_result,
doc_id_to_rank_map=map_document_id_order(search_result),
)
elif self.answer_style_config.quotes_config:
answer_handler = QuotesResponseHandler(
context_docs=search_result,
)
else:
raise ValueError("No answer style config provided")
# Quotes are no longer supported
# answer_handler: AnswerResponseHandler
# if self.answer_style_config.citation_config:
# answer_handler = CitationResponseHandler(
# context_docs=search_result,
# doc_id_to_rank_map=map_document_id_order(search_result),
# )
# elif self.answer_style_config.quotes_config:
# answer_handler = QuotesResponseHandler(
# context_docs=search_result,
# )
# else:
# raise ValueError("No answer style config provided")
answer_handler = CitationResponseHandler(
context_docs=search_result,
doc_id_to_rank_map=map_document_id_order(search_result),
)
response_handler_manager = LLMResponseHandlerManager(
tool_call_handler, answer_handler, self.is_cancelled
@@ -233,6 +232,8 @@ class Answer:
# DEBUG: good breakpoint
stream = self.llm.stream(
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
prompt=current_llm_call.prompt_builder.build(),
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
tool_choice=(
@@ -263,6 +264,7 @@ class Answer:
message_history=self.message_history,
llm_config=self.llm.config,
single_message_history=self.single_message_history,
raw_user_text=self.question,
)
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)

View File

@@ -2,20 +2,79 @@ import re
from typing import cast
from uuid import UUID
from fastapi import HTTPException
from fastapi.datastructures import Headers
from sqlalchemy.orm import Session
from danswer.auth.users import is_user_admin
from danswer.chat.models import CitationInfo
from danswer.chat.models import LlmDoc
from danswer.chat.models import PersonaOverrideConfig
from danswer.chat.models import ThreadMessage
from danswer.configs.constants import DEFAULT_PERSONA_ID
from danswer.configs.constants import MessageType
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import RerankingDetails
from danswer.context.search.models import RetrievalDetails
from danswer.db.chat import create_chat_session
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.llm import fetch_existing_doc_sets
from danswer.db.llm import fetch_existing_tools
from danswer.db.models import ChatMessage
from danswer.llm.answering.models import PreviousMessage
from danswer.search.models import InferenceSection
from danswer.db.models import Persona
from danswer.db.models import Prompt
from danswer.db.models import Tool
from danswer.db.models import User
from danswer.db.persona import get_prompts_by_ids
from danswer.llm.models import PreviousMessage
from danswer.natural_language_processing.utils import BaseTokenizer
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.utils.logger import setup_logger
logger = setup_logger()
def prepare_chat_message_request(
message_text: str,
user: User | None,
persona_id: int | None,
# Does the question need to have a persona override
persona_override_config: PersonaOverrideConfig | None,
prompt: Prompt | None,
message_ts_to_respond_to: str | None,
retrieval_details: RetrievalDetails | None,
rerank_settings: RerankingDetails | None,
db_session: Session,
) -> CreateChatMessageRequest:
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
new_chat_session = create_chat_session(
db_session=db_session,
description=None,
user_id=user.id if user else None,
# If using an override, this id will be ignored later on
persona_id=persona_id or DEFAULT_PERSONA_ID,
danswerbot_flow=True,
slack_thread_id=message_ts_to_respond_to,
)
return CreateChatMessageRequest(
chat_session_id=new_chat_session.id,
parent_message_id=None, # It's a standalone chat session each time
message=message_text,
file_descriptors=[], # Currently SlackBot/answer api do not support files in the context
prompt_id=prompt.id if prompt else None,
# Can always override the persona for the single query, if it's a normal persona
# then it will be treated the same
persona_override_config=persona_override_config,
search_doc_ids=None,
retrieval_options=retrieval_details,
rerank_settings=rerank_settings,
)
def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDoc:
return LlmDoc(
document_id=inference_section.center_chunk.document_id,
@@ -31,9 +90,49 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
if inference_section.center_chunk.source_links
else None,
source_links=inference_section.center_chunk.source_links,
match_highlights=inference_section.center_chunk.match_highlights,
)
def combine_message_thread(
messages: list[ThreadMessage],
max_tokens: int | None,
llm_tokenizer: BaseTokenizer,
) -> str:
"""Used to create a single combined message context from threads"""
if not messages:
return ""
message_strs: list[str] = []
total_token_count = 0
for message in reversed(messages):
if message.role == MessageType.USER:
role_str = message.role.value.upper()
if message.sender:
role_str += " " + message.sender
else:
# Since other messages might have the user identifying information
# better to use Unknown for symmetry
role_str += " Unknown"
else:
role_str = message.role.value.upper()
msg_str = f"{role_str}:\n{message.message}"
message_token_count = len(llm_tokenizer.encode(msg_str))
if (
max_tokens is not None
and total_token_count + message_token_count > max_tokens
):
break
message_strs.insert(0, msg_str)
total_token_count += message_token_count
return "\n\n".join(message_strs)
def create_chat_chain(
chat_session_id: UUID,
db_session: Session,
@@ -196,3 +295,71 @@ def extract_headers(
if lowercase_key in headers:
extracted_headers[lowercase_key] = headers[lowercase_key]
return extracted_headers
def create_temporary_persona(
persona_config: PersonaOverrideConfig, db_session: Session, user: User | None = None
) -> Persona:
if not is_user_admin(user):
raise HTTPException(
status_code=403,
detail="User is not authorized to create a persona in one shot queries",
)
"""Create a temporary Persona object from the provided configuration."""
persona = Persona(
name=persona_config.name,
description=persona_config.description,
num_chunks=persona_config.num_chunks,
llm_relevance_filter=persona_config.llm_relevance_filter,
llm_filter_extraction=persona_config.llm_filter_extraction,
recency_bias=persona_config.recency_bias,
llm_model_provider_override=persona_config.llm_model_provider_override,
llm_model_version_override=persona_config.llm_model_version_override,
)
if persona_config.prompts:
persona.prompts = [
Prompt(
name=p.name,
description=p.description,
system_prompt=p.system_prompt,
task_prompt=p.task_prompt,
include_citations=p.include_citations,
datetime_aware=p.datetime_aware,
)
for p in persona_config.prompts
]
elif persona_config.prompt_ids:
persona.prompts = get_prompts_by_ids(
db_session=db_session, prompt_ids=persona_config.prompt_ids
)
persona.tools = []
if persona_config.custom_tools_openapi:
for schema in persona_config.custom_tools_openapi:
tools = cast(
list[Tool],
build_custom_tools_from_openapi_schema_and_headers(schema),
)
persona.tools.extend(tools)
if persona_config.tools:
tool_ids = [tool.id for tool in persona_config.tools]
persona.tools.extend(
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
)
if persona_config.tool_ids:
persona.tools.extend(
fetch_existing_tools(
db_session=db_session, tool_ids=persona_config.tool_ids
)
)
fetched_docs = fetch_existing_doc_sets(
db_session=db_session, doc_ids=persona_config.document_set_ids
)
persona.document_sets = fetched_docs
return persona

View File

@@ -1,60 +1,22 @@
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from typing import TYPE_CHECKING
from langchain_core.messages import BaseMessage
from pydantic.v1 import BaseModel as BaseModel__v1
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import ResponsePart
from danswer.chat.models import StreamStopInfo
from danswer.chat.models import StreamStopReason
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
from danswer.tools.force import ForceUseTool
from danswer.tools.models import ToolCallFinalResult
from danswer.tools.models import ToolCallKickoff
from danswer.tools.models import ToolResponse
from danswer.tools.tool import Tool
if TYPE_CHECKING:
from danswer.llm.answering.stream_processing.answer_response_handler import (
AnswerResponseHandler,
)
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
ResponsePart = (
DanswerAnswerPiece
| CitationInfo
| DanswerQuotes
| ToolCallKickoff
| ToolResponse
| ToolCallFinalResult
| StreamStopInfo
)
class LLMCall(BaseModel__v1):
prompt_builder: AnswerPromptBuilder
tools: list[Tool]
force_use_tool: ForceUseTool
files: list[InMemoryChatFile]
tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult]
using_tool_calling_llm: bool
class Config:
arbitrary_types_allowed = True
from danswer.chat.prompt_builder.build import LLMCall
from danswer.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from danswer.chat.tool_handling.tool_response_handler import ToolResponseHandler
class LLMResponseHandlerManager:
def __init__(
self,
tool_handler: "ToolResponseHandler",
answer_handler: "AnswerResponseHandler",
tool_handler: ToolResponseHandler,
answer_handler: AnswerResponseHandler,
is_cancelled: Callable[[], bool],
):
self.tool_handler = tool_handler

View File

@@ -1,17 +1,30 @@
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from enum import Enum
from typing import Any
from typing import TYPE_CHECKING
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import model_validator
from danswer.configs.constants import DocumentSource
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import RetrievalDocs
from danswer.search.models import SearchResponse
from danswer.configs.constants import MessageType
from danswer.context.search.enums import QueryFlow
from danswer.context.search.enums import RecencyBiasSetting
from danswer.context.search.enums import SearchType
from danswer.context.search.models import RetrievalDocs
from danswer.llm.override_models import PromptOverride
from danswer.tools.models import ToolCallFinalResult
from danswer.tools.models import ToolCallKickoff
from danswer.tools.models import ToolResponse
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType
if TYPE_CHECKING:
from danswer.db.models import Prompt
class LlmDoc(BaseModel):
"""This contains the minimal set information for the LLM portion including citations"""
@@ -25,6 +38,7 @@ class LlmDoc(BaseModel):
updated_at: datetime | None
link: str | None
source_links: dict[int, str] | None
match_highlights: list[str] | None
# First chunk of info for streaming QA
@@ -117,20 +131,6 @@ class StreamingError(BaseModel):
stack_trace: str | None = None
class DanswerQuote(BaseModel):
# This is during inference so everything is a string by this point
quote: str
document_id: str
link: str | None
source_type: str
semantic_identifier: str
blurb: str
class DanswerQuotes(BaseModel):
quotes: list[DanswerQuote]
class DanswerContext(BaseModel):
content: str
document_id: str
@@ -146,14 +146,20 @@ class DanswerAnswer(BaseModel):
answer: str | None
class QAResponse(SearchResponse, DanswerAnswer):
quotes: list[DanswerQuote] | None
contexts: list[DanswerContexts] | None
predicted_flow: QueryFlow
predicted_search: SearchType
eval_res_valid: bool | None = None
class ThreadMessage(BaseModel):
message: str
sender: str | None = None
role: MessageType = MessageType.USER
class ChatDanswerBotResponse(BaseModel):
answer: str | None = None
citations: list[CitationInfo] | None = None
docs: QADocsResponse | None = None
llm_selected_doc_indices: list[int] | None = None
error_msg: str | None = None
chat_message_id: int | None = None
answer_valid: bool = True # Reflexion result, default True if Reflexion not run
class FileChatDisplay(BaseModel):
@@ -165,9 +171,41 @@ class CustomToolResponse(BaseModel):
tool_name: str
class ToolConfig(BaseModel):
id: int
class PromptOverrideConfig(BaseModel):
name: str
description: str = ""
system_prompt: str
task_prompt: str = ""
include_citations: bool = True
datetime_aware: bool = True
class PersonaOverrideConfig(BaseModel):
name: str
description: str
search_type: SearchType = SearchType.SEMANTIC
num_chunks: float | None = None
llm_relevance_filter: bool = False
llm_filter_extraction: bool = False
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
llm_model_provider_override: str | None = None
llm_model_version_override: str | None = None
prompts: list[PromptOverrideConfig] = Field(default_factory=list)
prompt_ids: list[int] = Field(default_factory=list)
document_set_ids: list[int] = Field(default_factory=list)
tools: list[ToolConfig] = Field(default_factory=list)
tool_ids: list[int] = Field(default_factory=list)
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
AnswerQuestionPossibleReturn = (
DanswerAnswerPiece
| DanswerQuotes
| CitationInfo
| DanswerContexts
| FileChatDisplay
@@ -183,3 +221,109 @@ AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn]
class LLMMetricsContainer(BaseModel):
prompt_tokens: int
response_tokens: int
StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]
class DocumentPruningConfig(BaseModel):
max_chunks: int | None = None
max_window_percentage: float | None = None
max_tokens: int | None = None
# different pruning behavior is expected when the
# user manually selects documents they want to chat with
# e.g. we don't want to truncate each document to be no more
# than one chunk long
is_manually_selected_docs: bool = False
# If user specifies to include additional context Chunks for each match, then different pruning
# is used. As many Sections as possible are included, and the last Section is truncated
# If this is false, all of the Sections are truncated if they are longer than the expected Chunk size.
# Sections are often expected to be longer than the maximum Chunk size but Chunks should not be.
use_sections: bool = True
# If using tools, then we need to consider the tool length
tool_num_tokens: int = 0
# If using a tool message to represent the docs, then we have to JSON serialize
# the document content, which adds to the token count.
using_tool_message: bool = False
class ContextualPruningConfig(DocumentPruningConfig):
num_chunk_multiple: int
@classmethod
def from_doc_pruning_config(
cls, num_chunk_multiple: int, doc_pruning_config: DocumentPruningConfig
) -> "ContextualPruningConfig":
return cls(num_chunk_multiple=num_chunk_multiple, **doc_pruning_config.dict())
class CitationConfig(BaseModel):
all_docs_useful: bool = False
class QuotesConfig(BaseModel):
pass
class AnswerStyleConfig(BaseModel):
citation_config: CitationConfig | None = None
quotes_config: QuotesConfig | None = None
document_pruning_config: DocumentPruningConfig = Field(
default_factory=DocumentPruningConfig
)
# forces the LLM to return a structured response, see
# https://platform.openai.com/docs/guides/structured-outputs/introduction
# right now, only used by the simple chat API
structured_response_format: dict | None = None
@model_validator(mode="after")
def check_quotes_and_citation(self) -> "AnswerStyleConfig":
if self.citation_config is None and self.quotes_config is None:
raise ValueError(
"One of `citation_config` or `quotes_config` must be provided"
)
if self.citation_config is not None and self.quotes_config is not None:
raise ValueError(
"Only one of `citation_config` or `quotes_config` must be provided"
)
return self
class PromptConfig(BaseModel):
"""Final representation of the Prompt configuration passed
into the `Answer` object."""
system_prompt: str
task_prompt: str
datetime_aware: bool
include_citations: bool
@classmethod
def from_model(
cls, model: "Prompt", prompt_override: PromptOverride | None = None
) -> "PromptConfig":
override_system_prompt = (
prompt_override.system_prompt if prompt_override else None
)
override_task_prompt = prompt_override.task_prompt if prompt_override else None
return cls(
system_prompt=override_system_prompt or model.system_prompt,
task_prompt=override_task_prompt or model.task_prompt,
datetime_aware=model.datetime_aware,
include_citations=model.include_citations,
)
model_config = ConfigDict(frozen=True)
ResponsePart = (
DanswerAnswerPiece
| CitationInfo
| ToolCallKickoff
| ToolResponse
| ToolCallFinalResult
| StreamStopInfo
)

View File

@@ -6,29 +6,41 @@ from typing import cast
from sqlalchemy.orm import Session
from danswer.chat.answer import Answer
from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.chat_utils import create_temporary_persona
from danswer.chat.models import AllCitations
from danswer.chat.models import AnswerStyleConfig
from danswer.chat.models import ChatDanswerBotResponse
from danswer.chat.models import CitationConfig
from danswer.chat.models import CitationInfo
from danswer.chat.models import CustomToolResponse
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerContexts
from danswer.chat.models import DocumentPruningConfig
from danswer.chat.models import FileChatDisplay
from danswer.chat.models import FinalUsedContextDocsResponse
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import MessageResponseIDInfo
from danswer.chat.models import MessageSpecificCitations
from danswer.chat.models import PromptConfig
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.chat.models import StreamStopInfo
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
from danswer.configs.chat_configs import BING_API_KEY
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.context.search.enums import OptionalSearchSetting
from danswer.context.search.enums import QueryFlow
from danswer.context.search.enums import SearchType
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import RetrievalDetails
from danswer.context.search.retrieval.search_runner import inference_sections_from_ids
from danswer.context.search.utils import chunks_or_sections_to_search_docs
from danswer.context.search.utils import dedupe_documents
from danswer.context.search.utils import drop_llm_indices
from danswer.context.search.utils import relevant_sections_to_indices
from danswer.db.chat import attach_files_to_chat_message
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
@@ -41,7 +53,6 @@ from danswer.db.chat import reserve_message_id
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.models import SearchDoc as DbSearchDoc
from danswer.db.models import ToolCall
from danswer.db.models import User
@@ -51,40 +62,24 @@ from danswer.document_index.factory import get_default_document_index
from danswer.file_store.models import ChatFileType
from danswer.file_store.models import FileDescriptor
from danswer.file_store.utils import load_all_chat_files
from danswer.file_store.utils import save_files_from_urls
from danswer.llm.answering.answer import Answer
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import CitationConfig
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.file_store.utils import save_files
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.interfaces import LLMConfig
from danswer.llm.models import PreviousMessage
from danswer.llm.utils import litellm_exception_to_error_msg
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import InferenceSection
from danswer.search.retrieval.search_runner import inference_sections_from_ids
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.search.utils import dedupe_documents
from danswer.search.utils import drop_llm_indices
from danswer.search.utils import relevant_sections_to_indices
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
from danswer.tools.built_in_tools import get_built_in_tool_by_id
from danswer.tools.force import ForceUseTool
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.models import ToolResponse
from danswer.tools.tool import Tool
from danswer.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.tools.tool_constructor import construct_tools
from danswer.tools.tool_constructor import CustomToolConfig
from danswer.tools.tool_constructor import ImageGenerationToolConfig
from danswer.tools.tool_constructor import InternetSearchToolConfig
from danswer.tools.tool_constructor import SearchToolConfig
from danswer.tools.tool_implementations.custom.custom_tool import (
CUSTOM_TOOL_RESPONSE_ID,
)
@@ -95,9 +90,6 @@ from danswer.tools.tool_implementations.images.image_generation_tool import (
from danswer.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationResponse,
)
from danswer.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
INTERNET_SEARCH_RESPONSE_ID,
)
@@ -113,6 +105,7 @@ from danswer.tools.tool_implementations.internet_search.internet_search_tool imp
from danswer.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from danswer.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
from danswer.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
@@ -122,11 +115,12 @@ from danswer.tools.tool_implementations.search.search_tool import (
SECTION_RELEVANCE_LIST_ID,
)
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.utils import compute_all_tool_tokens
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.headers import header_dict_to_header_list
from danswer.utils.logger import setup_logger
from danswer.utils.long_term_log import LongTermLogger
from danswer.utils.timing import log_function_time
from danswer.utils.timing import log_generator_function_time
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -269,6 +263,7 @@ def _get_force_search_settings(
ChatPacket = (
StreamingError
| QADocsResponse
| DanswerContexts
| LLMRelevanceFilterResponse
| FinalUsedContextDocsResponse
| ChatMessageDetail
@@ -295,11 +290,12 @@ def stream_chat_message_objects(
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
# if specified, uses the last user message and does not create a new user message based
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
enforce_chat_session_id_for_search_docs: bool = True,
bypass_acl: bool = False,
include_contexts: bool = False,
) -> ChatPacketStream:
"""Streams in order:
1. [conditional] Retrieved documents if a search needs to be run
@@ -307,6 +303,10 @@ def stream_chat_message_objects(
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
4. [always] Details on the final AI response message that is created
"""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
use_existing_user_message = new_msg_req.use_existing_user_message
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
# Currently surrounding context is not supported for chat
# Chat is already token heavy and harder for the model to process plus it would roll history over much faster
new_msg_req.chunks_above = 0
@@ -328,17 +328,36 @@ def stream_chat_message_objects(
retrieval_options = new_msg_req.retrieval_options
alternate_assistant_id = new_msg_req.alternate_assistant_id
# use alternate persona if alternative assistant id is passed in
# permanent "log" store, used primarily for debugging
long_term_logger = LongTermLogger(
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
)
if alternate_assistant_id is not None:
# Allows users to specify a temporary persona (assistant) in the chat session
# this takes highest priority since it's user specified
persona = get_persona_by_id(
alternate_assistant_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
elif new_msg_req.persona_override_config:
# Certain endpoints allow users to specify arbitrary persona settings
# this should never conflict with the alternate_assistant_id
persona = persona = create_temporary_persona(
db_session=db_session,
persona_config=new_msg_req.persona_override_config,
user=user,
)
else:
persona = chat_session.persona
if not persona:
raise RuntimeError("No persona specified or found for chat session")
# If a prompt override is specified via the API, use that with highest priority
# but for saving it, we are just mapping it to an existing prompt
prompt_id = new_msg_req.prompt_id
if prompt_id is None and persona.prompts:
prompt_id = sorted(persona.prompts, key=lambda x: x.id)[-1].id
@@ -353,6 +372,7 @@ def stream_chat_message_objects(
persona=persona,
llm_override=new_msg_req.llm_override or chat_session.llm_override,
additional_headers=litellm_additional_headers,
long_term_logger=long_term_logger,
)
except GenAIDisabledException:
raise RuntimeError("LLM is disabled. Can't use chat flow without LLM.")
@@ -428,12 +448,20 @@ def stream_chat_message_objects(
final_msg, history_msgs = create_chat_chain(
chat_session_id=chat_session_id, db_session=db_session
)
if final_msg.message_type != MessageType.USER:
raise RuntimeError(
"The last message was not a user message. Cannot call "
"`stream_chat_message_objects` with `is_regenerate=True` "
"when the last message is not a user message."
)
if existing_assistant_message_id is None:
if final_msg.message_type != MessageType.USER:
raise RuntimeError(
"The last message was not a user message. Cannot call "
"`stream_chat_message_objects` with `is_regenerate=True` "
"when the last message is not a user message."
)
else:
if final_msg.id != existing_assistant_message_id:
raise RuntimeError(
"The last message was not the existing assistant message. "
f"Final message id: {final_msg.id}, "
f"existing assistant message id: {existing_assistant_message_id}"
)
# Disable Query Rephrasing for the first message
# This leads to a better first response since the LLM rephrasing the question
@@ -504,13 +532,19 @@ def stream_chat_message_objects(
),
max_window_percentage=max_document_percentage,
)
reserved_message_id = reserve_message_id(
db_session=db_session,
chat_session_id=chat_session_id,
parent_message=user_message.id
if user_message is not None
else parent_message.id,
message_type=MessageType.ASSISTANT,
# we don't need to reserve a message id if we're using an existing assistant message
reserved_message_id = (
final_msg.id
if existing_assistant_message_id is not None
else reserve_message_id(
db_session=db_session,
chat_session_id=chat_session_id,
parent_message=user_message.id
if user_message is not None
else parent_message.id,
message_type=MessageType.ASSISTANT,
)
)
yield MessageResponseIDInfo(
user_message_id=user_message.id if user_message else None,
@@ -525,7 +559,13 @@ def stream_chat_message_objects(
partial_response = partial(
create_new_chat_message,
chat_session_id=chat_session_id,
parent_message=final_msg,
# if we're using an existing assistant message, then this will just be an
# update operation, in which case the parent should be the parent of
# the latest. If we're creating a new assistant message, then the parent
# should be the latest message (latest user message)
parent_message=(
final_msg if existing_assistant_message_id is None else parent_message
),
prompt_id=prompt_id,
overridden_model=overridden_model,
# message=,
@@ -537,21 +577,37 @@ def stream_chat_message_objects(
# reference_docs=,
db_session=db_session,
commit=False,
reserved_message_id=reserved_message_id,
)
if not final_msg.prompt:
raise RuntimeError("No Prompt found")
prompt_config = (
PromptConfig.from_model(
final_msg.prompt,
prompt_override=(
new_msg_req.prompt_override or chat_session.prompt_override
),
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
if new_msg_req.persona_override_config:
prompt_config = PromptConfig(
system_prompt=new_msg_req.persona_override_config.prompts[
0
].system_prompt,
task_prompt=new_msg_req.persona_override_config.prompts[0].task_prompt,
datetime_aware=new_msg_req.persona_override_config.prompts[
0
].datetime_aware,
include_citations=new_msg_req.persona_override_config.prompts[
0
].include_citations,
)
if not persona
else PromptConfig.from_model(persona.prompts[0])
)
elif prompt_override:
if not final_msg.prompt:
raise ValueError(
"Prompt override cannot be applied, no base prompt found."
)
prompt_config = PromptConfig.from_model(
final_msg.prompt,
prompt_override=prompt_override,
)
elif final_msg.prompt:
prompt_config = PromptConfig.from_model(final_msg.prompt)
else:
prompt_config = PromptConfig.from_model(persona.prompts[0])
answer_style_config = AnswerStyleConfig(
citation_config=CitationConfig(
all_docs_useful=selected_db_search_docs is not None
@@ -560,142 +616,42 @@ def stream_chat_message_objects(
structured_response_format=new_msg_req.structured_response_format,
)
# find out what tools to use
search_tool: SearchTool | None = None
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
for db_tool_model in persona.tools:
# handle in-code tools specially
if db_tool_model.in_code_tool_id:
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
search_tool = SearchTool(
db_session=db_session,
user=user,
persona=persona,
retrieval_options=retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=document_pruning_config,
answer_style_config=answer_style_config,
selected_sections=selected_sections,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
evaluation_type=(
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
),
)
tool_dict[db_tool_model.id] = [search_tool]
elif tool_cls.__name__ == ImageGenerationTool.__name__:
img_generation_llm_config: LLMConfig | None = None
if (
llm
and llm.config.api_key
and llm.config.model_provider == "openai"
):
img_generation_llm_config = LLMConfig(
model_provider=llm.config.model_provider,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=llm.config.api_key,
api_base=llm.config.api_base,
api_version=llm.config.api_version,
)
elif (
llm.config.model_provider == "azure"
and AZURE_DALLE_API_KEY is not None
):
img_generation_llm_config = LLMConfig(
model_provider="azure",
model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}",
temperature=GEN_AI_TEMPERATURE,
api_key=AZURE_DALLE_API_KEY,
api_base=AZURE_DALLE_API_BASE,
api_version=AZURE_DALLE_API_VERSION,
)
else:
llm_providers = fetch_existing_llm_providers(db_session)
openai_provider = next(
iter(
[
llm_provider
for llm_provider in llm_providers
if llm_provider.provider == "openai"
]
),
None,
)
if not openai_provider or not openai_provider.api_key:
raise ValueError(
"Image generation tool requires an OpenAI API key"
)
img_generation_llm_config = LLMConfig(
model_provider=openai_provider.provider,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=openai_provider.api_key,
api_base=openai_provider.api_base,
api_version=openai_provider.api_version,
)
tool_dict[db_tool_model.id] = [
ImageGenerationTool(
api_key=cast(str, img_generation_llm_config.api_key),
api_base=img_generation_llm_config.api_base,
api_version=img_generation_llm_config.api_version,
additional_headers=litellm_additional_headers,
model=img_generation_llm_config.model_name,
)
]
elif tool_cls.__name__ == InternetSearchTool.__name__:
bing_api_key = BING_API_KEY
if not bing_api_key:
raise ValueError(
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
)
tool_dict[db_tool_model.id] = [
InternetSearchTool(
api_key=bing_api_key,
answer_style_config=answer_style_config,
prompt_config=prompt_config,
)
]
continue
# handle all custom tools
if db_tool_model.openapi_schema:
tool_dict[db_tool_model.id] = cast(
list[Tool],
build_custom_tools_from_openapi_schema_and_headers(
db_tool_model.openapi_schema,
dynamic_schema_info=DynamicSchemaInfo(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
),
custom_headers=(db_tool_model.custom_headers or [])
+ (
header_dict_to_header_list(
custom_tool_additional_headers or {}
)
),
),
)
tool_dict = construct_tools(
persona=persona,
prompt_config=prompt_config,
db_session=db_session,
user=user,
llm=llm,
fast_llm=fast_llm,
search_tool_config=SearchToolConfig(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,
retrieval_options=retrieval_options or RetrievalDetails(),
rerank_settings=new_msg_req.rerank_settings,
selected_sections=selected_sections,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
latest_query_files=latest_query_files,
bypass_acl=bypass_acl,
),
internet_search_tool_config=InternetSearchToolConfig(
answer_style_config=answer_style_config,
),
image_generation_tool_config=ImageGenerationToolConfig(
additional_headers=litellm_additional_headers,
),
custom_tool_config=CustomToolConfig(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
additional_headers=custom_tool_additional_headers,
),
)
tools: list[Tool] = []
for tool_list in tool_dict.values():
tools.extend(tool_list)
# factor in tool definition size when pruning
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
tools, llm_tokenizer
)
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
llm_provider, llm_model_name
)
# LLM prompt building, response capturing, etc.
answer = Answer(
is_connected=is_connected,
@@ -724,7 +680,8 @@ def stream_chat_message_objects(
reference_db_search_docs = None
qa_docs_response = None
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
# any files to associate with the AI message e.g. dall-e generated images
ai_message_files = []
dropped_indices = None
tool_result = None
@@ -779,8 +736,14 @@ def stream_chat_message_objects(
list[ImageGenerationResponse], packet.response
)
file_ids = save_files_from_urls(
[img.url for img in img_generation_response]
file_ids = save_files(
urls=[img.url for img in img_generation_response if img.url],
base64_files=[
img.image_data
for img in img_generation_response
if img.image_data
],
tenant_id=tenant_id,
)
ai_message_files = [
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
@@ -806,15 +769,19 @@ def stream_chat_message_objects(
or custom_tool_response.response_type == "csv"
):
file_ids = custom_tool_response.tool_result.file_ids
ai_message_files = [
FileDescriptor(
id=str(file_id),
type=ChatFileType.IMAGE
if custom_tool_response.response_type == "image"
else ChatFileType.CSV,
)
for file_id in file_ids
]
ai_message_files.extend(
[
FileDescriptor(
id=str(file_id),
type=(
ChatFileType.IMAGE
if custom_tool_response.response_type == "image"
else ChatFileType.CSV
),
)
for file_id in file_ids
]
)
yield FileChatDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
@@ -823,6 +790,8 @@ def stream_chat_message_objects(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
elif packet.id == SEARCH_DOC_CONTENT_ID and include_contexts:
yield cast(DanswerContexts, packet.response)
elif isinstance(packet, StreamStopInfo):
pass
@@ -862,7 +831,8 @@ def stream_chat_message_objects(
citations_list=answer.citations,
db_docs=reference_db_search_docs,
)
yield AllCitations(citations=answer.citations)
if not answer.is_cancelled():
yield AllCitations(citations=answer.citations)
# Saving Gen AI answer and responding with message info
tool_name_to_tool_id: dict[str, int] = {}
@@ -871,7 +841,6 @@ def stream_chat_message_objects(
tool_name_to_tool_id[tool.name] = tool_id
gen_ai_response_message = partial_response(
reserved_message_id=reserved_message_id,
message=answer.llm_answer,
rephrased_query=(
qa_docs_response.rephrased_query if qa_docs_response else None
@@ -879,9 +848,11 @@ def stream_chat_message_objects(
reference_docs=reference_db_search_docs,
files=ai_message_files,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=message_specific_citations.citation_map
if message_specific_citations
else None,
citations=(
message_specific_citations.citation_map
if message_specific_citations
else None
),
error=None,
tool_call=(
ToolCall(
@@ -915,7 +886,6 @@ def stream_chat_message_objects(
def stream_chat_message(
new_msg_req: CreateChatMessageRequest,
user: User | None,
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
@@ -925,10 +895,36 @@ def stream_chat_message(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
use_existing_user_message=use_existing_user_message,
litellm_additional_headers=litellm_additional_headers,
custom_tool_additional_headers=custom_tool_additional_headers,
is_connected=is_connected,
)
for obj in objects:
yield get_json_line(obj.model_dump())
@log_function_time()
def gather_stream_for_slack(
packets: ChatPacketStream,
) -> ChatDanswerBotResponse:
response = ChatDanswerBotResponse()
answer = ""
for packet in packets:
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
elif isinstance(packet, QADocsResponse):
response.docs = packet
elif isinstance(packet, StreamingError):
response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail):
response.chat_message_id = packet.message_id
elif isinstance(packet, LLMRelevanceFilterResponse):
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
elif isinstance(packet, AllCitations):
response.citations = packet.citations
if answer:
response.answer = answer
return response

View File

@@ -4,20 +4,26 @@ from typing import cast
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from pydantic.v1 import BaseModel as BaseModel__v1
from danswer.chat.models import PromptConfig
from danswer.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens
from danswer.chat.prompt_builder.utils import translate_history_to_basemessages
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input_tokens
from danswer.llm.interfaces import LLMConfig
from danswer.llm.models import PreviousMessage
from danswer.llm.utils import build_content_with_imgs
from danswer.llm.utils import check_message_tokens
from danswer.llm.utils import message_to_prompt_and_imgs
from danswer.llm.utils import translate_history_to_basemessages
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from danswer.prompts.prompt_utils import add_date_time_to_prompt
from danswer.prompts.prompt_utils import drop_messages_history_overflow
from danswer.tools.force import ForceUseTool
from danswer.tools.models import ToolCallFinalResult
from danswer.tools.models import ToolCallKickoff
from danswer.tools.models import ToolResponse
from danswer.tools.tool import Tool
def default_build_system_message(
@@ -58,6 +64,7 @@ class AnswerPromptBuilder:
user_message: HumanMessage,
message_history: list[PreviousMessage],
llm_config: LLMConfig,
raw_user_text: str,
single_message_history: str | None = None,
) -> None:
self.max_tokens = compute_max_llm_input_tokens(llm_config)
@@ -88,6 +95,8 @@ class AnswerPromptBuilder:
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
self.raw_user_message = raw_user_text
def update_system_prompt(self, system_message: SystemMessage | None) -> None:
if not system_message:
self.system_message_and_token_cnt = None
@@ -136,3 +145,15 @@ class AnswerPromptBuilder:
return drop_messages_history_overflow(
final_messages_with_tokens, self.max_tokens
)
class LLMCall(BaseModel__v1):
prompt_builder: AnswerPromptBuilder
tools: list[Tool]
force_use_tool: ForceUseTool
files: list[InMemoryChatFile]
tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult]
using_tool_calling_llm: bool
class Config:
arbitrary_types_allowed = True

View File

@@ -2,11 +2,12 @@ from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from danswer.chat.models import LlmDoc
from danswer.chat.models import PromptConfig
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
from danswer.context.search.models import InferenceChunk
from danswer.db.models import Persona
from danswer.db.persona import get_default_prompt__read_only
from danswer.db.search_settings import get_multilingual_expansion
from danswer.llm.answering.models import PromptConfig
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.interfaces import LLMConfig
@@ -29,7 +30,6 @@ from danswer.prompts.token_counts import (
from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger
logger = setup_logger()

View File

@@ -1,46 +1,16 @@
from langchain.schema.messages import HumanMessage
from danswer.chat.models import LlmDoc
from danswer.chat.models import PromptConfig
from danswer.configs.chat_configs import LANGUAGE_HINT
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.context.search.models import InferenceChunk
from danswer.db.search_settings import get_multilingual_expansion
from danswer.llm.answering.models import PromptConfig
from danswer.llm.utils import message_to_prompt_and_imgs
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT
from danswer.prompts.prompt_utils import add_date_time_to_prompt
from danswer.prompts.prompt_utils import build_complete_context_str
from danswer.search.models import InferenceChunk
def _build_weak_llm_quotes_prompt(
question: str,
context_docs: list[LlmDoc] | list[InferenceChunk],
history_str: str,
prompt: PromptConfig,
) -> HumanMessage:
"""Since Danswer supports a variety of LLMs, this less demanding prompt is provided
as an option to use with weaker LLMs such as small version, low float precision, quantized,
or distilled models. It only uses one context document and has very weak requirements of
output format.
"""
context_block = ""
if context_docs:
context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs[0].content)
prompt_str = WEAK_LLM_PROMPT.format(
system_prompt=prompt.system_prompt,
context_block=context_block,
task_prompt=prompt.task_prompt,
user_query=question,
)
if prompt.datetime_aware:
prompt_str = add_date_time_to_prompt(prompt_str=prompt_str)
return HumanMessage(content=prompt_str)
def _build_strong_llm_quotes_prompt(
@@ -81,15 +51,9 @@ def build_quotes_user_message(
history_str: str,
prompt: PromptConfig,
) -> HumanMessage:
prompt_builder = (
_build_weak_llm_quotes_prompt
if QA_PROMPT_OVERRIDE == "weak"
else _build_strong_llm_quotes_prompt
)
query, _ = message_to_prompt_and_imgs(message)
return prompt_builder(
return _build_strong_llm_quotes_prompt(
question=query,
context_docs=context_docs,
history_str=history_str,

View File

@@ -0,0 +1,62 @@
from langchain.schema.messages import AIMessage
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from danswer.configs.constants import MessageType
from danswer.db.models import ChatMessage
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.models import PreviousMessage
from danswer.llm.utils import build_content_with_imgs
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
def build_dummy_prompt(
system_prompt: str, task_prompt: str, retrieval_disabled: bool
) -> str:
if retrieval_disabled:
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
user_query="<USER_QUERY>",
system_prompt=system_prompt,
task_prompt=task_prompt,
).strip()
return PARAMATERIZED_PROMPT.format(
context_docs_str="<CONTEXT_DOCS>",
user_query="<USER_QUERY>",
system_prompt=system_prompt,
task_prompt=task_prompt,
).strip()
def translate_danswer_msg_to_langchain(
msg: ChatMessage | PreviousMessage,
) -> BaseMessage:
files: list[InMemoryChatFile] = []
# If the message is a `ChatMessage`, it doesn't have the downloaded files
# attached. Just ignore them for now.
if not isinstance(msg, ChatMessage):
files = msg.files
content = build_content_with_imgs(msg.message, files, message_type=msg.message_type)
if msg.message_type == MessageType.SYSTEM:
raise ValueError("System messages are not currently part of history")
if msg.message_type == MessageType.ASSISTANT:
return AIMessage(content=content)
if msg.message_type == MessageType.USER:
return HumanMessage(content=content)
raise ValueError(f"New message type {msg.message_type} not handled")
def translate_history_to_basemessages(
history: list[ChatMessage] | list["PreviousMessage"],
) -> tuple[list[BaseMessage], list[int]]:
history_basemessages = [
translate_danswer_msg_to_langchain(msg)
for msg in history
if msg.token_count != 0
]
history_token_counts = [msg.token_count for msg in history if msg.token_count != 0]
return history_basemessages, history_token_counts

View File

@@ -5,20 +5,20 @@ from typing import TypeVar
from pydantic import BaseModel
from danswer.chat.models import ContextualPruningConfig
from danswer.chat.models import (
LlmDoc,
)
from danswer.chat.models import PromptConfig
from danswer.chat.prompt_builder.citations_prompt import compute_max_document_tokens
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.llm.answering.models import ContextualPruningConfig
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
from danswer.context.search.models import InferenceChunk
from danswer.context.search.models import InferenceSection
from danswer.llm.interfaces import LLMConfig
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.prompts.prompt_utils import build_doc_context_str
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
from danswer.tools.tool_implementations.search.search_utils import section_to_dict
from danswer.utils.logger import setup_logger

View File

@@ -3,16 +3,14 @@ from collections.abc import Generator
from langchain_core.messages import BaseMessage
from danswer.chat.llm_response_handler import ResponsePart
from danswer.chat.models import CitationInfo
from danswer.chat.models import LlmDoc
from danswer.llm.answering.llm_response_handler import ResponsePart
from danswer.llm.answering.stream_processing.citation_processing import (
CitationProcessor,
)
from danswer.llm.answering.stream_processing.quotes_processing import (
QuotesProcessor,
)
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
from danswer.chat.stream_processing.citation_processing import CitationProcessor
from danswer.chat.stream_processing.utils import DocumentIdOrderMapping
from danswer.utils.logger import setup_logger
logger = setup_logger()
class AnswerResponseHandler(abc.ABC):
@@ -48,6 +46,9 @@ class CitationResponseHandler(AnswerResponseHandler):
self.processed_text = ""
self.citations: list[CitationInfo] = []
# TODO remove this after citation issue is resolved
logger.debug(f"Document to ranking map {self.doc_id_to_rank_map}")
def handle_response_part(
self,
response_item: BaseMessage | None,
@@ -64,28 +65,29 @@ class CitationResponseHandler(AnswerResponseHandler):
yield from self.citation_processor.process_token(content)
class QuotesResponseHandler(AnswerResponseHandler):
def __init__(
self,
context_docs: list[LlmDoc],
is_json_prompt: bool = True,
):
self.quotes_processor = QuotesProcessor(
context_docs=context_docs,
is_json_prompt=is_json_prompt,
)
# No longer in use, remove later
# class QuotesResponseHandler(AnswerResponseHandler):
# def __init__(
# self,
# context_docs: list[LlmDoc],
# is_json_prompt: bool = True,
# ):
# self.quotes_processor = QuotesProcessor(
# context_docs=context_docs,
# is_json_prompt=is_json_prompt,
# )
def handle_response_part(
self,
response_item: BaseMessage | None,
previous_response_items: list[BaseMessage],
) -> Generator[ResponsePart, None, None]:
if response_item is None:
yield from self.quotes_processor.process_token(None)
return
# def handle_response_part(
# self,
# response_item: BaseMessage | None,
# previous_response_items: list[BaseMessage],
# ) -> Generator[ResponsePart, None, None]:
# if response_item is None:
# yield from self.quotes_processor.process_token(None)
# return
content = (
response_item.content if isinstance(response_item.content, str) else ""
)
# content = (
# response_item.content if isinstance(response_item.content, str) else ""
# )
yield from self.quotes_processor.process_token(content)
# yield from self.quotes_processor.process_token(content)

View File

@@ -4,8 +4,8 @@ from collections.abc import Generator
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.chat.stream_processing.utils import DocumentIdOrderMapping
from danswer.configs.chat_configs import STOP_STREAM_PAT
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
from danswer.prompts.constants import TRIPLE_BACKTICK
from danswer.utils.logger import setup_logger
@@ -67,9 +67,9 @@ class CitationProcessor:
if piece_that_comes_after == "\n" and in_code_block(self.llm_out):
self.curr_segment = self.curr_segment.replace("```", "```plaintext")
citation_pattern = r"\[(\d+)\]"
citation_pattern = r"\[(\d+)\]|\[\[(\d+)\]\]" # [1], [[1]], etc.
citations_found = list(re.finditer(citation_pattern, self.curr_segment))
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
possible_citation_pattern = r"(\[+\d*$)" # [1, [, [[, [[2, etc.
possible_citation_found = re.search(
possible_citation_pattern, self.curr_segment
)
@@ -77,13 +77,15 @@ class CitationProcessor:
if len(citations_found) == 0 and len(self.llm_out) - self.past_cite_count > 5:
self.current_citations = []
result = "" # Initialize result here
result = ""
if citations_found and not in_code_block(self.llm_out):
last_citation_end = 0
length_to_add = 0
while len(citations_found) > 0:
citation = citations_found.pop(0)
numerical_value = int(citation.group(1))
numerical_value = int(
next(group for group in citation.groups() if group is not None)
)
if 1 <= numerical_value <= self.max_citation_num:
context_llm_doc = self.context_docs[numerical_value - 1]
@@ -131,14 +133,6 @@ class CitationProcessor:
link = context_llm_doc.link
# Replace the citation in the current segment
start, end = citation.span()
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[{target_citation_num}]"
+ self.curr_segment[end + length_to_add :]
)
self.past_cite_count = len(self.llm_out)
self.current_citations.append(target_citation_num)
@@ -149,6 +143,7 @@ class CitationProcessor:
document_id=context_llm_doc.document_id,
)
start, end = citation.span()
if link:
prev_length = len(self.curr_segment)
self.curr_segment = (

View File

@@ -1,3 +1,4 @@
# THIS IS NO LONGER IN USE
import math
import re
from collections.abc import Generator
@@ -5,16 +6,15 @@ from json import JSONDecodeError
from typing import Optional
import regex
from pydantic import BaseModel
from danswer.chat.models import DanswerAnswer
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerQuote
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.context.search.models import InferenceChunk
from danswer.prompts.constants import ANSWER_PAT
from danswer.prompts.constants import QUOTE_PAT
from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_model_quote
from danswer.utils.text_processing import clean_up_code_blocks
@@ -26,6 +26,20 @@ logger = setup_logger()
answer_pattern = re.compile(r'{\s*"answer"\s*:\s*"', re.IGNORECASE)
class DanswerQuote(BaseModel):
# This is during inference so everything is a string by this point
quote: str
document_id: str
link: str | None
source_type: str
semantic_identifier: str
blurb: str
class DanswerQuotes(BaseModel):
quotes: list[DanswerQuote]
def _extract_answer_quotes_freeform(
answer_raw: str,
) -> tuple[Optional[str], Optional[list[str]]]:
@@ -231,16 +245,16 @@ class QuotesProcessor:
model_previous = self.model_output
self.model_output += token
if not self.found_answer_start:
m = answer_pattern.search(self.model_output)
if m:
self.found_answer_start = True
# Prevent heavy cases of hallucinations
if self.is_json_prompt and len(self.model_output) > 70:
logger.warning("LLM did not produce json as prompted")
if self.is_json_prompt and len(self.model_output) > 400:
self.found_answer_end = True
logger.warning("LLM did not produce json as prompted")
logger.debug("Model output thus far:", self.model_output)
return
remaining = self.model_output[m.end() :]

View File

@@ -3,7 +3,7 @@ from collections.abc import Sequence
from pydantic import BaseModel
from danswer.chat.models import LlmDoc
from danswer.search.models import InferenceChunk
from danswer.context.search.models import InferenceChunk
class DocumentIdOrderMapping(BaseModel):

View File

@@ -4,8 +4,8 @@ from langchain_core.messages import AIMessageChunk
from langchain_core.messages import BaseMessage
from langchain_core.messages import ToolCall
from danswer.llm.answering.llm_response_handler import LLMCall
from danswer.llm.answering.llm_response_handler import ResponsePart
from danswer.chat.models import ResponsePart
from danswer.chat.prompt_builder.build import LLMCall
from danswer.llm.interfaces import LLM
from danswer.tools.force import ForceUseTool
from danswer.tools.message import build_tool_message
@@ -62,7 +62,7 @@ class ToolResponseHandler:
llm_call.force_use_tool.args
if llm_call.force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=llm_call.prompt_builder.get_user_message_content(),
query=llm_call.prompt_builder.raw_user_message,
history=llm_call.prompt_builder.raw_message_history,
llm=llm,
force_run=True,
@@ -76,7 +76,7 @@ class ToolResponseHandler:
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=llm_call.tools,
query=llm_call.prompt_builder.get_user_message_content(),
query=llm_call.prompt_builder.raw_user_message,
history=llm_call.prompt_builder.raw_message_history,
llm=llm,
)
@@ -95,7 +95,7 @@ class ToolResponseHandler:
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=llm_call.prompt_builder.raw_message_history,
query=llm_call.prompt_builder.get_user_message_content(),
query=llm_call.prompt_builder.raw_user_message,
llm=llm,
)
if available_tools_and_args

View File

@@ -1,115 +0,0 @@
from typing_extensions import TypedDict # noreorder
from pydantic import BaseModel
from danswer.prompts.chat_tools import DANSWER_TOOL_DESCRIPTION
from danswer.prompts.chat_tools import DANSWER_TOOL_NAME
from danswer.prompts.chat_tools import TOOL_FOLLOWUP
from danswer.prompts.chat_tools import TOOL_LESS_FOLLOWUP
from danswer.prompts.chat_tools import TOOL_LESS_PROMPT
from danswer.prompts.chat_tools import TOOL_TEMPLATE
from danswer.prompts.chat_tools import USER_INPUT
class ToolInfo(TypedDict):
name: str
description: str
class DanswerChatModelOut(BaseModel):
model_raw: str
action: str
action_input: str
def call_tool(
model_actions: DanswerChatModelOut,
) -> str:
raise NotImplementedError("There are no additional tool integrations right now")
def form_user_prompt_text(
query: str,
tool_text: str | None,
hint_text: str | None,
user_input_prompt: str = USER_INPUT,
tool_less_prompt: str = TOOL_LESS_PROMPT,
) -> str:
user_prompt = tool_text or tool_less_prompt
user_prompt += user_input_prompt.format(user_input=query)
if hint_text:
if user_prompt[-1] != "\n":
user_prompt += "\n"
user_prompt += "\nHint: " + hint_text
return user_prompt.strip()
def form_tool_section_text(
tools: list[ToolInfo] | None, retrieval_enabled: bool, template: str = TOOL_TEMPLATE
) -> str | None:
if not tools and not retrieval_enabled:
return None
if retrieval_enabled and tools:
tools.append(
{"name": DANSWER_TOOL_NAME, "description": DANSWER_TOOL_DESCRIPTION}
)
tools_intro = []
if tools:
num_tools = len(tools)
for tool in tools:
description_formatted = tool["description"].replace("\n", " ")
tools_intro.append(f"> {tool['name']}: {description_formatted}")
prefix = "Must be one of " if num_tools > 1 else "Must be "
tools_intro_text = "\n".join(tools_intro)
tool_names_text = prefix + ", ".join([tool["name"] for tool in tools])
else:
return None
return template.format(
tool_overviews=tools_intro_text, tool_names=tool_names_text
).strip()
def form_tool_followup_text(
tool_output: str,
query: str,
hint_text: str | None,
tool_followup_prompt: str = TOOL_FOLLOWUP,
ignore_hint: bool = False,
) -> str:
# If multi-line query, it likely confuses the model more than helps
if "\n" not in query:
optional_reminder = f"\nAs a reminder, my query was: {query}\n"
else:
optional_reminder = ""
if not ignore_hint and hint_text:
hint_text_spaced = f"\nHint: {hint_text}\n"
else:
hint_text_spaced = ""
return tool_followup_prompt.format(
tool_output=tool_output,
optional_reminder=optional_reminder,
hint=hint_text_spaced,
).strip()
def form_tool_less_followup_text(
tool_output: str,
query: str,
hint_text: str | None,
tool_followup_prompt: str = TOOL_LESS_FOLLOWUP,
) -> str:
hint = f"Hint: {hint_text}" if hint_text else ""
return tool_followup_prompt.format(
context_str=tool_output, user_query=query, hint_text=hint
).strip()

View File

@@ -43,9 +43,6 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
AUTH_TYPE = AuthType((os.environ.get("AUTH_TYPE") or AuthType.DISABLED.value).lower())
DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED
# Necessary for cloud integration tests
DISABLE_VERIFICATION = os.environ.get("DISABLE_VERIFICATION", "").lower() == "true"
# Encryption key secret is used to encrypt connector credentials, api keys, and other sensitive
# information. This provides an extra layer of security on top of Postgres access controls
# and is available in Danswer EE
@@ -84,7 +81,14 @@ OAUTH_CLIENT_SECRET = (
or ""
)
# for future OAuth connector support
# OAUTH_CONFLUENCE_CLIENT_ID = os.environ.get("OAUTH_CONFLUENCE_CLIENT_ID", "")
# OAUTH_CONFLUENCE_CLIENT_SECRET = os.environ.get("OAUTH_CONFLUENCE_CLIENT_SECRET", "")
# OAUTH_JIRA_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLIENT_ID", "")
# OAUTH_JIRA_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLIENT_SECRET", "")
USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "")
# for basic auth
REQUIRE_EMAIL_VERIFICATION = (
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
@@ -118,6 +122,8 @@ VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
VESPA_CONFIG_SERVER_HOST = os.environ.get("VESPA_CONFIG_SERVER_HOST") or VESPA_HOST
VESPA_PORT = os.environ.get("VESPA_PORT") or "8081"
VESPA_TENANT_PORT = os.environ.get("VESPA_TENANT_PORT") or "19071"
# the number of times to try and connect to vespa on startup before giving up
VESPA_NUM_ATTEMPTS_ON_STARTUP = int(os.environ.get("NUM_RETRIES_ON_STARTUP") or 10)
VESPA_CLOUD_URL = os.environ.get("VESPA_CLOUD_URL", "")
@@ -234,7 +240,7 @@ except ValueError:
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT
)
CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT = 1
CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT = 3
try:
env_value = os.environ.get("CELERY_WORKER_INDEXING_CONCURRENCY")
if not env_value:
@@ -308,6 +314,22 @@ CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000)
)
# Due to breakages in the confluence API, the timezone offset must be specified client side
# to match the user's specified timezone.
# The current state of affairs:
# CQL queries are parsed in the user's timezone and cannot be specified in UTC
# no API retrieves the user's timezone
# All data is returned in UTC, so we can't derive the user's timezone from that
# https://community.developer.atlassian.com/t/confluence-cloud-time-zone-get-via-rest-api/35954/16
# https://jira.atlassian.com/browse/CONFCLOUD-69670
# enter as a floating point offset from UTC in hours (-24 < val < 24)
# this will be applied globally, so it probably makes sense to transition this to per
# connector as some point.
CONFLUENCE_TIMEZONE_OFFSET = float(os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", 0.0))
JIRA_CONNECTOR_LABELS_TO_SKIP = [
ignored_tag
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
@@ -422,6 +444,9 @@ LOG_ALL_MODEL_INTERACTIONS = (
LOG_DANSWER_MODEL_INTERACTIONS = (
os.environ.get("LOG_DANSWER_MODEL_INTERACTIONS", "").lower() == "true"
)
LOG_INDIVIDUAL_MODEL_TOKENS = (
os.environ.get("LOG_INDIVIDUAL_MODEL_TOKENS", "").lower() == "true"
)
# If set to `true` will enable additional logs about Vespa query performance
# (time spent on finding the right docs + time spent fetching summaries from disk)
LOG_VESPA_TIMING_INFORMATION = (
@@ -490,10 +515,6 @@ CONTROL_PLANE_API_BASE_URL = os.environ.get(
# JWT configuration
JWT_ALGORITHM = "HS256"
# Super Users
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
#####
# API Key Configs
@@ -503,3 +524,10 @@ _API_KEY_HASH_ROUNDS_RAW = os.environ.get("API_KEY_HASH_ROUNDS")
API_KEY_HASH_ROUNDS = (
int(_API_KEY_HASH_ROUNDS_RAW) if _API_KEY_HASH_ROUNDS_RAW else None
)
POD_NAME = os.environ.get("POD_NAME")
POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"

View File

@@ -1,9 +1,9 @@
import os
PROMPTS_YAML = "./danswer/chat/prompts.yaml"
PERSONAS_YAML = "./danswer/chat/personas.yaml"
INPUT_PROMPT_YAML = "./danswer/chat/input_prompts.yaml"
PROMPTS_YAML = "./danswer/seeding/prompts.yaml"
PERSONAS_YAML = "./danswer/seeding/personas.yaml"
INPUT_PROMPT_YAML = "./danswer/seeding/input_prompts.yaml"
NUM_RETURNED_HITS = 50
# Used for LLM filtering and reranking
@@ -17,9 +17,6 @@ MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0)
# ~3k input, half for docs, half for chat history + prompts
CHAT_TARGET_CHUNK_PERCENTAGE = 512 * 3 / 3072
# For selecting a different LLM question-answering prompt format
# Valid values: default, cot, weak
QA_PROMPT_OVERRIDE = os.environ.get("QA_PROMPT_OVERRIDE") or None
# 1 / (1 + DOC_TIME_DECAY * doc-age-in-years), set to 0 to have no decay
# Capped in Vespa at 0.5
DOC_TIME_DECAY = float(
@@ -27,8 +24,6 @@ DOC_TIME_DECAY = float(
)
BASE_RECENCY_DECAY = 0.5
FAVOR_RECENT_DECAY_MULTIPLIER = 2.0
# Currently this next one is not configurable via env
DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak"
# For the highest matching base size chunk, how many chunks above and below do we pull in by default
# Note this is not in any of the deployment configs yet
# Currently only applies to search flow not chat

View File

@@ -31,6 +31,8 @@ DISABLED_GEN_AI_MSG = (
"You can still use Danswer as a search engine."
)
DEFAULT_PERSONA_ID = 0
# Postgres connection constants for application_name
POSTGRES_WEB_APP_NAME = "web"
POSTGRES_INDEXER_APP_NAME = "indexer"
@@ -60,7 +62,6 @@ KV_GMAIL_CRED_KEY = "gmail_app_credential"
KV_GMAIL_SERVICE_ACCOUNT_KEY = "gmail_service_account_key"
KV_GOOGLE_DRIVE_CRED_KEY = "google_drive_app_credential"
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
KV_SLACK_BOT_TOKENS_CONFIG_KEY = "slack_bot_tokens_config_key"
KV_GEN_AI_KEY_CHECK_TIME = "genai_api_key_last_check_time"
KV_SETTINGS_KEY = "danswer_settings"
KV_CUSTOMER_UUID_KEY = "customer_uuid"
@@ -74,12 +75,16 @@ CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
# needs to be long enough to cover the maximum time it takes to download an object
# if we can get callbacks as object bytes download, we could lower this a lot.
CELERY_INDEXING_LOCK_TIMEOUT = 60 * 60 # 60 min
CELERY_INDEXING_LOCK_TIMEOUT = 3 * 60 * 60 # 60 min
# needs to be long enough to cover the maximum time it takes to download an object
# if we can get callbacks as object bytes download, we could lower this a lot.
CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min
CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT = 300 # 5 min
CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
@@ -209,9 +214,17 @@ class PostgresAdvisoryLocks(Enum):
class DanswerCeleryQueues:
# Light queue
VESPA_METADATA_SYNC = "vespa_metadata_sync"
DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert"
CONNECTOR_DELETION = "connector_deletion"
# Heavy queue
CONNECTOR_PRUNING = "connector_pruning"
CONNECTOR_DOC_PERMISSIONS_SYNC = "connector_doc_permissions_sync"
CONNECTOR_EXTERNAL_GROUP_SYNC = "connector_external_group_sync"
# Indexing queue
CONNECTOR_INDEXING = "connector_indexing"
@@ -221,8 +234,18 @@ class DanswerRedisLocks:
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat"
CHECK_INDEXING_BEAT_LOCK = "da_lock:check_indexing_beat"
CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK = (
"da_lock:check_connector_doc_permissions_sync_beat"
)
CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK = (
"da_lock:check_connector_external_group_sync_beat"
)
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = (
"da_lock:connector_doc_permissions_sync"
)
CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX = "da_lock:connector_external_group_sync"
PRUNING_LOCK_PREFIX = "da_lock:pruning"
INDEXING_METADATA_PREFIX = "da_metadata:indexing"
@@ -238,6 +261,32 @@ class DanswerCeleryPriority(int, Enum):
LOWEST = auto()
class DanswerCeleryTask:
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
CHECK_FOR_INDEXING = "check_for_indexing"
CHECK_FOR_PRUNING = "check_for_pruning"
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
"connector_permission_sync_generator_task"
)
UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK = (
"update_external_document_permissions_task"
)
CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK = (
"connector_external_group_sync_generator_task"
)
CONNECTOR_INDEXING_PROXY_TASK = "connector_indexing_proxy_task"
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"
CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task"
AUTOGENERATE_USAGE_REPORT_TASK = "autogenerate_usage_report_task"
REDIS_SOCKET_KEEPALIVE_OPTIONS = {}
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3

View File

@@ -4,11 +4,8 @@ import os
# Danswer Slack Bot Configs
#####
DANSWER_BOT_NUM_RETRIES = int(os.environ.get("DANSWER_BOT_NUM_RETRIES", "5"))
DANSWER_BOT_ANSWER_GENERATION_TIMEOUT = int(
os.environ.get("DANSWER_BOT_ANSWER_GENERATION_TIMEOUT", "90")
)
# How much of the available input context can be used for thread context
DANSWER_BOT_TARGET_CHUNK_PERCENTAGE = 512 * 2 / 3072
MAX_THREAD_CONTEXT_PERCENTAGE = 512 * 2 / 3072
# Number of docs to display in "Reference Documents"
DANSWER_BOT_NUM_DOCS_TO_DISPLAY = int(
os.environ.get("DANSWER_BOT_NUM_DOCS_TO_DISPLAY", "5")
@@ -47,17 +44,6 @@ DANSWER_BOT_DISPLAY_ERROR_MSGS = os.environ.get(
DANSWER_BOT_RESPOND_EVERY_CHANNEL = (
os.environ.get("DANSWER_BOT_RESPOND_EVERY_CHANNEL", "").lower() == "true"
)
# Add a second LLM call post Answer to verify if the Answer is valid
# Throws out answers that don't directly or fully answer the user query
# This is the default for all DanswerBot channels unless the channel is configured individually
# Set/unset by "Hide Non Answers"
ENABLE_DANSWERBOT_REFLEXION = (
os.environ.get("ENABLE_DANSWERBOT_REFLEXION", "").lower() == "true"
)
# Currently not support chain of thought, probably will add back later
DANSWER_BOT_DISABLE_COT = True
# if set, will default DanswerBot to use quotes and reference documents
DANSWER_BOT_USE_QUOTES = os.environ.get("DANSWER_BOT_USE_QUOTES", "").lower() == "true"
# Maximum Questions Per Minute, Default Uncapped
DANSWER_BOT_MAX_QPM = int(os.environ.get("DANSWER_BOT_MAX_QPM") or 0) or None

View File

@@ -70,7 +70,9 @@ GEN_AI_NUM_RESERVED_OUTPUT_TOKENS = int(
)
# Typically, GenAI models nowadays are at least 4K tokens
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = 4096
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = int(
os.environ.get("GEN_AI_MODEL_FALLBACK_MAX_TOKENS") or 4096
)
# Number of tokens from chat history to include at maximum
# 3000 should be enough context regardless of use, no need to include as much as possible
@@ -119,3 +121,14 @@ if _LITELLM_PASS_THROUGH_HEADERS_RAW:
logger.error(
"Failed to parse LITELLM_PASS_THROUGH_HEADERS, must be a valid JSON object"
)
# if specified, will merge the specified JSON with the existing body of the
# request before sending it to the LLM
LITELLM_EXTRA_BODY: dict | None = None
_LITELLM_EXTRA_BODY_RAW = os.environ.get("LITELLM_EXTRA_BODY")
if _LITELLM_EXTRA_BODY_RAW:
try:
LITELLM_EXTRA_BODY = json.loads(_LITELLM_EXTRA_BODY_RAW)
except Exception:
pass

View File

@@ -2,6 +2,8 @@ import json
import os
IMAGE_GENERATION_OUTPUT_FORMAT = os.environ.get("IMAGE_GENERATION_OUTPUT_FORMAT", "url")
# if specified, will pass through request headers to the call to API calls made by custom tools
CUSTOM_TOOL_PASS_THROUGH_HEADERS: list[str] | None = None
_CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW = os.environ.get(

View File

@@ -11,11 +11,16 @@ Connectors come in 3 different flows:
- Load Connector:
- Bulk indexes documents to reflect a point in time. This type of connector generally works by either pulling all
documents via a connector's API or loads the documents from some sort of a dump file.
- Poll connector:
- Poll Connector:
- Incrementally updates documents based on a provided time range. It is used by the background job to pull the latest
changes and additions since the last round of polling. This connector helps keep the document index up to date
without needing to fetch/embed/index every document which would be too slow to do frequently on large sets of
documents.
- Slim Connector:
- This connector should be a lighter weight method of checking all documents in the source to see if they still exist.
- This connector should be identical to the Poll or Load Connector except that it only fetches the IDs of the documents, not the documents themselves.
- This is used by our pruning job which removes old documents from the index.
- The optional start and end datetimes can be ignored.
- Event Based connectors:
- Connectors that listen to events and update documents accordingly.
- Currently not used by the background job, this exists for future design purposes.
@@ -26,8 +31,14 @@ Refer to [interfaces.py](https://github.com/danswer-ai/danswer/blob/main/backend
and this first contributor created Pull Request for a new connector (Shoutout to Dan Brown):
[Reference Pull Request](https://github.com/danswer-ai/danswer/pull/139)
For implementing a Slim Connector, refer to the comments in this PR:
[Slim Connector PR](https://github.com/danswer-ai/danswer/pull/3303/files)
All new connectors should have tests added to the `backend/tests/daily/connectors` directory. Refer to the above PR for an example of adding tests for a new connector.
#### Implementing the new Connector
The connector must subclass one or more of LoadConnector, PollConnector, or EventConnector.
The connector must subclass one or more of LoadConnector, PollConnector, SlimConnector, or EventConnector.
The `__init__` should take arguments for configuring what documents the connector will and where it finds those
documents. For example, if you have a wiki site, it may include the configuration for the team, topic, folder, etc. of

View File

@@ -5,9 +5,9 @@ from io import BytesIO
from typing import Any
from typing import Optional
import boto3
from botocore.client import Config
from mypy_boto3_s3 import S3Client
import boto3 # type: ignore
from botocore.client import Config # type: ignore
from mypy_boto3_s3 import S3Client # type: ignore
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import BlobType

View File

@@ -1,18 +1,21 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from urllib.parse import quote
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
from danswer.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.confluence.onyx_confluence import build_confluence_client
from danswer.connectors.confluence.onyx_confluence import OnyxConfluence
from danswer.connectors.confluence.utils import attachment_to_content
from danswer.connectors.confluence.utils import build_confluence_client
from danswer.connectors.confluence.utils import build_confluence_document_id
from danswer.connectors.confluence.utils import datetime_from_string
from danswer.connectors.confluence.utils import extract_text_from_confluence_html
from danswer.connectors.confluence.utils import validate_attachment_filetype
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector
@@ -51,6 +54,8 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
"restrictions.read.restrictions.group",
]
_SLIM_DOC_BATCH_SIZE = 5000
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
@@ -67,10 +72,11 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
# skip it. This is generally used to avoid indexing extra sensitive
# pages.
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
timezone_offset: float = CONFLUENCE_TIMEZONE_OFFSET,
) -> None:
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
self.confluence_client: OnyxConfluence | None = None
self._confluence_client: OnyxConfluence | None = None
self.is_cloud = is_cloud
# Remove trailing slash from wiki_base if present
@@ -81,15 +87,15 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
if cql_query:
# if a cql_query is provided, we will use it to fetch the pages
cql_page_query = cql_query
elif space:
# if no cql_query is provided, we will use the space to fetch the pages
cql_page_query += f" and space='{quote(space)}'"
elif page_id:
# if a cql_query is not provided, we will use the page_id to fetch the page
if index_recursively:
cql_page_query += f" and ancestor='{page_id}'"
else:
# if neither a space nor a cql_query is provided, we will use the page_id to fetch the page
cql_page_query += f" and id='{page_id}'"
elif space:
# if no cql_query or page_id is provided, we will use the space to fetch the pages
cql_page_query += f" and space='{quote(space)}'"
self.cql_page_query = cql_page_query
self.cql_time_filter = ""
@@ -97,39 +103,46 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
self.cql_label_filter = ""
if labels_to_skip:
labels_to_skip = list(set(labels_to_skip))
comma_separated_labels = ",".join(f"'{label}'" for label in labels_to_skip)
comma_separated_labels = ",".join(
f"'{quote(label)}'" for label in labels_to_skip
)
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
self.timezone: timezone = timezone(offset=timedelta(hours=timezone_offset))
@property
def confluence_client(self) -> OnyxConfluence:
if self._confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
return self._confluence_client
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
# see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py
# for a list of other hidden constructor args
self.confluence_client = build_confluence_client(
credentials_json=credentials,
self._confluence_client = build_confluence_client(
credentials=credentials,
is_cloud=self.is_cloud,
wiki_base=self.wiki_base,
)
return None
def _get_comment_string_for_page_id(self, page_id: str) -> str:
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
comment_string = ""
comment_cql = f"type=comment and container='{page_id}'"
comment_cql += self.cql_label_filter
expand = ",".join(_COMMENT_EXPANSION_FIELDS)
for comments in self.confluence_client.paginated_cql_page_retrieval(
for comment in self.confluence_client.paginated_cql_retrieval(
cql=comment_cql,
expand=expand,
):
for comment in comments:
comment_string += "\nComment:\n"
comment_string += extract_text_from_confluence_html(
confluence_client=self.confluence_client,
confluence_object=comment,
)
comment_string += "\nComment:\n"
comment_string += extract_text_from_confluence_html(
confluence_client=self.confluence_client,
confluence_object=comment,
fetched_titles=set(),
)
return comment_string
@@ -141,28 +154,28 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
If its a page, it extracts the text, adds the comments for the document text.
If its an attachment, it just downloads the attachment and converts that into a document.
"""
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
# The url and the id are the same
object_url = build_confluence_document_id(
self.wiki_base, confluence_object["_links"]["webui"]
self.wiki_base, confluence_object["_links"]["webui"], self.is_cloud
)
object_text = None
# Extract text from page
if confluence_object["type"] == "page":
object_text = extract_text_from_confluence_html(
self.confluence_client, confluence_object
confluence_client=self.confluence_client,
confluence_object=confluence_object,
fetched_titles={confluence_object.get("title", "")},
)
# Add comments to text
object_text += self._get_comment_string_for_page_id(confluence_object["id"])
elif confluence_object["type"] == "attachment":
object_text = attachment_to_content(
self.confluence_client, confluence_object
confluence_client=self.confluence_client, attachment=confluence_object
)
if object_text is None:
# This only happens for attachments that are not parseable
return None
# Get space name
@@ -193,44 +206,41 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
)
def _fetch_document_batches(self) -> GenerateDocumentsOutput:
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
doc_batch: list[Document] = []
confluence_page_ids: list[str] = []
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
logger.debug(f"page_query: {page_query}")
# Fetch pages as Documents
for page_batch in self.confluence_client.paginated_cql_page_retrieval(
for page in self.confluence_client.paginated_cql_retrieval(
cql=page_query,
expand=",".join(_PAGE_EXPANSION_FIELDS),
limit=self.batch_size,
):
for page in page_batch:
confluence_page_ids.append(page["id"])
doc = self._convert_object_to_document(page)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
logger.debug(f"_fetch_document_batches: {page['id']}")
confluence_page_ids.append(page["id"])
doc = self._convert_object_to_document(page)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
# Fetch attachments as Documents
for confluence_page_id in confluence_page_ids:
attachment_cql = f"type=attachment and container='{confluence_page_id}'"
attachment_cql += self.cql_label_filter
# TODO: maybe should add time filter as well?
for attachments in self.confluence_client.paginated_cql_page_retrieval(
for attachment in self.confluence_client.paginated_cql_retrieval(
cql=attachment_cql,
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
):
for attachment in attachments:
doc = self._convert_object_to_document(attachment)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
doc = self._convert_object_to_document(attachment)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
@@ -240,10 +250,10 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput:
# Add time filters
formatted_start_time = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
formatted_start_time = datetime.fromtimestamp(start, tz=self.timezone).strftime(
"%Y-%m-%d %H:%M"
)
formatted_end_time = datetime.fromtimestamp(end, tz=timezone.utc).strftime(
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
"%Y-%m-%d %H:%M"
)
self.cql_time_filter = f" and lastmodified >= '{formatted_start_time}'"
@@ -255,48 +265,69 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
doc_metadata_list: list[SlimDocument] = []
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
page_query = self.cql_page_query + self.cql_label_filter
for pages in self.confluence_client.cql_paginate_all_expansions(
for page in self.confluence_client.cql_paginate_all_expansions(
cql=page_query,
expand=restrictions_expand,
limit=_SLIM_DOC_BATCH_SIZE,
):
for page in pages:
# If the page has restrictions, add them to the perm_sync_data
# These will be used by doc_sync.py to sync permissions
perm_sync_data = {
"restrictions": page.get("restrictions", {}),
"space_key": page.get("space", {}).get("key"),
# If the page has restrictions, add them to the perm_sync_data
# These will be used by doc_sync.py to sync permissions
page_restrictions = page.get("restrictions")
page_space_key = page.get("space", {}).get("key")
page_perm_sync_data = {
"restrictions": page_restrictions or {},
"space_key": page_space_key,
}
doc_metadata_list.append(
SlimDocument(
id=build_confluence_document_id(
self.wiki_base,
page["_links"]["webui"],
self.is_cloud,
),
perm_sync_data=page_perm_sync_data,
)
)
attachment_cql = f"type=attachment and container='{page['id']}'"
attachment_cql += self.cql_label_filter
for attachment in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_cql,
expand=restrictions_expand,
limit=_SLIM_DOC_BATCH_SIZE,
):
if not validate_attachment_filetype(attachment):
continue
attachment_restrictions = attachment.get("restrictions")
if not attachment_restrictions:
attachment_restrictions = page_restrictions
attachment_space_key = attachment.get("space", {}).get("key")
if not attachment_space_key:
attachment_space_key = page_space_key
attachment_perm_sync_data = {
"restrictions": attachment_restrictions or {},
"space_key": attachment_space_key,
}
doc_metadata_list.append(
SlimDocument(
id=build_confluence_document_id(
self.wiki_base, page["_links"]["webui"]
self.wiki_base,
attachment["_links"]["webui"],
self.is_cloud,
),
perm_sync_data=perm_sync_data,
perm_sync_data=attachment_perm_sync_data,
)
)
attachment_cql = f"type=attachment and container='{page['id']}'"
attachment_cql += self.cql_label_filter
for attachments in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_cql,
expand=restrictions_expand,
):
for attachment in attachments:
doc_metadata_list.append(
SlimDocument(
id=build_confluence_document_id(
self.wiki_base, attachment["_links"]["webui"]
),
perm_sync_data=perm_sync_data,
)
)
yield doc_metadata_list
doc_metadata_list = []
if len(doc_metadata_list) > _SLIM_DOC_BATCH_SIZE:
yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE]
doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:]
yield doc_metadata_list

View File

@@ -20,6 +20,10 @@ F = TypeVar("F", bound=Callable[..., Any])
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
# https://jira.atlassian.com/browse/CONFCLOUD-76433
_PROBLEMATIC_EXPANSIONS = "body.storage.value"
_REPLACEMENT_EXPANSIONS = "body.view.value"
class ConfluenceRateLimitError(Exception):
pass
@@ -80,7 +84,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
MAX_RETRIES = 5
TIMEOUT = 3600
TIMEOUT = 600
timeout_at = time.monotonic() + TIMEOUT
for attempt in range(MAX_RETRIES):
@@ -95,6 +99,10 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
return confluence_call(*args, **kwargs)
except HTTPError as e:
delay_until = _handle_http_error(e, attempt)
logger.warning(
f"HTTPError in confluence call. "
f"Retrying in {delay_until} seconds..."
)
while time.monotonic() < delay_until:
# in the future, check a signal here to exit
time.sleep(1)
@@ -112,7 +120,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
return cast(F, wrapped_call)
_DEFAULT_PAGINATION_LIMIT = 100
_DEFAULT_PAGINATION_LIMIT = 1000
class OnyxConfluence(Confluence):
@@ -126,6 +134,32 @@ class OnyxConfluence(Confluence):
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
self._wrap_methods()
def get_current_user(self, expand: str | None = None) -> Any:
"""
Implements a method that isn't in the third party client.
Get information about the current user
:param expand: OPTIONAL expand for get status of user.
Possible param is "status". Results are "Active, Deactivated"
:return: Returns the user details
"""
from atlassian.errors import ApiPermissionError # type:ignore
url = "rest/api/user/current"
params = {}
if expand:
params["expand"] = expand
try:
response = self.get(url, params=params)
except HTTPError as e:
if e.response.status_code == 403:
raise ApiPermissionError(
"The calling user does not have permission", reason=e
)
raise
return response
def _wrap_methods(self) -> None:
"""
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
@@ -141,7 +175,7 @@ class OnyxConfluence(Confluence):
def _paginate_url(
self, url_suffix: str, limit: int | None = None
) -> Iterator[list[dict[str, Any]]]:
) -> Iterator[dict[str, Any]]:
"""
This will paginate through the top level query.
"""
@@ -153,46 +187,43 @@ class OnyxConfluence(Confluence):
while url_suffix:
try:
logger.debug(f"Making confluence call to {url_suffix}")
next_response = self.get(url_suffix)
except Exception as e:
logger.exception("Error in danswer_cql: \n")
raise e
yield next_response.get("results", [])
logger.warning(f"Error in confluence call to {url_suffix}")
# If the problematic expansion is in the url, replace it
# with the replacement expansion and try again
# If that fails, raise the error
if _PROBLEMATIC_EXPANSIONS not in url_suffix:
logger.exception(f"Error in confluence call to {url_suffix}")
raise e
logger.warning(
f"Replacing {_PROBLEMATIC_EXPANSIONS} with {_REPLACEMENT_EXPANSIONS}"
" and trying again."
)
url_suffix = url_suffix.replace(
_PROBLEMATIC_EXPANSIONS,
_REPLACEMENT_EXPANSIONS,
)
continue
# yield the results individually
yield from next_response.get("results", [])
url_suffix = next_response.get("_links", {}).get("next")
def paginated_groups_retrieval(
self,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
return self._paginate_url("rest/api/group", limit)
def paginated_group_members_retrieval(
self,
group_name: str,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
group_name = quote(group_name)
return self._paginate_url(f"rest/api/group/{group_name}/member", limit)
def paginated_cql_user_retrieval(
def paginated_cql_retrieval(
self,
cql: str,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
) -> Iterator[dict[str, Any]]:
"""
The content/search endpoint can be used to fetch pages, attachments, and comments.
"""
expand_string = f"&expand={expand}" if expand else ""
return self._paginate_url(
f"rest/api/search/user?cql={cql}{expand_string}", limit
)
def paginated_cql_page_retrieval(
self,
cql: str,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
expand_string = f"&expand={expand}" if expand else ""
return self._paginate_url(
yield from self._paginate_url(
f"rest/api/content/search?cql={cql}{expand_string}", limit
)
@@ -201,7 +232,7 @@ class OnyxConfluence(Confluence):
cql: str,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
) -> Iterator[dict[str, Any]]:
"""
This function will paginate through the top level query first, then
paginate through all of the expansions.
@@ -221,6 +252,120 @@ class OnyxConfluence(Confluence):
for item in data:
_traverse_and_update(item)
for results in self.paginated_cql_page_retrieval(cql, expand, limit):
_traverse_and_update(results)
yield results
for confluence_object in self.paginated_cql_retrieval(cql, expand, limit):
_traverse_and_update(confluence_object)
yield confluence_object
def paginated_cql_user_retrieval(
self,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
The search/user endpoint can be used to fetch users.
It's a seperate endpoint from the content/search endpoint used only for users.
Otherwise it's very similar to the content/search endpoint.
"""
cql = "type=user"
url = "rest/api/search/user" if self.cloud else "rest/api/search"
expand_string = f"&expand={expand}" if expand else ""
url += f"?cql={cql}{expand_string}"
yield from self._paginate_url(url, limit)
def paginated_groups_by_user_retrieval(
self,
user: dict[str, Any],
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
This is not an SQL like query.
It's a confluence specific endpoint that can be used to fetch groups.
"""
user_field = "accountId" if self.cloud else "key"
user_value = user["accountId"] if self.cloud else user["userKey"]
# Server uses userKey (but calls it key during the API call), Cloud uses accountId
user_query = f"{user_field}={quote(user_value)}"
url = f"rest/api/user/memberof?{user_query}"
yield from self._paginate_url(url, limit)
def paginated_groups_retrieval(
self,
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
This is not an SQL like query.
It's a confluence specific endpoint that can be used to fetch groups.
"""
yield from self._paginate_url("rest/api/group", limit)
def paginated_group_members_retrieval(
self,
group_name: str,
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
This is not an SQL like query.
It's a confluence specific endpoint that can be used to fetch the members of a group.
THIS DOESN'T WORK FOR SERVER because it breaks when there is a slash in the group name.
E.g. neither "test/group" nor "test%2Fgroup" works for confluence.
"""
group_name = quote(group_name)
yield from self._paginate_url(f"rest/api/group/{group_name}/member", limit)
def _validate_connector_configuration(
credentials: dict[str, Any],
is_cloud: bool,
wiki_base: str,
) -> None:
# test connection with direct client, no retries
confluence_client_with_minimal_retries = Confluence(
api_version="cloud" if is_cloud else "latest",
url=wiki_base.rstrip("/"),
username=credentials["confluence_username"] if is_cloud else None,
password=credentials["confluence_access_token"] if is_cloud else None,
token=credentials["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=6,
max_backoff_seconds=10,
)
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
# uncomment the following for testing
# the following is an attempt to retrieve the user's timezone
# Unfornately, all data is returned in UTC regardless of the user's time zone
# even tho CQL parses incoming times based on the user's time zone
# space_key = spaces["results"][0]["key"]
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
if not spaces:
raise RuntimeError(
f"No spaces found at {wiki_base}! "
"Check your credentials and wiki_base and make sure "
"is_cloud is set correctly."
)
def build_confluence_client(
credentials: dict[str, Any],
is_cloud: bool,
wiki_base: str,
) -> OnyxConfluence:
_validate_connector_configuration(
credentials=credentials,
is_cloud=is_cloud,
wiki_base=wiki_base,
)
return OnyxConfluence(
api_version="cloud" if is_cloud else "latest",
# Remove trailing slash from wiki_base if present
url=wiki_base.rstrip("/"),
# passing in username causes issues for Confluence data center
username=credentials["confluence_username"] if is_cloud else None,
password=credentials["confluence_access_token"] if is_cloud else None,
token=credentials["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=10,
max_backoff_seconds=60,
)

View File

@@ -2,6 +2,7 @@ import io
from datetime import datetime
from datetime import timezone
from typing import Any
from urllib.parse import quote
import bs4
@@ -31,7 +32,11 @@ def get_user_email_from_username__server(
response = confluence_client.get_mobile_parameters(user_name)
email = response.get("email")
except Exception:
email = None
# For now, we'll just return a string that indicates failure
# We may want to revert to returning None in the future
# email = None
email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
logger.warning(f"failed to get confluence email for {user_name}")
_USER_EMAIL_CACHE[user_name] = email
return _USER_EMAIL_CACHE[user_name]
@@ -71,7 +76,9 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
def extract_text_from_confluence_html(
confluence_client: OnyxConfluence, confluence_object: dict[str, Any]
confluence_client: OnyxConfluence,
confluence_object: dict[str, Any],
fetched_titles: set[str],
) -> str:
"""Parse a Confluence html page and replace the 'user Id' by the real
User Display Name
@@ -79,7 +86,7 @@ def extract_text_from_confluence_html(
Args:
confluence_object (dict): The confluence object as a dict
confluence_client (Confluence): Confluence client
fetched_titles (set[str]): The titles of the pages that have already been fetched
Returns:
str: loaded and formated Confluence page
"""
@@ -100,22 +107,93 @@ def extract_text_from_confluence_html(
continue
# Include @ sign for tagging, more clear for LLM
user.replaceWith("@" + _get_user(confluence_client, user_id))
for html_page_reference in soup.findAll("ac:structured-macro"):
# Here, we only want to process page within page macros
if html_page_reference.attrs.get("ac:name") != "include":
continue
page_data = html_page_reference.find("ri:page")
if not page_data:
logger.warning(
f"Skipping retrieval of {html_page_reference} because because page data is missing"
)
continue
page_title = page_data.attrs.get("ri:content-title")
if not page_title:
# only fetch pages that have a title
logger.warning(
f"Skipping retrieval of {html_page_reference} because it has no title"
)
continue
if page_title in fetched_titles:
# prevent recursive fetching of pages
logger.debug(f"Skipping {page_title} because it has already been fetched")
continue
fetched_titles.add(page_title)
# Wrap this in a try-except because there are some pages that might not exist
try:
page_query = f"type=page and title='{quote(page_title)}'"
page_contents: dict[str, Any] | None = None
# Confluence enforces title uniqueness, so we should only get one result here
for page in confluence_client.paginated_cql_retrieval(
cql=page_query,
expand="body.storage.value",
limit=1,
):
page_contents = page
break
except Exception as e:
logger.warning(
f"Error getting page contents for object {confluence_object}: {e}"
)
continue
if not page_contents:
continue
text_from_page = extract_text_from_confluence_html(
confluence_client=confluence_client,
confluence_object=page_contents,
fetched_titles=fetched_titles,
)
html_page_reference.replaceWith(text_from_page)
for html_link_body in soup.findAll("ac:link-body"):
# This extracts the text from inline links in the page so they can be
# represented in the document text as plain text
try:
text_from_link = html_link_body.text
html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})")
except Exception as e:
logger.warning(f"Error processing ac:link-body: {e}")
return format_document_soup(soup)
def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
return attachment["metadata"]["mediaType"] not in [
"image/jpeg",
"image/png",
"image/gif",
"image/svg+xml",
"video/mp4",
"video/quicktime",
]
def attachment_to_content(
confluence_client: OnyxConfluence,
attachment: dict[str, Any],
) -> str | None:
"""If it returns None, assume that we should skip this attachment."""
if attachment["metadata"]["mediaType"] in [
"image/jpeg",
"image/png",
"image/gif",
"image/svg+xml",
"video/mp4",
"video/quicktime",
]:
if not validate_attachment_filetype(attachment):
return None
download_link = confluence_client.url + attachment["_links"]["download"]
@@ -153,7 +231,9 @@ def attachment_to_content(
return extracted_text
def build_confluence_document_id(base_url: str, content_url: str) -> str:
def build_confluence_document_id(
base_url: str, content_url: str, is_cloud: bool
) -> str:
"""For confluence, the document id is the page url for a page based document
or the attachment download url for an attachment based document
@@ -164,10 +244,12 @@ def build_confluence_document_id(base_url: str, content_url: str) -> str:
Returns:
str: The document id
"""
if is_cloud and not base_url.endswith("/wiki"):
base_url += "/wiki"
return f"{base_url}{content_url}"
def extract_referenced_attachment_names(page_text: str) -> list[str]:
def _extract_referenced_attachment_names(page_text: str) -> list[str]:
"""Parse a Confluence html page to generate a list of current
attachments in use
@@ -195,20 +277,3 @@ def datetime_from_string(datetime_string: str) -> datetime:
datetime_object = datetime_object.astimezone(timezone.utc)
return datetime_object
def build_confluence_client(
credentials_json: dict[str, Any], is_cloud: bool, wiki_base: str
) -> OnyxConfluence:
return OnyxConfluence(
api_version="cloud" if is_cloud else "latest",
# Remove trailing slash from wiki_base if present
url=wiki_base.rstrip("/"),
# passing in username causes issues for Confluence data center
username=credentials_json["confluence_username"] if is_cloud else None,
password=credentials_json["confluence_access_token"] if is_cloud else None,
token=credentials_json["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=60,
max_backoff_seconds=60,
)

View File

@@ -1,8 +1,8 @@
import os
from collections.abc import Iterable
from datetime import datetime
from datetime import timezone
from typing import Any
from urllib.parse import urlparse
from jira import JIRA
from jira.resources import Issue
@@ -12,129 +12,93 @@ from danswer.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
from danswer.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.danswer_jira.utils import best_effort_basic_expert_info
from danswer.connectors.danswer_jira.utils import best_effort_get_field_from_issue
from danswer.connectors.danswer_jira.utils import build_jira_client
from danswer.connectors.danswer_jira.utils import build_jira_url
from danswer.connectors.danswer_jira.utils import extract_jira_project
from danswer.connectors.danswer_jira.utils import extract_text_from_adf
from danswer.connectors.danswer_jira.utils import get_comment_strs
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.interfaces import SlimConnector
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.models import SlimDocument
from danswer.utils.logger import setup_logger
logger = setup_logger()
PROJECT_URL_PAT = "projects"
JIRA_API_VERSION = os.environ.get("JIRA_API_VERSION") or "2"
_JIRA_SLIM_PAGE_SIZE = 500
_JIRA_FULL_PAGE_SIZE = 50
def extract_jira_project(url: str) -> tuple[str, str]:
parsed_url = urlparse(url)
jira_base = parsed_url.scheme + "://" + parsed_url.netloc
def _paginate_jql_search(
jira_client: JIRA,
jql: str,
max_results: int,
fields: str | None = None,
) -> Iterable[Issue]:
start = 0
while True:
logger.debug(
f"Fetching Jira issues with JQL: {jql}, "
f"starting at {start}, max results: {max_results}"
)
issues = jira_client.search_issues(
jql_str=jql,
startAt=start,
maxResults=max_results,
fields=fields,
)
# Split the path by '/' and find the position of 'projects' to get the project name
split_path = parsed_url.path.split("/")
if PROJECT_URL_PAT in split_path:
project_pos = split_path.index(PROJECT_URL_PAT)
if len(split_path) > project_pos + 1:
jira_project = split_path[project_pos + 1]
else:
raise ValueError("No project name found in the URL")
else:
raise ValueError("'projects' not found in the URL")
for issue in issues:
if isinstance(issue, Issue):
yield issue
else:
raise Exception(f"Found Jira object not of type Issue: {issue}")
return jira_base, jira_project
if len(issues) < max_results:
break
def extract_text_from_adf(adf: dict | None) -> str:
"""Extracts plain text from Atlassian Document Format:
https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/
WARNING: This function is incomplete and will e.g. skip lists!
"""
texts = []
if adf is not None and "content" in adf:
for block in adf["content"]:
if "content" in block:
for item in block["content"]:
if item["type"] == "text":
texts.append(item["text"])
return " ".join(texts)
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
if hasattr(jira_issue.fields, field):
return getattr(jira_issue.fields, field)
try:
return jira_issue.raw["fields"][field]
except Exception:
return None
def _get_comment_strs(
jira: Issue, comment_email_blacklist: tuple[str, ...] = ()
) -> list[str]:
comment_strs = []
for comment in jira.fields.comment.comments:
try:
body_text = (
comment.body
if JIRA_API_VERSION == "2"
else extract_text_from_adf(comment.raw["body"])
)
if (
hasattr(comment, "author")
and hasattr(comment.author, "emailAddress")
and comment.author.emailAddress in comment_email_blacklist
):
continue # Skip adding comment if author's email is in blacklist
comment_strs.append(body_text)
except Exception as e:
logger.error(f"Failed to process comment due to an error: {e}")
continue
return comment_strs
start += max_results
def fetch_jira_issues_batch(
jql: str,
start_index: int,
jira_client: JIRA,
batch_size: int = INDEX_BATCH_SIZE,
jql: str,
batch_size: int,
comment_email_blacklist: tuple[str, ...] = (),
labels_to_skip: set[str] | None = None,
) -> tuple[list[Document], int]:
doc_batch = []
batch = jira_client.search_issues(
jql,
startAt=start_index,
maxResults=batch_size,
)
for jira in batch:
if type(jira) != Issue:
logger.warning(f"Found Jira object not of type Issue {jira}")
continue
if labels_to_skip and any(
label in jira.fields.labels for label in labels_to_skip
):
logger.info(
f"Skipping {jira.key} because it has a label to skip. Found "
f"labels: {jira.fields.labels}. Labels to skip: {labels_to_skip}."
)
continue
) -> Iterable[Document]:
for issue in _paginate_jql_search(
jira_client=jira_client,
jql=jql,
max_results=batch_size,
):
if labels_to_skip:
if any(label in issue.fields.labels for label in labels_to_skip):
logger.info(
f"Skipping {issue.key} because it has a label to skip. Found "
f"labels: {issue.fields.labels}. Labels to skip: {labels_to_skip}."
)
continue
description = (
jira.fields.description
issue.fields.description
if JIRA_API_VERSION == "2"
else extract_text_from_adf(jira.raw["fields"]["description"])
else extract_text_from_adf(issue.raw["fields"]["description"])
)
comments = get_comment_strs(
issue=issue,
comment_email_blacklist=comment_email_blacklist,
)
comments = _get_comment_strs(jira, comment_email_blacklist)
ticket_content = f"{description}\n" + "\n".join(
[f"Comment: {comment}" for comment in comments if comment]
)
@@ -142,66 +106,53 @@ def fetch_jira_issues_batch(
# Check ticket size
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
logger.info(
f"Skipping {jira.key} because it exceeds the maximum size of "
f"Skipping {issue.key} because it exceeds the maximum size of "
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
)
continue
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
page_url = f"{jira_client.client_info()}/browse/{issue.key}"
people = set()
try:
people.add(
BasicExpertInfo(
display_name=jira.fields.creator.displayName,
email=jira.fields.creator.emailAddress,
)
)
creator = best_effort_get_field_from_issue(issue, "creator")
if basic_expert_info := best_effort_basic_expert_info(creator):
people.add(basic_expert_info)
except Exception:
# Author should exist but if not, doesn't matter
pass
try:
people.add(
BasicExpertInfo(
display_name=jira.fields.assignee.displayName, # type: ignore
email=jira.fields.assignee.emailAddress, # type: ignore
)
)
assignee = best_effort_get_field_from_issue(issue, "assignee")
if basic_expert_info := best_effort_basic_expert_info(assignee):
people.add(basic_expert_info)
except Exception:
# Author should exist but if not, doesn't matter
pass
metadata_dict = {}
priority = best_effort_get_field_from_issue(jira, "priority")
if priority:
if priority := best_effort_get_field_from_issue(issue, "priority"):
metadata_dict["priority"] = priority.name
status = best_effort_get_field_from_issue(jira, "status")
if status:
if status := best_effort_get_field_from_issue(issue, "status"):
metadata_dict["status"] = status.name
resolution = best_effort_get_field_from_issue(jira, "resolution")
if resolution:
if resolution := best_effort_get_field_from_issue(issue, "resolution"):
metadata_dict["resolution"] = resolution.name
labels = best_effort_get_field_from_issue(jira, "labels")
if labels:
if labels := best_effort_get_field_from_issue(issue, "labels"):
metadata_dict["label"] = labels
doc_batch.append(
Document(
id=page_url,
sections=[Section(link=page_url, text=ticket_content)],
source=DocumentSource.JIRA,
semantic_identifier=jira.fields.summary,
doc_updated_at=time_str_to_utc(jira.fields.updated),
primary_owners=list(people) or None,
# TODO add secondary_owners (commenters) if needed
metadata=metadata_dict,
)
yield Document(
id=page_url,
sections=[Section(link=page_url, text=ticket_content)],
source=DocumentSource.JIRA,
semantic_identifier=issue.fields.summary,
doc_updated_at=time_str_to_utc(issue.fields.updated),
primary_owners=list(people) or None,
# TODO add secondary_owners (commenters) if needed
metadata=metadata_dict,
)
return doc_batch, len(batch)
class JiraConnector(LoadConnector, PollConnector):
class JiraConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
self,
jira_project_url: str,
@@ -213,8 +164,8 @@ class JiraConnector(LoadConnector, PollConnector):
labels_to_skip: list[str] = JIRA_CONNECTOR_LABELS_TO_SKIP,
) -> None:
self.batch_size = batch_size
self.jira_base, self.jira_project = extract_jira_project(jira_project_url)
self.jira_client: JIRA | None = None
self.jira_base, self._jira_project = extract_jira_project(jira_project_url)
self._jira_client: JIRA | None = None
self._comment_email_blacklist = comment_email_blacklist or []
self.labels_to_skip = set(labels_to_skip)
@@ -223,54 +174,45 @@ class JiraConnector(LoadConnector, PollConnector):
def comment_email_blacklist(self) -> tuple:
return tuple(email.strip() for email in self._comment_email_blacklist)
@property
def jira_client(self) -> JIRA:
if self._jira_client is None:
raise ConnectorMissingCredentialError("Jira")
return self._jira_client
@property
def quoted_jira_project(self) -> str:
# Quote the project name to handle reserved words
return f'"{self._jira_project}"'
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
api_token = credentials["jira_api_token"]
# if user provide an email we assume it's cloud
if "jira_user_email" in credentials:
email = credentials["jira_user_email"]
self.jira_client = JIRA(
basic_auth=(email, api_token),
server=self.jira_base,
options={"rest_api_version": JIRA_API_VERSION},
)
else:
self.jira_client = JIRA(
token_auth=api_token,
server=self.jira_base,
options={"rest_api_version": JIRA_API_VERSION},
)
self._jira_client = build_jira_client(
credentials=credentials,
jira_base=self.jira_base,
)
return None
def load_from_state(self) -> GenerateDocumentsOutput:
if self.jira_client is None:
raise ConnectorMissingCredentialError("Jira")
jql = f"project = {self.quoted_jira_project}"
# Quote the project name to handle reserved words
quoted_project = f'"{self.jira_project}"'
start_ind = 0
while True:
doc_batch, fetched_batch_size = fetch_jira_issues_batch(
jql=f"project = {quoted_project}",
start_index=start_ind,
jira_client=self.jira_client,
batch_size=self.batch_size,
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
)
document_batch = []
for doc in fetch_jira_issues_batch(
jira_client=self.jira_client,
jql=jql,
batch_size=_JIRA_FULL_PAGE_SIZE,
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
):
document_batch.append(doc)
if len(document_batch) >= self.batch_size:
yield document_batch
document_batch = []
if doc_batch:
yield doc_batch
start_ind += fetched_batch_size
if fetched_batch_size < self.batch_size:
break
yield document_batch
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
if self.jira_client is None:
raise ConnectorMissingCredentialError("Jira")
start_date_str = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
"%Y-%m-%d %H:%M"
)
@@ -278,31 +220,54 @@ class JiraConnector(LoadConnector, PollConnector):
"%Y-%m-%d %H:%M"
)
# Quote the project name to handle reserved words
quoted_project = f'"{self.jira_project}"'
jql = (
f"project = {quoted_project} AND "
f"project = {self.quoted_jira_project} AND "
f"updated >= '{start_date_str}' AND "
f"updated <= '{end_date_str}'"
)
start_ind = 0
while True:
doc_batch, fetched_batch_size = fetch_jira_issues_batch(
jql=jql,
start_index=start_ind,
jira_client=self.jira_client,
batch_size=self.batch_size,
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
document_batch = []
for doc in fetch_jira_issues_batch(
jira_client=self.jira_client,
jql=jql,
batch_size=_JIRA_FULL_PAGE_SIZE,
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
):
document_batch.append(doc)
if len(document_batch) >= self.batch_size:
yield document_batch
document_batch = []
yield document_batch
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
jql = f"project = {self.quoted_jira_project}"
slim_doc_batch = []
for issue in _paginate_jql_search(
jira_client=self.jira_client,
jql=jql,
max_results=_JIRA_SLIM_PAGE_SIZE,
fields="key",
):
issue_key = best_effort_get_field_from_issue(issue, "key")
id = build_jira_url(self.jira_client, issue_key)
slim_doc_batch.append(
SlimDocument(
id=id,
perm_sync_data=None,
)
)
if len(slim_doc_batch) >= _JIRA_SLIM_PAGE_SIZE:
yield slim_doc_batch
slim_doc_batch = []
if doc_batch:
yield doc_batch
start_ind += fetched_batch_size
if fetched_batch_size < self.batch_size:
break
yield slim_doc_batch
if __name__ == "__main__":

View File

@@ -1,17 +1,136 @@
"""Module with custom fields processing functions"""
import os
from typing import Any
from typing import List
from urllib.parse import urlparse
from jira import JIRA
from jira.resources import CustomFieldOption
from jira.resources import Issue
from jira.resources import User
from danswer.connectors.models import BasicExpertInfo
from danswer.utils.logger import setup_logger
logger = setup_logger()
PROJECT_URL_PAT = "projects"
JIRA_API_VERSION = os.environ.get("JIRA_API_VERSION") or "2"
def best_effort_basic_expert_info(obj: Any) -> BasicExpertInfo | None:
display_name = None
email = None
if hasattr(obj, "display_name"):
display_name = obj.display_name
else:
display_name = obj.get("displayName")
if hasattr(obj, "emailAddress"):
email = obj.emailAddress
else:
email = obj.get("emailAddress")
if not email and not display_name:
return None
return BasicExpertInfo(display_name=display_name, email=email)
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
if hasattr(jira_issue.fields, field):
return getattr(jira_issue.fields, field)
try:
return jira_issue.raw["fields"][field]
except Exception:
return None
def extract_text_from_adf(adf: dict | None) -> str:
"""Extracts plain text from Atlassian Document Format:
https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/
WARNING: This function is incomplete and will e.g. skip lists!
"""
texts = []
if adf is not None and "content" in adf:
for block in adf["content"]:
if "content" in block:
for item in block["content"]:
if item["type"] == "text":
texts.append(item["text"])
return " ".join(texts)
def build_jira_url(jira_client: JIRA, issue_key: str) -> str:
return f"{jira_client.client_info()}/browse/{issue_key}"
def build_jira_client(credentials: dict[str, Any], jira_base: str) -> JIRA:
api_token = credentials["jira_api_token"]
# if user provide an email we assume it's cloud
if "jira_user_email" in credentials:
email = credentials["jira_user_email"]
return JIRA(
basic_auth=(email, api_token),
server=jira_base,
options={"rest_api_version": JIRA_API_VERSION},
)
else:
return JIRA(
token_auth=api_token,
server=jira_base,
options={"rest_api_version": JIRA_API_VERSION},
)
def extract_jira_project(url: str) -> tuple[str, str]:
parsed_url = urlparse(url)
jira_base = parsed_url.scheme + "://" + parsed_url.netloc
# Split the path by '/' and find the position of 'projects' to get the project name
split_path = parsed_url.path.split("/")
if PROJECT_URL_PAT in split_path:
project_pos = split_path.index(PROJECT_URL_PAT)
if len(split_path) > project_pos + 1:
jira_project = split_path[project_pos + 1]
else:
raise ValueError("No project name found in the URL")
else:
raise ValueError("'projects' not found in the URL")
return jira_base, jira_project
def get_comment_strs(
issue: Issue, comment_email_blacklist: tuple[str, ...] = ()
) -> list[str]:
comment_strs = []
for comment in issue.fields.comment.comments:
try:
body_text = (
comment.body
if JIRA_API_VERSION == "2"
else extract_text_from_adf(comment.raw["body"])
)
if (
hasattr(comment, "author")
and hasattr(comment.author, "emailAddress")
and comment.author.emailAddress in comment_email_blacklist
):
continue # Skip adding comment if author's email is in blacklist
comment_strs.append(body_text)
except Exception as e:
logger.error(f"Failed to process comment due to an error: {e}")
continue
return comment_strs
class CustomFieldExtractor:
@staticmethod
def _process_custom_field_value(value: Any) -> str:

View File

@@ -305,6 +305,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
query = _build_time_range_query(time_range_start, time_range_end)
doc_batch = []
for user_email in self._get_all_user_emails():
logger.info(f"Fetching slim threads for user: {user_email}")
gmail_service = get_gmail_service(self.creds, user_email)
for thread in execute_paginated_retrieval(
retrieval_function=gmail_service.users().threads().list,

View File

@@ -15,6 +15,7 @@ from danswer.connectors.google_drive.doc_conversion import (
convert_drive_item_to_document,
)
from danswer.connectors.google_drive.file_retrieval import crawl_folders_for_files
from danswer.connectors.google_drive.file_retrieval import get_all_files_for_oauth
from danswer.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
from danswer.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from danswer.connectors.google_drive.models import GoogleDriveFileType
@@ -82,12 +83,31 @@ def _process_files_batch(
yield doc_batch
def _clean_requested_drive_ids(
requested_drive_ids: set[str],
requested_folder_ids: set[str],
all_drive_ids_available: set[str],
) -> tuple[set[str], set[str]]:
invalid_requested_drive_ids = requested_drive_ids - all_drive_ids_available
filtered_folder_ids = requested_folder_ids - all_drive_ids_available
if invalid_requested_drive_ids:
logger.warning(
f"Some shared drive IDs were not found. IDs: {invalid_requested_drive_ids}"
)
logger.warning("Checking for folder access instead...")
filtered_folder_ids.update(invalid_requested_drive_ids)
valid_requested_drive_ids = requested_drive_ids - invalid_requested_drive_ids
return valid_requested_drive_ids, filtered_folder_ids
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
self,
include_shared_drives: bool = True,
include_shared_drives: bool = False,
include_my_drives: bool = False,
include_files_shared_with_me: bool = False,
shared_drive_urls: str | None = None,
include_my_drives: bool = True,
my_drive_emails: str | None = None,
shared_folder_urls: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
@@ -120,22 +140,36 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
if (
not include_shared_drives
and not include_my_drives
and not include_files_shared_with_me
and not shared_folder_urls
and not my_drive_emails
and not shared_drive_urls
):
raise ValueError(
"At least one of include_shared_drives, include_my_drives,"
" or shared_folder_urls must be true"
"Nothing to index. Please specify at least one of the following: "
"include_shared_drives, include_my_drives, include_files_shared_with_me, "
"shared_folder_urls, or my_drive_emails"
)
self.batch_size = batch_size
self.include_shared_drives = include_shared_drives
specific_requests_made = False
if bool(shared_drive_urls) or bool(my_drive_emails) or bool(shared_folder_urls):
specific_requests_made = True
self.include_files_shared_with_me = (
False if specific_requests_made else include_files_shared_with_me
)
self.include_my_drives = False if specific_requests_made else include_my_drives
self.include_shared_drives = (
False if specific_requests_made else include_shared_drives
)
shared_drive_url_list = _extract_str_list_from_comma_str(shared_drive_urls)
self._requested_shared_drive_ids = set(
_extract_ids_from_urls(shared_drive_url_list)
)
self.include_my_drives = include_my_drives
self._requested_my_drive_emails = set(
_extract_str_list_from_comma_str(my_drive_emails)
)
@@ -192,80 +226,72 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
def _update_traversed_parent_ids(self, folder_id: str) -> None:
self._retrieved_ids.add(folder_id)
def _get_all_user_emails(self, admins_only: bool) -> list[str]:
def _get_all_user_emails(self) -> list[str]:
# Start with primary admin email
user_emails = [self.primary_admin_email]
# Only fetch additional users if using service account
if isinstance(self.creds, OAuthCredentials):
return user_emails
admin_service = get_admin_service(
creds=self.creds,
user_email=self.primary_admin_email,
)
query = "isAdmin=true" if admins_only else "isAdmin=false"
emails = []
for user in execute_paginated_retrieval(
retrieval_function=admin_service.users().list,
list_key="users",
fields=USER_FIELDS,
domain=self.google_domain,
query=query,
):
if email := user.get("primaryEmail"):
emails.append(email)
return emails
# Get admins first since they're more likely to have access to most files
for is_admin in [True, False]:
query = "isAdmin=true" if is_admin else "isAdmin=false"
for user in execute_paginated_retrieval(
retrieval_function=admin_service.users().list,
list_key="users",
fields=USER_FIELDS,
domain=self.google_domain,
query=query,
):
if email := user.get("primaryEmail"):
if email not in user_emails:
user_emails.append(email)
return user_emails
def _get_all_drive_ids(self) -> set[str]:
primary_drive_service = get_drive_service(
creds=self.creds,
user_email=self.primary_admin_email,
)
is_service_account = isinstance(self.creds, ServiceAccountCredentials)
all_drive_ids = set()
for drive in execute_paginated_retrieval(
retrieval_function=primary_drive_service.drives().list,
list_key="drives",
useDomainAdminAccess=True,
useDomainAdminAccess=is_service_account,
fields="drives(id)",
):
all_drive_ids.add(drive["id"])
return all_drive_ids
def _initialize_all_class_variables(self) -> None:
# Get all user emails
# Get admins first becuase they are more likely to have access to the most files
user_emails = [self.primary_admin_email]
for admins_only in [True, False]:
for email in self._get_all_user_emails(admins_only=admins_only):
if email not in user_emails:
user_emails.append(email)
self._all_org_emails = user_emails
self._all_drive_ids: set[str] = self._get_all_drive_ids()
# remove drive ids from the folder ids because they are queried differently
self._requested_folder_ids -= self._all_drive_ids
# Remove drive_ids that are not in the all_drive_ids and check them as folders instead
invalid_drive_ids = self._requested_shared_drive_ids - self._all_drive_ids
if invalid_drive_ids:
if not all_drive_ids:
logger.warning(
f"Some shared drive IDs were not found. IDs: {invalid_drive_ids}"
"No drives found even though we are indexing shared drives was requested."
)
logger.warning("Checking for folder access instead...")
self._requested_folder_ids.update(invalid_drive_ids)
if not self.include_shared_drives:
self._requested_shared_drive_ids = set()
elif not self._requested_shared_drive_ids:
self._requested_shared_drive_ids = self._all_drive_ids
return all_drive_ids
def _impersonate_user_for_retrieval(
self,
user_email: str,
is_slim: bool,
filtered_drive_ids: set[str],
filtered_folder_ids: set[str],
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
drive_service = get_drive_service(self.creds, user_email)
if self.include_my_drives and (
not self._requested_my_drive_emails
or user_email in self._requested_my_drive_emails
):
# if we are including my drives, try to get the current user's my
# drive if any of the following are true:
# - include_my_drives is true
# - the current user's email is in the requested emails
if self.include_my_drives or user_email in self._requested_my_drive_emails:
yield from get_all_files_in_my_drive(
service=drive_service,
update_traversed_ids_func=self._update_traversed_parent_ids,
@@ -274,7 +300,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
end=end,
)
remaining_drive_ids = self._requested_shared_drive_ids - self._retrieved_ids
remaining_drive_ids = filtered_drive_ids - self._retrieved_ids
for drive_id in remaining_drive_ids:
yield from get_files_in_shared_drive(
service=drive_service,
@@ -285,7 +311,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
end=end,
)
remaining_folders = self._requested_folder_ids - self._retrieved_ids
remaining_folders = filtered_folder_ids - self._retrieved_ids
for folder_id in remaining_folders:
yield from crawl_folders_for_files(
service=drive_service,
@@ -296,33 +322,142 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
end=end,
)
def _fetch_drive_items(
def _manage_service_account_retrieval(
self,
is_slim: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
self._initialize_all_class_variables()
all_org_emails: list[str] = self._get_all_user_emails()
all_drive_ids: set[str] = self._get_all_drive_ids()
drive_ids_to_retrieve: set[str] = set()
folder_ids_to_retrieve: set[str] = set()
if self._requested_shared_drive_ids or self._requested_folder_ids:
drive_ids_to_retrieve, folder_ids_to_retrieve = _clean_requested_drive_ids(
requested_drive_ids=self._requested_shared_drive_ids,
requested_folder_ids=self._requested_folder_ids,
all_drive_ids_available=all_drive_ids,
)
elif self.include_shared_drives:
drive_ids_to_retrieve = all_drive_ids
# Process users in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=10) as executor:
future_to_email = {
executor.submit(
self._impersonate_user_for_retrieval, email, is_slim, start, end
self._impersonate_user_for_retrieval,
email,
is_slim,
drive_ids_to_retrieve,
folder_ids_to_retrieve,
start,
end,
): email
for email in self._all_org_emails
for email in all_org_emails
}
# Yield results as they complete
for future in as_completed(future_to_email):
yield from future.result()
remaining_folders = self._requested_folder_ids - self._retrieved_ids
remaining_folders = (
drive_ids_to_retrieve | folder_ids_to_retrieve
) - self._retrieved_ids
if remaining_folders:
logger.warning(
f"Some folders/drives were not retrieved. IDs: {remaining_folders}"
)
def _manage_oauth_retrieval(
self,
is_slim: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
drive_service = get_drive_service(self.creds, self.primary_admin_email)
if self.include_files_shared_with_me or self.include_my_drives:
yield from get_all_files_for_oauth(
service=drive_service,
include_files_shared_with_me=self.include_files_shared_with_me,
include_my_drives=self.include_my_drives,
include_shared_drives=self.include_shared_drives,
is_slim=is_slim,
start=start,
end=end,
)
all_requested = (
self.include_files_shared_with_me
and self.include_my_drives
and self.include_shared_drives
)
if all_requested:
# If all 3 are true, we already yielded from get_all_files_for_oauth
return
all_drive_ids = self._get_all_drive_ids()
drive_ids_to_retrieve: set[str] = set()
folder_ids_to_retrieve: set[str] = set()
if self._requested_shared_drive_ids or self._requested_folder_ids:
drive_ids_to_retrieve, folder_ids_to_retrieve = _clean_requested_drive_ids(
requested_drive_ids=self._requested_shared_drive_ids,
requested_folder_ids=self._requested_folder_ids,
all_drive_ids_available=all_drive_ids,
)
elif self.include_shared_drives:
drive_ids_to_retrieve = all_drive_ids
for drive_id in drive_ids_to_retrieve:
yield from get_files_in_shared_drive(
service=drive_service,
drive_id=drive_id,
is_slim=is_slim,
update_traversed_ids_func=self._update_traversed_parent_ids,
start=start,
end=end,
)
# Even if no folders were requested, we still check if any drives were requested
# that could be folders.
remaining_folders = folder_ids_to_retrieve - self._retrieved_ids
for folder_id in remaining_folders:
yield from crawl_folders_for_files(
service=drive_service,
parent_id=folder_id,
traversed_parent_ids=self._retrieved_ids,
update_traversed_ids_func=self._update_traversed_parent_ids,
start=start,
end=end,
)
remaining_folders = (
drive_ids_to_retrieve | folder_ids_to_retrieve
) - self._retrieved_ids
if remaining_folders:
logger.warning(
f"Some folders/drives were not retrieved. IDs: {remaining_folders}"
)
def _fetch_drive_items(
self,
is_slim: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
retrieval_method = (
self._manage_service_account_retrieval
if isinstance(self.creds, ServiceAccountCredentials)
else self._manage_oauth_retrieval
)
return retrieval_method(
is_slim=is_slim,
start=start,
end=end,
)
def _extract_docs_from_google_drive(
self,
start: SecondsSinceUnixEpoch | None = None,

View File

@@ -2,6 +2,7 @@ import io
from datetime import datetime
from datetime import timezone
from googleapiclient.discovery import build # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
@@ -48,6 +49,67 @@ def _extract_sections_basic(
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
try:
if mime_type == GDriveMimeType.SPREADSHEET.value:
try:
sheets_service = build(
"sheets", "v4", credentials=service._http.credentials
)
spreadsheet = (
sheets_service.spreadsheets()
.get(spreadsheetId=file["id"])
.execute()
)
sections = []
for sheet in spreadsheet["sheets"]:
sheet_name = sheet["properties"]["title"]
sheet_id = sheet["properties"]["sheetId"]
# Get sheet dimensions
grid_properties = sheet["properties"].get("gridProperties", {})
row_count = grid_properties.get("rowCount", 1000)
column_count = grid_properties.get("columnCount", 26)
# Convert column count to letter (e.g., 26 -> Z, 27 -> AA)
end_column = ""
while column_count:
column_count, remainder = divmod(column_count - 1, 26)
end_column = chr(65 + remainder) + end_column
range_name = f"'{sheet_name}'!A1:{end_column}{row_count}"
try:
result = (
sheets_service.spreadsheets()
.values()
.get(spreadsheetId=file["id"], range=range_name)
.execute()
)
values = result.get("values", [])
if values:
text = f"Sheet: {sheet_name}\n"
for row in values:
text += "\t".join(str(cell) for cell in row) + "\n"
sections.append(
Section(
link=f"{link}#gid={sheet_id}",
text=text,
)
)
except HttpError as e:
logger.warning(
f"Error fetching data for sheet '{sheet_name}': {e}"
)
continue
return sections
except Exception as e:
logger.warning(
f"Ran into exception '{e}' when pulling data from Google Sheet '{file['name']}'."
" Falling back to basic extraction."
)
if mime_type in [
GDriveMimeType.DOC.value,
GDriveMimeType.PPT.value,
@@ -65,6 +127,7 @@ def _extract_sections_basic(
.decode("utf-8")
)
return [Section(link=link, text=text)]
elif mime_type in [
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value,

View File

@@ -140,8 +140,8 @@ def get_files_in_shared_drive(
) -> Iterator[GoogleDriveFileType]:
# If we know we are going to folder crawl later, we can cache the folders here
# Get all folders being queried and add them to the traversed set
query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
query += " and trashed = false"
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
folder_query += " and trashed = false"
found_folders = False
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
@@ -152,7 +152,7 @@ def get_files_in_shared_drive(
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields="nextPageToken, files(id)",
q=query,
q=folder_query,
):
update_traversed_ids_func(file["id"])
found_folders = True
@@ -160,9 +160,9 @@ def get_files_in_shared_drive(
update_traversed_ids_func(drive_id)
# Get all files in the shared drive
query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
query += " and trashed = false"
query += _generate_time_range_filter(start, end)
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
file_query += " and trashed = false"
file_query += _generate_time_range_filter(start, end)
yield from execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
@@ -172,7 +172,7 @@ def get_files_in_shared_drive(
supportsAllDrives=True,
includeItemsFromAllDrives=True,
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query,
q=file_query,
)
@@ -185,14 +185,16 @@ def get_all_files_in_my_drive(
) -> Iterator[GoogleDriveFileType]:
# If we know we are going to folder crawl later, we can cache the folders here
# Get all folders being queried and add them to the traversed set
query = "trashed = false and 'me' in owners"
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
folder_query += " and trashed = false"
folder_query += " and 'me' in owners"
found_folders = False
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
corpora="user",
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query,
q=folder_query,
):
update_traversed_ids_func(file["id"])
found_folders = True
@@ -200,18 +202,52 @@ def get_all_files_in_my_drive(
update_traversed_ids_func(get_root_folder_id(service))
# Then get the files
query = "trashed = false and 'me' in owners"
query += _generate_time_range_filter(start, end)
fields = "files(id, name, mimeType, webViewLink, modifiedTime, createdTime)"
if not is_slim:
fields += ", files(permissions, permissionIds, owners)"
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
file_query += " and trashed = false"
file_query += " and 'me' in owners"
file_query += _generate_time_range_filter(start, end)
yield from execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
corpora="user",
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query,
q=file_query,
)
def get_all_files_for_oauth(
service: Any,
include_files_shared_with_me: bool,
include_my_drives: bool,
# One of the above 2 should be true
include_shared_drives: bool,
is_slim: bool = False,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
should_get_all = (
include_shared_drives and include_my_drives and include_files_shared_with_me
)
corpora = "allDrives" if should_get_all else "user"
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
file_query += " and trashed = false"
file_query += _generate_time_range_filter(start, end)
if not should_get_all:
if include_files_shared_with_me and not include_my_drives:
file_query += " and not 'me' in owners"
if not include_files_shared_with_me and include_my_drives:
file_query += " and 'me' in owners"
yield from execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
corpora=corpora,
includeItemsFromAllDrives=should_get_all,
supportsAllDrives=should_get_all,
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=file_query,
)

View File

@@ -105,7 +105,7 @@ def execute_paginated_retrieval(
)()
elif e.resp.status == 404 or e.resp.status == 403:
if continue_on_404_or_403:
logger.warning(f"Error executing request: {e}")
logger.debug(f"Error executing request: {e}")
results = {}
else:
raise e

View File

@@ -12,12 +12,15 @@ from dateutil import parser
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.interfaces import SlimConnector
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.models import SlimDocument
from danswer.utils.logger import setup_logger
@@ -28,6 +31,8 @@ logger = setup_logger()
SLAB_GRAPHQL_MAX_TRIES = 10
SLAB_API_URL = "https://api.slab.com/v1/graphql"
_SLIM_BATCH_SIZE = 1000
def run_graphql_request(
graphql_query: dict, bot_token: str, max_tries: int = SLAB_GRAPHQL_MAX_TRIES
@@ -158,21 +163,26 @@ def get_slab_url_from_title_id(base_url: str, title: str, page_id: str) -> str:
return urljoin(urljoin(base_url, "posts/"), url_id)
class SlabConnector(LoadConnector, PollConnector):
class SlabConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
self,
base_url: str,
batch_size: int = INDEX_BATCH_SIZE,
slab_bot_token: str | None = None,
) -> None:
self.base_url = base_url
self.batch_size = batch_size
self.slab_bot_token = slab_bot_token
self._slab_bot_token: str | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.slab_bot_token = credentials["slab_bot_token"]
self._slab_bot_token = credentials["slab_bot_token"]
return None
@property
def slab_bot_token(self) -> str:
if self._slab_bot_token is None:
raise ConnectorMissingCredentialError("Slab")
return self._slab_bot_token
def _iterate_posts(
self, time_filter: Callable[[datetime], bool] | None = None
) -> GenerateDocumentsOutput:
@@ -227,3 +237,21 @@ class SlabConnector(LoadConnector, PollConnector):
yield from self._iterate_posts(
time_filter=lambda t: start_time <= t <= end_time
)
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
slim_doc_batch: list[SlimDocument] = []
for post_id in get_all_post_ids(self.slab_bot_token):
slim_doc_batch.append(
SlimDocument(
id=post_id,
)
)
if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
yield slim_doc_batch
slim_doc_batch = []
if slim_doc_batch:
yield slim_doc_batch

View File

@@ -102,13 +102,21 @@ def _get_tickets(
def _fetch_author(client: ZendeskClient, author_id: str) -> BasicExpertInfo | None:
author_data = client.make_request(f"users/{author_id}", {})
user = author_data.get("user")
return (
BasicExpertInfo(display_name=user.get("name"), email=user.get("email"))
if user and user.get("name") and user.get("email")
else None
)
# Skip fetching if author_id is invalid
if not author_id or author_id == "-1":
return None
try:
author_data = client.make_request(f"users/{author_id}", {})
user = author_data.get("user")
return (
BasicExpertInfo(display_name=user.get("name"), email=user.get("email"))
if user and user.get("name") and user.get("email")
else None
)
except requests.exceptions.HTTPError:
# Handle any API errors gracefully
return None
def _article_to_document(

View File

@@ -8,13 +8,13 @@ from pydantic import field_validator
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.configs.constants import DocumentSource
from danswer.context.search.enums import LLMEvaluationType
from danswer.context.search.enums import OptionalSearchSetting
from danswer.context.search.enums import SearchType
from danswer.db.models import Persona
from danswer.db.models import SearchSettings
from danswer.indexing.models import BaseChunk
from danswer.indexing.models import IndexingSetting
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import SearchType
from shared_configs.enums import RerankerProvider

View File

@@ -5,33 +5,33 @@ from typing import cast
from sqlalchemy.orm import Session
from danswer.chat.models import PromptConfig
from danswer.chat.models import SectionRelevancePiece
from danswer.chat.prune_and_merge import _merge_sections
from danswer.chat.prune_and_merge import ChunkRange
from danswer.chat.prune_and_merge import merge_chunk_intervals
from danswer.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
from danswer.context.search.enums import LLMEvaluationType
from danswer.context.search.enums import QueryFlow
from danswer.context.search.enums import SearchType
from danswer.context.search.models import IndexFilters
from danswer.context.search.models import InferenceChunk
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import RerankMetricsContainer
from danswer.context.search.models import RetrievalMetricsContainer
from danswer.context.search.models import SearchQuery
from danswer.context.search.models import SearchRequest
from danswer.context.search.postprocessing.postprocessing import cleanup_chunks
from danswer.context.search.postprocessing.postprocessing import search_postprocessing
from danswer.context.search.preprocessing.preprocessing import retrieval_preprocessing
from danswer.context.search.retrieval.search_runner import retrieve_chunks
from danswer.context.search.utils import inference_section_from_chunks
from danswer.context.search.utils import relevant_sections_to_indices
from danswer.db.models import User
from danswer.db.search_settings import get_current_search_settings
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaChunkRequest
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prune_and_merge import _merge_sections
from danswer.llm.answering.prune_and_merge import ChunkRange
from danswer.llm.answering.prune_and_merge import merge_chunk_intervals
from danswer.llm.interfaces import LLM
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchQuery
from danswer.search.models import SearchRequest
from danswer.search.postprocessing.postprocessing import cleanup_chunks
from danswer.search.postprocessing.postprocessing import search_postprocessing
from danswer.search.preprocessing.preprocessing import retrieval_preprocessing
from danswer.search.retrieval.search_runner import retrieve_chunks
from danswer.search.utils import inference_section_from_chunks
from danswer.search.utils import relevant_sections_to_indices
from danswer.secondary_llm_flows.agentic_evaluation import evaluate_inference_section
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall

View File

@@ -9,19 +9,19 @@ from danswer.configs.app_configs import BLURB_SIZE
from danswer.configs.constants import RETURN_SEPARATOR
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
from danswer.context.search.enums import LLMEvaluationType
from danswer.context.search.models import ChunkMetric
from danswer.context.search.models import InferenceChunk
from danswer.context.search.models import InferenceChunkUncleaned
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import MAX_METRICS_CONTENT
from danswer.context.search.models import RerankMetricsContainer
from danswer.context.search.models import SearchQuery
from danswer.document_index.document_index_utils import (
translate_boost_count_to_multiplier,
)
from danswer.llm.interfaces import LLM
from danswer.natural_language_processing.search_nlp_models import RerankingModel
from danswer.search.enums import LLMEvaluationType
from danswer.search.models import ChunkMetric
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceChunkUncleaned
from danswer.search.models import InferenceSection
from danswer.search.models import MAX_METRICS_CONTENT
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import SearchQuery
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall

View File

@@ -1,8 +1,8 @@
from sqlalchemy.orm import Session
from danswer.access.access import get_acl_for_user
from danswer.context.search.models import IndexFilters
from danswer.db.models import User
from danswer.search.models import IndexFilters
def build_access_filters_for_user(user: User | None, session: Session) -> list[str]:

View File

@@ -9,21 +9,25 @@ from danswer.configs.chat_configs import HYBRID_ALPHA
from danswer.configs.chat_configs import HYBRID_ALPHA_KEYWORD
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.context.search.enums import LLMEvaluationType
from danswer.context.search.enums import RecencyBiasSetting
from danswer.context.search.enums import SearchType
from danswer.context.search.models import BaseFilters
from danswer.context.search.models import IndexFilters
from danswer.context.search.models import RerankingDetails
from danswer.context.search.models import SearchQuery
from danswer.context.search.models import SearchRequest
from danswer.context.search.preprocessing.access_filters import (
build_access_filters_for_user,
)
from danswer.context.search.retrieval.search_runner import (
remove_stop_words_and_punctuation,
)
from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
from danswer.db.models import User
from danswer.db.search_settings import get_current_search_settings
from danswer.llm.interfaces import LLM
from danswer.natural_language_processing.search_nlp_models import QueryAnalysisModel
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import RecencyBiasSetting
from danswer.search.enums import SearchType
from danswer.search.models import BaseFilters
from danswer.search.models import IndexFilters
from danswer.search.models import RerankingDetails
from danswer.search.models import SearchQuery
from danswer.search.models import SearchRequest
from danswer.search.preprocessing.access_filters import build_access_filters_for_user
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
from danswer.secondary_llm_flows.source_filter import extract_source_filter
from danswer.secondary_llm_flows.time_filter import extract_time_filter
from danswer.utils.logger import setup_logger

Some files were not shown because too many files have changed in this diff Show More