Compare commits

...

124 Commits

Author SHA1 Message Date
joachim-danswer
0d848fa9dd more experimentation changes 2024-12-13 12:57:58 -08:00
joachim-danswer
1dbb6f3d69 experimentation 2024-12-09 11:14:18 -08:00
joachim-danswer
91cf9a5472 Fit Score & more suitable rewrite 2024-12-08 09:20:03 -08:00
joachim-danswer
ebb0e56a30 initial fixes for core_qa_graph 2024-12-07 22:07:48 -08:00
hagen-danswer
091cb136c4 got core qa graph working 2024-12-07 12:25:54 -08:00
hagen-danswer
56052c5b4b imports 2024-12-07 06:09:57 -08:00
hagen-danswer
617726207b all 3 graphs r done 2024-12-07 06:06:22 -08:00
hagen-danswer
1be58e74b3 Finished primary graph 2024-12-06 11:01:03 -08:00
hagen-danswer
a693c991d7 Merge remote-tracking branch 'origin/agent-search-a' into initial-implementation 2024-12-04 15:42:58 -08:00
hagen-danswer
ef9942b751 Related permission docs to cc_pair to prevent orphan docs (#3336)
* Related permission docs to cc_pair to prevent orphan docs

* added script

* group sync deduping

* logging
2024-12-04 21:00:54 +00:00
pablodanswer
993acec5e9 Update memoization + silence unnecessary errors (#3337)
* update memoization + silence unnecessary errors

* proper org
2024-12-04 20:08:15 +00:00
Weves
b01a1b509a Add basic loadtest script 2024-12-04 10:53:48 -08:00
pablodanswer
4f994124ef remove now unnecessary user loading indicatort log (#3333) 2024-12-04 00:09:22 +00:00
rkuo-danswer
14863bd457 try single threaded playwright testing (#3322) 2024-12-03 23:21:46 +00:00
Yuhong Sun
aa1c4c635a Combining Search and Chat Backend (#3273)
* k

* k

* fix slack issues

* rebase

* k
2024-12-03 22:37:14 +00:00
rkuo-danswer
13f6e8a6b4 disable thread local locking in callbacks (#3319) 2024-12-03 22:32:56 +00:00
pablodanswer
66f47d294c Shared filter utility for clarity (#3270)
* shared filter util

* clearer comment
2024-12-03 19:30:42 +00:00
pablodanswer
0a685bda7d add comments for clarity (#3249) 2024-12-03 19:27:28 +00:00
pablodanswer
23dc8b5dad Search flow improvements (#3314)
* untoggle if no docs

* update

* nits

* nit

* typing

* nit
2024-12-03 18:56:27 +00:00
pablodanswer
cd5f2293ad Temperature (#3310)
* fix temperatures for default llm

* ensure anthropic models don't overflow

* minor cleanup

* k

* k

* k

* fix typing
2024-12-03 17:22:22 +00:00
rkuo-danswer
6c2269e565 refactor celery task names to constants (#3296) 2024-12-03 16:02:17 +00:00
Weves
46315cddf1 Adjust default confulence timezone 2024-12-02 22:25:29 -08:00
rkuo-danswer
5f28a1b0e4 Bugfix/confluence time zone (#3265)
* RedisLock typing

* checkpoint

* put in debug logging

* improve comments

* mypy fixes
2024-12-02 22:23:23 -08:00
rkuo-danswer
9e9b7ed61d Bugfix/connector aborted logging (#3309)
* improve error logging on task failure.

* add db exception hardening to the indexing watchdog

* log on db exception
2024-12-03 02:34:40 +00:00
pablodanswer
3fb2bfefec Update Chromatic Tests (#3300)
* remove / update search tests

* minor update
2024-12-02 23:08:54 +00:00
pablodanswer
7c618c9d17 Unified UI (#3308)
* fix typing

* add filters display
2024-12-02 15:12:13 -08:00
pablodanswer
03e2789392 Text embedding (PDF, TXT) (#3113)
* add text embedding

* post rebase cleanup

* fully functional post rebase

* rm logs

* rm '

* quick clean up

* k
2024-12-02 22:43:53 +00:00
Chris Weaver
2783fa08a3 Update openai version in model server (#3306) 2024-12-02 21:39:10 +00:00
pablodanswer
edeaee93a2 hard refresh on auth (#3305)
* hard refresh on auth

* k

* k

* comment for clarity
2024-12-02 20:12:12 +00:00
hagen-danswer
5385bae100 Add slim connector description (#3303)
* added docs example and test

* updated docs

* needed to make the tests run

* updated docs
2024-12-02 19:52:13 +00:00
pablodanswer
813445ab59 Minor JWT Feature (#3290)
* first pass

* k

* k

* finalize

* minor cleanup

* k

* address

* minor typing updates
2024-12-02 19:14:31 +00:00
pablodanswer
af814823c8 display name + model truncation (#3304) 2024-12-02 18:54:08 +00:00
pablodanswer
607f61eaeb Reusable function for search settings spread operation (#3301)
* combine for clarity once and for all

* remove logs

* k
2024-12-02 17:23:01 +00:00
hagen-danswer
4b28686721 Added Initial Implementation of the Agent Search Graph 2024-12-02 07:16:08 -08:00
pablodanswer
de66f7adb2 Updated chat flow (#3244)
* proper no assistant typing + no assistant modal

* updated chat flow

* k

* updates

* update

* k

* clean up

* fix mystery reorg

* cleanup

* update scroll

* default

* update logs

* push fade

* scroll nit

* finalize tags

* updates

* k

* various updates

* viewport height update

* source types update

* clean up unused components

* minor cleanup

* cleanup complete

* finalize changes

* badge up

* update filters

* small nit

* k

* k

* address comments

* quick unification of icons

* minor date range clarity

* minor nit

* k

* update sidebar line

* update for all screen sizes

* k

* k

* k

* k

* rm shs

* fix memoization

* fix memoization

* slack chat

* k

* k

* build org
2024-12-02 01:58:28 +00:00
Yuhong Sun
3432d932d1 Citation code comments 2024-12-01 14:10:11 -08:00
Yuhong Sun
9bd0cb9eb5 Fix Citation Minor Bugs (#3294) 2024-12-01 13:55:24 -08:00
Chris Weaver
f12eb4a5cf Fix assistant prompt zero-ing (#3293) 2024-11-30 04:45:40 +00:00
Chris Weaver
16863de0aa Improve model token limit detection (#3292)
* Properly find context window for ollama llama

* Better ollama support + upgrade litellm

* Ugprade OpenAI as well

* Fix mypy
2024-11-30 04:42:56 +00:00
Weves
63d1eefee5 Add read_only=True for xlsx parsing 2024-11-28 16:02:02 -08:00
pablodanswer
e338677896 order seeding 2024-11-28 15:41:10 -08:00
hagen-danswer
7be80c4af9 increased the pagination limit for confluence spaces (#3288) 2024-11-28 19:04:38 +00:00
rkuo-danswer
7f1e4a02bf Feature/kill indexing (#3213)
* checkpoint

* add celery termination of the task

* rename to RedisConnectorPermissionSyncPayload, add RedisLock to more places, add get_active_search_settings

* rename payload

* pretty sure these weren't named correctly

* testing in progress

* cleanup

* remove space

* merge fix

* three dots animation on Pausing

* improve messaging when connector is stopped or killed and animate buttons

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-11-28 05:32:45 +00:00
rkuo-danswer
5be7d27285 use indexing flag in db for manually triggering indexing (#3264)
* use indexing flag in db for manually trigger indexing

* add comment.

* only try to release the lock if we actually succeeded with the lock

* ensure we don't trigger manual indexing on anything but the primary search settings

* comment usage of primary search settings

* run check for indexing immediately after indexing triggers are set

* reorder fix
2024-11-28 01:34:34 +00:00
Weves
fd84b7a768 Remove duplicate API key router 2024-11-27 16:30:59 -08:00
Subash-Mohan
36941ae663 fix: Cannot configure API keys #3191 2024-11-27 16:25:00 -08:00
Matthew Holland
212353ed4a Fixed default feedback options 2024-11-27 16:23:52 -08:00
Richard Kuo (Danswer)
eb8708f770 the word "error" might be throwing off sentry 2024-11-27 14:31:21 -08:00
Chris Weaver
ac448956e9 Add handling for rate limiting (#3280) 2024-11-27 14:22:15 -08:00
pablodanswer
634a0b9398 no stack by default (#3278) 2024-11-27 20:58:21 +00:00
hagen-danswer
09d3e47c03 Perm sync behavior change (#3262)
* Change external permissions behavior

* fixed behavior

* added error handling

* LLM the goat

* comment

* simplify

* fixed

* done

* limits increased

* added a ton of logging

* uhhhh
2024-11-27 20:04:15 +00:00
pablodanswer
9c0cc94f15 refresh router -> refresh assistants (#3271) 2024-11-27 19:11:58 +00:00
hagen-danswer
07dfde2209 add continue in danswer button to slack bot responses (#3239)
* all done except routing

* fixed initial changes

* added backend endpoint for duplicating a chat session from Slack

* got chat duplication routing done

* got login routing working

* improved answer handling

* finished all checks

* finished all!

* made sure it works with google oauth

* dont remove that lol

* fixed weird thing

* bad comments
2024-11-27 18:25:38 +00:00
pablodanswer
28e2b78b2e Fix search dropdown (#3269)
* validate dropdown

* validate

* update organization

* move to utils
2024-11-27 16:10:07 +00:00
Emerson Gomes
0553062ac6 Adds icons for Google Gemini models and custom model icons for L… (#3218)
* Add description for Google Gemini models and custom model icons for LiteLLM (OpenAI) proxied models

* Adds Vertex AI aliases for Claude

---------

Co-authored-by: Emerson Gomes <emerson.gomes@thalesgroup.com>
2024-11-26 10:13:21 -08:00
hagen-danswer
284e375ba3 Merge pull request #3257 from danswer-ai/minor-perm-sync
Improved logging for confluence doc sync and robust user creation
2024-11-26 09:59:38 -08:00
hagen-danswer
1f2f7d0ac2 Improved logging for confluence doc sync and robust user creation 2024-11-26 08:51:15 -08:00
pablodanswer
2ecc28b57d remove unused stripe promise (#3248) 2024-11-26 01:50:39 +00:00
rkuo-danswer
77cf9b3539 improve messaging and UI around cleanup of leftover index attempts (#3247)
* improve messaging and UI around cleanup of leftover index attempts

* add tag on init
2024-11-25 22:27:14 +00:00
Weves
076ce2ebd0 Saml fix 2024-11-25 09:12:43 -08:00
pablodanswer
b625ee32a7 File handling cleanup (#3240)
* fix google sites connector

* minior cleanup

* rm comments
2024-11-25 04:06:47 +00:00
Richard Kuo (Danswer)
c32b93fcc3 increase indexing worker concurrency to 3 2024-11-24 18:11:58 -08:00
pablodanswer
1c8476072e Assistant cleanup (#3236)
* minor cleanup

* ensure users don't modify built-in attributes of assistants

* update sidebar

* k

* update update flow + assistant creation
2024-11-25 00:13:34 +00:00
Chris Weaver
7573416ca1 Fix API keys for MIT users (#3237) 2024-11-24 16:55:19 -08:00
Yuhong Sun
86d8666481 Add Test Case 2024-11-24 15:42:14 -08:00
Yuhong Sun
8abcde91d4 Fix Test (#3242) 2024-11-24 14:31:28 -08:00
Yuhong Sun
3466451d51 Fix Prompt for Non Function Calling LLMs (#3241) 2024-11-24 14:16:57 -08:00
Yuhong Sun
413891f143 Token Level Log (#3238) 2024-11-23 18:41:50 -08:00
Yuhong Sun
7a0a4d4b79 Remove Deprecated Endpoints (#3235) 2024-11-23 14:44:23 -08:00
Yuhong Sun
a3439605a5 Remove Dead Code (#3234) 2024-11-23 14:31:59 -08:00
pablodanswer
694e79f5e1 minor enforcement of CSV length for internal processing (#3109) 2024-11-23 21:05:30 +00:00
pablodanswer
5dfafc8612 minor calendar cleanup (#3219) 2024-11-23 21:01:05 +00:00
Yuhong Sun
62a4aa10db Refactor Search (#3233) 2024-11-23 13:42:54 -08:00
Yuhong Sun
a357cdc4c9 Remove Dead Code (#3232) 2024-11-23 13:21:27 -08:00
Yuhong Sun
84615abfdd Seeding (#3231) 2024-11-23 13:12:42 -08:00
pablodanswer
8ae6b1960b Bugfix/usage report (#3075)
* fix pagination

* update side

* fixed query history

* minor update

* minor update

* typing
2024-11-23 20:11:39 +00:00
James Jordan
d9b87bbbc2 Fixed 400 error when author of ticket is no longer an active user in a Zendesk account. (#3168) 2024-11-23 12:15:38 -08:00
Sanju Lokuhitige
a0065b01af Update CONTRIBUTING.md (#3112)
fix Formatting and Linting hyperlink
2024-11-23 12:13:23 -08:00
pablodanswer
c5306148a3 Ensure daterange not consistently re rendered (#3229)
* ensure daterange not consistently re rendered

* minor clean up
2024-11-23 19:35:00 +00:00
hagen-danswer
1e17934de4 Merge pull request #3214 from danswer-ai/fix-slack-ui
cleaned up new slack bot creation
2024-11-23 10:53:47 -08:00
pablodanswer
93add96ccc Various Nits (#3228) 2024-11-23 10:53:24 -08:00
rkuo-danswer
3a466a4b08 add minimal retries to confluence probe (#3222)
* add minimal retries to confluence probe

* name variable correctly
2024-11-23 17:11:15 +00:00
hagen-danswer
85cbd9caed Increased slim doc batch size for confluence connector (#3221) 2024-11-23 00:42:15 +00:00
pablodanswer
9dc23bf3e7 revert to previous doc select logic (#3217)
* revert to previous doc select logic

* k
2024-11-22 23:26:53 +00:00
hagen-danswer
e32809f7ca moved it outside 2024-11-22 14:59:58 -08:00
hagen-danswer
3e58f9f8ab fixed ugly stuff 2024-11-22 14:39:55 -08:00
pablodanswer
2381c8d498 Refresh all assistants on assistant refresh (#3216)
* k

* k
2024-11-22 22:38:23 +00:00
hagen-danswer
c6dadb24dc cleaned up new slack bot creation 2024-11-22 11:53:51 -08:00
hagen-danswer
5dc07d4178 Each section is now cleaned before being chunked (#3210)
* Each section is now cleaned before being chunked

* k

---------

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2024-11-22 19:06:19 +00:00
Chris Weaver
129c8f8faf Add start/end date ability for query history as CSV endpoint (#3211) 2024-11-22 18:29:13 +00:00
pablodanswer
67bfcabbc5 llm provider causing re render in effect (#3205)
* llm provider causing re render in effect

* clean

* unused

* k
2024-11-22 16:53:24 +00:00
rkuo-danswer
9819aa977a implement double check pattern for error conditions (#3201)
* Move unfenced check to check_for_indexing. implement a double check pattern for all indexing error checks

* improved commenting

* exclusions
2024-11-22 04:21:02 +00:00
hagen-danswer
8d5b8a4028 Merge pull request #3202 from danswer-ai/toggled_chat_default
Update default sidebar toggle
2024-11-21 19:53:05 -08:00
pablodanswer
682319d2e9 Bugfix/curator interface (#3198)
* mystery solved

* update config

* update

* update

* update user role

* remove values
2024-11-22 02:33:09 +00:00
hagen-danswer
fe1400aa36 replace deprecated confluence group api endpoint (#3197)
* replace deprecated confluence group api endpoint

* reworked it

* properly escaped the user query

* less passing around is_cloud

* done
2024-11-22 01:51:29 +00:00
pablodanswer
e3573b2bc1 add comment 2024-11-21 17:11:11 -08:00
pablodanswer
35b5c44cc7 update default sidebar toggle 2024-11-21 17:09:56 -08:00
rkuo-danswer
5eddc89b5a merge indexing and heartbeat callbacks (and associated lock reacquisi… (#3178)
* merge indexing and heartbeat callbacks (and associated lock reacquisition). no db updates

* review fixes
2024-11-21 23:48:58 +00:00
hagen-danswer
9a492ceb6d admins cant be set as curator on backend (#3194)
* set-curator

* updated error
2024-11-21 23:33:29 +00:00
rkuo-danswer
3c54ae9de9 Bugfix/redis wait (#3169)
* rename to payload

* log redis info replication on primary worker startup

* fix mypy

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-11-21 23:11:00 +00:00
pablodanswer
13f08f3ebb Horizontal scrollbar (#3195)
* clean horizontal scrollbar

* account for additional edge case
2024-11-21 22:08:21 +00:00
pablodanswer
bd9f15854f provider fix (#3187)
* clean horizontal scrollbar

* provider fix

* ensure proper migration

* k

* update migration

* Revert "clean horizontal scrollbar"

This reverts commit fa592a1b7a.
2024-11-21 22:08:16 +00:00
pablodanswer
366aa2a8ea quick fix (#3200) 2024-11-21 14:07:55 -08:00
pablodanswer
deee237c7e Sheet update (#3189)
* quick pass

* k

* update sheet

* add multiple sheet stuff

* k

* finalized

* update configuration
2024-11-21 18:07:00 +00:00
hagen-danswer
100b4a0d16 Added Slim connector for Jira (#3181)
* Added Slim connector for Jira

* fixed testing

* more cleanup of Jira connector

* cleanup
2024-11-21 17:00:20 +00:00
rkuo-danswer
70207b4b39 improve web testing (#3162)
* shared admin level test dependency

* change to on - push (recommended by chromatic)

* change playwright reporter to list, name test jobs

* use test tags ... much cleaner

* test vs prod

* try copying templates

* run with localhost?

* revert to dev

* new tests and a bit of refactoring

* add additional checks so that page snapshots reflect loaded state

* more admin tests

* User Management tests

* remaining admin pages

* test search and chat

* await fix and exclude UI that changes with dates.
2024-11-21 04:01:15 +00:00
pablodanswer
50826b6bef Formatting Niceties (#3183)
* search bar formatting

* update styling
2024-11-21 03:11:26 +00:00
pablodanswer
3f648cbc31 Folder clarity (#3180)
* folder clarity

* k
2024-11-21 03:11:17 +00:00
pablodanswer
c875a4774f valid props (#3186) 2024-11-21 01:13:54 +00:00
hagen-danswer
049091eb01 decreased confluence retry times and added more logging (#3184)
* decreased confluence retry times and added more logging

* added check on connector startup

* no retries!

* fr no retries
2024-11-21 00:00:14 +00:00
pablodanswer
3dac24542b silence small error (#3182) 2024-11-20 22:46:38 +00:00
pablodanswer
194dcb593d update slack redirect + token missing check (#3179)
* update slack redirect + token missing check

* reset time
2024-11-20 21:42:54 +00:00
pablodanswer
bf291d0c0a Fix missing json (#3177)
* initial steps

* k

* remove logs

* k

* k
2024-11-20 21:24:43 +00:00
rkuo-danswer
8309f4a802 test overlapping connectors (but using a source that is way too big a… (#3152)
* test overlapping connectors (but using a source that is way too big and slow, fix that next)

* pass thru secrets

* rename

* rename again

* now we are fixing it

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-11-20 21:12:01 +00:00
pablodanswer
0ff2565125 ensure margin properly applied (#3176)
* ensure margin properly applied

* formatting
2024-11-20 20:04:45 +00:00
hagen-danswer
e89dcd7f84 added logging and bugfixing to conf (#3167)
* standardized escaping of CQL strings

* think i found it

* fix

* should be fixed

* added handling for special linking behavior in confluence

* Update onyx_confluence.py

* Update onyx_confluence.py

---------

Co-authored-by: rkuo-danswer <rkuo@danswer.ai>
2024-11-20 18:40:21 +00:00
pablodanswer
645e7e828e Add Google Tag Manager for Web Cloud Build (#3173)
* add gtm for cloud build

* update github workflow
2024-11-20 17:38:33 +00:00
pablodanswer
2a54f14195 ensure everythigng has a default max height in selectorformfield (#3174) 2024-11-20 17:26:22 +00:00
hagen-danswer
9209fc804b multiple slackbot support (#3077)
* multiple slackbot support

* app_id + tenant_id key

* removed kv store stuff

* fixed up mypy and migration

* got frontend working for multiple slack bots

* some frontend stuff

* alembic fix

* might be valid

* refactor dun

* alembic stuff

* temp frontend stuff

* alembic stuff

* maybe fixed alembic

* maybe dis fix

* im getting mad

* api names changed

* tested

* almost done

* done

* routing nonsense

* done!

* done!!

* fr done

* doneski

* fix alembic migration

* getting mad again

* PLEASE IM BEGGING YOU
2024-11-20 01:49:43 +00:00
rkuo-danswer
b712877701 Merge pull request #3165 from danswer-ai/bugfix/pruning_logs
improve logging around pruning
2024-11-19 13:19:31 -08:00
Richard Kuo (Danswer)
e6df32dcc3 improve logging around pruning 2024-11-19 12:41:21 -08:00
Chris Weaver
eb81258a23 Update README.md
Fix slack link
2024-11-19 08:02:35 -08:00
hagen-danswer
487ef4acc0 Merge pull request #3160 from danswer-ai/add-to-admin-chat-sessions-api
Extend query history API
2024-11-19 07:28:12 -08:00
Weves
ce3124f9e4 Extend query history API 2024-11-18 17:50:21 -08:00
491 changed files with 15902 additions and 8622 deletions

View File

@@ -65,6 +65,7 @@ jobs:
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_GTM_ENABLED=true
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -3,12 +3,7 @@ concurrency:
group: Run-Chromatic-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
merge_group:
pull_request:
branches:
- main
- 'release/**'
on: push
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@@ -16,6 +11,8 @@ env:
jobs:
playwright-tests:
name: Playwright Tests
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"]
steps:
@@ -108,7 +105,7 @@ jobs:
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Start Docker containers
- name: Start Docker containers
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
@@ -193,7 +190,8 @@ jobs:
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
chromatic-tests:
name: Run Chromatic
name: Chromatic Tests
needs: playwright-tests
runs-on: [runs-on,runner=8cpu-linux-x64,ram=16,"run-id=${{ github.run_id }}"]
steps:

View File

@@ -13,7 +13,10 @@ on:
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
jobs:
integration-tests:
# See https://runs-on.com/runners/linux/
@@ -195,6 +198,9 @@ jobs:
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
danswer/danswer-integration:test \
/app/tests/integration/tests \

View File

@@ -24,6 +24,8 @@ env:
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }}
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
# Slab
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
jobs:
connectors-check:

View File

@@ -32,7 +32,7 @@ To contribute to this project, please follow the
When opening a pull request, mention related issues and feel free to tag relevant maintainers.
Before creating a pull request please make sure that the new changes conform to the formatting and linting requirements.
See the [Formatting and Linting](#-formatting-and-linting) section for how to run these checks locally.
See the [Formatting and Linting](#formatting-and-linting) section for how to run these checks locally.
### Getting Help 🙋

View File

@@ -12,7 +12,7 @@
<a href="https://docs.danswer.dev/" target="_blank">
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
</a>
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2lcmqw703-071hBuZBfNEOGUsLa5PXvQ" target="_blank">
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA" target="_blank">
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
</a>
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
@@ -135,7 +135,7 @@ Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md
## ✨Contributors
<a href="https://github.com/aryn-ai/sycamore/graphs/contributors">
<a href="https://github.com/danswer-ai/danswer/graphs/contributors">
<img alt="contributors" src="https://contrib.rocks/image?repo=danswer-ai/danswer"/>
</a>

View File

@@ -73,6 +73,7 @@ RUN apt-get update && \
rm -rf /var/lib/apt/lists/* && \
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
# Pre-downloading models for setups with limited egress
RUN python -c "from tokenizers import Tokenizer; \
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"

View File

@@ -1,5 +1,5 @@
from sqlalchemy.engine.base import Connection
from typing import Any
from typing import Literal
import asyncio
from logging.config import fileConfig
import logging
@@ -8,6 +8,7 @@ from alembic import context
from sqlalchemy import pool
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.sql import text
from sqlalchemy.sql.schema import SchemaItem
from shared_configs.configs import MULTI_TENANT
from danswer.db.engine import build_connection_string
@@ -35,7 +36,18 @@ logger = logging.getLogger(__name__)
def include_object(
object: Any, name: str, type_: str, reflected: bool, compare_to: Any
object: SchemaItem,
name: str | None,
type_: Literal[
"schema",
"table",
"column",
"index",
"unique_constraint",
"foreign_key_constraint",
],
reflected: bool,
compare_to: SchemaItem | None,
) -> bool:
"""
Determines whether a database object should be included in migrations.

View File

@@ -0,0 +1,59 @@
"""display custom llm models
Revision ID: 177de57c21c9
Revises: 4ee1287bd26a
Create Date: 2024-11-21 11:49:04.488677
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy import and_
revision = "177de57c21c9"
down_revision = "4ee1287bd26a"
branch_labels = None
depends_on = None
depends_on = None
def upgrade() -> None:
conn = op.get_bind()
llm_provider = sa.table(
"llm_provider",
sa.column("id", sa.Integer),
sa.column("provider", sa.String),
sa.column("model_names", postgresql.ARRAY(sa.String)),
sa.column("display_model_names", postgresql.ARRAY(sa.String)),
)
excluded_providers = ["openai", "bedrock", "anthropic", "azure"]
providers_to_update = sa.select(
llm_provider.c.id,
llm_provider.c.model_names,
llm_provider.c.display_model_names,
).where(
and_(
~llm_provider.c.provider.in_(excluded_providers),
llm_provider.c.model_names.isnot(None),
)
)
results = conn.execute(providers_to_update).fetchall()
for provider_id, model_names, display_model_names in results:
if display_model_names is None:
display_model_names = []
combined_model_names = list(set(display_model_names + model_names))
update_stmt = (
llm_provider.update()
.where(llm_provider.c.id == provider_id)
.values(display_model_names=combined_model_names)
)
conn.execute(update_stmt)
def downgrade() -> None:
pass

View File

@@ -0,0 +1,280 @@
"""add_multiple_slack_bot_support
Revision ID: 4ee1287bd26a
Revises: 47e5bef3a1d7
Create Date: 2024-11-06 13:15:53.302644
"""
import logging
from typing import cast
from alembic import op
import sqlalchemy as sa
from sqlalchemy.orm import Session
from danswer.key_value_store.factory import get_kv_store
from danswer.db.models import SlackBot
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "4ee1287bd26a"
down_revision = "47e5bef3a1d7"
branch_labels: None = None
depends_on: None = None
# Configure logging
logger = logging.getLogger("alembic.runtime.migration")
logger.setLevel(logging.INFO)
def upgrade() -> None:
logger.info(f"{revision}: create_table: slack_bot")
# Create new slack_bot table
op.create_table(
"slack_bot",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.Column("enabled", sa.Boolean(), nullable=False, server_default="true"),
sa.Column("bot_token", sa.LargeBinary(), nullable=False),
sa.Column("app_token", sa.LargeBinary(), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("bot_token"),
sa.UniqueConstraint("app_token"),
)
# # Create new slack_channel_config table
op.create_table(
"slack_channel_config",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("slack_bot_id", sa.Integer(), nullable=True),
sa.Column("persona_id", sa.Integer(), nullable=True),
sa.Column("channel_config", postgresql.JSONB(), nullable=False),
sa.Column("response_type", sa.String(), nullable=False),
sa.Column(
"enable_auto_filters", sa.Boolean(), nullable=False, server_default="false"
),
sa.ForeignKeyConstraint(
["slack_bot_id"],
["slack_bot.id"],
),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# Handle existing Slack bot tokens first
logger.info(f"{revision}: Checking for existing Slack bot.")
bot_token = None
app_token = None
first_row_id = None
try:
tokens = cast(dict, get_kv_store().load("slack_bot_tokens_config_key"))
except Exception:
logger.warning("No existing Slack bot tokens found.")
tokens = {}
bot_token = tokens.get("bot_token")
app_token = tokens.get("app_token")
if bot_token and app_token:
logger.info(f"{revision}: Found bot and app tokens.")
session = Session(bind=op.get_bind())
new_slack_bot = SlackBot(
name="Slack Bot (Migrated)",
enabled=True,
bot_token=bot_token,
app_token=app_token,
)
session.add(new_slack_bot)
session.commit()
first_row_id = new_slack_bot.id
# Create a default bot if none exists
# This is in case there are no slack tokens but there are channels configured
op.execute(
sa.text(
"""
INSERT INTO slack_bot (name, enabled, bot_token, app_token)
SELECT 'Default Bot', true, '', ''
WHERE NOT EXISTS (SELECT 1 FROM slack_bot)
RETURNING id;
"""
)
)
# Get the bot ID to use (either from existing migration or newly created)
bot_id_query = sa.text(
"""
SELECT COALESCE(
:first_row_id,
(SELECT id FROM slack_bot ORDER BY id ASC LIMIT 1)
) as bot_id;
"""
)
result = op.get_bind().execute(bot_id_query, {"first_row_id": first_row_id})
bot_id = result.scalar()
# CTE (Common Table Expression) that transforms the old slack_bot_config table data
# This splits up the channel_names into their own rows
channel_names_cte = """
WITH channel_names AS (
SELECT
sbc.id as config_id,
sbc.persona_id,
sbc.response_type,
sbc.enable_auto_filters,
jsonb_array_elements_text(sbc.channel_config->'channel_names') as channel_name,
sbc.channel_config->>'respond_tag_only' as respond_tag_only,
sbc.channel_config->>'respond_to_bots' as respond_to_bots,
sbc.channel_config->'respond_member_group_list' as respond_member_group_list,
sbc.channel_config->'answer_filters' as answer_filters,
sbc.channel_config->'follow_up_tags' as follow_up_tags
FROM slack_bot_config sbc
)
"""
# Insert the channel names into the new slack_channel_config table
insert_statement = """
INSERT INTO slack_channel_config (
slack_bot_id,
persona_id,
channel_config,
response_type,
enable_auto_filters
)
SELECT
:bot_id,
channel_name.persona_id,
jsonb_build_object(
'channel_name', channel_name.channel_name,
'respond_tag_only',
COALESCE((channel_name.respond_tag_only)::boolean, false),
'respond_to_bots',
COALESCE((channel_name.respond_to_bots)::boolean, false),
'respond_member_group_list',
COALESCE(channel_name.respond_member_group_list, '[]'::jsonb),
'answer_filters',
COALESCE(channel_name.answer_filters, '[]'::jsonb),
'follow_up_tags',
COALESCE(channel_name.follow_up_tags, '[]'::jsonb)
),
channel_name.response_type,
channel_name.enable_auto_filters
FROM channel_names channel_name;
"""
op.execute(sa.text(channel_names_cte + insert_statement).bindparams(bot_id=bot_id))
# Clean up old tokens if they existed
try:
if bot_token and app_token:
logger.info(f"{revision}: Removing old bot and app tokens.")
get_kv_store().delete("slack_bot_tokens_config_key")
except Exception:
logger.warning("tried to delete tokens in dynamic config but failed")
# Rename the table
op.rename_table(
"slack_bot_config__standard_answer_category",
"slack_channel_config__standard_answer_category",
)
# Rename the column
op.alter_column(
"slack_channel_config__standard_answer_category",
"slack_bot_config_id",
new_column_name="slack_channel_config_id",
)
# Drop the table with CASCADE to handle dependent objects
op.execute("DROP TABLE slack_bot_config CASCADE")
logger.info(f"{revision}: Migration complete.")
def downgrade() -> None:
# Recreate the old slack_bot_config table
op.create_table(
"slack_bot_config",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("persona_id", sa.Integer(), nullable=True),
sa.Column("channel_config", postgresql.JSONB(), nullable=False),
sa.Column("response_type", sa.String(), nullable=False),
sa.Column("enable_auto_filters", sa.Boolean(), nullable=False),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# Migrate data back to the old format
# Group by persona_id to combine channel names back into arrays
op.execute(
sa.text(
"""
INSERT INTO slack_bot_config (
persona_id,
channel_config,
response_type,
enable_auto_filters
)
SELECT DISTINCT ON (persona_id)
persona_id,
jsonb_build_object(
'channel_names', (
SELECT jsonb_agg(c.channel_config->>'channel_name')
FROM slack_channel_config c
WHERE c.persona_id = scc.persona_id
),
'respond_tag_only', (channel_config->>'respond_tag_only')::boolean,
'respond_to_bots', (channel_config->>'respond_to_bots')::boolean,
'respond_member_group_list', channel_config->'respond_member_group_list',
'answer_filters', channel_config->'answer_filters',
'follow_up_tags', channel_config->'follow_up_tags'
),
response_type,
enable_auto_filters
FROM slack_channel_config scc
WHERE persona_id IS NOT NULL;
"""
)
)
# Rename the table back
op.rename_table(
"slack_channel_config__standard_answer_category",
"slack_bot_config__standard_answer_category",
)
# Rename the column back
op.alter_column(
"slack_bot_config__standard_answer_category",
"slack_channel_config_id",
new_column_name="slack_bot_config_id",
)
# Try to save the first bot's tokens back to KV store
try:
first_bot = (
op.get_bind()
.execute(
sa.text(
"SELECT bot_token, app_token FROM slack_bot ORDER BY id LIMIT 1"
)
)
.first()
)
if first_bot and first_bot.bot_token and first_bot.app_token:
tokens = {
"bot_token": first_bot.bot_token,
"app_token": first_bot.app_token,
}
get_kv_store().store("slack_bot_tokens_config_key", tokens)
except Exception:
logger.warning("Failed to save tokens back to KV store")
# Drop the new tables in reverse order
op.drop_table("slack_channel_config")
op.drop_table("slack_bot")

View File

@@ -0,0 +1,45 @@
"""remove default bot
Revision ID: 6d562f86c78b
Revises: 177de57c21c9
Create Date: 2024-11-22 11:51:29.331336
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "6d562f86c78b"
down_revision = "177de57c21c9"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
sa.text(
"""
DELETE FROM slack_bot
WHERE name = 'Default Bot'
AND bot_token = ''
AND app_token = ''
AND NOT EXISTS (
SELECT 1 FROM slack_channel_config
WHERE slack_channel_config.slack_bot_id = slack_bot.id
)
"""
)
)
def downgrade() -> None:
op.execute(
sa.text(
"""
INSERT INTO slack_bot (name, enabled, bot_token, app_token)
SELECT 'Default Bot', true, '', ''
WHERE NOT EXISTS (SELECT 1 FROM slack_bot)
RETURNING id;
"""
)
)

View File

@@ -9,8 +9,8 @@ from alembic import op
import sqlalchemy as sa
from danswer.db.models import IndexModelStatus
from danswer.search.enums import RecencyBiasSetting
from danswer.search.enums import SearchType
from danswer.context.search.enums import RecencyBiasSetting
from danswer.context.search.enums import SearchType
# revision identifiers, used by Alembic.
revision = "776b3bbe9092"

View File

@@ -0,0 +1,35 @@
"""add web ui option to slack config
Revision ID: 93560ba1b118
Revises: 6d562f86c78b
Create Date: 2024-11-24 06:36:17.490612
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "93560ba1b118"
down_revision = "6d562f86c78b"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add show_continue_in_web_ui with default False to all existing channel_configs
op.execute(
"""
UPDATE slack_channel_config
SET channel_config = channel_config || '{"show_continue_in_web_ui": false}'::jsonb
WHERE NOT channel_config ? 'show_continue_in_web_ui'
"""
)
def downgrade() -> None:
# Remove show_continue_in_web_ui from all channel_configs
op.execute(
"""
UPDATE slack_channel_config
SET channel_config = channel_config - 'show_continue_in_web_ui'
"""
)

View File

@@ -7,6 +7,7 @@ Create Date: 2024-10-26 13:06:06.937969
"""
from alembic import op
from sqlalchemy.orm import Session
from sqlalchemy import text
# Import your models and constants
from danswer.db.models import (
@@ -15,7 +16,6 @@ from danswer.db.models import (
Credential,
IndexAttempt,
)
from danswer.configs.constants import DocumentSource
# revision identifiers, used by Alembic.
@@ -30,13 +30,11 @@ def upgrade() -> None:
bind = op.get_bind()
session = Session(bind=bind)
connectors_to_delete = (
session.query(Connector)
.filter(Connector.source == DocumentSource.REQUESTTRACKER)
.all()
# Get connectors using raw SQL
result = bind.execute(
text("SELECT id FROM connector WHERE source = 'requesttracker'")
)
connector_ids = [connector.id for connector in connectors_to_delete]
connector_ids = [row[0] for row in result]
if connector_ids:
cc_pairs_to_delete = (

View File

@@ -1,7 +1,7 @@
"""add creator to cc pair
Revision ID: 9cf5c00f72fe
Revises: c0fd6e4da83a
Revises: 26b931506ecb
Create Date: 2024-11-12 15:16:42.682902
"""

View File

@@ -0,0 +1,36 @@
"""Combine Search and Chat
Revision ID: 9f696734098f
Revises: a8c2065484e6
Create Date: 2024-11-27 15:32:19.694972
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "9f696734098f"
down_revision = "a8c2065484e6"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column("chat_session", "description", nullable=True)
op.drop_column("chat_session", "one_shot")
op.drop_column("slack_channel_config", "response_type")
def downgrade() -> None:
op.execute("UPDATE chat_session SET description = '' WHERE description IS NULL")
op.alter_column("chat_session", "description", nullable=False)
op.add_column(
"chat_session",
sa.Column("one_shot", sa.Boolean(), nullable=False, server_default=sa.false()),
)
op.add_column(
"slack_channel_config",
sa.Column(
"response_type", sa.String(), nullable=False, server_default="citations"
),
)

View File

@@ -0,0 +1,27 @@
"""add auto scroll to user model
Revision ID: a8c2065484e6
Revises: abe7378b8217
Create Date: 2024-11-22 17:34:09.690295
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "a8c2065484e6"
down_revision = "abe7378b8217"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column("auto_scroll", sa.Boolean(), nullable=True, server_default=None),
)
def downgrade() -> None:
op.drop_column("user", "auto_scroll")

View File

@@ -0,0 +1,30 @@
"""add indexing trigger to cc_pair
Revision ID: abe7378b8217
Revises: 6d562f86c78b
Create Date: 2024-11-26 19:09:53.481171
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "abe7378b8217"
down_revision = "93560ba1b118"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column(
"indexing_trigger",
sa.Enum("UPDATE", "REINDEX", name="indexingmode", native_enum=False),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("connector_credential_pair", "indexing_trigger")

View File

@@ -1,5 +1,6 @@
import asyncio
from logging.config import fileConfig
from typing import Literal
from sqlalchemy import pool
from sqlalchemy.engine import Connection
@@ -37,8 +38,15 @@ EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
def include_object(
object: SchemaItem,
name: str,
type_: str,
name: str | None,
type_: Literal[
"schema",
"table",
"column",
"index",
"unique_constraint",
"foreign_key_constraint",
],
reflected: bool,
compare_to: SchemaItem | None,
) -> bool:

View File

@@ -18,6 +18,11 @@ class ExternalAccess:
@dataclass(frozen=True)
class DocExternalAccess:
"""
This is just a class to wrap the external access and the document ID
together. It's used for syncing document permissions to Redis.
"""
external_access: ExternalAccess
# The document ID
doc_id: str

View File

@@ -0,0 +1,42 @@
from collections.abc import Hashable
from typing import Union
from langgraph.types import Send
from danswer.agent_search.core_qa_graph.states import BaseQAState
from danswer.agent_search.primary_graph.states import RetrieverState
from danswer.agent_search.primary_graph.states import VerifierState
def sub_continue_to_verifier(state: BaseQAState) -> Union[Hashable, list[Hashable]]:
# Routes each de-douped retrieved doc to the verifier step - in parallel
# Notice the 'Send()' API that takes care of the parallelization
return [
Send(
"sub_verifier",
VerifierState(
document=doc,
#question=state["original_question"],
question=state["sub_question_str"],
graph_start_time=state["graph_start_time"],
),
)
for doc in state["sub_question_deduped_retrieval_docs"]
]
def sub_continue_to_retrieval(state: BaseQAState) -> Union[Hashable, list[Hashable]]:
# Routes re-written queries to the (parallel) retrieval steps
# Notice the 'Send()' API that takes care of the parallelization
rewritten_queries = state["sub_question_search_queries"].rewritten_queries + [state["sub_question_str"]]
return [
Send(
"sub_custom_retrieve",
RetrieverState(
rewritten_query=query,
graph_start_time=state["graph_start_time"],
),
)
for query in rewritten_queries
]

View File

@@ -0,0 +1,132 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from danswer.agent_search.core_qa_graph.edges import sub_continue_to_retrieval
from danswer.agent_search.core_qa_graph.edges import sub_continue_to_verifier
from danswer.agent_search.core_qa_graph.nodes.combine_retrieved_docs import (
sub_combine_retrieved_docs,
)
from danswer.agent_search.core_qa_graph.nodes.custom_retrieve import (
sub_custom_retrieve,
)
from danswer.agent_search.core_qa_graph.nodes.dummy import sub_dummy
from danswer.agent_search.core_qa_graph.nodes.final_format import (
sub_final_format,
)
from danswer.agent_search.core_qa_graph.nodes.generate import sub_generate
from danswer.agent_search.core_qa_graph.nodes.qa_check import sub_qa_check
from danswer.agent_search.core_qa_graph.nodes.rewrite import sub_rewrite
from danswer.agent_search.core_qa_graph.nodes.verifier import sub_verifier
from danswer.agent_search.core_qa_graph.states import BaseQAOutputState
from danswer.agent_search.core_qa_graph.states import BaseQAState
from danswer.agent_search.core_qa_graph.states import CoreQAInputState
def build_core_qa_graph() -> StateGraph:
sub_answers_initial = StateGraph(
state_schema=BaseQAState,
output=BaseQAOutputState,
)
### Add nodes ###
sub_answers_initial.add_node(node="sub_dummy", action=sub_dummy)
sub_answers_initial.add_node(node="sub_rewrite", action=sub_rewrite)
sub_answers_initial.add_node(
node="sub_custom_retrieve",
action=sub_custom_retrieve,
)
sub_answers_initial.add_node(
node="sub_combine_retrieved_docs",
action=sub_combine_retrieved_docs,
)
sub_answers_initial.add_node(
node="sub_verifier",
action=sub_verifier,
)
sub_answers_initial.add_node(
node="sub_generate",
action=sub_generate,
)
sub_answers_initial.add_node(
node="sub_qa_check",
action=sub_qa_check,
)
sub_answers_initial.add_node(
node="sub_final_format",
action=sub_final_format,
)
### Add edges ###
sub_answers_initial.add_edge(START, "sub_dummy")
sub_answers_initial.add_edge("sub_dummy", "sub_rewrite")
sub_answers_initial.add_conditional_edges(
source="sub_rewrite",
path=sub_continue_to_retrieval,
)
sub_answers_initial.add_edge(
start_key="sub_custom_retrieve",
end_key="sub_combine_retrieved_docs",
)
sub_answers_initial.add_conditional_edges(
source="sub_combine_retrieved_docs",
path=sub_continue_to_verifier,
path_map=["sub_verifier"],
)
sub_answers_initial.add_edge(
start_key="sub_verifier",
end_key="sub_generate",
)
sub_answers_initial.add_edge(
start_key="sub_generate",
end_key="sub_qa_check",
)
sub_answers_initial.add_edge(
start_key="sub_qa_check",
end_key="sub_final_format",
)
sub_answers_initial.add_edge(
start_key="sub_final_format",
end_key=END,
)
# sub_answers_graph = sub_answers_initial.compile()
return sub_answers_initial
if __name__ == "__main__":
# q = "Whose music is kind of hard to easily enjoy?"
# q = "What is voice leading?"
# q = "What are the types of motions in music?"
# q = "What are key elements of music theory?"
# q = "How can I best understand music theory using voice leading?"
q = "What makes good music?"
# q = "types of motions in music"
# q = "What is the relationship between music and physics?"
# q = "Can you compare various grunge styles?"
# q = "Why is quantum gravity so hard?"
inputs = CoreQAInputState(
original_question=q,
sub_question_str=q,
)
sub_answers_graph = build_core_qa_graph()
compiled_sub_answers = sub_answers_graph.compile()
output = compiled_sub_answers.invoke(inputs)
print("\nOUTPUT:")
print(output.keys())
for key, value in output.items():
if key in [
"sub_question_answer",
"sub_question_str",
"sub_qas",
"initial_sub_qas",
"sub_question_answer",
]:
print(f"{key}: {value}")

View File

@@ -0,0 +1,36 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.core_qa_graph.states import BaseQAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.context.search.models import InferenceSection
def sub_combine_retrieved_docs(state: BaseQAState) -> dict[str, Any]:
"""
Dedupe the retrieved docs.
"""
node_start_time = datetime.now()
sub_question_base_retrieval_docs = state["sub_question_base_retrieval_docs"]
print(f"Number of docs from steps: {len(sub_question_base_retrieval_docs)}")
dedupe_docs: list[InferenceSection] = []
for base_retrieval_doc in sub_question_base_retrieval_docs:
if not any(
base_retrieval_doc.center_chunk.chunk_id == doc.center_chunk.chunk_id
for doc in dedupe_docs
):
dedupe_docs.append(base_retrieval_doc)
print(f"Number of deduped docs: {len(dedupe_docs)}")
return {
"sub_question_deduped_retrieval_docs": dedupe_docs,
"log_messages": generate_log_message(
message="sub - combine_retrieved_docs (dedupe)",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,66 @@
import datetime
from typing import Any
from danswer.agent_search.primary_graph.states import RetrieverState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import SearchRequest
from danswer.context.search.pipeline import SearchPipeline
from danswer.db.engine import get_session_context_manager
from danswer.llm.factory import get_default_llms
def sub_custom_retrieve(state: RetrieverState) -> dict[str, Any]:
"""
Retrieve documents
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, documents, that contains retrieved documents
"""
print("---RETRIEVE SUB---")
node_start_time = datetime.datetime.now()
rewritten_query = state["rewritten_query"]
# Retrieval
# TODO: add the actual retrieval, probably from search_tool.run()
documents: list[InferenceSection] = []
llm, fast_llm = get_default_llms()
with get_session_context_manager() as db_session:
documents = SearchPipeline(
search_request=SearchRequest(
query=rewritten_query,
),
user=None,
llm=llm,
fast_llm=fast_llm,
db_session=db_session,
)
reranked_docs = documents.reranked_sections
# initial metric to measure fit TODO: implement metric properly
top_1_score = reranked_docs[0].center_chunk.score
top_5_score = sum([doc.center_chunk.score for doc in reranked_docs[:5]]) / 5
top_10_score = sum([doc.center_chunk.score for doc in reranked_docs[:10]]) / 10
fit_score = 1/3 * (top_1_score + top_5_score + top_10_score)
chunk_ids = {'query': rewritten_query,
'chunk_ids': [doc.center_chunk.chunk_id for doc in reranked_docs]}
return {
"sub_question_base_retrieval_docs": reranked_docs,
"sub_chunk_ids": [chunk_ids],
"log_messages": generate_log_message(
message=f"sub - custom_retrieve, fit_score: {fit_score}",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,24 @@
import datetime
from typing import Any
from danswer.agent_search.core_qa_graph.states import BaseQAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def sub_dummy(state: BaseQAState) -> dict[str, Any]:
"""
Dummy step
"""
print("---Sub Dummy---")
node_start_time = datetime.datetime.now()
return {
"graph_start_time": node_start_time,
"log_messages": generate_log_message(
message="sub - dummy",
node_start_time=node_start_time,
graph_start_time=node_start_time,
),
}

View File

@@ -0,0 +1,22 @@
from typing import Any
from danswer.agent_search.core_qa_graph.states import BaseQAState
def sub_final_format(state: BaseQAState) -> dict[str, Any]:
"""
Create the final output for the QA subgraph
"""
print("---BASE FINAL FORMAT---")
return {
"sub_qas": [
{
"sub_question": state["sub_question_str"],
"sub_answer": state["sub_question_answer"],
"sub_answer_check": state["sub_question_answer_check"],
}
],
"log_messages": state["log_messages"],
}

View File

@@ -0,0 +1,91 @@
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from danswer.agent_search.core_qa_graph.states import BaseQAState
from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT
from danswer.agent_search.shared_graph_utils.utils import format_docs
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.llm.factory import get_default_llms
def sub_generate(state: BaseQAState) -> dict[str, Any]:
"""
Generate answer
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
print("---GENERATE---")
# Create sub-query results
verified_chunks = [chunk.center_chunk.chunk_id for chunk in state["sub_question_verified_retrieval_docs"]]
result_dict = {}
chunk_id_dicts = state["sub_chunk_ids"]
expanded_chunks = []
original_chunks = []
for chunk_id_dict in chunk_id_dicts:
sub_question = chunk_id_dict['query']
verified_sq_chunks = [chunk_id for chunk_id in chunk_id_dict['chunk_ids'] if chunk_id in verified_chunks]
if sub_question != state["original_question"]:
expanded_chunks += verified_sq_chunks
else:
result_dict['ORIGINAL'] = len(verified_sq_chunks)
original_chunks += verified_sq_chunks
result_dict[sub_question[:30]] = len(verified_sq_chunks)
expansion_chunks = set(expanded_chunks)
num_expansion_chunks = sum([1 for chunk_id in expansion_chunks if chunk_id in verified_chunks])
num_original_relevant_chunks = len(original_chunks)
num_missed_relevant_chunks = sum([1 for chunk_id in original_chunks if chunk_id not in expansion_chunks])
num_gained_relevant_chunks = sum([1 for chunk_id in expansion_chunks if chunk_id not in original_chunks])
result_dict['expansion_chunks'] = num_expansion_chunks
print(result_dict)
node_start_time = datetime.now()
question = state["sub_question_str"]
docs = state["sub_question_verified_retrieval_docs"]
print(f"Number of verified retrieval docs: {len(docs)}")
# Only take the top 10 docs.
# TODO: Make this dynamic or use config param?
top_10_docs = docs[-10:]
msg = [
HumanMessage(
content=BASE_RAG_PROMPT.format(question=question, context=format_docs(top_10_docs))
)
]
# Grader
_, fast_llm = get_default_llms()
response = list(
fast_llm.stream(
prompt=msg,
# structured_response_format=None,
)
)
answer_str = merge_message_runs(response, chunk_separator="")[0].content
return {
"sub_question_answer": answer_str,
"log_messages": generate_log_message(
message="base - generate",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,51 @@
import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from danswer.agent_search.core_qa_graph.states import BaseQAState
from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.llm.factory import get_default_llms
def sub_qa_check(state: BaseQAState) -> dict[str, Any]:
"""
Check if the sub-question answer is satisfactory.
Args:
state: The current SubQAState containing the sub-question and its answer
Returns:
dict containing the check result and log message
"""
node_start_time = datetime.datetime.now()
msg = [
HumanMessage(
content=BASE_CHECK_PROMPT.format(
question=state["sub_question_str"],
base_answer=state["sub_question_answer"],
)
)
]
_, fast_llm = get_default_llms()
response = list(
fast_llm.stream(
prompt=msg,
# structured_response_format=None,
)
)
response_str = merge_message_runs(response, chunk_separator="")[0].content
return {
"sub_question_answer_check": response_str,
"base_answer_messages": generate_log_message(
message="sub - qa_check",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,74 @@
import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from danswer.agent_search.core_qa_graph.states import BaseQAState
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
from danswer.agent_search.shared_graph_utils.prompts import (
REWRITE_PROMPT_MULTI_ORIGINAL,
)
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.llm.factory import get_default_llms
def sub_rewrite(state: BaseQAState) -> dict[str, Any]:
"""
Transform the initial question into more suitable search queries.
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
print("---SUB TRANSFORM QUERY---")
node_start_time = datetime.datetime.now()
# messages = state["base_answer_messages"]
question = state["sub_question_str"]
msg = [
HumanMessage(
content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question),
)
]
"""
msg = [
HumanMessage(
content=REWRITE_PROMPT_MULTI.format(question=question),
)
]
"""
_, fast_llm = get_default_llms()
llm_response_list = list(
fast_llm.stream(
prompt=msg,
# structured_response_format={"type": "json_object", "schema": RewrittenQueries.model_json_schema()},
# structured_response_format=RewrittenQueries.model_json_schema(),
)
)
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
print(f"llm_response: {llm_response}")
rewritten_queries = llm_response.split("--")
# rewritten_queries = [llm_response.split("\n")[0]]
print(f"rewritten_queries: {rewritten_queries}")
rewritten_queries = RewrittenQueries(rewritten_queries=rewritten_queries)
return {
"sub_question_search_queries": rewritten_queries,
"log_messages": generate_log_message(
message="sub - rewrite",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,64 @@
import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from danswer.agent_search.primary_graph.states import VerifierState
from danswer.agent_search.shared_graph_utils.models import BinaryDecision
from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.llm.factory import get_default_llms
def sub_verifier(state: VerifierState) -> dict[str, Any]:
"""
Check whether the document is relevant for the original user question
Args:
state (VerifierState): The current state
Returns:
dict: ict: The updated state with the final decision
"""
# print("---VERIFY QUTPUT---")
node_start_time = datetime.datetime.now()
question = state["question"]
document_content = state["document"].combined_content
msg = [
HumanMessage(
content=VERIFIER_PROMPT.format(
question=question, document_content=document_content
)
)
]
# Grader
llm, fast_llm = get_default_llms()
response = list(
llm.stream(
prompt=msg,
# structured_response_format=BinaryDecision.model_json_schema(),
)
)
response_string = merge_message_runs(response, chunk_separator="")[0].content
# Convert string response to proper dictionary format
decision_dict = {"decision": response_string.lower()}
formatted_response = BinaryDecision.model_validate(decision_dict)
print(f"Verification end time: {datetime.datetime.now()}")
return {
"sub_question_verified_retrieval_docs": [state["document"]]
if formatted_response.decision == "yes"
else [],
"log_messages": generate_log_message(
message=f"sub - verifier: {formatted_response.decision}",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,90 @@
import operator
from collections.abc import Sequence
from datetime import datetime
from typing import Annotated
from typing import TypedDict
from langchain_core.messages import BaseMessage
from langgraph.graph.message import add_messages
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
from danswer.context.search.models import InferenceSection
from danswer.llm.interfaces import LLM
class SubQuestionRetrieverState(TypedDict):
# The state for the parallel Retrievers. They each need to see only one query
sub_question_rewritten_query: str
class SubQuestionVerifierState(TypedDict):
# The state for the parallel verification step. Each node execution need to see only one question/doc pair
sub_question_document: InferenceSection
sub_question: str
class CoreQAInputState(TypedDict):
sub_question_str: str
original_question: str
class BaseQAState(TypedDict):
# The 'core SubQuestion' state.
original_question: str
graph_start_time: datetime
# start time for parallel initial sub-questionn thread
sub_query_start_time: datetime
sub_question_rewritten_queries: list[str]
sub_question_str: str
sub_question_search_queries: RewrittenQueries
sub_question_nr: int
sub_chunk_ids: Annotated[Sequence[dict], operator.add]
sub_question_base_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_deduped_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_verified_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_reranked_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
sub_question_answer: str
sub_question_answer_check: str
log_messages: Annotated[Sequence[BaseMessage], add_messages]
sub_qas: Annotated[Sequence[dict], operator.add]
# Answers sent back to core
initial_sub_qas: Annotated[Sequence[dict], operator.add]
primary_llm: LLM
fast_llm: LLM
class BaseQAOutputState(TypedDict):
# The 'SubQuestion' output state. Removes all the intermediate states
sub_question_rewritten_queries: list[str]
sub_question_str: str
sub_question_search_queries: list[str]
sub_question_nr: int
# Answers sent back to core
sub_qas: Annotated[Sequence[dict], operator.add]
# Answers sent back to core
initial_sub_qas: Annotated[Sequence[dict], operator.add]
sub_question_base_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_deduped_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_verified_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_reranked_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
sub_question_answer: str
sub_question_answer_check: str
log_messages: Annotated[Sequence[BaseMessage], add_messages]

View File

@@ -0,0 +1,46 @@
from collections.abc import Hashable
from typing import Union
from langgraph.types import Send
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
from danswer.agent_search.primary_graph.states import RetrieverState
from danswer.agent_search.primary_graph.states import VerifierState
def sub_continue_to_verifier(state: ResearchQAState) -> Union[Hashable, list[Hashable]]:
# Routes each de-douped retrieved doc to the verifier step - in parallel
# Notice the 'Send()' API that takes care of the parallelization
return [
Send(
"sub_verifier",
VerifierState(
document=doc,
question=state["sub_question"],
primary_llm=state["primary_llm"],
fast_llm=state["fast_llm"],
graph_start_time=state["graph_start_time"],
),
)
for doc in state["sub_question_base_retrieval_docs"]
]
def sub_continue_to_retrieval(
state: ResearchQAState,
) -> Union[Hashable, list[Hashable]]:
# Routes re-written queries to the (parallel) retrieval steps
# Notice the 'Send()' API that takes care of the parallelization
return [
Send(
"sub_custom_retrieve",
RetrieverState(
rewritten_query=query,
primary_llm=state["primary_llm"],
fast_llm=state["fast_llm"],
graph_start_time=state["graph_start_time"],
),
)
for query in state["sub_question_rewritten_queries"]
]

View File

@@ -0,0 +1,93 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from danswer.agent_search.deep_qa_graph.edges import sub_continue_to_retrieval
from danswer.agent_search.deep_qa_graph.edges import sub_continue_to_verifier
from danswer.agent_search.deep_qa_graph.nodes.combine_retrieved_docs import (
sub_combine_retrieved_docs,
)
from danswer.agent_search.deep_qa_graph.nodes.custom_retrieve import sub_custom_retrieve
from danswer.agent_search.deep_qa_graph.nodes.dummy import sub_dummy
from danswer.agent_search.deep_qa_graph.nodes.final_format import sub_final_format
from danswer.agent_search.deep_qa_graph.nodes.generate import sub_generate
from danswer.agent_search.deep_qa_graph.nodes.qa_check import sub_qa_check
from danswer.agent_search.deep_qa_graph.nodes.verifier import sub_verifier
from danswer.agent_search.deep_qa_graph.states import ResearchQAOutputState
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
def build_deep_qa_graph() -> StateGraph:
# Define the nodes we will cycle between
sub_answers = StateGraph(state_schema=ResearchQAState, output=ResearchQAOutputState)
### Add Nodes ###
# Dummy node for initial processing
sub_answers.add_node(node="sub_dummy", action=sub_dummy)
# The retrieval step
sub_answers.add_node(node="sub_custom_retrieve", action=sub_custom_retrieve)
# The dedupe step
sub_answers.add_node(
node="sub_combine_retrieved_docs", action=sub_combine_retrieved_docs
)
# Verifying retrieved information
sub_answers.add_node(node="sub_verifier", action=sub_verifier)
# Generating the response
sub_answers.add_node(node="sub_generate", action=sub_generate)
# Checking the quality of the answer
sub_answers.add_node(node="sub_qa_check", action=sub_qa_check)
# Final formatting of the response
sub_answers.add_node(node="sub_final_format", action=sub_final_format)
### Add Edges ###
# Generate multiple sub-questions
sub_answers.add_edge(start_key=START, end_key="sub_rewrite")
# For each sub-question, perform a retrieval in parallel
sub_answers.add_conditional_edges(
source="sub_rewrite",
path=sub_continue_to_retrieval,
path_map=["sub_custom_retrieve"],
)
# Combine the retrieved docs for each sub-question from the parallel retrievals
sub_answers.add_edge(
start_key="sub_custom_retrieve", end_key="sub_combine_retrieved_docs"
)
# Go over all of the combined retrieved docs and verify them against the original question
sub_answers.add_conditional_edges(
source="sub_combine_retrieved_docs",
path=sub_continue_to_verifier,
path_map=["sub_verifier"],
)
# Generate an answer for each verified retrieved doc
sub_answers.add_edge(start_key="sub_verifier", end_key="sub_generate")
# Check the quality of the answer
sub_answers.add_edge(start_key="sub_generate", end_key="sub_qa_check")
sub_answers.add_edge(start_key="sub_qa_check", end_key="sub_final_format")
sub_answers.add_edge(start_key="sub_final_format", end_key=END)
return sub_answers
if __name__ == "__main__":
# TODO: add the actual question
inputs = {"sub_question": "Whose music is kind of hard to easily enjoy?"}
sub_answers_graph = build_deep_qa_graph()
compiled_sub_answers = sub_answers_graph.compile()
output = compiled_sub_answers.invoke(inputs)
print("\nOUTPUT:")
print(output)

View File

@@ -0,0 +1,31 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def sub_combine_retrieved_docs(state: ResearchQAState) -> dict[str, Any]:
"""
Dedupe the retrieved docs.
"""
node_start_time = datetime.now()
sub_question_base_retrieval_docs = state["sub_question_base_retrieval_docs"]
print(f"Number of docs from steps: {len(sub_question_base_retrieval_docs)}")
dedupe_docs = []
for base_retrieval_doc in sub_question_base_retrieval_docs:
if base_retrieval_doc not in dedupe_docs:
dedupe_docs.append(base_retrieval_doc)
print(f"Number of deduped docs: {len(dedupe_docs)}")
return {
"sub_question_deduped_retrieval_docs": dedupe_docs,
"log_messages": generate_log_message(
message="sub - combine_retrieved_docs (dedupe)",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,33 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.primary_graph.states import RetrieverState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.context.search.models import InferenceSection
def sub_custom_retrieve(state: RetrieverState) -> dict[str, Any]:
"""
Retrieve documents
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, documents, that contains retrieved documents
"""
print("---RETRIEVE SUB---")
node_start_time = datetime.now()
# Retrieval
# TODO: add the actual retrieval, probably from search_tool.run()
documents: list[InferenceSection] = []
return {
"sub_question_base_retrieval_docs": documents,
"log_messages": generate_log_message(
message="sub - custom_retrieve",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,21 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.core_qa_graph.states import BaseQAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def sub_dummy(state: BaseQAState) -> dict[str, Any]:
"""
Dummy step
"""
print("---Sub Dummy---")
return {
"log_messages": generate_log_message(
message="sub - dummy",
node_start_time=datetime.now(),
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,31 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def sub_final_format(state: ResearchQAState) -> dict[str, Any]:
"""
Create the final output for the QA subgraph
"""
print("---SUB FINAL FORMAT---")
node_start_time = datetime.now()
return {
# TODO: Type this
"sub_qas": [
{
"sub_question": state["sub_question"],
"sub_answer": state["sub_question_answer"],
"sub_question_nr": state["sub_question_nr"],
"sub_answer_check": state["sub_question_answer_check"],
}
],
"log_messages": generate_log_message(
message="sub - final format",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,56 @@
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT
from danswer.agent_search.shared_graph_utils.utils import format_docs
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def sub_generate(state: ResearchQAState) -> dict[str, Any]:
"""
Generate answer
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
print("---SUB GENERATE---")
node_start_time = datetime.now()
question = state["sub_question"]
docs = state["sub_question_verified_retrieval_docs"]
print(f"Number of verified retrieval docs for sub-question: {len(docs)}")
msg = [
HumanMessage(
content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs))
)
]
# Grader
if len(docs) > 0:
model = state["fast_llm"]
response = list(
model.stream(
prompt=msg,
)
)
response_str = merge_message_runs(response, chunk_separator="")[0].content
else:
response_str = ""
return {
"sub_question_answer": response_str,
"log_messages": generate_log_message(
message="sub - generate",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,57 @@
import json
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.deep_qa_graph.prompts import SUB_CHECK_PROMPT
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
from danswer.agent_search.shared_graph_utils.models import BinaryDecision
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def sub_qa_check(state: ResearchQAState) -> dict[str, Any]:
"""
Check whether the final output satisfies the original user question
Args:
state (messages): The current state
Returns:
dict: The updated state with the final decision
"""
print("---CHECK SUB QUTPUT---")
node_start_time = datetime.now()
sub_answer = state["sub_question_answer"]
sub_question = state["sub_question"]
msg = [
HumanMessage(
content=SUB_CHECK_PROMPT.format(
sub_question=sub_question, sub_answer=sub_answer
)
)
]
# Grader
model = state["fast_llm"]
response = list(
model.stream(
prompt=msg,
structured_response_format=BinaryDecision.model_json_schema(),
)
)
raw_response = json.loads(response[0].pretty_repr())
formatted_response = BinaryDecision.model_validate(raw_response)
return {
"sub_question_answer_check": formatted_response.decision,
"log_messages": generate_log_message(
message=f"sub - qa check: {formatted_response.decision}",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,64 @@
import json
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.llm.interfaces import LLM
def sub_rewrite(state: ResearchQAState) -> dict[str, Any]:
"""
Transform the initial question into more suitable search queries.
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
print("---SUB TRANSFORM QUERY---")
node_start_time = datetime.now()
question = state["sub_question"]
msg = [
HumanMessage(
content=REWRITE_PROMPT_MULTI.format(question=question),
)
]
fast_llm: LLM = state["fast_llm"]
llm_response = list(
fast_llm.stream(
prompt=msg,
structured_response_format=RewrittenQueries.model_json_schema(),
)
)
# Get the rewritten queries in a defined format
rewritten_queries: RewrittenQueries = json.loads(llm_response[0].pretty_repr())
print(f"rewritten_queries: {rewritten_queries}")
rewritten_queries = RewrittenQueries(
rewritten_queries=[
"music hard to listen to",
"Music that is not fun or pleasant",
]
)
print(f"hardcoded rewritten_queries: {rewritten_queries}")
return {
"sub_question_rewritten_queries": rewritten_queries,
"log_messages": generate_log_message(
message="sub - rewrite",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,59 @@
import json
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.primary_graph.states import VerifierState
from danswer.agent_search.shared_graph_utils.models import BinaryDecision
from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def sub_verifier(state: VerifierState) -> dict[str, Any]:
"""
Check whether the document is relevant for the original user question
Args:
state (VerifierState): The current state
Returns:
dict: ict: The updated state with the final decision
"""
print("---SUB VERIFY QUTPUT---")
node_start_time = datetime.now()
question = state["question"]
document_content = state["document"].combined_content
msg = [
HumanMessage(
content=VERIFIER_PROMPT.format(
question=question, document_content=document_content
)
)
]
# Grader
model = state["fast_llm"]
response = list(
model.stream(
prompt=msg,
structured_response_format=BinaryDecision.model_json_schema(),
)
)
raw_response = json.loads(response[0].pretty_repr())
formatted_response = BinaryDecision.model_validate(raw_response)
return {
"deduped_retrieval_docs": [state["document"]]
if formatted_response.decision == "yes"
else [],
"log_messages": generate_log_message(
message=f"core - verifier: {formatted_response.decision}",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,13 @@
SUB_CHECK_PROMPT = """ \n
Please check whether the suggested answer seems to address the original question.
Please only answer with 'yes' or 'no' \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Here is the proposed answer:
\n ------- \n
{base_answer}
\n ------- \n
Please answer with yes or no:"""

View File

@@ -0,0 +1,64 @@
import operator
from collections.abc import Sequence
from datetime import datetime
from typing import Annotated
from typing import TypedDict
from langchain_core.messages import BaseMessage
from langgraph.graph.message import add_messages
from danswer.context.search.models import InferenceSection
from danswer.llm.interfaces import LLM
class ResearchQAState(TypedDict):
# The 'core SubQuestion' state.
original_question: str
graph_start_time: datetime
sub_question_rewritten_queries: list[str]
sub_question: str
sub_question_nr: int
sub_question_base_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_deduped_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_verified_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_reranked_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
sub_question_answer: str
sub_question_answer_check: str
log_messages: Annotated[Sequence[BaseMessage], add_messages]
sub_qas: Annotated[Sequence[dict], operator.add]
primary_llm: LLM
fast_llm: LLM
class ResearchQAOutputState(TypedDict):
# The 'SubQuestion' output state. Removes all the intermediate states
sub_question_rewritten_queries: list[str]
sub_question: str
sub_question_nr: int
# Answers sent back to core
sub_qas: Annotated[Sequence[dict], operator.add]
sub_question_base_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_deduped_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_verified_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_reranked_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
sub_question_answer: str
sub_question_answer_check: str
log_messages: Annotated[Sequence[BaseMessage], add_messages]

View File

@@ -0,0 +1,75 @@
from collections.abc import Hashable
from typing import Union
from langchain_core.messages import HumanMessage
from langgraph.types import Send
from danswer.agent_search.core_qa_graph.states import BaseQAState
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT
def continue_to_initial_sub_questions(
state: QAState,
) -> Union[Hashable, list[Hashable]]:
# Routes re-written queries to the (parallel) retrieval steps
# Notice the 'Send()' API that takes care of the parallelization
return [
Send(
"sub_answers_graph_initial",
BaseQAState(
sub_question_str=initial_sub_question["sub_question_str"],
sub_question_search_queries=initial_sub_question[
"sub_question_search_queries"
],
sub_question_nr=initial_sub_question["sub_question_nr"],
primary_llm=state["primary_llm"],
fast_llm=state["fast_llm"],
graph_start_time=state["graph_start_time"],
),
)
for initial_sub_question in state["initial_sub_questions"]
]
def continue_to_answer_sub_questions(state: QAState) -> Union[Hashable, list[Hashable]]:
# Routes re-written queries to the (parallel) retrieval steps
# Notice the 'Send()' API that takes care of the parallelization
return [
Send(
"sub_answers_graph",
ResearchQAState(
sub_question=sub_question["sub_question_str"],
sub_question_nr=sub_question["sub_question_nr"],
graph_start_time=state["graph_start_time"],
primary_llm=state["primary_llm"],
fast_llm=state["fast_llm"],
),
)
for sub_question in state["sub_questions"]
]
def continue_to_deep_answer(state: QAState) -> Union[Hashable, list[Hashable]]:
print("---GO TO DEEP ANSWER OR END---")
base_answer = state["base_answer"]
question = state["original_question"]
BASE_CHECK_MESSAGE = [
HumanMessage(
content=BASE_CHECK_PROMPT.format(question=question, base_answer=base_answer)
)
]
model = state["fast_llm"]
response = model.invoke(BASE_CHECK_MESSAGE)
print(f"CAN WE CONTINUE W/O GENERATING A DEEP ANSWER? - {response.pretty_repr()}")
if response.pretty_repr() == "no":
return "decompose"
else:
return "end"

View File

@@ -0,0 +1,171 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from danswer.agent_search.core_qa_graph.graph_builder import build_core_qa_graph
from danswer.agent_search.deep_qa_graph.graph_builder import build_deep_qa_graph
from danswer.agent_search.primary_graph.edges import continue_to_answer_sub_questions
from danswer.agent_search.primary_graph.edges import continue_to_deep_answer
from danswer.agent_search.primary_graph.edges import continue_to_initial_sub_questions
from danswer.agent_search.primary_graph.nodes.base_wait import base_wait
from danswer.agent_search.primary_graph.nodes.combine_retrieved_docs import (
combine_retrieved_docs,
)
from danswer.agent_search.primary_graph.nodes.custom_retrieve import custom_retrieve
from danswer.agent_search.primary_graph.nodes.decompose import decompose
from danswer.agent_search.primary_graph.nodes.deep_answer_generation import (
deep_answer_generation,
)
from danswer.agent_search.primary_graph.nodes.dummy_start import dummy_start
from danswer.agent_search.primary_graph.nodes.entity_term_extraction import (
entity_term_extraction,
)
from danswer.agent_search.primary_graph.nodes.final_stuff import final_stuff
from danswer.agent_search.primary_graph.nodes.generate_initial import generate_initial
from danswer.agent_search.primary_graph.nodes.main_decomp_base import main_decomp_base
from danswer.agent_search.primary_graph.nodes.rewrite import rewrite
from danswer.agent_search.primary_graph.nodes.sub_qa_level_aggregator import (
sub_qa_level_aggregator,
)
from danswer.agent_search.primary_graph.nodes.sub_qa_manager import sub_qa_manager
from danswer.agent_search.primary_graph.nodes.verifier import verifier
from danswer.agent_search.primary_graph.states import QAState
def build_core_graph() -> StateGraph:
# Define the nodes we will cycle between
core_answer_graph = StateGraph(state_schema=QAState)
### Add Nodes ###
core_answer_graph.add_node(node="dummy_start",
action=dummy_start)
# Re-writing the question
core_answer_graph.add_node(node="rewrite",
action=rewrite)
# The retrieval step
core_answer_graph.add_node(node="custom_retrieve",
action=custom_retrieve)
# Combine and dedupe retrieved docs.
core_answer_graph.add_node(
node="combine_retrieved_docs",
action=combine_retrieved_docs
)
# Extract entities, terms and relationships
core_answer_graph.add_node(
node="entity_term_extraction",
action=entity_term_extraction
)
# Verifying that a retrieved doc is relevant
core_answer_graph.add_node(node="verifier",
action=verifier)
# Initial question decomposition
core_answer_graph.add_node(node="main_decomp_base",
action=main_decomp_base)
# Build the base QA sub-graph and compile it
compiled_core_qa_graph = build_core_qa_graph().compile()
# Add the compiled base QA sub-graph as a node to the core graph
core_answer_graph.add_node(
node="sub_answers_graph_initial",
action=compiled_core_qa_graph
)
# Checking whether the initial answer is in the ballpark
core_answer_graph.add_node(node="base_wait",
action=base_wait)
# Decompose the question into sub-questions
core_answer_graph.add_node(node="decompose",
action=decompose)
# Manage the sub-questions
core_answer_graph.add_node(node="sub_qa_manager",
action=sub_qa_manager)
# Build the research QA sub-graph and compile it
compiled_deep_qa_graph = build_deep_qa_graph().compile()
# Add the compiled research QA sub-graph as a node to the core graph
core_answer_graph.add_node(node="sub_answers_graph",
action=compiled_deep_qa_graph)
# Aggregate the sub-questions
core_answer_graph.add_node(
node="sub_qa_level_aggregator",
action=sub_qa_level_aggregator
)
# aggregate sub questions and answers
core_answer_graph.add_node(
node="deep_answer_generation",
action=deep_answer_generation
)
# A final clean-up step
core_answer_graph.add_node(node="final_stuff",
action=final_stuff)
# Generating a response after we know the documents are relevant
core_answer_graph.add_node(node="generate_initial",
action=generate_initial)
### Add Edges ###
# start the initial sub-question decomposition
core_answer_graph.add_edge(start_key=START,
end_key="main_decomp_base")
core_answer_graph.add_conditional_edges(
source="main_decomp_base",
path=continue_to_initial_sub_questions,
)
# use the retrieved information to generate the answer
core_answer_graph.add_edge(
start_key=["verifier", "sub_answers_graph_initial"],
end_key="generate_initial"
)
core_answer_graph.add_edge(start_key="generate_initial",
end_key="base_wait")
core_answer_graph.add_conditional_edges(
source="base_wait",
path=continue_to_deep_answer,
path_map={"decompose": "entity_term_extraction", "end": "final_stuff"},
)
core_answer_graph.add_edge(start_key="entity_term_extraction", end_key="decompose")
core_answer_graph.add_edge(start_key="decompose",
end_key="sub_qa_manager")
core_answer_graph.add_conditional_edges(
source="sub_qa_manager",
path=continue_to_answer_sub_questions,
)
core_answer_graph.add_edge(
start_key="sub_answers_graph",
end_key="sub_qa_level_aggregator"
)
core_answer_graph.add_edge(
start_key="sub_qa_level_aggregator",
end_key="deep_answer_generation"
)
core_answer_graph.add_edge(
start_key="deep_answer_generation",
end_key="final_stuff"
)
core_answer_graph.add_edge(start_key="final_stuff",
end_key=END)
core_answer_graph.compile()
return core_answer_graph

View File

@@ -0,0 +1,27 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def base_wait(state: QAState) -> dict[str, Any]:
"""
Ensures that all required steps are completed before proceeding to the next step
Args:
state (messages): The current state
Returns:
dict: {} (no operation, just logging)
"""
print("---Base Wait ---")
node_start_time = datetime.now()
return {
"log_messages": generate_log_message(
message="core - base_wait",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,36 @@
from collections.abc import Sequence
from datetime import datetime
from typing import Any
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.context.search.models import InferenceSection
def combine_retrieved_docs(state: QAState) -> dict[str, Any]:
"""
Dedupe the retrieved docs.
"""
node_start_time = datetime.now()
base_retrieval_docs: Sequence[InferenceSection] = state["base_retrieval_docs"]
print(f"Number of docs from steps: {len(base_retrieval_docs)}")
dedupe_docs: list[InferenceSection] = []
for base_retrieval_doc in base_retrieval_docs:
if not any(
base_retrieval_doc.center_chunk.document_id == doc.center_chunk.document_id
for doc in dedupe_docs
):
dedupe_docs.append(base_retrieval_doc)
print(f"Number of deduped docs: {len(dedupe_docs)}")
return {
"deduped_retrieval_docs": dedupe_docs,
"log_messages": generate_log_message(
message="core - combine_retrieved_docs (dedupe)",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,52 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.primary_graph.states import RetrieverState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import SearchRequest
from danswer.context.search.pipeline import SearchPipeline
from danswer.db.engine import get_session_context_manager
from danswer.llm.factory import get_default_llms
def custom_retrieve(state: RetrieverState) -> dict[str, Any]:
"""
Retrieve documents
Args:
retriever_state (dict): The current graph state
Returns:
state (dict): New key added to state, documents, that contains retrieved documents
"""
print("---RETRIEVE---")
node_start_time = datetime.now()
query = state["rewritten_query"]
# Retrieval
# TODO: add the actual retrieval, probably from search_tool.run()
llm, fast_llm = get_default_llms()
with get_session_context_manager() as db_session:
top_sections = SearchPipeline(
search_request=SearchRequest(
query=query,
),
user=None,
llm=llm,
fast_llm=fast_llm,
db_session=db_session,
).reranked_sections
print(len(top_sections))
documents: list[InferenceSection] = []
return {
"base_retrieval_docs": documents,
"log_messages": generate_log_message(
message="core - custom_retrieve",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,78 @@
import json
import re
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.prompts import DEEP_DECOMPOSE_PROMPT
from danswer.agent_search.shared_graph_utils.utils import format_entity_term_extraction
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def decompose(state: QAState) -> dict[str, Any]:
""" """
node_start_time = datetime.now()
question = state["original_question"]
base_answer = state["base_answer"]
# get the entity term extraction dict and properly format it
entity_term_extraction_dict = state["retrieved_entities_relationships"][
"retrieved_entities_relationships"
]
entity_term_extraction_str = format_entity_term_extraction(
entity_term_extraction_dict
)
initial_question_answers = state["initial_sub_qas"]
addressed_question_list = [
x["sub_question"]
for x in initial_question_answers
if x["sub_answer_check"] == "yes"
]
failed_question_list = [
x["sub_question"]
for x in initial_question_answers
if x["sub_answer_check"] == "no"
]
msg = [
HumanMessage(
content=DEEP_DECOMPOSE_PROMPT.format(
question=question,
entity_term_extraction_str=entity_term_extraction_str,
base_answer=base_answer,
answered_sub_questions="\n - ".join(addressed_question_list),
failed_sub_questions="\n - ".join(failed_question_list),
),
)
]
# Grader
model = state["fast_llm"]
response = model.invoke(msg)
cleaned_response = re.sub(r"```json\n|\n```", "", response.pretty_repr())
parsed_response = json.loads(cleaned_response)
sub_questions_dict = {}
for sub_question_nr, sub_question_dict in enumerate(
parsed_response["sub_questions"]
):
sub_question_dict["answered"] = False
sub_question_dict["verified"] = False
sub_questions_dict[sub_question_nr] = sub_question_dict
return {
"decomposed_sub_questions_dict": sub_questions_dict,
"log_messages": generate_log_message(
message="deep - decompose",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,61 @@
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.prompts import COMBINED_CONTEXT
from danswer.agent_search.shared_graph_utils.prompts import MODIFIED_RAG_PROMPT
from danswer.agent_search.shared_graph_utils.utils import format_docs
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.agent_search.shared_graph_utils.utils import normalize_whitespace
# aggregate sub questions and answers
def deep_answer_generation(state: QAState) -> dict[str, Any]:
"""
Generate answer
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
print("---DEEP GENERATE---")
node_start_time = datetime.now()
question = state["original_question"]
docs = state["deduped_retrieval_docs"]
deep_answer_context = state["core_answer_dynamic_context"]
print(f"Number of verified retrieval docs - deep: {len(docs)}")
combined_context = normalize_whitespace(
COMBINED_CONTEXT.format(
deep_answer_context=deep_answer_context, formated_docs=format_docs(docs)
)
)
msg = [
HumanMessage(
content=MODIFIED_RAG_PROMPT.format(
question=question, combined_context=combined_context
)
)
]
# Grader
model = state["fast_llm"]
response = model.invoke(msg)
return {
"deep_answer": response.content,
"log_messages": generate_log_message(
message="deep - deep answer generation",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,11 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.primary_graph.states import QAState
def dummy_start(state: QAState) -> dict[str, Any]:
"""
Dummy node to set the start time
"""
return {"start_time": datetime.now()}

View File

@@ -0,0 +1,51 @@
import json
import re
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from danswer.agent_search.primary_graph.prompts import ENTITY_TERM_PROMPT
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.utils import format_docs
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.llm.factory import get_default_llms
def entity_term_extraction(state: QAState) -> dict[str, Any]:
"""Extract entities and terms from the question and context"""
node_start_time = datetime.now()
question = state["original_question"]
docs = state["deduped_retrieval_docs"]
doc_context = format_docs(docs)
msg = [
HumanMessage(
content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context),
)
]
_, fast_llm = get_default_llms()
# Grader
llm_response_list = list(
fast_llm.stream(
prompt=msg,
# structured_response_format={"type": "json_object", "schema": RewrittenQueries.model_json_schema()},
# structured_response_format=RewrittenQueries.model_json_schema(),
)
)
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
cleaned_response = re.sub(r"```json\n|\n```", "", llm_response)
parsed_response = json.loads(cleaned_response)
return {
"retrieved_entities_relationships": parsed_response,
"log_messages": generate_log_message(
message="deep - entity term extraction",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,85 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def final_stuff(state: QAState) -> dict[str, Any]:
"""
Invokes the agent model to generate a response based on the current state. Given
the question, it will decide to retrieve using the retriever tool, or simply end.
Args:
state (messages): The current state
Returns:
dict: The updated state with the agent response appended to messages
"""
print("---FINAL---")
node_start_time = datetime.now()
messages = state["log_messages"]
time_ordered_messages = [x.pretty_repr() for x in messages]
time_ordered_messages.sort()
print("Message Log:")
print("\n".join(time_ordered_messages))
initial_sub_qas = state["initial_sub_qas"]
initial_sub_qa_list = []
for initial_sub_qa in initial_sub_qas:
if initial_sub_qa["sub_answer_check"] == "yes":
initial_sub_qa_list.append(
f' Question:\n {initial_sub_qa["sub_question"]}\n --\n Answer:\n {initial_sub_qa["sub_answer"]}\n -----'
)
initial_sub_qa_context = "\n".join(initial_sub_qa_list)
log_message = generate_log_message(
message="all - final_stuff",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
)
print(log_message)
print("--------------------------------")
base_answer = state["base_answer"]
print(f"Final Base Answer:\n{base_answer}")
print("--------------------------------")
print(f"Initial Answered Sub Questions:\n{initial_sub_qa_context}")
print("--------------------------------")
if not state.get("deep_answer"):
print("No Deep Answer was required")
return {
"log_messages": log_message,
}
deep_answer = state["deep_answer"]
sub_qas = state["sub_qas"]
sub_qa_list = []
for sub_qa in sub_qas:
if sub_qa["sub_answer_check"] == "yes":
sub_qa_list.append(
f' Question:\n {sub_qa["sub_question"]}\n --\n Answer:\n {sub_qa["sub_answer"]}\n -----'
)
sub_qa_context = "\n".join(sub_qa_list)
print(f"Final Base Answer:\n{base_answer}")
print("--------------------------------")
print(f"Final Deep Answer:\n{deep_answer}")
print("--------------------------------")
print("Sub Questions and Answers:")
print(sub_qa_context)
return {
"log_messages": generate_log_message(
message="all - final_stuff",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,52 @@
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT
from danswer.agent_search.shared_graph_utils.utils import format_docs
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def generate(state: QAState) -> dict[str, Any]:
"""
Generate answer
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
print("---GENERATE---")
node_start_time = datetime.now()
question = state["original_question"]
docs = state["deduped_retrieval_docs"]
print(f"Number of verified retrieval docs: {len(docs)}")
msg = [
HumanMessage(
content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs))
)
]
# Grader
llm = state["fast_llm"]
response = list(
llm.stream(
prompt=msg,
structured_response_format=None,
)
)
return {
"base_answer": response[0].pretty_repr(),
"log_messages": generate_log_message(
message="core - generate",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,72 @@
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.primary_graph.prompts import INITIAL_RAG_PROMPT
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.utils import format_docs
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def generate_initial(state: QAState) -> dict[str, Any]:
"""
Generate answer
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
print("---GENERATE INITIAL---")
node_start_time = datetime.now()
question = state["original_question"]
docs = state["deduped_retrieval_docs"]
print(f"Number of verified retrieval docs - base: {len(docs)}")
sub_question_answers = state["initial_sub_qas"]
sub_question_answers_list = []
_SUB_QUESTION_ANSWER_TEMPLATE = """
Sub-Question:\n - {sub_question}\n --\nAnswer:\n - {sub_answer}\n\n
"""
for sub_question_answer_dict in sub_question_answers:
if (
sub_question_answer_dict["sub_answer_check"] == "yes"
and len(sub_question_answer_dict["sub_answer"]) > 0
and sub_question_answer_dict["sub_answer"] != "I don't know"
):
sub_question_answers_list.append(
_SUB_QUESTION_ANSWER_TEMPLATE.format(
sub_question=sub_question_answer_dict["sub_question"],
sub_answer=sub_question_answer_dict["sub_answer"],
)
)
sub_question_answer_str = "\n\n------\n\n".join(sub_question_answers_list)
msg = [
HumanMessage(
content=INITIAL_RAG_PROMPT.format(
question=question,
context=format_docs(docs),
answered_sub_questions=sub_question_answer_str,
)
)
]
# Grader
model = state["fast_llm"]
response = model.invoke(msg)
return {
"base_answer": response.pretty_repr(),
"log_messages": generate_log_message(
message="core - generate initial",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,64 @@
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.primary_graph.prompts import INITIAL_DECOMPOSITION_PROMPT
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.utils import clean_and_parse_list_string
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def main_decomp_base(state: QAState) -> dict[str, Any]:
"""
Perform an initial question decomposition, incl. one search term
Args:
state (messages): The current state
Returns:
dict: The updated state with initial decomposition
"""
print("---INITIAL DECOMP---")
node_start_time = datetime.now()
question = state["original_question"]
msg = [
HumanMessage(
content=INITIAL_DECOMPOSITION_PROMPT.format(question=question),
)
]
# Get the rewritten queries in a defined format
model = state["fast_llm"]
response = model.invoke(msg)
content = response.pretty_repr()
list_of_subquestions = clean_and_parse_list_string(content)
decomp_list = []
for sub_question_nr, sub_question in enumerate(list_of_subquestions):
sub_question_str = sub_question["sub_question"].strip()
# temporarily
sub_question_search_queries = [sub_question["search_term"]]
decomp_list.append(
{
"sub_question_str": sub_question_str,
"sub_question_search_queries": sub_question_search_queries,
"sub_question_nr": sub_question_nr,
}
)
return {
"initial_sub_questions": decomp_list,
"sub_query_start_time": node_start_time,
"log_messages": generate_log_message(
message="core - initial decomp",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,55 @@
import json
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def rewrite(state: QAState) -> dict[str, Any]:
"""
Transform the initial question into more suitable search queries.
Args:
qa_state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
print("---STARTING GRAPH---")
graph_start_time = datetime.now()
print("---TRANSFORM QUERY---")
node_start_time = datetime.now()
question = state["original_question"]
msg = [
HumanMessage(
content=REWRITE_PROMPT_MULTI.format(question=question),
)
]
# Get the rewritten queries in a defined format
fast_llm = state["fast_llm"]
llm_response = list(
fast_llm.stream(
prompt=msg,
structured_response_format=RewrittenQueries.model_json_schema(),
)
)
formatted_response: RewrittenQueries = json.loads(llm_response[0].pretty_repr())
return {
"rewritten_queries": formatted_response.rewritten_queries,
"log_messages": generate_log_message(
message="core - rewrite",
node_start_time=node_start_time,
graph_start_time=graph_start_time,
),
}

View File

@@ -0,0 +1,39 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
# aggregate sub questions and answers
def sub_qa_level_aggregator(state: QAState) -> dict[str, Any]:
sub_qas = state["sub_qas"]
node_start_time = datetime.now()
dynamic_context_list = [
"Below you will find useful information to answer the original question:"
]
checked_sub_qas = []
for core_answer_sub_qa in sub_qas:
question = core_answer_sub_qa["sub_question"]
answer = core_answer_sub_qa["sub_answer"]
verified = core_answer_sub_qa["sub_answer_check"]
if verified == "yes":
dynamic_context_list.append(
f"Question:\n{question}\n\nAnswer:\n{answer}\n\n---\n\n"
)
checked_sub_qas.append({"sub_question": question, "sub_answer": answer})
dynamic_context = "\n".join(dynamic_context_list)
return {
"core_answer_dynamic_context": dynamic_context,
"checked_sub_qas": checked_sub_qas,
"log_messages": generate_log_message(
message="deep - sub qa level aggregator",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,28 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def sub_qa_manager(state: QAState) -> dict[str, Any]:
""" """
node_start_time = datetime.now()
sub_questions_dict = state["decomposed_sub_questions_dict"]
sub_questions = {}
for sub_question_nr, sub_question_dict in sub_questions_dict.items():
sub_questions[sub_question_nr] = sub_question_dict["sub_question"]
return {
"sub_questions": sub_questions,
"num_new_question_iterations": 0,
"log_messages": generate_log_message(
message="deep - sub qa manager",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,59 @@
import json
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.primary_graph.states import VerifierState
from danswer.agent_search.shared_graph_utils.models import BinaryDecision
from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def verifier(state: VerifierState) -> dict[str, Any]:
"""
Check whether the document is relevant for the original user question
Args:
state (VerifierState): The current state
Returns:
dict: ict: The updated state with the final decision
"""
print("---VERIFY QUTPUT---")
node_start_time = datetime.now()
question = state["question"]
document_content = state["document"].combined_content
msg = [
HumanMessage(
content=VERIFIER_PROMPT.format(
question=question, document_content=document_content
)
)
]
# Grader
llm = state["fast_llm"]
response = list(
llm.stream(
prompt=msg,
structured_response_format=BinaryDecision.model_json_schema(),
)
)
raw_response = json.loads(response[0].pretty_repr())
formatted_response = BinaryDecision.model_validate(raw_response)
return {
"deduped_retrieval_docs": [state["document"]]
if formatted_response.decision == "yes"
else [],
"log_messages": generate_log_message(
message=f"core - verifier: {formatted_response.decision}",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,86 @@
INITIAL_DECOMPOSITION_PROMPT = """ \n
Please decompose an initial user question into not more than 4 appropriate sub-questions that help to
answer the original question. The purpose for this decomposition is to isolate individulal entities
(i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales
for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our
sales with company A' + 'what is our market share with company A' + 'is company A a reference customer
for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n
For each sub-question, please also create one search term that can be used to retrieve relevant
documents from a document store.
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Please formulate your answer as a list of json objects with the following format:
[{{"sub_question": <sub-question>, "search_term": <search term>}}, ...]
Answer:
"""
INITIAL_RAG_PROMPT = """ \n
You are an assistant for question-answering tasks. Use the information provided below - and only the
provided information - to answer the provided question.
The information provided below consists of:
1) a number of answered sub-questions - these are very important(!) and definitely should be
considered to answer the question.
2) a number of documents that were also deemed relevant for the question.
If you don't know the answer or if the provided information is empty or insufficient, just say
"I don't know". Do not use your internal knowledge!
Again, only use the provided informationand do not use your internal knowledge! It is a matter of life
and death that you do NOT use your internal knowledge, just the provided information!
Try to keep your answer concise.
And here is the question and the provided information:
\n
\nQuestion:\n {question}
\nAnswered Sub-questions:\n {answered_sub_questions}
\nContext:\n {context} \n\n
\n\n
Answer:"""
ENTITY_TERM_PROMPT = """ \n
Based on the original question and the context retieved from a dataset, please generate a list of
entities (e.g. companies, organizations, industries, products, locations, etc.), terms and concepts
(e.g. sales, revenue, etc.) that are relevant for the question, plus their relations to each other.
\n\n
Here is the original question:
\n ------- \n
{question}
\n ------- \n
And here is the context retrieved:
\n ------- \n
{context}
\n ------- \n
Please format your answer as a json object in the following format:
{{"retrieved_entities_relationships": {{
"entities": [{{
"entity_name": <assign a name for the entity>,
"entity_type": <specify a short type name for the entity, such as 'company', 'location',...>
}}],
"relationships": [{{
"name": <assign a name for the relationship>,
"type": <specify a short type name for the relationship, such as 'sales_to', 'is_location_of',...>,
"entities": [<related entity name 1>, <related entity name 2>]
}}],
"terms": [{{
"term_name": <assign a name for the term>,
"term_type": <specify a short type name for the term, such as 'revenue', 'market_share',...>,
"similar_to": <list terms that are similar to this term>
}}]
}}
}}
"""

View File

@@ -0,0 +1,73 @@
import operator
from collections.abc import Sequence
from datetime import datetime
from typing import Annotated
from typing import TypedDict
from langchain_core.messages import BaseMessage
from langgraph.graph.message import add_messages
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
from danswer.context.search.models import InferenceSection
class QAState(TypedDict):
# The 'main' state of the answer graph
original_question: str
graph_start_time: datetime
# start time for parallel initial sub-questionn thread
sub_query_start_time: datetime
log_messages: Annotated[Sequence[BaseMessage], add_messages]
rewritten_queries: RewrittenQueries
sub_questions: list[dict]
initial_sub_questions: list[dict]
ranked_subquestion_ids: list[int]
decomposed_sub_questions_dict: dict
rejected_sub_questions: Annotated[list[str], operator.add]
rejected_sub_questions_handled: bool
sub_qas: Annotated[Sequence[dict], operator.add]
initial_sub_qas: Annotated[Sequence[dict], operator.add]
checked_sub_qas: Annotated[Sequence[dict], operator.add]
base_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add]
deduped_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add]
reranked_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add]
retrieved_entities_relationships: dict
questions_context: list[dict]
qa_level: int
top_chunks: list[InferenceSection]
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
num_new_question_iterations: int
core_answer_dynamic_context: str
dynamic_context: str
initial_base_answer: str
base_answer: str
deep_answer: str
class QAOuputState(TypedDict):
# The 'main' output state of the answer graph. Removes all the intermediate states
original_question: str
log_messages: Annotated[Sequence[BaseMessage], add_messages]
sub_questions: list[dict]
sub_qas: Annotated[Sequence[dict], operator.add]
initial_sub_qas: Annotated[Sequence[dict], operator.add]
checked_sub_qas: Annotated[Sequence[dict], operator.add]
reranked_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add]
retrieved_entities_relationships: dict
top_chunks: list[InferenceSection]
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
base_answer: str
deep_answer: str
class RetrieverState(TypedDict):
# The state for the parallel Retrievers. They each need to see only one query
rewritten_query: str
graph_start_time: datetime
class VerifierState(TypedDict):
# The state for the parallel verification step. Each node execution need to see only one question/doc pair
document: InferenceSection
question: str
graph_start_time: datetime

View File

@@ -0,0 +1,22 @@
from danswer.agent_search.primary_graph.graph_builder import build_core_graph
from danswer.llm.answering.answer import AnswerStream
from danswer.llm.interfaces import LLM
from danswer.tools.tool import Tool
def run_graph(
query: str,
llm: LLM,
tools: list[Tool],
) -> AnswerStream:
graph = build_core_graph()
inputs = {
"original_question": query,
"messages": [],
"tools": tools,
"llm": llm,
}
compiled_graph = graph.compile()
output = compiled_graph.invoke(input=inputs)
yield from output

View File

@@ -0,0 +1,16 @@
from typing import Literal
from pydantic import BaseModel
# Pydantic models for structured outputs
class RewrittenQueries(BaseModel):
rewritten_queries: list[str]
class BinaryDecision(BaseModel):
decision: Literal["yes", "no"]
class SubQuestions(BaseModel):
sub_questions: list[str]

View File

@@ -0,0 +1,342 @@
REWRITE_PROMPT_MULTI_ORIGINAL = """ \n
Please convert an initial user question into a 2-3 more appropriate short and pointed search queries for retrievel from a
document store. Particularly, try to think about resolving ambiguities and make the search queries more specific,
enabling the system to search more broadly.
Also, try to make the search queries not redundant, i.e. not too similar! \n\n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Formulate the queries separated by '--' (Do not say 'Query 1: ...', just write the querytext): """
REWRITE_PROMPT_MULTI = """ \n
Please create a list of 2-3 sample documents that could answer an original question. Each document
should be about as long as the original question. \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Formulate the sample documents separated by '--' (Do not say 'Document 1: ...', just write the text): """
BASE_RAG_PROMPT = """ \n
You are an assistant for question-answering tasks. Use the context provided below - and only the
provided context - to answer the question. If you don't know the answer or if the provided context is
empty, just say "I don't know". Do not use your internal knowledge!
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the
question based on the context, say "I don't know". It is a matter of life and death that you do NOT
use your internal knowledge, just the provided information!
Use three sentences maximum and keep the answer concise.
answer concise.\nQuestion:\n {question} \nContext:\n {context} \n\n
\n\n
Answer:"""
BASE_CHECK_PROMPT = """ \n
Please check whether 1) the suggested answer seems to fully address the original question AND 2)the
original question requests a simple, factual answer, and there are no ambiguities, judgements,
aggregations, or any other complications that may require extra context. (I.e., if the question is
somewhat addressed, but the answer would benefit from more context, then answer with 'no'.)
Please only answer with 'yes' or 'no' \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Here is the proposed answer:
\n ------- \n
{base_answer}
\n ------- \n
Please answer with yes or no:"""
VERIFIER_PROMPT = """ \n
Please check whether the document seems to be relevant for the answer of the original question. Please
only answer with 'yes' or 'no' \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Here is the document text:
\n ------- \n
{document_content}
\n ------- \n
Please answer with yes or no:"""
INITIAL_DECOMPOSITION_PROMPT_BASIC = """ \n
Please decompose an initial user question into not more than 4 appropriate sub-questions that help to
answer the original question. The purpose for this decomposition is to isolate individulal entities
(i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales
for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our
sales with company A' + 'what is our market share with company A' + 'is company A a reference customer
for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Please formulate your answer as a list of subquestions:
Answer:
"""
REWRITE_PROMPT_SINGLE = """ \n
Please convert an initial user question into a more appropriate search query for retrievel from a
document store. \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Formulate the query: """
MODIFIED_RAG_PROMPT = """You are an assistant for question-answering tasks. Use the context provided below
- and only this context - to answer the question. If you don't know the answer, just say "I don't know".
Use three sentences maximum and keep the answer concise.
Pay also particular attention to the sub-questions and their answers, at least it may enrich the answer.
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the
question based on the context, say "I don't know". It is a matter of life and death that you do NOT
use your internal knowledge, just the provided information!
\nQuestion: {question}
\nContext: {combined_context} \n
Answer:"""
ORIG_DEEP_DECOMPOSE_PROMPT = """ \n
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite
good enough. Also, some sub-questions had been answered and this information has been used to provide
the initial answer. Some other subquestions may have been suggested based on little knowledge, but they
were not directly answerable. Also, some entities, relationships and terms are givenm to you so that
you have an idea of how the avaiolable data looks like.
Your role is to generate 3-5 new sub-questions that would help to answer the initial question,
considering:
1) The initial question
2) The initial answer that was found to be unsatisfactory
3) The sub-questions that were answered
4) The sub-questions that were suggested but not answered
5) The entities, relationships and terms that were extracted from the context
The individual questions should be answerable by a good RAG system.
So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the
question for different entities that may be involved in the original question, but in a way that does
not duplicate questions that were already tried.
Additional Guidelines:
- The sub-questions should be specific to the question and provide richer context for the question,
resolve ambiguities, or address shortcoming of the initial answer
- Each sub-question - when answered - should be relevant for the answer to the original question
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
other complications that may require extra context.
- The sub-questions MUST have the full context of the original question so that it can be executed by
a RAG system independently without the original question available
(Example:
- initial question: "What is the capital of France?"
- bad sub-question: "What is the name of the river there?"
- good sub-question: "What is the name of the river that flows through Paris?"
- For each sub-question, please provide a short explanation for why it is a good sub-question. So
generate a list of dictionaries with the following format:
[{{"sub_question": <sub-question>, "explanation": <explanation>, "search_term": <rewrite the
sub-question using as a search phrase for the document store>}}, ...]
\n\n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Here is the initial sub-optimal answer:
\n ------- \n
{base_answer}
\n ------- \n
Here are the sub-questions that were answered:
\n ------- \n
{answered_sub_questions}
\n ------- \n
Here are the sub-questions that were suggested but not answered:
\n ------- \n
{failed_sub_questions}
\n ------- \n
And here are the entities, relationships and terms extracted from the context:
\n ------- \n
{entity_term_extraction_str}
\n ------- \n
Please generate the list of good, fully contextualized sub-questions that would help to address the
main question. Again, please find questions that are NOT overlapping too much with the already answered
sub-questions or those that already were suggested and failed.
In other words - what can we try in addition to what has been tried so far?
Please think through it step by step and then generate the list of json dictionaries with the following
format:
{{"sub_questions": [{{"sub_question": <sub-question>,
"explanation": <explanation>,
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
...]}} """
DEEP_DECOMPOSE_PROMPT = """ \n
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite
good enough. Also, some sub-questions had been answered and this information has been used to provide
the initial answer. Some other subquestions may have been suggested based on little knowledge, but they
were not directly answerable. Also, some entities, relationships and terms are givenm to you so that
you have an idea of how the avaiolable data looks like.
Your role is to generate 4-6 new sub-questions that would help to answer the initial question,
considering:
1) The initial question
2) The initial answer that was found to be unsatisfactory
3) The sub-questions that were answered
4) The sub-questions that were suggested but not answered
5) The entities, relationships and terms that were extracted from the context
The individual questions should be answerable by a good RAG system.
So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the
question for different entities that may be involved in the original question, but in a way that does
not duplicate questions that were already tried.
Additional Guidelines:
- The sub-questions should be specific to the question and provide richer context for the question,
resolve ambiguities, or address shortcoming of the initial answer
- Each sub-question - when answered - should be relevant for the answer to the original question
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
other complications that may require extra context.
- The sub-questions MUST have the full context of the original question so that it can be executed by
a RAG system independently without the original question available
(Example:
- initial question: "What is the capital of France?"
- bad sub-question: "What is the name of the river there?"
- good sub-question: "What is the name of the river that flows through Paris?"
- For each sub-question, please also provide a search term that can be used to retrieve relevant
documents from a document store.
\n\n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Here is the initial sub-optimal answer:
\n ------- \n
{base_answer}
\n ------- \n
Here are the sub-questions that were answered:
\n ------- \n
{answered_sub_questions}
\n ------- \n
Here are the sub-questions that were suggested but not answered:
\n ------- \n
{failed_sub_questions}
\n ------- \n
And here are the entities, relationships and terms extracted from the context:
\n ------- \n
{entity_term_extraction_str}
\n ------- \n
Please generate the list of good, fully contextualized sub-questions that would help to address the
main question. Again, please find questions that are NOT overlapping too much with the already answered
sub-questions or those that already were suggested and failed.
In other words - what can we try in addition to what has been tried so far?
Generate the list of json dictionaries with the following format:
{{"sub_questions": [{{"sub_question": <sub-question>,
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
...]}} """
DECOMPOSE_PROMPT = """ \n
For an initial user question, please generate at 5-10 individual sub-questions whose answers would help
\n to answer the initial question. The individual questions should be answerable by a good RAG system.
So a good idea would be to \n use the sub-questions to resolve ambiguities and/or to separate the
question for different entities that may be involved in the original question.
In order to arrive at meaningful sub-questions, please also consider the context retrieved from the
document store, expressed as entities, relationships and terms. You can also think about the types
mentioned in brackets
Guidelines:
- The sub-questions should be specific to the question and provide richer context for the question,
and or resolve ambiguities
- Each sub-question - when answered - should be relevant for the answer to the original question
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
other complications that may require extra context.
- The sub-questions MUST have the full context of the original question so that it can be executed by
a RAG system independently without the original question available
(Example:
- initial question: "What is the capital of France?"
- bad sub-question: "What is the name of the river there?"
- good sub-question: "What is the name of the river that flows through Paris?"
- For each sub-question, please provide a short explanation for why it is a good sub-question. So
generate a list of dictionaries with the following format:
[{{"sub_question": <sub-question>, "explanation": <explanation>, "search_term": <rewrite the
sub-question using as a search phrase for the document store>}}, ...]
\n\n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
And here are the entities, relationships and terms extracted from the context:
\n ------- \n
{entity_term_extraction_str}
\n ------- \n
Please generate the list of good, fully contextualized sub-questions that would help to address the
main question. Don't be too specific unless the original question is specific.
Please think through it step by step and then generate the list of json dictionaries with the following
format:
{{"sub_questions": [{{"sub_question": <sub-question>,
"explanation": <explanation>,
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
...]}} """
#### Consolidations
COMBINED_CONTEXT = """-------
Below you will find useful information to answer the original question. First, you see a number of
sub-questions with their answers. This information should be considered to be more focussed and
somewhat more specific to the original question as it tries to contextualized facts.
After that will see the documents that were considered to be relevant to answer the original question.
Here are the sub-questions and their answers:
\n\n {deep_answer_context} \n\n
\n\n Here are the documents that were considered to be relevant to answer the original question:
\n\n {formated_docs} \n\n
----------------
"""
SUB_QUESTION_EXPLANATION_RANKER_PROMPT = """-------
Below you will find a question that we ultimately want to answer (the original question) and a list of
motivations in arbitrary order for generated sub-questions that are supposed to help us answering the
original question. The motivations are formatted as <motivation number>: <motivation explanation>.
(Again, the numbering is arbitrary and does not necessarily mean that 1 is the most relevant
motivation and 2 is less relevant.)
Please rank the motivations in order of relevance for answering the original question. Also, try to
ensure that the top questions do not duplicate too much, i.e. that they are not too similar.
Ultimately, create a list with the motivation numbers where the number of the most relevant
motivations comes first.
Here is the original question:
\n\n {original_question} \n\n
\n\n Here is the list of sub-question motivations:
\n\n {sub_question_explanations} \n\n
----------------
Please think step by step and then generate the ranked list of motivations.
Please format your answer as a json object in the following format:
{{"reasonning": <explain your reasoning for the ranking>,
"ranked_motivations": <ranked list of motivation numbers>}}
"""

View File

@@ -0,0 +1,91 @@
import ast
import json
import re
from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from typing import Any
from danswer.context.search.models import InferenceSection
def normalize_whitespace(text: str) -> str:
"""Normalize whitespace in text to single spaces and strip leading/trailing whitespace."""
import re
return re.sub(r"\s+", " ", text.strip())
# Post-processing
def format_docs(docs: Sequence[InferenceSection]) -> str:
return "\n\n".join(doc.combined_content for doc in docs)
def clean_and_parse_list_string(json_string: str) -> list[dict]:
# Remove markdown code block markers and any newline prefixes
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
cleaned_string = " ".join(cleaned_string.split())
# Parse the cleaned string into a Python dictionary
return ast.literal_eval(cleaned_string)
def clean_and_parse_json_string(json_string: str) -> dict[str, Any]:
# Remove markdown code block markers and any newline prefixes
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
cleaned_string = " ".join(cleaned_string.split())
# Parse the cleaned string into a Python dictionary
return json.loads(cleaned_string)
def format_entity_term_extraction(entity_term_extraction_dict: dict[str, Any]) -> str:
entities = entity_term_extraction_dict["entities"]
terms = entity_term_extraction_dict["terms"]
relationships = entity_term_extraction_dict["relationships"]
entity_strs = ["\nEntities:\n"]
for entity in entities:
entity_str = f"{entity['entity_name']} ({entity['entity_type']})"
entity_strs.append(entity_str)
entity_str = "\n - ".join(entity_strs)
relationship_strs = ["\n\nRelationships:\n"]
for relationship in relationships:
relationship_str = f"{relationship['name']} ({relationship['type']}): {relationship['entities']}"
relationship_strs.append(relationship_str)
relationship_str = "\n - ".join(relationship_strs)
term_strs = ["\n\nTerms:\n"]
for term in terms:
term_str = f"{term['term_name']} ({term['term_type']}): similar to {term['similar_to']}"
term_strs.append(term_str)
term_str = "\n - ".join(term_strs)
return "\n".join(entity_strs + relationship_strs + term_strs)
def _format_time_delta(time: timedelta) -> str:
seconds_from_start = f"{((time).seconds):03d}"
microseconds_from_start = f"{((time).microseconds):06d}"
return f"{seconds_from_start}.{microseconds_from_start}"
def generate_log_message(
message: str,
node_start_time: datetime,
graph_start_time: datetime | None = None,
) -> str:
current_time = datetime.now()
if graph_start_time is not None:
graph_time_str = _format_time_delta(current_time - graph_start_time)
else:
graph_time_str = "N/A"
node_time_str = _format_time_delta(current_time - node_start_time)
return f"{graph_time_str} ({node_time_str} s): {message}"

View File

@@ -23,7 +23,9 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
)
return UserPreferences(**preferences_data)
except KvKeyNotFoundError:
return UserPreferences(chosen_assistants=None, default_model=None)
return UserPreferences(
chosen_assistants=None, default_model=None, auto_scroll=True
)
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:

View File

@@ -49,7 +49,7 @@ from httpx_oauth.oauth2 import BaseOAuth2
from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy import text
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from danswer.auth.api_key import get_hashed_api_key_from_request
from danswer.auth.invited_users import get_invited_users
@@ -80,13 +80,14 @@ from danswer.db.auth import get_default_admin_user_emails
from danswer.db.auth import get_user_count
from danswer.db.auth import get_user_db
from danswer.db.auth import SQLAlchemyUserAdminDB
from danswer.db.engine import get_async_session
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.engine import get_session
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
from danswer.db.models import User
from danswer.db.users import get_user_by_email
from danswer.server.utils import BasicAuthenticationError
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
@@ -99,11 +100,6 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
class BasicAuthenticationError(HTTPException):
def __init__(self, detail: str):
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
def is_user_admin(user: User | None) -> bool:
if AUTH_TYPE == AuthType.DISABLED:
return True
@@ -609,7 +605,7 @@ optional_fastapi_current_user = fastapi_users.current_user(active=True, optional
async def optional_user_(
request: Request,
user: User | None,
db_session: Session,
async_db_session: AsyncSession,
) -> User | None:
"""NOTE: `request` and `db_session` are not used here, but are included
for the EE version of this function."""
@@ -618,13 +614,21 @@ async def optional_user_(
async def optional_user(
request: Request,
db_session: Session = Depends(get_session),
async_db_session: AsyncSession = Depends(get_async_session),
user: User | None = Depends(optional_fastapi_current_user),
) -> User | None:
versioned_fetch_user = fetch_versioned_implementation(
"danswer.auth.users", "optional_user_"
)
return await versioned_fetch_user(request, user, db_session)
user = await versioned_fetch_user(request, user, async_db_session)
# check if an API key is present
if user is None:
hashed_api_key = get_hashed_api_key_from_request(request)
if hashed_api_key:
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)
return user
async def double_check_user(
@@ -910,8 +914,8 @@ def get_oauth_router(
return router
def api_key_dep(
request: Request, db_session: Session = Depends(get_session)
async def api_key_dep(
request: Request, async_db_session: AsyncSession = Depends(get_async_session)
) -> User | None:
if AUTH_TYPE == AuthType.DISABLED:
return None
@@ -921,7 +925,7 @@ def api_key_dep(
raise HTTPException(status_code=401, detail="Missing API key")
if hashed_api_key:
user = fetch_user_for_api_key(hashed_api_key, db_session)
user = await fetch_user_for_api_key(hashed_api_key, async_db_session)
if user is None:
raise HTTPException(status_code=401, detail="Invalid API key")

View File

@@ -11,6 +11,7 @@ from celery.exceptions import WorkerShutdown
from celery.states import READY_STATES
from celery.utils.log import get_task_logger
from celery.worker import strategy # type: ignore
from redis.lock import Lock as RedisLock
from sentry_sdk.integrations.celery import CeleryIntegration
from sqlalchemy import text
from sqlalchemy.orm import Session
@@ -332,16 +333,16 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
return
logger.info("Releasing primary worker lock.")
lock = sender.primary_worker_lock
lock: RedisLock = sender.primary_worker_lock
try:
if lock.owned():
try:
lock.release()
sender.primary_worker_lock = None
except Exception as e:
logger.error(f"Failed to release primary worker lock: {e}")
except Exception as e:
logger.error(f"Failed to check if primary worker lock is owned: {e}")
except Exception:
logger.exception("Failed to release primary worker lock")
except Exception:
logger.exception("Failed to check if primary worker lock is owned")
def on_setup_logging(

View File

@@ -1,5 +1,6 @@
import multiprocessing
from typing import Any
from typing import cast
from celery import bootsteps # type: ignore
from celery import Celery
@@ -10,18 +11,21 @@ from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
from redis.lock import Lock as RedisLock
import danswer.background.celery.apps.app_base as app_base
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.background.celery.tasks.vespa.tasks import get_unfenced_index_attempt_ids
from danswer.background.celery.tasks.indexing.tasks import (
get_unfenced_index_attempt_ids,
)
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from danswer.db.engine import get_session_with_default_tenant
from danswer.db.engine import SqlEngine
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_canceled
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
@@ -35,7 +39,6 @@ from danswer.redis.redis_usergroup import RedisUserGroup
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
@@ -95,6 +98,15 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
# by the primary worker. This is unnecessary in the multi tenant scenario
r = get_redis_client(tenant_id=None)
# Log the role and slave count - being connected to a slave or slave count > 0 could be problematic
info: dict[str, Any] = cast(dict, r.info("replication"))
role: str = cast(str, info.get("role"))
connected_slaves: int = info.get("connected_slaves", 0)
logger.info(
f"Redis INFO REPLICATION: role={role} connected_slaves={connected_slaves}"
)
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
@@ -104,9 +116,13 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.
lock = r.lock(
# set thread_local=False since we don't control what thread the periodic task might
# reacquire the lock with
lock: RedisLock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
thread_local=False,
)
logger.info("Primary worker lock: Acquire starting.")
@@ -153,13 +169,13 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
continue
failure_reason = (
f"Orphaned index attempt found on startup: "
f"Canceling leftover index attempt found on startup: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
logger.warning(failure_reason)
mark_attempt_failed(attempt.id, db_session, failure_reason)
mark_attempt_canceled(attempt.id, db_session, failure_reason)
@worker_ready.connect
@@ -215,7 +231,7 @@ class HubPeriodicTask(bootsteps.StartStopStep):
if not hasattr(worker, "primary_worker_lock"):
return
lock = worker.primary_worker_lock
lock: RedisLock = worker.primary_worker_lock
r = get_redis_client(tenant_id=None)

View File

@@ -4,7 +4,6 @@ from typing import Any
from sqlalchemy.orm import Session
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
@@ -17,6 +16,7 @@ from danswer.connectors.models import Document
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.enums import TaskStatus
from danswer.db.models import TaskQueueState
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.redis.redis_connector import RedisConnector
from danswer.server.documents.models import DeletionAttemptSnapshot
from danswer.utils.logger import setup_logger
@@ -78,7 +78,7 @@ def document_batch_to_ids(
def extract_ids_from_runnable_connector(
runnable_connector: BaseConnector,
callback: RunIndexingCallbackInterface | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> set[str]:
"""
If the SlimConnector hasnt been implemented for the given connector, just pull
@@ -111,10 +111,15 @@ def extract_ids_from_runnable_connector(
for doc_batch in doc_batch_generator:
if callback:
if callback.should_stop():
raise RuntimeError("Stop signal received")
callback.progress(len(doc_batch))
raise RuntimeError(
"extract_ids_from_runnable_connector: Stop signal detected"
)
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
if callback:
callback.progress("extract_ids_from_runnable_connector", len(doc_batch))
return all_connector_doc_ids

View File

@@ -2,54 +2,55 @@ from datetime import timedelta
from typing import Any
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryTask
tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": "check_for_vespa_sync_task",
"task": DanswerCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
"schedule": timedelta(seconds=20),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-connector-deletion",
"task": "check_for_connector_deletion_task",
"task": DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
"schedule": timedelta(seconds=20),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-indexing",
"task": "check_for_indexing",
"task": DanswerCeleryTask.CHECK_FOR_INDEXING,
"schedule": timedelta(seconds=15),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-prune",
"task": "check_for_pruning",
"task": DanswerCeleryTask.CHECK_FOR_PRUNING,
"schedule": timedelta(seconds=15),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "kombu-message-cleanup",
"task": "kombu_message_cleanup_task",
"task": DanswerCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
"schedule": timedelta(seconds=3600),
"options": {"priority": DanswerCeleryPriority.LOWEST},
},
{
"name": "monitor-vespa-sync",
"task": "monitor_vespa_sync",
"task": DanswerCeleryTask.MONITOR_VESPA_SYNC,
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-doc-permissions-sync",
"task": "check_for_doc_permissions_sync",
"task": DanswerCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
"schedule": timedelta(seconds=30),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-external-group-sync",
"task": "check_for_external_group_sync",
"task": DanswerCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
"schedule": timedelta(seconds=20),
"options": {"priority": DanswerCeleryPriority.HIGH},
},

View File

@@ -5,13 +5,13 @@ from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.connector_credential_pair import get_connector_credential_pairs
@@ -19,7 +19,7 @@ from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.search_settings import get_all_search_settings
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_delete import RedisConnectorDeletionFenceData
from danswer.redis.redis_connector_delete import RedisConnectorDeletePayload
from danswer.redis.redis_pool import get_redis_client
@@ -29,7 +29,7 @@ class TaskDependencyError(RuntimeError):
@shared_task(
name="check_for_connector_deletion_task",
name=DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,
@@ -37,7 +37,7 @@ class TaskDependencyError(RuntimeError):
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
lock_beat: RedisLock = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -60,7 +60,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
redis_connector = RedisConnector(tenant_id, cc_pair_id)
try:
try_generate_document_cc_pair_cleanup_tasks(
self.app, cc_pair_id, db_session, r, lock_beat, tenant_id
self.app, cc_pair_id, db_session, lock_beat, tenant_id
)
except TaskDependencyError as e:
# this means we wanted to start deleting but dependent tasks were running
@@ -86,7 +86,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
app: Celery,
cc_pair_id: int,
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str | None,
) -> int | None:
@@ -118,7 +117,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
return None
# set a basic fence to start
fence_payload = RedisConnectorDeletionFenceData(
fence_payload = RedisConnectorDeletePayload(
num_tasks=None,
submitted=datetime.now(timezone.utc),
)

View File

@@ -8,6 +8,7 @@ from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from danswer.access.models import DocExternalAccess
from danswer.background.celery.apps.app_base import task_logger
@@ -17,9 +18,11 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import DocumentSource
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import upsert_document_by_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
@@ -27,7 +30,7 @@ from danswer.db.models import ConnectorCredentialPair
from danswer.db.users import batch_add_ext_perm_user_if_not_exists
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncData,
RedisConnectorPermissionSyncPayload,
)
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import doc_permission_sync_ctx
@@ -81,7 +84,7 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
@shared_task(
name="check_for_doc_permissions_sync",
name=DanswerCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
@@ -138,7 +141,7 @@ def try_creating_permissions_sync_task(
LOCK_TIMEOUT = 30
lock = r.lock(
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
timeout=LOCK_TIMEOUT,
)
@@ -162,8 +165,8 @@ def try_creating_permissions_sync_task(
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
app.send_task(
"connector_permission_sync_generator_task",
result = app.send_task(
DanswerCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
kwargs=dict(
cc_pair_id=cc_pair_id,
tenant_id=tenant_id,
@@ -174,8 +177,8 @@ def try_creating_permissions_sync_task(
)
# set a basic fence to start
payload = RedisConnectorPermissionSyncData(
started=None,
payload = RedisConnectorPermissionSyncPayload(
started=None, celery_task_id=result.id
)
redis_connector.permissions.set_fence(payload)
@@ -190,7 +193,7 @@ def try_creating_permissions_sync_task(
@shared_task(
name="connector_permission_sync_generator_task",
name=DanswerCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
acks_late=False,
soft_time_limit=JOB_TIMEOUT,
track_started=True,
@@ -241,13 +244,17 @@ def connector_permission_sync_generator_task(
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
if doc_sync_func is None:
raise ValueError(f"No doc sync func found for {source_type}")
raise ValueError(
f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}"
)
logger.info(f"Syncing docs for {source_type}")
logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}")
payload = RedisConnectorPermissionSyncData(
started=datetime.now(timezone.utc),
)
payload = redis_connector.permissions.payload
if not payload:
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
payload.started = datetime.now(timezone.utc)
redis_connector.permissions.set_fence(payload)
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
@@ -256,7 +263,12 @@ def connector_permission_sync_generator_task(
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
)
tasks_generated = redis_connector.permissions.generate_tasks(
self.app, lock, document_external_accesses, source_type
celery_app=self.app,
lock=lock,
new_permissions=document_external_accesses,
source_string=source_type,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
)
if tasks_generated is None:
return None
@@ -281,7 +293,7 @@ def connector_permission_sync_generator_task(
@shared_task(
name="update_external_document_permissions_task",
name=DanswerCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
time_limit=LIGHT_TIME_LIMIT,
max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES,
@@ -292,6 +304,8 @@ def update_external_document_permissions_task(
tenant_id: str | None,
serialized_doc_external_access: dict,
source_string: str,
connector_id: int,
credential_id: int,
) -> bool:
document_external_access = DocExternalAccess.from_dict(
serialized_doc_external_access
@@ -300,18 +314,28 @@ def update_external_document_permissions_task(
external_access = document_external_access.external_access
try:
with get_session_with_tenant(tenant_id) as db_session:
# Then we build the update requests to update vespa
# Add the users to the DB if they don't exist
batch_add_ext_perm_user_if_not_exists(
db_session=db_session,
emails=list(external_access.external_user_emails),
)
upsert_document_external_perms(
# Then we upsert the document's external permissions in postgres
created_new_doc = upsert_document_external_perms(
db_session=db_session,
doc_id=doc_id,
external_access=external_access,
source_type=DocumentSource(source_string),
)
if created_new_doc:
# If a new document was created, we associate it with the cc_pair
upsert_document_by_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
document_ids=[doc_id],
)
logger.debug(
f"Successfully synced postgres document permissions for {doc_id}"
)

View File

@@ -8,6 +8,7 @@ from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
@@ -16,6 +17,7 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector import mark_cc_pair_as_external_group_synced
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
@@ -24,13 +26,20 @@ from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_ext_group_sync import (
RedisConnectorExternalGroupSyncPayload,
)
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.danswer.db.connector_credential_pair import get_cc_pairs_by_source
from ee.danswer.db.external_perm import ExternalUserGroup
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair
from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIOD
from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIODS
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
from ee.danswer.external_permissions.sync_params import (
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC,
)
logger = setup_logger()
@@ -49,7 +58,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
if cc_pair.access_type != AccessType.SYNC:
return False
# skip pruning if not active
# skip external group sync if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
@@ -66,9 +75,9 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
if last_ext_group_sync is None:
return True
source_sync_period = EXTERNAL_GROUP_SYNC_PERIOD
source_sync_period = EXTERNAL_GROUP_SYNC_PERIODS.get(cc_pair.connector.source)
# If EXTERNAL_GROUP_SYNC_PERIOD is None, we always run the sync.
# If EXTERNAL_GROUP_SYNC_PERIODS is None, we always run the sync.
if not source_sync_period:
return True
@@ -81,7 +90,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
@shared_task(
name="check_for_external_group_sync",
name=DanswerCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
@@ -102,12 +111,28 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
# We only want to sync one cc_pair per source type in
# GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC
for source in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC:
# These are ordered by cc_pair id so the first one is the one we want
cc_pairs_to_dedupe = get_cc_pairs_by_source(
db_session, source, only_sync=True
)
# We only want to sync one cc_pair per source type
# in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC so we dedupe here
for cc_pair_to_remove in cc_pairs_to_dedupe[1:]:
cc_pairs = [
cc_pair
for cc_pair in cc_pairs
if cc_pair.id != cc_pair_to_remove.id
]
for cc_pair in cc_pairs:
if _is_external_group_sync_due(cc_pair):
cc_pair_ids_to_sync.append(cc_pair.id)
for cc_pair_id in cc_pair_ids_to_sync:
tasks_created = try_creating_permissions_sync_task(
tasks_created = try_creating_external_group_sync_task(
self.app, cc_pair_id, r, tenant_id
)
if not tasks_created:
@@ -125,7 +150,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
lock_beat.release()
def try_creating_permissions_sync_task(
def try_creating_external_group_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
@@ -156,8 +181,8 @@ def try_creating_permissions_sync_task(
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
_ = app.send_task(
"connector_external_group_sync_generator_task",
result = app.send_task(
DanswerCeleryTask.CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK,
kwargs=dict(
cc_pair_id=cc_pair_id,
tenant_id=tenant_id,
@@ -166,8 +191,13 @@ def try_creating_permissions_sync_task(
task_id=custom_task_id,
priority=DanswerCeleryPriority.HIGH,
)
# set a basic fence to start
redis_connector.external_group_sync.set_fence(True)
payload = RedisConnectorExternalGroupSyncPayload(
started=datetime.now(timezone.utc),
celery_task_id=result.id,
)
redis_connector.external_group_sync.set_fence(payload)
except Exception:
task_logger.exception(
@@ -182,7 +212,7 @@ def try_creating_permissions_sync_task(
@shared_task(
name="connector_external_group_sync_generator_task",
name=DanswerCeleryTask.CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK,
acks_late=False,
soft_time_limit=JOB_TIMEOUT,
track_started=True,
@@ -195,7 +225,7 @@ def connector_external_group_sync_generator_task(
tenant_id: str | None,
) -> None:
"""
Permission sync task that handles document permission syncing for a given connector credential pair
Permission sync task that handles external group syncing for a given connector credential pair
This task assumes that the task has already been properly fenced
"""
@@ -203,7 +233,7 @@ def connector_external_group_sync_generator_task(
r = get_redis_client(tenant_id=tenant_id)
lock = r.lock(
lock: RedisLock = r.lock(
DanswerRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
@@ -228,9 +258,13 @@ def connector_external_group_sync_generator_task(
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
if ext_group_sync_func is None:
raise ValueError(f"No external group sync func found for {source_type}")
raise ValueError(
f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
)
logger.info(f"Syncing docs for {source_type}")
logger.info(
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
)
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)
@@ -249,7 +283,6 @@ def connector_external_group_sync_generator_task(
)
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
except Exception as e:
task_logger.exception(
f"Failed to run external group sync: cc_pair={cc_pair_id}"
@@ -260,6 +293,6 @@ def connector_external_group_sync_generator_task(
raise e
finally:
# we always want to clear the fence after the task is done or failed so it doesn't get stuck
redis_connector.external_group_sync.set_fence(False)
redis_connector.external_group_sync.set_fence(None)
if lock.owned():
lock.release()

View File

@@ -3,6 +3,7 @@ from datetime import timezone
from http import HTTPStatus
from time import sleep
import redis
import sentry_sdk
from celery import Celery
from celery import shared_task
@@ -16,35 +17,42 @@ from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.background.indexing.run_indexing import RunIndexingCallbackInterface
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import DocumentSource
from danswer.db.connector import mark_ccpair_with_indexing_trigger
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.enums import IndexingMode
from danswer.db.enums import IndexingStatus
from danswer.db.enums import IndexModelStatus
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import delete_index_attempt
from danswer.db.index_attempt import get_all_index_attempts_by_status
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
from danswer.db.index_attempt import mark_attempt_canceled
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexAttempt
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_active_search_settings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_index import RedisConnectorIndexPayload
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
@@ -57,7 +65,7 @@ from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
class RunIndexingCallback(RunIndexingCallbackInterface):
class IndexingCallback(IndexingHeartbeatInterface):
def __init__(
self,
stop_key: str,
@@ -73,6 +81,7 @@ class RunIndexingCallback(RunIndexingCallbackInterface):
self.started: datetime = datetime.now(timezone.utc)
self.redis_lock.reacquire()
self.last_tag: str = "IndexingCallback.__init__"
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
def should_stop(self) -> bool:
@@ -80,15 +89,17 @@ class RunIndexingCallback(RunIndexingCallbackInterface):
return True
return False
def progress(self, amount: int) -> None:
def progress(self, tag: str, amount: int) -> None:
try:
self.redis_lock.reacquire()
self.last_tag = tag
self.last_lock_reacquire = datetime.now(timezone.utc)
except LockError:
logger.exception(
f"RunIndexingCallback - lock.reacquire exceptioned. "
f"IndexingCallback - lock.reacquire exceptioned. "
f"lock_timeout={self.redis_lock.timeout} "
f"start={self.started} "
f"last_tag={self.last_tag} "
f"last_reacquired={self.last_lock_reacquire} "
f"now={datetime.now(timezone.utc)}"
)
@@ -97,17 +108,65 @@ class RunIndexingCallback(RunIndexingCallbackInterface):
self.redis_client.incrby(self.generator_progress_key, amount)
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
want to clean them up.
Unfenced = attempt not in terminal state and fence does not exist.
"""
unfenced_attempts: list[int] = []
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# inner = index_attempt in non terminal state
# outer = r.fence_key down
# check the db for index attempts in a non terminal state
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for attempt in attempts:
fence_key = RedisConnectorIndex.fence_key_with_ids(
attempt.connector_credential_pair_id, attempt.search_settings_id
)
# if the fence is down / doesn't exist, possible error but not confirmed
if r.exists(fence_key):
continue
# Between the time the attempts are first looked up and the time we see the fence down,
# the attempt may have completed and taken down the fence normally.
# We need to double check that the index attempt is still in a non terminal state
# and matches the original state, which confirms we are really in a bad state.
attempt_2 = get_index_attempt(db_session, attempt.id)
if not attempt_2:
continue
if attempt.status != attempt_2.status:
continue
unfenced_attempts.append(attempt.id)
return unfenced_attempts
@shared_task(
name="check_for_indexing",
name=DanswerCeleryTask.CHECK_FOR_INDEXING,
soft_time_limit=300,
bind=True,
)
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
tasks_created = 0
locked = False
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
lock_beat: RedisLock = r.lock(
DanswerRedisLocks.CHECK_INDEXING_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -117,6 +176,9 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
if not lock_beat.acquire(blocking=False):
return None
locked = True
# check for search settings swap
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
old_search_settings = check_index_swap(db_session=db_session)
current_search_settings = get_current_search_settings(db_session)
@@ -135,26 +197,24 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
embedding_model=embedding_model,
)
# gather cc_pair_ids
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
lock_beat.reacquire()
cc_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
# kick off index attempts
for cc_pair_id in cc_pair_ids:
lock_beat.reacquire()
redis_connector = RedisConnector(tenant_id, cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings = [primary_search_settings]
# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings.append(secondary_search_settings)
for search_settings_instance in search_settings:
search_settings_list: list[SearchSettings] = get_active_search_settings(
db_session
)
for search_settings_instance in search_settings_list:
redis_connector_index = redis_connector.new_index(
search_settings_instance.id
)
@@ -170,22 +230,46 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
last_attempt = get_last_attempt_for_cc_pair(
cc_pair.id, search_settings_instance.id, db_session
)
search_settings_primary = False
if search_settings_instance.id == search_settings_list[0].id:
search_settings_primary = True
if not _should_index(
cc_pair=cc_pair,
last_index=last_attempt,
search_settings_instance=search_settings_instance,
secondary_index_building=len(search_settings) > 1,
search_settings_primary=search_settings_primary,
secondary_index_building=len(search_settings_list) > 1,
db_session=db_session,
):
continue
reindex = False
if search_settings_instance.id == search_settings_list[0].id:
# the indexing trigger is only checked and cleared with the primary search settings
if cc_pair.indexing_trigger is not None:
if cc_pair.indexing_trigger == IndexingMode.REINDEX:
reindex = True
task_logger.info(
f"Connector indexing manual trigger detected: "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings_instance.id} "
f"indexing_mode={cc_pair.indexing_trigger}"
)
mark_ccpair_with_indexing_trigger(
cc_pair.id, None, db_session
)
# using a task queue and only allowing one task per cc_pair/search_setting
# prevents us from starving out certain attempts
attempt_id = try_creating_indexing_task(
self.app,
cc_pair,
search_settings_instance,
False,
reindex,
db_session,
r,
tenant_id,
@@ -195,9 +279,31 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
f"Connector indexing queued: "
f"index_attempt={attempt_id} "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings_instance.id} "
f"search_settings={search_settings_instance.id}"
)
tasks_created += 1
# Fail any index attempts in the DB that don't have fences
# This shouldn't ever happen!
with get_session_with_tenant(tenant_id) as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
lock_beat.reacquire()
attempt = get_index_attempt(db_session, attempt_id)
if not attempt:
continue
failure_reason = (
f"Unfenced index attempt found in DB: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
task_logger.error(failure_reason)
mark_attempt_failed(
attempt.id, db_session, failure_reason=failure_reason
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -205,8 +311,14 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if lock_beat.owned():
lock_beat.release()
if locked:
if lock_beat.owned():
lock_beat.release()
else:
task_logger.error(
"check_for_indexing - Lock not owned on completion: "
f"tenant={tenant_id}"
)
return tasks_created
@@ -215,6 +327,7 @@ def _should_index(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
search_settings_instance: SearchSettings,
search_settings_primary: bool,
secondary_index_building: bool,
db_session: Session,
) -> bool:
@@ -279,6 +392,11 @@ def _should_index(
):
return False
if search_settings_primary:
if cc_pair.indexing_trigger is not None:
# if a manual indexing trigger is on the cc pair, honor it for primary search settings
return True
# if no attempt has ever occurred, we should index regardless of refresh_freq
if not last_index:
return True
@@ -311,10 +429,11 @@ def try_creating_indexing_task(
"""
LOCK_TIMEOUT = 30
index_attempt_id: int | None = None
# we need to serialize any attempt to trigger indexing since it can be triggered
# either via celery beat or manually (API call)
lock = r.lock(
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
timeout=LOCK_TIMEOUT,
)
@@ -365,8 +484,10 @@ def try_creating_indexing_task(
custom_task_id = redis_connector_index.generate_generator_task_id()
# when the task is sent, we have yet to finish setting up the fence
# therefore, the task must contain code that blocks until the fence is ready
result = celery_app.send_task(
"connector_indexing_proxy_task",
DanswerCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
kwargs=dict(
index_attempt_id=index_attempt_id,
cc_pair_id=cc_pair.id,
@@ -385,13 +506,16 @@ def try_creating_indexing_task(
payload.celery_task_id = result.id
redis_connector_index.set_fence(payload)
except Exception:
redis_connector_index.set_fence(None)
task_logger.exception(
f"Unexpected exception: "
f"try_creating_indexing_task - Unexpected exception: "
f"tenant={tenant_id} "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings.id}"
)
if index_attempt_id is not None:
delete_index_attempt(db_session, index_attempt_id)
redis_connector_index.set_fence(None)
return None
finally:
if lock.owned():
@@ -400,8 +524,14 @@ def try_creating_indexing_task(
return index_attempt_id
@shared_task(name="connector_indexing_proxy_task", acks_late=False, track_started=True)
@shared_task(
name=DanswerCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
bind=True,
acks_late=False,
track_started=True,
)
def connector_indexing_proxy_task(
self: Task,
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
@@ -409,15 +539,19 @@ def connector_indexing_proxy_task(
) -> None:
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
task_logger.info(
f"Indexing proxy - starting: attempt={index_attempt_id} "
f"Indexing watchdog - starting: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
if not self.request.id:
task_logger.error("self.request.id is None!")
client = SimpleJobClient()
job = client.submit(
connector_indexing_task,
connector_indexing_task_wrapper,
index_attempt_id,
cc_pair_id,
search_settings_id,
@@ -428,7 +562,7 @@ def connector_indexing_proxy_task(
if not job:
task_logger.info(
f"Indexing proxy - spawn failed: attempt={index_attempt_id} "
f"Indexing watchdog - spawn failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
@@ -436,31 +570,78 @@ def connector_indexing_proxy_task(
return
task_logger.info(
f"Indexing proxy - spawn succeeded: attempt={index_attempt_id} "
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
while True:
sleep(10)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
# do nothing for ongoing jobs that haven't been stopped
if not job.done():
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
while True:
sleep(5)
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
"Indexing watchdog - termination signal detected: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
try:
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_canceled(
index_attempt_id,
db_session,
"Connector termination signal detected",
)
finally:
# if the DB exceptions, we'll just get an unfriendly failure message
# in the UI instead of the cancellation message
logger.exception(
"Indexing watchdog - transient exception marking index attempt as canceled: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
if not index_attempt:
continue
job.cancel()
if not index_attempt.is_finished():
continue
break
if not job.done():
# if the spawned task is still running, restart the check once again
# if the index attempt is not in a finished status
try:
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
except Exception:
# if the DB exceptioned, just restart the check.
# polling the index attempt status doesn't need to be strongly consistent
logger.exception(
"Indexing watchdog - transient exception looking up index attempt: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
continue
if job.status == "error":
task_logger.error(
f"Indexing proxy - spawned task exceptioned: "
"Indexing watchdog - spawned task exceptioned: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
@@ -472,7 +653,7 @@ def connector_indexing_proxy_task(
break
task_logger.info(
f"Indexing proxy - finished: attempt={index_attempt_id} "
f"Indexing watchdog - finished: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
@@ -480,6 +661,38 @@ def connector_indexing_proxy_task(
return
def connector_indexing_task_wrapper(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
is_ee: bool,
) -> int | None:
"""Just wraps connector_indexing_task so we can log any exceptions before
re-raising it."""
result: int | None = None
try:
result = connector_indexing_task(
index_attempt_id,
cc_pair_id,
search_settings_id,
tenant_id,
is_ee,
)
except:
logger.exception(
f"connector_indexing_task exceptioned: "
f"tenant={tenant_id} "
f"index_attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
raise
return result
def connector_indexing_task(
index_attempt_id: int,
cc_pair_id: int,
@@ -534,6 +747,7 @@ def connector_indexing_task(
if redis_connector.delete.fenced:
raise RuntimeError(
f"Indexing will not start because connector deletion is in progress: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.delete.fence_key}"
)
@@ -541,18 +755,18 @@ def connector_indexing_task(
if redis_connector.stop.fenced:
raise RuntimeError(
f"Indexing will not start because a connector stop signal was detected: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.stop.fence_key}"
)
while True:
# wait for the fence to come up
if not redis_connector_index.fenced:
if not redis_connector_index.fenced: # The fence must exist
raise ValueError(
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
)
payload = redis_connector_index.payload
payload = redis_connector_index.payload # The payload must exist
if not payload:
raise ValueError("connector_indexing_task: payload invalid or not found")
@@ -575,16 +789,19 @@ def connector_indexing_task(
)
break
lock = r.lock(
# set thread_local=False since we don't control what thread the indexing/pruning
# might run our callback with
lock: RedisLock = r.lock(
redis_connector_index.generator_lock_key,
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
thread_local=False,
)
acquired = lock.acquire(blocking=False)
if not acquired:
logger.warning(
f"Indexing task already running, exiting...: "
f"cc_pair={cc_pair_id} search_settings={search_settings_id}"
f"index_attempt={index_attempt_id} cc_pair={cc_pair_id} search_settings={search_settings_id}"
)
return None
@@ -619,7 +836,7 @@ def connector_indexing_task(
)
# define a callback class
callback = RunIndexingCallback(
callback = IndexingCallback(
redis_connector.stop.fence_key,
redis_connector_index.generator_progress_key,
lock,

View File

@@ -13,12 +13,13 @@ from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import PostgresAdvisoryLocks
from danswer.db.engine import get_session_with_tenant
@shared_task(
name="kombu_message_cleanup_task",
name=DanswerCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
soft_time_limit=JOB_TIMEOUT,
bind=True,
base=AbortableTask,

View File

@@ -8,11 +8,12 @@ from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.background.celery.tasks.indexing.tasks import RunIndexingCallback
from danswer.background.celery.tasks.indexing.tasks import IndexingCallback
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
@@ -20,6 +21,7 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.models import InputType
@@ -39,7 +41,14 @@ logger = setup_logger()
def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
"""Returns boolean indicating if pruning is due."""
"""Returns boolean indicating if pruning is due.
Next pruning time is calculated as a delta from the last successful prune, or the
last successful indexing if pruning has never succeeded.
TODO(rkuo): consider whether we should allow pruning to be immediately rescheduled
if pruning fails (which is what it does now). A backoff could be reasonable.
"""
# skip pruning if no prune frequency is set
# pruning can still be forced via the API which will run a pruning task directly
@@ -68,7 +77,7 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
@shared_task(
name="check_for_pruning",
name=DanswerCeleryTask.CHECK_FOR_PRUNING,
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
@@ -177,7 +186,7 @@ def try_creating_prune_generator_task(
custom_task_id = f"{redis_connector.prune.generator_task_key}_{uuid4()}"
celery_app.send_task(
"connector_pruning_generator_task",
DanswerCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
kwargs=dict(
cc_pair_id=cc_pair.id,
connector_id=cc_pair.connector_id,
@@ -202,7 +211,7 @@ def try_creating_prune_generator_task(
@shared_task(
name="connector_pruning_generator_task",
name=DanswerCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
acks_late=False,
soft_time_limit=JOB_TIMEOUT,
track_started=True,
@@ -225,13 +234,18 @@ def connector_pruning_generator_task(
pruning_ctx_dict["request_id"] = self.request.id
pruning_ctx.set(pruning_ctx_dict)
task_logger.info(f"Pruning generator starting: cc_pair={cc_pair_id}")
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client(tenant_id=tenant_id)
lock = r.lock(
# set thread_local=False since we don't control what thread the indexing/pruning
# might run our callback with
lock: RedisLock = r.lock(
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.id}",
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
thread_local=False,
)
acquired = lock.acquire(blocking=False)
@@ -255,6 +269,11 @@ def connector_pruning_generator_task(
)
return
task_logger.info(
f"Pruning generator running connector: "
f"cc_pair={cc_pair_id} "
f"connector_source={cc_pair.connector.source}"
)
runnable_connector = instantiate_connector(
db_session,
cc_pair.connector.source,
@@ -263,12 +282,13 @@ def connector_pruning_generator_task(
cc_pair.credential,
)
callback = RunIndexingCallback(
callback = IndexingCallback(
redis_connector.stop.fence_key,
redis_connector.prune.generator_progress_key,
lock,
r,
)
# a list of docs in the source
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
runnable_connector, callback
@@ -290,8 +310,8 @@ def connector_pruning_generator_task(
task_logger.info(
f"Pruning set collected: "
f"cc_pair={cc_pair_id} "
f"docs_to_remove={len(doc_ids_to_remove)} "
f"doc_source={cc_pair.connector.source}"
f"connector_source={cc_pair.connector.source} "
f"docs_to_remove={len(doc_ids_to_remove)}"
)
task_logger.info(
@@ -314,10 +334,10 @@ def connector_pruning_generator_task(
f"Failed to run pruning: cc_pair={cc_pair_id} connector={connector_id}"
)
redis_connector.prune.generator_clear()
redis_connector.prune.taskset_clear()
redis_connector.prune.set_fence(False)
redis_connector.prune.reset()
raise e
finally:
if lock.owned():
lock.release()
task_logger.info(f"Pruning generator finished: cc_pair={cc_pair_id}")

View File

@@ -9,6 +9,7 @@ from tenacity import RetryError
from danswer.access.access import get_access_for_document
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from danswer.configs.constants import DanswerCeleryTask
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
from danswer.db.document import delete_documents_complete__no_commit
from danswer.db.document import get_document
@@ -31,7 +32,7 @@ LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
@shared_task(
name="document_by_cc_pair_cleanup_task",
name=DanswerCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
time_limit=LIGHT_TIME_LIMIT,
max_retries=DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES,

View File

@@ -5,7 +5,6 @@ from http import HTTPStatus
from typing import cast
import httpx
import redis
from celery import Celery
from celery import shared_task
from celery import Task
@@ -26,6 +25,7 @@ from danswer.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerCeleryTask
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector import mark_cc_pair_as_permissions_synced
@@ -49,11 +49,9 @@ from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import IndexingStatus
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.index_attempt import get_all_index_attempts_by_status
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import DocumentSet
from danswer.db.models import IndexAttempt
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaDocumentFields
@@ -62,7 +60,7 @@ from danswer.redis.redis_connector_credential_pair import RedisConnectorCredenti
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from danswer.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncData,
RedisConnectorPermissionSyncPayload,
)
from danswer.redis.redis_connector_index import RedisConnectorIndex
from danswer.redis.redis_connector_prune import RedisConnectorPrune
@@ -83,7 +81,7 @@ logger = setup_logger()
# celery auto associates tasks created inside another task,
# which bloats the result metadata considerably. trail=False prevents this.
@shared_task(
name="check_for_vespa_sync_task",
name=DanswerCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,
@@ -592,7 +590,7 @@ def monitor_ccpair_permissions_taskset(
if remaining > 0:
return
payload: RedisConnectorPermissionSyncData | None = (
payload: RedisConnectorPermissionSyncPayload | None = (
redis_connector.permissions.payload
)
start_time: datetime | None = payload.started if payload else None
@@ -600,9 +598,7 @@ def monitor_ccpair_permissions_taskset(
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
redis_connector.permissions.taskset_clear()
redis_connector.permissions.generator_clear()
redis_connector.permissions.set_fence(None)
redis_connector.permissions.reset()
def monitor_ccpair_indexing_taskset(
@@ -649,38 +645,52 @@ def monitor_ccpair_indexing_taskset(
# the task is still setting up
return
# Read result state BEFORE generator_complete_key to avoid a race condition
# never use any blocking methods on the result from inside a task!
result: AsyncResult = AsyncResult(payload.celery_task_id)
result_state = result.state
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# inner = get_completion / generator_complete not signaled
# outer = result.state in READY state
status_int = redis_connector_index.get_completion()
if status_int is None: # completion signal not set ... check for errors
# If we get here, and then the task both sets the completion signal and finishes,
# we will incorrectly abort the task. We must check result state, then check
# get_completion again to avoid the race condition.
if result_state in READY_STATES:
if status_int is None: # inner signal not set ... possible error
task_state = result.state
if (
task_state in READY_STATES
): # outer signal in terminal state ... possible error
# Now double check!
if redis_connector_index.get_completion() is None:
# IF the task state is READY, THEN generator_complete should be set
# if it isn't, then the worker crashed
# inner signal still not set (and cannot change when outer result_state is READY)
# Task is finished but generator complete isn't set.
# We have a problem! Worker may have crashed.
task_result = str(result.result)
task_traceback = str(result.traceback)
msg = (
f"Connector indexing aborted or exceptioned: "
f"attempt={payload.index_attempt_id} "
f"celery_task={payload.celery_task_id} "
f"result_state={result_state} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
f"result.state={task_state} "
f"result.result={task_result} "
f"result.traceback={task_traceback}"
)
task_logger.warning(msg)
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
if index_attempt:
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason=msg,
)
if (
index_attempt.status != IndexingStatus.CANCELED
and index_attempt.status != IndexingStatus.FAILED
):
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason=msg,
)
redis_connector_index.reset()
return
@@ -690,6 +700,7 @@ def monitor_ccpair_indexing_taskset(
task_logger.info(
f"Connector indexing finished: cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"status={status_enum.name} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
@@ -697,38 +708,7 @@ def monitor_ccpair_indexing_taskset(
redis_connector_index.reset()
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
want to clean them up.
Unfenced = attempt not in terminal state and fence does not exist.
"""
unfenced_attempts: list[int] = []
# do some cleanup before clearing fences
# check the db for any outstanding index attempts
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for attempt in attempts:
# if attempts exist in the db but we don't detect them in redis, mark them as failed
fence_key = RedisConnectorIndex.fence_key_with_ids(
attempt.connector_credential_pair_id, attempt.search_settings_id
)
if r.exists(fence_key):
continue
unfenced_attempts.append(attempt.id)
return unfenced_attempts
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
@shared_task(name=DanswerCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True)
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
It scans for fence values and then gets the counts of any associated tasksets.
@@ -753,7 +733,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
# print current queue lengths
r_celery = self.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r)
n_celery = celery_get_queue_length("celery", r_celery)
n_indexing = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_INDEXING, r_celery
)
@@ -779,25 +759,6 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
f"permissions_sync={n_permissions_sync} "
)
# Fail any index attempts in the DB that don't have fences
with get_session_with_tenant(tenant_id) as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
attempt = get_index_attempt(db_session, attempt_id)
if not attempt:
continue
failure_reason = (
f"Unfenced index attempt found in DB: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
task_logger.warning(failure_reason)
mark_attempt_failed(
attempt.id, db_session, failure_reason=failure_reason
)
lock_beat.reacquire()
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
@@ -858,7 +819,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
@shared_task(
name="vespa_metadata_sync_task",
name=DanswerCeleryTask.VESPA_METADATA_SYNC_TASK,
bind=True,
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
time_limit=LIGHT_TIME_LIMIT,

View File

@@ -1,6 +1,8 @@
"""Factory stub for running celery worker / celery beat."""
from celery import Celery
from danswer.background.celery.apps.beat import celery_app
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app = celery_app
app: Celery = celery_app

View File

@@ -1,8 +1,10 @@
"""Factory stub for running celery worker / celery beat."""
from celery import Celery
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
app = fetch_versioned_implementation(
app: Celery = fetch_versioned_implementation(
"danswer.background.celery.apps.primary", "celery_app"
)

View File

@@ -1,7 +1,5 @@
import time
import traceback
from abc import ABC
from abc import abstractmethod
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -21,6 +19,7 @@ from danswer.db.connector_credential_pair import get_last_successful_attempt_tim
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.index_attempt import mark_attempt_canceled
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_partially_succeeded
from danswer.db.index_attempt import mark_attempt_succeeded
@@ -31,7 +30,7 @@ from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.embedder import DefaultIndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logger import setup_logger
from danswer.utils.logger import TaskAttemptSingleton
@@ -42,19 +41,6 @@ logger = setup_logger()
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
class RunIndexingCallbackInterface(ABC):
"""Defines a callback interface to be passed to
to run_indexing_entrypoint."""
@abstractmethod
def should_stop(self) -> bool:
"""Signal to stop the looping function in flight."""
@abstractmethod
def progress(self, amount: int) -> None:
"""Send progress updates to the caller."""
def _get_connector_runner(
db_session: Session,
attempt: IndexAttempt,
@@ -102,11 +88,15 @@ def _get_connector_runner(
)
class ConnectorStopSignal(Exception):
"""A custom exception used to signal a stop in processing."""
def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
tenant_id: str | None,
callback: RunIndexingCallbackInterface | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
"""
1. Get documents which are either new or updated from specified application
@@ -138,13 +128,7 @@ def _run_indexing(
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings,
heartbeat=IndexingHeartbeat(
index_attempt_id=index_attempt.id,
db_session=db_session,
# let the world know we're still making progress after
# every 10 batches
freq=10,
),
callback=callback,
)
indexing_pipeline = build_indexing_pipeline(
@@ -157,6 +141,7 @@ def _run_indexing(
),
db_session=db_session,
tenant_id=tenant_id,
callback=callback,
)
db_cc_pair = index_attempt.connector_credential_pair
@@ -228,7 +213,7 @@ def _run_indexing(
# contents still need to be initially pulled.
if callback:
if callback.should_stop():
raise RuntimeError("Connector stop signal detected")
raise ConnectorStopSignal("Connector stop signal detected")
# TODO: should we move this into the above callback instead?
db_session.refresh(db_cc_pair)
@@ -289,7 +274,7 @@ def _run_indexing(
db_session.commit()
if callback:
callback.progress(len(doc_batch))
callback.progress("_run_indexing", len(doc_batch))
# This new value is updated every batch, so UI can refresh per batch update
update_docs_indexed(
@@ -322,26 +307,16 @@ def _run_indexing(
)
except Exception as e:
logger.exception(
f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds"
f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds"
)
# Only mark the attempt as a complete failure if this is the first indexing window.
# Otherwise, some progress was made - the next run will not start from the beginning.
# In this case, it is not accurate to mark it as a failure. When the next run begins,
# if that fails immediately, it will be marked as a failure.
#
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or not db_cc_pair.status.is_active()
or index_attempt.status != IndexingStatus.IN_PROGRESS
):
mark_attempt_failed(
if isinstance(e, ConnectorStopSignal):
mark_attempt_canceled(
index_attempt.id,
db_session,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
reason=str(e),
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
@@ -353,6 +328,37 @@ def _run_indexing(
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
else:
# Only mark the attempt as a complete failure if this is the first indexing window.
# Otherwise, some progress was made - the next run will not start from the beginning.
# In this case, it is not accurate to mark it as a failure. When the next run begins,
# if that fails immediately, it will be marked as a failure.
#
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or not db_cc_pair.status.is_active()
or index_attempt.status != IndexingStatus.IN_PROGRESS
):
mark_attempt_failed(
index_attempt.id,
db_session,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
net_docs=net_doc_change,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
# break => similar to success case. As mentioned above, if the next run fails for the same
# reason it will then be marked as a failure
@@ -419,7 +425,7 @@ def run_indexing_entrypoint(
tenant_id: str | None,
connector_credential_pair_id: int,
is_ee: bool = False,
callback: RunIndexingCallbackInterface | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
try:
if is_ee:

View File

@@ -2,20 +2,79 @@ import re
from typing import cast
from uuid import UUID
from fastapi import HTTPException
from fastapi.datastructures import Headers
from sqlalchemy.orm import Session
from danswer.auth.users import is_user_admin
from danswer.chat.models import CitationInfo
from danswer.chat.models import LlmDoc
from danswer.chat.models import PersonaOverrideConfig
from danswer.chat.models import ThreadMessage
from danswer.configs.constants import DEFAULT_PERSONA_ID
from danswer.configs.constants import MessageType
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import RerankingDetails
from danswer.context.search.models import RetrievalDetails
from danswer.db.chat import create_chat_session
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.llm import fetch_existing_doc_sets
from danswer.db.llm import fetch_existing_tools
from danswer.db.models import ChatMessage
from danswer.db.models import Persona
from danswer.db.models import Prompt
from danswer.db.models import Tool
from danswer.db.models import User
from danswer.db.persona import get_prompts_by_ids
from danswer.llm.answering.models import PreviousMessage
from danswer.search.models import InferenceSection
from danswer.natural_language_processing.utils import BaseTokenizer
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.utils.logger import setup_logger
logger = setup_logger()
def prepare_chat_message_request(
message_text: str,
user: User | None,
persona_id: int | None,
# Does the question need to have a persona override
persona_override_config: PersonaOverrideConfig | None,
prompt: Prompt | None,
message_ts_to_respond_to: str | None,
retrieval_details: RetrievalDetails | None,
rerank_settings: RerankingDetails | None,
db_session: Session,
) -> CreateChatMessageRequest:
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
new_chat_session = create_chat_session(
db_session=db_session,
description=None,
user_id=user.id if user else None,
# If using an override, this id will be ignored later on
persona_id=persona_id or DEFAULT_PERSONA_ID,
danswerbot_flow=True,
slack_thread_id=message_ts_to_respond_to,
)
return CreateChatMessageRequest(
chat_session_id=new_chat_session.id,
parent_message_id=None, # It's a standalone chat session each time
message=message_text,
file_descriptors=[], # Currently SlackBot/answer api do not support files in the context
prompt_id=prompt.id if prompt else None,
# Can always override the persona for the single query, if it's a normal persona
# then it will be treated the same
persona_override_config=persona_override_config,
search_doc_ids=None,
retrieval_options=retrieval_details,
rerank_settings=rerank_settings,
)
def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDoc:
return LlmDoc(
document_id=inference_section.center_chunk.document_id,
@@ -31,9 +90,49 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
if inference_section.center_chunk.source_links
else None,
source_links=inference_section.center_chunk.source_links,
match_highlights=inference_section.center_chunk.match_highlights,
)
def combine_message_thread(
messages: list[ThreadMessage],
max_tokens: int | None,
llm_tokenizer: BaseTokenizer,
) -> str:
"""Used to create a single combined message context from threads"""
if not messages:
return ""
message_strs: list[str] = []
total_token_count = 0
for message in reversed(messages):
if message.role == MessageType.USER:
role_str = message.role.value.upper()
if message.sender:
role_str += " " + message.sender
else:
# Since other messages might have the user identifying information
# better to use Unknown for symmetry
role_str += " Unknown"
else:
role_str = message.role.value.upper()
msg_str = f"{role_str}:\n{message.message}"
message_token_count = len(llm_tokenizer.encode(msg_str))
if (
max_tokens is not None
and total_token_count + message_token_count > max_tokens
):
break
message_strs.insert(0, msg_str)
total_token_count += message_token_count
return "\n\n".join(message_strs)
def create_chat_chain(
chat_session_id: UUID,
db_session: Session,
@@ -196,3 +295,71 @@ def extract_headers(
if lowercase_key in headers:
extracted_headers[lowercase_key] = headers[lowercase_key]
return extracted_headers
def create_temporary_persona(
persona_config: PersonaOverrideConfig, db_session: Session, user: User | None = None
) -> Persona:
if not is_user_admin(user):
raise HTTPException(
status_code=403,
detail="User is not authorized to create a persona in one shot queries",
)
"""Create a temporary Persona object from the provided configuration."""
persona = Persona(
name=persona_config.name,
description=persona_config.description,
num_chunks=persona_config.num_chunks,
llm_relevance_filter=persona_config.llm_relevance_filter,
llm_filter_extraction=persona_config.llm_filter_extraction,
recency_bias=persona_config.recency_bias,
llm_model_provider_override=persona_config.llm_model_provider_override,
llm_model_version_override=persona_config.llm_model_version_override,
)
if persona_config.prompts:
persona.prompts = [
Prompt(
name=p.name,
description=p.description,
system_prompt=p.system_prompt,
task_prompt=p.task_prompt,
include_citations=p.include_citations,
datetime_aware=p.datetime_aware,
)
for p in persona_config.prompts
]
elif persona_config.prompt_ids:
persona.prompts = get_prompts_by_ids(
db_session=db_session, prompt_ids=persona_config.prompt_ids
)
persona.tools = []
if persona_config.custom_tools_openapi:
for schema in persona_config.custom_tools_openapi:
tools = cast(
list[Tool],
build_custom_tools_from_openapi_schema_and_headers(schema),
)
persona.tools.extend(tools)
if persona_config.tools:
tool_ids = [tool.id for tool in persona_config.tools]
persona.tools.extend(
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
)
if persona_config.tool_ids:
persona.tools.extend(
fetch_existing_tools(
db_session=db_session, tool_ids=persona_config.tool_ids
)
)
fetched_docs = fetch_existing_doc_sets(
db_session=db_session, doc_ids=persona_config.document_set_ids
)
persona.document_sets = fetched_docs
return persona

View File

@@ -4,12 +4,14 @@ from enum import Enum
from typing import Any
from pydantic import BaseModel
from pydantic import Field
from danswer.configs.constants import DocumentSource
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import RetrievalDocs
from danswer.search.models import SearchResponse
from danswer.configs.constants import MessageType
from danswer.context.search.enums import QueryFlow
from danswer.context.search.enums import RecencyBiasSetting
from danswer.context.search.enums import SearchType
from danswer.context.search.models import RetrievalDocs
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType
@@ -25,6 +27,7 @@ class LlmDoc(BaseModel):
updated_at: datetime | None
link: str | None
source_links: dict[int, str] | None
match_highlights: list[str] | None
# First chunk of info for streaming QA
@@ -117,20 +120,6 @@ class StreamingError(BaseModel):
stack_trace: str | None = None
class DanswerQuote(BaseModel):
# This is during inference so everything is a string by this point
quote: str
document_id: str
link: str | None
source_type: str
semantic_identifier: str
blurb: str
class DanswerQuotes(BaseModel):
quotes: list[DanswerQuote]
class DanswerContext(BaseModel):
content: str
document_id: str
@@ -146,14 +135,20 @@ class DanswerAnswer(BaseModel):
answer: str | None
class QAResponse(SearchResponse, DanswerAnswer):
quotes: list[DanswerQuote] | None
contexts: list[DanswerContexts] | None
predicted_flow: QueryFlow
predicted_search: SearchType
eval_res_valid: bool | None = None
class ThreadMessage(BaseModel):
message: str
sender: str | None = None
role: MessageType = MessageType.USER
class ChatDanswerBotResponse(BaseModel):
answer: str | None = None
citations: list[CitationInfo] | None = None
docs: QADocsResponse | None = None
llm_selected_doc_indices: list[int] | None = None
error_msg: str | None = None
chat_message_id: int | None = None
answer_valid: bool = True # Reflexion result, default True if Reflexion not run
class FileChatDisplay(BaseModel):
@@ -165,9 +160,41 @@ class CustomToolResponse(BaseModel):
tool_name: str
class ToolConfig(BaseModel):
id: int
class PromptOverrideConfig(BaseModel):
name: str
description: str = ""
system_prompt: str
task_prompt: str = ""
include_citations: bool = True
datetime_aware: bool = True
class PersonaOverrideConfig(BaseModel):
name: str
description: str
search_type: SearchType = SearchType.SEMANTIC
num_chunks: float | None = None
llm_relevance_filter: bool = False
llm_filter_extraction: bool = False
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
llm_model_provider_override: str | None = None
llm_model_version_override: str | None = None
prompts: list[PromptOverrideConfig] = Field(default_factory=list)
prompt_ids: list[int] = Field(default_factory=list)
document_set_ids: list[int] = Field(default_factory=list)
tools: list[ToolConfig] = Field(default_factory=list)
tool_ids: list[int] = Field(default_factory=list)
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
AnswerQuestionPossibleReturn = (
DanswerAnswerPiece
| DanswerQuotes
| CitationInfo
| DanswerContexts
| FileChatDisplay

View File

@@ -7,10 +7,13 @@ from typing import cast
from sqlalchemy.orm import Session
from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.chat_utils import create_temporary_persona
from danswer.chat.models import AllCitations
from danswer.chat.models import ChatDanswerBotResponse
from danswer.chat.models import CitationInfo
from danswer.chat.models import CustomToolResponse
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerContexts
from danswer.chat.models import FileChatDisplay
from danswer.chat.models import FinalUsedContextDocsResponse
from danswer.chat.models import LLMRelevanceFilterResponse
@@ -23,6 +26,16 @@ from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.constants import MessageType
from danswer.context.search.enums import OptionalSearchSetting
from danswer.context.search.enums import QueryFlow
from danswer.context.search.enums import SearchType
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import RetrievalDetails
from danswer.context.search.retrieval.search_runner import inference_sections_from_ids
from danswer.context.search.utils import chunks_or_sections_to_search_docs
from danswer.context.search.utils import dedupe_documents
from danswer.context.search.utils import drop_llm_indices
from danswer.context.search.utils import relevant_sections_to_indices
from danswer.db.chat import attach_files_to_chat_message
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
@@ -56,16 +69,6 @@ from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.utils import litellm_exception_to_error_msg
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import InferenceSection
from danswer.search.models import RetrievalDetails
from danswer.search.retrieval.search_runner import inference_sections_from_ids
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.search.utils import dedupe_documents
from danswer.search.utils import drop_llm_indices
from danswer.search.utils import relevant_sections_to_indices
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
@@ -102,6 +105,7 @@ from danswer.tools.tool_implementations.internet_search.internet_search_tool imp
from danswer.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from danswer.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
from danswer.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
@@ -113,8 +117,10 @@ from danswer.tools.tool_implementations.search.search_tool import (
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.utils.logger import setup_logger
from danswer.utils.long_term_log import LongTermLogger
from danswer.utils.timing import log_function_time
from danswer.utils.timing import log_generator_function_time
logger = setup_logger()
@@ -256,6 +262,7 @@ def _get_force_search_settings(
ChatPacket = (
StreamingError
| QADocsResponse
| DanswerContexts
| LLMRelevanceFilterResponse
| FinalUsedContextDocsResponse
| ChatMessageDetail
@@ -286,6 +293,8 @@ def stream_chat_message_objects(
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
enforce_chat_session_id_for_search_docs: bool = True,
bypass_acl: bool = False,
include_contexts: bool = False,
) -> ChatPacketStream:
"""Streams in order:
1. [conditional] Retrieved documents if a search needs to be run
@@ -322,17 +331,31 @@ def stream_chat_message_objects(
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
)
# use alternate persona if alternative assistant id is passed in
if alternate_assistant_id is not None:
# Allows users to specify a temporary persona (assistant) in the chat session
# this takes highest priority since it's user specified
persona = get_persona_by_id(
alternate_assistant_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
elif new_msg_req.persona_override_config:
# Certain endpoints allow users to specify arbitrary persona settings
# this should never conflict with the alternate_assistant_id
persona = persona = create_temporary_persona(
db_session=db_session,
persona_config=new_msg_req.persona_override_config,
user=user,
)
else:
persona = chat_session.persona
if not persona:
raise RuntimeError("No persona specified or found for chat session")
# If a prompt override is specified via the API, use that with highest priority
# but for saving it, we are just mapping it to an existing prompt
prompt_id = new_msg_req.prompt_id
if prompt_id is None and persona.prompts:
prompt_id = sorted(persona.prompts, key=lambda x: x.id)[-1].id
@@ -555,19 +578,34 @@ def stream_chat_message_objects(
reserved_message_id=reserved_message_id,
)
if not final_msg.prompt:
raise RuntimeError("No Prompt found")
prompt_config = (
PromptConfig.from_model(
final_msg.prompt,
prompt_override=(
new_msg_req.prompt_override or chat_session.prompt_override
),
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
if new_msg_req.persona_override_config:
prompt_config = PromptConfig(
system_prompt=new_msg_req.persona_override_config.prompts[
0
].system_prompt,
task_prompt=new_msg_req.persona_override_config.prompts[0].task_prompt,
datetime_aware=new_msg_req.persona_override_config.prompts[
0
].datetime_aware,
include_citations=new_msg_req.persona_override_config.prompts[
0
].include_citations,
)
if not persona
else PromptConfig.from_model(persona.prompts[0])
)
elif prompt_override:
if not final_msg.prompt:
raise ValueError(
"Prompt override cannot be applied, no base prompt found."
)
prompt_config = PromptConfig.from_model(
final_msg.prompt,
prompt_override=prompt_override,
)
elif final_msg.prompt:
prompt_config = PromptConfig.from_model(final_msg.prompt)
else:
prompt_config = PromptConfig.from_model(persona.prompts[0])
answer_style_config = AnswerStyleConfig(
citation_config=CitationConfig(
all_docs_useful=selected_db_search_docs is not None
@@ -587,11 +625,13 @@ def stream_chat_message_objects(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,
retrieval_options=retrieval_options or RetrievalDetails(),
rerank_settings=new_msg_req.rerank_settings,
selected_sections=selected_sections,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
latest_query_files=latest_query_files,
bypass_acl=bypass_acl,
),
internet_search_tool_config=InternetSearchToolConfig(
answer_style_config=answer_style_config,
@@ -605,6 +645,7 @@ def stream_chat_message_objects(
additional_headers=custom_tool_additional_headers,
),
)
tools: list[Tool] = []
for tool_list in tool_dict.values():
tools.extend(tool_list)
@@ -736,6 +777,8 @@ def stream_chat_message_objects(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
elif packet.id == SEARCH_DOC_CONTENT_ID and include_contexts:
yield cast(DanswerContexts, packet.response)
elif isinstance(packet, StreamStopInfo):
pass
@@ -844,3 +887,30 @@ def stream_chat_message(
)
for obj in objects:
yield get_json_line(obj.model_dump())
@log_function_time()
def gather_stream_for_slack(
packets: ChatPacketStream,
) -> ChatDanswerBotResponse:
response = ChatDanswerBotResponse()
answer = ""
for packet in packets:
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
elif isinstance(packet, QADocsResponse):
response.docs = packet
elif isinstance(packet, StreamingError):
response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail):
response.chat_message_id = packet.message_id
elif isinstance(packet, LLMRelevanceFilterResponse):
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
elif isinstance(packet, AllCitations):
response.citations = packet.citations
if answer:
response.answer = answer
return response

View File

@@ -1,115 +0,0 @@
from typing_extensions import TypedDict # noreorder
from pydantic import BaseModel
from danswer.prompts.chat_tools import DANSWER_TOOL_DESCRIPTION
from danswer.prompts.chat_tools import DANSWER_TOOL_NAME
from danswer.prompts.chat_tools import TOOL_FOLLOWUP
from danswer.prompts.chat_tools import TOOL_LESS_FOLLOWUP
from danswer.prompts.chat_tools import TOOL_LESS_PROMPT
from danswer.prompts.chat_tools import TOOL_TEMPLATE
from danswer.prompts.chat_tools import USER_INPUT
class ToolInfo(TypedDict):
name: str
description: str
class DanswerChatModelOut(BaseModel):
model_raw: str
action: str
action_input: str
def call_tool(
model_actions: DanswerChatModelOut,
) -> str:
raise NotImplementedError("There are no additional tool integrations right now")
def form_user_prompt_text(
query: str,
tool_text: str | None,
hint_text: str | None,
user_input_prompt: str = USER_INPUT,
tool_less_prompt: str = TOOL_LESS_PROMPT,
) -> str:
user_prompt = tool_text or tool_less_prompt
user_prompt += user_input_prompt.format(user_input=query)
if hint_text:
if user_prompt[-1] != "\n":
user_prompt += "\n"
user_prompt += "\nHint: " + hint_text
return user_prompt.strip()
def form_tool_section_text(
tools: list[ToolInfo] | None, retrieval_enabled: bool, template: str = TOOL_TEMPLATE
) -> str | None:
if not tools and not retrieval_enabled:
return None
if retrieval_enabled and tools:
tools.append(
{"name": DANSWER_TOOL_NAME, "description": DANSWER_TOOL_DESCRIPTION}
)
tools_intro = []
if tools:
num_tools = len(tools)
for tool in tools:
description_formatted = tool["description"].replace("\n", " ")
tools_intro.append(f"> {tool['name']}: {description_formatted}")
prefix = "Must be one of " if num_tools > 1 else "Must be "
tools_intro_text = "\n".join(tools_intro)
tool_names_text = prefix + ", ".join([tool["name"] for tool in tools])
else:
return None
return template.format(
tool_overviews=tools_intro_text, tool_names=tool_names_text
).strip()
def form_tool_followup_text(
tool_output: str,
query: str,
hint_text: str | None,
tool_followup_prompt: str = TOOL_FOLLOWUP,
ignore_hint: bool = False,
) -> str:
# If multi-line query, it likely confuses the model more than helps
if "\n" not in query:
optional_reminder = f"\nAs a reminder, my query was: {query}\n"
else:
optional_reminder = ""
if not ignore_hint and hint_text:
hint_text_spaced = f"\nHint: {hint_text}\n"
else:
hint_text_spaced = ""
return tool_followup_prompt.format(
tool_output=tool_output,
optional_reminder=optional_reminder,
hint=hint_text_spaced,
).strip()
def form_tool_less_followup_text(
tool_output: str,
query: str,
hint_text: str | None,
tool_followup_prompt: str = TOOL_LESS_FOLLOWUP,
) -> str:
hint = f"Hint: {hint_text}" if hint_text else ""
return tool_followup_prompt.format(
context_str=tool_output, user_query=query, hint_text=hint
).strip()

View File

@@ -234,7 +234,7 @@ except ValueError:
CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT
)
CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT = 1
CELERY_WORKER_INDEXING_CONCURRENCY_DEFAULT = 3
try:
env_value = os.environ.get("CELERY_WORKER_INDEXING_CONCURRENCY")
if not env_value:
@@ -308,6 +308,22 @@ CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000)
)
# Due to breakages in the confluence API, the timezone offset must be specified client side
# to match the user's specified timezone.
# The current state of affairs:
# CQL queries are parsed in the user's timezone and cannot be specified in UTC
# no API retrieves the user's timezone
# All data is returned in UTC, so we can't derive the user's timezone from that
# https://community.developer.atlassian.com/t/confluence-cloud-time-zone-get-via-rest-api/35954/16
# https://jira.atlassian.com/browse/CONFCLOUD-69670
# enter as a floating point offset from UTC in hours (-24 < val < 24)
# this will be applied globally, so it probably makes sense to transition this to per
# connector as some point.
CONFLUENCE_TIMEZONE_OFFSET = float(os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", 0.0))
JIRA_CONNECTOR_LABELS_TO_SKIP = [
ignored_tag
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
@@ -422,6 +438,9 @@ LOG_ALL_MODEL_INTERACTIONS = (
LOG_DANSWER_MODEL_INTERACTIONS = (
os.environ.get("LOG_DANSWER_MODEL_INTERACTIONS", "").lower() == "true"
)
LOG_INDIVIDUAL_MODEL_TOKENS = (
os.environ.get("LOG_INDIVIDUAL_MODEL_TOKENS", "").lower() == "true"
)
# If set to `true` will enable additional logs about Vespa query performance
# (time spent on finding the right docs + time spent fetching summaries from disk)
LOG_VESPA_TIMING_INFORMATION = (
@@ -490,10 +509,6 @@ CONTROL_PLANE_API_BASE_URL = os.environ.get(
# JWT configuration
JWT_ALGORITHM = "HS256"
# Super Users
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
#####
# API Key Configs
@@ -507,3 +522,6 @@ API_KEY_HASH_ROUNDS = (
POD_NAME = os.environ.get("POD_NAME")
POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"

View File

@@ -1,9 +1,9 @@
import os
PROMPTS_YAML = "./danswer/chat/prompts.yaml"
PERSONAS_YAML = "./danswer/chat/personas.yaml"
INPUT_PROMPT_YAML = "./danswer/chat/input_prompts.yaml"
PROMPTS_YAML = "./danswer/seeding/prompts.yaml"
PERSONAS_YAML = "./danswer/seeding/personas.yaml"
INPUT_PROMPT_YAML = "./danswer/seeding/input_prompts.yaml"
NUM_RETURNED_HITS = 50
# Used for LLM filtering and reranking
@@ -17,9 +17,6 @@ MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0)
# ~3k input, half for docs, half for chat history + prompts
CHAT_TARGET_CHUNK_PERCENTAGE = 512 * 3 / 3072
# For selecting a different LLM question-answering prompt format
# Valid values: default, cot, weak
QA_PROMPT_OVERRIDE = os.environ.get("QA_PROMPT_OVERRIDE") or None
# 1 / (1 + DOC_TIME_DECAY * doc-age-in-years), set to 0 to have no decay
# Capped in Vespa at 0.5
DOC_TIME_DECAY = float(
@@ -27,8 +24,6 @@ DOC_TIME_DECAY = float(
)
BASE_RECENCY_DECAY = 0.5
FAVOR_RECENT_DECAY_MULTIPLIER = 2.0
# Currently this next one is not configurable via env
DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak"
# For the highest matching base size chunk, how many chunks above and below do we pull in by default
# Note this is not in any of the deployment configs yet
# Currently only applies to search flow not chat

View File

@@ -31,6 +31,8 @@ DISABLED_GEN_AI_MSG = (
"You can still use Danswer as a search engine."
)
DEFAULT_PERSONA_ID = 0
# Postgres connection constants for application_name
POSTGRES_WEB_APP_NAME = "web"
POSTGRES_INDEXER_APP_NAME = "indexer"
@@ -60,7 +62,6 @@ KV_GMAIL_CRED_KEY = "gmail_app_credential"
KV_GMAIL_SERVICE_ACCOUNT_KEY = "gmail_service_account_key"
KV_GOOGLE_DRIVE_CRED_KEY = "google_drive_app_credential"
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
KV_SLACK_BOT_TOKENS_CONFIG_KEY = "slack_bot_tokens_config_key"
KV_GEN_AI_KEY_CHECK_TIME = "genai_api_key_last_check_time"
KV_SETTINGS_KEY = "danswer_settings"
KV_CUSTOMER_UUID_KEY = "customer_uuid"
@@ -260,6 +261,32 @@ class DanswerCeleryPriority(int, Enum):
LOWEST = auto()
class DanswerCeleryTask:
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
CHECK_FOR_INDEXING = "check_for_indexing"
CHECK_FOR_PRUNING = "check_for_pruning"
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
"connector_permission_sync_generator_task"
)
UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK = (
"update_external_document_permissions_task"
)
CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK = (
"connector_external_group_sync_generator_task"
)
CONNECTOR_INDEXING_PROXY_TASK = "connector_indexing_proxy_task"
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"
CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task"
AUTOGENERATE_USAGE_REPORT_TASK = "autogenerate_usage_report_task"
REDIS_SOCKET_KEEPALIVE_OPTIONS = {}
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3

View File

@@ -4,11 +4,8 @@ import os
# Danswer Slack Bot Configs
#####
DANSWER_BOT_NUM_RETRIES = int(os.environ.get("DANSWER_BOT_NUM_RETRIES", "5"))
DANSWER_BOT_ANSWER_GENERATION_TIMEOUT = int(
os.environ.get("DANSWER_BOT_ANSWER_GENERATION_TIMEOUT", "90")
)
# How much of the available input context can be used for thread context
DANSWER_BOT_TARGET_CHUNK_PERCENTAGE = 512 * 2 / 3072
MAX_THREAD_CONTEXT_PERCENTAGE = 512 * 2 / 3072
# Number of docs to display in "Reference Documents"
DANSWER_BOT_NUM_DOCS_TO_DISPLAY = int(
os.environ.get("DANSWER_BOT_NUM_DOCS_TO_DISPLAY", "5")
@@ -47,17 +44,6 @@ DANSWER_BOT_DISPLAY_ERROR_MSGS = os.environ.get(
DANSWER_BOT_RESPOND_EVERY_CHANNEL = (
os.environ.get("DANSWER_BOT_RESPOND_EVERY_CHANNEL", "").lower() == "true"
)
# Add a second LLM call post Answer to verify if the Answer is valid
# Throws out answers that don't directly or fully answer the user query
# This is the default for all DanswerBot channels unless the channel is configured individually
# Set/unset by "Hide Non Answers"
ENABLE_DANSWERBOT_REFLEXION = (
os.environ.get("ENABLE_DANSWERBOT_REFLEXION", "").lower() == "true"
)
# Currently not support chain of thought, probably will add back later
DANSWER_BOT_DISABLE_COT = True
# if set, will default DanswerBot to use quotes and reference documents
DANSWER_BOT_USE_QUOTES = os.environ.get("DANSWER_BOT_USE_QUOTES", "").lower() == "true"
# Maximum Questions Per Minute, Default Uncapped
DANSWER_BOT_MAX_QPM = int(os.environ.get("DANSWER_BOT_MAX_QPM") or 0) or None

View File

@@ -70,7 +70,9 @@ GEN_AI_NUM_RESERVED_OUTPUT_TOKENS = int(
)
# Typically, GenAI models nowadays are at least 4K tokens
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = 4096
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = int(
os.environ.get("GEN_AI_MODEL_FALLBACK_MAX_TOKENS") or 4096
)
# Number of tokens from chat history to include at maximum
# 3000 should be enough context regardless of use, no need to include as much as possible

View File

@@ -11,11 +11,16 @@ Connectors come in 3 different flows:
- Load Connector:
- Bulk indexes documents to reflect a point in time. This type of connector generally works by either pulling all
documents via a connector's API or loads the documents from some sort of a dump file.
- Poll connector:
- Poll Connector:
- Incrementally updates documents based on a provided time range. It is used by the background job to pull the latest
changes and additions since the last round of polling. This connector helps keep the document index up to date
without needing to fetch/embed/index every document which would be too slow to do frequently on large sets of
documents.
- Slim Connector:
- This connector should be a lighter weight method of checking all documents in the source to see if they still exist.
- This connector should be identical to the Poll or Load Connector except that it only fetches the IDs of the documents, not the documents themselves.
- This is used by our pruning job which removes old documents from the index.
- The optional start and end datetimes can be ignored.
- Event Based connectors:
- Connectors that listen to events and update documents accordingly.
- Currently not used by the background job, this exists for future design purposes.
@@ -26,8 +31,14 @@ Refer to [interfaces.py](https://github.com/danswer-ai/danswer/blob/main/backend
and this first contributor created Pull Request for a new connector (Shoutout to Dan Brown):
[Reference Pull Request](https://github.com/danswer-ai/danswer/pull/139)
For implementing a Slim Connector, refer to the comments in this PR:
[Slim Connector PR](https://github.com/danswer-ai/danswer/pull/3303/files)
All new connectors should have tests added to the `backend/tests/daily/connectors` directory. Refer to the above PR for an example of adding tests for a new connector.
#### Implementing the new Connector
The connector must subclass one or more of LoadConnector, PollConnector, or EventConnector.
The connector must subclass one or more of LoadConnector, PollConnector, SlimConnector, or EventConnector.
The `__init__` should take arguments for configuring what documents the connector will and where it finds those
documents. For example, if you have a wiki site, it may include the configuration for the team, topic, folder, etc. of

View File

@@ -5,9 +5,9 @@ from io import BytesIO
from typing import Any
from typing import Optional
import boto3
from botocore.client import Config
from mypy_boto3_s3 import S3Client
import boto3 # type: ignore
from botocore.client import Config # type: ignore
from mypy_boto3_s3 import S3Client # type: ignore
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import BlobType

View File

@@ -1,15 +1,17 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from urllib.parse import quote
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
from danswer.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.confluence.onyx_confluence import build_confluence_client
from danswer.connectors.confluence.onyx_confluence import OnyxConfluence
from danswer.connectors.confluence.utils import attachment_to_content
from danswer.connectors.confluence.utils import build_confluence_client
from danswer.connectors.confluence.utils import build_confluence_document_id
from danswer.connectors.confluence.utils import datetime_from_string
from danswer.connectors.confluence.utils import extract_text_from_confluence_html
@@ -51,6 +53,8 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
"restrictions.read.restrictions.group",
]
_SLIM_DOC_BATCH_SIZE = 5000
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
@@ -67,10 +71,11 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
# skip it. This is generally used to avoid indexing extra sensitive
# pages.
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
timezone_offset: float = CONFLUENCE_TIMEZONE_OFFSET,
) -> None:
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
self.confluence_client: OnyxConfluence | None = None
self._confluence_client: OnyxConfluence | None = None
self.is_cloud = is_cloud
# Remove trailing slash from wiki_base if present
@@ -97,39 +102,46 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
self.cql_label_filter = ""
if labels_to_skip:
labels_to_skip = list(set(labels_to_skip))
comma_separated_labels = ",".join(f"'{label}'" for label in labels_to_skip)
comma_separated_labels = ",".join(
f"'{quote(label)}'" for label in labels_to_skip
)
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
self.timezone: timezone = timezone(offset=timedelta(hours=timezone_offset))
@property
def confluence_client(self) -> OnyxConfluence:
if self._confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
return self._confluence_client
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
# see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py
# for a list of other hidden constructor args
self.confluence_client = build_confluence_client(
credentials_json=credentials,
self._confluence_client = build_confluence_client(
credentials=credentials,
is_cloud=self.is_cloud,
wiki_base=self.wiki_base,
)
return None
def _get_comment_string_for_page_id(self, page_id: str) -> str:
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
comment_string = ""
comment_cql = f"type=comment and container='{page_id}'"
comment_cql += self.cql_label_filter
expand = ",".join(_COMMENT_EXPANSION_FIELDS)
for comments in self.confluence_client.paginated_cql_page_retrieval(
for comment in self.confluence_client.paginated_cql_retrieval(
cql=comment_cql,
expand=expand,
):
for comment in comments:
comment_string += "\nComment:\n"
comment_string += extract_text_from_confluence_html(
confluence_client=self.confluence_client,
confluence_object=comment,
)
comment_string += "\nComment:\n"
comment_string += extract_text_from_confluence_html(
confluence_client=self.confluence_client,
confluence_object=comment,
fetched_titles=set(),
)
return comment_string
@@ -141,9 +153,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
If its a page, it extracts the text, adds the comments for the document text.
If its an attachment, it just downloads the attachment and converts that into a document.
"""
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
# The url and the id are the same
object_url = build_confluence_document_id(
self.wiki_base, confluence_object["_links"]["webui"], self.is_cloud
@@ -153,16 +162,19 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
# Extract text from page
if confluence_object["type"] == "page":
object_text = extract_text_from_confluence_html(
self.confluence_client, confluence_object
confluence_client=self.confluence_client,
confluence_object=confluence_object,
fetched_titles={confluence_object.get("title", "")},
)
# Add comments to text
object_text += self._get_comment_string_for_page_id(confluence_object["id"])
elif confluence_object["type"] == "attachment":
object_text = attachment_to_content(
self.confluence_client, confluence_object
confluence_client=self.confluence_client, attachment=confluence_object
)
if object_text is None:
# This only happens for attachments that are not parseable
return None
# Get space name
@@ -193,44 +205,41 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
)
def _fetch_document_batches(self) -> GenerateDocumentsOutput:
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
doc_batch: list[Document] = []
confluence_page_ids: list[str] = []
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
logger.debug(f"page_query: {page_query}")
# Fetch pages as Documents
for page_batch in self.confluence_client.paginated_cql_page_retrieval(
for page in self.confluence_client.paginated_cql_retrieval(
cql=page_query,
expand=",".join(_PAGE_EXPANSION_FIELDS),
limit=self.batch_size,
):
for page in page_batch:
confluence_page_ids.append(page["id"])
doc = self._convert_object_to_document(page)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
logger.debug(f"_fetch_document_batches: {page['id']}")
confluence_page_ids.append(page["id"])
doc = self._convert_object_to_document(page)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
# Fetch attachments as Documents
for confluence_page_id in confluence_page_ids:
attachment_cql = f"type=attachment and container='{confluence_page_id}'"
attachment_cql += self.cql_label_filter
# TODO: maybe should add time filter as well?
for attachments in self.confluence_client.paginated_cql_page_retrieval(
for attachment in self.confluence_client.paginated_cql_retrieval(
cql=attachment_cql,
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
):
for attachment in attachments:
doc = self._convert_object_to_document(attachment)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
doc = self._convert_object_to_document(attachment)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
@@ -240,10 +249,10 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput:
# Add time filters
formatted_start_time = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
formatted_start_time = datetime.fromtimestamp(start, tz=self.timezone).strftime(
"%Y-%m-%d %H:%M"
)
formatted_end_time = datetime.fromtimestamp(end, tz=timezone.utc).strftime(
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
"%Y-%m-%d %H:%M"
)
self.cql_time_filter = f" and lastmodified >= '{formatted_start_time}'"
@@ -255,52 +264,52 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
doc_metadata_list: list[SlimDocument] = []
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
page_query = self.cql_page_query + self.cql_label_filter
for pages in self.confluence_client.cql_paginate_all_expansions(
for page in self.confluence_client.cql_paginate_all_expansions(
cql=page_query,
expand=restrictions_expand,
limit=_SLIM_DOC_BATCH_SIZE,
):
for page in pages:
# If the page has restrictions, add them to the perm_sync_data
# These will be used by doc_sync.py to sync permissions
perm_sync_data = {
"restrictions": page.get("restrictions", {}),
"space_key": page.get("space", {}).get("key"),
}
# If the page has restrictions, add them to the perm_sync_data
# These will be used by doc_sync.py to sync permissions
perm_sync_data = {
"restrictions": page.get("restrictions", {}),
"space_key": page.get("space", {}).get("key"),
}
doc_metadata_list.append(
SlimDocument(
id=build_confluence_document_id(
self.wiki_base,
page["_links"]["webui"],
self.is_cloud,
),
perm_sync_data=perm_sync_data,
)
)
attachment_cql = f"type=attachment and container='{page['id']}'"
attachment_cql += self.cql_label_filter
for attachment in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_cql,
expand=restrictions_expand,
limit=_SLIM_DOC_BATCH_SIZE,
):
doc_metadata_list.append(
SlimDocument(
id=build_confluence_document_id(
self.wiki_base,
page["_links"]["webui"],
attachment["_links"]["webui"],
self.is_cloud,
),
perm_sync_data=perm_sync_data,
)
)
attachment_cql = f"type=attachment and container='{page['id']}'"
attachment_cql += self.cql_label_filter
for attachments in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_cql,
expand=restrictions_expand,
):
for attachment in attachments:
doc_metadata_list.append(
SlimDocument(
id=build_confluence_document_id(
self.wiki_base,
attachment["_links"]["webui"],
self.is_cloud,
),
perm_sync_data=perm_sync_data,
)
)
yield doc_metadata_list
doc_metadata_list = []
if len(doc_metadata_list) > _SLIM_DOC_BATCH_SIZE:
yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE]
doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:]
yield doc_metadata_list

View File

@@ -20,6 +20,10 @@ F = TypeVar("F", bound=Callable[..., Any])
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
# https://jira.atlassian.com/browse/CONFCLOUD-76433
_PROBLEMATIC_EXPANSIONS = "body.storage.value"
_REPLACEMENT_EXPANSIONS = "body.view.value"
class ConfluenceRateLimitError(Exception):
pass
@@ -80,7 +84,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
MAX_RETRIES = 5
TIMEOUT = 3600
TIMEOUT = 600
timeout_at = time.monotonic() + TIMEOUT
for attempt in range(MAX_RETRIES):
@@ -95,6 +99,10 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
return confluence_call(*args, **kwargs)
except HTTPError as e:
delay_until = _handle_http_error(e, attempt)
logger.warning(
f"HTTPError in confluence call. "
f"Retrying in {delay_until} seconds..."
)
while time.monotonic() < delay_until:
# in the future, check a signal here to exit
time.sleep(1)
@@ -112,7 +120,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
return cast(F, wrapped_call)
_DEFAULT_PAGINATION_LIMIT = 100
_DEFAULT_PAGINATION_LIMIT = 1000
class OnyxConfluence(Confluence):
@@ -126,6 +134,32 @@ class OnyxConfluence(Confluence):
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
self._wrap_methods()
def get_current_user(self, expand: str | None = None) -> Any:
"""
Implements a method that isn't in the third party client.
Get information about the current user
:param expand: OPTIONAL expand for get status of user.
Possible param is "status". Results are "Active, Deactivated"
:return: Returns the user details
"""
from atlassian.errors import ApiPermissionError # type:ignore
url = "rest/api/user/current"
params = {}
if expand:
params["expand"] = expand
try:
response = self.get(url, params=params)
except HTTPError as e:
if e.response.status_code == 403:
raise ApiPermissionError(
"The calling user does not have permission", reason=e
)
raise
return response
def _wrap_methods(self) -> None:
"""
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
@@ -141,7 +175,7 @@ class OnyxConfluence(Confluence):
def _paginate_url(
self, url_suffix: str, limit: int | None = None
) -> Iterator[list[dict[str, Any]]]:
) -> Iterator[dict[str, Any]]:
"""
This will paginate through the top level query.
"""
@@ -153,46 +187,43 @@ class OnyxConfluence(Confluence):
while url_suffix:
try:
logger.debug(f"Making confluence call to {url_suffix}")
next_response = self.get(url_suffix)
except Exception as e:
logger.exception("Error in danswer_cql: \n")
raise e
yield next_response.get("results", [])
logger.warning(f"Error in confluence call to {url_suffix}")
# If the problematic expansion is in the url, replace it
# with the replacement expansion and try again
# If that fails, raise the error
if _PROBLEMATIC_EXPANSIONS not in url_suffix:
logger.exception(f"Error in confluence call to {url_suffix}")
raise e
logger.warning(
f"Replacing {_PROBLEMATIC_EXPANSIONS} with {_REPLACEMENT_EXPANSIONS}"
" and trying again."
)
url_suffix = url_suffix.replace(
_PROBLEMATIC_EXPANSIONS,
_REPLACEMENT_EXPANSIONS,
)
continue
# yield the results individually
yield from next_response.get("results", [])
url_suffix = next_response.get("_links", {}).get("next")
def paginated_groups_retrieval(
self,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
return self._paginate_url("rest/api/group", limit)
def paginated_group_members_retrieval(
self,
group_name: str,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
group_name = quote(group_name)
return self._paginate_url(f"rest/api/group/{group_name}/member", limit)
def paginated_cql_user_retrieval(
def paginated_cql_retrieval(
self,
cql: str,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
) -> Iterator[dict[str, Any]]:
"""
The content/search endpoint can be used to fetch pages, attachments, and comments.
"""
expand_string = f"&expand={expand}" if expand else ""
return self._paginate_url(
f"rest/api/search/user?cql={cql}{expand_string}", limit
)
def paginated_cql_page_retrieval(
self,
cql: str,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
expand_string = f"&expand={expand}" if expand else ""
return self._paginate_url(
yield from self._paginate_url(
f"rest/api/content/search?cql={cql}{expand_string}", limit
)
@@ -201,7 +232,7 @@ class OnyxConfluence(Confluence):
cql: str,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[list[dict[str, Any]]]:
) -> Iterator[dict[str, Any]]:
"""
This function will paginate through the top level query first, then
paginate through all of the expansions.
@@ -221,6 +252,120 @@ class OnyxConfluence(Confluence):
for item in data:
_traverse_and_update(item)
for results in self.paginated_cql_page_retrieval(cql, expand, limit):
_traverse_and_update(results)
yield results
for confluence_object in self.paginated_cql_retrieval(cql, expand, limit):
_traverse_and_update(confluence_object)
yield confluence_object
def paginated_cql_user_retrieval(
self,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
The search/user endpoint can be used to fetch users.
It's a seperate endpoint from the content/search endpoint used only for users.
Otherwise it's very similar to the content/search endpoint.
"""
cql = "type=user"
url = "rest/api/search/user" if self.cloud else "rest/api/search"
expand_string = f"&expand={expand}" if expand else ""
url += f"?cql={cql}{expand_string}"
yield from self._paginate_url(url, limit)
def paginated_groups_by_user_retrieval(
self,
user: dict[str, Any],
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
This is not an SQL like query.
It's a confluence specific endpoint that can be used to fetch groups.
"""
user_field = "accountId" if self.cloud else "key"
user_value = user["accountId"] if self.cloud else user["userKey"]
# Server uses userKey (but calls it key during the API call), Cloud uses accountId
user_query = f"{user_field}={quote(user_value)}"
url = f"rest/api/user/memberof?{user_query}"
yield from self._paginate_url(url, limit)
def paginated_groups_retrieval(
self,
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
This is not an SQL like query.
It's a confluence specific endpoint that can be used to fetch groups.
"""
yield from self._paginate_url("rest/api/group", limit)
def paginated_group_members_retrieval(
self,
group_name: str,
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
This is not an SQL like query.
It's a confluence specific endpoint that can be used to fetch the members of a group.
THIS DOESN'T WORK FOR SERVER because it breaks when there is a slash in the group name.
E.g. neither "test/group" nor "test%2Fgroup" works for confluence.
"""
group_name = quote(group_name)
yield from self._paginate_url(f"rest/api/group/{group_name}/member", limit)
def _validate_connector_configuration(
credentials: dict[str, Any],
is_cloud: bool,
wiki_base: str,
) -> None:
# test connection with direct client, no retries
confluence_client_with_minimal_retries = Confluence(
api_version="cloud" if is_cloud else "latest",
url=wiki_base.rstrip("/"),
username=credentials["confluence_username"] if is_cloud else None,
password=credentials["confluence_access_token"] if is_cloud else None,
token=credentials["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=6,
max_backoff_seconds=10,
)
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
# uncomment the following for testing
# the following is an attempt to retrieve the user's timezone
# Unfornately, all data is returned in UTC regardless of the user's time zone
# even tho CQL parses incoming times based on the user's time zone
# space_key = spaces["results"][0]["key"]
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
if not spaces:
raise RuntimeError(
f"No spaces found at {wiki_base}! "
"Check your credentials and wiki_base and make sure "
"is_cloud is set correctly."
)
def build_confluence_client(
credentials: dict[str, Any],
is_cloud: bool,
wiki_base: str,
) -> OnyxConfluence:
_validate_connector_configuration(
credentials=credentials,
is_cloud=is_cloud,
wiki_base=wiki_base,
)
return OnyxConfluence(
api_version="cloud" if is_cloud else "latest",
# Remove trailing slash from wiki_base if present
url=wiki_base.rstrip("/"),
# passing in username causes issues for Confluence data center
username=credentials["confluence_username"] if is_cloud else None,
password=credentials["confluence_access_token"] if is_cloud else None,
token=credentials["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=10,
max_backoff_seconds=60,
)

View File

@@ -2,6 +2,7 @@ import io
from datetime import datetime
from datetime import timezone
from typing import Any
from urllib.parse import quote
import bs4
@@ -31,7 +32,11 @@ def get_user_email_from_username__server(
response = confluence_client.get_mobile_parameters(user_name)
email = response.get("email")
except Exception:
email = None
# For now, we'll just return a string that indicates failure
# We may want to revert to returning None in the future
# email = None
email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
logger.warning(f"failed to get confluence email for {user_name}")
_USER_EMAIL_CACHE[user_name] = email
return _USER_EMAIL_CACHE[user_name]
@@ -71,7 +76,9 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
def extract_text_from_confluence_html(
confluence_client: OnyxConfluence, confluence_object: dict[str, Any]
confluence_client: OnyxConfluence,
confluence_object: dict[str, Any],
fetched_titles: set[str],
) -> str:
"""Parse a Confluence html page and replace the 'user Id' by the real
User Display Name
@@ -79,7 +86,7 @@ def extract_text_from_confluence_html(
Args:
confluence_object (dict): The confluence object as a dict
confluence_client (Confluence): Confluence client
fetched_titles (set[str]): The titles of the pages that have already been fetched
Returns:
str: loaded and formated Confluence page
"""
@@ -101,38 +108,72 @@ def extract_text_from_confluence_html(
# Include @ sign for tagging, more clear for LLM
user.replaceWith("@" + _get_user(confluence_client, user_id))
for html_page_reference in soup.findAll("ri:page"):
for html_page_reference in soup.findAll("ac:structured-macro"):
# Here, we only want to process page within page macros
if html_page_reference.attrs.get("ac:name") != "include":
continue
page_data = html_page_reference.find("ri:page")
if not page_data:
logger.warning(
f"Skipping retrieval of {html_page_reference} because because page data is missing"
)
continue
page_title = page_data.attrs.get("ri:content-title")
if not page_title:
# only fetch pages that have a title
logger.warning(
f"Skipping retrieval of {html_page_reference} because it has no title"
)
continue
if page_title in fetched_titles:
# prevent recursive fetching of pages
logger.debug(f"Skipping {page_title} because it has already been fetched")
continue
fetched_titles.add(page_title)
# Wrap this in a try-except because there are some pages that might not exist
try:
page_title = html_page_reference.attrs["ri:content-title"]
if not page_title:
continue
page_query = f"type=page and title='{page_title}'"
page_query = f"type=page and title='{quote(page_title)}'"
page_contents: dict[str, Any] | None = None
# Confluence enforces title uniqueness, so we should only get one result here
for page_batch in confluence_client.paginated_cql_page_retrieval(
for page in confluence_client.paginated_cql_retrieval(
cql=page_query,
expand="body.storage.value",
limit=1,
):
page_contents = page_batch[0]
page_contents = page
break
except Exception:
except Exception as e:
logger.warning(
f"Error getting page contents for object {confluence_object}"
f"Error getting page contents for object {confluence_object}: {e}"
)
continue
if not page_contents:
continue
text_from_page = extract_text_from_confluence_html(
confluence_client, page_contents
confluence_client=confluence_client,
confluence_object=page_contents,
fetched_titles=fetched_titles,
)
html_page_reference.replaceWith(text_from_page)
for html_link_body in soup.findAll("ac:link-body"):
# This extracts the text from inline links in the page so they can be
# represented in the document text as plain text
try:
text_from_link = html_link_body.text
html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})")
except Exception as e:
logger.warning(f"Error processing ac:link-body: {e}")
return format_document_soup(soup)
@@ -232,20 +273,3 @@ def datetime_from_string(datetime_string: str) -> datetime:
datetime_object = datetime_object.astimezone(timezone.utc)
return datetime_object
def build_confluence_client(
credentials_json: dict[str, Any], is_cloud: bool, wiki_base: str
) -> OnyxConfluence:
return OnyxConfluence(
api_version="cloud" if is_cloud else "latest",
# Remove trailing slash from wiki_base if present
url=wiki_base.rstrip("/"),
# passing in username causes issues for Confluence data center
username=credentials_json["confluence_username"] if is_cloud else None,
password=credentials_json["confluence_access_token"] if is_cloud else None,
token=credentials_json["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=60,
max_backoff_seconds=60,
)

Some files were not shown because too many files have changed in this diff Show More