Compare commits

...

204 Commits

Author SHA1 Message Date
pablodanswer
8a4e47781b remove history sidebar on mouse exiting window (#2173) 2024-08-19 23:15:54 +00:00
Chris Weaver
af647959f6 Performance Improvements (#2162) 2024-08-19 11:07:00 -07:00
pablodanswer
ea53977617 prevent empty doc link click (#2170) 2024-08-19 18:03:36 +00:00
Weves
c44c22a009 Fix model server 2024-08-19 07:23:24 -07:00
Yuhong Sun
5ab4d94d94 Logging Level Update (#2165) 2024-08-18 21:53:40 -07:00
Yuhong Sun
119aefba88 Add log files to containers (#2164) 2024-08-18 19:18:28 -07:00
pablodanswer
12fccfeffd Add stop generating functionality (#2100)
* functional types + sidebar

* remove commits

* remove logs

* functional rework of temporary user/assistant ID

* robustify switching

* remove logs

* typing

* robustify frontend handling

* cleaner loop + data persistence

* migrate to streaming response

* formatting

* add new loading state to prevent collisions

* add `ChatState` for more robust handling

* remove logs

* robustify typing

* unnecessary list removed

* robustify

* remove log

* remove false comment

* slightly more robust chat state

* update utility + copy

* improve clarity + new SSE handling utility function

* remove comments

* clearer

* add back stack trace detail

* cleaner messages

* clean final message handling

* tiny formatting (remove newline)

* add synchronous wrapper to avoid hampering main event loop

* update typing

* include logs

* slightly more specific logs

* add `critical` error just in case
2024-08-18 22:15:55 +00:00
Yuhong Sun
8a7bc4e411 Log Level Default (#2163) 2024-08-18 14:35:32 -07:00
rkuo-danswer
492797c9f3 Feature/indexing errors (#2148)
* backend changes to handle partial completion of index attempts

* typo fix

* Display partial success in UI

* make log timing more readable by limiting printed precision to milliseconds

* forgot alembic

* initial cut at "completed with errors" indexing

* remove and reorganize unused imports

* show view errors while indexing is in progress

* code review fixes
2024-08-18 19:14:32 +00:00
Yuhong Sun
739058aacc Logging updates (#2159) 2024-08-17 22:05:09 -07:00
Chris Weaver
17570038bb Add PG query logging (#2156) 2024-08-16 21:53:54 -07:00
Yuhong Sun
c0edfb50df k 2024-08-16 21:43:14 -07:00
pablodanswer
22573aba2a Improve Search (#2105) 2024-08-16 21:29:15 -07:00
Chris Weaver
efae24acd0 improve model seeding (#2155) 2024-08-17 01:30:13 +00:00
pablodanswer
f8e0e6f015 Extremely robustified Index Attempt migration (#2151)
* account for connector_id edge case

* robustified
2024-08-17 01:12:18 +00:00
pablodanswer
3cbc341b60 Enable persistence / removal of assistant icons + remove accidental regression (#2153)
* enable persistence / removal of assistant icons + remove accidental regression

* simpler env seeding for web building
2024-08-17 01:11:04 +00:00
pablodanswer
46c7089328 Enable seeding of analytics via file path (#2146)
* enable seeding of analytics via file path

* remove log
2024-08-16 03:14:56 +00:00
pablodanswer
3ffbe659e3 add handling for poorly formatting model names (#2143) 2024-08-15 22:01:57 +00:00
pablodanswer
33fed955d9 Add verbose error messages + robustify assistant switching (#2144)
* add verbose error messages + robustify assistant switching and chat sessions

* fix typing

* cleaner errors + add stack trace
2024-08-15 21:05:04 +00:00
rkuo-danswer
9fa4280f96 add configurable support for memory tracing during indexing (#2140) 2024-08-15 20:40:17 +00:00
Yuhong Sun
4d194bc86a Cohere No Large Chunks (#2145) 2024-08-15 10:18:54 -07:00
Weves
0853d1a8f1 Update force deletion script 2024-08-14 23:29:26 -07:00
Weves
f6547a64a0 More logging for SAML endpoints 2024-08-14 23:25:42 -07:00
hagen-danswer
61b5bd569b Reworked chunking to support mega chunks (#2032) 2024-08-14 22:18:53 -07:00
pablodanswer
680388537b UX clarity + minor new features (#2136) 2024-08-14 15:23:36 -07:00
pablodanswer
d9bcacfae7 validate messages (#2139) 2024-08-14 22:06:48 +00:00
hagen-danswer
2ab192933b Added import statement to fix typescript error (#2138) 2024-08-14 20:10:08 +00:00
Yuhong Sun
1c10f54294 GPU Model Server (#2135) 2024-08-14 11:04:28 -07:00
josvdw
0530f4283e updating readme for widget (#2132)
Co-authored-by: Jos Van der westhuizen <jos@danser.ai>
2024-08-14 16:55:59 +00:00
pablodanswer
3540aa579b Add ux improvements (#2130)
* add ux improvements

* add danswer version display

* show version properly

* improve copy + add web version to settings context

* update copy + danswer version
2024-08-14 16:43:52 +00:00
josvdw
54732a83c9 stopgap: clarify text on standard answer page for improved UX (#2122)
* stopgap: clarify text on standard answer page for improved UX

* replce apostrophe

* using tailwind:

---------

Co-authored-by: Jos Van der westhuizen <jos@danser.ai>
2024-08-14 01:28:49 +00:00
pablodanswer
5e6365c449 Minor update to clarify user adding (#2126)
* minor update to clarify user adding

* Update page.tsx

* run pretty
2024-08-13 21:09:51 +00:00
rkuo-danswer
20369fc451 Refactor/default indexing embedder (#2073)
* refactor embedding model instantiation

* remove unused UNCERTAINTY_PAT constant

* typo fixes

* fix mypy typing issues

* more typing fixes

* log attempt.id on dispatch

* unnecessary check removed after fixing type
2024-08-13 21:01:34 +00:00
rkuo-danswer
f15d6d2b59 allow admin role api keys (#2124)
* allow admin role api keys

* bump to rerun deployment

* types needs explicit export now for APIKey

* remove api_key.role, use User.role instead

* fix formatting

* formatting

* formatting

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-08-13 21:00:57 +00:00
pablodanswer
5dda047999 Always show search filters (#2128) 2024-08-13 13:36:46 -07:00
pablodanswer
ffd9b0180b Fix overflow for quotes in search section (#2123)
* fix overflow for quotes in search section

* proper overflow check
2024-08-13 20:32:11 +00:00
Yuhong Sun
5ad54fec87 Inference to handle no link docs (#2129) 2024-08-13 12:40:11 -07:00
hagen-danswer
d636181aa5 Added catch for empty link (#2037) 2024-08-12 20:08:56 -07:00
pablodanswer
e12ed7750a Add scrollbar to search / chat (#2121)
* add scrollbar to search / chat

* show overflow for lists
2024-08-13 03:07:37 +00:00
hagen-danswer
bbb8c5ff0b Speed up docker launch (#2099)
* use move instead of copy

* added logging

* fix overwrites

* tested throughly

* fixes

* clearer commenting
2024-08-13 00:45:05 +00:00
pablodanswer
83e945ba57 add cleaner / consolidate no docs found message (#2119) 2024-08-12 16:04:59 -07:00
rkuo-danswer
26df869b91 Feature/harden memory limits (#2118)
* log warning in indexer when size exceeds INDEXING_SIZE_WARNING_THRESHOLD

* add configurable attachment size limit for confluence

* specify "attachments"
2024-08-12 15:12:34 -07:00
Weves
1a4df1d65e Remove unnecessary LLM settings 2024-08-12 11:33:49 -07:00
Chris Weaver
0a165aae0b Slack improvements (#2113) 2024-08-11 21:27:37 -07:00
rkuo-danswer
e517f47a89 add send-message-simple-with-history endpoint to avoid… (#2101)
* add send-message-simple-with-history endpoint to support ramp. avoids bad json output in models and allows client to pass history in instead of maintaining it in our own session

* slightly better error checking

* addressing code review

* reject on any empty message

* update test naming
2024-08-12 03:33:52 +00:00
Nathan Schwerdfeger
c7e5b11c63 EE Connector Deletion Bugfix + Refactor (#2042)
---------

Co-authored-by: Weves <chrisweaver101@gmail.com>
2024-08-11 20:33:07 -07:00
Yuhong Sun
79523f2e0a Warm up reranker (#2111) 2024-08-11 15:20:51 -07:00
pablodanswer
7fae66b766 provider type default to none (#2110) 2024-08-11 14:51:12 -07:00
Yuhong Sun
386b229ed3 Cohere Rerank (#2109) 2024-08-11 14:22:42 -07:00
Yuhong Sun
ce666f3320 Propagate Embedding Enum (#2108) 2024-08-11 12:17:54 -07:00
Yuhong Sun
d60fb15ad3 Allowing users to set Search Settings (#2106) 2024-08-10 20:48:58 -07:00
pablodanswer
7358ece008 enable assistant editing 2024-08-10 14:38:34 -07:00
josvdw
9c5d33e198 open chatdocument links in a new tab instead of overriding danswer (#2090)
Co-authored-by: Jos Van der westhuizen <jos@danser.ai>
2024-08-10 21:37:59 +00:00
pablodanswer
7d5cfd2fa3 Add user specific model defaults (#2043) 2024-08-10 14:37:33 -07:00
Yuhong Sun
a4caf66a35 User Notification Backend (#2104) 2024-08-10 11:39:21 -07:00
pablodanswer
0a8d44b44c quote processing for lengthy intros (#2103) 2024-08-10 11:09:45 -07:00
pablodanswer
cc8a6da8e3 improve llm-generated citations (account for edge case) (#2096)
* improve llm-generated citations (account for edge case)

* additional test case
2024-08-10 02:06:39 +00:00
pablodanswer
54d4526b73 (Minor) Add cleaner search, feedback model, and connector view (#2098)
* add cleaner search, feedback model, and connector view

* Update ChatPage.tsx
2024-08-10 01:54:31 +00:00
Yuhong Sun
c8ead6a0dc Need Reindexing Flag Setup (#2102) 2024-08-09 17:44:57 -07:00
pablodanswer
7bfa99766d Add support for google slides (#2083)
* add support for google slides

* remove log + account for dead code

* squash
2024-08-09 17:12:51 +00:00
hagen-danswer
b230082891 Openai encoding temp hotfix (#2094) 2024-08-09 08:17:31 -07:00
Yuhong Sun
8cd1eda8b1 Rework Rerankers (#2093) 2024-08-08 21:33:49 -07:00
Yuhong Sun
7dcc42aa95 Intent Model Update (#2069) 2024-08-08 20:45:53 -07:00
pablodanswer
e59d1a0294 fix edge case with simpler code block + python formatting (#2092) 2024-08-08 20:44:32 -07:00
pablodanswer
384e61f4b0 add new gpt-4o model 2024-08-08 16:32:57 -07:00
pablodanswer
f28b930475 Image -> img (#2087) 2024-08-08 21:46:42 +00:00
pablodanswer
1d989f5343 Fix model override for persisting default assistant (#2081)
* fix model override for persisting default assistant

* run pretty

* don't modify

* Update ChatPage.tsx
2024-08-08 21:22:19 +00:00
pablodanswer
c1e3a1b3e7 Select proper assistant override (#2068)
* encode images properly

* proper assistant default model updates

* remove now unneeded image encoding update

* update naming of persona llm option gathering
2024-08-08 21:02:11 +00:00
rkuo-danswer
be9ed319d5 add unit test for quotes (#2085)
* add unit test for quotes

* test answer and quotes together
2024-08-08 18:20:07 +00:00
pablodanswer
c630fcffee Improve code block formatting (#2084)
* initial update to styling

* fix chat input bar padding

* improve color choices
2024-08-08 17:12:35 +00:00
josvdw
f411b9cb55 quality of life improvements for the launch.json template (#2082)
Co-authored-by: Jos Van der westhuizen <jos@danser.ai>
2024-08-08 06:39:30 +00:00
Richard Kuo (Danswer)
bdaaebe955 use re.search instead of re.match (which searches from start of string only) 2024-08-07 20:55:18 -07:00
pablodanswer
9eb48ca2c3 account for empty links + fix quote processing 2024-08-07 20:55:18 -07:00
rkuo-danswer
509fa3a994 add postgres configuration (#2076) 2024-08-08 00:13:59 +00:00
pablodanswer
5097c7f284 Handle saved search docs in eval flow (#2075) 2024-08-07 16:18:34 -07:00
pablodanswer
c4e1c62c00 Admin UX updates (#2057) 2024-08-07 14:55:16 -07:00
pablodanswer
eab82782ca Add proper delay for assistant switching (#2070)
* add proper delay for assistant switching

* persist input if possible
2024-08-07 14:46:15 -07:00
pablodanswer
53d976234a proper new chat button redirects (#2074) 2024-08-07 14:44:42 -07:00
pablodanswer
44d8e34b5a Improve seeding (includes all enterprise features) (#2065) 2024-08-07 10:44:33 -07:00
pablodanswer
d2e16a599d Improve shared chat page (#2066)
* improve look of shared chat page

* remove log

* cleaner display

* add initializing loader to shared chat page

* updated danswer loaders (for prism)

* remove default share
2024-08-07 16:13:55 +00:00
pablodanswer
291e6c4198 somewhat clearer API errors (#2064) 2024-08-07 03:04:26 +00:00
Chris Weaver
bb7e1d6e55 Add integration tests for document set syncing (#1904) 2024-08-06 18:00:19 -07:00
rkuo-danswer
fcc4c30ead don't skip the start of the json answer value (#2067) 2024-08-06 23:59:13 +00:00
pablodanswer
f20984ea1d Don't persist error perennially (#2061)
* don't persist error perennially

* proper functionality

* remove logs

* remove another log

* add comments for clarity + reverse conditional

* add comment back

* remove comment
2024-08-06 23:09:25 +00:00
pablodanswer
e0f0cfd92e Ensure relevance functions for selected docs (#2063)
* ensure relevance functions for selected docs

* remove logs

* remove log
2024-08-06 21:06:44 +00:00
pablodanswer
57aec7d02a doc sidebar width fix 2024-08-06 13:48:47 -07:00
pablodanswer
6350219143 Add proper default temperature + overrides (#2059)
* add proper default temperature + overrides

* remove unclear commment

* ammend defaults + include internet serach
2024-08-06 19:57:14 +00:00
pablodanswer
3bc2cf9946 update tool display bubbles to have cursor-dfeault 2024-08-06 12:49:42 -07:00
pablodanswer
7f7452dc98 Whitelabelling consistency (#2058)
* add white labelling to admin sidebar

* even more consistency
2024-08-06 19:45:38 +00:00
pablodanswer
dc2a50034d Clean chat banner (#2056)
* fully functional

* formatting

* ensure consistency with large logos

* ensure mobile support
2024-08-06 19:44:14 +00:00
pablodanswer
ab564a9ec8 Add cleaner loading / streaming for image loading (#2055)
* add image loading

* clean

* add loading skeleton

* clean up

* clearer comments
2024-08-06 19:28:48 +00:00
rkuo-danswer
cc3856ef6d enforce index attempt deduping on secondary indexing. (#2054)
* enforce index attempt deduping on secondary indexing.

* black fix

* typo fixes

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-08-06 17:45:16 +00:00
Yuhong Sun
a8a4ad9546 Chunk Filter Metadata Format (#2053) 2024-08-05 15:12:36 -07:00
pablodanswer
5bfdecacad fix assistant drag transform effect (#2052) 2024-08-05 14:53:38 -07:00
pablodanswer
0bde66a888 remove "quotes" section (#2049) 2024-08-05 18:51:43 +00:00
pablodanswer
5825d01d53 Better assistant interactions + UI (#2029)
* add assistnat re-ordering, selections, etc.

* squash

* remove unnecessary comment

* squash

* adapt dragging for all IDs + smoother animation + consistency

* fix minor typing issue

* fix minor typing issue

* remove logs
2024-08-05 18:22:57 +00:00
pablodanswer
cd22cca4e8 remove non-EE public connector options 2024-08-05 11:14:20 -07:00
pablodanswer
a3ea217f40 ensure consistency of answers + update llm relevance prompting (#2045) 2024-08-05 08:27:15 -07:00
pablodanswer
66e4dded91 Add properly random icons to assistant creation page (#2044) 2024-08-04 23:30:17 -07:00
pablodanswer
6d67d472cd Add answers to search (#2020) 2024-08-04 23:02:55 -07:00
Weves
76b7792e69 Harden embedding calls 2024-08-04 15:11:45 -07:00
Chris Weaver
9d7100a287 Fix secondary index attempts showing up as the primary index status + scheduling while in-progress (#2039) 2024-08-04 13:29:44 -07:00
pablodanswer
876feecd6f Fix code pasting formatting (#2033)
* fix pasting formatting

* add back small comments
2024-08-04 09:56:48 -07:00
pablodanswer
0261d689dc Various Admin Page + User Flow Improvements (#1987) 2024-08-03 18:09:46 -07:00
pablodanswer
aa4a00cbc2 fix minor html error (#2034) 2024-08-03 12:40:07 -07:00
Nathan Schwerdfeger
52c505c210 Remove partially implemented reply cancellation (#2031)
* fix: remove partially implemented response cancellation

* feat: notify user when unsupported chat cancellation is requested

* fix: correct ChatInputBar streaming detection logic
2024-08-03 18:12:04 +00:00
pablodanswer
ed455394fc detect foreign key composition sessions (#2024) 2024-08-02 17:26:57 +00:00
hagen-danswer
57cc53ab94 Added content tags to zendesk connector (#2017) 2024-08-02 10:09:53 -07:00
rkuo-danswer
6a61331cba Feature/log despam (#2022)
* move a lot of log spam to debug level. Consolidate some info level logging

* reformat more indexing logging
2024-08-02 15:28:53 +00:00
Weves
51731ad0dd Fix issue where large docs/batches break openai embedding 2024-08-02 01:07:09 -07:00
rkuo-danswer
f280586e68 pass function to Process correctly instead of running it inline (#2018)
* pass function to Process correctly instead of running it inline

* mypy fixes and pass back return result (even tho we don't use it right now)
2024-08-02 00:06:35 +00:00
hagen-danswer
e31d6be4ce Switched build to use a larger runner (#2019) 2024-08-01 14:29:45 -07:00
hagen-danswer
e6a92aa936 support confluence single page only indexing (#2008)
* added index recursively checkbox

* mypy fixes

* added migration to not break existing connectors
2024-08-01 20:32:46 +00:00
pablodanswer
a54ea9f9fa Fix cartesian issue with index attempts (#2015) 2024-08-01 10:25:25 -07:00
Yuhong Sun
73a92c046d Fix chunker (#2014) 2024-08-01 10:18:02 -07:00
pablodanswer
459bd46846 Add Prompt library (#1990) 2024-08-01 08:40:35 -07:00
Chris Weaver
445f7e70ba Fix image generation (#2009) 2024-08-01 00:27:02 -07:00
Yuhong Sun
ca893f9918 Rerank Handle Null (#2010) 2024-07-31 22:59:02 -07:00
hagen-danswer
1be1959d80 Changed default local model to nomic (#1943) 2024-07-31 18:54:02 -07:00
Chris Weaver
1654378850 Fix user dropdown font (#2007) 2024-08-01 00:29:14 +00:00
Chris Weaver
d6d391d244 Fix not_applicable (#2003) 2024-07-31 21:30:07 +00:00
rkuo-danswer
7c283b090d Feature/postgres connection names (#1998)
* avoid reindexing secondary indexes after they succeed

* use postgres application names to facilitate connection debugging

* centralize all postgres application_name constants in the constants file

* missed a couple of files

* mypy fixes

* update dev background script
2024-07-31 20:36:30 +00:00
pablodanswer
40226678af Add proper default values for assistant editing / creation (#2001) 2024-07-31 13:34:42 -07:00
rkuo-danswer
288e6fa606 Bugfix/pg connections (#2002)
* increase max_connections to 150 in all docker files

* lower celery worker concurrency to 6
2024-07-31 19:49:20 +00:00
hagen-danswer
5307d38472 Fixed tokenizer logic (#1986) 2024-07-31 09:59:45 -07:00
Yuhong Sun
d619602a6f Skip shortcut docs (#1999) 2024-07-31 09:51:01 -07:00
Yuhong Sun
348a2176f0 Fix Dropped Documents (#1997) 2024-07-31 09:33:36 -07:00
pablodanswer
89b6da36a6 process files with null title (#1989) 2024-07-31 08:18:50 -07:00
Yuhong Sun
036d5c737e No Null Embeddings (#1982) 2024-07-30 19:54:49 -07:00
pablodanswer
60a87d9472 Add back modals on chat page (#1983) 2024-07-30 17:42:59 -07:00
pablodanswer
eb9bb56829 Add initial mobile support (#1962) 2024-07-30 17:13:50 -07:00
hagen-danswer
d151082871 Moved warmup_encoders into scope (#1978) 2024-07-30 16:37:32 +00:00
pablodanswer
e4b1f5b963 fix index attempt migration where no credential ID 2024-07-30 08:57:57 -07:00
hagen-danswer
3938a053aa Rework tokenizer (#1957) 2024-07-29 23:01:49 -07:00
pablodanswer
7932e764d6 Make chat page layout cleaner + fix updating assistant images (#1973)
* ux updates for clarity
- [x] 'folders' -> 'chat folders'
- [x] sidebar to bottom left and smaller
- [x] Sidebar -> smaller logo
- [x] Align things properly
- [x] Expliti Pin: immediate + "Pin / Unpin"
- [x] Logo size smaller
- [x] Align things properly
- [x] Optionally fix gradient in sidebar
- [x] Upload logo to existing assistants

* remove unneeded logs

* run pretty

* actually run pretty!

* fix web file type

* fix very minor typo

* clean type for buildPersonaAPIBody

* fix span formatting

* HUGE ui change
2024-07-30 03:44:35 +00:00
Chris Weaver
fb6695a983 Fix flow where oidc_expiry is different from token expiry (#1974) 2024-07-30 03:17:08 +00:00
rkuo-danswer
015f415b71 avoid reindexing secondary indexes after they succeed (#1971) 2024-07-30 03:12:58 +00:00
rkuo-danswer
96b582070b authorized users and groups only have read access (#1960)
* authorized users and groups only have read access

* slightly better variable naming
2024-07-29 19:53:42 +00:00
rkuo-danswer
4a0a927a64 fix removed parameter in MediaWikiConnector (#1970) 2024-07-29 18:47:30 +00:00
hagen-danswer
ea9a9cb553 Fix typing for previous message 2024-07-29 10:01:38 -07:00
pablodanswer
38af12ab97 remove unnecessary index drop (#1968) 2024-07-29 09:51:53 -07:00
hagen-danswer
1b3154188d Fixed default indexing frequency (#1965)
* Fixed default indexing frequency

* fixed more defaults
2024-07-29 08:14:49 -07:00
Weves
1f321826ad Bigger images 2024-07-28 23:47:06 -07:00
Weves
cbfbe4e5d8 Fix image generation follow up q 2024-07-28 23:47:06 -07:00
pablodanswer
3aa0e0124b Add new admin page (#1947)
* add admin page

* credential + typing fix

* rebase fix

* on add, cleaner buttons

* functional G + Ddrive

* organized auth sections

* update types and remove logs

* ccs -> connectors

* validated formik

* update styling + connector-handling logic

* udpate colors

* separate out hooks + util functions

* update to adhere to rest standards

* remove "todos"

* rebase

* copy + formatting + sidebar

* update statuses + configuration possibilities

* update interfaces to be clearer

* update indexing status page

* formatting

* address backend security + comments

* update font

* fix form routing

* fix hydration error

* add statuses, fix bugs, etc. (squash)

* fix color (squash)

* squash

* add functionality to sidebar

* disblae buttons if deleting

* add color

* minor copy + formatting updates
- on modify credential, close
- update copy for deletion of connectors

* fix build error

* copy

---------

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2024-07-28 20:57:43 -07:00
Yuhong Sun
f2f60c9cc0 Fix EE Import backoff Logic (#1959) 2024-07-27 11:06:11 -07:00
Emerson Gomes
6c32821ad4 Allow removal of max_output_tokens by setting GEN_AI_MAX_OUTPUT_TOKENS=0 (#1958)
Co-authored-by: Emerson Gomes <emerson.gomes@thalesgroup.com>
2024-07-27 09:07:29 -07:00
Weves
d839595330 Add query override 2024-07-26 17:40:21 -07:00
Yuhong Sun
e422f96dff Pull Request Template (#1956) 2024-07-26 17:34:05 -07:00
Weves
d28f460330 Fix black 2024-07-26 16:43:15 -07:00
Eugene Astroner
8e441d975d Issue fix 2024-07-26 16:40:31 -07:00
pablodanswer
5c78af1f07 Deduplicate model names (#1950) 2024-07-26 16:30:49 -07:00
rkuo-danswer
e325e063ed Bugfix/persona access (#1951)
* also allow access to a persona if the user is in the list of authorized users or groups

* add comment on potential performance improvements

* work around for mypy typing
2024-07-26 22:05:57 +00:00
pablodanswer
c81b45300b Configurable models + updated assistants bar (#1942) 2024-07-26 11:00:49 -07:00
pablodanswer
26a1e963d1 Update personas.yaml (#1948) 2024-07-25 20:35:49 -07:00
pablodanswer
2a983263c7 Small update- Danswer update icons as well (#1945) 2024-07-25 20:31:41 -07:00
Yuhong Sun
2a37c95a5e Types for Migrations (#1944) 2024-07-25 18:18:48 -07:00
pablodanswer
c277a74f82 Add icons to assistants! (#1930) 2024-07-25 18:02:39 -07:00
rkuo-danswer
e4b31cd0d9 allow setting secondary worker count via environment variable. default to primary worker count if unset. (#1941) 2024-07-25 20:25:43 +00:00
hagen-danswer
a40d2a1e2e Change the way we get sqlalchemy session (#1940)
* changed default fast model to gpt-4o-mini

* Changed the way we get the sqlalchemy session
2024-07-25 18:36:14 +00:00
hagen-danswer
c9fb99d719 changed default fast model to gpt-4o-mini (#1939) 2024-07-25 10:50:02 -07:00
hagen-danswer
a4d71e08aa Added check for unknown tool names (#1924)
* answer.py

* Let it continue if broken
2024-07-25 00:19:08 +00:00
rkuo-danswer
546bfbd24b autoscale with pool=thread crashes celery. remove and use concurrency… (#1929)
* autoscale with pool=thread crashes celery. remove and use concurrency instead (to be improved later)

* update dev background script as well
2024-07-25 00:15:27 +00:00
hagen-danswer
27824d6cc6 Fixed login issue (#1920)
* included check for existing emails

* cleaned up logic
2024-07-25 00:03:29 +00:00
Weves
9d5c4ad634 Small fix for non tool calling LLMs 2024-07-24 15:41:43 -07:00
Shukant Pal
9b32003816 Handle SSL error tracebacks in site indexing connector (#1911)
My website (https://shukantpal.com) uses Let's Encrypt certificates, which aren't accepted by the Python urllib certificate verifier for some reason. My website is set up correctly otherwise (https://www.sslshopper.com/ssl-checker.html#hostname=www.shukantpal.com)

This change adds a fix so the correct traceback is shown in Danswer, instead of a generic "unable to connect, check your Internet connection".
2024-07-24 22:36:29 +00:00
pablodanswer
8bc4123ed7 add modern health check banner + expiration tracking (#1730)
---------

Co-authored-by: Weves <chrisweaver101@gmail.com>
2024-07-24 15:34:22 -07:00
pablodanswer
d58aaf7a59 add href 2024-07-24 14:33:56 -07:00
pablodanswer
a0056a1b3c add files (images) (#1926) 2024-07-24 21:26:01 +00:00
pablodanswer
d2584c773a slightly clearer description of model settings in assistants creation tab (#1925) 2024-07-24 21:25:30 +00:00
pablodanswer
807bef8ada Add environment variable for defaulted sidebar toggling (#1923)
* add env variable for defaulted sidebar toggling

* formatting

* update naming
2024-07-24 21:23:37 +00:00
rkuo-danswer
5afddacbb2 order list of new attempts from oldest to newest to prevent connector starvation (#1918) 2024-07-24 21:02:20 +00:00
hagen-danswer
4fb6a88f1e Quick fix (#1919) 2024-07-24 11:56:14 -07:00
rkuo-danswer
7057be6a88 Bugfix/indexing progress (#1916)
* mark in progress should always be committed

* no_commit version of mark_attempt is not needed
2024-07-24 11:39:44 -07:00
Yuhong Sun
91be8e7bfb Skip Null Docs (#1917) 2024-07-24 11:31:33 -07:00
Yuhong Sun
9651ea828b Handling Metadata by Vector and Keyword (#1909) 2024-07-24 11:05:56 -07:00
rkuo-danswer
6ee74bd0d1 fix pointers to various background tasks and scripts (#1914) 2024-07-24 10:12:51 -07:00
pablodanswer
48a0d29a5c Fix empty / reverted embeddings (#1910) 2024-07-23 22:41:31 -07:00
hagen-danswer
6ff8e6c0ea Improve eval pipeline qol (#1908) 2024-07-23 17:16:34 -07:00
Yuhong Sun
2470c68506 Don't rephrase first chat query (#1907) 2024-07-23 16:20:11 -07:00
hagen-danswer
866bc803b1 Implemented LLM disabling for api call (#1905) 2024-07-23 16:12:51 -07:00
pablodanswer
9c6084bd0d Embeddings- Clean up modal + "Important" call out (#1903) 2024-07-22 21:29:22 -07:00
hagen-danswer
a0b46c60c6 Switched eval api target back to oneshotqa (#1902) 2024-07-22 20:55:18 -07:00
pablodanswer
4029233df0 hide incomplete sources for non-admins (#1901) 2024-07-22 13:40:11 -07:00
hagen-danswer
6c88c0156c Added file upload retry logic (#1889) 2024-07-22 13:13:22 -07:00
pablodanswer
33332d08f2 fix citation title (#1900)
* fix citation title

* remove title function
2024-07-22 17:37:04 +00:00
hagen-danswer
17005fb705 switched default pruning behavior and removed some logging (#1898) 2024-07-22 17:36:26 +00:00
hagen-danswer
48a7fe80b1 Committed LLM updates to db (#1899) 2024-07-22 10:30:24 -07:00
pablodanswer
1276732409 Misc bug fixes (#1895) 2024-07-22 10:22:43 -07:00
Weves
f91b92a898 Make is_public default true for LLMProvider 2024-07-21 22:22:37 -07:00
Weves
6222f533be Update force delete script to handle user groups 2024-07-21 22:22:37 -07:00
hagen-danswer
1b49d17239 Added ability to control LLM access based on group (#1870)
* Added ability to control LLM access based on group

* completed relationship deletion

* cleaned up function

* added comments

* fixed frontend strings

* mypy fixes

* added case handling for deletion of user groups

* hidden advanced options now

* removed unnecessary code
2024-07-22 04:31:44 +00:00
Yuhong Sun
2f5f19642e Double Check Max Tokens for Indexing (#1893) 2024-07-21 21:12:39 -07:00
Yuhong Sun
6db4634871 Token Truncation (#1892) 2024-07-21 16:26:32 -07:00
Yuhong Sun
5cfed45cef Handle Empty Titles (#1891) 2024-07-21 14:59:23 -07:00
Weves
581ffde35a Fix jira connector failures for server deployments 2024-07-21 14:44:25 -07:00
pablodanswer
6313e6d91d Remove visit api when unneded (#1885)
* quick fix to test on ec2

* quick cleanup

* modify a name

* address full doc as well

* additional timing info + handling

* clean up

* squash

* Print only
2024-07-21 20:57:24 +00:00
Weves
c09c94bf32 Fix assistant swap 2024-07-21 13:57:36 -07:00
Yuhong Sun
0e8ba111c8 Model Touchups (#1887) 2024-07-21 12:31:00 -07:00
Yuhong Sun
2ba24b1734 Reenable Search Pipeline (#1886) 2024-07-21 10:33:29 -07:00
Yuhong Sun
44820b4909 k 2024-07-21 10:27:57 -07:00
hagen-danswer
eb3e7610fc Added retries and multithreading for cloud embedding (#1879)
* added retries and multithreading for cloud embedding

* refactored a bit

* cleaned up code

* got the errors to bubble up to the ui correctly

* added exceptin printing

* added requirements

* touchups

---------

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2024-07-20 22:10:18 -07:00
pablodanswer
7fbbb174bb minor fixes (#1882)
- Assistants tab size
- Fixed logo -> absolute
2024-07-20 21:02:57 -07:00
pablodanswer
3854ca11af add newlines for message content 2024-07-20 18:57:29 -07:00
562 changed files with 27406 additions and 21039 deletions

25
.github/pull_request_template.md vendored Normal file
View File

@@ -0,0 +1,25 @@
## Description
[Provide a brief description of the changes in this PR]
## How Has This Been Tested?
[Describe the tests you ran to verify your changes]
## Accepted Risk
[Any know risks or failure modes to point out to reviewers]
## Related Issue(s)
[If applicable, link to the issue(s) this PR addresses]
## Checklist:
- [ ] All of the automated tests pass
- [ ] All PR comments are addressed and marked resolved
- [ ] If there are migrations, they have been rebased to latest main
- [ ] If there are new dependencies, they are added to the requirements
- [ ] If there are new environment variables, they are added to all of the deployment methods
- [ ] If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
- [ ] Docker images build and basic functionalities work
- [ ] Author has done a final read through of the PR right before merge

View File

@@ -7,7 +7,8 @@ on:
jobs:
build-and-push:
runs-on: ubuntu-latest
runs-on:
group: amd64-image-builders
steps:
- name: Checkout code

View File

@@ -15,7 +15,7 @@ LOG_LEVEL=debug
# This passes top N results to LLM an additional time for reranking prior to answer generation
# This step is quite heavy on token usage so we disable it for dev generally
DISABLE_LLM_CHUNK_FILTER=True
DISABLE_LLM_DOC_RELEVANCE=True
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically)

View File

@@ -39,7 +39,8 @@
"--reload",
"--port",
"9000"
]
],
"consoleTitle": "Model Server"
},
{
"name": "API Server",
@@ -58,7 +59,8 @@
"--reload",
"--port",
"8080"
]
],
"consoleTitle": "API Server"
},
{
"name": "Indexing",
@@ -68,11 +70,12 @@
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env",
"env": {
"ENABLE_MINI_CHUNK": "false",
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
}
},
"consoleTitle": "Indexing"
},
// Celery and all async jobs, usually would include indexing as well but this is handled separately above for dev
{
@@ -90,7 +93,8 @@
},
"args": [
"--no-indexing"
]
],
"consoleTitle": "Background Jobs"
},
// For the listner to access the Slack API,
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
@@ -125,5 +129,17 @@
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
]
}
],
"compounds": [
{
"name": "Run Danswer",
"configurations": [
"Web Server",
"Model Server",
"API Server",
"Indexing",
"Background Jobs",
]
}
]
}

View File

@@ -68,7 +68,9 @@ RUN apt-get update && \
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
# Pre-downloading models for setups with limited egress
RUN python -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('intfloat/e5-base-v2')"
RUN python -c "from tokenizers import Tokenizer; \
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
# Pre-downloading NLTK for setups with limited egress
RUN python -c "import nltk; \

View File

@@ -18,14 +18,18 @@ RUN apt-get remove -y --allow-remove-essential perl-base && \
apt-get autoremove -y
# Pre-downloading models for setups with limited egress
RUN python -c "from transformers import AutoModel, AutoTokenizer, TFDistilBertForSequenceClassification; \
# Download tokenizers, distilbert for the Danswer model
# Download model weights
# Run Nomic to pull in the custom architecture and have it cached locally
RUN python -c "from transformers import AutoTokenizer; \
AutoTokenizer.from_pretrained('distilbert-base-uncased', cache_folder='/root/.cache/temp_huggingface/hub/'); \
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1', cache_folder='/root/.cache/temp_huggingface/hub/'); \
from huggingface_hub import snapshot_download; \
AutoTokenizer.from_pretrained('danswer/intent-model'); \
AutoTokenizer.from_pretrained('intfloat/e5-base-v2'); \
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
snapshot_download('danswer/intent-model'); \
snapshot_download('intfloat/e5-base-v2'); \
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1')"
snapshot_download(repo_id='danswer/hybrid-intent-token-classifier', revision='v1.0.3', cache_dir='/root/.cache/temp_huggingface/hub/'); \
snapshot_download('nomic-ai/nomic-embed-text-v1', cache_dir='/root/.cache/temp_huggingface/hub/'); \
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1', cache_dir='/root/.cache/temp_huggingface/hub/'); \
from sentence_transformers import SentenceTransformer; \
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True, cache_folder='/root/.cache/temp_huggingface/hub/');"
WORKDIR /app

View File

@@ -17,15 +17,11 @@ depends_on: None = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"chat_session",
sa.Column("current_alternate_model", sa.String(), nullable=True),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("chat_session", "current_alternate_model")
# ### end Alembic commands ###

View File

@@ -0,0 +1,26 @@
"""add_indexing_start_to_connector
Revision ID: 08a1eda20fe1
Revises: 8a87bd6ec550
Create Date: 2024-07-23 11:12:39.462397
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "08a1eda20fe1"
down_revision = "8a87bd6ec550"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"connector", sa.Column("indexing_start", sa.DateTime(), nullable=True)
)
def downgrade() -> None:
op.drop_column("connector", "indexing_start")

View File

@@ -0,0 +1,44 @@
"""notifications
Revision ID: 213fd978c6d8
Revises: 5fc1f54cc252
Create Date: 2024-08-10 11:13:36.070790
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "213fd978c6d8"
down_revision = "5fc1f54cc252"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"notification",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"notif_type",
sa.String(),
nullable=False,
),
sa.Column(
"user_id",
sa.UUID(),
nullable=True,
),
sa.Column("dismissed", sa.Boolean(), nullable=False),
sa.Column("last_shown", sa.DateTime(timezone=True), nullable=False),
sa.Column("first_shown", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
def downgrade() -> None:
op.drop_table("notification")

View File

@@ -79,7 +79,7 @@ def downgrade() -> None:
)
op.create_foreign_key(
"document_retrieval_feedback__chat_message_fk",
"document_retrieval",
"document_retrieval_feedback",
"chat_message",
["chat_message_id"],
["id"],

View File

@@ -160,12 +160,28 @@ def downgrade() -> None:
nullable=False,
),
)
op.drop_constraint(
"fk_index_attempt_credential_id", "index_attempt", type_="foreignkey"
)
op.drop_constraint(
"fk_index_attempt_connector_id", "index_attempt", type_="foreignkey"
)
# Check if the constraint exists before dropping
conn = op.get_bind()
inspector = sa.inspect(conn)
constraints = inspector.get_foreign_keys("index_attempt")
if any(
constraint["name"] == "fk_index_attempt_credential_id"
for constraint in constraints
):
op.drop_constraint(
"fk_index_attempt_credential_id", "index_attempt", type_="foreignkey"
)
if any(
constraint["name"] == "fk_index_attempt_connector_id"
for constraint in constraints
):
op.drop_constraint(
"fk_index_attempt_connector_id", "index_attempt", type_="foreignkey"
)
op.drop_column("index_attempt", "credential_id")
op.drop_column("index_attempt", "connector_id")
op.drop_table("connector_credential_pair")

View File

@@ -0,0 +1,70 @@
"""Add icon_color and icon_shape to Persona
Revision ID: 325975216eb3
Revises: 91ffac7e65b3
Create Date: 2024-07-24 21:29:31.784562
"""
import random
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table, column, select
# revision identifiers, used by Alembic.
revision = "325975216eb3"
down_revision = "91ffac7e65b3"
branch_labels: None = None
depends_on: None = None
colorOptions = [
"#FF6FBF",
"#6FB1FF",
"#B76FFF",
"#FFB56F",
"#6FFF8D",
"#FF6F6F",
"#6FFFFF",
]
# Function to generate a random shape ensuring at least 3 of the middle 4 squares are filled
def generate_random_shape() -> int:
center_squares = [12, 10, 6, 14, 13, 11, 7, 15]
center_fill = random.choice(center_squares)
remaining_squares = [i for i in range(16) if not (center_fill & (1 << i))]
random.shuffle(remaining_squares)
for i in range(10 - bin(center_fill).count("1")):
center_fill |= 1 << remaining_squares[i]
return center_fill
def upgrade() -> None:
op.add_column("persona", sa.Column("icon_color", sa.String(), nullable=True))
op.add_column("persona", sa.Column("icon_shape", sa.Integer(), nullable=True))
op.add_column("persona", sa.Column("uploaded_image_id", sa.String(), nullable=True))
persona = table(
"persona",
column("id", sa.Integer),
column("icon_color", sa.String),
column("icon_shape", sa.Integer),
)
conn = op.get_bind()
personas = conn.execute(select(persona.c.id))
for persona_id in personas:
random_color = random.choice(colorOptions)
random_shape = generate_random_shape()
conn.execute(
persona.update()
.where(persona.c.id == persona_id[0])
.values(icon_color=random_color, icon_shape=random_shape)
)
def downgrade() -> None:
op.drop_column("persona", "icon_shape")
op.drop_column("persona", "uploaded_image_id")
op.drop_column("persona", "icon_color")

View File

@@ -18,7 +18,6 @@ depends_on: None = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"chat_message", sa.Column("alternate_assistant_id", sa.Integer(), nullable=True)
)
@@ -29,10 +28,8 @@ def upgrade() -> None:
["alternate_assistant_id"],
["id"],
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("fk_chat_message_persona", "chat_message", type_="foreignkey")
op.drop_column("chat_message", "alternate_assistant_id")

View File

@@ -0,0 +1,42 @@
"""Rename index_origin to index_recursively
Revision ID: 1d6ad76d1f37
Revises: e1392f05e840
Create Date: 2024-08-01 12:38:54.466081
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "1d6ad76d1f37"
down_revision = "e1392f05e840"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.execute(
"""
UPDATE connector
SET connector_specific_config = jsonb_set(
connector_specific_config,
'{index_recursively}',
'true'::jsonb
) - 'index_origin'
WHERE connector_specific_config ? 'index_origin'
"""
)
def downgrade() -> None:
op.execute(
"""
UPDATE connector
SET connector_specific_config = jsonb_set(
connector_specific_config,
'{index_origin}',
connector_specific_config->'index_recursively'
) - 'index_recursively'
WHERE connector_specific_config ? 'index_recursively'
"""
)

View File

@@ -0,0 +1,49 @@
"""Add display_model_names to llm_provider
Revision ID: 473a1a7ca408
Revises: 325975216eb3
Create Date: 2024-07-25 14:31:02.002917
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "473a1a7ca408"
down_revision = "325975216eb3"
branch_labels: None = None
depends_on: None = None
default_models_by_provider = {
"openai": ["gpt-4", "gpt-4o", "gpt-4o-mini"],
"bedrock": [
"meta.llama3-1-70b-instruct-v1:0",
"meta.llama3-1-8b-instruct-v1:0",
"anthropic.claude-3-opus-20240229-v1:0",
"mistral.mistral-large-2402-v1:0",
"anthropic.claude-3-5-sonnet-20240620-v1:0",
],
"anthropic": ["claude-3-opus-20240229", "claude-3-5-sonnet-20240620"],
}
def upgrade() -> None:
op.add_column(
"llm_provider",
sa.Column("display_model_names", postgresql.ARRAY(sa.String()), nullable=True),
)
connection = op.get_bind()
for provider, models in default_models_by_provider.items():
connection.execute(
sa.text(
"UPDATE llm_provider SET display_model_names = :models WHERE provider = :provider"
),
{"models": models, "provider": provider},
)
def downgrade() -> None:
op.drop_column("llm_provider", "display_model_names")

View File

@@ -0,0 +1,80 @@
"""Moved status to connector credential pair
Revision ID: 4a951134c801
Revises: 7477a5f5d728
Create Date: 2024-08-10 19:20:34.527559
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4a951134c801"
down_revision = "7477a5f5d728"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column(
"status",
sa.Enum(
"ACTIVE",
"PAUSED",
"DELETING",
name="connectorcredentialpairstatus",
native_enum=False,
),
nullable=True,
),
)
# Update status of connector_credential_pair based on connector's disabled status
op.execute(
"""
UPDATE connector_credential_pair
SET status = CASE
WHEN (
SELECT disabled
FROM connector
WHERE connector.id = connector_credential_pair.connector_id
) = FALSE THEN 'ACTIVE'
ELSE 'PAUSED'
END
"""
)
# Make the status column not nullable after setting values
op.alter_column("connector_credential_pair", "status", nullable=False)
op.drop_column("connector", "disabled")
def downgrade() -> None:
op.add_column(
"connector",
sa.Column("disabled", sa.BOOLEAN(), autoincrement=False, nullable=True),
)
# Update disabled status of connector based on connector_credential_pair's status
op.execute(
"""
UPDATE connector
SET disabled = CASE
WHEN EXISTS (
SELECT 1
FROM connector_credential_pair
WHERE connector_credential_pair.connector_id = connector.id
AND connector_credential_pair.status = 'ACTIVE'
) THEN FALSE
ELSE TRUE
END
"""
)
# Make the disabled column not nullable after setting values
op.alter_column("connector", "disabled", nullable=False)
op.drop_column("connector_credential_pair", "status")

View File

@@ -0,0 +1,72 @@
"""Add type to credentials
Revision ID: 4ea2c93919c1
Revises: 473a1a7ca408
Create Date: 2024-07-18 13:07:13.655895
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4ea2c93919c1"
down_revision = "473a1a7ca408"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# Add the new 'source' column to the 'credential' table
op.add_column(
"credential",
sa.Column(
"source",
sa.String(length=100), # Use String instead of Enum
nullable=True, # Initially allow NULL values
),
)
op.add_column(
"credential",
sa.Column(
"name",
sa.String(),
nullable=True,
),
)
# Create a temporary table that maps each credential to a single connector source.
# This is needed because a credential can be associated with multiple connectors,
# but we want to assign a single source to each credential.
# We use DISTINCT ON to ensure we only get one row per credential_id.
op.execute(
"""
CREATE TEMPORARY TABLE temp_connector_credential AS
SELECT DISTINCT ON (cc.credential_id)
cc.credential_id,
c.source AS connector_source
FROM connector_credential_pair cc
JOIN connector c ON cc.connector_id = c.id
"""
)
# Update the 'source' column in the 'credential' table
op.execute(
"""
UPDATE credential cred
SET source = COALESCE(
(SELECT connector_source
FROM temp_connector_credential temp
WHERE cred.id = temp.credential_id),
'NOT_APPLICABLE'
)
"""
)
# If no exception was raised, alter the column
op.alter_column("credential", "source", nullable=True) # TODO modify
# # ### end Alembic commands ###
def downgrade() -> None:
op.drop_column("credential", "source")
op.drop_column("credential", "name")

View File

@@ -0,0 +1,25 @@
"""hybrid-enum
Revision ID: 5fc1f54cc252
Revises: 1d6ad76d1f37
Create Date: 2024-08-06 15:35:40.278485
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "5fc1f54cc252"
down_revision = "1d6ad76d1f37"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.drop_column("persona", "search_type")
def downgrade() -> None:
op.add_column("persona", sa.Column("search_type", sa.String(), nullable=True))
op.execute("UPDATE persona SET search_type = 'SEMANTIC'")
op.alter_column("persona", "search_type", nullable=False)

View File

@@ -0,0 +1,24 @@
"""Added model defaults for users
Revision ID: 7477a5f5d728
Revises: 213fd978c6d8
Create Date: 2024-08-04 19:00:04.512634
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7477a5f5d728"
down_revision = "213fd978c6d8"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column("user", sa.Column("default_model", sa.Text(), nullable=True))
def downgrade() -> None:
op.drop_column("user", "default_model")

View File

@@ -28,5 +28,9 @@ def upgrade() -> None:
def downgrade() -> None:
# This wasn't really required by the code either, no good reason to make it unique again
pass
op.create_unique_constraint(
"connector_credential_pair__name__key", "connector_credential_pair", ["name"]
)
op.alter_column(
"connector_credential_pair", "name", existing_type=sa.String(), nullable=True
)

View File

@@ -0,0 +1,41 @@
"""add_llm_group_permissions_control
Revision ID: 795b20b85b4b
Revises: 05c07bf07c00
Create Date: 2024-07-19 11:54:35.701558
"""
from alembic import op
import sqlalchemy as sa
revision = "795b20b85b4b"
down_revision = "05c07bf07c00"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"llm_provider__user_group",
sa.Column("llm_provider_id", sa.Integer(), nullable=False),
sa.Column("user_group_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["llm_provider_id"],
["llm_provider.id"],
),
sa.ForeignKeyConstraint(
["user_group_id"],
["user_group.id"],
),
sa.PrimaryKeyConstraint("llm_provider_id", "user_group_id"),
)
op.add_column(
"llm_provider",
sa.Column("is_public", sa.Boolean(), nullable=False, server_default="true"),
)
def downgrade() -> None:
op.drop_table("llm_provider__user_group")
op.drop_column("llm_provider", "is_public")

View File

@@ -0,0 +1,107 @@
"""associate index attempts with ccpair
Revision ID: 8a87bd6ec550
Revises: 4ea2c93919c1
Create Date: 2024-07-22 15:15:52.558451
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "8a87bd6ec550"
down_revision = "4ea2c93919c1"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
# Add the new connector_credential_pair_id column
op.add_column(
"index_attempt",
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=True),
)
# Create a foreign key constraint to the connector_credential_pair table
op.create_foreign_key(
"fk_index_attempt_connector_credential_pair_id",
"index_attempt",
"connector_credential_pair",
["connector_credential_pair_id"],
["id"],
)
# Populate the new connector_credential_pair_id column using existing connector_id and credential_id
op.execute(
"""
UPDATE index_attempt ia
SET connector_credential_pair_id = (
SELECT id FROM connector_credential_pair ccp
WHERE
(ia.connector_id IS NULL OR ccp.connector_id = ia.connector_id)
AND (ia.credential_id IS NULL OR ccp.credential_id = ia.credential_id)
LIMIT 1
)
WHERE ia.connector_id IS NOT NULL OR ia.credential_id IS NOT NULL
"""
)
# For good measure
op.execute(
"""
DELETE FROM index_attempt
WHERE connector_credential_pair_id IS NULL
"""
)
# Make the new connector_credential_pair_id column non-nullable
op.alter_column("index_attempt", "connector_credential_pair_id", nullable=False)
# Drop the old connector_id and credential_id columns
op.drop_column("index_attempt", "connector_id")
op.drop_column("index_attempt", "credential_id")
# Update the index to use connector_credential_pair_id
op.create_index(
"ix_index_attempt_latest_for_connector_credential_pair",
"index_attempt",
["connector_credential_pair_id", "time_created"],
)
def downgrade() -> None:
# Add back the old connector_id and credential_id columns
op.add_column(
"index_attempt", sa.Column("connector_id", sa.Integer(), nullable=True)
)
op.add_column(
"index_attempt", sa.Column("credential_id", sa.Integer(), nullable=True)
)
# Populate the old connector_id and credential_id columns using the connector_credential_pair_id
op.execute(
"""
UPDATE index_attempt ia
SET connector_id = ccp.connector_id, credential_id = ccp.credential_id
FROM connector_credential_pair ccp
WHERE ia.connector_credential_pair_id = ccp.id
"""
)
# Make the old connector_id and credential_id columns non-nullable
op.alter_column("index_attempt", "connector_id", nullable=False)
op.alter_column("index_attempt", "credential_id", nullable=False)
# Drop the new connector_credential_pair_id column
op.drop_constraint(
"fk_index_attempt_connector_credential_pair_id",
"index_attempt",
type_="foreignkey",
)
op.drop_column("index_attempt", "connector_credential_pair_id")
op.create_index(
"ix_index_attempt_latest_for_connector_credential_pair",
"index_attempt",
["connector_id", "credential_id", "time_created"],
)

View File

@@ -0,0 +1,26 @@
"""add expiry time
Revision ID: 91ffac7e65b3
Revises: bc9771dccadf
Create Date: 2024-06-24 09:39:56.462242
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "91ffac7e65b3"
down_revision = "795b20b85b4b"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.add_column(
"user", sa.Column("oidc_expiry", sa.DateTime(timezone=True), nullable=True)
)
def downgrade() -> None:
op.drop_column("user", "oidc_expiry")

View File

@@ -16,7 +16,6 @@ depends_on: None = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"connector_credential_pair",
"last_attempt_status",
@@ -29,11 +28,9 @@ def upgrade() -> None:
),
nullable=True,
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"connector_credential_pair",
"last_attempt_status",
@@ -46,4 +43,3 @@ def downgrade() -> None:
),
nullable=False,
)
# ### end Alembic commands ###

View File

@@ -0,0 +1,57 @@
"""Add index_attempt_errors table
Revision ID: c5b692fa265c
Revises: 4a951134c801
Create Date: 2024-08-08 14:06:39.581972
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "c5b692fa265c"
down_revision = "4a951134c801"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"index_attempt_errors",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("index_attempt_id", sa.Integer(), nullable=True),
sa.Column("batch", sa.Integer(), nullable=True),
sa.Column(
"doc_summaries",
postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
),
sa.Column("error_msg", sa.Text(), nullable=True),
sa.Column("traceback", sa.Text(), nullable=True),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["index_attempt_id"],
["index_attempt.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"index_attempt_id",
"index_attempt_errors",
["time_created"],
unique=False,
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("index_attempt_id", table_name="index_attempt_errors")
op.drop_table("index_attempt_errors")
# ### end Alembic commands ###

View File

@@ -19,6 +19,9 @@ depends_on: None = None
def upgrade() -> None:
op.drop_table("deletion_attempt")
# Remove the DeletionStatus enum
op.execute("DROP TYPE IF EXISTS deletionstatus;")
def downgrade() -> None:
op.create_table(

View File

@@ -0,0 +1,65 @@
"""chosen_assistants changed to jsonb
Revision ID: da4c21c69164
Revises: c5b692fa265c
Create Date: 2024-08-18 19:06:47.291491
"""
import json
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "da4c21c69164"
down_revision = "c5b692fa265c"
branch_labels = None
depends_on = None
def upgrade() -> None:
conn = op.get_bind()
existing_ids_and_chosen_assistants = conn.execute(
sa.text("select id, chosen_assistants from public.user")
)
op.drop_column(
"user",
"chosen_assistants",
)
op.add_column(
"user",
sa.Column(
"chosen_assistants",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)
for id, chosen_assistants in existing_ids_and_chosen_assistants:
conn.execute(
sa.text(
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
),
{"chosen_assistants": json.dumps(chosen_assistants), "id": id},
)
def downgrade() -> None:
conn = op.get_bind()
existing_ids_and_chosen_assistants = conn.execute(
sa.text("select id, chosen_assistants from public.user")
)
op.drop_column(
"user",
"chosen_assistants",
)
op.add_column(
"user",
sa.Column("chosen_assistants", postgresql.ARRAY(sa.Integer()), nullable=True),
)
for id, chosen_assistants in existing_ids_and_chosen_assistants:
conn.execute(
sa.text(
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
),
{"chosen_assistants": chosen_assistants, "id": id},
)

View File

@@ -136,4 +136,4 @@ def downgrade() -> None:
)
op.drop_column("index_attempt", "embedding_model_id")
op.drop_table("embedding_model")
op.execute("DROP TYPE indexmodelstatus;")
op.execute("DROP TYPE IF EXISTS indexmodelstatus;")

View File

@@ -0,0 +1,58 @@
"""Added input prompts
Revision ID: e1392f05e840
Revises: 08a1eda20fe1
Create Date: 2024-07-13 19:09:22.556224
"""
import fastapi_users_db_sqlalchemy
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "e1392f05e840"
down_revision = "08a1eda20fe1"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_table(
"inputprompt",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("prompt", sa.String(), nullable=False),
sa.Column("content", sa.String(), nullable=False),
sa.Column("active", sa.Boolean(), nullable=False),
sa.Column("is_public", sa.Boolean(), nullable=False),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"inputprompt__user",
sa.Column("input_prompt_id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["input_prompt_id"],
["inputprompt.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["inputprompt.id"],
),
sa.PrimaryKeyConstraint("input_prompt_id", "user_id"),
)
def downgrade() -> None:
op.drop_table("inputprompt__user")
op.drop_table("inputprompt")

View File

@@ -5,19 +5,16 @@ from danswer.access.utils import prefix_user
from danswer.configs.constants import PUBLIC_DOC_PAT
from danswer.db.document import get_acccess_info_for_documents
from danswer.db.models import User
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.utils.variable_functionality import fetch_versioned_implementation
def _get_access_for_documents(
document_ids: list[str],
db_session: Session,
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
) -> dict[str, DocumentAccess]:
document_access_info = get_acccess_info_for_documents(
db_session=db_session,
document_ids=document_ids,
cc_pair_to_delete=cc_pair_to_delete,
)
return {
document_id: DocumentAccess.build(user_ids, [], is_public)
@@ -28,14 +25,13 @@ def _get_access_for_documents(
def get_access_for_documents(
document_ids: list[str],
db_session: Session,
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
) -> dict[str, DocumentAccess]:
"""Fetches all access information for the given documents."""
versioned_get_access_for_documents_fn = fetch_versioned_implementation(
"danswer.access.access", "_get_access_for_documents"
)
return versioned_get_access_for_documents_fn(
document_ids, db_session, cc_pair_to_delete
document_ids, db_session
) # type: ignore

View File

@@ -1,21 +1,20 @@
from typing import cast
from danswer.configs.constants import KV_USER_STORE_KEY
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.dynamic_configs.interface import JSON_ro
USER_STORE_KEY = "INVITED_USERS"
def get_invited_users() -> list[str]:
try:
store = get_dynamic_config_store()
return cast(list, store.load(USER_STORE_KEY))
return cast(list, store.load(KV_USER_STORE_KEY))
except ConfigNotFoundError:
return list()
def write_invited_users(emails: list[str]) -> int:
store = get_dynamic_config_store()
store.store(USER_STORE_KEY, cast(JSON_ro, emails))
store.store(KV_USER_STORE_KEY, cast(JSON_ro, emails))
return len(emails)

View File

@@ -3,29 +3,27 @@ from typing import Any
from typing import cast
from danswer.auth.schemas import UserRole
from danswer.configs.constants import KV_NO_AUTH_USER_PREFERENCES_KEY
from danswer.dynamic_configs.store import ConfigNotFoundError
from danswer.dynamic_configs.store import DynamicConfigStore
from danswer.server.manage.models import UserInfo
from danswer.server.manage.models import UserPreferences
NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"
def set_no_auth_user_preferences(
store: DynamicConfigStore, preferences: UserPreferences
) -> None:
store.store(NO_AUTH_USER_PREFERENCES_KEY, preferences.dict())
store.store(KV_NO_AUTH_USER_PREFERENCES_KEY, preferences.dict())
def load_no_auth_user_preferences(store: DynamicConfigStore) -> UserPreferences:
try:
preferences_data = cast(
Mapping[str, Any], store.load(NO_AUTH_USER_PREFERENCES_KEY)
Mapping[str, Any], store.load(KV_NO_AUTH_USER_PREFERENCES_KEY)
)
return UserPreferences(**preferences_data)
except ConfigNotFoundError:
return UserPreferences(chosen_assistants=None)
return UserPreferences(chosen_assistants=None, default_model=None)
def fetch_no_auth_user(store: DynamicConfigStore) -> UserInfo:

View File

@@ -1,6 +1,8 @@
import smtplib
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime
from datetime import timezone
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import Optional
@@ -50,26 +52,34 @@ from danswer.db.auth import get_default_admin_user_emails
from danswer.db.auth import get_user_count
from danswer.db.auth import get_user_db
from danswer.db.engine import get_session
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import AccessToken
from danswer.db.models import User
from danswer.db.users import get_user_by_email
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import (
fetch_versioned_implementation,
)
from danswer.utils.variable_functionality import fetch_versioned_implementation
logger = setup_logger()
def is_user_admin(user: User | None) -> bool:
if AUTH_TYPE == AuthType.DISABLED:
return True
if user and user.role == UserRole.ADMIN:
return True
return False
def verify_auth_setting() -> None:
if AUTH_TYPE not in [AuthType.DISABLED, AuthType.BASIC, AuthType.GOOGLE_OAUTH]:
raise ValueError(
"User must choose a valid user authentication method: "
"disabled, basic, or google_oauth"
)
logger.info(f"Using Auth Type: {AUTH_TYPE.value}")
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
def get_display_email(email: str | None, space_less: bool = False) -> str:
@@ -92,12 +102,18 @@ def user_needs_to_be_verified() -> bool:
return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
def verify_email_in_whitelist(email: str) -> None:
def verify_email_is_invited(email: str) -> None:
whitelist = get_invited_users()
if (whitelist and email not in whitelist) or not email:
raise PermissionError("User not on allowed user whitelist")
def verify_email_in_whitelist(email: str) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
if not get_user_by_email(email, db_session):
verify_email_is_invited(email)
def verify_email_domain(email: str) -> None:
if VALID_EMAIL_DOMAINS:
if email.count("@") != 1:
@@ -147,7 +163,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
safe: bool = False,
request: Optional[Request] = None,
) -> models.UP:
verify_email_in_whitelist(user_create.email)
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
if hasattr(user_create, "role"):
user_count = await get_user_count()
@@ -173,7 +189,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
verify_email_in_whitelist(account_email)
verify_email_domain(account_email)
return await super().oauth_callback( # type: ignore
user = await super().oauth_callback( # type: ignore
oauth_name=oauth_name,
access_token=access_token,
account_id=account_id,
@@ -185,10 +201,18 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
is_verified_by_default=is_verified_by_default,
)
# NOTE: google oauth expires after 1hr. We don't want to force the user to
# re-authenticate that frequently, so for now we'll just ignore this for
# google oauth users
if expires_at and AUTH_TYPE != AuthType.GOOGLE_OAUTH:
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
await self.user_db.update(user, update_dict={"oidc_expiry": oidc_expiry})
return user
async def on_after_register(
self, user: User, request: Optional[Request] = None
) -> None:
logger.info(f"User {user.id} has registered.")
logger.notice(f"User {user.id} has registered.")
optional_telemetry(
record_type=RecordType.SIGN_UP,
data={"action": "create"},
@@ -198,14 +222,14 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
async def on_after_forgot_password(
self, user: User, token: str, request: Optional[Request] = None
) -> None:
logger.info(f"User {user.id} has forgot their password. Reset token: {token}")
logger.notice(f"User {user.id} has forgot their password. Reset token: {token}")
async def on_after_request_verify(
self, user: User, token: str, request: Optional[Request] = None
) -> None:
verify_email_domain(user.email)
logger.info(
logger.notice(
f"Verification requested for user {user.id}. Verification token: {token}"
)
@@ -227,10 +251,12 @@ cookie_transport = CookieTransport(
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> DatabaseStrategy:
return DatabaseStrategy(
strategy = DatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
)
return strategy
auth_backend = AuthenticationBackend(
name="database",
@@ -327,6 +353,12 @@ async def double_check_user(
detail="Access denied. User is not verified.",
)
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User's OIDC token has expired.",
)
return user
@@ -345,4 +377,5 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not an admin.",
)
return user

View File

@@ -5,6 +5,7 @@ from celery import Celery # type: ignore
from sqlalchemy.orm import Session
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.background.celery.celery_utils import should_kick_off_deletion_of_cc_pair
from danswer.background.celery.celery_utils import should_prune_cc_pair
from danswer.background.celery.celery_utils import should_sync_doc_set
from danswer.background.connector_deletion import delete_connector_credential_pair
@@ -14,6 +15,7 @@ from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.task_utils import name_cc_prune_task
from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import POSTGRES_CELERY_APP_NAME
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.models import InputType
from danswer.db.connector_credential_pair import get_connector_credential_pair
@@ -38,7 +40,9 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
connection_string = build_connection_string(db_api=SYNC_DB_API)
connection_string = build_connection_string(
db_api=SYNC_DB_API, app_name=POSTGRES_CELERY_APP_NAME
)
celery_broker_url = f"sqla+{connection_string}"
celery_backend_url = f"db+{connection_string}"
celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url)
@@ -100,7 +104,7 @@ def cleanup_connector_credential_pair_task(
@build_celery_task_wrapper(name_cc_prune_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def prune_documents_task(connector_id: int, credential_id: int) -> None:
"""connector pruning task. For a cc pair, this task pulls all docuement IDs from the source
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
with Session(get_sqlalchemy_engine()) as db_session:
@@ -267,6 +271,26 @@ def check_for_document_sets_sync_task() -> None:
)
@celery_app.task(
name="check_for_cc_pair_deletion_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_cc_pair_deletion_task() -> None:
"""Runs periodically to check if any deletion tasks should be run"""
with Session(get_sqlalchemy_engine()) as db_session:
# check if any document sets are not synced
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
if should_kick_off_deletion_of_cc_pair(cc_pair, db_session):
logger.notice(f"Deleting the {cc_pair.name} connector credential pair")
cleanup_connector_credential_pair_task.apply_async(
kwargs=dict(
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
),
)
@celery_app.task(
name="check_for_prune_task",
soft_time_limit=JOB_TIMEOUT,
@@ -302,6 +326,12 @@ celery_app.conf.beat_schedule = {
"task": "check_for_document_sets_sync_task",
"schedule": timedelta(seconds=5),
},
"check-for-cc-pair-deletion": {
"task": "check_for_cc_pair_deletion_task",
# don't need to check too often, since we kick off a deletion initially
# during the API call that actually marks the CC pair for deletion
"schedule": timedelta(minutes=1),
},
}
celery_app.conf.beat_schedule.update(
{

View File

@@ -6,8 +6,8 @@ from sqlalchemy.orm import Session
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.task_utils import name_cc_prune_task
from danswer.background.task_utils import name_document_set_sync_task
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.configs.app_configs import PREVENT_SIMULTANEOUS_PRUNING
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
@@ -16,10 +16,14 @@ from danswer.connectors.interfaces import IdConnector
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import Document
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.engine import get_db_current_time
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import Connector
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential
from danswer.db.models import DocumentSet
from danswer.db.models import TaskQueueState
from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.db.tasks import get_latest_task_by_type
@@ -31,22 +35,52 @@ logger = setup_logger()
def get_deletion_status(
connector_id: int, credential_id: int, db_session: Session
) -> DeletionAttemptSnapshot | None:
) -> TaskQueueState | None:
cleanup_task_name = name_cc_cleanup_task(
connector_id=connector_id, credential_id=credential_id
)
task_state = get_latest_task(task_name=cleanup_task_name, db_session=db_session)
return get_latest_task(task_name=cleanup_task_name, db_session=db_session)
if not task_state:
def get_deletion_attempt_snapshot(
connector_id: int, credential_id: int, db_session: Session
) -> DeletionAttemptSnapshot | None:
deletion_task = get_deletion_status(connector_id, credential_id, db_session)
if not deletion_task:
return None
return DeletionAttemptSnapshot(
connector_id=connector_id,
credential_id=credential_id,
status=task_state.status,
status=deletion_task.status,
)
def should_kick_off_deletion_of_cc_pair(
cc_pair: ConnectorCredentialPair, db_session: Session
) -> bool:
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
return False
if check_deletion_attempt_is_allowed(cc_pair, db_session):
return False
deletion_task = get_deletion_status(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
db_session=db_session,
)
if deletion_task and check_task_is_live_and_not_timed_out(
deletion_task,
db_session,
# 1 hour timeout
timeout=60 * 60,
):
return False
return True
def should_sync_doc_set(document_set: DocumentSet, db_session: Session) -> bool:
if document_set.is_up_to_date:
return False
@@ -58,7 +92,7 @@ def should_sync_doc_set(document_set: DocumentSet, db_session: Session) -> bool:
logger.info(f"Document set '{document_set.id}' is already syncing. Skipping.")
return False
logger.info(f"Document set {document_set.id} syncing now!")
logger.info(f"Document set {document_set.id} syncing now.")
return True
@@ -80,7 +114,7 @@ def should_prune_cc_pair(
return True
return False
if PREVENT_SIMULTANEOUS_PRUNING:
if not ALLOW_SIMULTANEOUS_PRUNING:
pruning_type_task_name = name_cc_prune_task()
last_pruning_type_task = get_latest_task_by_type(
pruning_type_task_name, db_session
@@ -89,11 +123,9 @@ def should_prune_cc_pair(
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
last_pruning_type_task, db_session
):
logger.info("Another Connector is already pruning. Skipping.")
return False
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
logger.info(f"Connector '{connector.name}' is already pruning. Skipping.")
return False
if not last_pruning_task.start_time:

View File

@@ -10,8 +10,6 @@ are multiple connector / credential pairs that have indexed it
connector / credential pair from the access list
(6) delete all relevant entries from postgres
"""
import time
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_documents
@@ -24,10 +22,8 @@ from danswer.db.document import delete_documents_complete__no_commit
from danswer.db.document import get_document_connector_cnts
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document_set import get_document_sets_by_ids
from danswer.db.document_set import (
mark_cc_pair__document_set_relationships_to_be_deleted__no_commit,
)
from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.models import ConnectorCredentialPair
@@ -35,6 +31,10 @@ from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import UpdateRequest
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from danswer.utils.variable_functionality import noop_fallback
logger = setup_logger()
@@ -78,25 +78,37 @@ def delete_connector_credential_pair_batch(
document_ids_to_update = [
document_id for document_id, cnt in document_connector_cnts if cnt > 1
]
# maps document id to list of document set names
new_doc_sets_for_documents: dict[str, set[str]] = {
document_id_and_document_set_names_tuple[0]: set(
document_id_and_document_set_names_tuple[1]
)
for document_id_and_document_set_names_tuple in fetch_document_sets_for_documents(
db_session=db_session,
document_ids=document_ids_to_update,
)
}
# determine future ACLs for documents in batch
access_for_documents = get_access_for_documents(
document_ids=document_ids_to_update,
db_session=db_session,
cc_pair_to_delete=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
# update Vespa
logger.debug(f"Updating documents: {document_ids_to_update}")
update_requests = [
UpdateRequest(
document_ids=[document_id],
access=access,
document_sets=new_doc_sets_for_documents[document_id],
)
for document_id, access in access_for_documents.items()
]
logger.debug(f"Updating documents: {document_ids_to_update}")
document_index.update(update_requests=update_requests)
# clean up Postgres
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_ids=document_ids_to_update,
@@ -108,48 +120,6 @@ def delete_connector_credential_pair_batch(
db_session.commit()
def cleanup_synced_entities(
cc_pair: ConnectorCredentialPair, db_session: Session
) -> None:
"""Updates the document sets associated with the connector / credential pair,
then relies on the document set sync script to kick off Celery jobs which will
sync these updates to Vespa.
Waits until the document sets are synced before returning."""
logger.info(f"Cleaning up Document Sets for CC Pair with ID: '{cc_pair.id}'")
document_sets_ids_to_sync = list(
mark_cc_pair__document_set_relationships_to_be_deleted__no_commit(
cc_pair_id=cc_pair.id,
db_session=db_session,
)
)
db_session.commit()
# wait till all document sets are synced before continuing
while True:
all_synced = True
document_sets = get_document_sets_by_ids(
db_session=db_session, document_set_ids=document_sets_ids_to_sync
)
for document_set in document_sets:
if not document_set.is_up_to_date:
all_synced = False
if all_synced:
break
# wait for 30 seconds before checking again
db_session.commit() # end transaction
logger.info(
f"Document sets '{document_sets_ids_to_sync}' not synced yet, waiting 30s"
)
time.sleep(30)
logger.info(
f"Finished cleaning up Document Sets for CC Pair with ID: '{cc_pair.id}'"
)
def delete_connector_credential_pair(
db_session: Session,
document_index: DocumentIndex,
@@ -177,17 +147,33 @@ def delete_connector_credential_pair(
)
num_docs_deleted += len(documents)
# Clean up document sets / access information from Postgres
# and sync these updates to Vespa
# TODO: add user group cleanup with `fetch_versioned_implementation`
cleanup_synced_entities(cc_pair, db_session)
# clean up the rest of the related Postgres entities
# index attempts
delete_index_attempts(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
# document sets
delete_document_set_cc_pair_relationship__no_commit(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
# user groups
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
"danswer.db.user_group",
"delete_user_group_cc_pair_relationship__no_commit",
noop_fallback,
)
cleanup_user_groups(
cc_pair_id=cc_pair.id,
db_session=db_session,
)
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=connector_id,
@@ -199,11 +185,11 @@ def delete_connector_credential_pair(
connector_id=connector_id,
)
if not connector or not len(connector.credentials):
logger.debug("Found no credentials left for connector, deleting connector")
logger.info("Found no credentials left for connector, deleting connector")
db_session.delete(connector)
db_session.commit()
logger.info(
logger.notice(
"Successfully deleted connector_credential_pair with connector_id:"
f" '{connector_id}' and credential_id: '{credential_id}'. Deleted {num_docs_deleted} docs."
)

View File

@@ -41,6 +41,12 @@ def _initializer(
return func(*args, **kwargs)
def _run_in_process(
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
) -> None:
_initializer(func, args, kwargs)
@dataclass
class SimpleJob:
"""Drop in replacement for `dask.distributed.Future`"""
@@ -113,7 +119,7 @@ class SimpleJobClient:
job_id = self.job_id_counter
self.job_id_counter += 1
process = Process(target=_initializer(func=func, args=args), daemon=True)
process = Process(target=_run_in_process, args=(func, args), daemon=True)
job = SimpleJob(id=job_id, process=process)
process.start()

View File

@@ -7,6 +7,9 @@ from datetime import timezone
from sqlalchemy.orm import Session
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
from danswer.background.indexing.tracer import DanswerTracer
from danswer.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
from danswer.configs.app_configs import INDEXING_TRACER_INTERVAL
from danswer.configs.app_configs import POLL_CONNECTOR_OFFSET
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.interfaces import GenerateDocumentsOutput
@@ -14,13 +17,14 @@ from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import IndexAttemptMetadata
from danswer.connectors.models import InputType
from danswer.db.connector import disable_connector
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_in_progress__no_commit
from danswer.db.index_attempt import mark_attempt_in_progress
from danswer.db.index_attempt import mark_attempt_partially_succeeded
from danswer.db.index_attempt import mark_attempt_succeeded
from danswer.db.index_attempt import update_docs_indexed
from danswer.db.models import IndexAttempt
@@ -35,6 +39,8 @@ from danswer.utils.variable_functionality import global_version
logger = setup_logger()
INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
def _get_document_generator(
db_session: Session,
@@ -49,19 +55,26 @@ def _get_document_generator(
are the complete list of existing documents of the connector. If the task
of type LOAD_STATE, the list will be considered complete and otherwise incomplete.
"""
task = attempt.connector.input_type
task = attempt.connector_credential_pair.connector.input_type
try:
runnable_connector = instantiate_connector(
attempt.connector.source,
attempt.connector_credential_pair.connector.source,
task,
attempt.connector.connector_specific_config,
attempt.credential,
attempt.connector_credential_pair.connector.connector_specific_config,
attempt.connector_credential_pair.credential,
db_session,
)
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")
disable_connector(attempt.connector.id, db_session)
# since we failed to even instantiate the connector, we pause the CCPair since
# it will never succeed
update_connector_credential_pair(
db_session=db_session,
connector_id=attempt.connector_credential_pair.connector.id,
credential_id=attempt.connector_credential_pair.credential.id,
status=ConnectorCredentialPairStatus.PAUSED,
)
raise e
if task == InputType.LOAD_STATE:
@@ -70,7 +83,10 @@ def _get_document_generator(
elif task == InputType.POLL:
assert isinstance(runnable_connector, PollConnector)
if attempt.connector_id is None or attempt.credential_id is None:
if (
attempt.connector_credential_pair.connector_id is None
or attempt.connector_credential_pair.connector_id is None
):
raise ValueError(
f"Polling attempt {attempt.id} is missing connector_id or credential_id, "
f"can't fetch time range."
@@ -98,6 +114,7 @@ def _run_indexing(
3. Updates Postgres to record the indexed documents + the outcome of this run
"""
start_time = time.time()
db_embedding_model = index_attempt.embedding_model
index_name = db_embedding_model.index_name
@@ -110,16 +127,12 @@ def _run_indexing(
primary_index_name=index_name, secondary_index_name=None
)
embedding_model = DefaultIndexingEmbedder(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
query_prefix=db_embedding_model.query_prefix,
passage_prefix=db_embedding_model.passage_prefix,
api_key=db_embedding_model.api_key,
provider_type=db_embedding_model.provider_type,
embedding_model = DefaultIndexingEmbedder.from_db_embedding_model(
db_embedding_model
)
indexing_pipeline = build_indexing_pipeline(
attempt_id=index_attempt.id,
embedder=embedding_model,
document_index=document_index,
ignore_time_skip=index_attempt.from_beginning
@@ -127,19 +140,37 @@ def _run_indexing(
db_session=db_session,
)
db_connector = index_attempt.connector
db_credential = index_attempt.credential
db_cc_pair = index_attempt.connector_credential_pair
db_connector = index_attempt.connector_credential_pair.connector
db_credential = index_attempt.connector_credential_pair.credential
last_successful_index_time = (
0.0
if index_attempt.from_beginning
else get_last_successful_attempt_time(
connector_id=db_connector.id,
credential_id=db_credential.id,
embedding_model=index_attempt.embedding_model,
db_session=db_session,
db_connector.indexing_start.timestamp()
if index_attempt.from_beginning and db_connector.indexing_start is not None
else (
0.0
if index_attempt.from_beginning
else get_last_successful_attempt_time(
connector_id=db_connector.id,
credential_id=db_credential.id,
embedding_model=index_attempt.embedding_model,
db_session=db_session,
)
)
)
if INDEXING_TRACER_INTERVAL > 0:
logger.debug(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}")
tracer = DanswerTracer()
tracer.start()
tracer.snap()
index_attempt_md = IndexAttemptMetadata(
connector_id=db_connector.id,
credential_id=db_credential.id,
)
batch_num = 0
net_doc_change = 0
document_count = 0
chunk_count = 0
@@ -166,6 +197,10 @@ def _run_indexing(
)
all_connector_doc_ids: set[str] = set()
tracer_counter = 0
if INDEXING_TRACER_INTERVAL > 0:
tracer.snap()
for doc_batch in doc_batch_generator:
# Check if connector is disabled mid run and stop if so unless it's the secondary
# index being built. We want to populate it even for paused connectors
@@ -173,7 +208,7 @@ def _run_indexing(
# contents still need to be initially pulled.
db_session.refresh(db_connector)
if (
db_connector.disabled
db_cc_pair.status == ConnectorCredentialPairStatus.PAUSED
and db_embedding_model.status != IndexModelStatus.FUTURE
):
# let the `except` block handle this
@@ -184,17 +219,30 @@ def _run_indexing(
# Likely due to user manually disabling it or model swap
raise RuntimeError("Index Attempt was canceled")
logger.debug(
f"Indexing batch of documents: {[doc.to_short_descriptor() for doc in doc_batch]}"
batch_description = []
for doc in doc_batch:
batch_description.append(doc.to_short_descriptor())
doc_size = 0
for section in doc.sections:
doc_size += len(section.text)
if doc_size > INDEXING_SIZE_WARNING_THRESHOLD:
logger.warning(
f"Document size: doc='{doc.to_short_descriptor()}' "
f"size={doc_size} "
f"threshold={INDEXING_SIZE_WARNING_THRESHOLD}"
)
logger.debug(f"Indexing batch of documents: {batch_description}")
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
new_docs, total_batch_chunks = indexing_pipeline(
document_batch=doc_batch,
index_attempt_metadata=index_attempt_md,
)
new_docs, total_batch_chunks = indexing_pipeline(
documents=doc_batch,
index_attempt_metadata=IndexAttemptMetadata(
connector_id=db_connector.id,
credential_id=db_credential.id,
),
)
batch_num += 1
net_doc_change += new_docs
chunk_count += total_batch_chunks
document_count += len(doc_batch)
@@ -216,6 +264,17 @@ def _run_indexing(
docs_removed_from_index=0,
)
tracer_counter += 1
if (
INDEXING_TRACER_INTERVAL > 0
and tracer_counter % INDEXING_TRACER_INTERVAL == 0
):
logger.debug(
f"Running trace comparison for batch {tracer_counter}. interval={INDEXING_TRACER_INTERVAL}"
)
tracer.snap()
tracer.log_previous_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
run_end_dt = window_end
if is_primary:
update_connector_credential_pair(
@@ -226,7 +285,7 @@ def _run_indexing(
run_dt=run_end_dt,
)
except Exception as e:
logger.info(
logger.exception(
f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds"
)
# Only mark the attempt as a complete failure if this is the first indexing window.
@@ -238,7 +297,7 @@ def _run_indexing(
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or db_connector.disabled
or db_cc_pair.status == ConnectorCredentialPairStatus.PAUSED
or index_attempt.status != IndexingStatus.IN_PROGRESS
):
mark_attempt_failed(
@@ -250,17 +309,66 @@ def _run_indexing(
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=index_attempt.connector.id,
credential_id=index_attempt.credential.id,
connector_id=db_connector.id,
credential_id=db_credential.id,
net_docs=net_doc_change,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
# break => similar to success case. As mentioned above, if the next run fails for the same
# reason it will then be marked as a failure
break
mark_attempt_succeeded(index_attempt, db_session)
if INDEXING_TRACER_INTERVAL > 0:
logger.debug(
f"Running trace comparison between start and end of indexing. {tracer_counter} batches processed."
)
tracer.snap()
tracer.log_first_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
tracer.stop()
logger.debug("Memory tracer stopped.")
if (
index_attempt_md.num_exceptions > 0
and index_attempt_md.num_exceptions >= batch_num
):
mark_attempt_failed(
index_attempt,
db_session,
failure_reason="All batches exceptioned.",
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=index_attempt.connector_credential_pair.connector.id,
credential_id=index_attempt.connector_credential_pair.credential.id,
)
raise Exception(
f"Connector failed - All batches exceptioned: batches={batch_num}"
)
elapsed_time = time.time() - start_time
if index_attempt_md.num_exceptions == 0:
mark_attempt_succeeded(index_attempt, db_session)
logger.info(
f"Connector succeeded: "
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
)
else:
mark_attempt_partially_succeeded(index_attempt, db_session)
logger.info(
f"Connector completed with some errors: "
f"exceptions={index_attempt_md.num_exceptions} "
f"batches={batch_num} "
f"docs={document_count} "
f"chunks={chunk_count} "
f"elapsed={elapsed_time:.2f}s"
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
@@ -269,13 +377,6 @@ def _run_indexing(
run_dt=run_end_dt,
)
logger.info(
f"Indexed or refreshed {document_count} total documents for a total of {chunk_count} indexed chunks"
)
logger.info(
f"Connector successfully finished, elapsed time: {time.time() - start_time} seconds"
)
def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexAttempt:
# make sure that the index attempt can't change in between checking the
@@ -299,9 +400,7 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
)
# only commit once, to make sure this all happens in a single transaction
mark_attempt_in_progress__no_commit(attempt)
if attempt.embedding_model.status != IndexModelStatus.PRESENT:
db_session.commit()
mark_attempt_in_progress(attempt, db_session)
return attempt
@@ -324,17 +423,19 @@ def run_indexing_entrypoint(index_attempt_id: int, is_ee: bool = False) -> None:
attempt = _prepare_index_attempt(db_session, index_attempt_id)
logger.info(
f"Running indexing attempt for connector: '{attempt.connector.name}', "
f"with config: '{attempt.connector.connector_specific_config}', and "
f"with credentials: '{attempt.credential_id}'"
f"Indexing starting: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)
_run_indexing(db_session, attempt)
logger.info(
f"Completed indexing attempt for connector: '{attempt.connector.name}', "
f"with config: '{attempt.connector.connector_specific_config}', and "
f"with credentials: '{attempt.credential_id}'"
f"Indexing finished: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)
except Exception as e:
logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}")

View File

@@ -0,0 +1,77 @@
import tracemalloc
from danswer.utils.logger import setup_logger
logger = setup_logger()
DANSWER_TRACEMALLOC_FRAMES = 10
class DanswerTracer:
def __init__(self) -> None:
self.snapshot_first: tracemalloc.Snapshot | None = None
self.snapshot_prev: tracemalloc.Snapshot | None = None
self.snapshot: tracemalloc.Snapshot | None = None
def start(self) -> None:
tracemalloc.start(DANSWER_TRACEMALLOC_FRAMES)
def stop(self) -> None:
tracemalloc.stop()
def snap(self) -> None:
snapshot = tracemalloc.take_snapshot()
# Filter out irrelevant frames (e.g., from tracemalloc itself or importlib)
snapshot = snapshot.filter_traces(
(
tracemalloc.Filter(False, tracemalloc.__file__), # Exclude tracemalloc
tracemalloc.Filter(
False, "<frozen importlib._bootstrap>"
), # Exclude importlib
tracemalloc.Filter(
False, "<frozen importlib._bootstrap_external>"
), # Exclude external importlib
)
)
if not self.snapshot_first:
self.snapshot_first = snapshot
if self.snapshot:
self.snapshot_prev = self.snapshot
self.snapshot = snapshot
def log_snapshot(self, numEntries: int) -> None:
if not self.snapshot:
return
stats = self.snapshot.statistics("traceback")
for s in stats[:numEntries]:
logger.debug(f"Tracer snap: {s}")
for line in s.traceback:
logger.debug(f"* {line}")
@staticmethod
def log_diff(
snap_current: tracemalloc.Snapshot,
snap_previous: tracemalloc.Snapshot,
numEntries: int,
) -> None:
stats = snap_current.compare_to(snap_previous, "traceback")
for s in stats[:numEntries]:
logger.debug(f"Tracer diff: {s}")
for line in s.traceback.format():
logger.debug(f"* {line}")
def log_previous_diff(self, numEntries: int) -> None:
if not self.snapshot or not self.snapshot_prev:
return
DanswerTracer.log_diff(self.snapshot, self.snapshot_prev, numEntries)
def log_first_diff(self, numEntries: int) -> None:
if not self.snapshot or not self.snapshot_first:
return
DanswerTracer.log_diff(self.snapshot, self.snapshot_first, numEntries)

View File

@@ -16,24 +16,29 @@ from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
from danswer.db.connector import fetch_connectors
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import init_sqlalchemy_engine
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_inprogress_index_attempts
from danswer.db.index_attempt import get_last_attempt
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
from danswer.db.index_attempt import get_not_started_index_attempts
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import Connector
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import EmbeddingModel
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.db.swap_index import check_index_swap
from danswer.search.search_nlp_models import warm_up_encoders
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
@@ -41,6 +46,7 @@ from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import LOG_LEVEL
from shared_configs.configs import MODEL_SERVER_PORT
logger = setup_logger()
# If the indexing dies, it's most likely due to resource constraints,
@@ -53,12 +59,14 @@ _UNEXPECTED_STATE_FAILURE_REASON = (
def _should_create_new_indexing(
connector: Connector,
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
model: EmbeddingModel,
secondary_index_building: bool,
db_session: Session,
) -> bool:
connector = cc_pair.connector
# User can still manually create single indexing attempts via the UI for the
# currently in use index
if DISABLE_INDEX_UPDATE_ON_SWAP:
@@ -66,28 +74,46 @@ def _should_create_new_indexing(
return False
# When switching over models, always index at least once
if model.status == IndexModelStatus.FUTURE and not last_index:
if connector.id == 0: # Ingestion API
return False
if model.status == IndexModelStatus.FUTURE:
if last_index:
# No new index if the last index attempt succeeded
# Once is enough. The model will never be able to swap otherwise.
if last_index.status == IndexingStatus.SUCCESS:
return False
# No new index if the last index attempt is waiting to start
if last_index.status == IndexingStatus.NOT_STARTED:
return False
# No new index if the last index attempt is running
if last_index.status == IndexingStatus.IN_PROGRESS:
return False
else:
if connector.id == 0: # Ingestion API
return False
return True
# If the connector is disabled, don't index
# NOTE: during an embedding model switch over, we ignore this
# and index the disabled connectors as well (which is why this if
# statement is below the first condition above)
if connector.disabled:
# If the connector is paused or is the ingestion API, don't index
# NOTE: during an embedding model switch over, the following logic
# is bypassed by the above check for a future model
if cc_pair.status == ConnectorCredentialPairStatus.PAUSED or connector.id == 0:
return False
if connector.refresh_freq is None:
return False
if not last_index:
return True
# Only one scheduled job per connector at a time
# Can schedule another one if the current one is already running however
# Because the currently running one will not be until the latest time
# Note, this last index is for the given embedding model
if last_index.status == IndexingStatus.NOT_STARTED:
if connector.refresh_freq is None:
return False
# Only one scheduled/ongoing job per connector at a time
# this prevents cases where
# (1) the "latest" index_attempt is scheduled so we show
# that in the UI despite another index_attempt being in-progress
# (2) multiple scheduled index_attempts at a time
if (
last_index.status == IndexingStatus.NOT_STARTED
or last_index.status == IndexingStatus.IN_PROGRESS
):
return False
current_db_time = get_db_current_time(db_session)
@@ -95,24 +121,14 @@ def _should_create_new_indexing(
return time_since_index.total_seconds() >= connector.refresh_freq
def _is_indexing_job_marked_as_finished(index_attempt: IndexAttempt | None) -> bool:
if index_attempt is None:
return False
return (
index_attempt.status == IndexingStatus.FAILED
or index_attempt.status == IndexingStatus.SUCCESS
)
def _mark_run_failed(
db_session: Session, index_attempt: IndexAttempt, failure_reason: str
) -> None:
"""Marks the `index_attempt` row as failed + updates the `
connector_credential_pair` to reflect that the run failed"""
logger.warning(
f"Marking in-progress attempt 'connector: {index_attempt.connector_id}, "
f"credential: {index_attempt.credential_id}' as failed due to {failure_reason}"
f"Marking in-progress attempt 'connector: {index_attempt.connector_credential_pair.connector_id}, "
f"credential: {index_attempt.connector_credential_pair.credential_id}' as failed due to {failure_reason}"
)
mark_attempt_failed(
index_attempt=index_attempt,
@@ -131,7 +147,7 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
3. There is not already an ongoing indexing attempt for this pair
"""
with Session(get_sqlalchemy_engine()) as db_session:
ongoing: set[tuple[int | None, int | None, int]] = set()
ongoing: set[tuple[int | None, int]] = set()
for attempt_id in existing_jobs:
attempt = get_index_attempt(
db_session=db_session, index_attempt_id=attempt_id
@@ -144,8 +160,7 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
continue
ongoing.add(
(
attempt.connector_id,
attempt.credential_id,
attempt.connector_credential_pair_id,
attempt.embedding_model_id,
)
)
@@ -155,31 +170,26 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
if secondary_embedding_model is not None:
embedding_models.append(secondary_embedding_model)
all_connectors = fetch_connectors(db_session)
for connector in all_connectors:
for association in connector.credentials:
for model in embedding_models:
credential = association.credential
all_connector_credential_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair in all_connector_credential_pairs:
for model in embedding_models:
# Check if there is an ongoing indexing attempt for this connector credential pair
if (cc_pair.id, model.id) in ongoing:
continue
# Check if there is an ongoing indexing attempt for this connector + credential pair
if (connector.id, credential.id, model.id) in ongoing:
continue
last_attempt = get_last_attempt_for_cc_pair(
cc_pair.id, model.id, db_session
)
if not _should_create_new_indexing(
cc_pair=cc_pair,
last_index=last_attempt,
model=model,
secondary_index_building=len(embedding_models) > 1,
db_session=db_session,
):
continue
last_attempt = get_last_attempt(
connector.id, credential.id, model.id, db_session
)
if not _should_create_new_indexing(
connector=connector,
last_index=last_attempt,
model=model,
secondary_index_building=len(embedding_models) > 1,
db_session=db_session,
):
continue
create_index_attempt(
connector.id, credential.id, model.id, db_session
)
create_index_attempt(cc_pair.id, model.id, db_session)
def cleanup_indexing_jobs(
@@ -196,10 +206,12 @@ def cleanup_indexing_jobs(
)
# do nothing for ongoing jobs that haven't been stopped
if not job.done() and not _is_indexing_job_marked_as_finished(
index_attempt
):
continue
if not job.done():
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
if job.status == "error":
logger.error(job.exception())
@@ -271,24 +283,28 @@ def kickoff_indexing_jobs(
# Don't include jobs waiting in the Dask queue that just haven't started running
# Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
with Session(engine) as db_session:
# get_not_started_index_attempts orders its returned results from oldest to newest
# we must process attempts in a FIFO manner to prevent connector starvation
new_indexing_attempts = [
(attempt, attempt.embedding_model)
for attempt in get_not_started_index_attempts(db_session)
if attempt.id not in existing_jobs
]
logger.info(f"Found {len(new_indexing_attempts)} new indexing tasks.")
logger.debug(f"Found {len(new_indexing_attempts)} new indexing task(s).")
if not new_indexing_attempts:
return existing_jobs
indexing_attempt_count = 0
for attempt, embedding_model in new_indexing_attempts:
use_secondary_index = (
embedding_model.status == IndexModelStatus.FUTURE
if embedding_model is not None
else False
)
if attempt.connector is None:
if attempt.connector_credential_pair.connector is None:
logger.warning(
f"Skipping index attempt as Connector has been deleted: {attempt}"
)
@@ -297,7 +313,7 @@ def kickoff_indexing_jobs(
attempt, db_session, failure_reason="Connector is null"
)
continue
if attempt.credential is None:
if attempt.connector_credential_pair.credential is None:
logger.warning(
f"Skipping index attempt as Credential has been deleted: {attempt}"
)
@@ -323,35 +339,53 @@ def kickoff_indexing_jobs(
)
if run:
secondary_str = "(secondary index) " if use_secondary_index else ""
if indexing_attempt_count == 0:
logger.info(
f"Indexing dispatch starts: pending={len(new_indexing_attempts)}"
)
indexing_attempt_count += 1
secondary_str = " (secondary index)" if use_secondary_index else ""
logger.info(
f"Kicked off {secondary_str}"
f"indexing attempt for connector: '{attempt.connector.name}', "
f"with config: '{attempt.connector.connector_specific_config}', and "
f"with credentials: '{attempt.credential_id}'"
f"Indexing dispatched{secondary_str}: "
f"attempt_id={attempt.id} "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.credential_id}'"
)
existing_jobs_copy[attempt.id] = run
if indexing_attempt_count > 0:
logger.info(
f"Indexing dispatch results: "
f"initial_pending={len(new_indexing_attempts)} "
f"started={indexing_attempt_count} "
f"remaining={len(new_indexing_attempts) - indexing_attempt_count}"
)
return existing_jobs_copy
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
def update_loop(
delay: int = 10,
num_workers: int = NUM_INDEXING_WORKERS,
num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
) -> None:
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
check_index_swap(db_session=db_session)
db_embedding_model = get_current_db_embedding_model(db_session)
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
if db_embedding_model.cloud_provider_id is None:
logger.info("Running a first inference to warm up embedding model")
warm_up_encoders(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
model_server_host=INDEXING_MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
if db_embedding_model.cloud_provider_id is None:
logger.notice("Running a first inference to warm up embedding model")
warm_up_bi_encoder(
embedding_model=db_embedding_model,
model_server_host=INDEXING_MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
client_primary: Client | SimpleJobClient
client_secondary: Client | SimpleJobClient
@@ -366,7 +400,7 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
silence_logs=logging.ERROR,
)
cluster_secondary = LocalCluster(
n_workers=num_workers,
n_workers=num_secondary_workers,
threads_per_worker=1,
silence_logs=logging.ERROR,
)
@@ -376,18 +410,18 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
client_primary.register_worker_plugin(ResourceLogger())
else:
client_primary = SimpleJobClient(n_workers=num_workers)
client_secondary = SimpleJobClient(n_workers=num_workers)
client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
existing_jobs: dict[int, Future | SimpleJob] = {}
while True:
start = time.time()
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
logger.info(f"Running update, current UTC time: {start_time_utc}")
logger.debug(f"Running update, current UTC time: {start_time_utc}")
if existing_jobs:
# TODO: make this debug level once the "no jobs are being scheduled" issue is resolved
logger.info(
logger.debug(
"Found existing indexing jobs: "
f"{[(attempt_id, job.status) for attempt_id, job in existing_jobs.items()]}"
)
@@ -411,8 +445,9 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
def update__main() -> None:
set_is_ee_based_on_env_variable()
init_sqlalchemy_engine(POSTGRES_INDEXER_APP_NAME)
logger.info("Starting Indexing Loop")
logger.notice("Starting indexing service")
update_loop()

View File

@@ -35,14 +35,17 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
def create_chat_chain(
chat_session_id: int,
db_session: Session,
prefetch_tool_calls: bool = True,
) -> tuple[ChatMessage, list[ChatMessage]]:
"""Build the linear chain of messages without including the root message"""
mainline_messages: list[ChatMessage] = []
all_chat_messages = get_chat_messages_by_session(
chat_session_id=chat_session_id,
user_id=None,
db_session=db_session,
skip_permission_check=True,
prefetch_tool_calls=prefetch_tool_calls,
)
id_to_msg = {msg.id: msg for msg in all_chat_messages}

View File

@@ -0,0 +1,24 @@
input_prompts:
- id: -5
prompt: "Elaborate"
content: "Elaborate on the above, give me a more in depth explanation."
active: true
is_public: true
- id: -4
prompt: "Reword"
content: "Help me rewrite the following politely and concisely for professional communication:\n"
active: true
is_public: true
- id: -3
prompt: "Email"
content: "Write a professional email for me including a subject line, signature, etc. Template the parts that need editing with [ ]. The email should cover the following points:\n"
active: true
is_public: true
- id: -2
prompt: "Debug"
content: "Provide step-by-step troubleshooting instructions for the following issue:\n"
active: true
is_public: true

View File

@@ -1,13 +1,17 @@
import yaml
from sqlalchemy.orm import Session
from danswer.configs.chat_configs import INPUT_PROMPT_YAML
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import PERSONAS_YAML
from danswer.configs.chat_configs import PROMPTS_YAML
from danswer.db.document_set import get_or_create_document_set_by_name
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
from danswer.db.models import DocumentSet as DocumentSetDBModel
from danswer.db.models import Persona
from danswer.db.models import Prompt as PromptDBModel
from danswer.db.models import Tool as ToolDBModel
from danswer.db.persona import get_prompt_by_name
from danswer.db.persona import upsert_persona
from danswer.db.persona import upsert_prompt
@@ -76,9 +80,31 @@ def load_personas_from_yaml(
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
p_id = persona.get("id")
tool_ids = []
if persona.get("image_generation"):
image_gen_tool = (
db_session.query(ToolDBModel)
.filter(ToolDBModel.name == "ImageGenerationTool")
.first()
)
if image_gen_tool:
tool_ids.append(image_gen_tool.id)
llm_model_provider_override = persona.get("llm_model_provider_override")
llm_model_version_override = persona.get("llm_model_version_override")
# Set specific overrides for image generation persona
if persona.get("image_generation"):
llm_model_version_override = "gpt-4o"
existing_persona = (
db_session.query(Persona)
.filter(Persona.name == persona["name"])
.first()
)
upsert_persona(
user=None,
# Negative to not conflict with existing personas
persona_id=(-1 * p_id) if p_id is not None else None,
name=persona["name"],
description=persona["description"],
@@ -88,20 +114,52 @@ def load_personas_from_yaml(
llm_relevance_filter=persona.get("llm_relevance_filter"),
starter_messages=persona.get("starter_messages"),
llm_filter_extraction=persona.get("llm_filter_extraction"),
llm_model_provider_override=None,
llm_model_version_override=None,
icon_shape=persona.get("icon_shape"),
icon_color=persona.get("icon_color"),
llm_model_provider_override=llm_model_provider_override,
llm_model_version_override=llm_model_version_override,
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
prompt_ids=prompt_ids,
document_set_ids=doc_set_ids,
tool_ids=tool_ids,
default_persona=True,
is_public=True,
display_priority=existing_persona.display_priority
if existing_persona is not None
else persona.get("display_priority"),
is_visible=existing_persona.is_visible
if existing_persona is not None
else persona.get("is_visible"),
db_session=db_session,
)
def load_input_prompts_from_yaml(input_prompts_yaml: str = INPUT_PROMPT_YAML) -> None:
with open(input_prompts_yaml, "r") as file:
data = yaml.safe_load(file)
all_input_prompts = data.get("input_prompts", [])
with Session(get_sqlalchemy_engine()) as db_session:
for input_prompt in all_input_prompts:
# If these prompts are deleted (which is a hard delete in the DB), on server startup
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
insert_input_prompt_if_not_exists(
user=None,
input_prompt_id=input_prompt.get("id"),
prompt=input_prompt["prompt"],
content=input_prompt["content"],
is_public=input_prompt["is_public"],
active=input_prompt.get("active", True),
db_session=db_session,
commit=True,
)
def load_chat_yamls(
prompt_yaml: str = PROMPTS_YAML,
personas_yaml: str = PERSONAS_YAML,
input_prompts_yaml: str = INPUT_PROMPT_YAML,
) -> None:
load_prompts_from_yaml(prompt_yaml)
load_personas_from_yaml(personas_yaml)
load_input_prompts_from_yaml(input_prompts_yaml)

View File

@@ -46,15 +46,22 @@ class LLMRelevanceFilterResponse(BaseModel):
relevant_chunk_indices: list[int]
class RelevanceChunk(BaseModel):
# TODO make this document level. Also slight misnomer here as this is actually
# done at the section level currently rather than the chunk
relevant: bool | None = None
class RelevanceAnalysis(BaseModel):
relevant: bool
content: str | None = None
class LLMRelevanceSummaryResponse(BaseModel):
relevance_summaries: dict[str, RelevanceChunk]
class SectionRelevancePiece(RelevanceAnalysis):
"""LLM analysis mapped to an Inference Section"""
document_id: str
chunk_id: int # ID of the center chunk for a given inference section
class DocumentRelevance(BaseModel):
"""Contains all relevance information for a given search"""
relevance_summaries: dict[str, RelevanceAnalysis]
class DanswerAnswerPiece(BaseModel):
@@ -69,8 +76,14 @@ class CitationInfo(BaseModel):
document_id: str
class MessageResponseIDInfo(BaseModel):
user_message_id: int | None
reserved_assistant_message_id: int
class StreamingError(BaseModel):
error: str
stack_trace: str | None = None
class DanswerQuote(BaseModel):

View File

@@ -5,7 +5,7 @@ personas:
# this is for DanswerBot to use when tagged in a non-configured channel
# Careful setting specific IDs, this won't autoincrement the next ID value for postgres
- id: 0
name: "Danswer"
name: "Knowledge"
description: >
Assistant with access to documents from your Connected Sources.
# Default Prompt objects attached to the persona, see prompts.yaml
@@ -17,7 +17,7 @@ personas:
num_chunks: 10
# Enable/Disable usage of the LLM chunk filter feature whereby each chunk is passed to the LLM to determine
# if the chunk is useful or not towards the latest user query
# This feature can be overriden for all personas via DISABLE_LLM_CHUNK_FILTER env variable
# This feature can be overriden for all personas via DISABLE_LLM_DOC_RELEVANCE env variable
llm_relevance_filter: true
# Enable/Disable usage of the LLM to extract query time filters including source type and time range filters
llm_filter_extraction: true
@@ -37,12 +37,15 @@ personas:
# - "Engineer Onboarding"
# - "Benefits"
document_sets: []
icon_shape: 23013
icon_color: "#6FB1FF"
display_priority: 1
is_visible: true
- id: 1
name: "GPT"
name: "General"
description: >
Assistant with no access to documents. Chat with just the Language Model.
Assistant with no access to documents. Chat with just the Large Language Model.
prompts:
- "OnlyLLM"
num_chunks: 0
@@ -50,7 +53,10 @@ personas:
llm_filter_extraction: true
recency_bias: "auto"
document_sets: []
icon_shape: 50910
icon_color: "#FF6F6F"
display_priority: 0
is_visible: true
- id: 2
name: "Paraphrase"
@@ -63,3 +69,25 @@ personas:
llm_filter_extraction: true
recency_bias: "auto"
document_sets: []
icon_shape: 45519
icon_color: "#6FFF8D"
display_priority: 2
is_visible: false
- id: 3
name: "Art"
description: >
Assistant for generating images based on descriptions.
prompts:
- "ImageGeneration"
num_chunks: 0
llm_relevance_filter: false
llm_filter_extraction: false
recency_bias: "no_decay"
document_sets: []
icon_shape: 234124
icon_color: "#9B59B6"
image_generation: true
display_priority: 3
is_visible: true

View File

@@ -11,6 +11,7 @@ from danswer.chat.models import CustomToolResponse
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import ImageGenerationDisplay
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import MessageResponseIDInfo
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import BING_API_KEY
@@ -27,6 +28,7 @@ from danswer.db.chat import get_chat_session_by_id
from danswer.db.chat import get_db_search_doc_by_id
from danswer.db.chat import get_doc_query_identifiers_from_model
from danswer.db.chat import get_or_create_root_message
from danswer.db.chat import reserve_message_id
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.embedding_model import get_current_db_embedding_model
@@ -51,7 +53,9 @@ from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import litellm_exception_to_error_msg
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
@@ -60,6 +64,7 @@ from danswer.search.retrieval.search_runner import inference_sections_from_ids
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.search.utils import dedupe_documents
from danswer.search.utils import drop_llm_indices
from danswer.search.utils import relevant_sections_to_indices
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
@@ -178,7 +183,7 @@ def _handle_internet_search_tool_response_summary(
rephrased_query=internet_search_response.revised_query,
top_documents=response_docs,
predicted_flow=QueryFlow.QUESTION_ANSWER,
predicted_search=SearchType.HYBRID,
predicted_search=SearchType.SEMANTIC,
applied_source_filters=[],
applied_time_cutoff=None,
recency_bias_multiplier=1.0,
@@ -187,37 +192,46 @@ def _handle_internet_search_tool_response_summary(
)
def _check_should_force_search(
new_msg_req: CreateChatMessageRequest,
) -> ForceUseTool | None:
# If files are already provided, don't run the search tool
def _get_force_search_settings(
new_msg_req: CreateChatMessageRequest, tools: list[Tool]
) -> ForceUseTool:
internet_search_available = any(
isinstance(tool, InternetSearchTool) for tool in tools
)
search_tool_available = any(isinstance(tool, SearchTool) for tool in tools)
if not internet_search_available and not search_tool_available:
# Does not matter much which tool is set here as force is false and neither tool is available
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
tool_name = SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
# Currently, the internet search tool does not support query override
args = (
{"query": new_msg_req.query_override}
if new_msg_req.query_override and tool_name == SearchTool._NAME
else None
)
if new_msg_req.file_descriptors:
return None
# If user has uploaded files they're using, don't run any of the search tools
return ForceUseTool(force_use=False, tool_name=tool_name)
if (
new_msg_req.query_override
or (
should_force_search = any(
[
new_msg_req.retrieval_options
and new_msg_req.retrieval_options.run_search == OptionalSearchSetting.ALWAYS
)
or new_msg_req.search_doc_ids
or DISABLE_LLM_CHOOSE_SEARCH
):
args = (
{"query": new_msg_req.query_override}
if new_msg_req.query_override
else None
)
# if we are using selected docs, just put something here so the Tool doesn't need
# to build its own args via an LLM call
if new_msg_req.search_doc_ids:
args = {"query": new_msg_req.message}
and new_msg_req.retrieval_options.run_search
== OptionalSearchSetting.ALWAYS,
new_msg_req.search_doc_ids,
DISABLE_LLM_CHOOSE_SEARCH,
]
)
return ForceUseTool(
tool_name=SearchTool._NAME,
args=args,
)
return None
if should_force_search:
# If we are using selected docs, just put something here so the Tool doesn't need to build its own args via an LLM call
args = {"query": new_msg_req.message} if new_msg_req.search_doc_ids else args
return ForceUseTool(force_use=True, tool_name=tool_name, args=args)
return ForceUseTool(force_use=False, tool_name=tool_name, args=args)
ChatPacket = (
@@ -229,6 +243,7 @@ ChatPacket = (
| CitationInfo
| ImageGenerationDisplay
| CustomToolResponse
| MessageResponseIDInfo
)
ChatPacketStream = Iterator[ChatPacket]
@@ -244,16 +259,15 @@ def stream_chat_message_objects(
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
# if specified, uses the last user message and does not create a new user message based
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
# user message (e.g. this can only be used for the chat-seeding flow).
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
) -> ChatPacketStream:
"""Streams in order:
1. [conditional] Retrieved documents if a search needs to be run
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
4. [always] Details on the final AI response message that is created
"""
try:
user_id = user.id if user is not None else None
@@ -274,7 +288,10 @@ def stream_chat_message_objects(
# use alternate persona if alternative assistant id is passed in
if alternate_assistant_id is not None:
persona = get_persona_by_id(
alternate_assistant_id, user=user, db_session=db_session
alternate_assistant_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
else:
persona = chat_session.persona
@@ -297,7 +314,13 @@ def stream_chat_message_objects(
except GenAIDisabledException:
raise RuntimeError("LLM is disabled. Can't use chat flow without LLM.")
llm_tokenizer = get_default_llm_tokenizer()
llm_provider = llm.config.model_provider
llm_model_name = llm.config.model_name
llm_tokenizer = get_tokenizer(
model_name=llm_model_name,
provider_type=llm_provider,
)
llm_tokenizer_encode_func = cast(
Callable[[str], list[int]], llm_tokenizer.encode
)
@@ -361,6 +384,14 @@ def stream_chat_message_objects(
"when the last message is not a user message."
)
# Disable Query Rephrasing for the first message
# This leads to a better first response since the LLM rephrasing the question
# leads to worst search quality
if not history_msgs:
new_msg_req.query_override = (
new_msg_req.query_override or new_msg_req.message
)
# load all files needed for this chat chain in memory
files = load_all_chat_files(
history_msgs, new_msg_req.file_descriptors, db_session
@@ -420,10 +451,19 @@ def stream_chat_message_objects(
else default_num_chunks
),
max_window_percentage=max_document_percentage,
use_sections=new_msg_req.chunks_above > 0
or new_msg_req.chunks_below > 0,
)
reserved_message_id = reserve_message_id(
db_session=db_session,
chat_session_id=chat_session_id,
parent_message=user_message.id
if user_message is not None
else parent_message.id,
message_type=MessageType.ASSISTANT,
)
yield MessageResponseIDInfo(
user_message_id=user_message.id if user_message else None,
reserved_assistant_message_id=reserved_message_id,
)
# Cannot determine these without the LLM step or breaking out early
partial_response = partial(
create_new_chat_message,
@@ -476,6 +516,9 @@ def stream_chat_message_objects(
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
evaluation_type=LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP,
)
tool_dict[db_tool_model.id] = [search_tool]
elif tool_cls.__name__ == ImageGenerationTool.__name__:
@@ -544,13 +587,16 @@ def stream_chat_message_objects(
tools.extend(tool_list)
# factor in tool definition size when pruning
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(tools)
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
tools, llm_tokenizer
)
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
llm.config.model_provider, llm.config.model_name
llm_provider, llm_model_name
)
# LLM prompt building, response capturing, etc.
answer = Answer(
is_connected=is_connected,
question=final_msg.message,
latest_query_files=latest_query_files,
answer_style_config=AnswerStyleConfig(
@@ -576,11 +622,7 @@ def stream_chat_message_objects(
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
],
tools=tools,
force_use_tool=(
_check_should_force_search(new_msg_req)
if search_tool and len(tools) == 1
else None
),
force_use_tool=_get_force_search_settings(new_msg_req, tools),
)
reference_db_search_docs = None
@@ -588,6 +630,7 @@ def stream_chat_message_objects(
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
dropped_indices = None
tool_result = None
for packet in answer.processed_streamed_output:
if isinstance(packet, ToolResponse):
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
@@ -606,18 +649,28 @@ def stream_chat_message_objects(
)
yield qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
chunk_indices = packet.response
relevance_sections = packet.response
if reference_db_search_docs is not None and dropped_indices:
chunk_indices = drop_llm_indices(
llm_indices=chunk_indices,
search_docs=reference_db_search_docs,
dropped_indices=dropped_indices,
if reference_db_search_docs is not None:
llm_indices = relevant_sections_to_indices(
relevance_sections=relevance_sections,
items=[
translate_db_search_doc_to_server_search_doc(doc)
for doc in reference_db_search_docs
],
)
if dropped_indices:
llm_indices = drop_llm_indices(
llm_indices=llm_indices,
search_docs=reference_db_search_docs,
dropped_indices=dropped_indices,
)
yield LLMRelevanceFilterResponse(
relevant_chunk_indices=llm_indices
)
yield LLMRelevanceFilterResponse(
relevant_chunk_indices=chunk_indices
)
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], packet.response
@@ -653,20 +706,15 @@ def stream_chat_message_objects(
if isinstance(packet, ToolCallFinalResult):
tool_result = packet
yield cast(ChatPacket, packet)
logger.debug("Reached end of stream")
except Exception as e:
logger.exception("Failed to process chat message")
# Don't leak the API key
error_msg = str(e)
if llm.config.api_key and llm.config.api_key.lower() in error_msg.lower():
error_msg = (
f"LLM failed to respond. Invalid API "
f"key error from '{llm.config.model_provider}'."
)
logger.exception(f"Failed to process chat message: {error_msg}")
yield StreamingError(error=error_msg)
# Cancel the transaction so that no messages are saved
client_error_msg = litellm_exception_to_error_msg(e, llm)
if llm.config.api_key and len(llm.config.api_key) > 2:
error_msg = error_msg.replace(llm.config.api_key, "[REDACTED_API_KEY]")
yield StreamingError(error=client_error_msg, stack_trace=error_msg)
db_session.rollback()
return
@@ -686,6 +734,7 @@ def stream_chat_message_objects(
tool_name_to_tool_id[tool.name] = tool_id
gen_ai_response_message = partial_response(
reserved_message_id=reserved_message_id,
message=answer.llm_answer,
rephrased_query=(
qa_docs_response.rephrased_query if qa_docs_response else None
@@ -706,6 +755,8 @@ def stream_chat_message_objects(
if tool_result
else [],
)
logger.debug("Committing messages")
db_session.commit() # actually save user / assistant message
msg_detail_response = translate_db_message_to_chat_message_detail(
@@ -714,7 +765,8 @@ def stream_chat_message_objects(
yield msg_detail_response
except Exception as e:
logger.exception(e)
error_msg = str(e)
logger.exception(error_msg)
# Frontend will erase whatever answer and show this instead
yield StreamingError(error="Failed to parse LLM output")
@@ -726,6 +778,7 @@ def stream_chat_message(
user: User | None,
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
) -> Iterator[str]:
with get_session_context_manager() as db_session:
objects = stream_chat_message_objects(
@@ -734,6 +787,7 @@ def stream_chat_message(
db_session=db_session,
use_existing_user_message=use_existing_user_message,
litellm_additional_headers=litellm_additional_headers,
is_connected=is_connected,
)
for obj in objects:
yield get_json_line(obj.dict())

View File

@@ -30,7 +30,23 @@ prompts:
# Prompts the LLM to include citations in the for [1], [2] etc.
# which get parsed to match the passed in sources
include_citations: true
- name: "ImageGeneration"
description: "Generates images based on user prompts!"
system: >
You are an advanced image generation system capable of creating diverse and detailed images.
You can interpret user prompts and generate high-quality, creative images that match their descriptions.
You always strive to create safe and appropriate content, avoiding any harmful or offensive imagery.
task: >
Generate an image based on the user's description.
Provide a detailed description of the generated image, including key elements, colors, and composition.
If the request is not possible or appropriate, explain why and suggest alternatives.
datetime_aware: true
include_citations: false
- name: "OnlyLLM"
description: "Chat directly with the LLM!"

View File

@@ -129,6 +129,17 @@ POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
# defaults to False
POSTGRES_POOL_PRE_PING = os.environ.get("POSTGRES_POOL_PRE_PING", "").lower() == "true"
# recycle timeout in seconds
POSTGRES_POOL_RECYCLE_DEFAULT = 60 * 20 # 20 minutes
try:
POSTGRES_POOL_RECYCLE = int(
os.environ.get("POSTGRES_POOL_RECYCLE", POSTGRES_POOL_RECYCLE_DEFAULT)
)
except ValueError:
POSTGRES_POOL_RECYCLE = POSTGRES_POOL_RECYCLE_DEFAULT
#####
# Connector Configs
@@ -191,6 +202,11 @@ CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING = (
os.environ.get("CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING", "").lower() == "true"
)
# Attachments exceeding this size will not be retrieved (in bytes)
CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD = int(
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD", 50 * 1024 * 1024)
)
JIRA_CONNECTOR_LABELS_TO_SKIP = [
ignored_tag
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
@@ -212,10 +228,11 @@ EXPERIMENTAL_CHECKPOINTING_ENABLED = (
os.environ.get("EXPERIMENTAL_CHECKPOINTING_ENABLED", "").lower() == "true"
)
PRUNING_DISABLED = -1
DEFAULT_PRUNING_FREQ = 60 * 60 * 24 # Once a day
PREVENT_SIMULTANEOUS_PRUNING = (
os.environ.get("PREVENT_SIMULTANEOUS_PRUNING", "").lower() == "true"
ALLOW_SIMULTANEOUS_PRUNING = (
os.environ.get("ALLOW_SIMULTANEOUS_PRUNING", "").lower() == "true"
)
# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation.
@@ -248,18 +265,39 @@ DISABLE_INDEX_UPDATE_ON_SWAP = (
# fairly large amount of memory in order to increase substantially, since
# each worker loads the embedding models into memory.
NUM_INDEXING_WORKERS = int(os.environ.get("NUM_INDEXING_WORKERS") or 1)
NUM_SECONDARY_INDEXING_WORKERS = int(
os.environ.get("NUM_SECONDARY_INDEXING_WORKERS") or NUM_INDEXING_WORKERS
)
# More accurate results at the expense of indexing speed and index size (stores additional 4 MINI_CHUNK vectors)
ENABLE_MINI_CHUNK = os.environ.get("ENABLE_MINI_CHUNK", "").lower() == "true"
ENABLE_MULTIPASS_INDEXING = (
os.environ.get("ENABLE_MULTIPASS_INDEXING", "").lower() == "true"
)
# Finer grained chunking for more detail retention
# Slightly larger since the sentence aware split is a max cutoff so most minichunks will be under MINI_CHUNK_SIZE
# tokens. But we need it to be at least as big as 1/4th chunk size to avoid having a tiny mini-chunk at the end
MINI_CHUNK_SIZE = 150
# This is the number of regular chunks per large chunk
LARGE_CHUNK_RATIO = 4
# Include the document level metadata in each chunk. If the metadata is too long, then it is thrown out
# We don't want the metadata to overwhelm the actual contents of the chunk
SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true"
# Timeout to wait for job's last update before killing it, in hours
CLEANUP_INDEXING_JOBS_TIMEOUT = int(os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT", 3))
# The indexer will warn in the logs whenver a document exceeds this threshold (in bytes)
INDEXING_SIZE_WARNING_THRESHOLD = int(
os.environ.get("INDEXING_SIZE_WARNING_THRESHOLD", 100 * 1024 * 1024)
)
# during indexing, will log verbose memory diff stats every x batches and at the end.
# 0 disables this behavior and is the default.
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL", 0))
# During an indexing attempt, specifies the number of batches which are allowed to
# exception without aborting the attempt.
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT", 0))
#####
# Miscellaneous
@@ -287,6 +325,10 @@ LOG_VESPA_TIMING_INFORMATION = (
os.environ.get("LOG_VESPA_TIMING_INFORMATION", "").lower() == "true"
)
LOG_ENDPOINT_LATENCY = os.environ.get("LOG_ENDPOINT_LATENCY", "").lower() == "true"
LOG_POSTGRES_LATENCY = os.environ.get("LOG_POSTGRES_LATENCY", "").lower() == "true"
LOG_POSTGRES_CONN_COUNTS = (
os.environ.get("LOG_POSTGRES_CONN_COUNTS", "").lower() == "true"
)
# Anonymous usage telemetry
DISABLE_TELEMETRY = os.environ.get("DISABLE_TELEMETRY", "").lower() == "true"

View File

@@ -3,12 +3,13 @@ import os
PROMPTS_YAML = "./danswer/chat/prompts.yaml"
PERSONAS_YAML = "./danswer/chat/personas.yaml"
INPUT_PROMPT_YAML = "./danswer/chat/input_prompts.yaml"
NUM_RETURNED_HITS = 50
# Used for LLM filtering and reranking
# We want this to be approximately the number of results we want to show on the first page
# It cannot be too large due to cost and latency implications
NUM_RERANKED_RESULTS = 20
NUM_POSTPROCESSED_RESULTS = 20
# May be less depending on model
MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0)
@@ -32,11 +33,6 @@ DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak"
# Note this is not in any of the deployment configs yet
CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 0)
CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 0)
# Whether the LLM should evaluate all of the document chunks passed in for usefulness
# in relation to the user query
DISABLE_LLM_CHUNK_FILTER = (
os.environ.get("DISABLE_LLM_CHUNK_FILTER", "").lower() == "true"
)
# Whether the LLM should be used to decide if a search would help given the chat history
DISABLE_LLM_CHOOSE_SEARCH = (
os.environ.get("DISABLE_LLM_CHOOSE_SEARCH", "").lower() == "true"
@@ -47,22 +43,19 @@ DISABLE_LLM_QUERY_REPHRASE = (
# 1 edit per 20 characters, currently unused due to fuzzy match being too slow
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
# Keyword Search Drop Stopwords
# If user has changed the default model, would most likely be to use a multilingual
# model, the stopwords are NLTK english stopwords so then we would want to not drop the keywords
if os.environ.get("EDIT_KEYWORD_QUERY"):
EDIT_KEYWORD_QUERY = os.environ.get("EDIT_KEYWORD_QUERY", "").lower() == "true"
else:
EDIT_KEYWORD_QUERY = not os.environ.get("DOCUMENT_ENCODER_MODEL")
# Weighting factor between Vector and Keyword Search, 1 for completely vector search
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.62)))
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.5)))
HYBRID_ALPHA_KEYWORD = max(
0, min(1, float(os.environ.get("HYBRID_ALPHA_KEYWORD") or 0.4))
)
# Weighting factor between Title and Content of documents during search, 1 for completely
# Title based. Default heavily favors Content because Title is also included at the top of
# Content. This is to avoid cases where the Content is very relevant but it may not be clear
# if the title is separated out. Title is most of a "boost" than a separate field.
TITLE_CONTENT_RATIO = max(
0, min(1, float(os.environ.get("TITLE_CONTENT_RATIO") or 0.20))
0, min(1, float(os.environ.get("TITLE_CONTENT_RATIO") or 0.10))
)
# A list of languages passed to the LLM to rephase the query
# For example "English,French,Spanish", be sure to use the "," separator
MULTILINGUAL_QUERY_EXPANSION = os.environ.get("MULTILINGUAL_QUERY_EXPANSION") or None
@@ -75,16 +68,16 @@ LANGUAGE_CHAT_NAMING_HINT = (
or "The name of the conversation must be in the same language as the user query."
)
# Agentic search takes significantly more tokens and therefore has much higher cost.
# This configuration allows users to get a search-only experience with instant results
# and no involvement from the LLM.
# Additionally, some LLM providers have strict rate limits which may prohibit
# sending many API requests at once (as is done in agentic search).
DISABLE_AGENTIC_SEARCH = (
os.environ.get("DISABLE_AGENTIC_SEARCH") or "false"
).lower() == "true"
# Whether the LLM should evaluate all of the document chunks passed in for usefulness
# in relation to the user query
DISABLE_LLM_DOC_RELEVANCE = (
os.environ.get("DISABLE_LLM_DOC_RELEVANCE", "").lower() == "true"
)
# Stops streaming answers back to the UI if this pattern is seen:
STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None

View File

@@ -1,26 +1,6 @@
from enum import Enum
DOCUMENT_ID = "document_id"
CHUNK_ID = "chunk_id"
BLURB = "blurb"
CONTENT = "content"
SOURCE_TYPE = "source_type"
SOURCE_LINKS = "source_links"
SOURCE_LINK = "link"
SEMANTIC_IDENTIFIER = "semantic_identifier"
TITLE = "title"
SKIP_TITLE_EMBEDDING = "skip_title"
SECTION_CONTINUATION = "section_continuation"
EMBEDDINGS = "embeddings"
TITLE_EMBEDDING = "title_embedding"
ALLOWED_USERS = "allowed_users"
ACCESS_CONTROL_LIST = "access_control_list"
DOCUMENT_SETS = "document_sets"
TIME_FILTER = "time_filter"
METADATA = "metadata"
METADATA_LIST = "metadata_list"
METADATA_SUFFIX = "metadata_suffix"
MATCH_HIGHLIGHTS = "match_highlights"
# stored in the `metadata` of a chunk. Used to signify that this chunk should
# not be used for QA. For example, Google Drive file types which can't be parsed
# are still useful as a search result but not for QA.
@@ -28,23 +8,11 @@ IGNORE_FOR_QA = "ignore_for_qa"
# NOTE: deprecated, only used for porting key from old system
GEN_AI_API_KEY_STORAGE_KEY = "genai_api_key"
PUBLIC_DOC_PAT = "PUBLIC"
PUBLIC_DOCUMENT_SET = "__PUBLIC"
QUOTE = "quote"
BOOST = "boost"
DOC_UPDATED_AT = "doc_updated_at" # Indexed as seconds since epoch
PRIMARY_OWNERS = "primary_owners"
SECONDARY_OWNERS = "secondary_owners"
RECENCY_BIAS = "recency_bias"
HIDDEN = "hidden"
SCORE = "score"
ID_SEPARATOR = ":;:"
DEFAULT_BOOST = 0
SESSION_KEY = "session"
QUERY_EVENT_ID = "query_event_id"
LLM_CHUNKS = "llm_chunks"
# For chunking/processing chunks
MAX_CHUNK_TITLE_LEN = 1000
RETURN_SEPARATOR = "\n\r\n"
SECTION_SEPARATOR = "\n\n"
# For combining attributes, doesn't have to be unique/perfect to work
@@ -60,12 +28,37 @@ DISABLED_GEN_AI_MSG = (
"You can still use Danswer as a search engine."
)
# Postgres connection constants for application_name
POSTGRES_WEB_APP_NAME = "web"
POSTGRES_INDEXER_APP_NAME = "indexer"
POSTGRES_CELERY_APP_NAME = "celery"
POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
POSTGRES_CELERY_WORKER_APP_NAME = "celery_worker"
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
POSTGRES_UNKNOWN_APP_NAME = "unknown"
# API Keys
DANSWER_API_KEY_PREFIX = "API_KEY__"
DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN = "danswerapikey.ai"
UNNAMED_KEY_PLACEHOLDER = "Unnamed"
# Key-Value store keys
KV_REINDEX_KEY = "needs_reindexing"
KV_SEARCH_SETTINGS = "search_settings"
KV_USER_STORE_KEY = "INVITED_USERS"
KV_NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"
KV_CRED_KEY = "credential_id_{}"
KV_GMAIL_CRED_KEY = "gmail_app_credential"
KV_GMAIL_SERVICE_ACCOUNT_KEY = "gmail_service_account_key"
KV_GOOGLE_DRIVE_CRED_KEY = "google_drive_app_credential"
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
KV_SLACK_BOT_TOKENS_CONFIG_KEY = "slack_bot_tokens_config_key"
KV_GEN_AI_KEY_CHECK_TIME = "genai_api_key_last_check_time"
KV_SETTINGS_KEY = "danswer_settings"
KV_CUSTOMER_UUID_KEY = "customer_uuid"
KV_ENTERPRISE_SETTINGS_KEY = "danswer_enterprise_settings"
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
class DocumentSource(str, Enum):
# Special case, document passed in via Danswer APIs without specifying a source type
@@ -109,6 +102,10 @@ class DocumentSource(str, Enum):
NOT_APPLICABLE = "not_applicable"
class NotificationType(str, Enum):
REINDEX = "reindex"
class BlobType(str, Enum):
R2 = "r2"
S3 = "s3"

View File

@@ -12,13 +12,15 @@ import os
# The useable models configured as below must be SentenceTransformer compatible
# NOTE: DO NOT CHANGE SET THESE UNLESS YOU KNOW WHAT YOU ARE DOING
# IDEALLY, YOU SHOULD CHANGE EMBEDDING MODELS VIA THE UI
DEFAULT_DOCUMENT_ENCODER_MODEL = "intfloat/e5-base-v2"
DEFAULT_DOCUMENT_ENCODER_MODEL = "nomic-ai/nomic-embed-text-v1"
DOCUMENT_ENCODER_MODEL = (
os.environ.get("DOCUMENT_ENCODER_MODEL") or DEFAULT_DOCUMENT_ENCODER_MODEL
)
# If the below is changed, Vespa deployment must also be changed
DOC_EMBEDDING_DIM = int(os.environ.get("DOC_EMBEDDING_DIM") or 768)
# Model should be chosen with 512 context size, ideally don't change this
# If multipass_indexing is enabled, the max context size would be set to
# DOC_EMBEDDING_CONTEXT_SIZE * LARGE_CHUNK_RATIO
DOC_EMBEDDING_CONTEXT_SIZE = 512
NORMALIZE_EMBEDDINGS = (
os.environ.get("NORMALIZE_EMBEDDINGS") or "true"
@@ -34,17 +36,16 @@ OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS = False
SIM_SCORE_RANGE_LOW = float(os.environ.get("SIM_SCORE_RANGE_LOW") or 0.0)
SIM_SCORE_RANGE_HIGH = float(os.environ.get("SIM_SCORE_RANGE_HIGH") or 1.0)
# Certain models like e5, BGE, etc use a prefix for asymmetric retrievals (query generally shorter than docs)
ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "query: ")
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "passage: ")
ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "search_query: ")
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "search_document: ")
# Purely an optimization, memory limitation consideration
BATCH_SIZE_ENCODE_CHUNKS = 8
# don't send over too many chunks at once, as sending too many could cause timeouts
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES = 512
# For score display purposes, only way is to know the expected ranges
CROSS_ENCODER_RANGE_MAX = 1
CROSS_ENCODER_RANGE_MIN = 0
# Unused currently, can't be used with the current default encoder model due to its output range
SEARCH_DISTANCE_CUTOFF = 0
#####
# Generative AI Model Configs

View File

@@ -56,7 +56,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
Raises ValueError for unsupported bucket types.
"""
logger.info(
logger.debug(
f"Loading credentials for {self.bucket_name} or type {self.bucket_type}"
)
@@ -169,7 +169,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
end: datetime,
) -> GenerateDocumentsOutput:
if self.s3_client is None:
raise ConnectorMissingCredentialError("Blog storage")
raise ConnectorMissingCredentialError("Blob storage")
paginator = self.s3_client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix)
@@ -220,7 +220,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
yield batch
def load_from_state(self) -> GenerateDocumentsOutput:
logger.info("Loading blob objects")
logger.debug("Loading blob objects")
return self._yield_blob_objects(
start=datetime(1970, 1, 1, tzinfo=timezone.utc),
end=datetime.now(timezone.utc),
@@ -230,7 +230,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
if self.s3_client is None:
raise ConnectorMissingCredentialError("Blog storage")
raise ConnectorMissingCredentialError("Blob storage")
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)

View File

@@ -13,6 +13,7 @@ import bs4
from atlassian import Confluence # type:ignore
from requests import HTTPError
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING
@@ -217,16 +218,19 @@ class RecursiveIndexer:
self,
batch_size: int,
confluence_client: Confluence,
index_origin: bool,
index_recursively: bool,
origin_page_id: str,
) -> None:
self.batch_size = 1
# batch_size
self.confluence_client = confluence_client
self.index_origin = index_origin
self.index_recursively = index_recursively
self.origin_page_id = origin_page_id
self.pages = self.recurse_children_pages(0, self.origin_page_id)
def get_origin_page(self) -> list[dict[str, Any]]:
return [self._fetch_origin_page()]
def get_pages(self, ind: int, size: int) -> list[dict]:
if ind * size > len(self.pages):
return []
@@ -282,12 +286,11 @@ class RecursiveIndexer:
current_level_pages = next_level_pages
next_level_pages = []
if self.index_origin:
try:
origin_page = self._fetch_origin_page()
pages.append(origin_page)
except Exception as e:
logger.warning(f"Appending origin page with id {page_id} failed: {e}")
try:
origin_page = self._fetch_origin_page()
pages.append(origin_page)
except Exception as e:
logger.warning(f"Appending origin page with id {page_id} failed: {e}")
return pages
@@ -340,7 +343,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
def __init__(
self,
wiki_page_url: str,
index_origin: bool = True,
index_recursively: bool = True,
batch_size: int = INDEX_BATCH_SIZE,
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
# if a page has one of the labels specified in this list, we will just
@@ -352,7 +355,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
self.continue_on_failure = continue_on_failure
self.labels_to_skip = set(labels_to_skip)
self.recursive_indexer: RecursiveIndexer | None = None
self.index_origin = index_origin
self.index_recursively = index_recursively
(
self.wiki_base,
self.space,
@@ -369,7 +372,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
logger.info(
f"wiki_base: {self.wiki_base}, space: {self.space}, page_id: {self.page_id},"
+ f" space_level_scan: {self.space_level_scan}, origin: {self.index_origin}"
+ f" space_level_scan: {self.space_level_scan}, index_recursively: {self.index_recursively}"
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
@@ -453,10 +456,13 @@ class ConfluenceConnector(LoadConnector, PollConnector):
origin_page_id=self.page_id,
batch_size=self.batch_size,
confluence_client=self.confluence_client,
index_origin=self.index_origin,
index_recursively=self.index_recursively,
)
return self.recursive_indexer.get_pages(start_ind, batch_size)
if self.index_recursively:
return self.recursive_indexer.get_pages(start_ind, batch_size)
else:
return self.recursive_indexer.get_origin_page()
pages: list[dict[str, Any]] = []
@@ -555,6 +561,17 @@ class ConfluenceConnector(LoadConnector, PollConnector):
if attachment["title"] not in files_in_used:
continue
download_link = confluence_client.url + attachment["_links"]["download"]
attachment_size = attachment["extensions"]["fileSize"]
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to size. "
f"size={attachment_size} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
)
continue
download_link = confluence_client.url + attachment["_links"]["download"]
response = confluence_client._session.get(download_link)

View File

@@ -56,7 +56,7 @@ class _RateLimitDecorator:
sleep_cnt = 0
while len(self.call_history) == self.max_calls:
sleep_time = self.sleep_time * (self.sleep_backoff**sleep_cnt)
logger.info(
logger.notice(
f"Rate limit exceeded for function {func.__name__}. "
f"Waiting {sleep_time} seconds before retrying."
)

View File

@@ -56,6 +56,16 @@ def extract_text_from_content(content: dict) -> str:
return " ".join(texts)
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
if hasattr(jira_issue.fields, field):
return getattr(jira_issue.fields, field)
try:
return jira_issue.raw["fields"][field]
except Exception:
return None
def _get_comment_strs(
jira: Issue, comment_email_blacklist: tuple[str, ...] = ()
) -> list[str]:
@@ -117,8 +127,10 @@ def fetch_jira_issues_batch(
continue
comments = _get_comment_strs(jira, comment_email_blacklist)
semantic_rep = f"{jira.fields.description}\n" + "\n".join(
[f"Comment: {comment}" for comment in comments]
semantic_rep = (
f"{jira.fields.description}\n"
if jira.fields.description
else "" + "\n".join([f"Comment: {comment}" for comment in comments])
)
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
@@ -147,14 +159,18 @@ def fetch_jira_issues_batch(
pass
metadata_dict = {}
if jira.fields.priority:
metadata_dict["priority"] = jira.fields.priority.name
if jira.fields.status:
metadata_dict["status"] = jira.fields.status.name
if jira.fields.resolution:
metadata_dict["resolution"] = jira.fields.resolution.name
if jira.fields.labels:
metadata_dict["label"] = jira.fields.labels
priority = best_effort_get_field_from_issue(jira, "priority")
if priority:
metadata_dict["priority"] = priority.name
status = best_effort_get_field_from_issue(jira, "status")
if status:
metadata_dict["status"] = status.name
resolution = best_effort_get_field_from_issue(jira, "resolution")
if resolution:
metadata_dict["resolution"] = resolution.name
labels = best_effort_get_field_from_issue(jira, "labels")
if labels:
metadata_dict["label"] = labels
doc_batch.append(
Document(

View File

@@ -64,7 +64,7 @@ class DiscourseConnector(PollConnector):
self.permissions: DiscoursePerms | None = None
self.active_categories: set | None = None
@rate_limit_builder(max_calls=100, period=60)
@rate_limit_builder(max_calls=50, period=60)
def _make_request(self, endpoint: str, params: dict | None = None) -> Response:
if not self.permissions:
raise ConnectorMissingCredentialError("Discourse")

View File

@@ -38,7 +38,7 @@ def _sleep_after_rate_limit_exception(github_client: Github) -> None:
tzinfo=timezone.utc
) - datetime.now(tz=timezone.utc)
sleep_time += timedelta(minutes=1) # add an extra minute just to be safe
logger.info(f"Ran into Github rate-limit. Sleeping {sleep_time.seconds} seconds.")
logger.notice(f"Ran into Github rate-limit. Sleeping {sleep_time.seconds} seconds.")
time.sleep(sleep_time.seconds)

View File

@@ -11,16 +11,17 @@ from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
from sqlalchemy.orm import Session
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.connectors.gmail.constants import CRED_KEY
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import KV_CRED_KEY
from danswer.configs.constants import KV_GMAIL_CRED_KEY
from danswer.configs.constants import KV_GMAIL_SERVICE_ACCOUNT_KEY
from danswer.connectors.gmail.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
)
from danswer.connectors.gmail.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
from danswer.connectors.gmail.constants import GMAIL_CRED_KEY
from danswer.connectors.gmail.constants import (
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.gmail.constants import GMAIL_SERVICE_ACCOUNT_KEY
from danswer.connectors.gmail.constants import SCOPES
from danswer.db.credentials import update_credential_json
from danswer.db.models import User
@@ -49,7 +50,7 @@ def get_gmail_creds_for_authorized_user(
try:
creds.refresh(Request())
if creds.valid:
logger.info("Refreshed Gmail tokens.")
logger.notice("Refreshed Gmail tokens.")
return creds
except Exception as e:
logger.exception(f"Failed to refresh gmail access token due to: {e}")
@@ -71,7 +72,7 @@ def get_gmail_creds_for_service_account(
def verify_csrf(credential_id: int, state: str) -> None:
csrf = get_dynamic_config_store().load(CRED_KEY.format(str(credential_id)))
csrf = get_dynamic_config_store().load(KV_CRED_KEY.format(str(credential_id)))
if csrf != state:
raise PermissionError(
"State from Gmail Connector callback does not match expected"
@@ -79,7 +80,7 @@ def verify_csrf(credential_id: int, state: str) -> None:
def get_gmail_auth_url(credential_id: int) -> str:
creds_str = str(get_dynamic_config_store().load(GMAIL_CRED_KEY))
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_CRED_KEY))
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
@@ -91,12 +92,14 @@ def get_gmail_auth_url(credential_id: int) -> str:
parsed_url = cast(ParseResult, urlparse(auth_url))
params = parse_qs(parsed_url.query)
get_dynamic_config_store().store(CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True) # type: ignore
get_dynamic_config_store().store(
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
) # type: ignore
return str(auth_url)
def get_auth_url(credential_id: int) -> str:
creds_str = str(get_dynamic_config_store().load(GMAIL_CRED_KEY))
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_CRED_KEY))
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
@@ -108,7 +111,9 @@ def get_auth_url(credential_id: int) -> str:
parsed_url = cast(ParseResult, urlparse(auth_url))
params = parse_qs(parsed_url.query)
get_dynamic_config_store().store(CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True) # type: ignore
get_dynamic_config_store().store(
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
) # type: ignore
return str(auth_url)
@@ -146,28 +151,29 @@ def build_service_account_creds(
credential_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user_email
return CredentialBase(
source=DocumentSource.GMAIL,
credential_json=credential_dict,
admin_public=True,
)
def get_google_app_gmail_cred() -> GoogleAppCredentials:
creds_str = str(get_dynamic_config_store().load(GMAIL_CRED_KEY))
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_CRED_KEY))
return GoogleAppCredentials(**json.loads(creds_str))
def upsert_google_app_gmail_cred(app_credentials: GoogleAppCredentials) -> None:
get_dynamic_config_store().store(
GMAIL_CRED_KEY, app_credentials.json(), encrypt=True
KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True
)
def delete_google_app_gmail_cred() -> None:
get_dynamic_config_store().delete(GMAIL_CRED_KEY)
get_dynamic_config_store().delete(KV_GMAIL_CRED_KEY)
def get_gmail_service_account_key() -> GoogleServiceAccountKey:
creds_str = str(get_dynamic_config_store().load(GMAIL_SERVICE_ACCOUNT_KEY))
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
return GoogleServiceAccountKey(**json.loads(creds_str))
@@ -175,19 +181,19 @@ def upsert_gmail_service_account_key(
service_account_key: GoogleServiceAccountKey,
) -> None:
get_dynamic_config_store().store(
GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
)
def upsert_service_account_key(service_account_key: GoogleServiceAccountKey) -> None:
get_dynamic_config_store().store(
GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
)
def delete_gmail_service_account_key() -> None:
get_dynamic_config_store().delete(GMAIL_SERVICE_ACCOUNT_KEY)
get_dynamic_config_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
def delete_service_account_key() -> None:
get_dynamic_config_store().delete(GMAIL_SERVICE_ACCOUNT_KEY)
get_dynamic_config_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)

View File

@@ -1,7 +1,4 @@
DB_CREDENTIALS_DICT_TOKEN_KEY = "gmail_tokens"
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "gmail_service_account_key"
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "gmail_delegated_user"
CRED_KEY = "credential_id_{}"
GMAIL_CRED_KEY = "gmail_app_credential"
GMAIL_SERVICE_ACCOUNT_KEY = "gmail_service_account_key"
SCOPES = ["https://www.googleapis.com/auth/gmail.readonly"]

View File

@@ -81,10 +81,10 @@ class GongConnector(LoadConnector, PollConnector):
for workspace in workspace_list:
if workspace:
logger.info(f"Updating workspace: {workspace}")
logger.info(f"Updating Gong workspace: {workspace}")
workspace_id = workspace_map.get(workspace)
if not workspace_id:
logger.error(f"Invalid workspace: {workspace}")
logger.error(f"Invalid Gong workspace: {workspace}")
if not self.continue_on_fail:
raise ValueError(f"Invalid workspace: {workspace}")
continue

View File

@@ -267,7 +267,7 @@ def get_all_files_batched(
yield from batch_generator(
items=found_files,
batch_size=batch_size,
pre_batch_yield=lambda batch_files: logger.info(
pre_batch_yield=lambda batch_files: logger.debug(
f"Parseable Documents in batch: {[file['name'] for file in batch_files]}"
),
)
@@ -306,24 +306,29 @@ def get_all_files_batched(
def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
mime_type = file["mimeType"]
if mime_type not in set(item.value for item in GDriveMimeType):
# Unsupported file types can still have a title, finding this way is still useful
return UNSUPPORTED_FILE_TYPE_CONTENT
if mime_type == GDriveMimeType.DOC.value:
return (
if mime_type in [
GDriveMimeType.DOC.value,
GDriveMimeType.PPT.value,
GDriveMimeType.SPREADSHEET.value,
]:
export_mime_type = "text/plain"
if mime_type == GDriveMimeType.SPREADSHEET.value:
export_mime_type = "text/csv"
elif mime_type == GDriveMimeType.PPT.value:
export_mime_type = "text/plain"
response = (
service.files()
.export(fileId=file["id"], mimeType="text/plain")
.export(fileId=file["id"], mimeType=export_mime_type)
.execute()
.decode("utf-8")
)
elif mime_type == GDriveMimeType.SPREADSHEET.value:
return (
service.files()
.export(fileId=file["id"], mimeType="text/csv")
.execute()
.decode("utf-8")
)
return response.decode("utf-8")
elif mime_type == GDriveMimeType.WORD_DOC.value:
response = service.files().get_media(fileId=file["id"]).execute()
return docx_to_text(file=io.BytesIO(response))
@@ -333,9 +338,6 @@ def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
elif mime_type == GDriveMimeType.POWERPOINT.value:
response = service.files().get_media(fileId=file["id"]).execute()
return pptx_to_text(file=io.BytesIO(response))
elif mime_type == GDriveMimeType.PPT.value:
response = service.files().get_media(fileId=file["id"]).execute()
return pptx_to_text(file=io.BytesIO(response))
return UNSUPPORTED_FILE_TYPE_CONTENT

View File

@@ -11,7 +11,10 @@ from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
from sqlalchemy.orm import Session
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.connectors.google_drive.constants import CRED_KEY
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import KV_CRED_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
)
@@ -19,8 +22,6 @@ from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
from danswer.connectors.google_drive.constants import GOOGLE_DRIVE_CRED_KEY
from danswer.connectors.google_drive.constants import GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
from danswer.connectors.google_drive.constants import SCOPES
from danswer.db.credentials import update_credential_json
from danswer.db.models import User
@@ -49,7 +50,7 @@ def get_google_drive_creds_for_authorized_user(
try:
creds.refresh(Request())
if creds.valid:
logger.info("Refreshed Google Drive tokens.")
logger.notice("Refreshed Google Drive tokens.")
return creds
except Exception as e:
logger.exception(f"Failed to refresh google drive access token due to: {e}")
@@ -71,7 +72,7 @@ def get_google_drive_creds_for_service_account(
def verify_csrf(credential_id: int, state: str) -> None:
csrf = get_dynamic_config_store().load(CRED_KEY.format(str(credential_id)))
csrf = get_dynamic_config_store().load(KV_CRED_KEY.format(str(credential_id)))
if csrf != state:
raise PermissionError(
"State from Google Drive Connector callback does not match expected"
@@ -79,7 +80,7 @@ def verify_csrf(credential_id: int, state: str) -> None:
def get_auth_url(credential_id: int) -> str:
creds_str = str(get_dynamic_config_store().load(GOOGLE_DRIVE_CRED_KEY))
creds_str = str(get_dynamic_config_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
@@ -91,7 +92,9 @@ def get_auth_url(credential_id: int) -> str:
parsed_url = cast(ParseResult, urlparse(auth_url))
params = parse_qs(parsed_url.query)
get_dynamic_config_store().store(CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True) # type: ignore
get_dynamic_config_store().store(
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
) # type: ignore
return str(auth_url)
@@ -118,6 +121,7 @@ def update_credential_access_tokens(
def build_service_account_creds(
source: DocumentSource,
delegated_user_email: str | None = None,
) -> CredentialBase:
service_account_key = get_service_account_key()
@@ -131,34 +135,37 @@ def build_service_account_creds(
return CredentialBase(
credential_json=credential_dict,
admin_public=True,
source=DocumentSource.GOOGLE_DRIVE,
)
def get_google_app_cred() -> GoogleAppCredentials:
creds_str = str(get_dynamic_config_store().load(GOOGLE_DRIVE_CRED_KEY))
creds_str = str(get_dynamic_config_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
return GoogleAppCredentials(**json.loads(creds_str))
def upsert_google_app_cred(app_credentials: GoogleAppCredentials) -> None:
get_dynamic_config_store().store(
GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True
KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True
)
def delete_google_app_cred() -> None:
get_dynamic_config_store().delete(GOOGLE_DRIVE_CRED_KEY)
get_dynamic_config_store().delete(KV_GOOGLE_DRIVE_CRED_KEY)
def get_service_account_key() -> GoogleServiceAccountKey:
creds_str = str(get_dynamic_config_store().load(GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY))
creds_str = str(
get_dynamic_config_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
)
return GoogleServiceAccountKey(**json.loads(creds_str))
def upsert_service_account_key(service_account_key: GoogleServiceAccountKey) -> None:
get_dynamic_config_store().store(
GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
)
def delete_service_account_key() -> None:
get_dynamic_config_store().delete(GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
get_dynamic_config_store().delete(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)

View File

@@ -1,9 +1,6 @@
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens"
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "google_drive_delegated_user"
CRED_KEY = "credential_id_{}"
GOOGLE_DRIVE_CRED_KEY = "google_drive_app_credential"
GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
SCOPES = [
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/drive.metadata.readonly",

View File

@@ -86,7 +86,6 @@ class MediaWikiConnector(LoadConnector, PollConnector):
categories: The categories to include in the index.
pages: The pages to include in the index.
recurse_depth: The depth to recurse into categories. -1 means unbounded recursion.
connector_name: The name of the connector.
language_code: The language code of the wiki.
batch_size: The batch size for loading documents.
@@ -104,7 +103,6 @@ class MediaWikiConnector(LoadConnector, PollConnector):
categories: list[str],
pages: list[str],
recurse_depth: int,
connector_name: str,
language_code: str = "en",
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
@@ -118,10 +116,8 @@ class MediaWikiConnector(LoadConnector, PollConnector):
self.batch_size = batch_size
# short names can only have ascii letters and digits
self.connector_name = connector_name
connector_name = "".join(ch for ch in connector_name if ch.isalnum())
self.family = family_class_dispatch(hostname, connector_name)()
self.family = family_class_dispatch(hostname, "Wikipedia Connector")()
self.site = pywikibot.Site(fam=self.family, code=language_code)
self.categories = [
pywikibot.Category(self.site, f"Category:{category.replace(' ', '_')}")
@@ -210,7 +206,6 @@ class MediaWikiConnector(LoadConnector, PollConnector):
if __name__ == "__main__":
HOSTNAME = "fallout.fandom.com"
test_connector = MediaWikiConnector(
connector_name="Fallout",
hostname=HOSTNAME,
categories=["Fallout:_New_Vegas_factions"],
pages=["Fallout: New Vegas"],

View File

@@ -114,7 +114,9 @@ class DocumentBase(BaseModel):
title: str | None = None
from_ingestion_api: bool = False
def get_title_for_document_index(self) -> str | None:
def get_title_for_document_index(
self,
) -> str | None:
# If title is explicitly empty, return a None here for embedding purposes
if self.title == "":
return None
@@ -164,6 +166,36 @@ class Document(DocumentBase):
)
class DocumentErrorSummary(BaseModel):
id: str
semantic_id: str
section_link: str | None
@classmethod
def from_document(cls, doc: Document) -> "DocumentErrorSummary":
section_link = doc.sections[0].link if len(doc.sections) > 0 else None
return cls(
id=doc.id, semantic_id=doc.semantic_identifier, section_link=section_link
)
@classmethod
def from_dict(cls, data: dict) -> "DocumentErrorSummary":
return cls(
id=str(data.get("id")),
semantic_id=str(data.get("semantic_id")),
section_link=str(data.get("section_link")),
)
def to_dict(self) -> dict[str, str | None]:
return {
"id": self.id,
"semantic_id": self.semantic_id,
"section_link": self.section_link,
}
class IndexAttemptMetadata(BaseModel):
batch_num: int | None = None
num_exceptions: int = 0
connector_id: int
credential_id: int

View File

@@ -68,12 +68,13 @@ def make_slack_api_call_paginated(
def make_slack_api_rate_limited(
call: Callable[..., SlackResponse], max_retries: int = 3
call: Callable[..., SlackResponse], max_retries: int = 7
) -> Callable[..., SlackResponse]:
"""Wraps calls to slack API so that they automatically handle rate limiting"""
@wraps(call)
def rate_limited_call(**kwargs: Any) -> SlackResponse:
last_exception = None
for _ in range(max_retries):
try:
# Make the API call
@@ -85,14 +86,20 @@ def make_slack_api_rate_limited(
return response
except SlackApiError as e:
if e.response["error"] == "ratelimited":
last_exception = e
try:
error = e.response["error"]
except KeyError:
error = "unknown error"
if error == "ratelimited":
# Handle rate limiting: get the 'Retry-After' header value and sleep for that duration
retry_after = int(e.response.headers.get("Retry-After", 1))
logger.info(
f"Slack call rate limited, retrying after {retry_after} seconds. Exception: {e}"
)
time.sleep(retry_after)
elif e.response["error"] in ["already_reacted", "no_reaction"]:
elif error in ["already_reacted", "no_reaction"]:
# The response isn't used for reactions, this is basically just a pass
return e.response
else:
@@ -100,7 +107,11 @@ def make_slack_api_rate_limited(
raise
# If the code reaches this point, all retries have been exhausted
raise Exception(f"Max retries ({max_retries}) exceeded")
msg = f"Max retries ({max_retries}) exceeded"
if last_exception:
raise Exception(msg) from last_exception
else:
raise Exception(msg)
return rate_limited_call

View File

@@ -15,6 +15,7 @@ from playwright.sync_api import BrowserContext
from playwright.sync_api import Playwright
from playwright.sync_api import sync_playwright
from requests_oauthlib import OAuth2Session # type:ignore
from urllib3.exceptions import MaxRetryError
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.app_configs import WEB_CONNECTOR_OAUTH_CLIENT_ID
@@ -83,6 +84,13 @@ def check_internet_connection(url: str) -> None:
try:
response = requests.get(url, timeout=3)
response.raise_for_status()
except requests.exceptions.SSLError as e:
cause = (
e.args[0].reason
if isinstance(e.args, tuple) and isinstance(e.args[0], MaxRetryError)
else e.args
)
raise Exception(f"SSL error {str(cause)}")
except (requests.RequestException, ValueError):
raise Exception(f"Unable to reach {url} - check your internet connection")

View File

@@ -15,7 +15,6 @@ class WikipediaConnector(wiki.MediaWikiConnector):
categories: list[str],
pages: list[str],
recurse_depth: int,
connector_name: str,
language_code: str = "en",
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
@@ -24,7 +23,6 @@ class WikipediaConnector(wiki.MediaWikiConnector):
categories=categories,
pages=pages,
recurse_depth=recurse_depth,
connector_name=connector_name,
language_code=language_code,
batch_size=batch_size,
)

View File

@@ -1,5 +1,7 @@
from typing import Any
import requests
from retry import retry
from zenpy import Zenpy # type: ignore
from zenpy.lib.api_objects.help_centre_objects import Article # type: ignore
@@ -19,12 +21,24 @@ from danswer.connectors.models import Section
from danswer.file_processing.html_utils import parse_html_page_basic
def _article_to_document(article: Article) -> Document:
def _article_to_document(article: Article, content_tags: dict[str, str]) -> Document:
author = BasicExpertInfo(
display_name=article.author.name, email=article.author.email
)
update_time = time_str_to_utc(article.updated_at)
labels = [str(label) for label in article.label_names]
# build metadata
metadata: dict[str, str | list[str]] = {
"labels": [str(label) for label in article.label_names if label],
"content_tags": [
content_tags[tag_id]
for tag_id in article.content_tag_ids
if tag_id in content_tags
],
}
# remove empty values
metadata = {k: v for k, v in metadata.items() if v}
return Document(
id=f"article:{article.id}",
@@ -35,7 +49,7 @@ def _article_to_document(article: Article) -> Document:
semantic_identifier=article.title,
doc_updated_at=update_time,
primary_owners=[author],
metadata={"labels": labels} if labels else {},
metadata=metadata,
)
@@ -48,6 +62,42 @@ class ZendeskConnector(LoadConnector, PollConnector):
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
self.zendesk_client: Zenpy | None = None
self.content_tags: dict[str, str] = {}
@retry(tries=3, delay=2, backoff=2)
def _set_content_tags(
self, subdomain: str, email: str, token: str, page_size: int = 30
) -> None:
# Construct the base URL
base_url = f"https://{subdomain}.zendesk.com/api/v2/guide/content_tags"
# Set up authentication
auth = (f"{email}/token", token)
# Set up pagination parameters
params = {"page[size]": page_size}
try:
while True:
# Make the GET request
response = requests.get(base_url, auth=auth, params=params)
# Check if the request was successful
if response.status_code == 200:
data = response.json()
content_tag_list = data.get("records", [])
for tag in content_tag_list:
self.content_tags[tag["id"]] = tag["name"]
# Check if there are more pages
if data.get("meta", {}).get("has_more", False):
params["page[after]"] = data["meta"]["after_cursor"]
else:
break
else:
raise Exception(f"Error: {response.status_code}\n{response.text}")
except Exception as e:
raise Exception(f"Error fetching content tags: {str(e)}")
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
# Subdomain is actually the whole URL
@@ -62,6 +112,11 @@ class ZendeskConnector(LoadConnector, PollConnector):
email=credentials["zendesk_email"],
token=credentials["zendesk_token"],
)
self._set_content_tags(
subdomain,
credentials["zendesk_email"],
credentials["zendesk_token"],
)
return None
def load_from_state(self) -> GenerateDocumentsOutput:
@@ -92,10 +147,30 @@ class ZendeskConnector(LoadConnector, PollConnector):
):
continue
doc_batch.append(_article_to_document(article))
doc_batch.append(_article_to_document(article, self.content_tags))
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch.clear()
if doc_batch:
yield doc_batch
if __name__ == "__main__":
import os
import time
connector = ZendeskConnector()
connector.load_credentials(
{
"zendesk_subdomain": os.environ["ZENDESK_SUBDOMAIN"],
"zendesk_email": os.environ["ZENDESK_EMAIL"],
"zendesk_token": os.environ["ZENDESK_TOKEN"],
}
)
current = time.time()
one_day_ago = current - 24 * 60 * 60 # 1 day
document_batches = connector.poll_source(one_day_ago, current)
print(next(document_batches))

View File

@@ -70,6 +70,10 @@ def _process_citations_for_slack(text: str) -> str:
def slack_link_format(match: Match) -> str:
link_text = match.group(1)
link_url = match.group(2)
# Account for empty link citations
if link_url == "":
return f"[{link_text}]"
return f"<{link_url}|[{link_text}]>"
# Substitute all matches in the input text
@@ -299,7 +303,9 @@ def build_sources_blocks(
else []
)
+ [
MarkdownTextObject(
MarkdownTextObject(text=f"{document_title}")
if d.link == ""
else MarkdownTextObject(
text=f"*<{d.link}|[{citation_num}] {document_title}>*\n{final_metadata_str}"
),
]

View File

@@ -6,7 +6,6 @@ FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID = "feedback-doc-button"
IMMEDIATE_RESOLVED_BUTTON_ACTION_ID = "immediate-resolved-button"
FOLLOWUP_BUTTON_ACTION_ID = "followup-button"
FOLLOWUP_BUTTON_RESOLVED_ACTION_ID = "followup-resolved-button"
SLACK_CHANNEL_ID = "channel_id"
VIEW_DOC_FEEDBACK_ID = "view-doc-feedback"
GENERATE_ANSWER_BUTTON_ACTION_ID = "generate-answer-button"

View File

@@ -1,4 +1,3 @@
import logging
from typing import Any
from typing import cast
@@ -134,7 +133,7 @@ def handle_generate_answer_button(
receiver_ids=None,
client=client.web_client,
channel=channel_id,
logger=cast(logging.Logger, logger),
logger=logger,
feedback_reminder_id=None,
)

View File

@@ -1,6 +1,4 @@
import datetime
import logging
from typing import cast
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
@@ -9,7 +7,6 @@ from sqlalchemy.orm import Session
from danswer.configs.danswerbot_configs import DANSWER_BOT_FEEDBACK_REMINDER
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
from danswer.danswerbot.slack.blocks import get_feedback_reminder_blocks
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
from danswer.danswerbot.slack.handlers.handle_regular_answer import (
handle_regular_answer,
)
@@ -17,7 +14,6 @@ from danswer.danswerbot.slack.handlers.handle_standard_answers import (
handle_standard_answers,
)
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import ChannelIdAdapter
from danswer.danswerbot.slack.utils import fetch_user_ids_from_emails
from danswer.danswerbot.slack.utils import fetch_user_ids_from_groups
from danswer.danswerbot.slack.utils import respond_in_thread
@@ -26,6 +22,7 @@ from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import SlackBotConfig
from danswer.utils.logger import setup_logger
from shared_configs.configs import SLACK_CHANNEL_ID
logger_base = setup_logger()
@@ -53,12 +50,8 @@ def send_msg_ack_to_user(details: SlackMessageInfo, client: WebClient) -> None:
def schedule_feedback_reminder(
details: SlackMessageInfo, include_followup: bool, client: WebClient
) -> str | None:
logger = cast(
logging.Logger,
ChannelIdAdapter(
logger_base, extra={SLACK_CHANNEL_ID: details.channel_to_respond}
),
)
logger = setup_logger(extra={SLACK_CHANNEL_ID: details.channel_to_respond})
if not DANSWER_BOT_FEEDBACK_REMINDER:
logger.info("Scheduled feedback reminder disabled...")
return None
@@ -97,10 +90,7 @@ def schedule_feedback_reminder(
def remove_scheduled_feedback_reminder(
client: WebClient, channel: str | None, msg_id: str
) -> None:
logger = cast(
logging.Logger,
ChannelIdAdapter(logger_base, extra={SLACK_CHANNEL_ID: channel}),
)
logger = setup_logger(extra={SLACK_CHANNEL_ID: channel})
try:
client.chat_deleteScheduledMessage(
@@ -129,10 +119,7 @@ def handle_message(
"""
channel = message_info.channel_to_respond
logger = cast(
logging.Logger,
ChannelIdAdapter(logger_base, extra={SLACK_CHANNEL_ID: channel}),
)
logger = setup_logger(extra={SLACK_CHANNEL_ID: channel})
messages = message_info.thread_messages
sender_id = message_info.sender

View File

@@ -1,5 +1,4 @@
import functools
import logging
from collections.abc import Callable
from typing import Any
from typing import cast
@@ -50,7 +49,8 @@ from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.search.enums import OptionalSearchSetting
from danswer.search.models import BaseFilters
from danswer.search.models import RetrievalDetails
from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW
from danswer.search.search_settings import get_search_settings
from danswer.utils.logger import DanswerLoggingAdapter
srl = SlackRateLimiter()
@@ -83,7 +83,7 @@ def handle_regular_answer(
receiver_ids: list[str] | None,
client: WebClient,
channel: str,
logger: logging.Logger,
logger: DanswerLoggingAdapter,
feedback_reminder_id: str | None,
num_retries: int = DANSWER_BOT_NUM_RETRIES,
answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
@@ -136,7 +136,6 @@ def handle_regular_answer(
tries=num_retries,
delay=0.25,
backoff=2,
logger=logger,
)
@rate_limits(client=client, channel=channel, thread_ts=message_ts_to_respond_to)
def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | None:
@@ -223,15 +222,23 @@ def handle_regular_answer(
enable_auto_detect_filters=auto_detect_filters,
)
# Always apply reranking settings if it exists, this is the non-streaming flow
saved_search_settings = get_search_settings()
# This includes throwing out answer via reflexion
answer = _get_answer(
DirectQARequest(
messages=messages,
multilingual_query_expansion=saved_search_settings.multilingual_expansion
if saved_search_settings
else None,
prompt_id=prompt.id if prompt else None,
persona_id=persona.id if persona is not None else 0,
retrieval_options=retrieval_details,
chain_of_thought=not disable_cot,
skip_rerank=not ENABLE_RERANKING_ASYNC_FLOW,
rerank_settings=saved_search_settings.to_reranking_detail()
if saved_search_settings
else None,
)
)
except Exception as e:
@@ -311,7 +318,7 @@ def handle_regular_answer(
)
if answer.answer_valid is False:
logger.info(
logger.notice(
"Answer was evaluated to be invalid, throwing it away without responding."
)
update_emote_react(
@@ -349,7 +356,7 @@ def handle_regular_answer(
return True
if not answer.answer and disable_docs_only_answer:
logger.info(
logger.notice(
"Unable to find answer - not responding since the "
"`DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER` env variable is set"
)

View File

@@ -1,5 +1,3 @@
import logging
from slack_sdk import WebClient
from sqlalchemy.orm import Session
@@ -21,6 +19,7 @@ from danswer.db.models import SlackBotConfig
from danswer.db.standard_answer import fetch_standard_answer_categories_by_names
from danswer.db.standard_answer import find_matching_standard_answers
from danswer.server.manage.models import StandardAnswer
from danswer.utils.logger import DanswerLoggingAdapter
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -61,7 +60,7 @@ def handle_standard_answers(
receiver_ids: list[str] | None,
slack_bot_config: SlackBotConfig | None,
prompt: Prompt | None,
logger: logging.Logger,
logger: DanswerLoggingAdapter,
client: WebClient,
db_session: Session,
) -> bool:

View File

@@ -21,7 +21,6 @@ from danswer.danswerbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_I
from danswer.danswerbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID
from danswer.danswerbot.slack.constants import LIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
from danswer.danswerbot.slack.constants import VIEW_DOC_FEEDBACK_ID
from danswer.danswerbot.slack.handlers.handle_buttons import handle_doc_feedback_button
from danswer.danswerbot.slack.handlers.handle_buttons import handle_followup_button
@@ -39,7 +38,6 @@ from danswer.danswerbot.slack.handlers.handle_message import (
from danswer.danswerbot.slack.handlers.handle_message import schedule_feedback_reminder
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.tokens import fetch_tokens
from danswer.danswerbot.slack.utils import ChannelIdAdapter
from danswer.danswerbot.slack.utils import decompose_action_id
from danswer.danswerbot.slack.utils import get_channel_name_from_id
from danswer.danswerbot.slack.utils import get_danswer_bot_app_id
@@ -50,13 +48,14 @@ from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_sqlalchemy_engine
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.one_shot_answer.models import ThreadMessage
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.search.search_nlp_models import warm_up_encoders
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import SLACK_CHANNEL_ID
logger = setup_logger()
@@ -84,18 +83,18 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
event = cast(dict[str, Any], req.payload.get("event", {}))
msg = cast(str | None, event.get("text"))
channel = cast(str | None, event.get("channel"))
channel_specific_logger = ChannelIdAdapter(
logger, extra={SLACK_CHANNEL_ID: channel}
)
channel_specific_logger = setup_logger(extra={SLACK_CHANNEL_ID: channel})
# This should never happen, but we can't continue without a channel since
# we can't send a response without it
if not channel:
channel_specific_logger.error("Found message without channel - skipping")
channel_specific_logger.warning("Found message without channel - skipping")
return False
if not msg:
channel_specific_logger.error("Cannot respond to empty message - skipping")
channel_specific_logger.warning(
"Cannot respond to empty message - skipping"
)
return False
if (
@@ -185,9 +184,8 @@ def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool
if req.type == "slash_commands":
# Verify that there's an associated channel
channel = req.payload.get("channel_id")
channel_specific_logger = ChannelIdAdapter(
logger, extra={SLACK_CHANNEL_ID: channel}
)
channel_specific_logger = setup_logger(extra={SLACK_CHANNEL_ID: channel})
if not channel:
channel_specific_logger.error(
"Received DanswerBot command without channel - skipping"
@@ -230,7 +228,7 @@ def process_feedback(req: SocketModeRequest, client: SocketModeClient) -> None:
)
query_event_id, _, _ = decompose_action_id(feedback_id)
logger.info(f"Successfully handled QA feedback for event: {query_event_id}")
logger.notice(f"Successfully handled QA feedback for event: {query_event_id}")
def build_request_details(
@@ -247,15 +245,17 @@ def build_request_details(
msg = remove_danswer_bot_tag(msg, client=client.web_client)
if DANSWER_BOT_REPHRASE_MESSAGE:
logger.info(f"Rephrasing Slack message. Original message: {msg}")
logger.notice(f"Rephrasing Slack message. Original message: {msg}")
try:
msg = rephrase_slack_message(msg)
logger.info(f"Rephrased message: {msg}")
logger.notice(f"Rephrased message: {msg}")
except Exception as e:
logger.error(f"Error while trying to rephrase the Slack message: {e}")
else:
logger.notice(f"Received Slack message: {msg}")
if tagged:
logger.info("User tagged DanswerBot")
logger.debug("User tagged DanswerBot")
if thread_ts != message_ts and thread_ts is not None:
thread_messages = read_slack_thread(
@@ -437,7 +437,7 @@ def _initialize_socket_client(socket_client: SocketModeClient) -> None:
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
# Establish a WebSocket connection to the Socket Mode servers
logger.info("Listening for messages from Slack...")
logger.notice("Listening for messages from Slack...")
socket_client.connect()
@@ -454,7 +454,7 @@ if __name__ == "__main__":
slack_bot_tokens: SlackBotTokens | None = None
socket_client: SocketModeClient | None = None
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
logger.notice("Verifying query preprocessing (NLTK) data is downloaded")
download_nltk_data()
while True:
@@ -463,16 +463,15 @@ if __name__ == "__main__":
if latest_slack_bot_tokens != slack_bot_tokens:
if slack_bot_tokens is not None:
logger.info("Slack Bot tokens have changed - reconnecting")
logger.notice("Slack Bot tokens have changed - reconnecting")
else:
# This happens on the very first time the listener process comes up
# or the tokens have updated (set up for the first time)
with Session(get_sqlalchemy_engine()) as db_session:
embedding_model = get_current_db_embedding_model(db_session)
if embedding_model.cloud_provider_id is None:
warm_up_encoders(
model_name=embedding_model.model_name,
normalize=embedding_model.normalize,
warm_up_bi_encoder(
embedding_model=embedding_model,
model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)

View File

@@ -1,13 +1,11 @@
import os
from typing import cast
from danswer.configs.constants import KV_SLACK_BOT_TOKENS_CONFIG_KEY
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.server.manage.models import SlackBotTokens
_SLACK_BOT_TOKENS_CONFIG_KEY = "slack_bot_tokens_config_key"
def fetch_tokens() -> SlackBotTokens:
# first check env variables
app_token = os.environ.get("DANSWER_BOT_SLACK_APP_TOKEN")
@@ -17,7 +15,7 @@ def fetch_tokens() -> SlackBotTokens:
dynamic_config_store = get_dynamic_config_store()
return SlackBotTokens(
**cast(dict, dynamic_config_store.load(key=_SLACK_BOT_TOKENS_CONFIG_KEY))
**cast(dict, dynamic_config_store.load(key=KV_SLACK_BOT_TOKENS_CONFIG_KEY))
)
@@ -26,5 +24,5 @@ def save_tokens(
) -> None:
dynamic_config_store = get_dynamic_config_store()
dynamic_config_store.store(
key=_SLACK_BOT_TOKENS_CONFIG_KEY, val=dict(tokens), encrypt=True
key=KV_SLACK_BOT_TOKENS_CONFIG_KEY, val=dict(tokens), encrypt=True
)

View File

@@ -3,7 +3,6 @@ import random
import re
import string
import time
from collections.abc import MutableMapping
from typing import Any
from typing import cast
from typing import Optional
@@ -25,7 +24,6 @@ from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.connectors.slack.utils import SlackTextCleaner
from danswer.danswerbot.slack.constants import FeedbackVisibility
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
from danswer.danswerbot.slack.tokens import fetch_tokens
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.users import get_user_by_email
@@ -110,20 +108,6 @@ def remove_danswer_bot_tag(message_str: str, client: WebClient) -> str:
return re.sub(rf"<@{bot_tag_id}>\s", "", message_str)
class ChannelIdAdapter(logging.LoggerAdapter):
"""This is used to add the channel ID to all log messages
emitted in this file"""
def process(
self, msg: str, kwargs: MutableMapping[str, Any]
) -> tuple[str, MutableMapping[str, Any]]:
channel_id = self.extra.get(SLACK_CHANNEL_ID) if self.extra else None
if channel_id:
return f"[Channel ID: {channel_id}] {msg}", kwargs
else:
return msg, kwargs
def get_web_client() -> WebClient:
slack_tokens = fetch_tokens()
return WebClient(token=slack_tokens.bot_token)

View File

@@ -16,7 +16,7 @@ from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from danswer.auth.schemas import UserRole
from danswer.chat.models import LLMRelevanceSummaryResponse
from danswer.chat.models import DocumentRelevance
from danswer.configs.chat_configs import HARD_DELETE_CHATS
from danswer.configs.constants import MessageType
from danswer.db.models import ChatMessage
@@ -117,6 +117,7 @@ def get_chat_sessions_by_user(
deleted: bool | None,
db_session: Session,
only_one_shot: bool = False,
limit: int = 50,
) -> list[ChatSession]:
stmt = select(ChatSession).where(ChatSession.user_id == user_id)
@@ -130,6 +131,9 @@ def get_chat_sessions_by_user(
if deleted is not None:
stmt = stmt.where(ChatSession.deleted == deleted)
if limit:
stmt = stmt.limit(limit)
result = db_session.execute(stmt)
chat_sessions = result.scalars().all()
@@ -393,6 +397,34 @@ def get_or_create_root_message(
return new_root_message
def reserve_message_id(
db_session: Session,
chat_session_id: int,
parent_message: int,
message_type: MessageType,
) -> int:
# Create an empty chat message
empty_message = ChatMessage(
chat_session_id=chat_session_id,
parent_message=parent_message,
latest_child_message=None,
message="",
token_count=0,
message_type=message_type,
)
# Add the empty message to the session
db_session.add(empty_message)
# Flush the session to get an ID for the new chat message
db_session.flush()
# Get the ID of the newly created message
new_id = empty_message.id
return new_id
def create_new_chat_message(
chat_session_id: int,
parent_message: ChatMessage,
@@ -410,29 +442,51 @@ def create_new_chat_message(
citations: dict[int, int] | None = None,
tool_calls: list[ToolCall] | None = None,
commit: bool = True,
reserved_message_id: int | None = None,
) -> ChatMessage:
new_chat_message = ChatMessage(
chat_session_id=chat_session_id,
parent_message=parent_message.id,
latest_child_message=None,
message=message,
rephrased_query=rephrased_query,
prompt_id=prompt_id,
token_count=token_count,
message_type=message_type,
citations=citations,
files=files,
tool_calls=tool_calls if tool_calls else [],
error=error,
alternate_assistant_id=alternate_assistant_id,
)
if reserved_message_id is not None:
# Edit existing message
existing_message = db_session.query(ChatMessage).get(reserved_message_id)
if existing_message is None:
raise ValueError(f"No message found with id {reserved_message_id}")
existing_message.chat_session_id = chat_session_id
existing_message.parent_message = parent_message.id
existing_message.message = message
existing_message.rephrased_query = rephrased_query
existing_message.prompt_id = prompt_id
existing_message.token_count = token_count
existing_message.message_type = message_type
existing_message.citations = citations
existing_message.files = files
existing_message.tool_calls = tool_calls if tool_calls else []
existing_message.error = error
existing_message.alternate_assistant_id = alternate_assistant_id
new_chat_message = existing_message
else:
# Create new message
new_chat_message = ChatMessage(
chat_session_id=chat_session_id,
parent_message=parent_message.id,
latest_child_message=None,
message=message,
rephrased_query=rephrased_query,
prompt_id=prompt_id,
token_count=token_count,
message_type=message_type,
citations=citations,
files=files,
tool_calls=tool_calls if tool_calls else [],
error=error,
alternate_assistant_id=alternate_assistant_id,
)
db_session.add(new_chat_message)
# SQL Alchemy will propagate this to update the reference_docs' foreign keys
if reference_docs:
new_chat_message.search_docs = reference_docs
db_session.add(new_chat_message)
# Flush the session to get an ID for the new chat message
db_session.flush()
@@ -541,11 +595,11 @@ def get_doc_query_identifiers_from_model(
def update_search_docs_table_with_relevance(
db_session: Session,
reference_db_search_docs: list[SearchDoc],
relevance_summary: LLMRelevanceSummaryResponse,
relevance_summary: DocumentRelevance,
) -> None:
for search_doc in reference_db_search_docs:
relevance_data = relevance_summary.relevance_summaries.get(
f"{search_doc.document_id}-{search_doc.chunk_ind}"
search_doc.document_id
)
if relevance_data is not None:
db_session.execute(

View File

@@ -1,7 +1,7 @@
from typing import cast
from fastapi import HTTPException
from sqlalchemy import and_
from sqlalchemy import exists
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import aliased
@@ -11,6 +11,7 @@ from danswer.configs.app_configs import DEFAULT_PRUNING_FREQ
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import InputType
from danswer.db.models import Connector
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexAttempt
from danswer.server.documents.models import ConnectorBase
from danswer.server.documents.models import ObjectCreationIdResponse
@@ -20,19 +21,24 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
def check_connectors_exist(db_session: Session) -> bool:
# Connector 0 is created on server startup as a default for ingestion
# it will always exist and we don't need to count it for this
stmt = select(exists(Connector).where(Connector.id > 0))
result = db_session.execute(stmt)
return result.scalar() or False
def fetch_connectors(
db_session: Session,
sources: list[DocumentSource] | None = None,
input_types: list[InputType] | None = None,
disabled_status: bool | None = None,
) -> list[Connector]:
stmt = select(Connector)
if sources is not None:
stmt = stmt.where(Connector.source.in_(sources))
if input_types is not None:
stmt = stmt.where(Connector.input_type.in_(input_types))
if disabled_status is not None:
stmt = stmt.where(Connector.disabled == disabled_status)
results = db_session.scalars(stmt)
return list(results.all())
@@ -85,10 +91,8 @@ def create_connector(
input_type=connector_data.input_type,
connector_specific_config=connector_data.connector_specific_config,
refresh_freq=connector_data.refresh_freq,
prune_freq=connector_data.prune_freq
if connector_data.prune_freq is not None
else DEFAULT_PRUNING_FREQ,
disabled=connector_data.disabled,
indexing_start=connector_data.indexing_start,
prune_freq=connector_data.prune_freq,
)
db_session.add(connector)
db_session.commit()
@@ -122,33 +126,18 @@ def update_connector(
if connector_data.prune_freq is not None
else DEFAULT_PRUNING_FREQ
)
connector.disabled = connector_data.disabled
db_session.commit()
return connector
def disable_connector(
connector_id: int,
db_session: Session,
) -> StatusResponse[int]:
connector = fetch_connector_by_id(connector_id, db_session)
if connector is None:
raise HTTPException(status_code=404, detail="Connector does not exist")
connector.disabled = True
db_session.commit()
return StatusResponse(
success=True, message="Connector deleted successfully", data=connector_id
)
def delete_connector(
connector_id: int,
db_session: Session,
) -> StatusResponse[int]:
"""Currently unused due to foreign key restriction from IndexAttempt
Use disable_connector instead"""
"""Only used in special cases (e.g. a connector is in a bad state and we need to delete it).
Be VERY careful using this, as it could lead to a bad state if not used correctly.
"""
connector = fetch_connector_by_id(connector_id, db_session)
if connector is None:
return StatusResponse(
@@ -179,11 +168,9 @@ def fetch_latest_index_attempt_by_connector(
latest_index_attempts: list[IndexAttempt] = []
if source:
connectors = fetch_connectors(
db_session, sources=[source], disabled_status=False
)
connectors = fetch_connectors(db_session, sources=[source])
else:
connectors = fetch_connectors(db_session, disabled_status=False)
connectors = fetch_connectors(db_session)
if not connectors:
return []
@@ -191,7 +178,8 @@ def fetch_latest_index_attempt_by_connector(
for connector in connectors:
latest_index_attempt = (
db_session.query(IndexAttempt)
.filter(IndexAttempt.connector_id == connector.id)
.join(ConnectorCredentialPair)
.filter(ConnectorCredentialPair.connector_id == connector.id)
.order_by(IndexAttempt.time_updated.desc())
.first()
)
@@ -207,13 +195,11 @@ def fetch_latest_index_attempts_by_status(
) -> list[IndexAttempt]:
subquery = (
db_session.query(
IndexAttempt.connector_id,
IndexAttempt.credential_id,
IndexAttempt.connector_credential_pair_id,
IndexAttempt.status,
func.max(IndexAttempt.time_updated).label("time_updated"),
)
.group_by(IndexAttempt.connector_id)
.group_by(IndexAttempt.credential_id)
.group_by(IndexAttempt.connector_credential_pair_id)
.group_by(IndexAttempt.status)
.subquery()
)
@@ -223,12 +209,13 @@ def fetch_latest_index_attempts_by_status(
query = db_session.query(IndexAttempt).join(
alias,
and_(
IndexAttempt.connector_id == alias.connector_id,
IndexAttempt.credential_id == alias.credential_id,
IndexAttempt.connector_credential_pair_id
== alias.connector_credential_pair_id,
IndexAttempt.status == alias.status,
IndexAttempt.time_updated == alias.time_updated,
),
)
return cast(list[IndexAttempt], query.all())
@@ -247,20 +234,29 @@ def fetch_unique_document_sources(db_session: Session) -> list[DocumentSource]:
def create_initial_default_connector(db_session: Session) -> None:
default_connector_id = 0
default_connector = fetch_connector_by_id(default_connector_id, db_session)
if default_connector is not None:
if (
default_connector.source != DocumentSource.INGESTION_API
or default_connector.input_type != InputType.LOAD_STATE
or default_connector.refresh_freq is not None
or default_connector.disabled
or default_connector.name != "Ingestion API"
or default_connector.connector_specific_config != {}
or default_connector.prune_freq is not None
):
raise ValueError(
"DB is not in a valid initial state. "
"Default connector does not have expected values."
logger.warning(
"Default connector does not have expected values. Updating to proper state."
)
# Ensure default connector has correct valuesg
default_connector.source = DocumentSource.INGESTION_API
default_connector.input_type = InputType.LOAD_STATE
default_connector.refresh_freq = None
default_connector.name = "Ingestion API"
default_connector.connector_specific_config = {}
default_connector.prune_freq = None
db_session.commit()
return
# Create a new default connector if it doesn't exist
connector = Connector(
id=default_connector_id,
name="Ingestion API",

View File

@@ -6,8 +6,10 @@ from sqlalchemy import desc
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.configs.constants import DocumentSource
from danswer.db.connector import fetch_connector_by_id
from danswer.db.credentials import fetch_credential_by_id
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import EmbeddingModel
from danswer.db.models import IndexAttempt
@@ -25,7 +27,9 @@ def get_connector_credential_pairs(
) -> list[ConnectorCredentialPair]:
stmt = select(ConnectorCredentialPair)
if not include_disabled:
stmt = stmt.where(ConnectorCredentialPair.connector.disabled == False) # noqa
stmt = stmt.where(
ConnectorCredentialPair.status == ConnectorCredentialPairStatus.ACTIVE
) # noqa
results = db_session.scalars(stmt)
return list(results.all())
@@ -42,6 +46,17 @@ def get_connector_credential_pair(
return result.scalar_one_or_none()
def get_connector_credential_source_from_id(
cc_pair_id: int,
db_session: Session,
) -> DocumentSource | None:
stmt = select(ConnectorCredentialPair)
stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id)
result = db_session.execute(stmt)
cc_pair = result.scalar_one_or_none()
return cc_pair.connector.source if cc_pair else None
def get_connector_credential_pair_from_id(
cc_pair_id: int,
db_session: Session,
@@ -75,26 +90,78 @@ def get_last_successful_attempt_time(
# For Secondary Index we don't keep track of the latest success, so have to calculate it live
attempt = (
db_session.query(IndexAttempt)
.join(
ConnectorCredentialPair,
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
)
.filter(
IndexAttempt.connector_id == connector_id,
IndexAttempt.credential_id == credential_id,
ConnectorCredentialPair.connector_id == connector_id,
ConnectorCredentialPair.credential_id == credential_id,
IndexAttempt.embedding_model_id == embedding_model.id,
IndexAttempt.status == IndexingStatus.SUCCESS,
)
.order_by(IndexAttempt.time_started.desc())
.first()
)
if not attempt or not attempt.time_started:
connector = fetch_connector_by_id(connector_id, db_session)
if connector and connector.indexing_start:
return connector.indexing_start.timestamp()
return 0.0
return attempt.time_started.timestamp()
"""Updates"""
def _update_connector_credential_pair(
db_session: Session,
cc_pair: ConnectorCredentialPair,
status: ConnectorCredentialPairStatus | None = None,
net_docs: int | None = None,
run_dt: datetime | None = None,
) -> None:
# simply don't update last_successful_index_time if run_dt is not specified
# at worst, this would result in re-indexing documents that were already indexed
if run_dt is not None:
cc_pair.last_successful_index_time = run_dt
if net_docs is not None:
cc_pair.total_docs_indexed += net_docs
if status is not None:
cc_pair.status = status
db_session.commit()
def update_connector_credential_pair_from_id(
db_session: Session,
cc_pair_id: int,
status: ConnectorCredentialPairStatus | None = None,
net_docs: int | None = None,
run_dt: datetime | None = None,
) -> None:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if not cc_pair:
logger.warning(
f"Attempted to update pair for Connector Credential Pair '{cc_pair_id}'"
f" but it does not exist"
)
return
_update_connector_credential_pair(
db_session=db_session,
cc_pair=cc_pair,
status=status,
net_docs=net_docs,
run_dt=run_dt,
)
def update_connector_credential_pair(
db_session: Session,
connector_id: int,
credential_id: int,
status: ConnectorCredentialPairStatus | None = None,
net_docs: int | None = None,
run_dt: datetime | None = None,
) -> None:
@@ -105,13 +172,14 @@ def update_connector_credential_pair(
f"and credential id {credential_id}"
)
return
# simply don't update last_successful_index_time if run_dt is not specified
# at worst, this would result in re-indexing documents that were already indexed
if run_dt is not None:
cc_pair.last_successful_index_time = run_dt
if net_docs is not None:
cc_pair.total_docs_indexed += net_docs
db_session.commit()
_update_connector_credential_pair(
db_session=db_session,
cc_pair=cc_pair,
status=status,
net_docs=net_docs,
run_dt=run_dt,
)
def delete_connector_credential_pair__no_commit(
@@ -142,6 +210,8 @@ def associate_default_cc_pair(db_session: Session) -> None:
connector_id=0,
credential_id=0,
name="DefaultCCPair",
status=ConnectorCredentialPairStatus.ACTIVE,
is_public=True,
)
db_session.add(association)
db_session.commit()
@@ -186,6 +256,7 @@ def add_credential_to_connector(
connector_id=connector_id,
credential_id=credential_id,
name=cc_pair_name,
status=ConnectorCredentialPairStatus.ACTIVE,
is_public=is_public,
)
db_session.add(association)
@@ -241,6 +312,12 @@ def remove_credential_from_connector(
)
def fetch_connector_credential_pairs(
db_session: Session,
) -> list[ConnectorCredentialPair]:
return db_session.query(ConnectorCredentialPair).all()
def resync_cc_pair(
cc_pair: ConnectorCredentialPair,
db_session: Session,
@@ -253,10 +330,14 @@ def resync_cc_pair(
) -> IndexAttempt | None:
query = (
db_session.query(IndexAttempt)
.join(
ConnectorCredentialPair,
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
)
.join(EmbeddingModel, IndexAttempt.embedding_model_id == EmbeddingModel.id)
.filter(
IndexAttempt.connector_id == connector_id,
IndexAttempt.credential_id == credential_id,
ConnectorCredentialPair.connector_id == connector_id,
ConnectorCredentialPair.credential_id == credential_id,
EmbeddingModel.status == IndexModelStatus.PRESENT,
)
)

View File

@@ -2,10 +2,13 @@ from typing import Any
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import and_
from sqlalchemy.sql.expression import or_
from danswer.auth.schemas import UserRole
from danswer.configs.constants import DocumentSource
from danswer.connectors.gmail.constants import (
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
@@ -14,8 +17,10 @@ from danswer.connectors.google_drive.constants import (
)
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential
from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.models import User
from danswer.server.documents.models import CredentialBase
from danswer.server.documents.models import CredentialDataUpdateRequest
from danswer.utils.logger import setup_logger
@@ -74,6 +79,69 @@ def fetch_credential_by_id(
return credential
def fetch_credentials_by_source(
db_session: Session,
user: User | None,
document_source: DocumentSource | None = None,
) -> list[Credential]:
base_query = select(Credential).where(Credential.source == document_source)
base_query = _attach_user_filters(base_query, user)
credentials = db_session.execute(base_query).scalars().all()
return list(credentials)
def swap_credentials_connector(
new_credential_id: int, connector_id: int, user: User | None, db_session: Session
) -> ConnectorCredentialPair:
# Check if the user has permission to use the new credential
new_credential = fetch_credential_by_id(new_credential_id, user, db_session)
if not new_credential:
raise ValueError(
f"No Credential found with id {new_credential_id} or user doesn't have permission to use it"
)
# Existing pair
existing_pair = db_session.execute(
select(ConnectorCredentialPair).where(
ConnectorCredentialPair.connector_id == connector_id
)
).scalar_one_or_none()
if not existing_pair:
raise ValueError(
f"No ConnectorCredentialPair found for connector_id {connector_id}"
)
# Check if the new credential is compatible with the connector
if new_credential.source != existing_pair.connector.source:
raise ValueError(
f"New credential source {new_credential.source} does not match connector source {existing_pair.connector.source}"
)
db_session.execute(
update(DocumentByConnectorCredentialPair)
.where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id
== existing_pair.credential_id,
)
)
.values(credential_id=new_credential_id)
)
# Update the existing pair with the new credential
existing_pair.credential_id = new_credential_id
existing_pair.credential = new_credential
# Commit the changes
db_session.commit()
# Refresh the object to ensure all relationships are up-to-date
db_session.refresh(existing_pair)
return existing_pair
def create_credential(
credential_data: CredentialBase,
user: User | None,
@@ -83,6 +151,8 @@ def create_credential(
credential_json=credential_data.credential_json,
user_id=user.id if user else None,
admin_public=credential_data.admin_public,
source=credential_data.source,
name=credential_data.name,
)
db_session.add(credential)
db_session.commit()
@@ -90,6 +160,28 @@ def create_credential(
return credential
def alter_credential(
credential_id: int,
credential_data: CredentialDataUpdateRequest,
user: User,
db_session: Session,
) -> Credential | None:
credential = fetch_credential_by_id(credential_id, user, db_session)
if credential is None:
return None
credential.name = credential_data.name
# Update only the keys present in credential_data.credential_json
for key, value in credential_data.credential_json.items():
credential.credential_json[key] = value
credential.user_id = user.id if user is not None else None
db_session.commit()
return credential
def update_credential(
credential_id: int,
credential_data: CredentialBase,
@@ -136,6 +228,7 @@ def delete_credential(
credential_id: int,
user: User | None,
db_session: Session,
force: bool = False,
) -> None:
credential = fetch_credential_by_id(credential_id, user, db_session)
if credential is None:
@@ -149,11 +242,38 @@ def delete_credential(
.all()
)
if associated_connectors:
raise ValueError(
f"Cannot delete credential {credential_id} as it is still associated with {len(associated_connectors)} connector(s). "
"Please delete all associated connectors first."
)
associated_doc_cc_pairs = (
db_session.query(DocumentByConnectorCredentialPair)
.filter(DocumentByConnectorCredentialPair.credential_id == credential_id)
.all()
)
if associated_connectors or associated_doc_cc_pairs:
if force:
logger.warning(
f"Force deleting credential {credential_id} and its associated records"
)
# Delete DocumentByConnectorCredentialPair records first
for doc_cc_pair in associated_doc_cc_pairs:
db_session.delete(doc_cc_pair)
# Then delete ConnectorCredentialPair records
for connector in associated_connectors:
db_session.delete(connector)
# Commit these deletions before deleting the credential
db_session.flush()
else:
raise ValueError(
f"Cannot delete credential as it is still associated with "
f"{len(associated_connectors)} connector(s) and {len(associated_doc_cc_pairs)} document(s). "
)
if force:
logger.warning(f"Force deleting credential {credential_id}")
else:
logger.notice(f"Deleting credential {credential_id}")
db_session.delete(credential)
db_session.commit()

View File

@@ -1,6 +1,7 @@
from sqlalchemy.orm import Session
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.index_attempt import get_last_attempt
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexingStatus
@@ -13,7 +14,7 @@ def check_deletion_attempt_is_allowed(
) -> str | None:
"""
To be deletable:
(1) connector should be disabled
(1) connector should be paused
(2) there should be no in-progress/planned index attempts
Returns an error message if the deletion attempt is not allowed, otherwise None.
@@ -23,7 +24,10 @@ def check_deletion_attempt_is_allowed(
f"'{connector_credential_pair.credential_id}' is not deletable."
)
if not connector_credential_pair.connector.disabled:
if (
connector_credential_pair.status != ConnectorCredentialPairStatus.PAUSED
and connector_credential_pair.status != ConnectorCredentialPairStatus.DELETING
):
return base_error_msg + " Connector must be paused."
connector_id = connector_credential_pair.connector_id

View File

@@ -7,6 +7,7 @@ from uuid import UUID
from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import exists
from sqlalchemy import func
from sqlalchemy import or_
from sqlalchemy import select
@@ -16,6 +17,7 @@ from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session
from danswer.configs.constants import DEFAULT_BOOST
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.feedback import delete_document_feedback_for_documents__no_commit
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential
@@ -30,6 +32,12 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
def check_docs_exist(db_session: Session) -> bool:
stmt = select(exists(DbDocument))
result = db_session.execute(stmt)
return result.scalar() or False
def get_documents_for_connector_credential_pair(
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
) -> Sequence[DbDocument]:
@@ -103,36 +111,19 @@ def get_document_cnts_for_cc_pairs(
def get_acccess_info_for_documents(
db_session: Session,
document_ids: list[str],
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
) -> Sequence[tuple[str, list[UUID | None], bool]]:
"""Gets back all relevant access info for the given documents. This includes
the user_ids for cc pairs that the document is associated with + whether any
of the associated cc pairs are intending to make the document globally public.
If `cc_pair_to_delete` is specified, gets the above access info as if that
pair had been deleted. This is needed since we want to delete from the Vespa
before deleting from Postgres to ensure that the state of Postgres never "loses"
documents that still exist in Vespa.
"""
stmt = select(
DocumentByConnectorCredentialPair.id,
func.array_agg(Credential.user_id).label("user_ids"),
func.bool_or(ConnectorCredentialPair.is_public).label("public_doc"),
).where(DocumentByConnectorCredentialPair.id.in_(document_ids))
# pretend that the specified cc pair doesn't exist
if cc_pair_to_delete:
stmt = stmt.where(
and_(
DocumentByConnectorCredentialPair.connector_id
!= cc_pair_to_delete.connector_id,
DocumentByConnectorCredentialPair.credential_id
!= cc_pair_to_delete.credential_id,
)
)
stmt = (
stmt.join(
select(
DocumentByConnectorCredentialPair.id,
func.array_agg(Credential.user_id).label("user_ids"),
func.bool_or(ConnectorCredentialPair.is_public).label("public_doc"),
)
.where(DocumentByConnectorCredentialPair.id.in_(document_ids))
.join(
Credential,
DocumentByConnectorCredentialPair.credential_id == Credential.id,
)
@@ -145,6 +136,9 @@ def get_acccess_info_for_documents(
== ConnectorCredentialPair.credential_id,
),
)
# don't include CC pairs that are being deleted
# NOTE: CC pairs can never go from DELETING to any other state -> it's safe to ignore them
.where(ConnectorCredentialPair.status != ConnectorCredentialPairStatus.DELETING)
.group_by(DocumentByConnectorCredentialPair.id)
)
return db_session.execute(stmt).all() # type: ignore
@@ -311,7 +305,7 @@ def acquire_document_locks(db_session: Session, document_ids: list[str]) -> bool
_NUM_LOCK_ATTEMPTS = 10
_LOCK_RETRY_DELAY = 30
_LOCK_RETRY_DELAY = 10
@contextlib.contextmanager
@@ -323,7 +317,7 @@ def prepare_to_modify_documents(
called ahead of any modification to Vespa. Locks should be released by the
caller as soon as updates are complete by finishing the transaction.
NOTE: only one commit is allowed within the context manager returned by this funtion.
NOTE: only one commit is allowed within the context manager returned by this function.
Multiple commits will result in a sqlalchemy.exc.InvalidRequestError.
NOTE: this function will commit any existing transaction.
"""
@@ -341,7 +335,9 @@ def prepare_to_modify_documents(
yield transaction
break
except OperationalError as e:
logger.info(f"Failed to acquire locks for documents, retrying. Error: {e}")
logger.warning(
f"Failed to acquire locks for documents, retrying. Error: {e}"
)
time.sleep(retry_delay)

View File

@@ -9,6 +9,7 @@ from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Document
from danswer.db.models import DocumentByConnectorCredentialPair
@@ -270,37 +271,20 @@ def mark_document_set_as_to_be_deleted(
raise
def mark_cc_pair__document_set_relationships_to_be_deleted__no_commit(
cc_pair_id: int, db_session: Session
) -> set[int]:
"""Marks all CC Pair -> Document Set relationships for the specified
`cc_pair_id` as not current and returns the list of all document set IDs
affected.
NOTE: rases a `ValueError` if any of the document sets are currently syncing
to avoid getting into a bad state."""
document_set__cc_pair_relationships = db_session.scalars(
select(DocumentSet__ConnectorCredentialPair).where(
def delete_document_set_cc_pair_relationship__no_commit(
connector_id: int, credential_id: int, db_session: Session
) -> None:
"""Deletes all rows from DocumentSet__ConnectorCredentialPair where the
connector_credential_pair_id matches the given cc_pair_id."""
delete_stmt = delete(DocumentSet__ConnectorCredentialPair).where(
and_(
ConnectorCredentialPair.connector_id == connector_id,
ConnectorCredentialPair.credential_id == credential_id,
DocumentSet__ConnectorCredentialPair.connector_credential_pair_id
== cc_pair_id
== ConnectorCredentialPair.id,
)
).all()
document_set_ids_touched: set[int] = set()
for document_set__cc_pair_relationship in document_set__cc_pair_relationships:
document_set__cc_pair_relationship.is_current = False
if not document_set__cc_pair_relationship.document_set.is_up_to_date:
raise ValueError(
"Cannot delete CC pair while it is attached to a document set "
"that is syncing. Please wait for the document set to finish "
"syncing, and then try again."
)
document_set__cc_pair_relationship.document_set.is_up_to_date = False
document_set_ids_touched.add(document_set__cc_pair_relationship.document_set_id)
return document_set_ids_touched
)
db_session.execute(delete_stmt)
def fetch_document_sets(
@@ -431,8 +415,10 @@ def fetch_documents_for_document_set_paginated(
def fetch_document_sets_for_documents(
document_ids: list[str], db_session: Session
document_ids: list[str],
db_session: Session,
) -> Sequence[tuple[str, list[str]]]:
"""Gives back a list of (document_id, list[document_set_names]) tuples"""
stmt = (
select(Document.id, func.array_agg(DocumentSetDBModel.name))
.join(
@@ -459,6 +445,10 @@ def fetch_document_sets_for_documents(
Document.id == DocumentByConnectorCredentialPair.id,
)
.where(Document.id.in_(document_ids))
# don't include CC pairs that are being deleted
# NOTE: CC pairs can never go from DELETING to any other state -> it's safe to ignore them
# as we can assume their document sets are no longer relevant
.where(ConnectorCredentialPair.status != ConnectorCredentialPairStatus.DELETING)
.where(DocumentSet__ConnectorCredentialPair.is_current == True) # noqa: E712
.group_by(Document.id)
)

View File

@@ -15,7 +15,7 @@ from danswer.db.models import CloudEmbeddingProvider
from danswer.db.models import EmbeddingModel
from danswer.db.models import IndexModelStatus
from danswer.indexing.models import EmbeddingModelDetail
from danswer.search.search_nlp_models import clean_model_name
from danswer.natural_language_processing.search_nlp_models import clean_model_name
from danswer.server.manage.embedding.models import (
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
)

View File

@@ -1,9 +1,11 @@
import contextlib
import time
from collections.abc import AsyncGenerator
from collections.abc import Generator
from datetime import datetime
from typing import ContextManager
from sqlalchemy import event
from sqlalchemy import text
from sqlalchemy.engine import create_engine
from sqlalchemy.engine import Engine
@@ -13,11 +15,16 @@ from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from danswer.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
from danswer.configs.app_configs import LOG_POSTGRES_LATENCY
from danswer.configs.app_configs import POSTGRES_DB
from danswer.configs.app_configs import POSTGRES_HOST
from danswer.configs.app_configs import POSTGRES_PASSWORD
from danswer.configs.app_configs import POSTGRES_POOL_PRE_PING
from danswer.configs.app_configs import POSTGRES_POOL_RECYCLE
from danswer.configs.app_configs import POSTGRES_PORT
from danswer.configs.app_configs import POSTGRES_USER
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -25,12 +32,70 @@ logger = setup_logger()
SYNC_DB_API = "psycopg2"
ASYNC_DB_API = "asyncpg"
POSTGRES_APP_NAME = (
POSTGRES_UNKNOWN_APP_NAME # helps to diagnose open connections in postgres
)
# global so we don't create more than one engine per process
# outside of being best practice, this is needed so we can properly pool
# connections and not create a new pool on every request
_SYNC_ENGINE: Engine | None = None
_ASYNC_ENGINE: AsyncEngine | None = None
SessionFactory: sessionmaker[Session] | None = None
if LOG_POSTGRES_LATENCY:
# Function to log before query execution
@event.listens_for(Engine, "before_cursor_execute")
def before_cursor_execute( # type: ignore
conn, cursor, statement, parameters, context, executemany
):
conn.info["query_start_time"] = time.time()
# Function to log after query execution
@event.listens_for(Engine, "after_cursor_execute")
def after_cursor_execute( # type: ignore
conn, cursor, statement, parameters, context, executemany
):
total_time = time.time() - conn.info["query_start_time"]
# don't spam TOO hard
if total_time > 0.1:
logger.debug(
f"Query Complete: {statement}\n\nTotal Time: {total_time:.4f} seconds"
)
if LOG_POSTGRES_CONN_COUNTS:
# Global counter for connection checkouts and checkins
checkout_count = 0
checkin_count = 0
@event.listens_for(Engine, "checkout")
def log_checkout(dbapi_connection, connection_record, connection_proxy): # type: ignore
global checkout_count
checkout_count += 1
active_connections = connection_proxy._pool.checkedout()
idle_connections = connection_proxy._pool.checkedin()
pool_size = connection_proxy._pool.size()
logger.debug(
"Connection Checkout\n"
f"Active Connections: {active_connections};\n"
f"Idle: {idle_connections};\n"
f"Pool Size: {pool_size};\n"
f"Total connection checkouts: {checkout_count}"
)
@event.listens_for(Engine, "checkin")
def log_checkin(dbapi_connection, connection_record): # type: ignore
global checkin_count
checkin_count += 1
logger.debug(f"Total connection checkins: {checkin_count}")
"""END DEBUGGING LOGGING"""
def get_db_current_time(db_session: Session) -> datetime:
"""Get the current time from Postgres representing the start of the transaction
@@ -51,24 +116,50 @@ def build_connection_string(
host: str = POSTGRES_HOST,
port: str = POSTGRES_PORT,
db: str = POSTGRES_DB,
app_name: str | None = None,
) -> str:
if app_name:
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}?application_name={app_name}"
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
def init_sqlalchemy_engine(app_name: str) -> None:
global POSTGRES_APP_NAME
POSTGRES_APP_NAME = app_name
def get_sqlalchemy_engine() -> Engine:
global _SYNC_ENGINE
if _SYNC_ENGINE is None:
connection_string = build_connection_string(db_api=SYNC_DB_API)
_SYNC_ENGINE = create_engine(connection_string, pool_size=40, max_overflow=10)
connection_string = build_connection_string(
db_api=SYNC_DB_API, app_name=POSTGRES_APP_NAME + "_sync"
)
_SYNC_ENGINE = create_engine(
connection_string,
pool_size=40,
max_overflow=10,
pool_pre_ping=POSTGRES_POOL_PRE_PING,
pool_recycle=POSTGRES_POOL_RECYCLE,
)
return _SYNC_ENGINE
def get_sqlalchemy_async_engine() -> AsyncEngine:
global _ASYNC_ENGINE
if _ASYNC_ENGINE is None:
# underlying asyncpg cannot accept application_name directly in the connection string
# https://github.com/MagicStack/asyncpg/issues/798
connection_string = build_connection_string()
_ASYNC_ENGINE = create_async_engine(
connection_string, pool_size=40, max_overflow=10
connection_string,
connect_args={
"server_settings": {"application_name": POSTGRES_APP_NAME + "_async"}
},
pool_size=40,
max_overflow=10,
pool_pre_ping=POSTGRES_POOL_PRE_PING,
pool_recycle=POSTGRES_POOL_RECYCLE,
)
return _ASYNC_ENGINE
@@ -93,7 +184,7 @@ async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
async def warm_up_connections(
sync_connections_to_warm_up: int = 10, async_connections_to_warm_up: int = 10
sync_connections_to_warm_up: int = 20, async_connections_to_warm_up: int = 20
) -> None:
sync_postgres_engine = get_sqlalchemy_engine()
connections = [
@@ -115,4 +206,8 @@ async def warm_up_connections(
await async_conn.close()
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
def get_session_factory() -> sessionmaker[Session]:
global SessionFactory
if SessionFactory is None:
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
return SessionFactory

View File

@@ -6,6 +6,15 @@ class IndexingStatus(str, PyEnum):
IN_PROGRESS = "in_progress"
SUCCESS = "success"
FAILED = "failed"
COMPLETED_WITH_ERRORS = "completed_with_errors"
def is_terminal(self) -> bool:
terminal_states = {
IndexingStatus.SUCCESS,
IndexingStatus.COMPLETED_WITH_ERRORS,
IndexingStatus.FAILED,
}
return self in terminal_states
# these may differ in the future, which is why we're okay with this duplication
@@ -33,3 +42,9 @@ class IndexModelStatus(str, PyEnum):
class ChatSessionSharedStatus(str, PyEnum):
PUBLIC = "public"
PRIVATE = "private"
class ConnectorCredentialPairStatus(str, PyEnum):
ACTIVE = "ACTIVE"
PAUSED = "PAUSED"
DELETING = "DELETING"

View File

@@ -1,20 +1,22 @@
from collections.abc import Sequence
from sqlalchemy import and_
from sqlalchemy import ColumnElement
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from danswer.connectors.models import Document
from danswer.connectors.models import DocumentErrorSummary
from danswer.db.models import EmbeddingModel
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexAttemptError
from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.server.documents.models import ConnectorCredentialPair
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
@@ -23,6 +25,22 @@ from danswer.utils.telemetry import RecordType
logger = setup_logger()
def get_last_attempt_for_cc_pair(
cc_pair_id: int,
embedding_model_id: int,
db_session: Session,
) -> IndexAttempt | None:
return (
db_session.query(IndexAttempt)
.filter(
IndexAttempt.connector_credential_pair_id == cc_pair_id,
IndexAttempt.embedding_model_id == embedding_model_id,
)
.order_by(IndexAttempt.time_updated.desc())
.first()
)
def get_index_attempt(
db_session: Session, index_attempt_id: int
) -> IndexAttempt | None:
@@ -31,15 +49,13 @@ def get_index_attempt(
def create_index_attempt(
connector_id: int,
credential_id: int,
connector_credential_pair_id: int,
embedding_model_id: int,
db_session: Session,
from_beginning: bool = False,
) -> int:
new_attempt = IndexAttempt(
connector_id=connector_id,
credential_id=credential_id,
connector_credential_pair_id=connector_credential_pair_id,
embedding_model_id=embedding_model_id,
from_beginning=from_beginning,
status=IndexingStatus.NOT_STARTED,
@@ -56,7 +72,9 @@ def get_inprogress_index_attempts(
) -> list[IndexAttempt]:
stmt = select(IndexAttempt)
if connector_id is not None:
stmt = stmt.where(IndexAttempt.connector_id == connector_id)
stmt = stmt.where(
IndexAttempt.connector_credential_pair.has(connector_id=connector_id)
)
stmt = stmt.where(IndexAttempt.status == IndexingStatus.IN_PROGRESS)
incomplete_attempts = db_session.scalars(stmt)
@@ -65,21 +83,31 @@ def get_inprogress_index_attempts(
def get_not_started_index_attempts(db_session: Session) -> list[IndexAttempt]:
"""This eagerly loads the connector and credential so that the db_session can be expired
before running long-living indexing jobs, which causes increasing memory usage"""
before running long-living indexing jobs, which causes increasing memory usage.
Results are ordered by time_created (oldest to newest)."""
stmt = select(IndexAttempt)
stmt = stmt.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
stmt = stmt.order_by(IndexAttempt.time_created)
stmt = stmt.options(
joinedload(IndexAttempt.connector), joinedload(IndexAttempt.credential)
joinedload(IndexAttempt.connector_credential_pair).joinedload(
ConnectorCredentialPair.connector
),
joinedload(IndexAttempt.connector_credential_pair).joinedload(
ConnectorCredentialPair.credential
),
)
new_attempts = db_session.scalars(stmt)
return list(new_attempts.all())
def mark_attempt_in_progress__no_commit(
def mark_attempt_in_progress(
index_attempt: IndexAttempt,
db_session: Session,
) -> None:
index_attempt.status = IndexingStatus.IN_PROGRESS
index_attempt.time_started = index_attempt.time_started or func.now() # type: ignore
db_session.commit()
def mark_attempt_succeeded(
@@ -91,6 +119,15 @@ def mark_attempt_succeeded(
db_session.commit()
def mark_attempt_partially_succeeded(
index_attempt: IndexAttempt,
db_session: Session,
) -> None:
index_attempt.status = IndexingStatus.COMPLETED_WITH_ERRORS
db_session.add(index_attempt)
db_session.commit()
def mark_attempt_failed(
index_attempt: IndexAttempt,
db_session: Session,
@@ -103,7 +140,7 @@ def mark_attempt_failed(
db_session.add(index_attempt)
db_session.commit()
source = index_attempt.connector.source
source = index_attempt.connector_credential_pair.connector.source
optional_telemetry(record_type=RecordType.FAILURE, data={"connector": source})
@@ -128,11 +165,16 @@ def get_last_attempt(
embedding_model_id: int | None,
db_session: Session,
) -> IndexAttempt | None:
stmt = select(IndexAttempt).where(
IndexAttempt.connector_id == connector_id,
IndexAttempt.credential_id == credential_id,
IndexAttempt.embedding_model_id == embedding_model_id,
stmt = (
select(IndexAttempt)
.join(ConnectorCredentialPair)
.where(
ConnectorCredentialPair.connector_id == connector_id,
ConnectorCredentialPair.credential_id == credential_id,
IndexAttempt.embedding_model_id == embedding_model_id,
)
)
# Note, the below is using time_created instead of time_updated
stmt = stmt.order_by(desc(IndexAttempt.time_created))
@@ -140,14 +182,12 @@ def get_last_attempt(
def get_latest_index_attempts(
connector_credential_pair_identifiers: list[ConnectorCredentialPairIdentifier],
secondary_index: bool,
db_session: Session,
) -> Sequence[IndexAttempt]:
ids_stmt = select(
IndexAttempt.connector_id,
IndexAttempt.credential_id,
func.max(IndexAttempt.time_created).label("max_time_created"),
IndexAttempt.connector_credential_pair_id,
func.max(IndexAttempt.id).label("max_id"),
).join(EmbeddingModel, IndexAttempt.embedding_model_id == EmbeddingModel.id)
if secondary_index:
@@ -155,46 +195,87 @@ def get_latest_index_attempts(
else:
ids_stmt = ids_stmt.where(EmbeddingModel.status == IndexModelStatus.PRESENT)
where_stmts: list[ColumnElement] = []
for connector_credential_pair_identifier in connector_credential_pair_identifiers:
where_stmts.append(
and_(
IndexAttempt.connector_id
== connector_credential_pair_identifier.connector_id,
IndexAttempt.credential_id
== connector_credential_pair_identifier.credential_id,
)
)
if where_stmts:
ids_stmt = ids_stmt.where(or_(*where_stmts))
ids_stmt = ids_stmt.group_by(IndexAttempt.connector_id, IndexAttempt.credential_id)
ids_subqery = ids_stmt.subquery()
ids_stmt = ids_stmt.group_by(IndexAttempt.connector_credential_pair_id)
ids_subquery = ids_stmt.subquery()
stmt = (
select(IndexAttempt)
.join(
ids_subqery,
and_(
ids_subqery.c.connector_id == IndexAttempt.connector_id,
ids_subqery.c.credential_id == IndexAttempt.credential_id,
),
ids_subquery,
IndexAttempt.connector_credential_pair_id
== ids_subquery.c.connector_credential_pair_id,
)
.where(IndexAttempt.time_created == ids_subqery.c.max_time_created)
.where(IndexAttempt.id == ids_subquery.c.max_id)
)
return db_session.execute(stmt).scalars().all()
def get_index_attempts_for_connector(
db_session: Session,
connector_id: int,
only_current: bool = True,
disinclude_finished: bool = False,
) -> Sequence[IndexAttempt]:
stmt = (
select(IndexAttempt)
.join(ConnectorCredentialPair)
.where(ConnectorCredentialPair.connector_id == connector_id)
)
if disinclude_finished:
stmt = stmt.where(
IndexAttempt.status.in_(
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
)
)
if only_current:
stmt = stmt.join(EmbeddingModel).where(
EmbeddingModel.status == IndexModelStatus.PRESENT
)
stmt = stmt.order_by(IndexAttempt.time_created.desc())
return db_session.execute(stmt).scalars().all()
def get_latest_finished_index_attempt_for_cc_pair(
connector_credential_pair_id: int,
secondary_index: bool,
db_session: Session,
) -> IndexAttempt | None:
stmt = select(IndexAttempt).where(
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
IndexAttempt.status.not_in(
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
),
)
if secondary_index:
stmt = stmt.join(EmbeddingModel).where(
EmbeddingModel.status == IndexModelStatus.FUTURE
)
else:
stmt = stmt.join(EmbeddingModel).where(
EmbeddingModel.status == IndexModelStatus.PRESENT
)
stmt = stmt.order_by(desc(IndexAttempt.time_created))
stmt = stmt.limit(1)
return db_session.execute(stmt).scalar_one_or_none()
def get_index_attempts_for_cc_pair(
db_session: Session,
cc_pair_identifier: ConnectorCredentialPairIdentifier,
only_current: bool = True,
disinclude_finished: bool = False,
) -> Sequence[IndexAttempt]:
stmt = select(IndexAttempt).where(
and_(
IndexAttempt.connector_id == cc_pair_identifier.connector_id,
IndexAttempt.credential_id == cc_pair_identifier.credential_id,
stmt = (
select(IndexAttempt)
.join(ConnectorCredentialPair)
.where(
and_(
ConnectorCredentialPair.connector_id == cc_pair_identifier.connector_id,
ConnectorCredentialPair.credential_id
== cc_pair_identifier.credential_id,
)
)
)
if disinclude_finished:
@@ -218,9 +299,11 @@ def delete_index_attempts(
db_session: Session,
) -> None:
stmt = delete(IndexAttempt).where(
IndexAttempt.connector_id == connector_id,
IndexAttempt.credential_id == credential_id,
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
ConnectorCredentialPair.connector_id == connector_id,
ConnectorCredentialPair.credential_id == credential_id,
)
db_session.execute(stmt)
@@ -249,14 +332,15 @@ def expire_index_attempts(
db_session.commit()
def cancel_indexing_attempts_for_connector(
connector_id: int,
def cancel_indexing_attempts_for_ccpair(
cc_pair_id: int,
db_session: Session,
include_secondary_index: bool = False,
) -> None:
stmt = delete(IndexAttempt).where(
IndexAttempt.connector_id == connector_id,
IndexAttempt.status == IndexingStatus.NOT_STARTED,
stmt = (
delete(IndexAttempt)
.where(IndexAttempt.connector_credential_pair_id == cc_pair_id)
.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
)
if not include_secondary_index:
@@ -273,6 +357,8 @@ def cancel_indexing_attempts_for_connector(
def cancel_indexing_attempts_past_model(
db_session: Session,
) -> None:
"""Stops all indexing attempts that are in progress or not started for
any embedding model that not present/future"""
db_session.execute(
update(IndexAttempt)
.where(
@@ -296,7 +382,8 @@ def count_unique_cc_pairs_with_successful_index_attempts(
Then do distinct by connector_id and credential_id which is equivalent to the cc-pair. Finally,
do a count to get the total number of unique cc-pairs with successful attempts"""
unique_pairs_count = (
db_session.query(IndexAttempt.connector_id, IndexAttempt.credential_id)
db_session.query(IndexAttempt.connector_credential_pair_id)
.join(ConnectorCredentialPair)
.filter(
IndexAttempt.embedding_model_id == embedding_model_id,
IndexAttempt.status == IndexingStatus.SUCCESS,
@@ -306,3 +393,41 @@ def count_unique_cc_pairs_with_successful_index_attempts(
)
return unique_pairs_count
def create_index_attempt_error(
index_attempt_id: int | None,
batch: int | None,
docs: list[Document],
exception_msg: str,
exception_traceback: str,
db_session: Session,
) -> int:
doc_summaries = []
for doc in docs:
doc_summary = DocumentErrorSummary.from_document(doc)
doc_summaries.append(doc_summary.to_dict())
new_error = IndexAttemptError(
index_attempt_id=index_attempt_id,
batch=batch,
doc_summaries=doc_summaries,
error_msg=exception_msg,
traceback=exception_traceback,
)
db_session.add(new_error)
db_session.commit()
return new_error.id
def get_index_attempt_errors(
index_attempt_id: int,
db_session: Session,
) -> list[IndexAttemptError]:
stmt = select(IndexAttemptError).where(
IndexAttemptError.index_attempt_id == index_attempt_id
)
errors = db_session.scalars(stmt)
return list(errors.all())

View File

@@ -0,0 +1,202 @@
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import InputPrompt
from danswer.db.models import User
from danswer.server.features.input_prompt.models import InputPromptSnapshot
from danswer.server.manage.models import UserInfo
from danswer.utils.logger import setup_logger
logger = setup_logger()
def insert_input_prompt_if_not_exists(
user: User | None,
input_prompt_id: int | None,
prompt: str,
content: str,
active: bool,
is_public: bool,
db_session: Session,
commit: bool = True,
) -> InputPrompt:
if input_prompt_id is not None:
input_prompt = (
db_session.query(InputPrompt).filter_by(id=input_prompt_id).first()
)
else:
query = db_session.query(InputPrompt).filter(InputPrompt.prompt == prompt)
if user:
query = query.filter(InputPrompt.user_id == user.id)
else:
query = query.filter(InputPrompt.user_id.is_(None))
input_prompt = query.first()
if input_prompt is None:
input_prompt = InputPrompt(
id=input_prompt_id,
prompt=prompt,
content=content,
active=active,
is_public=is_public or user is None,
user_id=user.id if user else None,
)
db_session.add(input_prompt)
if commit:
db_session.commit()
return input_prompt
def insert_input_prompt(
prompt: str,
content: str,
is_public: bool,
user: User | None,
db_session: Session,
) -> InputPrompt:
input_prompt = InputPrompt(
prompt=prompt,
content=content,
active=True,
is_public=is_public or user is None,
user_id=user.id if user is not None else None,
)
db_session.add(input_prompt)
db_session.commit()
return input_prompt
def update_input_prompt(
user: User | None,
input_prompt_id: int,
prompt: str,
content: str,
active: bool,
db_session: Session,
) -> InputPrompt:
input_prompt = db_session.scalar(
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
)
if input_prompt is None:
raise ValueError(f"No input prompt with id {input_prompt_id}")
if not validate_user_prompt_authorization(user, input_prompt):
raise HTTPException(status_code=401, detail="You don't own this prompt")
input_prompt.prompt = prompt
input_prompt.content = content
input_prompt.active = active
db_session.commit()
return input_prompt
def validate_user_prompt_authorization(
user: User | None, input_prompt: InputPrompt
) -> bool:
prompt = InputPromptSnapshot.from_model(input_prompt=input_prompt)
if prompt.user_id is not None:
if user is None:
return False
user_details = UserInfo.from_model(user)
if str(user_details.id) != str(prompt.user_id):
return False
return True
def remove_public_input_prompt(input_prompt_id: int, db_session: Session) -> None:
input_prompt = db_session.scalar(
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
)
if input_prompt is None:
raise ValueError(f"No input prompt with id {input_prompt_id}")
if not input_prompt.is_public:
raise HTTPException(status_code=400, detail="This prompt is not public")
db_session.delete(input_prompt)
db_session.commit()
def remove_input_prompt(
user: User | None, input_prompt_id: int, db_session: Session
) -> None:
input_prompt = db_session.scalar(
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
)
if input_prompt is None:
raise ValueError(f"No input prompt with id {input_prompt_id}")
if input_prompt.is_public:
raise HTTPException(
status_code=400, detail="Cannot delete public prompts with this method"
)
if not validate_user_prompt_authorization(user, input_prompt):
raise HTTPException(status_code=401, detail="You do not own this prompt")
db_session.delete(input_prompt)
db_session.commit()
def fetch_input_prompt_by_id(
id: int, user_id: UUID | None, db_session: Session
) -> InputPrompt:
query = select(InputPrompt).where(InputPrompt.id == id)
if user_id:
query = query.where(
(InputPrompt.user_id == user_id) | (InputPrompt.user_id is None)
)
else:
# If no user_id is provided, only fetch prompts without a user_id (aka public)
query = query.where(InputPrompt.user_id == None) # noqa
result = db_session.scalar(query)
if result is None:
raise HTTPException(422, "No input prompt found")
return result
def fetch_public_input_prompts(
db_session: Session,
) -> list[InputPrompt]:
query = select(InputPrompt).where(InputPrompt.is_public)
return list(db_session.scalars(query).all())
def fetch_input_prompts_by_user(
db_session: Session,
user_id: UUID | None,
active: bool | None = None,
include_public: bool = False,
) -> list[InputPrompt]:
query = select(InputPrompt)
if user_id is not None:
if include_public:
query = query.where(
(InputPrompt.user_id == user_id) | InputPrompt.is_public
)
else:
query = query.where(InputPrompt.user_id == user_id)
elif include_public:
query = query.where(InputPrompt.is_public)
if active is not None:
query = query.where(InputPrompt.active == active)
return list(db_session.scalars(query).all())

View File

@@ -1,15 +1,41 @@
from sqlalchemy import delete
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
from danswer.db.models import LLMProvider as LLMProviderModel
from danswer.db.models import LLMProvider__UserGroup
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from danswer.server.manage.llm.models import FullLLMProvider
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
def update_group_llm_provider_relationships__no_commit(
llm_provider_id: int,
group_ids: list[int] | None,
db_session: Session,
) -> None:
# Delete existing relationships
db_session.query(LLMProvider__UserGroup).filter(
LLMProvider__UserGroup.llm_provider_id == llm_provider_id
).delete(synchronize_session="fetch")
# Add new relationships from given group_ids
if group_ids:
new_relationships = [
LLMProvider__UserGroup(
llm_provider_id=llm_provider_id,
user_group_id=group_id,
)
for group_id in group_ids
]
db_session.add_all(new_relationships)
def upsert_cloud_embedding_provider(
db_session: Session, provider: CloudEmbeddingProviderCreationRequest
) -> CloudEmbeddingProvider:
@@ -36,36 +62,36 @@ def upsert_llm_provider(
existing_llm_provider = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
)
if existing_llm_provider:
existing_llm_provider.provider = llm_provider.provider
existing_llm_provider.api_key = llm_provider.api_key
existing_llm_provider.api_base = llm_provider.api_base
existing_llm_provider.api_version = llm_provider.api_version
existing_llm_provider.custom_config = llm_provider.custom_config
existing_llm_provider.default_model_name = llm_provider.default_model_name
existing_llm_provider.fast_default_model_name = (
llm_provider.fast_default_model_name
)
existing_llm_provider.model_names = llm_provider.model_names
db_session.commit()
return FullLLMProvider.from_model(existing_llm_provider)
# if it does not exist, create a new entry
llm_provider_model = LLMProviderModel(
name=llm_provider.name,
provider=llm_provider.provider,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config,
default_model_name=llm_provider.default_model_name,
fast_default_model_name=llm_provider.fast_default_model_name,
model_names=llm_provider.model_names,
is_default_provider=None,
if not existing_llm_provider:
existing_llm_provider = LLMProviderModel(name=llm_provider.name)
db_session.add(existing_llm_provider)
existing_llm_provider.provider = llm_provider.provider
existing_llm_provider.api_key = llm_provider.api_key
existing_llm_provider.api_base = llm_provider.api_base
existing_llm_provider.api_version = llm_provider.api_version
existing_llm_provider.custom_config = llm_provider.custom_config
existing_llm_provider.default_model_name = llm_provider.default_model_name
existing_llm_provider.fast_default_model_name = llm_provider.fast_default_model_name
existing_llm_provider.model_names = llm_provider.model_names
existing_llm_provider.is_public = llm_provider.is_public
existing_llm_provider.display_model_names = llm_provider.display_model_names
if not existing_llm_provider.id:
# If its not already in the db, we need to generate an ID by flushing
db_session.flush()
# Make sure the relationship table stays up to date
update_group_llm_provider_relationships__no_commit(
llm_provider_id=existing_llm_provider.id,
group_ids=llm_provider.groups,
db_session=db_session,
)
db_session.add(llm_provider_model)
db_session.commit()
return FullLLMProvider.from_model(llm_provider_model)
return FullLLMProvider.from_model(existing_llm_provider)
def fetch_existing_embedding_providers(
@@ -74,8 +100,29 @@ def fetch_existing_embedding_providers(
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
def fetch_existing_llm_providers(db_session: Session) -> list[LLMProviderModel]:
return list(db_session.scalars(select(LLMProviderModel)).all())
def fetch_existing_llm_providers(
db_session: Session,
user: User | None = None,
) -> list[LLMProviderModel]:
if not user:
return list(db_session.scalars(select(LLMProviderModel)).all())
stmt = select(LLMProviderModel).distinct()
user_groups_subquery = (
select(User__UserGroup.user_group_id)
.where(User__UserGroup.user_id == user.id)
.subquery()
)
access_conditions = or_(
LLMProviderModel.is_public,
LLMProviderModel.id.in_( # User is part of a group that has access
select(LLMProvider__UserGroup.llm_provider_id).where(
LLMProvider__UserGroup.user_group_id.in_(user_groups_subquery) # type: ignore
)
),
)
stmt = stmt.where(access_conditions)
return list(db_session.scalars(stmt).all())
def fetch_embedding_provider(
@@ -119,6 +166,13 @@ def remove_embedding_provider(
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
# Remove LLMProvider's dependent relationships
db_session.execute(
delete(LLMProvider__UserGroup).where(
LLMProvider__UserGroup.llm_provider_id == provider_id
)
)
# Remove LLMProvider
db_session.execute(
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
)

View File

@@ -11,6 +11,7 @@ from uuid import UUID
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseOAuthAccountTableUUID
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseUserTableUUID
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTableUUID
from fastapi_users_db_sqlalchemy.generics import TIMESTAMPAware
from sqlalchemy import Boolean
from sqlalchemy import DateTime
from sqlalchemy import Enum
@@ -37,10 +38,12 @@ from danswer.configs.constants import DEFAULT_BOOST
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import FileOrigin
from danswer.configs.constants import MessageType
from danswer.configs.constants import NotificationType
from danswer.configs.constants import SearchFeedbackType
from danswer.configs.constants import TokenRateLimitScope
from danswer.connectors.models import InputType
from danswer.db.enums import ChatSessionSharedStatus
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.enums import IndexingStatus
from danswer.db.enums import IndexModelStatus
from danswer.db.enums import TaskStatus
@@ -50,9 +53,9 @@ from danswer.file_store.models import FileDescriptor
from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride
from danswer.search.enums import RecencyBiasSetting
from danswer.search.enums import SearchType
from danswer.utils.encryption import decrypt_bytes_to_string
from danswer.utils.encryption import encrypt_string_to_bytes
from shared_configs.enums import EmbeddingProvider
class Base(DeclarativeBase):
@@ -117,9 +120,17 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
# if specified, controls the assistants that are shown to the user + their order
# if not specified, all assistants are shown
chosen_assistants: Mapped[list[int]] = mapped_column(
postgresql.ARRAY(Integer), nullable=True
postgresql.JSONB(), nullable=True
)
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
TIMESTAMPAware(timezone=True), nullable=True
)
default_model: Mapped[str] = mapped_column(Text, nullable=True)
# organized in typical structured fashion
# formatted as `displayName__provider__modelName`
# relationships
credentials: Mapped[list["Credential"]] = relationship(
"Credential", back_populates="user", lazy="joined"
@@ -132,10 +143,41 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
)
prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user")
input_prompts: Mapped[list["InputPrompt"]] = relationship(
"InputPrompt", back_populates="user"
)
# Personas owned by this user
personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user")
# Custom tools created by this user
custom_tools: Mapped[list["Tool"]] = relationship("Tool", back_populates="user")
# Notifications for the UI
notifications: Mapped[list["Notification"]] = relationship(
"Notification", back_populates="user"
)
class InputPrompt(Base):
__tablename__ = "inputprompt"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
prompt: Mapped[str] = mapped_column(String)
content: Mapped[str] = mapped_column(String)
active: Mapped[bool] = mapped_column(Boolean)
user: Mapped[User | None] = relationship("User", back_populates="input_prompts")
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
class InputPrompt__User(Base):
__tablename__ = "inputprompt__user"
input_prompt_id: Mapped[int] = mapped_column(
ForeignKey("inputprompt.id"), primary_key=True
)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("inputprompt.id"), primary_key=True
)
class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base):
@@ -157,6 +199,24 @@ class ApiKey(Base):
DateTime(timezone=True), server_default=func.now()
)
# Add this relationship to access the User object via user_id
user: Mapped["User"] = relationship("User", foreign_keys=[user_id])
class Notification(Base):
__tablename__ = "notification"
id: Mapped[int] = mapped_column(primary_key=True)
notif_type: Mapped[NotificationType] = mapped_column(
Enum(NotificationType, native_enum=False)
)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
dismissed: Mapped[bool] = mapped_column(Boolean, default=False)
last_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
first_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
user: Mapped[User] = relationship("User", back_populates="notifications")
"""
Association Tables
@@ -184,7 +244,9 @@ class Persona__User(Base):
__tablename__ = "persona__user"
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True)
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id"), primary_key=True, nullable=True
)
class DocumentSet__User(Base):
@@ -193,7 +255,9 @@ class DocumentSet__User(Base):
document_set_id: Mapped[int] = mapped_column(
ForeignKey("document_set.id"), primary_key=True
)
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id"), primary_key=True, nullable=True
)
class DocumentSet__ConnectorCredentialPair(Base):
@@ -301,6 +365,9 @@ class ConnectorCredentialPair(Base):
nullable=False,
)
name: Mapped[str] = mapped_column(String, nullable=False)
status: Mapped[ConnectorCredentialPairStatus] = mapped_column(
Enum(ConnectorCredentialPairStatus, native_enum=False), nullable=False
)
connector_id: Mapped[int] = mapped_column(
ForeignKey("connector.id"), primary_key=True
)
@@ -337,6 +404,9 @@ class ConnectorCredentialPair(Base):
back_populates="connector_credential_pairs",
overlaps="document_set",
)
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
"IndexAttempt", back_populates="connector_credential_pair"
)
class Document(Base):
@@ -416,6 +486,9 @@ class Connector(Base):
connector_specific_config: Mapped[dict[str, Any]] = mapped_column(
postgresql.JSONB()
)
indexing_start: Mapped[datetime.datetime | None] = mapped_column(
DateTime, nullable=True
)
refresh_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
prune_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
time_created: Mapped[datetime.datetime] = mapped_column(
@@ -424,7 +497,6 @@ class Connector(Base):
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
disabled: Mapped[bool] = mapped_column(Boolean, default=False)
credentials: Mapped[list["ConnectorCredentialPair"]] = relationship(
"ConnectorCredentialPair",
@@ -434,14 +506,17 @@ class Connector(Base):
documents_by_connector: Mapped[
list["DocumentByConnectorCredentialPair"]
] = relationship("DocumentByConnectorCredentialPair", back_populates="connector")
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
"IndexAttempt", back_populates="connector"
)
class Credential(Base):
__tablename__ = "credential"
name: Mapped[str] = mapped_column(String, nullable=True)
source: Mapped[DocumentSource] = mapped_column(
Enum(DocumentSource, native_enum=False)
)
id: Mapped[int] = mapped_column(primary_key=True)
credential_json: Mapped[dict[str, Any]] = mapped_column(EncryptedJson())
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
@@ -462,9 +537,7 @@ class Credential(Base):
documents_by_credential: Mapped[
list["DocumentByConnectorCredentialPair"]
] = relationship("DocumentByConnectorCredentialPair", back_populates="credential")
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
"IndexAttempt", back_populates="credential"
)
user: Mapped[User | None] = relationship("User", back_populates="credentials")
@@ -516,12 +589,16 @@ class EmbeddingModel(Base):
cloud_provider='{self.cloud_provider.name if self.cloud_provider else 'None'}')>"
@property
def api_key(self) -> str | None:
return self.cloud_provider.api_key if self.cloud_provider else None
def provider_type(self) -> EmbeddingProvider | None:
return (
EmbeddingProvider(self.cloud_provider.name.lower())
if self.cloud_provider is not None
else None
)
@property
def provider_type(self) -> str | None:
return self.cloud_provider.name if self.cloud_provider else None
def api_key(self) -> str | None:
return self.cloud_provider.api_key if self.cloud_provider is not None else None
class IndexAttempt(Base):
@@ -534,13 +611,10 @@ class IndexAttempt(Base):
__tablename__ = "index_attempt"
id: Mapped[int] = mapped_column(primary_key=True)
connector_id: Mapped[int | None] = mapped_column(
ForeignKey("connector.id"),
nullable=True,
)
credential_id: Mapped[int | None] = mapped_column(
ForeignKey("credential.id"),
nullable=True,
connector_credential_pair_id: Mapped[int] = mapped_column(
ForeignKey("connector_credential_pair.id"),
nullable=False,
)
# Some index attempts that run from beginning will still have this as False
@@ -578,21 +652,20 @@ class IndexAttempt(Base):
onupdate=func.now(),
)
connector: Mapped[Connector] = relationship(
"Connector", back_populates="index_attempts"
)
credential: Mapped[Credential] = relationship(
"Credential", back_populates="index_attempts"
connector_credential_pair: Mapped[ConnectorCredentialPair] = relationship(
"ConnectorCredentialPair", back_populates="index_attempts"
)
embedding_model: Mapped[EmbeddingModel] = relationship(
"EmbeddingModel", back_populates="index_attempts"
)
error_rows = relationship("IndexAttemptError", back_populates="index_attempt")
__table_args__ = (
Index(
"ix_index_attempt_latest_for_connector_credential_pair",
"connector_id",
"credential_id",
"connector_credential_pair_id",
"time_created",
),
)
@@ -600,13 +673,59 @@ class IndexAttempt(Base):
def __repr__(self) -> str:
return (
f"<IndexAttempt(id={self.id!r}, "
f"connector_id={self.connector_id!r}, "
f"status={self.status!r}, "
f"error_msg={self.error_msg!r})>"
f"time_created={self.time_created!r}, "
f"time_updated={self.time_updated!r}, "
)
def is_finished(self) -> bool:
return self.status.is_terminal()
class IndexAttemptError(Base):
"""
Represents an error that was encountered during an IndexAttempt.
"""
__tablename__ = "index_attempt_errors"
id: Mapped[int] = mapped_column(primary_key=True)
index_attempt_id: Mapped[int] = mapped_column(
ForeignKey("index_attempt.id"),
nullable=True,
)
# The index of the batch where the error occurred (if looping thru batches)
# Just informational.
batch: Mapped[int | None] = mapped_column(Integer, default=None)
doc_summaries: Mapped[list[Any]] = mapped_column(postgresql.JSONB())
error_msg: Mapped[str | None] = mapped_column(Text, default=None)
traceback: Mapped[str | None] = mapped_column(Text, default=None)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
)
# This is the reverse side of the relationship
index_attempt = relationship("IndexAttempt", back_populates="error_rows")
__table_args__ = (
Index(
"index_attempt_id",
"time_created",
),
)
def __repr__(self) -> str:
return (
f"<IndexAttempt(id={self.id!r}, "
f"index_attempt_id={self.index_attempt_id!r}, "
f"error_msg={self.error_msg!r})>"
f"time_created={self.time_created!r}, "
)
class DocumentByConnectorCredentialPair(Base):
"""Represents an indexing of a document by a specific connector / credential pair"""
@@ -821,6 +940,8 @@ class ChatMessage(Base):
secondary="chat_message__search_doc",
back_populates="chat_messages",
)
# NOTE: Should always be attached to the `assistant` message.
# represents the tool calls used to generate this message
tool_calls: Mapped[list["ToolCall"]] = relationship(
"ToolCall",
back_populates="message",
@@ -923,6 +1044,11 @@ class LLMProvider(Base):
default_model_name: Mapped[str] = mapped_column(String)
fast_default_model_name: Mapped[str | None] = mapped_column(String, nullable=True)
# Models to actually disp;aly to users
# If nulled out, we assume in the application logic we should present all
display_model_names: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True
)
# The LLMs that are available for this provider. Only required if not a default provider.
# If a default provider, then the LLM options are pulled from the `options.py` file.
# If needed, can be pulled out as a separate table in the future.
@@ -932,6 +1058,13 @@ class LLMProvider(Base):
# should only be set for a single provider
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
# EE only
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
groups: Mapped[list["UserGroup"]] = relationship(
"UserGroup",
secondary="llm_provider__user_group",
viewonly=True,
)
class CloudEmbeddingProvider(Base):
@@ -1071,10 +1204,6 @@ class Persona(Base):
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
name: Mapped[str] = mapped_column(String)
description: Mapped[str] = mapped_column(String)
# Currently stored but unused, all flows use hybrid
search_type: Mapped[SearchType] = mapped_column(
Enum(SearchType, native_enum=False), default=SearchType.HYBRID
)
# Number of chunks to pass to the LLM for generation.
num_chunks: Mapped[float | None] = mapped_column(Float, nullable=True)
# Pass every chunk through LLM for evaluation, fairly expensive
@@ -1107,9 +1236,14 @@ class Persona(Base):
# controls the ordering of personas in the UI
# higher priority personas are displayed first, ties are resolved by the ID,
# where lower value IDs (e.g. created earlier) are displayed first
display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=None)
display_priority: Mapped[int | None] = mapped_column(
Integer, nullable=True, default=None
)
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
uploaded_image_id: Mapped[str | None] = mapped_column(String, nullable=True)
icon_color: Mapped[str | None] = mapped_column(String, nullable=True)
icon_shape: Mapped[int | None] = mapped_column(Integer, nullable=True)
# These are only defaults, users can select from all if desired
prompts: Mapped[list[Prompt]] = relationship(
@@ -1137,6 +1271,7 @@ class Persona(Base):
viewonly=True,
)
# EE only
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
groups: Mapped[list["UserGroup"]] = relationship(
"UserGroup",
secondary="persona__user_group",
@@ -1323,7 +1458,9 @@ class User__UserGroup(Base):
user_group_id: Mapped[int] = mapped_column(
ForeignKey("user_group.id"), primary_key=True
)
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id"), primary_key=True, nullable=True
)
class UserGroup__ConnectorCredentialPair(Base):
@@ -1360,6 +1497,17 @@ class Persona__UserGroup(Base):
)
class LLMProvider__UserGroup(Base):
__tablename__ = "llm_provider__user_group"
llm_provider_id: Mapped[int] = mapped_column(
ForeignKey("llm_provider.id"), primary_key=True
)
user_group_id: Mapped[int] = mapped_column(
ForeignKey("user_group.id"), primary_key=True
)
class DocumentSet__UserGroup(Base):
__tablename__ = "document_set__user_group"

View File

@@ -0,0 +1,76 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.sql import func
from danswer.configs.constants import NotificationType
from danswer.db.models import Notification
from danswer.db.models import User
def create_notification(
user: User | None,
notif_type: NotificationType,
db_session: Session,
) -> Notification:
notification = Notification(
user_id=user.id if user else None,
notif_type=notif_type,
dismissed=False,
last_shown=func.now(),
first_shown=func.now(),
)
db_session.add(notification)
db_session.commit()
return notification
def get_notification_by_id(
notification_id: int, user: User | None, db_session: Session
) -> Notification:
user_id = user.id if user else None
notif = db_session.get(Notification, notification_id)
if not notif:
raise ValueError(f"No notification found with id {notification_id}")
if notif.user_id != user_id:
raise PermissionError(
f"User {user_id} is not authorized to access notification {notification_id}"
)
return notif
def get_notifications(
user: User | None,
db_session: Session,
notif_type: NotificationType | None = None,
include_dismissed: bool = True,
) -> list[Notification]:
query = select(Notification).where(
Notification.user_id == user.id if user else Notification.user_id.is_(None)
)
if not include_dismissed:
query = query.where(Notification.dismissed.is_(False))
if notif_type:
query = query.where(Notification.notif_type == notif_type)
return list(db_session.execute(query).scalars().all())
def dismiss_all_notifications(
notif_type: NotificationType,
db_session: Session,
) -> None:
db_session.query(Notification).filter(Notification.notif_type == notif_type).update(
{"dismissed": True}
)
db_session.commit()
def dismiss_notification(notification: Notification, db_session: Session) -> None:
notification.dismissed = True
db_session.commit()
def update_notification_last_shown(
notification: Notification, db_session: Session
) -> None:
notification.last_shown = func.now()
db_session.commit()

View File

@@ -9,6 +9,8 @@ from sqlalchemy import not_
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from danswer.auth.schemas import UserRole
@@ -24,6 +26,7 @@ from danswer.db.models import StarterMessage
from danswer.db.models import Tool
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.db.models import UserGroup
from danswer.search.enums import RecencyBiasSetting
from danswer.server.features.persona.models import CreatePersonaRequest
from danswer.server.features.persona.models import PersonaSnapshot
@@ -62,6 +65,7 @@ def create_update_persona(
) -> PersonaSnapshot:
"""Higher level function than upsert_persona, although either is valid to use."""
# Permission to actually use these is checked later
try:
persona = upsert_persona(
persona_id=persona_id,
@@ -80,6 +84,10 @@ def create_update_persona(
starter_messages=create_persona_request.starter_messages,
is_public=create_persona_request.is_public,
db_session=db_session,
icon_color=create_persona_request.icon_color,
icon_shape=create_persona_request.icon_shape,
uploaded_image_id=create_persona_request.uploaded_image_id,
remove_image=create_persona_request.remove_image,
)
versioned_make_persona_private = fetch_versioned_implementation(
@@ -162,6 +170,7 @@ def get_personas(
include_default: bool = True,
include_slack_bot_personas: bool = False,
include_deleted: bool = False,
joinedload_all: bool = False,
) -> Sequence[Persona]:
stmt = select(Persona).distinct()
if user_id is not None:
@@ -193,7 +202,16 @@ def get_personas(
if not include_deleted:
stmt = stmt.where(Persona.deleted.is_(False))
return db_session.scalars(stmt).all()
if joinedload_all:
stmt = stmt.options(
joinedload(Persona.prompts),
joinedload(Persona.tools),
joinedload(Persona.document_sets),
joinedload(Persona.groups),
joinedload(Persona.users),
)
return db_session.execute(stmt).unique().scalars().all()
def mark_persona_as_deleted(
@@ -328,6 +346,12 @@ def upsert_persona(
persona_id: int | None = None,
default_persona: bool = False,
commit: bool = True,
icon_color: str | None = None,
icon_shape: int | None = None,
uploaded_image_id: str | None = None,
display_priority: int | None = None,
is_visible: bool = True,
remove_image: bool | None = None,
) -> Persona:
if persona_id is not None:
persona = db_session.query(Persona).filter_by(id=persona_id).first()
@@ -383,6 +407,12 @@ def upsert_persona(
persona.starter_messages = starter_messages
persona.deleted = False # Un-delete if previously deleted
persona.is_public = is_public
persona.icon_color = icon_color
persona.icon_shape = icon_shape
if remove_image or uploaded_image_id:
persona.uploaded_image_id = uploaded_image_id
persona.display_priority = display_priority
persona.is_visible = is_visible
# Do not delete any associations manually added unless
# a new updated list is provided
@@ -415,6 +445,11 @@ def upsert_persona(
llm_model_version_override=llm_model_version_override,
starter_messages=starter_messages,
tools=tools or [],
icon_shape=icon_shape,
icon_color=icon_color,
uploaded_image_id=uploaded_image_id,
display_priority=display_priority,
is_visible=is_visible,
)
db_session.add(persona)
@@ -548,6 +583,8 @@ def get_default_prompt__read_only() -> Prompt:
return _get_default_prompt(db_session)
# TODO: since this gets called with every chat message, could it be more efficient to pregenerate
# a direct mapping indicating whether a user has access to a specific persona?
def get_persona_by_id(
persona_id: int,
# if user is `None` assume the user is an admin or auth is disabled
@@ -556,16 +593,38 @@ def get_persona_by_id(
include_deleted: bool = False,
is_for_edit: bool = True, # NOTE: assume true for safety
) -> Persona:
stmt = select(Persona).where(Persona.id == persona_id)
stmt = (
select(Persona)
.options(selectinload(Persona.users), selectinload(Persona.groups))
.where(Persona.id == persona_id)
)
or_conditions = []
# if user is an admin, they should have access to all Personas
# and will skip the following clause
if user is not None and user.role != UserRole.ADMIN:
or_conditions.extend([Persona.user_id == user.id, Persona.user_id.is_(None)])
# the user is not an admin
isPersonaUnowned = Persona.user_id.is_(
None
) # allow access if persona user id is None
isUserCreator = (
Persona.user_id == user.id
) # allow access if user created the persona
or_conditions.extend([isPersonaUnowned, isUserCreator])
# if we aren't editing, also give access to all public personas
# if we aren't editing, also give access if:
# 1. the user is authorized for this persona
# 2. the user is in an authorized group for this persona
# 3. if the persona is public
if not is_for_edit:
isSharedWithUser = Persona.users.any(
id=user.id
) # allow access if user is in allowed users
isSharedWithGroup = Persona.groups.any(
UserGroup.users.any(id=user.id)
) # allow access if user is in any allowed group
or_conditions.extend([isSharedWithUser, isSharedWithGroup])
or_conditions.append(Persona.is_public.is_(True))
if or_conditions:

View File

@@ -1,5 +1,6 @@
from sqlalchemy.orm import Session
from danswer.configs.constants import KV_REINDEX_KEY
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.connector_credential_pair import resync_cc_pair
from danswer.db.embedding_model import get_current_db_embedding_model
@@ -10,6 +11,7 @@ from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import (
count_unique_cc_pairs_with_successful_index_attempts,
)
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -52,6 +54,9 @@ def check_index_swap(db_session: Session) -> None:
)
if cc_pair_count > 0:
kv_store = get_dynamic_config_store()
kv_store.store(KV_REINDEX_KEY, False)
# Expire jobs for the now past index/embedding model
cancel_indexing_attempts_past_model(db_session)

View File

@@ -50,4 +50,11 @@ def get_uuid_from_chunk(
unique_identifier_string = "_".join(
[doc_str, str(chunk.chunk_id), str(mini_chunk_ind)]
)
if chunk.large_chunk_reference_ids:
unique_identifier_string += "_large" + "_".join(
[
str(referenced_chunk_id)
for referenced_chunk_id in chunk.large_chunk_reference_ids
]
)
return uuid.uuid5(uuid.NAMESPACE_X500, unique_identifier_string)

View File

@@ -7,6 +7,7 @@ from danswer.access.models import DocumentAccess
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunkUncleaned
from shared_configs.model_server_models import Embedding
@dataclass(frozen=True)
@@ -15,6 +16,25 @@ class DocumentInsertionRecord:
already_existed: bool
@dataclass(frozen=True)
class VespaChunkRequest:
document_id: str
min_chunk_ind: int | None = None
max_chunk_ind: int | None = None
@property
def is_capped(self) -> bool:
# If the max chunk index is not None, then the chunk request is capped
# If the min chunk index is None, we can assume the min is 0
return self.max_chunk_ind is not None
@property
def range(self) -> int | None:
if self.max_chunk_ind is not None:
return (self.max_chunk_ind - (self.min_chunk_ind or 0)) + 1
return None
@dataclass
class DocumentMetadata:
"""
@@ -182,10 +202,9 @@ class IdRetrievalCapable(abc.ABC):
@abc.abstractmethod
def id_based_retrieval(
self,
document_id: str,
min_chunk_ind: int | None,
max_chunk_ind: int | None,
user_access_control_list: list[str] | None = None,
chunk_requests: list[VespaChunkRequest],
filters: IndexFilters,
batch_retrieval: bool = False,
) -> list[InferenceChunkUncleaned]:
"""
Fetch chunk(s) based on document id
@@ -196,11 +215,9 @@ class IdRetrievalCapable(abc.ABC):
or extended section will have duplicate segments.
Parameters:
- document_id: document id for which to retrieve the chunk(s)
- min_chunk_ind: if None then fetch from the start of doc
- max_chunk_ind:
- filters: standard filters object, in this case only the access filter is applied as a
permission check
- chunk_requests: requests containing the document id and the chunk range to retrieve
- filters: Filters to apply to retrieval
- batch_retrieval: If True, perform a batch retrieval
Returns:
list of chunks for the document id or the specific chunk by the specified chunk index
@@ -209,80 +226,6 @@ class IdRetrievalCapable(abc.ABC):
raise NotImplementedError
class KeywordCapable(abc.ABC):
"""
Class must implement the keyword search functionality
"""
@abc.abstractmethod
def keyword_retrieval(
self,
query: str,
filters: IndexFilters,
time_decay_multiplier: float,
num_to_retrieve: int,
offset: int = 0,
) -> list[InferenceChunkUncleaned]:
"""
Run keyword search and return a list of chunks. Inference chunks are chunks with all of the
information required for query time purposes. For example, some details of the document
required at indexing time are no longer needed past this point. At the same time, the
matching keywords need to be highlighted.
NOTE: the query passed in here is the unprocessed plain text query. Preprocessing is
expected to be handled by this function as it may depend on the index implementation.
Things like query expansion, synonym injection, stop word removal, lemmatization, etc. are
done here.
Parameters:
- query: unmodified user query
- filters: standard filter object
- time_decay_multiplier: how much to decay the document scores as they age. Some queries
based on the persona settings, will have this be a 2x or 3x of the default
- num_to_retrieve: number of highest matching chunks to return
- offset: number of highest matching chunks to skip (kind of like pagination)
Returns:
best matching chunks based on keyword matching (should be BM25 algorithm ideally)
"""
raise NotImplementedError
class VectorCapable(abc.ABC):
"""
Class must implement the vector/semantic search functionality
"""
@abc.abstractmethod
def semantic_retrieval(
self,
query: str, # Needed for matching purposes
query_embedding: list[float],
filters: IndexFilters,
time_decay_multiplier: float,
num_to_retrieve: int,
offset: int = 0,
) -> list[InferenceChunkUncleaned]:
"""
Run vector/semantic search and return a list of inference chunks.
Parameters:
- query: unmodified user query. This is needed for getting the matching highlighted
keywords
- query_embedding: vector representation of the query, must be of the correct
dimensionality for the primary index
- filters: standard filter object
- time_decay_multiplier: how much to decay the document scores as they age. Some queries
based on the persona settings, will have this be a 2x or 3x of the default
- num_to_retrieve: number of highest matching chunks to return
- offset: number of highest matching chunks to skip (kind of like pagination)
Returns:
best matching chunks based on vector similarity
"""
raise NotImplementedError
class HybridCapable(abc.ABC):
"""
Class must implement hybrid (keyword + vector) search functionality
@@ -292,12 +235,13 @@ class HybridCapable(abc.ABC):
def hybrid_retrieval(
self,
query: str,
query_embedding: list[float],
query_embedding: Embedding,
final_keywords: list[str] | None,
filters: IndexFilters,
hybrid_alpha: float,
time_decay_multiplier: float,
num_to_retrieve: int,
offset: int = 0,
hybrid_alpha: float | None = None,
) -> list[InferenceChunkUncleaned]:
"""
Run hybrid search and return a list of inference chunks.
@@ -312,15 +256,16 @@ class HybridCapable(abc.ABC):
keywords
- query_embedding: vector representation of the query, must be of the correct
dimensionality for the primary index
- final_keywords: Final keywords to be used from the query, defaults to query if not set
- filters: standard filter object
- time_decay_multiplier: how much to decay the document scores as they age. Some queries
based on the persona settings, will have this be a 2x or 3x of the default
- num_to_retrieve: number of highest matching chunks to return
- offset: number of highest matching chunks to skip (kind of like pagination)
- hybrid_alpha: weighting between the keyword and vector search results. It is important
that the two scores are normalized to the same range so that a meaningful
comparison can be made. 1 for 100% weighting on vector score, 0 for 100% weighting
on keyword score.
- time_decay_multiplier: how much to decay the document scores as they age. Some queries
based on the persona settings, will have this be a 2x or 3x of the default
- num_to_retrieve: number of highest matching chunks to return
- offset: number of highest matching chunks to skip (kind of like pagination)
Returns:
best matching chunks based on weighted sum of keyword and vector/semantic search scores
@@ -386,7 +331,7 @@ class BaseIndex(
"""
class DocumentIndex(KeywordCapable, VectorCapable, HybridCapable, BaseIndex, abc.ABC):
class DocumentIndex(HybridCapable, BaseIndex, abc.ABC):
"""
A valid document index that can plug into all Danswer flows must implement all of these
functionalities, though "technically" it does not need to be keyword or vector capable as

View File

@@ -20,18 +20,10 @@ schema DANSWER_CHUNK_NAME {
# `semantic_identifier` will be the channel name, but the `title` will be empty
field title type string {
indexing: summary | index | attribute
match {
gram
gram-size: 3
}
index: enable-bm25
}
field content type string {
indexing: summary | index
match {
gram
gram-size: 3
}
index: enable-bm25
}
# duplication of `content` is far from ideal, but is needed for
@@ -88,6 +80,10 @@ schema DANSWER_CHUNK_NAME {
rank:filter
attribute: fast-search
}
# If chunk is a large chunk, this will contain the ids of the smaller chunks
field large_chunk_reference_ids type array<int> {
indexing: summary | attribute
}
field metadata type string {
indexing: summary | attribute
}
@@ -153,43 +149,45 @@ schema DANSWER_CHUNK_NAME {
query(query_embedding) tensor<float>(x[VARIABLE_DIM])
}
# This must be separate function for normalize_linear to work
function vector_score() {
function title_vector_score() {
expression {
# If no title, the full vector score comes from the content embedding
(query(title_content_ratio) * if(attribute(skip_title), closeness(field, embeddings), closeness(field, title_embedding))) +
((1 - query(title_content_ratio)) * closeness(field, embeddings))
}
}
# This must be separate function for normalize_linear to work
function keyword_score() {
expression {
(query(title_content_ratio) * bm25(title)) +
((1 - query(title_content_ratio)) * bm25(content))
# If no good matching titles, then it should use the context embeddings rather than having some
# irrelevant title have a vector score of 1. This way at least it will be the doc with the highest
# matching content score getting the full score
max(closeness(field, embeddings), closeness(field, title_embedding))
}
}
# First phase must be vector to allow hits that have no keyword matches
first-phase {
expression: vector_score
expression: closeness(field, embeddings)
}
# Weighted average between Vector Search and BM-25
# Each is a weighted average between the Title and Content fields
# Finally each doc is boosted by it's user feedback based boost and recency
# If any embedding or index field is missing, it just receives a score of 0
# Assumptions:
# - For a given query + corpus, the BM-25 scores will be relatively similar in distribution
# therefore not normalizing before combining.
# - For documents without title, it gets a score of 0 for that and this is ok as documents
# without any title match should be penalized.
global-phase {
expression {
(
# Weighted Vector Similarity Score
(query(alpha) * normalize_linear(vector_score)) +
(
query(alpha) * (
(query(title_content_ratio) * normalize_linear(title_vector_score))
+
((1 - query(title_content_ratio)) * normalize_linear(closeness(field, embeddings)))
)
)
+
# Weighted Keyword Similarity Score
((1 - query(alpha)) * normalize_linear(keyword_score))
# Note: for the BM25 Title score, it requires decent stopword removal in the query
# This needs to be the case so there aren't irrelevant titles being normalized to a score of 1
(
(1 - query(alpha)) * (
(query(title_content_ratio) * normalize_linear(bm25(title)))
+
((1 - query(title_content_ratio)) * normalize_linear(bm25(content)))
)
)
)
# Boost based on user feedback
* document_boost
@@ -204,8 +202,6 @@ schema DANSWER_CHUNK_NAME {
bm25(content)
closeness(field, title_embedding)
closeness(field, embeddings)
keyword_score
vector_score
document_boost
recency_bias
closest(embeddings)
@@ -219,28 +215,4 @@ schema DANSWER_CHUNK_NAME {
expression: bm25(content) + (5 * bm25(title))
}
}
# THE ONES BELOW ARE OUT OF DATE, DO NOT USE
# THEY MIGHT NOT EVEN WORK AT ALL
rank-profile keyword_search inherits default, default_rank {
first-phase {
expression: bm25(content) * document_boost * recency_bias
}
match-features: recency_bias document_boost bm25(content)
}
rank-profile semantic_searchVARIABLE_DIM inherits default, default_rank {
inputs {
query(query_embedding) tensor<float>(x[VARIABLE_DIM])
}
first-phase {
# Cannot do boost with the chosen embedding model because of high default similarity
# This depends on the embedding model chosen
expression: closeness(field, embeddings)
}
match-features: recency_bias document_boost closest(embeddings)
}
}

View File

@@ -0,0 +1,424 @@
import json
import string
from collections.abc import Callable
from collections.abc import Mapping
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
import requests
from retry import retry
from danswer.configs.app_configs import LOG_VESPA_TIMING_INFORMATION
from danswer.document_index.interfaces import VespaChunkRequest
from danswer.document_index.vespa.shared_utils.vespa_request_builders import (
build_vespa_filters,
)
from danswer.document_index.vespa.shared_utils.vespa_request_builders import (
build_vespa_id_based_retrieval_yql,
)
from danswer.document_index.vespa_constants import ACCESS_CONTROL_LIST
from danswer.document_index.vespa_constants import BLURB
from danswer.document_index.vespa_constants import BOOST
from danswer.document_index.vespa_constants import CHUNK_ID
from danswer.document_index.vespa_constants import CONTENT
from danswer.document_index.vespa_constants import CONTENT_SUMMARY
from danswer.document_index.vespa_constants import DOC_UPDATED_AT
from danswer.document_index.vespa_constants import DOCUMENT_ID
from danswer.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from danswer.document_index.vespa_constants import HIDDEN
from danswer.document_index.vespa_constants import LARGE_CHUNK_REFERENCE_IDS
from danswer.document_index.vespa_constants import MAX_ID_SEARCH_QUERY_SIZE
from danswer.document_index.vespa_constants import METADATA
from danswer.document_index.vespa_constants import METADATA_SUFFIX
from danswer.document_index.vespa_constants import PRIMARY_OWNERS
from danswer.document_index.vespa_constants import RECENCY_BIAS
from danswer.document_index.vespa_constants import SEARCH_ENDPOINT
from danswer.document_index.vespa_constants import SECONDARY_OWNERS
from danswer.document_index.vespa_constants import SECTION_CONTINUATION
from danswer.document_index.vespa_constants import SEMANTIC_IDENTIFIER
from danswer.document_index.vespa_constants import SOURCE_LINKS
from danswer.document_index.vespa_constants import SOURCE_TYPE
from danswer.document_index.vespa_constants import TITLE
from danswer.document_index.vespa_constants import YQL_BASE
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunkUncleaned
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
logger = setup_logger()
def _process_dynamic_summary(
dynamic_summary: str, max_summary_length: int = 400
) -> list[str]:
if not dynamic_summary:
return []
current_length = 0
processed_summary: list[str] = []
for summary_section in dynamic_summary.split("<sep />"):
# if we're past the desired max length, break at the last word
if current_length + len(summary_section) >= max_summary_length:
summary_section = summary_section[: max_summary_length - current_length]
summary_section = summary_section.lstrip() # remove any leading whitespace
# handle the case where the truncated section is either just a
# single (partial) word or if it's empty
first_space = summary_section.find(" ")
if first_space == -1:
# add ``...`` to previous section
if processed_summary:
processed_summary[-1] += "..."
break
# handle the valid truncated section case
summary_section = summary_section.rsplit(" ", 1)[0]
if summary_section[-1] in string.punctuation:
summary_section = summary_section[:-1]
summary_section += "..."
processed_summary.append(summary_section)
break
processed_summary.append(summary_section)
current_length += len(summary_section)
return processed_summary
def _vespa_hit_to_inference_chunk(
hit: dict[str, Any], null_score: bool = False
) -> InferenceChunkUncleaned:
fields = cast(dict[str, Any], hit["fields"])
# parse fields that are stored as strings, but are really json / datetime
metadata = json.loads(fields[METADATA]) if METADATA in fields else {}
updated_at = (
datetime.fromtimestamp(fields[DOC_UPDATED_AT], tz=timezone.utc)
if DOC_UPDATED_AT in fields
else None
)
match_highlights = _process_dynamic_summary(
# fallback to regular `content` if the `content_summary` field
# isn't present
dynamic_summary=hit["fields"].get(CONTENT_SUMMARY, hit["fields"][CONTENT]),
)
semantic_identifier = fields.get(SEMANTIC_IDENTIFIER, "")
if not semantic_identifier:
logger.error(
f"Chunk with blurb: {fields.get(BLURB, 'Unknown')[:50]}... has no Semantic Identifier"
)
source_links = fields.get(SOURCE_LINKS, {})
source_links_dict_unprocessed = (
json.loads(source_links) if isinstance(source_links, str) else source_links
)
source_links_dict = {
int(k): v
for k, v in cast(dict[str, str], source_links_dict_unprocessed).items()
}
return InferenceChunkUncleaned(
chunk_id=fields[CHUNK_ID],
blurb=fields.get(BLURB, ""), # Unused
content=fields[CONTENT], # Includes extra title prefix and metadata suffix
source_links=source_links_dict or {0: ""},
section_continuation=fields[SECTION_CONTINUATION],
document_id=fields[DOCUMENT_ID],
source_type=fields[SOURCE_TYPE],
title=fields.get(TITLE),
semantic_identifier=fields[SEMANTIC_IDENTIFIER],
boost=fields.get(BOOST, 1),
recency_bias=fields.get("matchfeatures", {}).get(RECENCY_BIAS, 1.0),
score=None if null_score else hit.get("relevance", 0),
hidden=fields.get(HIDDEN, False),
primary_owners=fields.get(PRIMARY_OWNERS),
secondary_owners=fields.get(SECONDARY_OWNERS),
large_chunk_reference_ids=fields.get(LARGE_CHUNK_REFERENCE_IDS, []),
metadata=metadata,
metadata_suffix=fields.get(METADATA_SUFFIX),
match_highlights=match_highlights,
updated_at=updated_at,
)
def _get_chunks_via_visit_api(
chunk_request: VespaChunkRequest,
index_name: str,
filters: IndexFilters,
field_names: list[str] | None = None,
get_large_chunks: bool = False,
) -> list[dict]:
# Constructing the URL for the Visit API
# NOTE: visit API uses the same URL as the document API, but with different params
url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
# build the list of fields to retrieve
field_set_list = (
None
if not field_names
else [f"{index_name}:{field_name}" for field_name in field_names]
)
acl_fieldset_entry = f"{index_name}:{ACCESS_CONTROL_LIST}"
if (
field_set_list
and filters.access_control_list
and acl_fieldset_entry not in field_set_list
):
field_set_list.append(acl_fieldset_entry)
field_set = ",".join(field_set_list) if field_set_list else None
# build filters
selection = f"{index_name}.document_id=='{chunk_request.document_id}'"
if chunk_request.is_capped:
selection += f" and {index_name}.chunk_id>={chunk_request.min_chunk_ind or 0}"
selection += f" and {index_name}.chunk_id<={chunk_request.max_chunk_ind}"
if not get_large_chunks:
selection += f" and {index_name}.large_chunk_reference_ids == null"
# Setting up the selection criteria in the query parameters
params = {
# NOTE: Document Selector Language doesn't allow `contains`, so we can't check
# for the ACL in the selection. Instead, we have to check as a postfilter
"selection": selection,
"continuation": None,
"wantedDocumentCount": 1_000,
"fieldSet": field_set,
}
document_chunks: list[dict] = []
while True:
response = requests.get(url, params=params)
try:
response.raise_for_status()
except requests.HTTPError as e:
request_info = f"Headers: {response.request.headers}\nPayload: {params}"
response_info = f"Status Code: {response.status_code}\nResponse Content: {response.text}"
error_base = f"Error occurred getting chunk by Document ID {chunk_request.document_id}"
logger.error(
f"{error_base}:\n"
f"{request_info}\n"
f"{response_info}\n"
f"Exception: {e}"
)
raise requests.HTTPError(error_base) from e
# Check if the response contains any documents
response_data = response.json()
if "documents" in response_data:
for document in response_data["documents"]:
if filters.access_control_list:
document_acl = document["fields"].get(ACCESS_CONTROL_LIST)
if not document_acl or not any(
user_acl_entry in document_acl
for user_acl_entry in filters.access_control_list
):
continue
document_chunks.append(document)
# Check for continuation token to handle pagination
if "continuation" in response_data and response_data["continuation"]:
params["continuation"] = response_data["continuation"]
else:
break # Exit loop if no continuation token
return document_chunks
def get_all_vespa_ids_for_document_id(
document_id: str,
index_name: str,
filters: IndexFilters | None = None,
get_large_chunks: bool = False,
) -> list[str]:
document_chunks = _get_chunks_via_visit_api(
chunk_request=VespaChunkRequest(document_id=document_id),
index_name=index_name,
filters=filters or IndexFilters(access_control_list=None),
field_names=[DOCUMENT_ID],
get_large_chunks=get_large_chunks,
)
return [chunk["id"].split("::", 1)[-1] for chunk in document_chunks]
def parallel_visit_api_retrieval(
index_name: str,
chunk_requests: list[VespaChunkRequest],
filters: IndexFilters,
get_large_chunks: bool = False,
) -> list[InferenceChunkUncleaned]:
functions_with_args: list[tuple[Callable, tuple]] = [
(
_get_chunks_via_visit_api,
(chunk_request, index_name, filters, get_large_chunks),
)
for chunk_request in chunk_requests
]
parallel_results = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=True
)
# Any failures to retrieve would give a None, drop the Nones and empty lists
vespa_chunk_sets = [res for res in parallel_results if res]
flattened_vespa_chunks = []
for chunk_set in vespa_chunk_sets:
flattened_vespa_chunks.extend(chunk_set)
inference_chunks = [
_vespa_hit_to_inference_chunk(chunk, null_score=True)
for chunk in flattened_vespa_chunks
]
return inference_chunks
@retry(tries=3, delay=1, backoff=2)
def query_vespa(
query_params: Mapping[str, str | int | float]
) -> list[InferenceChunkUncleaned]:
if "query" in query_params and not cast(str, query_params["query"]).strip():
raise ValueError("No/empty query received")
params = dict(
**query_params,
**{
"presentation.timing": True,
}
if LOG_VESPA_TIMING_INFORMATION
else {},
)
response = requests.post(
SEARCH_ENDPOINT,
json=params,
)
try:
response.raise_for_status()
except requests.HTTPError as e:
request_info = f"Headers: {response.request.headers}\nPayload: {params}"
response_info = (
f"Status Code: {response.status_code}\n"
f"Response Content: {response.text}"
)
error_base = "Failed to query Vespa"
logger.error(
f"{error_base}:\n"
f"{request_info}\n"
f"{response_info}\n"
f"Exception: {e}"
)
raise requests.HTTPError(error_base) from e
response_json: dict[str, Any] = response.json()
if LOG_VESPA_TIMING_INFORMATION:
logger.debug("Vespa timing info: %s", response_json.get("timing"))
hits = response_json["root"].get("children", [])
for hit in hits:
if hit["fields"].get(CONTENT) is None:
identifier = hit["fields"].get("documentid") or hit["id"]
logger.error(
f"Vespa Index with Vespa ID {identifier} has no contents. "
f"This is invalid because the vector is not meaningful and keywordsearch cannot "
f"fetch this document"
)
filtered_hits = [hit for hit in hits if hit["fields"].get(CONTENT) is not None]
inference_chunks = [_vespa_hit_to_inference_chunk(hit) for hit in filtered_hits]
# Good Debugging Spot
return inference_chunks
def _get_chunks_via_batch_search(
index_name: str,
chunk_requests: list[VespaChunkRequest],
filters: IndexFilters,
get_large_chunks: bool = False,
) -> list[InferenceChunkUncleaned]:
if not chunk_requests:
return []
filters_str = build_vespa_filters(filters=filters, include_hidden=True)
yql = (
YQL_BASE.format(index_name=index_name)
+ filters_str
+ build_vespa_id_based_retrieval_yql(chunk_requests[0])
)
chunk_requests.pop(0)
for request in chunk_requests:
yql += " or " + build_vespa_id_based_retrieval_yql(request)
params: dict[str, str | int | float] = {
"yql": yql,
"hits": MAX_ID_SEARCH_QUERY_SIZE,
}
inference_chunks = query_vespa(params)
if not get_large_chunks:
inference_chunks = [
chunk for chunk in inference_chunks if not chunk.large_chunk_reference_ids
]
inference_chunks.sort(key=lambda chunk: chunk.chunk_id)
return inference_chunks
def batch_search_api_retrieval(
index_name: str,
chunk_requests: list[VespaChunkRequest],
filters: IndexFilters,
get_large_chunks: bool = False,
) -> list[InferenceChunkUncleaned]:
retrieved_chunks: list[InferenceChunkUncleaned] = []
capped_requests: list[VespaChunkRequest] = []
uncapped_requests: list[VespaChunkRequest] = []
chunk_count = 0
for request in chunk_requests:
# All requests without a chunk range are uncapped
# Uncapped requests are retrieved using the Visit API
range = request.range
if range is None:
uncapped_requests.append(request)
continue
# If adding the range to the chunk count is greater than the
# max query size, we need to perform a retrieval to avoid hitting the limit
if chunk_count + range > MAX_ID_SEARCH_QUERY_SIZE:
retrieved_chunks.extend(
_get_chunks_via_batch_search(
index_name=index_name,
chunk_requests=capped_requests,
filters=filters,
get_large_chunks=get_large_chunks,
)
)
capped_requests = []
chunk_count = 0
capped_requests.append(request)
chunk_count += range
if capped_requests:
retrieved_chunks.extend(
_get_chunks_via_batch_search(
index_name=index_name,
chunk_requests=capped_requests,
filters=filters,
get_large_chunks=get_large_chunks,
)
)
if uncapped_requests:
logger.debug(f"Retrieving {len(uncapped_requests)} uncapped requests")
retrieved_chunks.extend(
parallel_visit_api_retrieval(
index_name, uncapped_requests, filters, get_large_chunks
)
)
return retrieved_chunks

View File

@@ -0,0 +1,65 @@
import concurrent.futures
import httpx
from retry import retry
from danswer.document_index.vespa.chunk_retrieval import (
get_all_vespa_ids_for_document_id,
)
from danswer.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from danswer.document_index.vespa_constants import NUM_THREADS
from danswer.utils.logger import setup_logger
logger = setup_logger()
CONTENT_SUMMARY = "content_summary"
@retry(tries=3, delay=1, backoff=2)
def _delete_vespa_doc_chunks(
document_id: str, index_name: str, http_client: httpx.Client
) -> None:
doc_chunk_ids = get_all_vespa_ids_for_document_id(
document_id=document_id,
index_name=index_name,
get_large_chunks=True,
)
for chunk_id in doc_chunk_ids:
try:
res = http_client.delete(
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{chunk_id}"
)
res.raise_for_status()
except httpx.HTTPStatusError as e:
logger.error(f"Failed to delete chunk, details: {e.response.text}")
raise
def delete_vespa_docs(
document_ids: list[str],
index_name: str,
http_client: httpx.Client,
executor: concurrent.futures.ThreadPoolExecutor | None = None,
) -> None:
external_executor = True
if not executor:
external_executor = False
executor = concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS)
try:
doc_deletion_future = {
executor.submit(
_delete_vespa_doc_chunks, doc_id, index_name, http_client
): doc_id
for doc_id in document_ids
}
for future in concurrent.futures.as_completed(doc_deletion_future):
# Will raise exception if the deletion raised an exception
future.result()
finally:
if not external_executor:
executor.shutdown(wait=True)

Some files were not shown because too many files have changed in this diff Show More